From f506a39d3256aa2a2864ca0e35032666353e5102 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sat, 15 Apr 2017 21:07:23 +0200 Subject: [PATCH] multi buffer --- app/proxyman/mux/mux.go | 29 +++------ app/proxyman/mux/mux_test.go | 7 +-- app/proxyman/mux/reader.go | 46 +++++++-------- app/proxyman/mux/writer.go | 42 ++++--------- app/proxyman/outbound/handler.go | 13 +++- common/buf/io.go | 7 ++- common/buf/merge_reader.go | 31 ++-------- common/buf/merge_reader_test.go | 8 +-- common/buf/multi_buffer.go | 88 ++++++++++++++++++++++++++++ common/buf/multi_buffer_test.go | 25 ++++++++ common/buf/reader.go | 42 ++++--------- common/buf/reader_test.go | 11 +--- common/buf/writer.go | 47 ++++++--------- common/buf/writer_test.go | 2 +- common/crypto/auth.go | 30 ++++++++-- proxy/blackhole/config.go | 4 +- proxy/shadowsocks/ota.go | 28 ++++++--- proxy/shadowsocks/ota_test.go | 4 +- proxy/shadowsocks/protocol.go | 21 +++++-- proxy/shadowsocks/protocol_test.go | 12 ++-- proxy/shadowsocks/server.go | 80 +++++++++++++------------ proxy/socks/protocol.go | 21 ++++--- proxy/socks/protocol_test.go | 4 +- proxy/socks/server.go | 57 +++++++++--------- proxy/vmess/inbound/inbound.go | 2 +- proxy/vmess/outbound/outbound.go | 2 +- transport/internet/udp/dispatcher.go | 8 ++- transport/ray/direct.go | 10 ++-- transport/ray/direct_test.go | 6 +- 29 files changed, 390 insertions(+), 297 deletions(-) create mode 100644 common/buf/multi_buffer.go create mode 100644 common/buf/multi_buffer_test.go diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index c3606f9c2..f8377dd61 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -180,31 +180,20 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool } func drain(reader *Reader) error { - for { - data, more, err := reader.Read() - if err != nil { - return err - } - data.Release() - if !more { - return nil - } + data, err := reader.Read() + if err != nil { + return err } + data.Release() + return nil } func pipe(reader *Reader, writer buf.Writer) error { - for { - data, more, err := reader.Read() - if err != nil { - return err - } - if err := writer.Write(data); err != nil { - return err - } - if !more { - return nil - } + data, err := reader.Read() + if err != nil { + return err } + return writer.Write(data) } func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *Reader) error { diff --git a/app/proxyman/mux/mux_test.go b/app/proxyman/mux/mux_test.go index aa95ad8be..75c647b85 100644 --- a/app/proxyman/mux/mux_test.go +++ b/app/proxyman/mux/mux_test.go @@ -20,7 +20,7 @@ func TestReaderWriter(t *testing.T) { payload := buf.New() payload.AppendBytes('a', 'b', 'c', 'd') - assert.Error(writer.Write(payload)).IsNil() + assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil() writer.Close() @@ -32,10 +32,9 @@ func TestReaderWriter(t *testing.T) { assert.Destination(meta.Target).Equals(dest) assert.Byte(byte(meta.Option)).Equals(byte(OptionData)) - data, more, err := reader.Read() + data, err := reader.Read() assert.Error(err).IsNil() - assert.Bool(more).IsFalse() - assert.String(data.String()).Equals("abcd") + assert.String(data[0].String()).Equals("abcd") meta, err = reader.ReadMetadata() assert.Error(err).IsNil() diff --git a/app/proxyman/mux/reader.go b/app/proxyman/mux/reader.go index 74ad05bc3..ad18c3f04 100644 --- a/app/proxyman/mux/reader.go +++ b/app/proxyman/mux/reader.go @@ -8,9 +8,8 @@ import ( ) type Reader struct { - reader io.Reader - remainingLength int - buffer *buf.Buffer + reader io.Reader + buffer *buf.Buffer } func NewReader(reader buf.Reader) *Reader { @@ -38,28 +37,27 @@ func (r *Reader) ReadMetadata() (*FrameMetadata, error) { return ReadFrameFrom(b.Bytes()) } -func (r *Reader) Read() (*buf.Buffer, bool, error) { - b := buf.New() - var dataLen int - if r.remainingLength > 0 { - dataLen = r.remainingLength - r.remainingLength = 0 - } else { - if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil { - return nil, false, err +func (r *Reader) Read() (buf.MultiBuffer, error) { + r.buffer.Clear() + if err := r.buffer.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil { + return nil, err + } + + dataLen := int(serial.BytesToUint16(r.buffer.Bytes())) + mb := buf.NewMultiBuffer() + for dataLen > 0 { + b := buf.New() + readLen := buf.Size + if dataLen < readLen { + readLen = dataLen } - dataLen = int(serial.BytesToUint16(b.Bytes())) - b.Clear() + if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, readLen)); err != nil { + mb.Release() + return nil, err + } + dataLen -= readLen + mb.Append(b) } - if dataLen > buf.Size { - r.remainingLength = dataLen - buf.Size - dataLen = buf.Size - } - - if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, dataLen)); err != nil { - return nil, false, err - } - - return b, (r.remainingLength > 0), nil + return mb, nil } diff --git a/app/proxyman/mux/writer.go b/app/proxyman/mux/writer.go index 9927b3008..30b76e28e 100644 --- a/app/proxyman/mux/writer.go +++ b/app/proxyman/mux/writer.go @@ -29,7 +29,7 @@ func NewResponseWriter(id uint16, writer buf.Writer) *Writer { } } -func (w *Writer) writeInternal(b *buf.Buffer) error { +func (w *Writer) Write(mb buf.MultiBuffer) error { meta := FrameMetadata{ SessionID: w.id, Target: w.dest, @@ -41,42 +41,21 @@ func (w *Writer) writeInternal(b *buf.Buffer) error { meta.SessionStatus = SessionStatusNew } - if b.Len() > 0 { + if mb.Len() > 0 { meta.Option.Add(OptionData) } frame := buf.New() frame.AppendSupplier(meta.AsSupplier()) - if b.Len() > 0 { - frame.AppendSupplier(serial.WriteUint16(0)) - lengthBytes := frame.BytesFrom(-2) + mb2 := buf.NewMultiBuffer() + mb2.Append(frame) - nBytes, err := frame.Write(b.Bytes()) - if err != nil { - frame.Release() - return err - } - - serial.Uint16ToBytes(uint16(nBytes), lengthBytes[:0]) - b.SliceFrom(nBytes) + if mb.Len() > 0 { + frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len()))) + mb2.AppendMulti(mb) } - - return w.writer.Write(frame) -} - -func (w *Writer) Write(b *buf.Buffer) error { - defer b.Release() - - if err := w.writeInternal(b); err != nil { - return err - } - for !b.IsEmpty() { - if err := w.writeInternal(b); err != nil { - return err - } - } - return nil + return w.writer.Write(mb2) } func (w *Writer) Close() { @@ -88,5 +67,8 @@ func (w *Writer) Close() { frame := buf.New() frame.AppendSupplier(meta.AsSupplier()) - w.writer.Write(frame) + mb := buf.NewMultiBuffer() + mb.Append(frame) + + w.writer.Write(mb) } diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 036a46347..0e037a1b6 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -129,7 +129,7 @@ type Connection struct { remoteAddr net.Addr reader io.Reader - writer io.Writer + writer buf.Writer } func NewConnection(stream ray.Ray) *Connection { @@ -144,7 +144,7 @@ func NewConnection(stream ray.Ray) *Connection { Port: 0, }, reader: buf.ToBytesReader(stream.InboundOutput()), - writer: buf.ToBytesWriter(stream.InboundInput()), + writer: stream.InboundInput(), } } @@ -161,7 +161,14 @@ func (v *Connection) Write(b []byte) (int, error) { if v.closed { return 0, io.ErrClosedPipe } - return v.writer.Write(b) + return buf.ToBytesWriter(v.writer).Write(b) +} + +func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) { + if v.closed { + return 0, io.ErrClosedPipe + } + return mb.Len(), v.writer.Write(mb) } // Close implements net.Conn.Close(). diff --git a/common/buf/io.go b/common/buf/io.go index 9beb2ccc4..18d7029c8 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -11,19 +11,19 @@ import ( // Reader extends io.Reader with alloc.Buffer. type Reader interface { // Read reads content from underlying reader, and put it into an alloc.Buffer. - Read() (*Buffer, error) + Read() (MultiBuffer, error) } var ErrReadTimeout = newError("IO timeout") type TimeoutReader interface { - ReadTimeout(time.Duration) (*Buffer, error) + ReadTimeout(time.Duration) (MultiBuffer, error) } // Writer extends io.Writer with alloc.Buffer. type Writer interface { // Write writes an alloc.Buffer into underlying writer. - Write(*Buffer) error + Write(MultiBuffer) error } // ReadFrom creates a Supplier to read from a given io.Reader. @@ -78,6 +78,7 @@ func PipeUntilEOF(timer signal.ActivityTimer, reader Reader, writer Writer) erro func NewReader(reader io.Reader) Reader { return &BytesToBufferReader{ reader: reader, + buffer: NewLocal(32 * 1024), } } diff --git a/common/buf/merge_reader.go b/common/buf/merge_reader.go index a272163de..9d3f43752 100644 --- a/common/buf/merge_reader.go +++ b/common/buf/merge_reader.go @@ -3,7 +3,6 @@ package buf type MergingReader struct { reader Reader timeoutReader TimeoutReader - leftover *Buffer } func NewMergingReader(reader Reader) Reader { @@ -13,41 +12,23 @@ func NewMergingReader(reader Reader) Reader { } } -func (r *MergingReader) Read() (*Buffer, error) { - if r.leftover != nil { - b := r.leftover - r.leftover = nil - return b, nil - } - - b, err := r.reader.Read() +func (r *MergingReader) Read() (MultiBuffer, error) { + mb, err := r.reader.Read() if err != nil { return nil, err } - if b.IsFull() { - return b, nil - } - if r.timeoutReader == nil { - return b, nil + return mb, nil } for { - b2, err := r.timeoutReader.ReadTimeout(0) + mb2, err := r.timeoutReader.ReadTimeout(0) if err != nil { break } - - nBytes := b.Append(b2.Bytes()) - b2.SliceFrom(nBytes) - if b2.IsEmpty() { - b2.Release() - } else { - r.leftover = b2 - break - } + mb.AppendMulti(mb2) } - return b, nil + return mb, nil } diff --git a/common/buf/merge_reader_test.go b/common/buf/merge_reader_test.go index 57cbdfc48..48a0aefec 100644 --- a/common/buf/merge_reader_test.go +++ b/common/buf/merge_reader_test.go @@ -16,18 +16,18 @@ func TestMergingReader(t *testing.T) { stream := ray.NewStream(context.Background()) b1 := New() b1.AppendBytes('a', 'b', 'c') - stream.Write(b1) + stream.Write(NewMultiBufferValue(b1)) b2 := New() b2.AppendBytes('e', 'f', 'g') - stream.Write(b2) + stream.Write(NewMultiBufferValue(b2)) b3 := New() b3.AppendBytes('h', 'i', 'j') - stream.Write(b3) + stream.Write(NewMultiBufferValue(b3)) reader := NewMergingReader(stream) b, err := reader.Read() assert.Error(err).IsNil() - assert.String(b.String()).Equals("abcefghij") + assert.Int(b.Len()).Equals(9) } diff --git a/common/buf/multi_buffer.go b/common/buf/multi_buffer.go new file mode 100644 index 000000000..3c5352427 --- /dev/null +++ b/common/buf/multi_buffer.go @@ -0,0 +1,88 @@ +package buf + +import ( + "io" + "net" +) + +type MultiBufferWriter interface { + WriteMultiBuffer(MultiBuffer) (int, error) +} + +type MultiBuffer []*Buffer + +func NewMultiBuffer() MultiBuffer { + return MultiBuffer(make([]*Buffer, 0, 8)) +} + +func NewMultiBufferValue(b ...*Buffer) MultiBuffer { + return MultiBuffer(b) +} + +func (b *MultiBuffer) Append(buf *Buffer) { + *b = append(*b, buf) +} + +func (b *MultiBuffer) AppendMulti(mb MultiBuffer) { + *b = append(*b, mb...) +} + +func (mb *MultiBuffer) Read(b []byte) (int, error) { + if len(*mb) == 0 { + return 0, io.EOF + } + endIndex := len(*mb) + totalBytes := 0 + for i, bb := range *mb { + nBytes, err := bb.Read(b) + totalBytes += nBytes + if err != nil { + return totalBytes, err + } + b = b[nBytes:] + if bb.IsEmpty() { + bb.Release() + } else { + endIndex = i + break + } + } + *mb = (*mb)[endIndex:] + return totalBytes, nil +} + +func (mb MultiBuffer) WriteTo(writer io.Writer) (int, error) { + if mw, ok := writer.(MultiBufferWriter); ok { + return mw.WriteMultiBuffer(mb) + } + bs := make([][]byte, len(mb)) + for i, b := range mb { + bs[i] = b.Bytes() + } + nbs := net.Buffers(bs) + nBytes, err := nbs.WriteTo(writer) + return int(nBytes), err +} + +func (mb MultiBuffer) Len() int { + size := 0 + for _, b := range mb { + size += b.Len() + } + return size +} + +func (mb MultiBuffer) IsEmpty() bool { + for _, b := range mb { + if !b.IsEmpty() { + return false + } + } + return true +} + +func (mb MultiBuffer) Release() { + for _, b := range mb { + b.Release() + } +} diff --git a/common/buf/multi_buffer_test.go b/common/buf/multi_buffer_test.go new file mode 100644 index 000000000..9b9e0c938 --- /dev/null +++ b/common/buf/multi_buffer_test.go @@ -0,0 +1,25 @@ +package buf_test + +import ( + "testing" + + . "v2ray.com/core/common/buf" + "v2ray.com/core/testing/assert" +) + +func TestMultiBufferRead(t *testing.T) { + assert := assert.On(t) + + b1 := New() + b1.AppendBytes('a', 'b') + + b2 := New() + b2.AppendBytes('c', 'd') + mb := NewMultiBufferValue(b1, b2) + + bs := make([]byte, 32) + nBytes, err := mb.Read(bs) + assert.Error(err).IsNil() + assert.Int(nBytes).Equals(4) + assert.Bytes(bs[:nBytes]).Equals([]byte("abcd")) +} diff --git a/common/buf/reader.go b/common/buf/reader.go index df5bd67cb..4f050b690 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -4,48 +4,28 @@ import "io" // BytesToBufferReader is a Reader that adjusts its reading speed automatically. type BytesToBufferReader struct { - reader io.Reader - largeBuffer *Buffer - highVolumn bool + reader io.Reader + buffer *Buffer } // Read implements Reader.Read(). -func (v *BytesToBufferReader) Read() (*Buffer, error) { - if v.highVolumn && v.largeBuffer.IsEmpty() { - if v.largeBuffer == nil { - v.largeBuffer = NewLocal(32 * 1024) - } - err := v.largeBuffer.AppendSupplier(ReadFrom(v.reader)) - if err != nil { - return nil, err - } - if v.largeBuffer.Len() < Size { - v.highVolumn = false - } - } - - buffer := New() - if !v.largeBuffer.IsEmpty() { - err := buffer.AppendSupplier(ReadFrom(v.largeBuffer)) - return buffer, err - } - - err := buffer.AppendSupplier(ReadFrom(v.reader)) - if err != nil { - buffer.Release() +func (v *BytesToBufferReader) Read() (MultiBuffer, error) { + if err := v.buffer.Reset(ReadFrom(v.reader)); err != nil { return nil, err } - if buffer.IsFull() { - v.highVolumn = true + mb := NewMultiBuffer() + for !v.buffer.IsEmpty() { + b := New() + b.AppendSupplier(ReadFrom(v.buffer)) + mb.Append(b) } - - return buffer, nil + return mb, nil } type bufferToBytesReader struct { stream Reader - current *Buffer + current MultiBuffer err error } diff --git a/common/buf/reader_test.go b/common/buf/reader_test.go index b6234a061..c0fce205a 100644 --- a/common/buf/reader_test.go +++ b/common/buf/reader_test.go @@ -15,14 +15,7 @@ func TestAdaptiveReader(t *testing.T) { buffer := bytes.NewBuffer(rawContent) reader := NewReader(buffer) - b1, err := reader.Read() + b, err := reader.Read() assert.Error(err).IsNil() - assert.Bool(b1.IsFull()).IsTrue() - assert.Int(b1.Len()).Equals(Size) - assert.Int(buffer.Len()).Equals(cap(rawContent) - Size) - - b2, err := reader.Read() - assert.Error(err).IsNil() - assert.Bool(b2.IsFull()).IsTrue() - assert.Int(buffer.Len()).Equals(1007616) + assert.Int(b.Len()).Equals(32 * 1024) } diff --git a/common/buf/writer.go b/common/buf/writer.go index 7a2601a99..84fa55ba6 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -8,39 +8,30 @@ type BufferToBytesWriter struct { } // Write implements Writer.Write(). Write() takes ownership of the given buffer. -func (v *BufferToBytesWriter) Write(buffer *Buffer) error { - defer buffer.Release() - for { - nBytes, err := v.writer.Write(buffer.Bytes()) - if err != nil { - return err - } - if nBytes == buffer.Len() { - break - } - buffer.SliceFrom(nBytes) - } - return nil +func (v *BufferToBytesWriter) Write(buffer MultiBuffer) error { + _, err := buffer.WriteTo(v.writer) + //buffer.Release() + return err } type bytesToBufferWriter struct { writer Writer } -func (v *bytesToBufferWriter) Write(payload []byte) (int, error) { - bytesWritten := 0 - size := len(payload) - for size > 0 { - buffer := New() - nBytes, _ := buffer.Write(payload) - size -= nBytes - payload = payload[nBytes:] - bytesWritten += nBytes - err := v.writer.Write(buffer) - if err != nil { - return bytesWritten, err - } +func (w *bytesToBufferWriter) Write(payload []byte) (int, error) { + mb := NewMultiBuffer() + for p := payload; len(p) > 0; { + b := New() + nBytes, _ := b.Write(p) + p = p[nBytes:] + mb.Append(b) } - - return bytesWritten, nil + if err := w.writer.Write(mb); err != nil { + return 0, err + } + return len(payload), nil +} + +func (w *bytesToBufferWriter) WriteMulteBuffer(mb MultiBuffer) (int, error) { + return mb.Len(), w.writer.Write(mb) } diff --git a/common/buf/writer_test.go b/common/buf/writer_test.go index c2ad61f30..afd49f49e 100644 --- a/common/buf/writer_test.go +++ b/common/buf/writer_test.go @@ -20,7 +20,7 @@ func TestWriter(t *testing.T) { writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024)) writer := NewWriter(NewBufferedWriter(writeBuffer)) - err := writer.Write(lb) + err := writer.Write(NewMultiBufferValue(lb)) assert.Error(err).IsNil() assert.Bytes(expectedBytes).Equals(writeBuffer.Bytes()) } diff --git a/common/crypto/auth.go b/common/crypto/auth.go index e297b46c2..4e21592e1 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -215,14 +215,34 @@ func NewAuthenticationWriter(auth Authenticator, writer io.Writer, sizeMask Uint } } -func (v *AuthenticationWriter) Write(b []byte) (int, error) { - cipherChunk, err := v.auth.Seal(v.buffer[2:2], b) +func (w *AuthenticationWriter) Write(b []byte) (int, error) { + cipherChunk, err := w.auth.Seal(w.buffer[2:2], b) if err != nil { return 0, err } - size := uint16(len(cipherChunk)) ^ v.sizeMask.Next() - serial.Uint16ToBytes(size, v.buffer[:0]) - _, err = v.writer.Write(v.buffer[:2+len(cipherChunk)]) + size := uint16(len(cipherChunk)) ^ w.sizeMask.Next() + serial.Uint16ToBytes(size, w.buffer[:0]) + _, err = w.writer.Write(w.buffer[:2+len(cipherChunk)]) return len(b), err } + +func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) { + const StartIndex = 17 * 1024 + var totalBytes int + for { + payloadLen, err := mb.Read(w.buffer[StartIndex:]) + if err != nil { + return 0, err + } + nBytes, err := w.Write(w.buffer[StartIndex : StartIndex+payloadLen]) + totalBytes += nBytes + if err != nil { + return totalBytes, err + } + if mb.IsEmpty() { + break + } + } + return totalBytes, nil +} diff --git a/proxy/blackhole/config.go b/proxy/blackhole/config.go index b87e9e405..7633fc85d 100644 --- a/proxy/blackhole/config.go +++ b/proxy/blackhole/config.go @@ -28,7 +28,9 @@ func (v *NoneResponse) WriteTo(buf.Writer) {} func (v *HTTPResponse) WriteTo(writer buf.Writer) { b := buf.NewLocal(512) b.AppendSupplier(serial.WriteString(http403response)) - writer.Write(b) + mb := buf.NewMultiBuffer() + mb.Append(b) + writer.Write(mb) } // GetInternalResponse converts response settings from proto to internal data structure. diff --git a/proxy/shadowsocks/ota.go b/proxy/shadowsocks/ota.go index 31a1794ad..3f2238638 100644 --- a/proxy/shadowsocks/ota.go +++ b/proxy/shadowsocks/ota.go @@ -68,7 +68,7 @@ func NewChunkReader(reader io.Reader, auth *Authenticator) *ChunkReader { } } -func (v *ChunkReader) Read() (*buf.Buffer, error) { +func (v *ChunkReader) Read() (buf.MultiBuffer, error) { buffer := buf.New() if err := buffer.AppendSupplier(buf.ReadFullFrom(v.reader, 2)); err != nil { buffer.Release() @@ -100,7 +100,10 @@ func (v *ChunkReader) Read() (*buf.Buffer, error) { } buffer.SliceFrom(AuthSize) - return buffer, nil + mb := buf.NewMultiBuffer() + mb.Append(buffer) + + return mb, nil } type ChunkWriter struct { @@ -117,11 +120,22 @@ func NewChunkWriter(writer io.Writer, auth *Authenticator) *ChunkWriter { } } -func (v *ChunkWriter) Write(payload *buf.Buffer) error { +func (w *ChunkWriter) Write(mb buf.MultiBuffer) error { + defer mb.Release() + + for _, b := range mb { + if err := w.writeInternal(b); err != nil { + return err + } + } + return nil +} + +func (w *ChunkWriter) writeInternal(payload *buf.Buffer) error { totalLength := payload.Len() - serial.Uint16ToBytes(uint16(totalLength), v.buffer[:0]) - v.auth.Authenticate(payload.Bytes())(v.buffer[2:]) - copy(v.buffer[2+AuthSize:], payload.Bytes()) - _, err := v.writer.Write(v.buffer[:2+AuthSize+payload.Len()]) + serial.Uint16ToBytes(uint16(totalLength), w.buffer[:0]) + w.auth.Authenticate(payload.Bytes())(w.buffer[2:]) + copy(w.buffer[2+AuthSize:], payload.Bytes()) + _, err := w.writer.Write(w.buffer[:2+AuthSize+payload.Len()]) return err } diff --git a/proxy/shadowsocks/ota_test.go b/proxy/shadowsocks/ota_test.go index 8a93eb341..dda777e1d 100644 --- a/proxy/shadowsocks/ota_test.go +++ b/proxy/shadowsocks/ota_test.go @@ -18,7 +18,7 @@ func TestNormalChunkReading(t *testing.T) { []byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}))) payload, err := reader.Read() assert.Error(err).IsNil() - assert.Bytes(payload.Bytes()).Equals([]byte{11, 12, 13, 14, 15, 16, 17, 18}) + assert.Bytes(payload[0].Bytes()).Equals([]byte{11, 12, 13, 14, 15, 16, 17, 18}) } func TestNormalChunkWriting(t *testing.T) { @@ -30,7 +30,7 @@ func TestNormalChunkWriting(t *testing.T) { b := buf.NewLocal(256) b.Append([]byte{11, 12, 13, 14, 15, 16, 17, 18}) - err := writer.Write(b) + err := writer.Write(buf.NewMultiBufferValue(b)) assert.Error(err).IsNil() assert.Bytes(buffer.Bytes()).Equals([]byte{0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18}) } diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 0269c1a02..52a60058f 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -362,7 +362,7 @@ type UDPReader struct { User *protocol.User } -func (v *UDPReader) Read() (*buf.Buffer, error) { +func (v *UDPReader) Read() (buf.MultiBuffer, error) { buffer := buf.NewSmall() err := buffer.AppendSupplier(buf.ReadFrom(v.Reader)) if err != nil { @@ -374,7 +374,9 @@ func (v *UDPReader) Read() (*buf.Buffer, error) { buffer.Release() return nil, err } - return payload, nil + mb := buf.NewMultiBuffer() + mb.Append(payload) + return mb, nil } type UDPWriter struct { @@ -382,12 +384,21 @@ type UDPWriter struct { Request *protocol.RequestHeader } -func (v *UDPWriter) Write(buffer *buf.Buffer) error { - payload, err := EncodeUDPPacket(v.Request, buffer) +func (w *UDPWriter) Write(mb buf.MultiBuffer) error { + for _, b := range mb { + if err := w.writeInternal(b); err != nil { + return err + } + } + return nil +} + +func (w *UDPWriter) writeInternal(buffer *buf.Buffer) error { + payload, err := EncodeUDPPacket(w.Request, buffer) if err != nil { return err } - _, err = v.Writer.Write(payload.Bytes()) + _, err = w.Writer.Write(payload.Bytes()) payload.Release() return err } diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go index 7d1ee6ff8..1326cd704 100644 --- a/proxy/shadowsocks/protocol_test.go +++ b/proxy/shadowsocks/protocol_test.go @@ -66,7 +66,7 @@ func TestTCPRequest(t *testing.T) { writer, err := WriteTCPRequest(request, cache) assert.Error(err).IsNil() - writer.Write(data) + writer.Write(buf.NewMultiBufferValue(data)) decodedRequest, reader, err := ReadTCPSession(request.User, cache) assert.Error(err).IsNil() @@ -75,7 +75,7 @@ func TestTCPRequest(t *testing.T) { decodedData, err := reader.Read() assert.Error(err).IsNil() - assert.String(decodedData.String()).Equals("test string") + assert.String(decodedData[0].String()).Equals("test string") } func TestUDPReaderWriter(t *testing.T) { @@ -106,19 +106,19 @@ func TestUDPReaderWriter(t *testing.T) { b := buf.New() b.AppendSupplier(serial.WriteString("test payload")) - err := writer.Write(b) + err := writer.Write(buf.NewMultiBufferValue(b)) assert.Error(err).IsNil() payload, err := reader.Read() assert.Error(err).IsNil() - assert.String(payload.String()).Equals("test payload") + assert.String(payload[0].String()).Equals("test payload") b = buf.New() b.AppendSupplier(serial.WriteString("test payload 2")) - err = writer.Write(b) + err = writer.Write(buf.NewMultiBufferValue(b)) assert.Error(err).IsNil() payload, err = reader.Read() assert.Error(err).IsNil() - assert.String(payload.String()).Equals("test payload 2") + assert.String(payload[0].String()).Equals("test payload 2") } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 925917e45..a969f857f 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -75,52 +75,54 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection reader := buf.NewReader(conn) for { - payload, err := reader.Read() + mpayload, err := reader.Read() if err != nil { break } - request, data, err := DecodeUDPPacket(v.user, payload) - if err != nil { - if source, ok := proxy.SourceFromContext(ctx); ok { - log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err)) - log.Access(source, "", log.AccessRejected, err) - } - payload.Release() - continue - } - - if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled { - log.Trace(newError("client payload enables OTA but server doesn't allow it")) - payload.Release() - continue - } - - if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled { - log.Trace(newError("client payload disables OTA but server forces it")) - payload.Release() - continue - } - - dest := request.Destination() - if source, ok := proxy.SourceFromContext(ctx); ok { - log.Access(source, dest, log.AccessAccepted, "") - } - log.Trace(newError("tunnelling request to ", dest)) - - ctx = protocol.ContextWithUser(ctx, request.User) - udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) { - defer payload.Release() - - data, err := EncodeUDPPacket(request, payload) + for _, payload := range mpayload { + request, data, err := DecodeUDPPacket(v.user, payload) if err != nil { - log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning()) - return + if source, ok := proxy.SourceFromContext(ctx); ok { + log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err)) + log.Access(source, "", log.AccessRejected, err) + } + payload.Release() + continue } - defer data.Release() - conn.Write(data.Bytes()) - }) + if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled { + log.Trace(newError("client payload enables OTA but server doesn't allow it")) + payload.Release() + continue + } + + if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled { + log.Trace(newError("client payload disables OTA but server forces it")) + payload.Release() + continue + } + + dest := request.Destination() + if source, ok := proxy.SourceFromContext(ctx); ok { + log.Access(source, dest, log.AccessAccepted, "") + } + log.Trace(newError("tunnelling request to ", dest)) + + ctx = protocol.ContextWithUser(ctx, request.User) + udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) { + defer payload.Release() + + data, err := EncodeUDPPacket(request, payload) + if err != nil { + log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning()) + return + } + defer data.Release() + + conn.Write(data.Bytes()) + }) + } } return nil diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 37274a648..238c58033 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -347,7 +347,7 @@ func NewUDPReader(reader io.Reader) *UDPReader { return &UDPReader{reader: reader} } -func (r *UDPReader) Read() (*buf.Buffer, error) { +func (r *UDPReader) Read() (buf.MultiBuffer, error) { b := buf.NewSmall() if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil { return nil, err @@ -358,7 +358,9 @@ func (r *UDPReader) Read() (*buf.Buffer, error) { } b.Clear() b.Append(data) - return b, nil + mb := buf.NewMultiBuffer() + mb.Append(b) + return mb, nil } type UDPWriter struct { @@ -373,12 +375,15 @@ func NewUDPWriter(request *protocol.RequestHeader, writer io.Writer) *UDPWriter } } -func (w *UDPWriter) Write(b *buf.Buffer) error { - eb := EncodeUDPPacket(w.request, b.Bytes()) - b.Release() - defer eb.Release() - if _, err := w.writer.Write(eb.Bytes()); err != nil { - return err +func (w *UDPWriter) Write(mb buf.MultiBuffer) error { + defer mb.Release() + + for _, b := range mb { + eb := EncodeUDPPacket(w.request, b.Bytes()) + defer eb.Release() + if _, err := w.writer.Write(eb.Bytes()); err != nil { + return err + } } return nil } diff --git a/proxy/socks/protocol_test.go b/proxy/socks/protocol_test.go index d71a9b3b0..7e2575d3f 100644 --- a/proxy/socks/protocol_test.go +++ b/proxy/socks/protocol_test.go @@ -24,11 +24,11 @@ func TestUDPEncoding(t *testing.T) { content := []byte{'a'} payload := buf.New() payload.Append(content) - assert.Error(writer.Write(payload)).IsNil() + assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil() reader := NewUDPReader(b) decodedPayload, err := reader.Read() assert.Error(err).IsNil() - assert.Bytes(decodedPayload.Bytes()).Equals(content) + assert.Bytes(decodedPayload[0].Bytes()).Equals(content) } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 910666788..901c9f8f2 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -159,38 +159,41 @@ func (v *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, reader := buf.NewReader(conn) for { - payload, err := reader.Read() + mpayload, err := reader.Read() if err != nil { return err } - request, data, err := DecodeUDPPacket(payload.Bytes()) - if err != nil { - log.Trace(newError("failed to parse UDP request").Base(err)) - continue + for _, payload := range mpayload { + request, data, err := DecodeUDPPacket(payload.Bytes()) + + if err != nil { + log.Trace(newError("failed to parse UDP request").Base(err)) + continue + } + + if len(data) == 0 { + continue + } + + log.Trace(newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug()) + if source, ok := proxy.SourceFromContext(ctx); ok { + log.Access(source, request.Destination, log.AccessAccepted, "") + } + + dataBuf := buf.NewSmall() + dataBuf.Append(data) + udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) { + defer payload.Release() + + log.Trace(newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug()) + + udpMessage := EncodeUDPPacket(request, payload.Bytes()) + defer udpMessage.Release() + + conn.Write(udpMessage.Bytes()) + }) } - - if len(data) == 0 { - continue - } - - log.Trace(newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug()) - if source, ok := proxy.SourceFromContext(ctx); ok { - log.Access(source, request.Destination, log.AccessAccepted, "") - } - - dataBuf := buf.NewSmall() - dataBuf.Append(data) - udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) { - defer payload.Release() - - log.Trace(newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug()) - - udpMessage := EncodeUDPPacket(request, payload.Bytes()) - defer udpMessage.Release() - - conn.Write(udpMessage.Bytes()) - }) } } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index b4864dc7b..d6137364f 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -166,7 +166,7 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio } if request.Option.Has(protocol.RequestOptionChunkStream) { - if err := bodyWriter.Write(buf.NewLocal(8)); err != nil { + if err := bodyWriter.Write(buf.NewMultiBuffer()); err != nil { return err } } diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index d3c286ed9..13ac94bca 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -133,7 +133,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial } if request.Option.Has(protocol.RequestOptionChunkStream) { - if err := bodyWriter.Write(buf.NewLocal(8)); err != nil { + if err := bodyWriter.Write(buf.NewMultiBuffer()); err != nil { return err } } diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go index 678f92cc2..01a71b455 100644 --- a/transport/internet/udp/dispatcher.go +++ b/transport/internet/udp/dispatcher.go @@ -57,7 +57,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination inboundRay, existing := v.getInboundRay(ctx, destination) outputStream := inboundRay.InboundInput() if outputStream != nil { - if err := outputStream.Write(payload); err != nil { + if err := outputStream.Write(buf.NewMultiBufferValue(payload)); err != nil { v.RemoveRay(destination) } } @@ -71,10 +71,12 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination func handleInput(input ray.InputStream, callback ResponseCallback) { for { - data, err := input.Read() + mb, err := input.Read() if err != nil { break } - callback(data) + for _, b := range mb { + callback(b) + } } } diff --git a/transport/ray/direct.go b/transport/ray/direct.go index c069bc5b0..9dd0a6253 100644 --- a/transport/ray/direct.go +++ b/transport/ray/direct.go @@ -42,7 +42,7 @@ func (v *directRay) InboundOutput() InputStream { } type Stream struct { - buffer chan *buf.Buffer + buffer chan buf.MultiBuffer ctx context.Context close chan bool err chan bool @@ -51,13 +51,13 @@ type Stream struct { func NewStream(ctx context.Context) *Stream { return &Stream{ ctx: ctx, - buffer: make(chan *buf.Buffer, bufferSize), + buffer: make(chan buf.MultiBuffer, bufferSize), close: make(chan bool), err: make(chan bool), } } -func (v *Stream) Read() (*buf.Buffer, error) { +func (v *Stream) Read() (buf.MultiBuffer, error) { select { case <-v.ctx.Done(): return nil, io.ErrClosedPipe @@ -79,7 +79,7 @@ func (v *Stream) Read() (*buf.Buffer, error) { } } -func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) { +func (v *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) { select { case <-v.ctx.Done(): return nil, io.ErrClosedPipe @@ -107,7 +107,7 @@ func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) { } } -func (v *Stream) Write(data *buf.Buffer) (err error) { +func (v *Stream) Write(data buf.MultiBuffer) (err error) { if data.IsEmpty() { return } diff --git a/transport/ray/direct_test.go b/transport/ray/direct_test.go index 4f1fd8679..20d0d443f 100644 --- a/transport/ray/direct_test.go +++ b/transport/ray/direct_test.go @@ -16,7 +16,7 @@ func TestStreamIO(t *testing.T) { stream := NewStream(context.Background()) b1 := buf.New() b1.AppendBytes('a') - assert.Error(stream.Write(b1)).IsNil() + assert.Error(stream.Write(buf.NewMultiBufferValue(b1))).IsNil() _, err := stream.Read() assert.Error(err).IsNil() @@ -27,7 +27,7 @@ func TestStreamIO(t *testing.T) { b2 := buf.New() b2.AppendBytes('b') - err = stream.Write(b2) + err = stream.Write(buf.NewMultiBufferValue(b2)) assert.Error(err).Equals(io.ErrClosedPipe) } @@ -37,7 +37,7 @@ func TestStreamClose(t *testing.T) { stream := NewStream(context.Background()) b1 := buf.New() b1.AppendBytes('a') - assert.Error(stream.Write(b1)).IsNil() + assert.Error(stream.Write(buf.NewMultiBufferValue(b1))).IsNil() stream.Close()