diff --git a/io/aes.go b/io/aes.go index 8c0d4f357..f905324ed 100644 --- a/io/aes.go +++ b/io/aes.go @@ -6,7 +6,7 @@ import ( "io" ) -func NewAesDecryptReader(key []byte, iv []byte, reader io.Reader) (io.Reader, error) { +func NewAesDecryptReader(key []byte, iv []byte, reader io.Reader) (*CryptionReader, error) { aesBlock, err := aes.NewCipher(key) if err != nil { return nil, err @@ -16,7 +16,7 @@ func NewAesDecryptReader(key []byte, iv []byte, reader io.Reader) (io.Reader, er return NewCryptionReader(aesStream, reader), nil } -func NewAesEncryptWriter(key []byte, iv []byte, writer io.Writer) (io.Writer, error) { +func NewAesEncryptWriter(key []byte, iv []byte, writer io.Writer) (*CryptionWriter, error) { aesBlock, err := aes.NewCipher(key) if err != nil { return nil, err diff --git a/io/encryption.go b/io/encryption.go index c2fc40c87..100e9f24a 100644 --- a/io/encryption.go +++ b/io/encryption.go @@ -48,9 +48,13 @@ func NewCryptionWriter(stream cipher.Stream, writer io.Writer) *CryptionWriter { } } +func (writer CryptionWriter) Crypt(blocks []byte) { + writer.stream.XORKeyStream(blocks, blocks) +} + // Write writes the give blocks to underlying writer. The length of the blocks // must be a multiply of BlockSize() func (writer CryptionWriter) Write(blocks []byte) (int, error) { - writer.stream.XORKeyStream(blocks, blocks) + writer.Crypt(blocks) return writer.writer.Write(blocks) } diff --git a/io/vmess/vmess.go b/io/vmess/vmess.go index d319a4ef6..a7f7c6c7d 100644 --- a/io/vmess/vmess.go +++ b/io/vmess/vmess.go @@ -181,26 +181,14 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { return request, nil } -type VMessRequestWriter struct { - idHash v2hash.CounterHash - randomRangeInt64 v2math.RandomInt64InRange -} - -func NewVMessRequestWriter(idHash v2hash.CounterHash, randomRangeInt64 v2math.RandomInt64InRange) *VMessRequestWriter { - return &VMessRequestWriter{ - idHash: idHash, - randomRangeInt64: randomRangeInt64, - } -} - -func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) error { +func (request *VMessRequest) ToBytes(idHash v2hash.CounterHash, randomRangeInt64 v2math.RandomInt64InRange) ([]byte, error) { buffer := make([]byte, 0, 300) - counter := w.randomRangeInt64(time.Now().UTC().Unix(), 30) - idHash := w.idHash.Hash(request.UserId.Bytes, counter) + counter := randomRangeInt64(time.Now().UTC().Unix(), 30) + hash := idHash.Hash(request.UserId.Bytes, counter) - log.Debug("Writing userhash: %v", idHash) - buffer = append(buffer, idHash...) + log.Debug("Writing userhash: %v", hash) + buffer = append(buffer, hash...) encryptionBegin := len(buffer) @@ -208,7 +196,7 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro randomContent := make([]byte, randomLength) _, err := rand.Read(randomContent) if err != nil { - return err + return nil, err } buffer = append(buffer, byte(randomLength)) buffer = append(buffer, randomContent...) @@ -240,7 +228,7 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro paddingBuffer := make([]byte, paddingLength) _, err = rand.Read(paddingBuffer) if err != nil { - return err + return nil, err } buffer = append(buffer, byte(paddingLength)) buffer = append(buffer, paddingBuffer...) @@ -248,21 +236,12 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro aesCipher, err := aes.NewCipher(request.UserId.CmdKey()) if err != nil { - return err + return nil, err } aesStream := cipher.NewCFBEncrypter(aesCipher, v2hash.Int64Hash(counter)) - cWriter := v2io.NewCryptionWriter(aesStream, writer) + aesStream.XORKeyStream(buffer[encryptionBegin:encryptionEnd], buffer[encryptionBegin:encryptionEnd]) - _, err = writer.Write(buffer[0:encryptionBegin]) - if err != nil { - return err - } - _, err = cWriter.Write(buffer[encryptionBegin:encryptionEnd]) - if err != nil { - return err - } - - return nil + return buffer, nil } type VMessResponse [4]byte diff --git a/net/vmess/vmessin.go b/net/vmess/vmessin.go index 6c71ac637..ef9b9ceaa 100644 --- a/net/vmess/vmessin.go +++ b/net/vmess/vmessin.go @@ -77,6 +77,7 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error if err != nil { return log.Error("Failed to create encrypt writer: %v", err) } + //responseWriter.Write(response[:]) // Optimize for small response packet buffer := make([]byte, 0, 1024) @@ -87,7 +88,11 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error } responseWriter.Write(buffer) - go handleOutput(request, responseWriter, output, writeFinish) + if open { + go handleOutput(request, responseWriter, output, writeFinish) + } else { + close(writeFinish) + } <-writeFinish if tcpConn, ok := connection.(*net.TCPConn); ok { diff --git a/net/vmess/vmessout.go b/net/vmess/vmessout.go index 81e96e119..19552128f 100644 --- a/net/vmess/vmessout.go +++ b/net/vmess/vmessout.go @@ -98,19 +98,32 @@ func startCommunicate(request *vmessio.VMessRequest, dest v2net.Address, ray cor func handleRequest(conn *net.TCPConn, request *vmessio.VMessRequest, input <-chan []byte, finish chan<- bool) error { defer close(finish) - requestWriter := vmessio.NewVMessRequestWriter(v2hash.NewTimeHash(v2hash.HMACHash{}), v2math.GenerateRandomInt64InRange) - err := requestWriter.Write(conn, request) - if err != nil { - log.Error("Failed to write VMess request: %v", err) - return err - } - encryptRequestWriter, err := v2io.NewAesEncryptWriter(request.RequestKey[:], request.RequestIV[:], conn) if err != nil { log.Error("Failed to create encrypt writer: %v", err) return err } + buffer, err := request.ToBytes(v2hash.NewTimeHash(v2hash.HMACHash{}), v2math.GenerateRandomInt64InRange) + if err != nil { + log.Error("VMessOut: Failed to serialize VMess request: %v", err) + } + //conn.Write(buffer) + data, open := <-input + if open { + encryptRequestWriter.Crypt(data) + buffer = append(buffer, data...) + } + + _, err = conn.Write(buffer) + if err != nil { + log.Error("VMessOut: Failed to write VMess request: %v", err) + } + + if !open { + return nil + } + v2net.ChanToWriter(encryptRequestWriter, input) return nil }