diff --git a/common/alloc/buffer.go b/common/alloc/buffer.go index 79ac0196e..33c5a3d57 100644 --- a/common/alloc/buffer.go +++ b/common/alloc/buffer.go @@ -220,7 +220,14 @@ func (b *Buffer) Read(data []byte) (int, error) { func (b *Buffer) FillFrom(reader io.Reader) (int, error) { begin := b.Len() nBytes, err := reader.Read(b.head[b.offset+begin:]) - b.Value = b.head[:b.offset+begin+nBytes] + b.Value = b.head[b.offset : b.offset+begin+nBytes] + return nBytes, err +} + +func (b *Buffer) FillFullFrom(reader io.Reader, amount int) (int, error) { + begin := b.Len() + nBytes, err := io.ReadFull(reader, b.head[b.offset+begin:b.offset+begin+amount]) + b.Value = b.head[b.offset : b.offset+begin+nBytes] return nBytes, err } diff --git a/proxy/shadowsocks/ota.go b/proxy/shadowsocks/ota.go index 15698a175..c6c3f367a 100644 --- a/proxy/shadowsocks/ota.go +++ b/proxy/shadowsocks/ota.go @@ -71,8 +71,8 @@ func (v *ChunkReader) Release() { } func (v *ChunkReader) Read() (*alloc.Buffer, error) { - buffer := alloc.NewBuffer() - if _, err := io.ReadFull(v.reader, buffer.BytesTo(2)); err != nil { + buffer := alloc.NewBuffer().Clear() + if _, err := buffer.FillFullFrom(v.reader, 2); err != nil { buffer.Release() return nil, err } @@ -84,11 +84,12 @@ func (v *ChunkReader) Read() (*alloc.Buffer, error) { buffer.Release() buffer = alloc.NewLocalBuffer(int(length) + 128) } - if _, err := io.ReadFull(v.reader, buffer.BytesTo(int(length))); err != nil { + + buffer.Clear() + if _, err := buffer.FillFullFrom(v.reader, int(length)); err != nil { buffer.Release() return nil, err } - buffer.Slice(0, int(length)) authBytes := buffer.BytesTo(AuthSize) payload := buffer.BytesFrom(AuthSize) diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 733cca732..fa40cd107 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -28,11 +28,11 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea } account := rawAccount.(*ShadowsocksAccount) - buffer := alloc.NewLocalBuffer(512) + buffer := alloc.NewLocalBuffer(512).Clear() defer buffer.Release() ivLen := account.Cipher.IVSize() - _, err = io.ReadFull(reader, buffer.Bytes()[:ivLen]) + _, err = buffer.FillFullFrom(reader, ivLen) if err != nil { return nil, nil, errors.Base(err).Message("Shadowsocks|TCP: Failed to read IV.") } @@ -52,8 +52,8 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea Command: protocol.RequestCommandTCP, } - lenBuffer := 1 - _, err = io.ReadFull(reader, buffer.Bytes()[:1]) + buffer.Clear() + _, err = buffer.FillFullFrom(reader, 1) if err != nil { return nil, nil, errors.Base(err).Message("Shadowsocks|TCP: Failed to read address type.") } @@ -73,53 +73,47 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea switch addrType { case AddrTypeIPv4: - _, err := io.ReadFull(reader, buffer.BytesRange(lenBuffer, lenBuffer+4)) + _, err := buffer.FillFullFrom(reader, 4) if err != nil { return nil, nil, errors.Base(err).Message("Shadowsocks|TCP: Failed to read IPv4 address.") } - request.Address = v2net.IPAddress(buffer.BytesRange(lenBuffer, lenBuffer+4)) - lenBuffer += 4 + request.Address = v2net.IPAddress(buffer.BytesFrom(-4)) case AddrTypeIPv6: - _, err := io.ReadFull(reader, buffer.BytesRange(lenBuffer, lenBuffer+16)) + _, err := buffer.FillFullFrom(reader, 16) if err != nil { return nil, nil, errors.Base(err).Message("Shadowsocks|TCP: Failed to read IPv6 address.") } - request.Address = v2net.IPAddress(buffer.BytesRange(lenBuffer, lenBuffer+16)) - lenBuffer += 16 + request.Address = v2net.IPAddress(buffer.BytesFrom(-16)) case AddrTypeDomain: - _, err := io.ReadFull(reader, buffer.BytesRange(lenBuffer, lenBuffer+1)) + _, err := buffer.FillFullFrom(reader, 1) if err != nil { return nil, nil, errors.Base(err).Message("Shadowsocks|TCP: Failed to read domain lenth.") } - domainLength := int(buffer.Bytes()[lenBuffer]) - lenBuffer++ - _, err = io.ReadFull(reader, buffer.BytesRange(lenBuffer, lenBuffer+domainLength)) + domainLength := int(buffer.BytesFrom(-1)[0]) + _, err = buffer.FillFullFrom(reader, domainLength) if err != nil { return nil, nil, errors.Base(err).Message("Shadowsocks|TCP: Failed to read domain.") } - request.Address = v2net.DomainAddress(string(buffer.BytesRange(lenBuffer, lenBuffer+domainLength))) - lenBuffer += domainLength + request.Address = v2net.DomainAddress(string(buffer.BytesFrom(-domainLength))) default: return nil, nil, errors.New("Shadowsocks|TCP: Unknown address type: ", addrType) } - _, err = io.ReadFull(reader, buffer.BytesRange(lenBuffer, lenBuffer+2)) + _, err = buffer.FillFullFrom(reader, 2) if err != nil { return nil, nil, errors.Base(err).Message("Shadowsocks|TCP: Failed to read port.") } - - request.Port = v2net.PortFromBytes(buffer.BytesRange(lenBuffer, lenBuffer+2)) - lenBuffer += 2 + request.Port = v2net.PortFromBytes(buffer.BytesFrom(-2)) if request.Option.Has(RequestOptionOneTimeAuth) { - authBytes := buffer.BytesRange(lenBuffer, lenBuffer+AuthSize) - _, err = io.ReadFull(reader, authBytes) + actualAuth := authenticator.Authenticate(nil, buffer.Bytes()) + + _, err := buffer.FillFullFrom(reader, AuthSize) if err != nil { return nil, nil, errors.Base(err).Message("Shadowsocks|TCP: Failed to read OTA.") } - actualAuth := authenticator.Authenticate(nil, buffer.BytesTo(lenBuffer)) - if !bytes.Equal(actualAuth, authBytes) { + if !bytes.Equal(actualAuth, buffer.BytesFrom(-AuthSize)) { return nil, nil, errors.New("Shadowsocks|TCP: Invalid OTA") } } diff --git a/proxy/socks/protocol/socks.go b/proxy/socks/protocol/socks.go index ecd6fc6ce..b5af88463 100644 --- a/proxy/socks/protocol/socks.go +++ b/proxy/socks/protocol/socks.go @@ -185,10 +185,10 @@ type Socks5Request struct { } func ReadRequest(reader io.Reader) (request *Socks5Request, err error) { - buffer := alloc.NewLocalBuffer(512) + buffer := alloc.NewLocalBuffer(512).Clear() defer buffer.Release() - _, err = io.ReadFull(reader, buffer.Value[:4]) + _, err = buffer.FillFullFrom(reader, 4) if err != nil { return } @@ -206,17 +206,18 @@ func ReadRequest(reader io.Reader) (request *Socks5Request, err error) { return } case AddrTypeDomain: - _, err = io.ReadFull(reader, buffer.Value[0:1]) + buffer.Clear() + _, err = buffer.FillFullFrom(reader, 1) if err != nil { return } - domainLength := buffer.Value[0] - _, err = io.ReadFull(reader, buffer.Value[:domainLength]) + domainLength := int(buffer.Byte(0)) + _, err = buffer.FillFullFrom(reader, domainLength) if err != nil { return } - request.Domain = string(append([]byte(nil), buffer.Value[:domainLength]...)) + request.Domain = string(buffer.BytesFrom(-domainLength)) case AddrTypeIPv6: _, err = io.ReadFull(reader, request.IPv6[:]) if err != nil { @@ -227,12 +228,12 @@ func ReadRequest(reader io.Reader) (request *Socks5Request, err error) { return } - _, err = io.ReadFull(reader, buffer.Value[:2]) + _, err = buffer.FillFullFrom(reader, 2) if err != nil { return } - request.Port = v2net.PortFromBytes(buffer.Value[:2]) + request.Port = v2net.PortFromBytes(buffer.BytesFrom(-2)) return }