From 494f431c376a18c60ff13587e0c5d2363a2fb017 Mon Sep 17 00:00:00 2001 From: v2ray Date: Thu, 28 Jan 2016 23:58:23 +0100 Subject: [PATCH] remove ReadAllBytes in favor of io.ReadFull --- common/net/transport.go | 14 -------------- proxy/shadowsocks/protocol.go | 12 ++++++------ proxy/shadowsocks/shadowsocks.go | 3 ++- proxy/vmess/outbound/outbound.go | 7 ++++--- proxy/vmess/protocol/vmess.go | 14 +++++++------- 5 files changed, 19 insertions(+), 31 deletions(-) diff --git a/common/net/transport.go b/common/net/transport.go index cb581cd6f..daa0dd1d3 100644 --- a/common/net/transport.go +++ b/common/net/transport.go @@ -17,20 +17,6 @@ func ReadFrom(reader io.Reader, buffer *alloc.Buffer) (*alloc.Buffer, error) { return buffer, err } -// ReadAllBytes reads all bytes required from reader, if no error happens. -func ReadAllBytes(reader io.Reader, buffer []byte) (int, error) { - bytesRead := 0 - bytesAsked := len(buffer) - for bytesRead < bytesAsked { - nBytes, err := reader.Read(buffer[bytesRead:]) - bytesRead += nBytes - if err != nil { - return bytesRead, err - } - } - return bytesRead, nil -} - // ReaderToChan dumps all content from a given reader to a chan by constantly reading it until EOF. func ReaderToChan(stream chan<- *alloc.Buffer, reader io.Reader) error { allocate := alloc.NewBuffer diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 8df98b4bf..067cc4f8f 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -24,7 +24,7 @@ func ReadRequest(reader io.Reader) (*Request, error) { buffer := alloc.NewSmallBuffer() defer buffer.Release() - _, err := v2net.ReadAllBytes(reader, buffer.Value[:1]) + _, err := io.ReadFull(reader, buffer.Value[:1]) if err != nil { log.Error("Shadowsocks: Failed to read address type: ", err) return nil, transport.CorruptedPacket @@ -35,27 +35,27 @@ func ReadRequest(reader io.Reader) (*Request, error) { addrType := buffer.Value[0] switch addrType { case AddrTypeIPv4: - _, err := v2net.ReadAllBytes(reader, buffer.Value[:4]) + _, err := io.ReadFull(reader, buffer.Value[:4]) if err != nil { log.Error("Shadowsocks: Failed to read IPv4 address: ", err) return nil, transport.CorruptedPacket } request.Address = v2net.IPAddress(buffer.Value[:4]) case AddrTypeIPv6: - _, err := v2net.ReadAllBytes(reader, buffer.Value[:16]) + _, err := io.ReadFull(reader, buffer.Value[:16]) if err != nil { log.Error("Shadowsocks: Failed to read IPv6 address: ", err) return nil, transport.CorruptedPacket } request.Address = v2net.IPAddress(buffer.Value[:16]) case AddrTypeDomain: - _, err := v2net.ReadAllBytes(reader, buffer.Value[:1]) + _, err := io.ReadFull(reader, buffer.Value[:1]) if err != nil { log.Error("Shadowsocks: Failed to read domain lenth: ", err) return nil, transport.CorruptedPacket } domainLength := int(buffer.Value[0]) - _, err = v2net.ReadAllBytes(reader, buffer.Value[:domainLength]) + _, err = io.ReadFull(reader, buffer.Value[:domainLength]) if err != nil { log.Error("Shadowsocks: Failed to read domain: ", err) return nil, transport.CorruptedPacket @@ -66,7 +66,7 @@ func ReadRequest(reader io.Reader) (*Request, error) { return nil, transport.CorruptedPacket } - _, err = v2net.ReadAllBytes(reader, buffer.Value[:2]) + _, err = io.ReadFull(reader, buffer.Value[:2]) if err != nil { log.Error("Shadowsocks: Failed to read port: ", err) return nil, transport.CorruptedPacket diff --git a/proxy/shadowsocks/shadowsocks.go b/proxy/shadowsocks/shadowsocks.go index 57b2c08b5..4c0332120 100644 --- a/proxy/shadowsocks/shadowsocks.go +++ b/proxy/shadowsocks/shadowsocks.go @@ -4,6 +4,7 @@ package shadowsocks import ( "crypto/rand" + "io" "sync" "github.com/v2ray/v2ray-core/app" @@ -127,7 +128,7 @@ func (this *Shadowsocks) handleConnection(conn *hub.TCPConn) { buffer := alloc.NewSmallBuffer() defer buffer.Release() - _, err := v2net.ReadAllBytes(conn, buffer.Value[:this.config.Cipher.IVSize()]) + _, err := io.ReadFull(conn, buffer.Value[:this.config.Cipher.IVSize()]) if err != nil { log.Error("Shadowsocks: Failed to read IV: ", err) return diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 74b5fd9a4..da96bfd83 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -3,6 +3,7 @@ package outbound import ( "crypto/md5" "crypto/rand" + "io" "net" "sync" "time" @@ -39,8 +40,8 @@ func (this *VMessOutboundHandler) Dispatch(firstPacket v2net.Packet, ray ray.Out } buffer := alloc.NewSmallBuffer() - defer buffer.Release() // Buffer is released after communication finishes. - v2net.ReadAllBytes(rand.Reader, buffer.Value[:33]) // 16 + 16 + 1 + defer buffer.Release() // Buffer is released after communication finishes. + io.ReadFull(rand.Reader, buffer.Value[:33]) // 16 + 16 + 1 request.RequestIV = buffer.Value[:16] request.RequestKey = buffer.Value[16:32] request.ResponseHeader = buffer.Value[32] @@ -170,7 +171,7 @@ func (this *VMessOutboundHandler) handleResponse(conn net.Conn, request *protoco dataLen := int(buffer.Value[3]) if buffer.Len() < dataLen+4 { // Rare case diffBuffer := make([]byte, dataLen+4-buffer.Len()) - v2net.ReadAllBytes(decryptResponseReader, diffBuffer) + io.ReadFull(decryptResponseReader, diffBuffer) buffer.Append(diffBuffer) } command := buffer.Value[2] diff --git a/proxy/vmess/protocol/vmess.go b/proxy/vmess/protocol/vmess.go index fc613f661..6f608d877 100644 --- a/proxy/vmess/protocol/vmess.go +++ b/proxy/vmess/protocol/vmess.go @@ -69,7 +69,7 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { buffer := alloc.NewSmallBuffer() defer buffer.Release() - nBytes, err := v2net.ReadAllBytes(reader, buffer.Value[:vmess.IDBytesLen]) + nBytes, err := io.ReadFull(reader, buffer.Value[:vmess.IDBytesLen]) if err != nil { log.Debug("VMess: Failed to read request ID (", nBytes, " bytes): ", err) return nil, err @@ -91,7 +91,7 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { decryptor := v2crypto.NewCryptionReader(aesStream, reader) - nBytes, err = v2net.ReadAllBytes(decryptor, buffer.Value[:41]) + nBytes, err = io.ReadFull(decryptor, buffer.Value[:41]) if err != nil { log.Debug("VMess: Failed to read request header (", nBytes, " bytes): ", err) return nil, err @@ -117,7 +117,7 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { switch buffer.Value[40] { case addrTypeIPv4: - nBytes, err = v2net.ReadAllBytes(decryptor, buffer.Value[41:45]) // 4 bytes + nBytes, err = io.ReadFull(decryptor, buffer.Value[41:45]) // 4 bytes bufferLen += 4 if err != nil { log.Debug("VMess: Failed to read target IPv4 (", nBytes, " bytes): ", err) @@ -125,7 +125,7 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { } request.Address = v2net.IPAddress(buffer.Value[41:45]) case addrTypeIPv6: - nBytes, err = v2net.ReadAllBytes(decryptor, buffer.Value[41:57]) // 16 bytes + nBytes, err = io.ReadFull(decryptor, buffer.Value[41:57]) // 16 bytes bufferLen += 16 if err != nil { log.Debug("VMess: Failed to read target IPv6 (", nBytes, " bytes): ", nBytes, err) @@ -133,7 +133,7 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { } request.Address = v2net.IPAddress(buffer.Value[41:57]) case addrTypeDomain: - nBytes, err = v2net.ReadAllBytes(decryptor, buffer.Value[41:42]) + nBytes, err = io.ReadFull(decryptor, buffer.Value[41:42]) if err != nil { log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err) return nil, err @@ -142,7 +142,7 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { if domainLength == 0 { return nil, transport.CorruptedPacket } - nBytes, err = v2net.ReadAllBytes(decryptor, buffer.Value[42:42+domainLength]) + nBytes, err = io.ReadFull(decryptor, buffer.Value[42:42+domainLength]) if err != nil { log.Debug("VMess: Failed to read target domain (", nBytes, " bytes): ", nBytes, err) return nil, err @@ -152,7 +152,7 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { request.Address = v2net.DomainAddress(string(domainBytes)) } - nBytes, err = v2net.ReadAllBytes(decryptor, buffer.Value[bufferLen:bufferLen+4]) + nBytes, err = io.ReadFull(decryptor, buffer.Value[bufferLen:bufferLen+4]) if err != nil { log.Debug("VMess: Failed to read checksum (", nBytes, " bytes): ", nBytes, err) return nil, err