mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-17 14:57:44 -05:00
remove closure on ReadFullFrom
This commit is contained in:
parent
9360448c59
commit
58e2ed3381
@ -6,15 +6,15 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"v2ray.com/core/common/session"
|
||||
"v2ray.com/core/features/routing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/session"
|
||||
"v2ray.com/core/common/signal/pubsub"
|
||||
"v2ray.com/core/common/task"
|
||||
"v2ray.com/core/features/routing"
|
||||
"v2ray.com/core/transport/internet/udp"
|
||||
)
|
||||
|
||||
|
@ -167,6 +167,23 @@ func (b *Buffer) Read(data []byte) (int, error) {
|
||||
return nBytes, nil
|
||||
}
|
||||
|
||||
// ReadFrom implements io.ReaderFrom.
|
||||
func (b *Buffer) ReadFrom(reader io.Reader) (int64, error) {
|
||||
n, err := reader.Read(b.v[b.end:])
|
||||
b.end += int32(n)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) {
|
||||
end := b.end + size
|
||||
if end > int32(len(b.v)) {
|
||||
return 0, newError("out of bound: ", end)
|
||||
}
|
||||
n, err := io.ReadFull(reader, b.v[b.end:end])
|
||||
b.end += int32(n)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// String returns the string form of this Buffer.
|
||||
func (b *Buffer) String() string {
|
||||
return string(b.Bytes())
|
||||
|
@ -1,12 +1,13 @@
|
||||
package buf_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/compare"
|
||||
|
||||
. "v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/compare"
|
||||
"v2ray.com/core/common/serial"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
@ -73,6 +74,23 @@ func TestBufferSlice(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferReadFullFrom(t *testing.T) {
|
||||
payload := make([]byte, 1024)
|
||||
common.Must2(rand.Read(payload))
|
||||
|
||||
reader := bytes.NewReader(payload)
|
||||
b := New()
|
||||
n, err := b.ReadFullFrom(reader, 1024)
|
||||
common.Must(err)
|
||||
if n != 1024 {
|
||||
t.Error("expect reading 1024 bytes, but actually ", n)
|
||||
}
|
||||
|
||||
if err := compare.BytesEqualWithDetail(payload, b.Bytes()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewBuffer(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
buffer := New()
|
||||
|
@ -26,20 +26,6 @@ type Writer interface {
|
||||
WriteMultiBuffer(MultiBuffer) error
|
||||
}
|
||||
|
||||
// ReadFrom creates a Supplier to read from a given io.Reader.
|
||||
func ReadFrom(reader io.Reader) Supplier {
|
||||
return func(b []byte) (int, error) {
|
||||
return reader.Read(b)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFullFrom creates a Supplier to read full buffer from a given io.Reader.
|
||||
func ReadFullFrom(reader io.Reader, size int32) Supplier {
|
||||
return func(b []byte) (int, error) {
|
||||
return io.ReadFull(reader, b[:size])
|
||||
}
|
||||
}
|
||||
|
||||
// WriteAllBytes ensures all bytes are written into the given writer.
|
||||
func WriteAllBytes(writer io.Writer, payload []byte) error {
|
||||
for len(payload) > 0 {
|
||||
|
@ -79,7 +79,7 @@ func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) {
|
||||
|
||||
for {
|
||||
b := New()
|
||||
err := b.Reset(ReadFullFrom(reader, Size))
|
||||
_, err := b.ReadFullFrom(reader, Size)
|
||||
if b.IsEmpty() {
|
||||
b.Release()
|
||||
} else {
|
||||
@ -220,7 +220,7 @@ func (mb *MultiBuffer) SliceBySize(size int32) MultiBuffer {
|
||||
*mb = (*mb)[endIndex:]
|
||||
if endIndex == 0 && len(*mb) > 0 {
|
||||
b := New()
|
||||
common.Must(b.Reset(ReadFullFrom((*mb)[0], size)))
|
||||
common.Must2(b.ReadFullFrom((*mb)[0], size))
|
||||
return NewMultiBufferValue(b)
|
||||
}
|
||||
return slice
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
func readOne(r io.Reader) (*Buffer, error) {
|
||||
b := New()
|
||||
for i := 0; i < 64; i++ {
|
||||
err := b.Reset(ReadFrom(r))
|
||||
_, err := b.ReadFrom(r)
|
||||
if !b.IsEmpty() {
|
||||
return b, nil
|
||||
}
|
||||
|
@ -140,7 +140,7 @@ func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
|
||||
if w.buffer == nil {
|
||||
w.buffer = New()
|
||||
}
|
||||
if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil {
|
||||
if _, err := w.buffer.ReadFrom(&b); err != nil {
|
||||
return err
|
||||
}
|
||||
if w.buffer.IsFull() {
|
||||
@ -248,7 +248,8 @@ func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
|
||||
|
||||
totalBytes := int64(0)
|
||||
for {
|
||||
err := b.Reset(ReadFrom(reader))
|
||||
b.Clear()
|
||||
_, err := b.ReadFrom(reader)
|
||||
totalBytes += int64(b.Len())
|
||||
if err != nil {
|
||||
if errors.Cause(err) == io.EOF {
|
||||
|
@ -17,7 +17,7 @@ func TestWriter(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
lb := New()
|
||||
assert(lb.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
|
||||
common.Must2(lb.ReadFrom(rand.Reader))
|
||||
|
||||
expectedBytes := append([]byte(nil), lb.Bytes()...)
|
||||
|
||||
@ -54,7 +54,7 @@ func TestDiscardBytes(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
b := New()
|
||||
common.Must(b.Reset(ReadFullFrom(rand.Reader, Size)))
|
||||
common.Must2(b.ReadFullFrom(rand.Reader, Size))
|
||||
|
||||
nBytes, err := io.Copy(DiscardBytes, b)
|
||||
assert(nBytes, Equals, int64(Size))
|
||||
|
@ -132,7 +132,7 @@ var errSoft = newError("waiting for more data")
|
||||
|
||||
func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) {
|
||||
b := buf.New()
|
||||
if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil {
|
||||
if _, err := b.ReadFullFrom(r.reader, size); err != nil {
|
||||
b.Release()
|
||||
return nil, err
|
||||
}
|
||||
@ -270,7 +270,7 @@ func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
|
||||
}
|
||||
if paddingSize > 0 {
|
||||
// With size of the chunk and padding length encrypted, the content of padding doesn't matter much.
|
||||
common.Must(eb.AppendSupplier(buf.ReadFullFrom(w.randReader, int32(paddingSize))))
|
||||
common.Must2(eb.ReadFullFrom(w.randReader, int32(paddingSize)))
|
||||
}
|
||||
|
||||
return eb, nil
|
||||
@ -289,9 +289,7 @@ func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
|
||||
|
||||
for {
|
||||
b := buf.New()
|
||||
common.Must(b.Reset(func(bb []byte) (int, error) {
|
||||
return mb.Read(bb[:payloadSize])
|
||||
}))
|
||||
common.Must2(b.ReadFrom(io.LimitReader(&mb, int64(payloadSize))))
|
||||
eb, err := w.seal(b)
|
||||
b.Release()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/common/vio"
|
||||
)
|
||||
|
||||
type SessionStatus byte
|
||||
@ -60,11 +62,11 @@ type FrameMetadata struct {
|
||||
}
|
||||
|
||||
func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
|
||||
lenBytes := b.Bytes()
|
||||
common.Must2(b.WriteBytes(0x00, 0x00))
|
||||
lenBytes := b.Bytes()
|
||||
|
||||
len0 := b.Len()
|
||||
if err := b.AppendSupplier(serial.WriteUint16(f.SessionID)); err != nil {
|
||||
if _, err := vio.WriteUint16(b, f.SessionID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -84,7 +86,7 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
|
||||
}
|
||||
|
||||
len1 := b.Len()
|
||||
serial.Uint16ToBytes(uint16(len1-len0), lenBytes)
|
||||
binary.BigEndian.PutUint16(lenBytes, uint16(len1-len0))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -101,7 +103,7 @@ func (f *FrameMetadata) Unmarshal(reader io.Reader) error {
|
||||
b := buf.New()
|
||||
defer b.Release()
|
||||
|
||||
if err := b.Reset(buf.ReadFullFrom(reader, int32(metaLen))); err != nil {
|
||||
if _, err := b.ReadFullFrom(reader, int32(metaLen)); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.UnmarshalFromBuffer(b)
|
||||
|
@ -38,7 +38,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
}
|
||||
|
||||
b := buf.New()
|
||||
if err := b.Reset(buf.ReadFullFrom(r.reader, int32(size))); err != nil {
|
||||
if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
|
||||
b.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/common/vio"
|
||||
)
|
||||
|
||||
type Writer struct {
|
||||
@ -66,7 +66,7 @@ func writeMetaWithFrame(writer buf.Writer, meta FrameMetadata, data buf.MultiBuf
|
||||
if err := meta.WriteTo(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := frame.AppendSupplier(serial.WriteUint16(uint16(data.Len()))); err != nil {
|
||||
if _, err := vio.WriteUint16(frame, uint16(data.Len())); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ func NewAddressParser(options ...AddressOption) *AddressParser {
|
||||
}
|
||||
|
||||
func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
|
||||
if _, err := b.ReadFullFrom(reader, 2); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return net.PortFromBytes(b.BytesFrom(-2)), nil
|
||||
@ -73,7 +73,7 @@ func isValidDomain(d string) bool {
|
||||
}
|
||||
|
||||
func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
|
||||
if _, err := b.ReadFullFrom(reader, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -89,21 +89,21 @@ func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres
|
||||
|
||||
switch addrFamily {
|
||||
case net.AddressFamilyIPv4:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
|
||||
if _, err := b.ReadFullFrom(reader, 4); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.IPAddress(b.BytesFrom(-4)), nil
|
||||
case net.AddressFamilyIPv6:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
|
||||
if _, err := b.ReadFullFrom(reader, 16); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.IPAddress(b.BytesFrom(-16)), nil
|
||||
case net.AddressFamilyDomain:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
|
||||
if _, err := b.ReadFullFrom(reader, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domainLength := int32(b.Byte(b.Len() - 1))
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
|
||||
if _, err := b.ReadFullFrom(reader, domainLength); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domain := string(b.BytesFrom(-domainLength))
|
||||
|
@ -20,13 +20,6 @@ func ReadUint16(reader io.Reader) (uint16, error) {
|
||||
return BytesToUint16(b[:]), nil
|
||||
}
|
||||
|
||||
func WriteUint16(value uint16) func([]byte) (int, error) {
|
||||
return func(b []byte) (int, error) {
|
||||
Uint16ToBytes(value, b[:0])
|
||||
return 2, nil
|
||||
}
|
||||
}
|
||||
|
||||
func Uint32ToBytes(value uint32, b []byte) []byte {
|
||||
return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value))
|
||||
}
|
||||
|
18
common/vio/serial.go
Normal file
18
common/vio/serial.go
Normal file
@ -0,0 +1,18 @@
|
||||
package vio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
func WriteUint32(writer io.Writer, value uint32) (int, error) {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], value)
|
||||
return writer.Write(b[:])
|
||||
}
|
||||
|
||||
func WriteUint16(writer io.Writer, value uint16) (int, error) {
|
||||
var b [2]byte
|
||||
binary.BigEndian.PutUint16(b[:], value)
|
||||
return writer.Write(b[:])
|
||||
}
|
24
common/vio/serial_test.go
Normal file
24
common/vio/serial_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
package vio_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/compare"
|
||||
"v2ray.com/core/common/vio"
|
||||
)
|
||||
|
||||
func TestUint32Serial(t *testing.T) {
|
||||
b := buf.New()
|
||||
defer b.Release()
|
||||
|
||||
n, err := vio.WriteUint32(b, 10)
|
||||
common.Must(err)
|
||||
if n != 4 {
|
||||
t.Error("expect 4 bytes writtng, but actually ", n)
|
||||
}
|
||||
if err := compare.BytesEqualWithDetail(b.Bytes(), []byte{0, 0, 0, 10}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
@ -36,7 +36,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
|
||||
ivLen := account.Cipher.IVSize()
|
||||
var iv []byte
|
||||
if ivLen > 0 {
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, ivLen)); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil {
|
||||
return nil, nil, newError("failed to read IV").Base(err)
|
||||
}
|
||||
|
||||
@ -85,7 +85,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
|
||||
actualAuth := make([]byte, AuthSize)
|
||||
authenticator.Authenticate(buffer.Bytes())(actualAuth)
|
||||
|
||||
err := buffer.AppendSupplier(buf.ReadFullFrom(br, AuthSize))
|
||||
_, err := buffer.ReadFullFrom(br, AuthSize)
|
||||
if err != nil {
|
||||
return nil, nil, newError("Failed to read OTA").Base(err)
|
||||
}
|
||||
@ -196,7 +196,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
|
||||
buffer := buf.New()
|
||||
ivLen := account.Cipher.IVSize()
|
||||
if ivLen > 0 {
|
||||
common.Must(buffer.Reset(buf.ReadFullFrom(rand.Reader, ivLen)))
|
||||
common.Must2(buffer.ReadFullFrom(rand.Reader, ivLen))
|
||||
}
|
||||
iv := buffer.Bytes()
|
||||
|
||||
@ -287,7 +287,7 @@ type UDPReader struct {
|
||||
|
||||
func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
buffer := buf.New()
|
||||
err := buffer.AppendSupplier(buf.ReadFrom(v.Reader))
|
||||
_, err := buffer.ReadFrom(v.Reader)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, err
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/common/vio"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -49,7 +49,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
||||
|
||||
request := new(protocol.RequestHeader)
|
||||
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
|
||||
return nil, newError("insufficient header").Base(err)
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
||||
return nil, newError("socks 4 is not allowed when auth is required.")
|
||||
}
|
||||
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 6)); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(reader, 6); err != nil {
|
||||
return nil, newError("insufficient header").Base(err)
|
||||
}
|
||||
port := net.PortFromBytes(buffer.BytesRange(2, 4))
|
||||
@ -94,7 +94,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
||||
|
||||
if version == socks5Version {
|
||||
nMethod := int32(buffer.Byte(1))
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, nMethod)); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(reader, nMethod); err != nil {
|
||||
return nil, newError("failed to read auth methods").Base(err)
|
||||
}
|
||||
|
||||
@ -127,7 +127,9 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
||||
return nil, newError("failed to write auth response").Base(err)
|
||||
}
|
||||
}
|
||||
if err := buffer.Reset(buf.ReadFullFrom(reader, 3)); err != nil {
|
||||
|
||||
buffer.Clear()
|
||||
if _, err := buffer.ReadFullFrom(reader, 3); err != nil {
|
||||
return nil, newError("failed to read request").Base(err)
|
||||
}
|
||||
|
||||
@ -185,21 +187,25 @@ func readUsernamePassword(reader io.Reader) (string, string, error) {
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
|
||||
if err := buffer.Reset(buf.ReadFullFrom(reader, 2)); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
nUsername := int32(buffer.Byte(1))
|
||||
|
||||
if err := buffer.Reset(buf.ReadFullFrom(reader, nUsername)); err != nil {
|
||||
buffer.Clear()
|
||||
if _, err := buffer.ReadFullFrom(reader, nUsername); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
username := buffer.String()
|
||||
|
||||
if err := buffer.Reset(buf.ReadFullFrom(reader, 1)); err != nil {
|
||||
buffer.Clear()
|
||||
if _, err := buffer.ReadFullFrom(reader, 1); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
nPassword := int32(buffer.Byte(0))
|
||||
if err := buffer.Reset(buf.ReadFullFrom(reader, nPassword)); err != nil {
|
||||
|
||||
buffer.Clear()
|
||||
if _, err := buffer.ReadFullFrom(reader, nPassword); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
password := buffer.String()
|
||||
@ -254,7 +260,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
|
||||
defer buffer.Release()
|
||||
|
||||
common.Must2(buffer.WriteBytes(0x00, errCode))
|
||||
common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
|
||||
common.Must2(vio.WriteUint16(buffer, port.Value()))
|
||||
common.Must2(buffer.Write(address.IP()))
|
||||
return buf.WriteAllBytes(writer, buffer.Bytes())
|
||||
}
|
||||
@ -305,7 +311,7 @@ func NewUDPReader(reader io.Reader) *UDPReader {
|
||||
|
||||
func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
b := buf.New()
|
||||
if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
|
||||
if _, err := b.ReadFrom(r.reader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := DecodeUDPPacket(b); err != nil {
|
||||
@ -362,7 +368,8 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := b.Reset(buf.ReadFullFrom(reader, 2)); err != nil {
|
||||
b.Clear()
|
||||
if _, err := b.ReadFullFrom(reader, 2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -374,7 +381,8 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
|
||||
}
|
||||
|
||||
if authByte == authPassword {
|
||||
if err := b.Reset(buf.ReadFullFrom(reader, 2)); err != nil {
|
||||
b.Clear()
|
||||
if _, err := b.ReadFullFrom(reader, 2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if b.Byte(1) != 0x00 {
|
||||
@ -398,7 +406,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
|
||||
}
|
||||
|
||||
b.Clear()
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil {
|
||||
if _, err := b.ReadFullFrom(reader, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -80,7 +80,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
|
||||
}
|
||||
|
||||
if padingLen > 0 {
|
||||
common.Must(buffer.AppendSupplier(buf.ReadFullFrom(rand.Reader, int32(padingLen))))
|
||||
common.Must2(buffer.ReadFullFrom(rand.Reader, int32(padingLen)))
|
||||
}
|
||||
|
||||
{
|
||||
@ -164,7 +164,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(c.responseReader, 4)); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(c.responseReader, 4); err != nil {
|
||||
return nil, newError("failed to read response header").Base(err)
|
||||
}
|
||||
|
||||
@ -180,7 +180,8 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
|
||||
cmdID := buffer.Byte(2)
|
||||
dataLen := int32(buffer.Byte(3))
|
||||
|
||||
if err := buffer.Reset(buf.ReadFullFrom(c.responseReader, dataLen)); err != nil {
|
||||
buffer.Clear()
|
||||
if _, err := buffer.ReadFullFrom(c.responseReader, dataLen); err != nil {
|
||||
return nil, newError("failed to read response command").Base(err)
|
||||
}
|
||||
command, err := UnmarshalCommand(cmdID, buffer.Bytes())
|
||||
|
@ -125,7 +125,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, protocol.IDBytesLen)); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(reader, protocol.IDBytesLen); err != nil {
|
||||
return nil, newError("failed to read request header").Base(err)
|
||||
}
|
||||
|
||||
@ -140,7 +140,8 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
||||
aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:])
|
||||
decryptor := crypto.NewCryptionReader(aesStream, reader)
|
||||
|
||||
if err := buffer.Reset(buf.ReadFullFrom(decryptor, 38)); err != nil {
|
||||
buffer.Clear()
|
||||
if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil {
|
||||
return nil, newError("failed to read request header").Base(err)
|
||||
}
|
||||
|
||||
@ -178,12 +179,12 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
||||
}
|
||||
|
||||
if padingLen > 0 {
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, int32(padingLen))); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(decryptor, int32(padingLen)); err != nil {
|
||||
return nil, newError("failed to read padding").Base(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil {
|
||||
if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil {
|
||||
return nil, newError("failed to read checksum").Base(err)
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,7 @@ func (server *Server) handleConnection(conn net.Conn) {
|
||||
|
||||
for {
|
||||
b := buf.New()
|
||||
if err := b.AppendSupplier(buf.ReadFrom(conn)); err != nil {
|
||||
if _, err := b.ReadFrom(conn); err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func TestListen(t *testing.T) {
|
||||
defer conn.Close()
|
||||
|
||||
b := buf.New()
|
||||
common.Must(b.Reset(buf.ReadFrom(conn)))
|
||||
common.Must2(b.ReadFrom(conn))
|
||||
assert(b.String(), Equals, "Request")
|
||||
|
||||
common.Must2(conn.Write([]byte("Response")))
|
||||
@ -44,7 +44,7 @@ func TestListen(t *testing.T) {
|
||||
assert(err, IsNil)
|
||||
|
||||
b := buf.New()
|
||||
common.Must(b.Reset(buf.ReadFrom(conn)))
|
||||
common.Must2(b.ReadFrom(conn))
|
||||
|
||||
assert(b.String(), Equals, "Response")
|
||||
}
|
||||
@ -67,7 +67,7 @@ func TestListenAbstract(t *testing.T) {
|
||||
defer conn.Close()
|
||||
|
||||
b := buf.New()
|
||||
common.Must(b.Reset(buf.ReadFrom(conn)))
|
||||
common.Must2(b.ReadFrom(conn))
|
||||
assert(b.String(), Equals, "Request")
|
||||
|
||||
common.Must2(conn.Write([]byte("Response")))
|
||||
@ -83,7 +83,7 @@ func TestListenAbstract(t *testing.T) {
|
||||
assert(err, IsNil)
|
||||
|
||||
b := buf.New()
|
||||
common.Must(b.Reset(buf.ReadFrom(conn)))
|
||||
common.Must2(b.ReadFrom(conn))
|
||||
|
||||
assert(b.String(), Equals, "Response")
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
|
||||
totalBytes := int32(0)
|
||||
endingDetected := false
|
||||
for totalBytes < maxHeaderLength {
|
||||
err := buffer.AppendSupplier(buf.ReadFrom(reader))
|
||||
_, err := buffer.ReadFrom(reader)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return nil, err
|
||||
|
@ -39,7 +39,7 @@ func TestHTTPConnection(t *testing.T) {
|
||||
defer b.Release()
|
||||
|
||||
for {
|
||||
if err := b.Reset(buf.ReadFrom(conn)); err != nil {
|
||||
if _, err := b.ReadFrom(conn); err != nil {
|
||||
return
|
||||
}
|
||||
nBytes, err := conn.Write(b.Bytes())
|
||||
@ -76,13 +76,15 @@ func TestHTTPConnection(t *testing.T) {
|
||||
assert(nBytes, Equals, N)
|
||||
assert(err, IsNil)
|
||||
|
||||
assert(b2.Reset(buf.ReadFullFrom(conn, N)), IsNil)
|
||||
b2.Clear()
|
||||
common.Must2(b2.ReadFullFrom(conn, N))
|
||||
assert(b2.Bytes(), Equals, b1)
|
||||
|
||||
nBytes, err = conn.Write(b1)
|
||||
assert(nBytes, Equals, N)
|
||||
assert(err, IsNil)
|
||||
|
||||
assert(b2.Reset(buf.ReadFullFrom(conn, N)), IsNil)
|
||||
b2.Clear()
|
||||
common.Must2(b2.ReadFullFrom(conn, N))
|
||||
assert(b2.Bytes(), Equals, b1)
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn
|
||||
go func() {
|
||||
for {
|
||||
payload := buf.New()
|
||||
if err := payload.Reset(buf.ReadFrom(input)); err != nil {
|
||||
if _, err := payload.ReadFrom(input); err != nil {
|
||||
payload.Release()
|
||||
close(cache)
|
||||
return
|
||||
|
@ -2,6 +2,7 @@ package kcp
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
@ -274,9 +275,7 @@ func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool {
|
||||
}
|
||||
|
||||
b := buf.New()
|
||||
common.Must(b.Reset(func(v []byte) (int, error) {
|
||||
return mb.Read(v[:w.conn.mss])
|
||||
}))
|
||||
common.Must2(b.ReadFrom(io.LimitReader(mb, int64(w.conn.mss))))
|
||||
w.window.Push(w.nextNumber, b)
|
||||
w.nextNumber++
|
||||
return true
|
||||
|
@ -40,7 +40,7 @@ func TestTCPFastOpen(t *testing.T) {
|
||||
common.Must(err)
|
||||
|
||||
b := buf.New()
|
||||
common.Must(b.Reset(buf.ReadFrom(conn)))
|
||||
common.Must2(b.ReadFrom(conn))
|
||||
if err := compare.BytesEqualWithDetail(b.Bytes(), []byte("abcd")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user