diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index 17007b241..5470a2937 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -161,7 +161,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) { log.Trace(newError("dispatching request to ", dest)) data, _ := s.input.ReadTimeout(time.Millisecond * 500) - if err := writer.Write(data); err != nil { + if err := writer.WriteMultiBuffer(data); err != nil { log.Trace(newError("failed to write first payload").Base(err)) return } @@ -234,7 +234,7 @@ func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error { func (m *Client) fetchOutput() { defer m.cancel() - reader := buf.ToBytesReader(m.inboundRay.InboundOutput()) + reader := buf.NewBufferedReader(m.inboundRay.InboundOutput()) for { meta, err := ReadMetadata(reader) @@ -396,7 +396,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error func (w *ServerWorker) run(ctx context.Context) { input := w.outboundRay.OutboundInput() - reader := buf.ToBytesReader(input) + reader := buf.NewBufferedReader(input) defer w.sessionManager.Close() diff --git a/app/proxyman/mux/mux_test.go b/app/proxyman/mux/mux_test.go index eacdd7d2d..bb6425c60 100644 --- a/app/proxyman/mux/mux_test.go +++ b/app/proxyman/mux/mux_test.go @@ -16,7 +16,7 @@ import ( func readAll(reader buf.Reader) (buf.MultiBuffer, error) { var mb buf.MultiBuffer for { - b, err := reader.Read() + b, err := reader.ReadMultiBuffer() if err == io.EOF { break } @@ -45,7 +45,7 @@ func TestReaderWriter(t *testing.T) { writePayload := func(writer *Writer, payload ...byte) error { b := buf.New() b.Append(payload) - return writer.Write(buf.NewMultiBufferValue(b)) + return writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) } assert(writePayload(writer, 'a', 'b', 'c', 'd'), IsNil) @@ -60,7 +60,7 @@ func TestReaderWriter(t *testing.T) { assert(writePayload(writer2, 'y'), IsNil) writer2.Close() - bytesReader := buf.ToBytesReader(stream) + bytesReader := buf.NewBufferedReader(stream) streamReader := NewStreamReader(bytesReader) meta, err := ReadMetadata(bytesReader) diff --git a/app/proxyman/mux/reader.go b/app/proxyman/mux/reader.go index acdfe75b0..80192dc9d 100644 --- a/app/proxyman/mux/reader.go +++ b/app/proxyman/mux/reader.go @@ -40,8 +40,8 @@ func NewPacketReader(reader io.Reader) *PacketReader { } } -// Read implements buf.Reader. -func (r *PacketReader) Read() (buf.MultiBuffer, error) { +// ReadMultiBuffer implements buf.Reader. +func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { if r.eof { return nil, io.EOF } @@ -79,8 +79,8 @@ func NewStreamReader(reader io.Reader) *StreamReader { } } -// Read implmenets buf.Reader. -func (r *StreamReader) Read() (buf.MultiBuffer, error) { +// ReadMultiBuffer implmenets buf.Reader. +func (r *StreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) { if r.leftOver == 0 { r.leftOver = -1 return nil, io.EOF diff --git a/app/proxyman/mux/writer.go b/app/proxyman/mux/writer.go index 39a3158c3..64c290bdf 100644 --- a/app/proxyman/mux/writer.go +++ b/app/proxyman/mux/writer.go @@ -56,7 +56,7 @@ func (w *Writer) writeMetaOnly() error { if err := b.Reset(meta.AsSupplier()); err != nil { return err } - return w.writer.Write(buf.NewMultiBufferValue(b)) + return w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) } func (w *Writer) writeData(mb buf.MultiBuffer) error { @@ -74,11 +74,11 @@ func (w *Writer) writeData(mb buf.MultiBuffer) error { mb2 := buf.NewMultiBufferCap(len(mb) + 1) mb2.Append(frame) mb2.AppendMulti(mb) - return w.writer.Write(mb2) + return w.writer.WriteMultiBuffer(mb2) } -// Write implements buf.MultiBufferWriter. -func (w *Writer) Write(mb buf.MultiBuffer) error { +// WriteMultiBuffer implements buf.Writer. +func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error { defer mb.Release() if mb.IsEmpty() { @@ -109,5 +109,5 @@ func (w *Writer) Close() { frame := buf.New() common.Must(frame.Reset(meta.AsSupplier())) - w.writer.Write(buf.NewMultiBufferValue(frame)) + w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(frame)) } diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 8c6d4bbd3..03027bdb7 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -123,8 +123,8 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn } var ( - _ buf.MultiBufferReader = (*Connection)(nil) - _ buf.MultiBufferWriter = (*Connection)(nil) + _ buf.Reader = (*Connection)(nil) + _ buf.Writer = (*Connection)(nil) ) type Connection struct { @@ -133,9 +133,8 @@ type Connection struct { localAddr net.Addr remoteAddr net.Addr - bytesReader io.Reader - reader buf.Reader - writer buf.Writer + reader *buf.BufferedReader + writer buf.Writer } func NewConnection(stream ray.Ray) *Connection { @@ -149,9 +148,8 @@ func NewConnection(stream ray.Ray) *Connection { IP: []byte{0, 0, 0, 0}, Port: 0, }, - bytesReader: buf.ToBytesReader(stream.InboundOutput()), - reader: stream.InboundOutput(), - writer: stream.InboundInput(), + reader: buf.NewBufferedReader(stream.InboundOutput()), + writer: stream.InboundInput(), } } @@ -160,11 +158,11 @@ func (v *Connection) Read(b []byte) (int, error) { if v.closed { return 0, io.EOF } - return v.bytesReader.Read(b) + return v.reader.Read(b) } func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { - return v.reader.Read() + return v.reader.ReadMultiBuffer() } // Write implements net.Conn.Write(). @@ -172,14 +170,19 @@ func (v *Connection) Write(b []byte) (int, error) { if v.closed { return 0, io.ErrClosedPipe } - return buf.ToBytesWriter(v.writer).Write(b) + + l := len(b) + mb := buf.NewMultiBufferCap(l/buf.Size + 1) + mb.Write(b) + return l, v.writer.WriteMultiBuffer(mb) } func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { if v.closed { return io.ErrClosedPipe } - return v.writer.Write(mb) + + return v.writer.WriteMultiBuffer(mb) } // Close implements net.Conn.Close(). diff --git a/common/buf/buffered_reader.go b/common/buf/buffered_reader.go deleted file mode 100644 index 66ef5c065..000000000 --- a/common/buf/buffered_reader.go +++ /dev/null @@ -1,53 +0,0 @@ -package buf - -import ( - "io" -) - -// BufferedReader is a reader with internal cache. -type BufferedReader struct { - reader io.Reader - buffer *Buffer - buffered bool -} - -// NewBufferedReader creates a new BufferedReader based on an io.Reader. -func NewBufferedReader(rawReader io.Reader) *BufferedReader { - return &BufferedReader{ - reader: rawReader, - buffer: NewLocal(1024), - buffered: true, - } -} - -// IsBuffered returns true if the internal cache is effective. -func (r *BufferedReader) IsBuffered() bool { - return r.buffered -} - -// SetBuffered is to enable or disable internal cache. If cache is disabled, -// Read() calls will be delegated to the underlying io.Reader directly. -func (r *BufferedReader) SetBuffered(cached bool) { - r.buffered = cached -} - -// Read implements io.Reader.Read(). -func (r *BufferedReader) Read(b []byte) (int, error) { - if !r.buffered || r.buffer == nil { - if !r.buffer.IsEmpty() { - return r.buffer.Read(b) - } - return r.reader.Read(b) - } - if r.buffer.IsEmpty() { - if err := r.buffer.Reset(ReadFrom(r.reader)); err != nil { - return 0, err - } - } - - if r.buffer.IsEmpty() { - return 0, nil - } - - return r.buffer.Read(b) -} diff --git a/common/buf/buffered_reader_test.go b/common/buf/buffered_reader_test.go deleted file mode 100644 index d646ee034..000000000 --- a/common/buf/buffered_reader_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package buf_test - -import ( - "crypto/rand" - "testing" - - . "v2ray.com/core/common/buf" - . "v2ray.com/ext/assert" -) - -func TestBufferedReader(t *testing.T) { - assert := With(t) - - content := New() - assert(content.AppendSupplier(ReadFrom(rand.Reader)), IsNil) - - len := content.Len() - - reader := NewBufferedReader(content) - assert(reader.IsBuffered(), IsTrue) - - payload := make([]byte, 16) - - nBytes, err := reader.Read(payload) - assert(nBytes, Equals, 16) - assert(err, IsNil) - - len2 := content.Len() - assert(len-len2, GreaterThan, 16) - - nBytes, err = reader.Read(payload) - assert(nBytes, Equals, 16) - assert(err, IsNil) - - assert(content.Len(), Equals, len2) -} diff --git a/common/buf/buffered_writer.go b/common/buf/buffered_writer.go deleted file mode 100644 index b3ff77a37..000000000 --- a/common/buf/buffered_writer.go +++ /dev/null @@ -1,73 +0,0 @@ -package buf - -import "io" - -// BufferedWriter is an io.Writer with internal buffer. It writes to underlying writer when buffer is full or on demand. -// This type is not thread safe. -type BufferedWriter struct { - writer io.Writer - buffer *Buffer - buffered bool -} - -// NewBufferedWriter creates a new BufferedWriter. -func NewBufferedWriter(writer io.Writer) *BufferedWriter { - return NewBufferedWriterSize(writer, 1024) -} - -// NewBufferedWriterSize creates a BufferedWriter with specified buffer size. -func NewBufferedWriterSize(writer io.Writer, size uint32) *BufferedWriter { - return &BufferedWriter{ - writer: writer, - buffer: NewLocal(int(size)), - buffered: true, - } -} - -// Write implements io.Writer. -func (w *BufferedWriter) Write(b []byte) (int, error) { - if !w.buffered || w.buffer == nil { - return w.writer.Write(b) - } - bytesWritten := 0 - for bytesWritten < len(b) { - nBytes, err := w.buffer.Write(b[bytesWritten:]) - if err != nil { - return bytesWritten, err - } - bytesWritten += nBytes - if w.buffer.IsFull() { - if err := w.Flush(); err != nil { - return bytesWritten, err - } - } - } - return bytesWritten, nil -} - -// Flush writes all buffered content into underlying writer, if any. -func (w *BufferedWriter) Flush() error { - defer w.buffer.Clear() - for !w.buffer.IsEmpty() { - nBytes, err := w.writer.Write(w.buffer.Bytes()) - if err != nil { - return err - } - w.buffer.SliceFrom(nBytes) - } - return nil -} - -// IsBuffered returns true if this BufferedWriter holds a buffer. -func (w *BufferedWriter) IsBuffered() bool { - return w.buffered -} - -// SetBuffered controls whether the BufferedWriter holds a buffer for writing. If not buffered, any write() calls into underlying writer directly. -func (w *BufferedWriter) SetBuffered(cached bool) error { - w.buffered = cached - if !cached && !w.buffer.IsEmpty() { - return w.Flush() - } - return nil -} diff --git a/common/buf/buffered_writer_test.go b/common/buf/buffered_writer_test.go deleted file mode 100644 index 9368633e4..000000000 --- a/common/buf/buffered_writer_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package buf_test - -import ( - "crypto/rand" - "testing" - - "v2ray.com/core/common" - . "v2ray.com/core/common/buf" - . "v2ray.com/ext/assert" -) - -func TestBufferedWriter(t *testing.T) { - assert := With(t) - - content := New() - - writer := NewBufferedWriter(content) - assert(writer.IsBuffered(), IsTrue) - - payload := make([]byte, 16) - - nBytes, err := writer.Write(payload) - assert(nBytes, Equals, 16) - assert(err, IsNil) - - assert(content.IsEmpty(), IsTrue) - - assert(writer.SetBuffered(false), IsNil) - assert(content.Len(), Equals, 16) -} - -func TestBufferedWriterLargePayload(t *testing.T) { - assert := With(t) - - content := NewLocal(128 * 1024) - - writer := NewBufferedWriter(content) - assert(writer.IsBuffered(), IsTrue) - - payload := make([]byte, 64*1024) - common.Must2(rand.Read(payload)) - - nBytes, err := writer.Write(payload[:512]) - assert(nBytes, Equals, 512) - assert(err, IsNil) - - assert(content.IsEmpty(), IsTrue) - - nBytes, err = writer.Write(payload[512:]) - assert(err, IsNil) - assert(writer.Flush(), IsNil) - assert(nBytes, Equals, 64*1024-512) - assert(content.Bytes(), Equals, payload) -} diff --git a/common/buf/copy.go b/common/buf/copy.go index 873e187aa..f8d4577e3 100644 --- a/common/buf/copy.go +++ b/common/buf/copy.go @@ -17,7 +17,7 @@ type copyHandler struct { } func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) { - mb, err := reader.Read() + mb, err := reader.ReadMultiBuffer() if err != nil { for _, handler := range h.onReadError { err = handler(err) @@ -27,7 +27,7 @@ func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) { } func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error { - err := writer.Write(mb) + err := writer.WriteMultiBuffer(mb) if err != nil { for _, handler := range h.onWriteError { err = handler(err) @@ -36,6 +36,10 @@ func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error { return err } +type SizeCounter struct { + Size int64 +} + type CopyOption func(*copyHandler) func IgnoreReaderError() CopyOption { @@ -62,6 +66,14 @@ func UpdateActivity(timer signal.ActivityUpdater) CopyOption { } } +func CountSize(sc *SizeCounter) CopyOption { + return func(handler *copyHandler) { + handler.onData = append(handler.onData, func(b MultiBuffer) { + sc.Size += int64(b.Len()) + }) + } +} + func copyInternal(reader Reader, writer Writer, handler *copyHandler) error { for { buffer, err := handler.readFrom(reader) diff --git a/common/buf/io.go b/common/buf/io.go index a7abeff1d..17debb2b5 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -5,10 +5,10 @@ import ( "time" ) -// Reader extends io.Reader with alloc.Buffer. +// Reader extends io.Reader with MultiBuffer. type Reader interface { - // Read reads content from underlying reader, and put it into an alloc.Buffer. - Read() (MultiBuffer, error) + // ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer. + ReadMultiBuffer() (MultiBuffer, error) } // ErrReadTimeout is an error that happens with IO timeout. @@ -19,10 +19,10 @@ type TimeoutReader interface { ReadTimeout(time.Duration) (MultiBuffer, error) } -// Writer extends io.Writer with alloc.Buffer. +// Writer extends io.Writer with MultiBuffer. type Writer interface { - // Write writes an alloc.Buffer into underlying writer. - Write(MultiBuffer) error + // WriteMultiBuffer writes a MultiBuffer into underlying writer. + WriteMultiBuffer(MultiBuffer) error } // ReadFrom creates a Supplier to read from a given io.Reader. @@ -49,45 +49,21 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier { // NewReader creates a new Reader. // The Reader instance doesn't take the ownership of reader. func NewReader(reader io.Reader) Reader { - if mr, ok := reader.(MultiBufferReader); ok { - return &readerAdpater{ - MultiBufferReader: mr, - } + if mr, ok := reader.(Reader); ok { + return mr } - return &BytesToBufferReader{ - reader: reader, - } -} - -// ToBytesReader converts a Reaaer to io.Reader. -func ToBytesReader(stream Reader) io.Reader { - return &bufferToBytesReader{ - stream: stream, - } + return NewBytesToBufferReader(reader) } // NewWriter creates a new Writer. func NewWriter(writer io.Writer) Writer { - if mw, ok := writer.(MultiBufferWriter); ok { - return &writerAdapter{ - writer: mw, - } + if mw, ok := writer.(Writer); ok { + return mw } return &BufferToBytesWriter{ - writer: writer, - } -} - -func NewMergingWriter(writer io.Writer) Writer { - return NewMergingWriterSize(writer, 4096) -} - -func NewMergingWriterSize(writer io.Writer, size uint32) Writer { - return &mergingWriter{ - writer: writer, - buffer: make([]byte, size), + Writer: writer, } } @@ -96,10 +72,3 @@ func NewSequentialWriter(writer io.Writer) Writer { writer: writer, } } - -// ToBytesWriter converts a Writer to io.Writer -func ToBytesWriter(writer Writer) io.Writer { - return &bytesToBufferWriter{ - writer: writer, - } -} diff --git a/common/buf/multi_buffer.go b/common/buf/multi_buffer.go index 724466307..b2814e5f8 100644 --- a/common/buf/multi_buffer.go +++ b/common/buf/multi_buffer.go @@ -8,16 +8,6 @@ import ( "v2ray.com/core/common/errors" ) -// MultiBufferWriter is a writer that writes MultiBuffer. -type MultiBufferWriter interface { - WriteMultiBuffer(MultiBuffer) error -} - -// MultiBufferReader is a reader that reader payload as MultiBuffer. -type MultiBufferReader interface { - ReadMultiBuffer() (MultiBuffer, error) -} - // ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF. func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) { mb := NewMultiBufferCap(128) diff --git a/common/buf/reader.go b/common/buf/reader.go index 6d52ea0a8..8a327b3ee 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -8,19 +8,19 @@ import ( // BytesToBufferReader is a Reader that adjusts its reading speed automatically. type BytesToBufferReader struct { - reader io.Reader + io.Reader buffer []byte } func NewBytesToBufferReader(reader io.Reader) Reader { return &BytesToBufferReader{ - reader: reader, + Reader: reader, } } func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) { b := New() - if err := b.Reset(ReadFrom(r.reader)); err != nil { + if err := b.Reset(ReadFrom(r.Reader)); err != nil { b.Release() return nil, err } @@ -30,13 +30,13 @@ func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) { return NewMultiBufferValue(b), nil } -// Read implements Reader.Read(). -func (r *BytesToBufferReader) Read() (MultiBuffer, error) { +// ReadMultiBuffer implements Reader. +func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) { if r.buffer == nil { return r.readSmall() } - nBytes, err := r.reader.Read(r.buffer) + nBytes, err := r.Reader.Read(r.buffer) if err != nil { return nil, err } @@ -46,20 +46,33 @@ func (r *BytesToBufferReader) Read() (MultiBuffer, error) { return mb, nil } -type readerAdpater struct { - MultiBufferReader +type BufferedReader struct { + stream Reader + legacyReader io.Reader + leftOver MultiBuffer + buffered bool } -func (r *readerAdpater) Read() (MultiBuffer, error) { - return r.ReadMultiBuffer() +func NewBufferedReader(reader Reader) *BufferedReader { + r := &BufferedReader{ + stream: reader, + buffered: true, + } + if lr, ok := reader.(io.Reader); ok { + r.legacyReader = lr + } + return r } -type bufferToBytesReader struct { - stream Reader - leftOver MultiBuffer +func (r *BufferedReader) SetBuffered(f bool) { + r.buffered = f } -func (r *bufferToBytesReader) Read(b []byte) (int, error) { +func (r *BufferedReader) IsBuffered() bool { + return r.buffered +} + +func (r *BufferedReader) Read(b []byte) (int, error) { if r.leftOver != nil { nBytes, _ := r.leftOver.Read(b) if r.leftOver.IsEmpty() { @@ -69,7 +82,11 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) { return nBytes, nil } - mb, err := r.stream.Read() + if !r.buffered && r.legacyReader != nil { + return r.legacyReader.Read(b) + } + + mb, err := r.stream.ReadMultiBuffer() if err != nil { return 0, err } @@ -81,39 +98,39 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) { return nBytes, nil } -func (r *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) { +func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) { if r.leftOver != nil { mb := r.leftOver r.leftOver = nil return mb, nil } - return r.stream.Read() + return r.stream.ReadMultiBuffer() } -func (r *bufferToBytesReader) writeToInternal(writer io.Writer) (int64, error) { +func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) { mbWriter := NewWriter(writer) totalBytes := int64(0) if r.leftOver != nil { totalBytes += int64(r.leftOver.Len()) - if err := mbWriter.Write(r.leftOver); err != nil { + if err := mbWriter.WriteMultiBuffer(r.leftOver); err != nil { return 0, err } } for { - mb, err := r.stream.Read() + mb, err := r.stream.ReadMultiBuffer() if err != nil { return totalBytes, err } totalBytes += int64(mb.Len()) - if err := mbWriter.Write(mb); err != nil { + if err := mbWriter.WriteMultiBuffer(mb); err != nil { return totalBytes, err } } } -func (r *bufferToBytesReader) WriteTo(writer io.Writer) (int64, error) { +func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) { nBytes, err := r.writeToInternal(writer) if errors.Cause(err) == io.EOF { return nBytes, nil diff --git a/common/buf/reader_test.go b/common/buf/reader_test.go index f33ed4ced..35c144323 100644 --- a/common/buf/reader_test.go +++ b/common/buf/reader_test.go @@ -15,11 +15,11 @@ func TestAdaptiveReader(t *testing.T) { assert := With(t) reader := NewReader(bytes.NewReader(make([]byte, 1024*1024))) - b, err := reader.Read() + b, err := reader.ReadMultiBuffer() assert(err, IsNil) assert(b.Len(), Equals, 2*1024) - b, err = reader.Read() + b, err = reader.ReadMultiBuffer() assert(err, IsNil) assert(b.Len(), Equals, 32*1024) } @@ -28,22 +28,23 @@ func TestBytesReaderWriteTo(t *testing.T) { assert := With(t) stream := ray.NewStream(context.Background()) - reader := ToBytesReader(stream) + reader := NewBufferedReader(stream) b1 := New() b1.AppendBytes('a', 'b', 'c') b2 := New() b2.AppendBytes('e', 'f', 'g') - assert(stream.Write(NewMultiBufferValue(b1, b2)), IsNil) + assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil) stream.Close() stream2 := ray.NewStream(context.Background()) - writer := ToBytesWriter(stream2) + writer := NewBufferedWriter(stream2) + writer.SetBuffered(false) nBytes, err := io.Copy(writer, reader) assert(err, IsNil) assert(nBytes, Equals, int64(6)) - mb, err := stream2.Read() + mb, err := stream2.ReadMultiBuffer() assert(err, IsNil) assert(len(mb), Equals, 2) assert(mb[0].String(), Equals, "abc") @@ -54,16 +55,16 @@ func TestBytesReaderMultiBuffer(t *testing.T) { assert := With(t) stream := ray.NewStream(context.Background()) - reader := ToBytesReader(stream) + reader := NewBufferedReader(stream) b1 := New() b1.AppendBytes('a', 'b', 'c') b2 := New() b2.AppendBytes('e', 'f', 'g') - assert(stream.Write(NewMultiBufferValue(b1, b2)), IsNil) + assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil) stream.Close() mbReader := NewReader(reader) - mb, err := mbReader.Read() + mb, err := mbReader.ReadMultiBuffer() assert(err, IsNil) assert(len(mb), Equals, 2) assert(mb[0].String(), Equals, "abc") diff --git a/common/buf/writer.go b/common/buf/writer.go index 32122a153..4825db6f5 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -8,49 +8,142 @@ import ( // BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer. type BufferToBytesWriter struct { - writer io.Writer + io.Writer } -// Write implements Writer.Write(). Write() takes ownership of the given buffer. -func (w *BufferToBytesWriter) Write(mb MultiBuffer) error { +func NewBufferToBytesWriter(writer io.Writer) *BufferToBytesWriter { + return &BufferToBytesWriter{ + Writer: writer, + } +} + +// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer. +func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error { defer mb.Release() bs := mb.ToNetBuffers() - _, err := bs.WriteTo(w.writer) + _, err := bs.WriteTo(w) return err } -type writerAdapter struct { - writer MultiBufferWriter +func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) { + if readerFrom, ok := w.Writer.(io.ReaderFrom); ok { + return readerFrom.ReadFrom(reader) + } + + var sc SizeCounter + err := Copy(NewReader(reader), w, CountSize(&sc)) + return sc.Size, err } -// Write implements buf.MultiBufferWriter. -func (w *writerAdapter) Write(mb MultiBuffer) error { - return w.writer.WriteMultiBuffer(mb) +type BufferedWriter struct { + writer Writer + legacyWriter io.Writer + buffer *Buffer + buffered bool } -type mergingWriter struct { - writer io.Writer - buffer []byte +func NewBufferedWriter(writer Writer) *BufferedWriter { + w := &BufferedWriter{ + writer: writer, + buffer: New(), + buffered: true, + } + if lw, ok := writer.(io.Writer); ok { + w.legacyWriter = lw + } + return w } -func (w *mergingWriter) Write(mb MultiBuffer) error { - defer mb.Release() +func (w *BufferedWriter) Write(b []byte) (int, error) { + if !w.buffered && w.legacyWriter != nil { + return w.legacyWriter.Write(b) + } - for !mb.IsEmpty() { - nBytes, _ := mb.Read(w.buffer) - if _, err := w.writer.Write(w.buffer[:nBytes]); err != nil { + totalBytes := 0 + for len(b) > 0 { + nBytes, err := w.buffer.Write(b) + totalBytes += nBytes + if err != nil { + return totalBytes, err + } + if w.buffer.IsFull() { + if err := w.Flush(); err != nil { + return totalBytes, err + } + } + b = b[nBytes:] + } + return totalBytes, nil +} + +func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error { + if !w.buffered { + return w.writer.WriteMultiBuffer(b) + } + + defer b.Release() + + for !b.IsEmpty() { + if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil { return err } + if w.buffer.IsFull() { + if err := w.Flush(); err != nil { + return err + } + } + } + + return nil +} + +func (w *BufferedWriter) Flush() error { + if !w.buffer.IsEmpty() { + if err := w.writer.WriteMultiBuffer(NewMultiBufferValue(w.buffer)); err != nil { + return err + } + + if w.buffered { + w.buffer = New() + } else { + w.buffer = nil + } } return nil } +func (w *BufferedWriter) SetBuffered(f bool) error { + w.buffered = f + if !f { + return w.Flush() + } + return nil +} + +func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) { + var sc SizeCounter + if !w.buffer.IsEmpty() { + sc.Size += int64(w.buffer.Len()) + if err := w.Flush(); err != nil { + return sc.Size, err + } + } + + if readerFrom, ok := w.writer.(io.ReaderFrom); ok { + return readerFrom.ReadFrom(reader) + } + + w.buffered = false + err := Copy(NewReader(reader), w, CountSize(&sc)) + return sc.Size, err +} + type seqWriter struct { writer io.Writer } -func (w *seqWriter) Write(mb MultiBuffer) error { +func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error { defer mb.Release() for _, b := range mb { @@ -65,49 +158,9 @@ func (w *seqWriter) Write(mb MultiBuffer) error { return nil } -var ( - _ MultiBufferWriter = (*bytesToBufferWriter)(nil) -) - -type bytesToBufferWriter struct { - writer Writer -} - -// Write implements io.Writer. -func (w *bytesToBufferWriter) Write(payload []byte) (int, error) { - mb := NewMultiBufferCap(len(payload)/Size + 1) - mb.Write(payload) - if err := w.writer.Write(mb); err != nil { - return 0, err - } - return len(payload), nil -} - -func (w *bytesToBufferWriter) WriteMultiBuffer(mb MultiBuffer) error { - return w.writer.Write(mb) -} - -func (w *bytesToBufferWriter) ReadFrom(reader io.Reader) (int64, error) { - mbReader := NewReader(reader) - totalBytes := int64(0) - for { - mb, err := mbReader.Read() - if errors.Cause(err) == io.EOF { - break - } else if err != nil { - return totalBytes, err - } - totalBytes += int64(mb.Len()) - if err := w.writer.Write(mb); err != nil { - return totalBytes, err - } - } - return totalBytes, nil -} - type noOpWriter struct{} -func (noOpWriter) Write(b MultiBuffer) error { +func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error { b.Release() return nil } diff --git a/common/buf/writer_test.go b/common/buf/writer_test.go index 0dfef9b03..3f64411dd 100644 --- a/common/buf/writer_test.go +++ b/common/buf/writer_test.go @@ -25,9 +25,11 @@ func TestWriter(t *testing.T) { writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024)) - writer := NewWriter(NewBufferedWriter(writeBuffer)) - err := writer.Write(NewMultiBufferValue(lb)) + writer := NewBufferedWriter(NewWriter(writeBuffer)) + writer.SetBuffered(false) + err := writer.WriteMultiBuffer(NewMultiBufferValue(lb)) assert(err, IsNil) + assert(writer.Flush(), IsNil) assert(expectedBytes, Equals, writeBuffer.Bytes()) } @@ -36,20 +38,21 @@ func TestBytesWriterReadFrom(t *testing.T) { cache := ray.NewStream(context.Background()) reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192)) - _, err := reader.WriteTo(ToBytesWriter(cache)) + writer := NewBufferedWriter(cache) + writer.SetBuffered(false) + _, err := reader.WriteTo(writer) assert(err, IsNil) - mb, err := cache.Read() + mb, err := cache.ReadMultiBuffer() assert(err, IsNil) assert(mb.Len(), Equals, 8192) - assert(len(mb), Equals, 4) } func TestDiscardBytes(t *testing.T) { assert := With(t) b := New() - common.Must(b.Reset(ReadFrom(rand.Reader))) + common.Must(b.Reset(ReadFullFrom(rand.Reader, Size))) nBytes, err := io.Copy(DiscardBytes, b) assert(nBytes, Equals, int64(Size)) @@ -64,7 +67,7 @@ func TestDiscardBytesMultiBuffer(t *testing.T) { common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size))) r := NewReader(buffer) - nBytes, err := io.Copy(DiscardBytes, ToBytesReader(r)) + nBytes, err := io.Copy(DiscardBytes, NewBufferedReader(r)) assert(nBytes, Equals, int64(size)) assert(err, IsNil) } diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 289e14f29..cd29a6a05 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -151,7 +151,7 @@ func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) { return b, nil } -func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { +func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) { b, err := r.readChunk(true) if err != nil { return nil, err @@ -193,81 +193,97 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { return mb, nil } +const ( + WriteSize = 1024 +) + type AuthenticationWriter struct { auth Authenticator - buffer []byte - payload []byte - writer *buf.BufferedWriter + writer buf.Writer sizeParser ChunkSizeEncoder transferType protocol.TransferType } func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter { - const payloadSize = 1024 return &AuthenticationWriter{ auth: auth, - buffer: make([]byte, payloadSize+sizeParser.SizeBytes()+auth.Overhead()), - payload: make([]byte, payloadSize), - writer: buf.NewBufferedWriterSize(writer, readerBufferSize), + writer: buf.NewWriter(writer), sizeParser: sizeParser, transferType: transferType, } } -func (w *AuthenticationWriter) append(b []byte) error { - encryptedSize := len(b) + w.auth.Overhead() - buffer := w.sizeParser.Encode(uint16(encryptedSize), w.buffer[:0]) +func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) { + encryptedSize := b.Len() + w.auth.Overhead() - buffer, err := w.auth.Seal(buffer, b) - if err != nil { - return err + eb := buf.New() + common.Must(eb.Reset(func(bb []byte) (int, error) { + w.sizeParser.Encode(uint16(encryptedSize), bb[:0]) + return w.sizeParser.SizeBytes(), nil + })) + if err := eb.AppendSupplier(func(bb []byte) (int, error) { + _, err := w.auth.Seal(bb[:0], b.Bytes()) + return encryptedSize, err + }); err != nil { + eb.Release() + return nil, err } - if _, err := w.writer.Write(buffer); err != nil { - return err - } - - return nil + return eb, nil } func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error { defer mb.Release() + mb2Write := buf.NewMultiBufferCap(len(mb) + 10) + for { - n, _ := mb.Read(w.payload) - if err := w.append(w.payload[:n]); err != nil { + b := buf.New() + common.Must(b.Reset(func(bb []byte) (int, error) { + return mb.Read(bb[:WriteSize]) + })) + eb, err := w.seal(b) + b.Release() + + if err != nil { + mb2Write.Release() return err } + mb2Write.Append(eb) if mb.IsEmpty() { break } } - return w.writer.Flush() + return w.writer.WriteMultiBuffer(mb2Write) } func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error { defer mb.Release() + mb2Write := buf.NewMultiBufferCap(len(mb) * 2) + for { b := mb.SplitFirst() if b == nil { b = buf.New() } - if err := w.append(b.Bytes()); err != nil { - b.Release() + eb, err := w.seal(b) + b.Release() + if err != nil { + mb2Write.Release() return err } - b.Release() + mb2Write.Append(eb) if mb.IsEmpty() { break } } - return w.writer.Flush() + return w.writer.WriteMultiBuffer(mb2Write) } -func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error { +func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { if w.transferType == protocol.TransferTypeStream { return w.writeStream(mb) } diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index 602fa34ec..4016dc51b 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -42,9 +42,9 @@ func TestAuthenticationReaderWriter(t *testing.T) { AdditionalDataGenerator: &NoOpBytesGenerator{}, }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream) - assert(writer.Write(buf.NewMultiBufferValue(payload)), IsNil) + assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil) assert(cache.Len(), Equals, 83360) - assert(writer.Write(buf.MultiBuffer{}), IsNil) + assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil) assert(err, IsNil) reader := NewAuthenticationReader(&AEADAuthenticator{ @@ -58,7 +58,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { var mb buf.MultiBuffer for mb.Len() < len(rawPayload) { - mb2, err := reader.Read() + mb2, err := reader.ReadMultiBuffer() assert(err, IsNil) mb.AppendMulti(mb2) @@ -68,7 +68,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { mb.Read(mbContent) assert(mbContent, Equals, rawPayload) - _, err = reader.Read() + _, err = reader.ReadMultiBuffer() assert(err, Equals, io.EOF) } @@ -104,9 +104,9 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) { pb2.Append([]byte("efgh")) payload.Append(pb2) - assert(writer.Write(payload), IsNil) + assert(writer.WriteMultiBuffer(payload), IsNil) assert(cache.Len(), GreaterThan, 0) - assert(writer.Write(buf.MultiBuffer{}), IsNil) + assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil) assert(err, IsNil) reader := NewAuthenticationReader(&AEADAuthenticator{ @@ -117,7 +117,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) { AdditionalDataGenerator: &NoOpBytesGenerator{}, }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket) - mb, err := reader.Read() + mb, err := reader.ReadMultiBuffer() assert(err, IsNil) b1 := mb.SplitFirst() @@ -126,6 +126,6 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) { assert(b2.String(), Equals, "efgh") assert(mb.IsEmpty(), IsTrue) - _, err = reader.Read() + _, err = reader.ReadMultiBuffer() assert(err, Equals, io.EOF) } diff --git a/common/crypto/chunk.go b/common/crypto/chunk.go index af43e346c..c66a107bb 100644 --- a/common/crypto/chunk.go +++ b/common/crypto/chunk.go @@ -48,7 +48,7 @@ func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *Chunk sizeDecoder: sizeDecoder, reader: buf.NewReader(reader), buffer: make([]byte, sizeDecoder.SizeBytes()), - leftOver: buf.NewMultiBufferCap(16), + leftOver: buf.NewMultiBufferCap(16), } } @@ -56,7 +56,7 @@ func (r *ChunkStreamReader) readAtLeast(size int) error { mb := r.leftOver r.leftOver = nil for mb.Len() < size { - extra, err := r.reader.Read() + extra, err := r.reader.ReadMultiBuffer() if err != nil { mb.Release() return err @@ -78,7 +78,7 @@ func (r *ChunkStreamReader) readSize() (uint16, error) { return r.sizeDecoder.Decode(r.buffer) } -func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) { +func (r *ChunkStreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) { size := r.leftOverSize if size == 0 { nextSize, err := r.readSize() @@ -129,10 +129,10 @@ func NewChunkStreamWriter(sizeEncoder ChunkSizeEncoder, writer io.Writer) *Chunk } } -func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error { +func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { const sliceSize = 8192 mbLen := mb.Len() - mb2Write := buf.NewMultiBufferCap(mbLen / buf.Size + mbLen / sliceSize + 2) + mb2Write := buf.NewMultiBufferCap(mbLen/buf.Size + mbLen/sliceSize + 2) for { slice := mb.SliceBySize(sliceSize) @@ -150,5 +150,5 @@ func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error { } } - return w.writer.Write(mb2Write) + return w.writer.WriteMultiBuffer(mb2Write) } diff --git a/common/crypto/chunk_test.go b/common/crypto/chunk_test.go index 2c9763cb6..4131ebfb2 100644 --- a/common/crypto/chunk_test.go +++ b/common/crypto/chunk_test.go @@ -19,26 +19,26 @@ func TestChunkStreamIO(t *testing.T) { b := buf.New() b.AppendBytes('a', 'b', 'c', 'd') - assert(writer.Write(buf.NewMultiBufferValue(b)), IsNil) + assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil) b = buf.New() b.AppendBytes('e', 'f', 'g') - assert(writer.Write(buf.NewMultiBufferValue(b)), IsNil) + assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil) - assert(writer.Write(buf.MultiBuffer{}), IsNil) + assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil) assert(cache.Len(), Equals, 13) - mb, err := reader.Read() + mb, err := reader.ReadMultiBuffer() assert(err, IsNil) assert(mb.Len(), Equals, 4) assert(mb[0].Bytes(), Equals, []byte("abcd")) - mb, err = reader.Read() + mb, err = reader.ReadMultiBuffer() assert(err, IsNil) assert(mb.Len(), Equals, 3) assert(mb[0].Bytes(), Equals, []byte("efg")) - _, err = reader.Read() + _, err = reader.ReadMultiBuffer() assert(err, Equals, io.EOF) } diff --git a/common/crypto/io.go b/common/crypto/io.go index cc595f17e..295ba4725 100644 --- a/common/crypto/io.go +++ b/common/crypto/io.go @@ -28,7 +28,7 @@ func (r *CryptionReader) Read(data []byte) (int, error) { } var ( - _ buf.MultiBufferWriter = (*CryptionWriter)(nil) + _ buf.Writer = (*CryptionWriter)(nil) ) type CryptionWriter struct { diff --git a/proxy/blackhole/config.go b/proxy/blackhole/config.go index 88c58c9a3..5bfd3290e 100644 --- a/proxy/blackhole/config.go +++ b/proxy/blackhole/config.go @@ -29,7 +29,7 @@ func (*NoneResponse) WriteTo(buf.Writer) {} func (*HTTPResponse) WriteTo(writer buf.Writer) { b := buf.NewLocal(512) common.Must(b.AppendSupplier(serial.WriteString(http403response))) - writer.Write(buf.NewMultiBufferValue(b)) + writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) } // GetInternalResponse converts response settings from proto to internal data structure. diff --git a/proxy/http/server.go b/proxy/http/server.go index 9c3f06af3..484389ddf 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -255,15 +255,18 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, rea requestDone := signal.ExecuteAsync(func() error { request.Header.Set("Connection", "close") - requestWriter := buf.ToBytesWriter(ray.InboundInput()) + requestWriter := buf.NewBufferedWriter(ray.InboundInput()) if err := request.Write(requestWriter); err != nil { return err } + if err := requestWriter.Flush(); err != nil { + return err + } return nil }) responseDone := signal.ExecuteAsync(func() error { - responseReader := bufio.NewReaderSize(buf.ToBytesReader(ray.InboundOutput()), 2048) + responseReader := bufio.NewReaderSize(buf.NewBufferedReader(ray.InboundOutput()), 2048) response, err := http.ReadResponse(responseReader, request) if err == nil { StripHopByHopHeaders(response.Header) diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 0470b646b..045813577 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -93,7 +93,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5) if request.Command == protocol.RequestCommandTCP { - bufferedWriter := buf.NewBufferedWriter(conn) + bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) bodyWriter, err := WriteTCPRequest(request, bufferedWriter) if err != nil { return newError("failed to write request").Base(err) diff --git a/proxy/shadowsocks/ota.go b/proxy/shadowsocks/ota.go index eb8e37072..91549b829 100644 --- a/proxy/shadowsocks/ota.go +++ b/proxy/shadowsocks/ota.go @@ -68,7 +68,7 @@ func NewChunkReader(reader io.Reader, auth *Authenticator) *ChunkReader { } } -func (v *ChunkReader) Read() (buf.MultiBuffer, error) { +func (v *ChunkReader) ReadMultiBuffer() (buf.MultiBuffer, error) { buffer := buf.New() if err := buffer.AppendSupplier(buf.ReadFullFrom(v.reader, 2)); err != nil { buffer.Release() @@ -117,8 +117,8 @@ func NewChunkWriter(writer io.Writer, auth *Authenticator) *ChunkWriter { } } -// Write implements buf.MultiBufferWriter. -func (w *ChunkWriter) Write(mb buf.MultiBuffer) error { +// WriteMultiBuffer implements buf.Writer. +func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { defer mb.Release() for { diff --git a/proxy/shadowsocks/ota_test.go b/proxy/shadowsocks/ota_test.go index b93d55a69..95064444f 100644 --- a/proxy/shadowsocks/ota_test.go +++ b/proxy/shadowsocks/ota_test.go @@ -16,7 +16,7 @@ func TestNormalChunkReading(t *testing.T) { 0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18) reader := NewChunkReader(buffer, NewAuthenticator(ChunkKeyGenerator( []byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}))) - payload, err := reader.Read() + payload, err := reader.ReadMultiBuffer() assert(err, IsNil) assert(payload[0].Bytes(), Equals, []byte{11, 12, 13, 14, 15, 16, 17, 18}) } @@ -30,7 +30,7 @@ func TestNormalChunkWriting(t *testing.T) { b := buf.NewLocal(256) b.Append([]byte{11, 12, 13, 14, 15, 16, 17, 18}) - err := writer.Write(buf.NewMultiBufferValue(b)) + err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) assert(err, IsNil) assert(buffer.Bytes(), Equals, []byte{0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18}) } diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 529b2cb27..c44548957 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -362,7 +362,7 @@ type UDPReader struct { User *protocol.User } -func (v *UDPReader) Read() (buf.MultiBuffer, error) { +func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { buffer := buf.New() err := buffer.AppendSupplier(buf.ReadFrom(v.Reader)) if err != nil { diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go index 763bd26b2..f2cb3c0a3 100644 --- a/proxy/shadowsocks/protocol_test.go +++ b/proxy/shadowsocks/protocol_test.go @@ -112,14 +112,14 @@ func TestTCPRequest(t *testing.T) { writer, err := WriteTCPRequest(request, cache) assert(err, IsNil) - assert(writer.Write(buf.NewMultiBufferValue(data)), IsNil) + assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(data)), IsNil) decodedRequest, reader, err := ReadTCPSession(request.User, cache) assert(err, IsNil) assert(decodedRequest.Address, Equals, request.Address) assert(decodedRequest.Port, Equals, request.Port) - decodedData, err := reader.Read() + decodedData, err := reader.ReadMultiBuffer() assert(err, IsNil) assert(decodedData[0].String(), Equals, string(payload)) } @@ -158,19 +158,19 @@ func TestUDPReaderWriter(t *testing.T) { b := buf.New() b.AppendSupplier(serial.WriteString("test payload")) - err := writer.Write(buf.NewMultiBufferValue(b)) + err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) assert(err, IsNil) - payload, err := reader.Read() + payload, err := reader.ReadMultiBuffer() assert(err, IsNil) assert(payload[0].String(), Equals, "test payload") b = buf.New() b.AppendSupplier(serial.WriteString("test payload 2")) - err = writer.Write(buf.NewMultiBufferValue(b)) + err = writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) assert(err, IsNil) - payload, err = reader.Read() + payload, err = reader.ReadMultiBuffer() assert(err, IsNil) assert(payload[0].String(), Equals, "test payload 2") } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index be28b5dd3..0e52ed3b6 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -74,7 +74,7 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection reader := buf.NewReader(conn) for { - mpayload, err := reader.Read() + mpayload, err := reader.ReadMultiBuffer() if err != nil { break } @@ -129,7 +129,7 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error { conn.SetReadDeadline(time.Now().Add(time.Second * 8)) - bufferedReader := buf.NewBufferedReader(conn) + bufferedReader := buf.NewBufferedReader(buf.NewReader(conn)) request, bodyReader, err := ReadTCPSession(s.user, bufferedReader) if err != nil { log.Access(conn.RemoteAddr(), "", log.AccessRejected, err) @@ -153,17 +153,17 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, } responseDone := signal.ExecuteAsync(func() error { - bufferedWriter := buf.NewBufferedWriter(conn) + bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) responseWriter, err := WriteTCPResponse(request, bufferedWriter) if err != nil { return newError("failed to write response").Base(err) } - payload, err := ray.InboundOutput().Read() + payload, err := ray.InboundOutput().ReadMultiBuffer() if err != nil { return err } - if err := responseWriter.Write(payload); err != nil { + if err := responseWriter.WriteMultiBuffer(payload); err != nil { return err } payload.Release() diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index f956b6f0c..f5f255203 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -352,7 +352,7 @@ func NewUDPReader(reader io.Reader) *UDPReader { return &UDPReader{reader: reader} } -func (r *UDPReader) Read() (buf.MultiBuffer, error) { +func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { b := buf.New() if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil { return nil, err diff --git a/proxy/socks/protocol_test.go b/proxy/socks/protocol_test.go index 61f2d8dbd..a5cb18cfb 100644 --- a/proxy/socks/protocol_test.go +++ b/proxy/socks/protocol_test.go @@ -24,11 +24,11 @@ func TestUDPEncoding(t *testing.T) { content := []byte{'a'} payload := buf.New() payload.Append(content) - assert(writer.Write(buf.NewMultiBufferValue(payload)), IsNil) + assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil) reader := NewUDPReader(b) - decodedPayload, err := reader.Read() + decodedPayload, err := reader.ReadMultiBuffer() assert(err, IsNil) assert(decodedPayload[0].Bytes(), Equals, content) } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index defa77916..ec3d74e28 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -58,7 +58,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error { conn.SetReadDeadline(time.Now().Add(time.Second * 8)) - reader := buf.NewBufferedReader(conn) + reader := buf.NewBufferedReader(buf.NewReader(conn)) inboundDest, ok := proxy.InboundEntryPointFromContext(ctx) if !ok { @@ -154,7 +154,7 @@ func (v *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, reader := buf.NewReader(conn) for { - mpayload, err := reader.Read() + mpayload, err := reader.ReadMultiBuffer() if err != nil { return err } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index faff13407..46b5a6b3e 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -142,12 +142,12 @@ func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSess bodyWriter := session.EncodeResponseBody(request, output) // Optimize for small response packet - data, err := input.Read() + data, err := input.ReadMultiBuffer() if err != nil { return err } - if err := bodyWriter.Write(data); err != nil { + if err := bodyWriter.WriteMultiBuffer(data); err != nil { return err } data.Release() @@ -163,7 +163,7 @@ func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSess } if request.Option.Has(protocol.RequestOptionChunkStream) { - if err := bodyWriter.Write(buf.MultiBuffer{}); err != nil { + if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil { return err } } @@ -177,7 +177,7 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i return err } - reader := buf.NewBufferedReader(connection) + reader := buf.NewBufferedReader(buf.NewReader(connection)) session := encoding.NewServerSession(v.clients, v.sessionHistory) request, err := session.DecodeRequestHeader(reader) @@ -213,14 +213,12 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i input := ray.InboundInput() output := ray.InboundOutput() - reader.SetBuffered(false) - requestDone := signal.ExecuteAsync(func() error { return transferRequest(timer, session, request, reader, input) }) responseDone := signal.ExecuteAsync(func() error { - writer := buf.NewBufferedWriter(connection) + writer := buf.NewBufferedWriter(buf.NewWriter(connection)) defer writer.Flush() response := &protocol.ResponseHeader{ diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index a883d1fb5..aac6094d4 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -106,7 +106,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5) requestDone := signal.ExecuteAsync(func() error { - writer := buf.NewBufferedWriter(conn) + writer := buf.NewBufferedWriter(buf.NewWriter(conn)) if err := session.EncodeRequestHeader(request, writer); err != nil { return newError("failed to encode request").Base(err).AtWarning() } @@ -117,7 +117,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial return newError("failed to get first payload").Base(err) } if !firstPayload.IsEmpty() { - if err := bodyWriter.Write(firstPayload); err != nil { + if err := bodyWriter.WriteMultiBuffer(firstPayload); err != nil { return newError("failed to write first payload").Base(err) } firstPayload.Release() @@ -132,7 +132,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial } if request.Option.Has(protocol.RequestOptionChunkStream) { - if err := bodyWriter.Write(buf.MultiBuffer{}); err != nil { + if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil { return err } } @@ -142,7 +142,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial responseDone := signal.ExecuteAsync(func() error { defer output.Close() - reader := buf.NewBufferedReader(conn) + reader := buf.NewBufferedReader(buf.NewReader(conn)) header, err := session.DecodeResponseHeader(reader) if err != nil { return err diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index f20ab562e..5551fab6e 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -169,8 +169,7 @@ type SystemConnection interface { } var ( - _ buf.MultiBufferReader = (*Connection)(nil) - _ buf.MultiBufferWriter = (*Connection)(nil) + _ buf.Reader = (*Connection)(nil) ) // Connection is a KCP connection over UDP. @@ -265,7 +264,7 @@ func (v *Connection) OnDataOutput() { } } -// ReadMultiBuffer implements buf.MultiBufferReader. +// ReadMultiBuffer implements buf.Reader. func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { if v == nil { return nil, io.EOF @@ -375,13 +374,6 @@ func (v *Connection) Write(b []byte) (int, error) { } } -func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { - if c.mergingWriter == nil { - c.mergingWriter = buf.NewMergingWriterSize(c, c.mss) - } - return c.mergingWriter.Write(mb) -} - func (v *Connection) SetState(state State) { current := v.Elapsed() atomic.StoreInt32((*int32)(&v.state), int32(state)) diff --git a/transport/internet/tls/tls.go b/transport/internet/tls/tls.go index 42ade5f79..c9b6e6f4d 100644 --- a/transport/internet/tls/tls.go +++ b/transport/internet/tls/tls.go @@ -10,29 +10,23 @@ import ( //go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg tls -path Transport,Internet,TLS var ( - _ buf.MultiBufferReader = (*conn)(nil) - _ buf.MultiBufferWriter = (*conn)(nil) + _ buf.Writer = (*conn)(nil) ) type conn struct { net.Conn - mergingReader buf.Reader - mergingWriter buf.Writer -} - -func (c *conn) ReadMultiBuffer() (buf.MultiBuffer, error) { - if c.mergingReader == nil { - c.mergingReader = buf.NewBytesToBufferReader(c.Conn) - } - return c.mergingReader.Read() + mergingWriter *buf.BufferedWriter } func (c *conn) WriteMultiBuffer(mb buf.MultiBuffer) error { if c.mergingWriter == nil { - c.mergingWriter = buf.NewMergingWriter(c.Conn) + c.mergingWriter = buf.NewBufferedWriter(buf.NewWriter(c.Conn)) } - return c.mergingWriter.Write(mb) + if err := c.mergingWriter.WriteMultiBuffer(mb); err != nil { + return err + } + return c.mergingWriter.Flush() } func Client(c net.Conn, config *tls.Config) net.Conn { diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go index ab9e8b006..8b116614a 100644 --- a/transport/internet/udp/dispatcher.go +++ b/transport/internet/udp/dispatcher.go @@ -57,7 +57,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, inboundRay, existing := v.getInboundRay(ctx, destination) outputStream := inboundRay.InboundInput() if outputStream != nil { - if err := outputStream.Write(buf.NewMultiBufferValue(payload)); err != nil { + if err := outputStream.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil { v.RemoveRay(destination) } } @@ -71,7 +71,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, func handleInput(input ray.InputStream, callback ResponseCallback) { for { - mb, err := input.Read() + mb, err := input.ReadMultiBuffer() if err != nil { break } diff --git a/transport/internet/udp/dispatcher_test.go b/transport/internet/udp/dispatcher_test.go index b780ab16a..76a939e77 100644 --- a/transport/internet/udp/dispatcher_test.go +++ b/transport/internet/udp/dispatcher_test.go @@ -28,11 +28,11 @@ func TestSameDestinationDispatching(t *testing.T) { link := ray.NewRay(ctx) go func() { for { - data, err := link.OutboundInput().Read() + data, err := link.OutboundInput().ReadMultiBuffer() if err != nil { break } - err = link.OutboundOutput().Write(data) + err = link.OutboundOutput().WriteMultiBuffer(data) assert(err, IsNil) } }() diff --git a/transport/internet/websocket/connection.go b/transport/internet/websocket/connection.go index bcdde3cff..d4e774326 100644 --- a/transport/internet/websocket/connection.go +++ b/transport/internet/websocket/connection.go @@ -11,8 +11,7 @@ import ( ) var ( - _ buf.MultiBufferReader = (*connection)(nil) - _ buf.MultiBufferWriter = (*connection)(nil) + _ buf.Writer = (*connection)(nil) ) // connection is a wrapper for net.Conn over WebSocket connection. @@ -20,8 +19,7 @@ type connection struct { conn *websocket.Conn reader io.Reader - mergingReader buf.Reader - mergingWriter buf.Writer + mergingWriter *buf.BufferedWriter } func newConnection(conn *websocket.Conn) *connection { @@ -47,13 +45,6 @@ func (c *connection) Read(b []byte) (int, error) { } } -func (c *connection) ReadMultiBuffer() (buf.MultiBuffer, error) { - if c.mergingReader == nil { - c.mergingReader = buf.NewBytesToBufferReader(c) - } - return c.mergingReader.Read() -} - func (c *connection) getReader() (io.Reader, error) { if c.reader != nil { return c.reader, nil @@ -77,9 +68,12 @@ func (c *connection) Write(b []byte) (int, error) { func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error { if c.mergingWriter == nil { - c.mergingWriter = buf.NewMergingWriter(c) + c.mergingWriter = buf.NewBufferedWriter(buf.NewBufferToBytesWriter(c)) } - return c.mergingWriter.Write(mb) + if err := c.mergingWriter.WriteMultiBuffer(mb); err != nil { + return err + } + return c.mergingWriter.Flush() } func (c *connection) Close() error { diff --git a/transport/ray/direct.go b/transport/ray/direct.go index 2378e394b..2df368b49 100644 --- a/transport/ray/direct.go +++ b/transport/ray/direct.go @@ -106,7 +106,7 @@ func (s *Stream) Peek(b *buf.Buffer) { } // Read reads data from the Stream. -func (s *Stream) Read() (buf.MultiBuffer, error) { +func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) { for { mb, err := s.getData() if err != nil { @@ -178,7 +178,7 @@ func (s *Stream) waitForStreamSize() error { } // Write writes more data into the Stream. -func (s *Stream) Write(data buf.MultiBuffer) error { +func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error { if data.IsEmpty() { return nil } diff --git a/transport/ray/direct_test.go b/transport/ray/direct_test.go index 4350d7ca1..64c0ab1d0 100644 --- a/transport/ray/direct_test.go +++ b/transport/ray/direct_test.go @@ -16,18 +16,18 @@ func TestStreamIO(t *testing.T) { stream := NewStream(context.Background()) b1 := buf.New() b1.AppendBytes('a') - assert(stream.Write(buf.NewMultiBufferValue(b1)), IsNil) + assert(stream.WriteMultiBuffer(buf.NewMultiBufferValue(b1)), IsNil) - _, err := stream.Read() + _, err := stream.ReadMultiBuffer() assert(err, IsNil) stream.Close() - _, err = stream.Read() + _, err = stream.ReadMultiBuffer() assert(err, Equals, io.EOF) b2 := buf.New() b2.AppendBytes('b') - err = stream.Write(buf.NewMultiBufferValue(b2)) + err = stream.WriteMultiBuffer(buf.NewMultiBufferValue(b2)) assert(err, Equals, io.ErrClosedPipe) } @@ -37,13 +37,13 @@ func TestStreamClose(t *testing.T) { stream := NewStream(context.Background()) b1 := buf.New() b1.AppendBytes('a') - assert(stream.Write(buf.NewMultiBufferValue(b1)), IsNil) + assert(stream.WriteMultiBuffer(buf.NewMultiBufferValue(b1)), IsNil) stream.Close() - _, err := stream.Read() + _, err := stream.ReadMultiBuffer() assert(err, IsNil) - _, err = stream.Read() + _, err = stream.ReadMultiBuffer() assert(err, Equals, io.EOF) }