diff --git a/app/dns/udpns.go b/app/dns/udpns.go index 4ccb4b59d..e9b0ef4b7 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -6,15 +6,15 @@ import ( "sync/atomic" "time" - "v2ray.com/core/common/session" - "v2ray.com/core/features/routing" - "github.com/miekg/dns" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" + "v2ray.com/core/common/session" "v2ray.com/core/common/signal/pubsub" "v2ray.com/core/common/task" + "v2ray.com/core/features/routing" "v2ray.com/core/transport/internet/udp" ) diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 44fe656c2..5b28e7d45 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -167,6 +167,23 @@ func (b *Buffer) Read(data []byte) (int, error) { return nBytes, nil } +// ReadFrom implements io.ReaderFrom. +func (b *Buffer) ReadFrom(reader io.Reader) (int64, error) { + n, err := reader.Read(b.v[b.end:]) + b.end += int32(n) + return int64(n), err +} + +func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) { + end := b.end + size + if end > int32(len(b.v)) { + return 0, newError("out of bound: ", end) + } + n, err := io.ReadFull(reader, b.v[b.end:end]) + b.end += int32(n) + return int64(n), err +} + // String returns the string form of this Buffer. func (b *Buffer) String() string { return string(b.Bytes()) diff --git a/common/buf/buffer_test.go b/common/buf/buffer_test.go index a7a461d07..fa3973d70 100644 --- a/common/buf/buffer_test.go +++ b/common/buf/buffer_test.go @@ -1,12 +1,13 @@ package buf_test import ( + "bytes" + "crypto/rand" "testing" "v2ray.com/core/common" - "v2ray.com/core/common/compare" - . "v2ray.com/core/common/buf" + "v2ray.com/core/common/compare" "v2ray.com/core/common/serial" . "v2ray.com/ext/assert" ) @@ -73,6 +74,23 @@ func TestBufferSlice(t *testing.T) { } } +func TestBufferReadFullFrom(t *testing.T) { + payload := make([]byte, 1024) + common.Must2(rand.Read(payload)) + + reader := bytes.NewReader(payload) + b := New() + n, err := b.ReadFullFrom(reader, 1024) + common.Must(err) + if n != 1024 { + t.Error("expect reading 1024 bytes, but actually ", n) + } + + if err := compare.BytesEqualWithDetail(payload, b.Bytes()); err != nil { + t.Error(err) + } +} + func BenchmarkNewBuffer(b *testing.B) { for i := 0; i < b.N; i++ { buffer := New() diff --git a/common/buf/io.go b/common/buf/io.go index 5527529ae..0c58a53ff 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -26,20 +26,6 @@ type Writer interface { WriteMultiBuffer(MultiBuffer) error } -// ReadFrom creates a Supplier to read from a given io.Reader. -func ReadFrom(reader io.Reader) Supplier { - return func(b []byte) (int, error) { - return reader.Read(b) - } -} - -// ReadFullFrom creates a Supplier to read full buffer from a given io.Reader. -func ReadFullFrom(reader io.Reader, size int32) Supplier { - return func(b []byte) (int, error) { - return io.ReadFull(reader, b[:size]) - } -} - // WriteAllBytes ensures all bytes are written into the given writer. func WriteAllBytes(writer io.Writer, payload []byte) error { for len(payload) > 0 { diff --git a/common/buf/multi_buffer.go b/common/buf/multi_buffer.go index 2d42ae2ff..3f48cc5d7 100644 --- a/common/buf/multi_buffer.go +++ b/common/buf/multi_buffer.go @@ -79,7 +79,7 @@ func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) { for { b := New() - err := b.Reset(ReadFullFrom(reader, Size)) + _, err := b.ReadFullFrom(reader, Size) if b.IsEmpty() { b.Release() } else { @@ -220,7 +220,7 @@ func (mb *MultiBuffer) SliceBySize(size int32) MultiBuffer { *mb = (*mb)[endIndex:] if endIndex == 0 && len(*mb) > 0 { b := New() - common.Must(b.Reset(ReadFullFrom((*mb)[0], size))) + common.Must2(b.ReadFullFrom((*mb)[0], size)) return NewMultiBufferValue(b) } return slice diff --git a/common/buf/reader.go b/common/buf/reader.go index 1d535b4a2..409c77a68 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -10,7 +10,7 @@ import ( func readOne(r io.Reader) (*Buffer, error) { b := New() for i := 0; i < 64; i++ { - err := b.Reset(ReadFrom(r)) + _, err := b.ReadFrom(r) if !b.IsEmpty() { return b, nil } diff --git a/common/buf/writer.go b/common/buf/writer.go index 80c15f841..650519d8a 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -140,7 +140,7 @@ func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error { if w.buffer == nil { w.buffer = New() } - if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil { + if _, err := w.buffer.ReadFrom(&b); err != nil { return err } if w.buffer.IsFull() { @@ -248,7 +248,8 @@ func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) { totalBytes := int64(0) for { - err := b.Reset(ReadFrom(reader)) + b.Clear() + _, err := b.ReadFrom(reader) totalBytes += int64(b.Len()) if err != nil { if errors.Cause(err) == io.EOF { diff --git a/common/buf/writer_test.go b/common/buf/writer_test.go index 7adcf610a..51cc2dfe3 100644 --- a/common/buf/writer_test.go +++ b/common/buf/writer_test.go @@ -17,7 +17,7 @@ func TestWriter(t *testing.T) { assert := With(t) lb := New() - assert(lb.AppendSupplier(ReadFrom(rand.Reader)), IsNil) + common.Must2(lb.ReadFrom(rand.Reader)) expectedBytes := append([]byte(nil), lb.Bytes()...) @@ -54,7 +54,7 @@ func TestDiscardBytes(t *testing.T) { assert := With(t) b := New() - common.Must(b.Reset(ReadFullFrom(rand.Reader, Size))) + common.Must2(b.ReadFullFrom(rand.Reader, Size)) nBytes, err := io.Copy(DiscardBytes, b) assert(nBytes, Equals, int64(Size)) diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 59509cdeb..51722c112 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -132,7 +132,7 @@ var errSoft = newError("waiting for more data") func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) { b := buf.New() - if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil { + if _, err := b.ReadFullFrom(r.reader, size); err != nil { b.Release() return nil, err } @@ -270,7 +270,7 @@ func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) { } if paddingSize > 0 { // With size of the chunk and padding length encrypted, the content of padding doesn't matter much. - common.Must(eb.AppendSupplier(buf.ReadFullFrom(w.randReader, int32(paddingSize)))) + common.Must2(eb.ReadFullFrom(w.randReader, int32(paddingSize))) } return eb, nil @@ -289,9 +289,7 @@ func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error { for { b := buf.New() - common.Must(b.Reset(func(bb []byte) (int, error) { - return mb.Read(bb[:payloadSize]) - })) + common.Must2(b.ReadFrom(io.LimitReader(&mb, int64(payloadSize)))) eb, err := w.seal(b) b.Release() diff --git a/common/mux/frame.go b/common/mux/frame.go index b2caa144f..5bfed2f43 100644 --- a/common/mux/frame.go +++ b/common/mux/frame.go @@ -1,6 +1,7 @@ package mux import ( + "encoding/binary" "io" "v2ray.com/core/common" @@ -9,6 +10,7 @@ import ( "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" + "v2ray.com/core/common/vio" ) type SessionStatus byte @@ -60,11 +62,11 @@ type FrameMetadata struct { } func (f FrameMetadata) WriteTo(b *buf.Buffer) error { - lenBytes := b.Bytes() common.Must2(b.WriteBytes(0x00, 0x00)) + lenBytes := b.Bytes() len0 := b.Len() - if err := b.AppendSupplier(serial.WriteUint16(f.SessionID)); err != nil { + if _, err := vio.WriteUint16(b, f.SessionID); err != nil { return err } @@ -84,7 +86,7 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error { } len1 := b.Len() - serial.Uint16ToBytes(uint16(len1-len0), lenBytes) + binary.BigEndian.PutUint16(lenBytes, uint16(len1-len0)) return nil } @@ -101,7 +103,7 @@ func (f *FrameMetadata) Unmarshal(reader io.Reader) error { b := buf.New() defer b.Release() - if err := b.Reset(buf.ReadFullFrom(reader, int32(metaLen))); err != nil { + if _, err := b.ReadFullFrom(reader, int32(metaLen)); err != nil { return err } return f.UnmarshalFromBuffer(b) diff --git a/common/mux/reader.go b/common/mux/reader.go index fcd243bff..d9ace3dbf 100644 --- a/common/mux/reader.go +++ b/common/mux/reader.go @@ -38,7 +38,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { } b := buf.New() - if err := b.Reset(buf.ReadFullFrom(r.reader, int32(size))); err != nil { + if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil { b.Release() return nil, err } diff --git a/common/mux/writer.go b/common/mux/writer.go index bc5c48c06..d879cb1fa 100644 --- a/common/mux/writer.go +++ b/common/mux/writer.go @@ -5,7 +5,7 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" - "v2ray.com/core/common/serial" + "v2ray.com/core/common/vio" ) type Writer struct { @@ -66,7 +66,7 @@ func writeMetaWithFrame(writer buf.Writer, meta FrameMetadata, data buf.MultiBuf if err := meta.WriteTo(frame); err != nil { return err } - if err := frame.AppendSupplier(serial.WriteUint16(uint16(data.Len()))); err != nil { + if _, err := vio.WriteUint16(frame, uint16(data.Len())); err != nil { return err } diff --git a/common/protocol/address.go b/common/protocol/address.go index 9d391e8e3..e1cfb3aeb 100644 --- a/common/protocol/address.go +++ b/common/protocol/address.go @@ -53,7 +53,7 @@ func NewAddressParser(options ...AddressOption) *AddressParser { } func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) { - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil { + if _, err := b.ReadFullFrom(reader, 2); err != nil { return 0, err } return net.PortFromBytes(b.BytesFrom(-2)), nil @@ -73,7 +73,7 @@ func isValidDomain(d string) bool { } func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) { - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { + if _, err := b.ReadFullFrom(reader, 1); err != nil { return nil, err } @@ -89,21 +89,21 @@ func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres switch addrFamily { case net.AddressFamilyIPv4: - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { + if _, err := b.ReadFullFrom(reader, 4); err != nil { return nil, err } return net.IPAddress(b.BytesFrom(-4)), nil case net.AddressFamilyIPv6: - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { + if _, err := b.ReadFullFrom(reader, 16); err != nil { return nil, err } return net.IPAddress(b.BytesFrom(-16)), nil case net.AddressFamilyDomain: - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { + if _, err := b.ReadFullFrom(reader, 1); err != nil { return nil, err } domainLength := int32(b.Byte(b.Len() - 1)) - if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil { + if _, err := b.ReadFullFrom(reader, domainLength); err != nil { return nil, err } domain := string(b.BytesFrom(-domainLength)) diff --git a/common/serial/numbers.go b/common/serial/numbers.go index 3b15c38ff..a4473ddeb 100755 --- a/common/serial/numbers.go +++ b/common/serial/numbers.go @@ -20,13 +20,6 @@ func ReadUint16(reader io.Reader) (uint16, error) { return BytesToUint16(b[:]), nil } -func WriteUint16(value uint16) func([]byte) (int, error) { - return func(b []byte) (int, error) { - Uint16ToBytes(value, b[:0]) - return 2, nil - } -} - func Uint32ToBytes(value uint32, b []byte) []byte { return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value)) } diff --git a/common/vio/serial.go b/common/vio/serial.go new file mode 100644 index 000000000..7dcfa36a5 --- /dev/null +++ b/common/vio/serial.go @@ -0,0 +1,18 @@ +package vio + +import ( + "encoding/binary" + "io" +) + +func WriteUint32(writer io.Writer, value uint32) (int, error) { + var b [4]byte + binary.BigEndian.PutUint32(b[:], value) + return writer.Write(b[:]) +} + +func WriteUint16(writer io.Writer, value uint16) (int, error) { + var b [2]byte + binary.BigEndian.PutUint16(b[:], value) + return writer.Write(b[:]) +} diff --git a/common/vio/serial_test.go b/common/vio/serial_test.go new file mode 100644 index 000000000..ebbab130b --- /dev/null +++ b/common/vio/serial_test.go @@ -0,0 +1,24 @@ +package vio_test + +import ( + "testing" + + "v2ray.com/core/common" + "v2ray.com/core/common/buf" + "v2ray.com/core/common/compare" + "v2ray.com/core/common/vio" +) + +func TestUint32Serial(t *testing.T) { + b := buf.New() + defer b.Release() + + n, err := vio.WriteUint32(b, 10) + common.Must(err) + if n != 4 { + t.Error("expect 4 bytes writtng, but actually ", n) + } + if err := compare.BytesEqualWithDetail(b.Bytes(), []byte{0, 0, 0, 10}); err != nil { + t.Error(err) + } +} diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index a6f49a1c4..6aba4b496 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -36,7 +36,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ ivLen := account.Cipher.IVSize() var iv []byte if ivLen > 0 { - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, ivLen)); err != nil { + if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil { return nil, nil, newError("failed to read IV").Base(err) } @@ -85,7 +85,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ actualAuth := make([]byte, AuthSize) authenticator.Authenticate(buffer.Bytes())(actualAuth) - err := buffer.AppendSupplier(buf.ReadFullFrom(br, AuthSize)) + _, err := buffer.ReadFullFrom(br, AuthSize) if err != nil { return nil, nil, newError("Failed to read OTA").Base(err) } @@ -196,7 +196,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff buffer := buf.New() ivLen := account.Cipher.IVSize() if ivLen > 0 { - common.Must(buffer.Reset(buf.ReadFullFrom(rand.Reader, ivLen))) + common.Must2(buffer.ReadFullFrom(rand.Reader, ivLen)) } iv := buffer.Bytes() @@ -287,7 +287,7 @@ type UDPReader struct { func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { buffer := buf.New() - err := buffer.AppendSupplier(buf.ReadFrom(v.Reader)) + _, err := buffer.ReadFrom(v.Reader) if err != nil { buffer.Release() return nil, err diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 35d08d798..a7411e9ad 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -7,7 +7,7 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" - "v2ray.com/core/common/serial" + "v2ray.com/core/common/vio" ) const ( @@ -49,7 +49,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol request := new(protocol.RequestHeader) - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil { + if _, err := buffer.ReadFullFrom(reader, 2); err != nil { return nil, newError("insufficient header").Base(err) } @@ -60,7 +60,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol return nil, newError("socks 4 is not allowed when auth is required.") } - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 6)); err != nil { + if _, err := buffer.ReadFullFrom(reader, 6); err != nil { return nil, newError("insufficient header").Base(err) } port := net.PortFromBytes(buffer.BytesRange(2, 4)) @@ -94,7 +94,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol if version == socks5Version { nMethod := int32(buffer.Byte(1)) - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, nMethod)); err != nil { + if _, err := buffer.ReadFullFrom(reader, nMethod); err != nil { return nil, newError("failed to read auth methods").Base(err) } @@ -127,7 +127,9 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol return nil, newError("failed to write auth response").Base(err) } } - if err := buffer.Reset(buf.ReadFullFrom(reader, 3)); err != nil { + + buffer.Clear() + if _, err := buffer.ReadFullFrom(reader, 3); err != nil { return nil, newError("failed to read request").Base(err) } @@ -185,21 +187,25 @@ func readUsernamePassword(reader io.Reader) (string, string, error) { buffer := buf.New() defer buffer.Release() - if err := buffer.Reset(buf.ReadFullFrom(reader, 2)); err != nil { + if _, err := buffer.ReadFullFrom(reader, 2); err != nil { return "", "", err } nUsername := int32(buffer.Byte(1)) - if err := buffer.Reset(buf.ReadFullFrom(reader, nUsername)); err != nil { + buffer.Clear() + if _, err := buffer.ReadFullFrom(reader, nUsername); err != nil { return "", "", err } username := buffer.String() - if err := buffer.Reset(buf.ReadFullFrom(reader, 1)); err != nil { + buffer.Clear() + if _, err := buffer.ReadFullFrom(reader, 1); err != nil { return "", "", err } nPassword := int32(buffer.Byte(0)) - if err := buffer.Reset(buf.ReadFullFrom(reader, nPassword)); err != nil { + + buffer.Clear() + if _, err := buffer.ReadFullFrom(reader, nPassword); err != nil { return "", "", err } password := buffer.String() @@ -254,7 +260,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po defer buffer.Release() common.Must2(buffer.WriteBytes(0x00, errCode)) - common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value()))) + common.Must2(vio.WriteUint16(buffer, port.Value())) common.Must2(buffer.Write(address.IP())) return buf.WriteAllBytes(writer, buffer.Bytes()) } @@ -305,7 +311,7 @@ func NewUDPReader(reader io.Reader) *UDPReader { func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { b := buf.New() - if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil { + if _, err := b.ReadFrom(r.reader); err != nil { return nil, err } if _, err := DecodeUDPPacket(b); err != nil { @@ -362,7 +368,8 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i return nil, err } - if err := b.Reset(buf.ReadFullFrom(reader, 2)); err != nil { + b.Clear() + if _, err := b.ReadFullFrom(reader, 2); err != nil { return nil, err } @@ -374,7 +381,8 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i } if authByte == authPassword { - if err := b.Reset(buf.ReadFullFrom(reader, 2)); err != nil { + b.Clear() + if _, err := b.ReadFullFrom(reader, 2); err != nil { return nil, err } if b.Byte(1) != 0x00 { @@ -398,7 +406,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i } b.Clear() - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil { + if _, err := b.ReadFullFrom(reader, 3); err != nil { return nil, err } diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index ce318f582..7005d1752 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -80,7 +80,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ } if padingLen > 0 { - common.Must(buffer.AppendSupplier(buf.ReadFullFrom(rand.Reader, int32(padingLen)))) + common.Must2(buffer.ReadFullFrom(rand.Reader, int32(padingLen))) } { @@ -164,7 +164,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon buffer := buf.New() defer buffer.Release() - if err := buffer.AppendSupplier(buf.ReadFullFrom(c.responseReader, 4)); err != nil { + if _, err := buffer.ReadFullFrom(c.responseReader, 4); err != nil { return nil, newError("failed to read response header").Base(err) } @@ -180,7 +180,8 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon cmdID := buffer.Byte(2) dataLen := int32(buffer.Byte(3)) - if err := buffer.Reset(buf.ReadFullFrom(c.responseReader, dataLen)); err != nil { + buffer.Clear() + if _, err := buffer.ReadFullFrom(c.responseReader, dataLen); err != nil { return nil, newError("failed to read response command").Base(err) } command, err := UnmarshalCommand(cmdID, buffer.Bytes()) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 5231f4cbf..c95673e0e 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -125,7 +125,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request buffer := buf.New() defer buffer.Release() - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, protocol.IDBytesLen)); err != nil { + if _, err := buffer.ReadFullFrom(reader, protocol.IDBytesLen); err != nil { return nil, newError("failed to read request header").Base(err) } @@ -140,7 +140,8 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:]) decryptor := crypto.NewCryptionReader(aesStream, reader) - if err := buffer.Reset(buf.ReadFullFrom(decryptor, 38)); err != nil { + buffer.Clear() + if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil { return nil, newError("failed to read request header").Base(err) } @@ -178,12 +179,12 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request } if padingLen > 0 { - if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, int32(padingLen))); err != nil { + if _, err := buffer.ReadFullFrom(decryptor, int32(padingLen)); err != nil { return nil, newError("failed to read padding").Base(err) } } - if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil { + if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil { return nil, newError("failed to read checksum").Base(err) } diff --git a/testing/servers/tcp/tcp.go b/testing/servers/tcp/tcp.go index 6f49a34df..83b1ac564 100644 --- a/testing/servers/tcp/tcp.go +++ b/testing/servers/tcp/tcp.go @@ -69,7 +69,7 @@ func (server *Server) handleConnection(conn net.Conn) { for { b := buf.New() - if err := b.AppendSupplier(buf.ReadFrom(conn)); err != nil { + if _, err := b.ReadFrom(conn); err != nil { if err == io.EOF { return nil } diff --git a/transport/internet/domainsocket/listener_test.go b/transport/internet/domainsocket/listener_test.go index d5408c2ba..b6c64c066 100644 --- a/transport/internet/domainsocket/listener_test.go +++ b/transport/internet/domainsocket/listener_test.go @@ -28,7 +28,7 @@ func TestListen(t *testing.T) { defer conn.Close() b := buf.New() - common.Must(b.Reset(buf.ReadFrom(conn))) + common.Must2(b.ReadFrom(conn)) assert(b.String(), Equals, "Request") common.Must2(conn.Write([]byte("Response"))) @@ -44,7 +44,7 @@ func TestListen(t *testing.T) { assert(err, IsNil) b := buf.New() - common.Must(b.Reset(buf.ReadFrom(conn))) + common.Must2(b.ReadFrom(conn)) assert(b.String(), Equals, "Response") } @@ -67,7 +67,7 @@ func TestListenAbstract(t *testing.T) { defer conn.Close() b := buf.New() - common.Must(b.Reset(buf.ReadFrom(conn))) + common.Must2(b.ReadFrom(conn)) assert(b.String(), Equals, "Request") common.Must2(conn.Write([]byte("Response"))) @@ -83,7 +83,7 @@ func TestListenAbstract(t *testing.T) { assert(err, IsNil) b := buf.New() - common.Must(b.Reset(buf.ReadFrom(conn))) + common.Must2(b.ReadFrom(conn)) assert(b.String(), Equals, "Response") } diff --git a/transport/internet/headers/http/http.go b/transport/internet/headers/http/http.go index b1b28fda5..88d4ebf23 100644 --- a/transport/internet/headers/http/http.go +++ b/transport/internet/headers/http/http.go @@ -60,7 +60,7 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) { totalBytes := int32(0) endingDetected := false for totalBytes < maxHeaderLength { - err := buffer.AppendSupplier(buf.ReadFrom(reader)) + _, err := buffer.ReadFrom(reader) if err != nil { buffer.Release() return nil, err diff --git a/transport/internet/http/http_test.go b/transport/internet/http/http_test.go index 6627c3bce..9d54b375f 100644 --- a/transport/internet/http/http_test.go +++ b/transport/internet/http/http_test.go @@ -39,7 +39,7 @@ func TestHTTPConnection(t *testing.T) { defer b.Release() for { - if err := b.Reset(buf.ReadFrom(conn)); err != nil { + if _, err := b.ReadFrom(conn); err != nil { return } nBytes, err := conn.Write(b.Bytes()) @@ -76,13 +76,15 @@ func TestHTTPConnection(t *testing.T) { assert(nBytes, Equals, N) assert(err, IsNil) - assert(b2.Reset(buf.ReadFullFrom(conn, N)), IsNil) + b2.Clear() + common.Must2(b2.ReadFullFrom(conn, N)) assert(b2.Bytes(), Equals, b1) nBytes, err = conn.Write(b1) assert(nBytes, Equals, N) assert(err, IsNil) - assert(b2.Reset(buf.ReadFullFrom(conn, N)), IsNil) + b2.Clear() + common.Must2(b2.ReadFullFrom(conn, N)) assert(b2.Bytes(), Equals, b1) } diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index ae8c36fa5..5d1547996 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -23,7 +23,7 @@ func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn go func() { for { payload := buf.New() - if err := payload.Reset(buf.ReadFrom(input)); err != nil { + if _, err := payload.ReadFrom(input); err != nil { payload.Release() close(cache) return diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index ca1fb6bbb..c734ea364 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -2,6 +2,7 @@ package kcp import ( "container/list" + "io" "sync" "v2ray.com/core/common" @@ -274,9 +275,7 @@ func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool { } b := buf.New() - common.Must(b.Reset(func(v []byte) (int, error) { - return mb.Read(v[:w.conn.mss]) - })) + common.Must2(b.ReadFrom(io.LimitReader(mb, int64(w.conn.mss)))) w.window.Push(w.nextNumber, b) w.nextNumber++ return true diff --git a/transport/internet/sockopt_test.go b/transport/internet/sockopt_test.go index 4e37c2dec..294327754 100644 --- a/transport/internet/sockopt_test.go +++ b/transport/internet/sockopt_test.go @@ -40,7 +40,7 @@ func TestTCPFastOpen(t *testing.T) { common.Must(err) b := buf.New() - common.Must(b.Reset(buf.ReadFrom(conn))) + common.Must2(b.ReadFrom(conn)) if err := compare.BytesEqualWithDetail(b.Bytes(), []byte("abcd")); err != nil { t.Fatal(err) }