1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-04 16:37:12 -05:00

remove closure on ReadFullFrom

This commit is contained in:
Darien Raymond 2018-11-02 15:01:33 +01:00
parent 9360448c59
commit 58e2ed3381
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
27 changed files with 158 additions and 90 deletions

View File

@ -6,15 +6,15 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"v2ray.com/core/common/session"
"v2ray.com/core/features/routing"
"github.com/miekg/dns" "github.com/miekg/dns"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/common/signal/pubsub" "v2ray.com/core/common/signal/pubsub"
"v2ray.com/core/common/task" "v2ray.com/core/common/task"
"v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet/udp" "v2ray.com/core/transport/internet/udp"
) )

View File

@ -167,6 +167,23 @@ func (b *Buffer) Read(data []byte) (int, error) {
return nBytes, nil return nBytes, nil
} }
// ReadFrom implements io.ReaderFrom.
func (b *Buffer) ReadFrom(reader io.Reader) (int64, error) {
n, err := reader.Read(b.v[b.end:])
b.end += int32(n)
return int64(n), err
}
func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) {
end := b.end + size
if end > int32(len(b.v)) {
return 0, newError("out of bound: ", end)
}
n, err := io.ReadFull(reader, b.v[b.end:end])
b.end += int32(n)
return int64(n), err
}
// String returns the string form of this Buffer. // String returns the string form of this Buffer.
func (b *Buffer) String() string { func (b *Buffer) String() string {
return string(b.Bytes()) return string(b.Bytes())

View File

@ -1,12 +1,13 @@
package buf_test package buf_test
import ( import (
"bytes"
"crypto/rand"
"testing" "testing"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/compare"
. "v2ray.com/core/common/buf" . "v2ray.com/core/common/buf"
"v2ray.com/core/common/compare"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
) )
@ -73,6 +74,23 @@ func TestBufferSlice(t *testing.T) {
} }
} }
func TestBufferReadFullFrom(t *testing.T) {
payload := make([]byte, 1024)
common.Must2(rand.Read(payload))
reader := bytes.NewReader(payload)
b := New()
n, err := b.ReadFullFrom(reader, 1024)
common.Must(err)
if n != 1024 {
t.Error("expect reading 1024 bytes, but actually ", n)
}
if err := compare.BytesEqualWithDetail(payload, b.Bytes()); err != nil {
t.Error(err)
}
}
func BenchmarkNewBuffer(b *testing.B) { func BenchmarkNewBuffer(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
buffer := New() buffer := New()

View File

@ -26,20 +26,6 @@ type Writer interface {
WriteMultiBuffer(MultiBuffer) error WriteMultiBuffer(MultiBuffer) error
} }
// ReadFrom creates a Supplier to read from a given io.Reader.
func ReadFrom(reader io.Reader) Supplier {
return func(b []byte) (int, error) {
return reader.Read(b)
}
}
// ReadFullFrom creates a Supplier to read full buffer from a given io.Reader.
func ReadFullFrom(reader io.Reader, size int32) Supplier {
return func(b []byte) (int, error) {
return io.ReadFull(reader, b[:size])
}
}
// WriteAllBytes ensures all bytes are written into the given writer. // WriteAllBytes ensures all bytes are written into the given writer.
func WriteAllBytes(writer io.Writer, payload []byte) error { func WriteAllBytes(writer io.Writer, payload []byte) error {
for len(payload) > 0 { for len(payload) > 0 {

View File

@ -79,7 +79,7 @@ func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) {
for { for {
b := New() b := New()
err := b.Reset(ReadFullFrom(reader, Size)) _, err := b.ReadFullFrom(reader, Size)
if b.IsEmpty() { if b.IsEmpty() {
b.Release() b.Release()
} else { } else {
@ -220,7 +220,7 @@ func (mb *MultiBuffer) SliceBySize(size int32) MultiBuffer {
*mb = (*mb)[endIndex:] *mb = (*mb)[endIndex:]
if endIndex == 0 && len(*mb) > 0 { if endIndex == 0 && len(*mb) > 0 {
b := New() b := New()
common.Must(b.Reset(ReadFullFrom((*mb)[0], size))) common.Must2(b.ReadFullFrom((*mb)[0], size))
return NewMultiBufferValue(b) return NewMultiBufferValue(b)
} }
return slice return slice

View File

@ -10,7 +10,7 @@ import (
func readOne(r io.Reader) (*Buffer, error) { func readOne(r io.Reader) (*Buffer, error) {
b := New() b := New()
for i := 0; i < 64; i++ { for i := 0; i < 64; i++ {
err := b.Reset(ReadFrom(r)) _, err := b.ReadFrom(r)
if !b.IsEmpty() { if !b.IsEmpty() {
return b, nil return b, nil
} }

View File

@ -140,7 +140,7 @@ func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
if w.buffer == nil { if w.buffer == nil {
w.buffer = New() w.buffer = New()
} }
if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil { if _, err := w.buffer.ReadFrom(&b); err != nil {
return err return err
} }
if w.buffer.IsFull() { if w.buffer.IsFull() {
@ -248,7 +248,8 @@ func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
totalBytes := int64(0) totalBytes := int64(0)
for { for {
err := b.Reset(ReadFrom(reader)) b.Clear()
_, err := b.ReadFrom(reader)
totalBytes += int64(b.Len()) totalBytes += int64(b.Len())
if err != nil { if err != nil {
if errors.Cause(err) == io.EOF { if errors.Cause(err) == io.EOF {

View File

@ -17,7 +17,7 @@ func TestWriter(t *testing.T) {
assert := With(t) assert := With(t)
lb := New() lb := New()
assert(lb.AppendSupplier(ReadFrom(rand.Reader)), IsNil) common.Must2(lb.ReadFrom(rand.Reader))
expectedBytes := append([]byte(nil), lb.Bytes()...) expectedBytes := append([]byte(nil), lb.Bytes()...)
@ -54,7 +54,7 @@ func TestDiscardBytes(t *testing.T) {
assert := With(t) assert := With(t)
b := New() b := New()
common.Must(b.Reset(ReadFullFrom(rand.Reader, Size))) common.Must2(b.ReadFullFrom(rand.Reader, Size))
nBytes, err := io.Copy(DiscardBytes, b) nBytes, err := io.Copy(DiscardBytes, b)
assert(nBytes, Equals, int64(Size)) assert(nBytes, Equals, int64(Size))

View File

@ -132,7 +132,7 @@ var errSoft = newError("waiting for more data")
func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) { func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) {
b := buf.New() b := buf.New()
if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil { if _, err := b.ReadFullFrom(r.reader, size); err != nil {
b.Release() b.Release()
return nil, err return nil, err
} }
@ -270,7 +270,7 @@ func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
} }
if paddingSize > 0 { if paddingSize > 0 {
// With size of the chunk and padding length encrypted, the content of padding doesn't matter much. // With size of the chunk and padding length encrypted, the content of padding doesn't matter much.
common.Must(eb.AppendSupplier(buf.ReadFullFrom(w.randReader, int32(paddingSize)))) common.Must2(eb.ReadFullFrom(w.randReader, int32(paddingSize)))
} }
return eb, nil return eb, nil
@ -289,9 +289,7 @@ func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
for { for {
b := buf.New() b := buf.New()
common.Must(b.Reset(func(bb []byte) (int, error) { common.Must2(b.ReadFrom(io.LimitReader(&mb, int64(payloadSize))))
return mb.Read(bb[:payloadSize])
}))
eb, err := w.seal(b) eb, err := w.seal(b)
b.Release() b.Release()

View File

@ -1,6 +1,7 @@
package mux package mux
import ( import (
"encoding/binary"
"io" "io"
"v2ray.com/core/common" "v2ray.com/core/common"
@ -9,6 +10,7 @@ import (
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
"v2ray.com/core/common/vio"
) )
type SessionStatus byte type SessionStatus byte
@ -60,11 +62,11 @@ type FrameMetadata struct {
} }
func (f FrameMetadata) WriteTo(b *buf.Buffer) error { func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
lenBytes := b.Bytes()
common.Must2(b.WriteBytes(0x00, 0x00)) common.Must2(b.WriteBytes(0x00, 0x00))
lenBytes := b.Bytes()
len0 := b.Len() len0 := b.Len()
if err := b.AppendSupplier(serial.WriteUint16(f.SessionID)); err != nil { if _, err := vio.WriteUint16(b, f.SessionID); err != nil {
return err return err
} }
@ -84,7 +86,7 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
} }
len1 := b.Len() len1 := b.Len()
serial.Uint16ToBytes(uint16(len1-len0), lenBytes) binary.BigEndian.PutUint16(lenBytes, uint16(len1-len0))
return nil return nil
} }
@ -101,7 +103,7 @@ func (f *FrameMetadata) Unmarshal(reader io.Reader) error {
b := buf.New() b := buf.New()
defer b.Release() defer b.Release()
if err := b.Reset(buf.ReadFullFrom(reader, int32(metaLen))); err != nil { if _, err := b.ReadFullFrom(reader, int32(metaLen)); err != nil {
return err return err
} }
return f.UnmarshalFromBuffer(b) return f.UnmarshalFromBuffer(b)

View File

@ -38,7 +38,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
} }
b := buf.New() b := buf.New()
if err := b.Reset(buf.ReadFullFrom(r.reader, int32(size))); err != nil { if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
b.Release() b.Release()
return nil, err return nil, err
} }

View File

@ -5,7 +5,7 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial" "v2ray.com/core/common/vio"
) )
type Writer struct { type Writer struct {
@ -66,7 +66,7 @@ func writeMetaWithFrame(writer buf.Writer, meta FrameMetadata, data buf.MultiBuf
if err := meta.WriteTo(frame); err != nil { if err := meta.WriteTo(frame); err != nil {
return err return err
} }
if err := frame.AppendSupplier(serial.WriteUint16(uint16(data.Len()))); err != nil { if _, err := vio.WriteUint16(frame, uint16(data.Len())); err != nil {
return err return err
} }

View File

@ -53,7 +53,7 @@ func NewAddressParser(options ...AddressOption) *AddressParser {
} }
func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) { func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil { if _, err := b.ReadFullFrom(reader, 2); err != nil {
return 0, err return 0, err
} }
return net.PortFromBytes(b.BytesFrom(-2)), nil return net.PortFromBytes(b.BytesFrom(-2)), nil
@ -73,7 +73,7 @@ func isValidDomain(d string) bool {
} }
func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) { func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { if _, err := b.ReadFullFrom(reader, 1); err != nil {
return nil, err return nil, err
} }
@ -89,21 +89,21 @@ func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres
switch addrFamily { switch addrFamily {
case net.AddressFamilyIPv4: case net.AddressFamilyIPv4:
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { if _, err := b.ReadFullFrom(reader, 4); err != nil {
return nil, err return nil, err
} }
return net.IPAddress(b.BytesFrom(-4)), nil return net.IPAddress(b.BytesFrom(-4)), nil
case net.AddressFamilyIPv6: case net.AddressFamilyIPv6:
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { if _, err := b.ReadFullFrom(reader, 16); err != nil {
return nil, err return nil, err
} }
return net.IPAddress(b.BytesFrom(-16)), nil return net.IPAddress(b.BytesFrom(-16)), nil
case net.AddressFamilyDomain: case net.AddressFamilyDomain:
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { if _, err := b.ReadFullFrom(reader, 1); err != nil {
return nil, err return nil, err
} }
domainLength := int32(b.Byte(b.Len() - 1)) domainLength := int32(b.Byte(b.Len() - 1))
if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil { if _, err := b.ReadFullFrom(reader, domainLength); err != nil {
return nil, err return nil, err
} }
domain := string(b.BytesFrom(-domainLength)) domain := string(b.BytesFrom(-domainLength))

View File

@ -20,13 +20,6 @@ func ReadUint16(reader io.Reader) (uint16, error) {
return BytesToUint16(b[:]), nil return BytesToUint16(b[:]), nil
} }
func WriteUint16(value uint16) func([]byte) (int, error) {
return func(b []byte) (int, error) {
Uint16ToBytes(value, b[:0])
return 2, nil
}
}
func Uint32ToBytes(value uint32, b []byte) []byte { func Uint32ToBytes(value uint32, b []byte) []byte {
return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value)) return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value))
} }

18
common/vio/serial.go Normal file
View File

@ -0,0 +1,18 @@
package vio
import (
"encoding/binary"
"io"
)
func WriteUint32(writer io.Writer, value uint32) (int, error) {
var b [4]byte
binary.BigEndian.PutUint32(b[:], value)
return writer.Write(b[:])
}
func WriteUint16(writer io.Writer, value uint16) (int, error) {
var b [2]byte
binary.BigEndian.PutUint16(b[:], value)
return writer.Write(b[:])
}

24
common/vio/serial_test.go Normal file
View File

@ -0,0 +1,24 @@
package vio_test
import (
"testing"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/compare"
"v2ray.com/core/common/vio"
)
func TestUint32Serial(t *testing.T) {
b := buf.New()
defer b.Release()
n, err := vio.WriteUint32(b, 10)
common.Must(err)
if n != 4 {
t.Error("expect 4 bytes writtng, but actually ", n)
}
if err := compare.BytesEqualWithDetail(b.Bytes(), []byte{0, 0, 0, 10}); err != nil {
t.Error(err)
}
}

View File

@ -36,7 +36,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
ivLen := account.Cipher.IVSize() ivLen := account.Cipher.IVSize()
var iv []byte var iv []byte
if ivLen > 0 { if ivLen > 0 {
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, ivLen)); err != nil { if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil {
return nil, nil, newError("failed to read IV").Base(err) return nil, nil, newError("failed to read IV").Base(err)
} }
@ -85,7 +85,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
actualAuth := make([]byte, AuthSize) actualAuth := make([]byte, AuthSize)
authenticator.Authenticate(buffer.Bytes())(actualAuth) authenticator.Authenticate(buffer.Bytes())(actualAuth)
err := buffer.AppendSupplier(buf.ReadFullFrom(br, AuthSize)) _, err := buffer.ReadFullFrom(br, AuthSize)
if err != nil { if err != nil {
return nil, nil, newError("Failed to read OTA").Base(err) return nil, nil, newError("Failed to read OTA").Base(err)
} }
@ -196,7 +196,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
buffer := buf.New() buffer := buf.New()
ivLen := account.Cipher.IVSize() ivLen := account.Cipher.IVSize()
if ivLen > 0 { if ivLen > 0 {
common.Must(buffer.Reset(buf.ReadFullFrom(rand.Reader, ivLen))) common.Must2(buffer.ReadFullFrom(rand.Reader, ivLen))
} }
iv := buffer.Bytes() iv := buffer.Bytes()
@ -287,7 +287,7 @@ type UDPReader struct {
func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer := buf.New() buffer := buf.New()
err := buffer.AppendSupplier(buf.ReadFrom(v.Reader)) _, err := buffer.ReadFrom(v.Reader)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return nil, err return nil, err

View File

@ -7,7 +7,7 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial" "v2ray.com/core/common/vio"
) )
const ( const (
@ -49,7 +49,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
request := new(protocol.RequestHeader) request := new(protocol.RequestHeader)
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil { if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
return nil, newError("insufficient header").Base(err) return nil, newError("insufficient header").Base(err)
} }
@ -60,7 +60,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
return nil, newError("socks 4 is not allowed when auth is required.") return nil, newError("socks 4 is not allowed when auth is required.")
} }
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 6)); err != nil { if _, err := buffer.ReadFullFrom(reader, 6); err != nil {
return nil, newError("insufficient header").Base(err) return nil, newError("insufficient header").Base(err)
} }
port := net.PortFromBytes(buffer.BytesRange(2, 4)) port := net.PortFromBytes(buffer.BytesRange(2, 4))
@ -94,7 +94,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
if version == socks5Version { if version == socks5Version {
nMethod := int32(buffer.Byte(1)) nMethod := int32(buffer.Byte(1))
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, nMethod)); err != nil { if _, err := buffer.ReadFullFrom(reader, nMethod); err != nil {
return nil, newError("failed to read auth methods").Base(err) return nil, newError("failed to read auth methods").Base(err)
} }
@ -127,7 +127,9 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
return nil, newError("failed to write auth response").Base(err) return nil, newError("failed to write auth response").Base(err)
} }
} }
if err := buffer.Reset(buf.ReadFullFrom(reader, 3)); err != nil {
buffer.Clear()
if _, err := buffer.ReadFullFrom(reader, 3); err != nil {
return nil, newError("failed to read request").Base(err) return nil, newError("failed to read request").Base(err)
} }
@ -185,21 +187,25 @@ func readUsernamePassword(reader io.Reader) (string, string, error) {
buffer := buf.New() buffer := buf.New()
defer buffer.Release() defer buffer.Release()
if err := buffer.Reset(buf.ReadFullFrom(reader, 2)); err != nil { if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
return "", "", err return "", "", err
} }
nUsername := int32(buffer.Byte(1)) nUsername := int32(buffer.Byte(1))
if err := buffer.Reset(buf.ReadFullFrom(reader, nUsername)); err != nil { buffer.Clear()
if _, err := buffer.ReadFullFrom(reader, nUsername); err != nil {
return "", "", err return "", "", err
} }
username := buffer.String() username := buffer.String()
if err := buffer.Reset(buf.ReadFullFrom(reader, 1)); err != nil { buffer.Clear()
if _, err := buffer.ReadFullFrom(reader, 1); err != nil {
return "", "", err return "", "", err
} }
nPassword := int32(buffer.Byte(0)) nPassword := int32(buffer.Byte(0))
if err := buffer.Reset(buf.ReadFullFrom(reader, nPassword)); err != nil {
buffer.Clear()
if _, err := buffer.ReadFullFrom(reader, nPassword); err != nil {
return "", "", err return "", "", err
} }
password := buffer.String() password := buffer.String()
@ -254,7 +260,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
defer buffer.Release() defer buffer.Release()
common.Must2(buffer.WriteBytes(0x00, errCode)) common.Must2(buffer.WriteBytes(0x00, errCode))
common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value()))) common.Must2(vio.WriteUint16(buffer, port.Value()))
common.Must2(buffer.Write(address.IP())) common.Must2(buffer.Write(address.IP()))
return buf.WriteAllBytes(writer, buffer.Bytes()) return buf.WriteAllBytes(writer, buffer.Bytes())
} }
@ -305,7 +311,7 @@ func NewUDPReader(reader io.Reader) *UDPReader {
func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
b := buf.New() b := buf.New()
if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil { if _, err := b.ReadFrom(r.reader); err != nil {
return nil, err return nil, err
} }
if _, err := DecodeUDPPacket(b); err != nil { if _, err := DecodeUDPPacket(b); err != nil {
@ -362,7 +368,8 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
return nil, err return nil, err
} }
if err := b.Reset(buf.ReadFullFrom(reader, 2)); err != nil { b.Clear()
if _, err := b.ReadFullFrom(reader, 2); err != nil {
return nil, err return nil, err
} }
@ -374,7 +381,8 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
} }
if authByte == authPassword { if authByte == authPassword {
if err := b.Reset(buf.ReadFullFrom(reader, 2)); err != nil { b.Clear()
if _, err := b.ReadFullFrom(reader, 2); err != nil {
return nil, err return nil, err
} }
if b.Byte(1) != 0x00 { if b.Byte(1) != 0x00 {
@ -398,7 +406,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
} }
b.Clear() b.Clear()
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil { if _, err := b.ReadFullFrom(reader, 3); err != nil {
return nil, err return nil, err
} }

View File

@ -80,7 +80,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
} }
if padingLen > 0 { if padingLen > 0 {
common.Must(buffer.AppendSupplier(buf.ReadFullFrom(rand.Reader, int32(padingLen)))) common.Must2(buffer.ReadFullFrom(rand.Reader, int32(padingLen)))
} }
{ {
@ -164,7 +164,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
buffer := buf.New() buffer := buf.New()
defer buffer.Release() defer buffer.Release()
if err := buffer.AppendSupplier(buf.ReadFullFrom(c.responseReader, 4)); err != nil { if _, err := buffer.ReadFullFrom(c.responseReader, 4); err != nil {
return nil, newError("failed to read response header").Base(err) return nil, newError("failed to read response header").Base(err)
} }
@ -180,7 +180,8 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
cmdID := buffer.Byte(2) cmdID := buffer.Byte(2)
dataLen := int32(buffer.Byte(3)) dataLen := int32(buffer.Byte(3))
if err := buffer.Reset(buf.ReadFullFrom(c.responseReader, dataLen)); err != nil { buffer.Clear()
if _, err := buffer.ReadFullFrom(c.responseReader, dataLen); err != nil {
return nil, newError("failed to read response command").Base(err) return nil, newError("failed to read response command").Base(err)
} }
command, err := UnmarshalCommand(cmdID, buffer.Bytes()) command, err := UnmarshalCommand(cmdID, buffer.Bytes())

View File

@ -125,7 +125,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
buffer := buf.New() buffer := buf.New()
defer buffer.Release() defer buffer.Release()
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, protocol.IDBytesLen)); err != nil { if _, err := buffer.ReadFullFrom(reader, protocol.IDBytesLen); err != nil {
return nil, newError("failed to read request header").Base(err) return nil, newError("failed to read request header").Base(err)
} }
@ -140,7 +140,8 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:]) aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:])
decryptor := crypto.NewCryptionReader(aesStream, reader) decryptor := crypto.NewCryptionReader(aesStream, reader)
if err := buffer.Reset(buf.ReadFullFrom(decryptor, 38)); err != nil { buffer.Clear()
if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil {
return nil, newError("failed to read request header").Base(err) return nil, newError("failed to read request header").Base(err)
} }
@ -178,12 +179,12 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
} }
if padingLen > 0 { if padingLen > 0 {
if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, int32(padingLen))); err != nil { if _, err := buffer.ReadFullFrom(decryptor, int32(padingLen)); err != nil {
return nil, newError("failed to read padding").Base(err) return nil, newError("failed to read padding").Base(err)
} }
} }
if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil { if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil {
return nil, newError("failed to read checksum").Base(err) return nil, newError("failed to read checksum").Base(err)
} }

