From e7ce2c71667210a7ed4856108e3e70d4837c93cf Mon Sep 17 00:00:00 2001 From: V2Ray Date: Mon, 14 Sep 2015 13:51:30 +0200 Subject: [PATCH] Close connection chan when socks client finishes sending data, and fix a performance issue in VMess decoding. --- io/vmess/decryptionreader.go | 72 ------------------------------- io/vmess/decryptionreader_test.go | 62 -------------------------- io/vmess/vmess.go | 10 ++++- net/freedom/freedom.go | 22 +++++++--- net/socks/socks.go | 22 ++++------ net/vmess/vmessin.go | 24 +++++++---- net/vmess/vmessout.go | 19 +++++--- ray.go | 6 ++- 8 files changed, 66 insertions(+), 171 deletions(-) delete mode 100644 io/vmess/decryptionreader.go delete mode 100644 io/vmess/decryptionreader_test.go diff --git a/io/vmess/decryptionreader.go b/io/vmess/decryptionreader.go deleted file mode 100644 index 9fb2d49e8..000000000 --- a/io/vmess/decryptionreader.go +++ /dev/null @@ -1,72 +0,0 @@ -package vmess - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "fmt" - "io" - - v2io "github.com/v2ray/v2ray-core/io" -) - -const ( - blockSize = 16 // Decryption block size, inherited from AES -) - -// DecryptionReader is a byte stream reader to decrypt AES-128 CBC (for now) -// encrypted content. -type DecryptionReader struct { - reader *v2io.CryptionReader - buffer *bytes.Buffer -} - -// NewDecryptionReader creates a new DescriptionReader by given byte Reader and -// AES key. -func NewDecryptionReader(reader io.Reader, key []byte, iv []byte) (*DecryptionReader, error) { - decryptionReader := new(DecryptionReader) - aesCipher, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - aesStream := cipher.NewCFBDecrypter(aesCipher, iv) - decryptionReader.reader = v2io.NewCryptionReader(aesStream, reader) - decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2*blockSize)) - return decryptionReader, nil -} - -func (reader *DecryptionReader) readBlock() error { - buffer := make([]byte, blockSize) - nBytes, err := reader.reader.Read(buffer) - if err != nil && err != io.EOF { - return err - } - if nBytes < blockSize { - return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes) - } - reader.buffer.Write(buffer) - return err -} - -// Read returns decrypted bytes of given length -func (reader *DecryptionReader) Read(p []byte) (int, error) { - nBytes, err := reader.buffer.Read(p) - if err != nil && err != io.EOF { - return nBytes, err - } - if nBytes < len(p) { - err = reader.readBlock() - if err != nil { - return nBytes, err - } - moreBytes, err := reader.buffer.Read(p[nBytes:]) - if err != nil { - return nBytes, err - } - nBytes += moreBytes - if nBytes != len(p) { - return nBytes, fmt.Errorf("Unable to read %d bytes", len(p)) - } - } - return nBytes, err -} diff --git a/io/vmess/decryptionreader_test.go b/io/vmess/decryptionreader_test.go deleted file mode 100644 index 264db0342..000000000 --- a/io/vmess/decryptionreader_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package vmess - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - mrand "math/rand" - "testing" - - "github.com/v2ray/v2ray-core/testing/unit" -) - -func randomBytes(p []byte, t *testing.T) { - assert := unit.Assert(t) - - nBytes, err := rand.Read(p) - assert.Error(err).IsNil() - assert.Int(nBytes).Named("# bytes of random buffer").Equals(len(p)) -} - -func TestNormalReading(t *testing.T) { - assert := unit.Assert(t) - - testSize := 256 - plaintext := make([]byte, testSize) - randomBytes(plaintext, t) - - keySize := 16 - key := make([]byte, keySize) - randomBytes(key, t) - iv := make([]byte, keySize) - randomBytes(iv, t) - - aesBlock, err := aes.NewCipher(key) - assert.Error(err).IsNil() - - aesStream := cipher.NewCFBEncrypter(aesBlock, iv) - - ciphertext := make([]byte, testSize) - aesStream.XORKeyStream(ciphertext, plaintext) - - ciphertextcopy := make([]byte, testSize) - copy(ciphertextcopy, ciphertext) - - reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key, iv) - assert.Error(err).IsNil() - - readtext := make([]byte, testSize) - readSize := 0 - for readSize < testSize { - nBytes := mrand.Intn(16) + 1 - if nBytes > testSize-readSize { - nBytes = testSize - readSize - } - bytesRead, err := reader.Read(readtext[readSize : readSize+nBytes]) - assert.Error(err).IsNil() - assert.Int(bytesRead).Equals(nBytes) - readSize += nBytes - } - assert.Bytes(readtext).Named("Plaintext").Equals(plaintext) -} diff --git a/io/vmess/vmess.go b/io/vmess/vmess.go index 69158aa3b..4873c5f07 100644 --- a/io/vmess/vmess.go +++ b/io/vmess/vmess.go @@ -22,6 +22,8 @@ const ( addrTypeDomain = byte(0x02) Version = byte(0x01) + + blockSize = 16 ) var ( @@ -75,8 +77,14 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { return nil, ErrorInvalidUser } request.UserId = *userId + + aesCipher, err := aes.NewCipher(userId.Hash([]byte("PWD"))) + if err != nil { + return nil, err + } + aesStream := cipher.NewCFBDecrypter(aesCipher, emptyIV) + decryptor := v2io.NewCryptionReader(aesStream, reader) - decryptor, err := NewDecryptionReader(reader, userId.Hash([]byte("PWD")), emptyIV) if err != nil { return nil, err } diff --git a/net/freedom/freedom.go b/net/freedom/freedom.go index dd496764c..83c6e7fb3 100644 --- a/net/freedom/freedom.go +++ b/net/freedom/freedom.go @@ -27,26 +27,34 @@ func (vconn *FreedomConnection) Start(ray core.OutboundRay) error { } log.Debug("Sending outbound tcp: %s", vconn.dest.String()) - finish := make(chan bool, 2) - go vconn.DumpInput(conn, input, finish) - go vconn.DumpOutput(conn, output, finish) - go vconn.CloseConn(conn, finish) + readFinish := make(chan bool) + writeFinish := make(chan bool) + + go vconn.DumpInput(conn, input, writeFinish) + go vconn.DumpOutput(conn, output, readFinish) + go vconn.CloseConn(conn, readFinish, writeFinish) return nil } func (vconn *FreedomConnection) DumpInput(conn net.Conn, input <-chan []byte, finish chan<- bool) { v2net.ChanToWriter(conn, input) + log.Debug("Freedom closing input") finish <- true } func (vconn *FreedomConnection) DumpOutput(conn net.Conn, output chan<- []byte, finish chan<- bool) { v2net.ReaderToChan(output, conn) close(output) + log.Debug("Freedom closing output") finish <- true } -func (vconn *FreedomConnection) CloseConn(conn net.Conn, finish <-chan bool) { - <-finish - <-finish +func (vconn *FreedomConnection) CloseConn(conn net.Conn, readFinish <-chan bool, writeFinish <-chan bool) { + <-writeFinish + if tcpConn, ok := conn.(*net.TCPConn); ok { + log.Debug("Closing freedom write.") + tcpConn.CloseWrite(); + } + <-readFinish conn.Close() } diff --git a/net/socks/socks.go b/net/socks/socks.go index ee2b1a272..65b5ab493 100644 --- a/net/socks/socks.go +++ b/net/socks/socks.go @@ -76,11 +76,10 @@ func (server *SocksServer) HandleConnection(connection net.Conn) error { authResponse := socksio.NewAuthenticationResponse(socksio.AuthNoMatchingMethod) socksio.WriteAuthentication(connection, authResponse) - log.Info("Client doesn't support allowed any auth methods.") + log.Warning("Client doesn't support allowed any auth methods.") return ErrorAuthenticationFailed } - log.Debug("Auth accepted, responding auth.") authResponse := socksio.NewAuthenticationResponse(socksio.AuthNotRequired) socksio.WriteAuthentication(connection, authResponse) @@ -96,7 +95,7 @@ func (server *SocksServer) HandleConnection(connection net.Conn) error { response := socksio.NewSocks5Response() response.Error = socksio.ErrorCommandNotSupported socksio.WriteResponse(connection, response) - log.Info("Unsupported socks command %d", request.Command) + log.Warning("Unsupported socks command %d", request.Command) return ErrorCommandNotSupported } @@ -111,17 +110,17 @@ func (server *SocksServer) HandleConnection(connection net.Conn) error { case socksio.AddrTypeDomain: response.Domain = request.Domain } - log.Debug("Socks response port = %d", response.Port) socksio.WriteResponse(connection, response) ray := server.vPoint.NewInboundConnectionAccepted(request.Destination()) input := ray.InboundInput() output := ray.InboundOutput() - finish := make(chan bool, 2) + readFinish := make(chan bool) + writeFinish := make(chan bool) - go server.dumpInput(connection, input, finish) - go server.dumpOutput(connection, output, finish) - server.waitForFinish(finish) + go server.dumpInput(connection, input, readFinish) + go server.dumpOutput(connection, output, writeFinish) + <-writeFinish return nil } @@ -129,15 +128,12 @@ func (server *SocksServer) HandleConnection(connection net.Conn) error { func (server *SocksServer) dumpInput(conn net.Conn, input chan<- []byte, finish chan<- bool) { v2net.ReaderToChan(input, conn) close(input) + log.Debug("Socks input closed") finish <- true } func (server *SocksServer) dumpOutput(conn net.Conn, output <-chan []byte, finish chan<- bool) { v2net.ChanToWriter(conn, output) + log.Debug("Socks output closed") finish <- true } - -func (server *SocksServer) waitForFinish(finish <-chan bool) { - <-finish - <-finish -} diff --git a/net/vmess/vmessin.go b/net/vmess/vmessin.go index 580597b63..b24355eca 100644 --- a/net/vmess/vmessin.go +++ b/net/vmess/vmessin.go @@ -50,6 +50,7 @@ func (handler *VMessInboundHandler) AcceptConnections(listener net.Listener) err func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error { defer connection.Close() + reader := vmessio.NewVMessRequestReader(handler.clients) request, err := reader.Read(connection) @@ -60,7 +61,6 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error response := vmessio.NewVMessResponse(request) nBytes, err := connection.Write(response[:]) - log.Debug("Writing VMess response %v", response) if err != nil { return log.Error("Failed to write VMess response (%d bytes): %v", nBytes, err) } @@ -83,11 +83,19 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error ray := handler.vPoint.NewInboundConnectionAccepted(request.Address) input := ray.InboundInput() output := ray.InboundOutput() - finish := make(chan bool, 2) + + readFinish := make(chan bool) + writeFinish := make(chan bool) - go handler.dumpInput(requestReader, input, finish) - go handler.dumpOutput(responseWriter, output, finish) - handler.waitForFinish(finish) + go handler.dumpInput(requestReader, input, readFinish) + go handler.dumpOutput(responseWriter, output, writeFinish) + + <-writeFinish + if tcpConn, ok := connection.(*net.TCPConn); ok { + log.Debug("VMessIn closing write") + tcpConn.CloseWrite(); + } + <-readFinish return nil } @@ -95,18 +103,16 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error func (handler *VMessInboundHandler) dumpInput(reader io.Reader, input chan<- []byte, finish chan<- bool) { v2net.ReaderToChan(input, reader) close(input) + log.Debug("VMessIn closing input") finish <- true } func (handler *VMessInboundHandler) dumpOutput(writer io.Writer, output <-chan []byte, finish chan<- bool) { v2net.ChanToWriter(writer, output) + log.Debug("VMessOut closing output") finish <- true } -func (handler *VMessInboundHandler) waitForFinish(finish <-chan bool) { - <-finish - <-finish -} type VMessInboundHandlerFactory struct { } diff --git a/net/vmess/vmessout.go b/net/vmess/vmessout.go index ec77383f9..4a1b40e43 100644 --- a/net/vmess/vmessout.go +++ b/net/vmess/vmessout.go @@ -67,7 +67,7 @@ func (handler *VMessOutboundHandler) Start(ray core.OutboundRay) error { } func (handler *VMessOutboundHandler) startCommunicate(request *vmessio.VMessRequest, dest v2net.Address, ray core.OutboundRay) error { - conn, err := net.Dial("tcp", dest.String()) + conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{dest.IP, int(dest.Port), ""}) log.Debug("VMessOutbound dialing tcp: %s", dest.String()) if err != nil { log.Error("Failed to open tcp (%s): %v", dest.String(), err) @@ -109,22 +109,29 @@ func (handler *VMessOutboundHandler) startCommunicate(request *vmessio.VMessRequ input := ray.OutboundInput() output := ray.OutboundOutput() - finish := make(chan bool, 2) + readFinish := make(chan bool) + writeFinish := make(chan bool) - go handler.dumpInput(encryptRequestWriter, input, finish) - go handler.dumpOutput(decryptResponseReader, output, finish) - handler.waitForFinish(finish) - return nil + go handler.dumpInput(encryptRequestWriter, input, readFinish) + go handler.dumpOutput(decryptResponseReader, output, writeFinish) + + <-readFinish + conn.CloseWrite() + log.Debug("VMessOut closing write") + <-writeFinish + return nil } func (handler *VMessOutboundHandler) dumpOutput(reader io.Reader, output chan<- []byte, finish chan<- bool) { v2net.ReaderToChan(output, reader) close(output) + log.Debug("VMessOut closing output") finish <- true } func (handler *VMessOutboundHandler) dumpInput(writer io.Writer, input <-chan []byte, finish chan<- bool) { v2net.ChanToWriter(writer, input) + log.Debug("VMessOut closing input") finish <- true } diff --git a/ray.go b/ray.go index bcfa013af..96ac01416 100644 --- a/ray.go +++ b/ray.go @@ -1,12 +1,16 @@ package core +const ( + bufferSize = 16 +) + type Ray struct { Input chan []byte Output chan []byte } func NewRay() Ray { - return Ray{make(chan []byte, 128), make(chan []byte, 128)} + return Ray{make(chan []byte, bufferSize), make(chan []byte, bufferSize)} } type OutboundRay interface {