View File

@ -69,7 +69,7 @@ func (server *Server) handleConnection(conn net.Conn) {
for { for {
b := buf.New() b := buf.New()
if err := b.AppendSupplier(buf.ReadFrom(conn)); err != nil { if _, err := b.ReadFrom(conn); err != nil {
if err == io.EOF { if err == io.EOF {
return nil return nil
} }

View File

@ -28,7 +28,7 @@ func TestListen(t *testing.T) {
defer conn.Close() defer conn.Close()
b := buf.New() b := buf.New()
common.Must(b.Reset(buf.ReadFrom(conn))) common.Must2(b.ReadFrom(conn))
assert(b.String(), Equals, "Request") assert(b.String(), Equals, "Request")
common.Must2(conn.Write([]byte("Response"))) common.Must2(conn.Write([]byte("Response")))
@ -44,7 +44,7 @@ func TestListen(t *testing.T) {
assert(err, IsNil) assert(err, IsNil)
b := buf.New() b := buf.New()
common.Must(b.Reset(buf.ReadFrom(conn))) common.Must2(b.ReadFrom(conn))
assert(b.String(), Equals, "Response") assert(b.String(), Equals, "Response")
} }
@ -67,7 +67,7 @@ func TestListenAbstract(t *testing.T) {
defer conn.Close() defer conn.Close()
b := buf.New() b := buf.New()
common.Must(b.Reset(buf.ReadFrom(conn))) common.Must2(b.ReadFrom(conn))
assert(b.String(), Equals, "Request") assert(b.String(), Equals, "Request")
common.Must2(conn.Write([]byte("Response"))) common.Must2(conn.Write([]byte("Response")))
@ -83,7 +83,7 @@ func TestListenAbstract(t *testing.T) {
assert(err, IsNil) assert(err, IsNil)
b := buf.New() b := buf.New()
common.Must(b.Reset(buf.ReadFrom(conn))) common.Must2(b.ReadFrom(conn))
assert(b.String(), Equals, "Response") assert(b.String(), Equals, "Response")
} }

View File

@ -60,7 +60,7 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
totalBytes := int32(0) totalBytes := int32(0)
endingDetected := false endingDetected := false
for totalBytes < maxHeaderLength { for totalBytes < maxHeaderLength {
err := buffer.AppendSupplier(buf.ReadFrom(reader)) _, err := buffer.ReadFrom(reader)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return nil, err return nil, err

View File

@ -39,7 +39,7 @@ func TestHTTPConnection(t *testing.T) {
defer b.Release() defer b.Release()
for { for {
if err := b.Reset(buf.ReadFrom(conn)); err != nil { if _, err := b.ReadFrom(conn); err != nil {
return return
} }
nBytes, err := conn.Write(b.Bytes()) nBytes, err := conn.Write(b.Bytes())
@ -76,13 +76,15 @@ func TestHTTPConnection(t *testing.T) {
assert(nBytes, Equals, N) assert(nBytes, Equals, N)
assert(err, IsNil) assert(err, IsNil)
assert(b2.Reset(buf.ReadFullFrom(conn, N)), IsNil) b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) assert(b2.Bytes(), Equals, b1)
nBytes, err = conn.Write(b1) nBytes, err = conn.Write(b1)
assert(nBytes, Equals, N) assert(nBytes, Equals, N)
assert(err, IsNil) assert(err, IsNil)
assert(b2.Reset(buf.ReadFullFrom(conn, N)), IsNil) b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) assert(b2.Bytes(), Equals, b1)
} }

View File

@ -23,7 +23,7 @@ func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn
go func() { go func() {
for { for {
payload := buf.New() payload := buf.New()
if err := payload.Reset(buf.ReadFrom(input)); err != nil { if _, err := payload.ReadFrom(input); err != nil {
payload.Release() payload.Release()
close(cache) close(cache)
return return

View File

@ -2,6 +2,7 @@ package kcp
import ( import (
"container/list" "container/list"
"io"
"sync" "sync"
"v2ray.com/core/common" "v2ray.com/core/common"
@ -274,9 +275,7 @@ func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool {
} }
b := buf.New() b := buf.New()
common.Must(b.Reset(func(v []byte) (int, error) { common.Must2(b.ReadFrom(io.LimitReader(mb, int64(w.conn.mss))))
return mb.Read(v[:w.conn.mss])
}))
w.window.Push(w.nextNumber, b) w.window.Push(w.nextNumber, b)
w.nextNumber++ w.nextNumber++
return true return true

View File

@ -40,7 +40,7 @@ func TestTCPFastOpen(t *testing.T) {
common.Must(err) common.Must(err)
b := buf.New() b := buf.New()
common.Must(b.Reset(buf.ReadFrom(conn))) common.Must2(b.ReadFrom(conn))
if err := compare.BytesEqualWithDetail(b.Bytes(), []byte("abcd")); err != nil { if err := compare.BytesEqualWithDetail(b.Bytes(), []byte("abcd")); err != nil {
t.Fatal(err) t.Fatal(err)
} }