1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-30 05:56:54 -05:00

cleanup buffer usage

This commit is contained in:
Darien Raymond 2017-11-09 22:33:15 +01:00
parent 6e61538b36
commit 594ec15c09
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
41 changed files with 358 additions and 529 deletions

View File

@ -161,7 +161,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
log.Trace(newError("dispatching request to ", dest))
data, _ := s.input.ReadTimeout(time.Millisecond * 500)
if err := writer.Write(data); err != nil {
if err := writer.WriteMultiBuffer(data); err != nil {
log.Trace(newError("failed to write first payload").Base(err))
return
}
@ -234,7 +234,7 @@ func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
func (m *Client) fetchOutput() {
defer m.cancel()
reader := buf.ToBytesReader(m.inboundRay.InboundOutput())
reader := buf.NewBufferedReader(m.inboundRay.InboundOutput())
for {
meta, err := ReadMetadata(reader)
@ -396,7 +396,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error
func (w *ServerWorker) run(ctx context.Context) {
input := w.outboundRay.OutboundInput()
reader := buf.ToBytesReader(input)
reader := buf.NewBufferedReader(input)
defer w.sessionManager.Close()

View File

@ -16,7 +16,7 @@ import (
func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
var mb buf.MultiBuffer
for {
b, err := reader.Read()
b, err := reader.ReadMultiBuffer()
if err == io.EOF {
break
}
@ -45,7 +45,7 @@ func TestReaderWriter(t *testing.T) {
writePayload := func(writer *Writer, payload ...byte) error {
b := buf.New()
b.Append(payload)
return writer.Write(buf.NewMultiBufferValue(b))
return writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
}
assert(writePayload(writer, 'a', 'b', 'c', 'd'), IsNil)
@ -60,7 +60,7 @@ func TestReaderWriter(t *testing.T) {
assert(writePayload(writer2, 'y'), IsNil)
writer2.Close()
bytesReader := buf.ToBytesReader(stream)
bytesReader := buf.NewBufferedReader(stream)
streamReader := NewStreamReader(bytesReader)
meta, err := ReadMetadata(bytesReader)

View File

@ -40,8 +40,8 @@ func NewPacketReader(reader io.Reader) *PacketReader {
}
}
// Read implements buf.Reader.
func (r *PacketReader) Read() (buf.MultiBuffer, error) {
// ReadMultiBuffer implements buf.Reader.
func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.eof {
return nil, io.EOF
}
@ -79,8 +79,8 @@ func NewStreamReader(reader io.Reader) *StreamReader {
}
}
// Read implmenets buf.Reader.
func (r *StreamReader) Read() (buf.MultiBuffer, error) {
// ReadMultiBuffer implmenets buf.Reader.
func (r *StreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.leftOver == 0 {
r.leftOver = -1
return nil, io.EOF

View File

@ -56,7 +56,7 @@ func (w *Writer) writeMetaOnly() error {
if err := b.Reset(meta.AsSupplier()); err != nil {
return err
}
return w.writer.Write(buf.NewMultiBufferValue(b))
return w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
}
func (w *Writer) writeData(mb buf.MultiBuffer) error {
@ -74,11 +74,11 @@ func (w *Writer) writeData(mb buf.MultiBuffer) error {
mb2 := buf.NewMultiBufferCap(len(mb) + 1)
mb2.Append(frame)
mb2.AppendMulti(mb)
return w.writer.Write(mb2)
return w.writer.WriteMultiBuffer(mb2)
}
// Write implements buf.MultiBufferWriter.
func (w *Writer) Write(mb buf.MultiBuffer) error {
// WriteMultiBuffer implements buf.Writer.
func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release()
if mb.IsEmpty() {
@ -109,5 +109,5 @@ func (w *Writer) Close() {
frame := buf.New()
common.Must(frame.Reset(meta.AsSupplier()))
w.writer.Write(buf.NewMultiBufferValue(frame))
w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(frame))
}

View File

@ -123,8 +123,8 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
}
var (
_ buf.MultiBufferReader = (*Connection)(nil)
_ buf.MultiBufferWriter = (*Connection)(nil)
_ buf.Reader = (*Connection)(nil)
_ buf.Writer = (*Connection)(nil)
)
type Connection struct {
@ -133,8 +133,7 @@ type Connection struct {
localAddr net.Addr
remoteAddr net.Addr
bytesReader io.Reader
reader buf.Reader
reader *buf.BufferedReader
writer buf.Writer
}
@ -149,8 +148,7 @@ func NewConnection(stream ray.Ray) *Connection {
IP: []byte{0, 0, 0, 0},
Port: 0,
},
bytesReader: buf.ToBytesReader(stream.InboundOutput()),
reader: stream.InboundOutput(),
reader: buf.NewBufferedReader(stream.InboundOutput()),
writer: stream.InboundInput(),
}
}
@ -160,11 +158,11 @@ func (v *Connection) Read(b []byte) (int, error) {
if v.closed {
return 0, io.EOF
}
return v.bytesReader.Read(b)
return v.reader.Read(b)
}
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
return v.reader.Read()
return v.reader.ReadMultiBuffer()
}
// Write implements net.Conn.Write().
@ -172,14 +170,19 @@ func (v *Connection) Write(b []byte) (int, error) {
if v.closed {
return 0, io.ErrClosedPipe
}
return buf.ToBytesWriter(v.writer).Write(b)
l := len(b)
mb := buf.NewMultiBufferCap(l/buf.Size + 1)
mb.Write(b)
return l, v.writer.WriteMultiBuffer(mb)
}
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if v.closed {
return io.ErrClosedPipe
}
return v.writer.Write(mb)
return v.writer.WriteMultiBuffer(mb)
}
// Close implements net.Conn.Close().

View File

@ -1,53 +0,0 @@
package buf
import (
"io"
)
// BufferedReader is a reader with internal cache.
type BufferedReader struct {
reader io.Reader
buffer *Buffer
buffered bool
}
// NewBufferedReader creates a new BufferedReader based on an io.Reader.
func NewBufferedReader(rawReader io.Reader) *BufferedReader {
return &BufferedReader{
reader: rawReader,
buffer: NewLocal(1024),
buffered: true,
}
}
// IsBuffered returns true if the internal cache is effective.
func (r *BufferedReader) IsBuffered() bool {
return r.buffered
}
// SetBuffered is to enable or disable internal cache. If cache is disabled,
// Read() calls will be delegated to the underlying io.Reader directly.
func (r *BufferedReader) SetBuffered(cached bool) {
r.buffered = cached
}
// Read implements io.Reader.Read().
func (r *BufferedReader) Read(b []byte) (int, error) {
if !r.buffered || r.buffer == nil {
if !r.buffer.IsEmpty() {
return r.buffer.Read(b)
}
return r.reader.Read(b)
}
if r.buffer.IsEmpty() {
if err := r.buffer.Reset(ReadFrom(r.reader)); err != nil {
return 0, err
}
}
if r.buffer.IsEmpty() {
return 0, nil
}
return r.buffer.Read(b)
}

View File

@ -1,36 +0,0 @@
package buf_test
import (
"crypto/rand"
"testing"
. "v2ray.com/core/common/buf"
. "v2ray.com/ext/assert"
)
func TestBufferedReader(t *testing.T) {
assert := With(t)
content := New()
assert(content.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
len := content.Len()
reader := NewBufferedReader(content)
assert(reader.IsBuffered(), IsTrue)
payload := make([]byte, 16)
nBytes, err := reader.Read(payload)
assert(nBytes, Equals, 16)
assert(err, IsNil)
len2 := content.Len()
assert(len-len2, GreaterThan, 16)
nBytes, err = reader.Read(payload)
assert(nBytes, Equals, 16)
assert(err, IsNil)
assert(content.Len(), Equals, len2)
}

View File

@ -1,73 +0,0 @@
package buf
import "io"
// BufferedWriter is an io.Writer with internal buffer. It writes to underlying writer when buffer is full or on demand.
// This type is not thread safe.
type BufferedWriter struct {
writer io.Writer
buffer *Buffer
buffered bool
}
// NewBufferedWriter creates a new BufferedWriter.
func NewBufferedWriter(writer io.Writer) *BufferedWriter {
return NewBufferedWriterSize(writer, 1024)
}
// NewBufferedWriterSize creates a BufferedWriter with specified buffer size.
func NewBufferedWriterSize(writer io.Writer, size uint32) *BufferedWriter {
return &BufferedWriter{
writer: writer,
buffer: NewLocal(int(size)),
buffered: true,
}
}
// Write implements io.Writer.
func (w *BufferedWriter) Write(b []byte) (int, error) {
if !w.buffered || w.buffer == nil {
return w.writer.Write(b)
}
bytesWritten := 0
for bytesWritten < len(b) {
nBytes, err := w.buffer.Write(b[bytesWritten:])
if err != nil {
return bytesWritten, err
}
bytesWritten += nBytes
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return bytesWritten, err
}
}
}
return bytesWritten, nil
}
// Flush writes all buffered content into underlying writer, if any.
func (w *BufferedWriter) Flush() error {
defer w.buffer.Clear()
for !w.buffer.IsEmpty() {
nBytes, err := w.writer.Write(w.buffer.Bytes())
if err != nil {
return err
}
w.buffer.SliceFrom(nBytes)
}
return nil
}
// IsBuffered returns true if this BufferedWriter holds a buffer.
func (w *BufferedWriter) IsBuffered() bool {
return w.buffered
}
// SetBuffered controls whether the BufferedWriter holds a buffer for writing. If not buffered, any write() calls into underlying writer directly.
func (w *BufferedWriter) SetBuffered(cached bool) error {
w.buffered = cached
if !cached && !w.buffer.IsEmpty() {
return w.Flush()
}
return nil
}

View File

@ -1,54 +0,0 @@
package buf_test
import (
"crypto/rand"
"testing"
"v2ray.com/core/common"
. "v2ray.com/core/common/buf"
. "v2ray.com/ext/assert"
)
func TestBufferedWriter(t *testing.T) {
assert := With(t)
content := New()
writer := NewBufferedWriter(content)
assert(writer.IsBuffered(), IsTrue)
payload := make([]byte, 16)
nBytes, err := writer.Write(payload)
assert(nBytes, Equals, 16)
assert(err, IsNil)
assert(content.IsEmpty(), IsTrue)
assert(writer.SetBuffered(false), IsNil)
assert(content.Len(), Equals, 16)
}
func TestBufferedWriterLargePayload(t *testing.T) {
assert := With(t)
content := NewLocal(128 * 1024)
writer := NewBufferedWriter(content)
assert(writer.IsBuffered(), IsTrue)
payload := make([]byte, 64*1024)
common.Must2(rand.Read(payload))
nBytes, err := writer.Write(payload[:512])
assert(nBytes, Equals, 512)
assert(err, IsNil)
assert(content.IsEmpty(), IsTrue)
nBytes, err = writer.Write(payload[512:])
assert(err, IsNil)
assert(writer.Flush(), IsNil)
assert(nBytes, Equals, 64*1024-512)
assert(content.Bytes(), Equals, payload)
}

View File

@ -17,7 +17,7 @@ type copyHandler struct {
}
func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
mb, err := reader.Read()
mb, err := reader.ReadMultiBuffer()
if err != nil {
for _, handler := range h.onReadError {
err = handler(err)
@ -27,7 +27,7 @@ func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
}
func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
err := writer.Write(mb)
err := writer.WriteMultiBuffer(mb)
if err != nil {
for _, handler := range h.onWriteError {
err = handler(err)
@ -36,6 +36,10 @@ func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
return err
}
type SizeCounter struct {
Size int64
}
type CopyOption func(*copyHandler)
func IgnoreReaderError() CopyOption {
@ -62,6 +66,14 @@ func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
}
}
func CountSize(sc *SizeCounter) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(b MultiBuffer) {
sc.Size += int64(b.Len())
})
}
}
func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
for {
buffer, err := handler.readFrom(reader)

View File

@ -5,10 +5,10 @@ import (
"time"
)
// Reader extends io.Reader with alloc.Buffer.
// Reader extends io.Reader with MultiBuffer.
type Reader interface {
// Read reads content from underlying reader, and put it into an alloc.Buffer.
Read() (MultiBuffer, error)
// ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
ReadMultiBuffer() (MultiBuffer, error)
}
// ErrReadTimeout is an error that happens with IO timeout.
@ -19,10 +19,10 @@ type TimeoutReader interface {
ReadTimeout(time.Duration) (MultiBuffer, error)
}
// Writer extends io.Writer with alloc.Buffer.
// Writer extends io.Writer with MultiBuffer.
type Writer interface {
// Write writes an alloc.Buffer into underlying writer.
Write(MultiBuffer) error
// WriteMultiBuffer writes a MultiBuffer into underlying writer.
WriteMultiBuffer(MultiBuffer) error
}
// ReadFrom creates a Supplier to read from a given io.Reader.
@ -49,45 +49,21 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier {
// NewReader creates a new Reader.
// The Reader instance doesn't take the ownership of reader.
func NewReader(reader io.Reader) Reader {
if mr, ok := reader.(MultiBufferReader); ok {
return &readerAdpater{
MultiBufferReader: mr,
}
if mr, ok := reader.(Reader); ok {
return mr
}
return &BytesToBufferReader{
reader: reader,
}
}
// ToBytesReader converts a Reaaer to io.Reader.
func ToBytesReader(stream Reader) io.Reader {
return &bufferToBytesReader{
stream: stream,
}
return NewBytesToBufferReader(reader)
}
// NewWriter creates a new Writer.
func NewWriter(writer io.Writer) Writer {
if mw, ok := writer.(MultiBufferWriter); ok {
return &writerAdapter{
writer: mw,
}
if mw, ok := writer.(Writer); ok {
return mw
}
return &BufferToBytesWriter{
writer: writer,
}
}
func NewMergingWriter(writer io.Writer) Writer {
return NewMergingWriterSize(writer, 4096)
}
func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
return &mergingWriter{
writer: writer,
buffer: make([]byte, size),
Writer: writer,
}
}
@ -96,10 +72,3 @@ func NewSequentialWriter(writer io.Writer) Writer {
writer: writer,
}
}
// ToBytesWriter converts a Writer to io.Writer
func ToBytesWriter(writer Writer) io.Writer {
return &bytesToBufferWriter{
writer: writer,
}
}

View File

@ -8,16 +8,6 @@ import (
"v2ray.com/core/common/errors"
)
// MultiBufferWriter is a writer that writes MultiBuffer.
type MultiBufferWriter interface {
WriteMultiBuffer(MultiBuffer) error
}
// MultiBufferReader is a reader that reader payload as MultiBuffer.
type MultiBufferReader interface {
ReadMultiBuffer() (MultiBuffer, error)
}
// ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF.
func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) {
mb := NewMultiBufferCap(128)

View File

@ -8,19 +8,19 @@ import (
// BytesToBufferReader is a Reader that adjusts its reading speed automatically.
type BytesToBufferReader struct {
reader io.Reader
io.Reader
buffer []byte
}
func NewBytesToBufferReader(reader io.Reader) Reader {
return &BytesToBufferReader{
reader: reader,
Reader: reader,
}
}
func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) {
b := New()
if err := b.Reset(ReadFrom(r.reader)); err != nil {
if err := b.Reset(ReadFrom(r.Reader)); err != nil {
b.Release()
return nil, err
}
@ -30,13 +30,13 @@ func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) {
return NewMultiBufferValue(b), nil
}
// Read implements Reader.Read().
func (r *BytesToBufferReader) Read() (MultiBuffer, error) {
// ReadMultiBuffer implements Reader.
func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.buffer == nil {
return r.readSmall()
}
nBytes, err := r.reader.Read(r.buffer)
nBytes, err := r.Reader.Read(r.buffer)
if err != nil {
return nil, err
}
@ -46,20 +46,33 @@ func (r *BytesToBufferReader) Read() (MultiBuffer, error) {
return mb, nil
}
type readerAdpater struct {
MultiBufferReader
}
func (r *readerAdpater) Read() (MultiBuffer, error) {
return r.ReadMultiBuffer()
}
type bufferToBytesReader struct {
type BufferedReader struct {
stream Reader
legacyReader io.Reader
leftOver MultiBuffer
buffered bool
}
func (r *bufferToBytesReader) Read(b []byte) (int, error) {
func NewBufferedReader(reader Reader) *BufferedReader {
r := &BufferedReader{
stream: reader,
buffered: true,
}
if lr, ok := reader.(io.Reader); ok {
r.legacyReader = lr
}
return r
}
func (r *BufferedReader) SetBuffered(f bool) {
r.buffered = f
}
func (r *BufferedReader) IsBuffered() bool {
return r.buffered
}
func (r *BufferedReader) Read(b []byte) (int, error) {
if r.leftOver != nil {
nBytes, _ := r.leftOver.Read(b)
if r.leftOver.IsEmpty() {
@ -69,7 +82,11 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) {
return nBytes, nil
}
mb, err := r.stream.Read()
if !r.buffered && r.legacyReader != nil {
return r.legacyReader.Read(b)
}
mb, err := r.stream.ReadMultiBuffer()
if err != nil {
return 0, err
}
@ -81,39 +98,39 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) {
return nBytes, nil
}
func (r *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) {
func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.leftOver != nil {
mb := r.leftOver
r.leftOver = nil
return mb, nil
}
return r.stream.Read()
return r.stream.ReadMultiBuffer()
}
func (r *bufferToBytesReader) writeToInternal(writer io.Writer) (int64, error) {
func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
mbWriter := NewWriter(writer)
totalBytes := int64(0)
if r.leftOver != nil {
totalBytes += int64(r.leftOver.Len())
if err := mbWriter.Write(r.leftOver); err != nil {
if err := mbWriter.WriteMultiBuffer(r.leftOver); err != nil {
return 0, err
}
}
for {
mb, err := r.stream.Read()
mb, err := r.stream.ReadMultiBuffer()
if err != nil {
return totalBytes, err
}
totalBytes += int64(mb.Len())
if err := mbWriter.Write(mb); err != nil {
if err := mbWriter.WriteMultiBuffer(mb); err != nil {
return totalBytes, err
}
}
}
func (r *bufferToBytesReader) WriteTo(writer io.Writer) (int64, error) {
func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
nBytes, err := r.writeToInternal(writer)
if errors.Cause(err) == io.EOF {
return nBytes, nil

View File

@ -15,11 +15,11 @@ func TestAdaptiveReader(t *testing.T) {
assert := With(t)
reader := NewReader(bytes.NewReader(make([]byte, 1024*1024)))
b, err := reader.Read()
b, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(b.Len(), Equals, 2*1024)
b, err = reader.Read()
b, err = reader.ReadMultiBuffer()
assert(err, IsNil)
assert(b.Len(), Equals, 32*1024)
}
@ -28,22 +28,23 @@ func TestBytesReaderWriteTo(t *testing.T) {
assert := With(t)
stream := ray.NewStream(context.Background())
reader := ToBytesReader(stream)
reader := NewBufferedReader(stream)
b1 := New()
b1.AppendBytes('a', 'b', 'c')
b2 := New()
b2.AppendBytes('e', 'f', 'g')
assert(stream.Write(NewMultiBufferValue(b1, b2)), IsNil)
assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
stream.Close()
stream2 := ray.NewStream(context.Background())
writer := ToBytesWriter(stream2)
writer := NewBufferedWriter(stream2)
writer.SetBuffered(false)
nBytes, err := io.Copy(writer, reader)
assert(err, IsNil)
assert(nBytes, Equals, int64(6))
mb, err := stream2.Read()
mb, err := stream2.ReadMultiBuffer()
assert(err, IsNil)
assert(len(mb), Equals, 2)
assert(mb[0].String(), Equals, "abc")
@ -54,16 +55,16 @@ func TestBytesReaderMultiBuffer(t *testing.T) {
assert := With(t)
stream := ray.NewStream(context.Background())
reader := ToBytesReader(stream)
reader := NewBufferedReader(stream)
b1 := New()
b1.AppendBytes('a', 'b', 'c')
b2 := New()
b2.AppendBytes('e', 'f', 'g')
assert(stream.Write(NewMultiBufferValue(b1, b2)), IsNil)
assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
stream.Close()
mbReader := NewReader(reader)
mb, err := mbReader.Read()
mb, err := mbReader.ReadMultiBuffer()
assert(err, IsNil)
assert(len(mb), Equals, 2)
assert(mb[0].String(), Equals, "abc")

View File

@ -8,49 +8,142 @@ import (
// BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
type BufferToBytesWriter struct {
writer io.Writer
io.Writer
}
// Write implements Writer.Write(). Write() takes ownership of the given buffer.
func (w *BufferToBytesWriter) Write(mb MultiBuffer) error {
func NewBufferToBytesWriter(writer io.Writer) *BufferToBytesWriter {
return &BufferToBytesWriter{
Writer: writer,
}
}
// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release()
bs := mb.ToNetBuffers()
_, err := bs.WriteTo(w.writer)
_, err := bs.WriteTo(w)
return err
}
type writerAdapter struct {
writer MultiBufferWriter
func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
if readerFrom, ok := w.Writer.(io.ReaderFrom); ok {
return readerFrom.ReadFrom(reader)
}
// Write implements buf.MultiBufferWriter.
func (w *writerAdapter) Write(mb MultiBuffer) error {
return w.writer.WriteMultiBuffer(mb)
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
type mergingWriter struct {
writer io.Writer
buffer []byte
type BufferedWriter struct {
writer Writer
legacyWriter io.Writer
buffer *Buffer
buffered bool
}
func (w *mergingWriter) Write(mb MultiBuffer) error {
defer mb.Release()
func NewBufferedWriter(writer Writer) *BufferedWriter {
w := &BufferedWriter{
writer: writer,
buffer: New(),
buffered: true,
}
if lw, ok := writer.(io.Writer); ok {
w.legacyWriter = lw
}
return w
}
for !mb.IsEmpty() {
nBytes, _ := mb.Read(w.buffer)
if _, err := w.writer.Write(w.buffer[:nBytes]); err != nil {
func (w *BufferedWriter) Write(b []byte) (int, error) {
if !w.buffered && w.legacyWriter != nil {
return w.legacyWriter.Write(b)
}
totalBytes := 0
for len(b) > 0 {
nBytes, err := w.buffer.Write(b)
totalBytes += nBytes
if err != nil {
return totalBytes, err
}
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return totalBytes, err
}
}
b = b[nBytes:]
}
return totalBytes, nil
}
func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
if !w.buffered {
return w.writer.WriteMultiBuffer(b)
}
defer b.Release()
for !b.IsEmpty() {
if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil {
return err
}
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return err
}
}
}
return nil
}
func (w *BufferedWriter) Flush() error {
if !w.buffer.IsEmpty() {
if err := w.writer.WriteMultiBuffer(NewMultiBufferValue(w.buffer)); err != nil {
return err
}
if w.buffered {
w.buffer = New()
} else {
w.buffer = nil
}
}
return nil
}
func (w *BufferedWriter) SetBuffered(f bool) error {
w.buffered = f
if !f {
return w.Flush()
}
return nil
}
func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
var sc SizeCounter
if !w.buffer.IsEmpty() {
sc.Size += int64(w.buffer.Len())
if err := w.Flush(); err != nil {
return sc.Size, err
}
}
if readerFrom, ok := w.writer.(io.ReaderFrom); ok {
return readerFrom.ReadFrom(reader)
}
w.buffered = false
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
type seqWriter struct {
writer io.Writer
}
func (w *seqWriter) Write(mb MultiBuffer) error {
func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
@ -65,49 +158,9 @@ func (w *seqWriter) Write(mb MultiBuffer) error {
return nil
}
var (
_ MultiBufferWriter = (*bytesToBufferWriter)(nil)
)
type bytesToBufferWriter struct {
writer Writer
}
// Write implements io.Writer.
func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
mb := NewMultiBufferCap(len(payload)/Size + 1)
mb.Write(payload)
if err := w.writer.Write(mb); err != nil {
return 0, err
}
return len(payload), nil
}
func (w *bytesToBufferWriter) WriteMultiBuffer(mb MultiBuffer) error {
return w.writer.Write(mb)
}
func (w *bytesToBufferWriter) ReadFrom(reader io.Reader) (int64, error) {
mbReader := NewReader(reader)
totalBytes := int64(0)
for {
mb, err := mbReader.Read()
if errors.Cause(err) == io.EOF {
break
} else if err != nil {
return totalBytes, err
}
totalBytes += int64(mb.Len())
if err := w.writer.Write(mb); err != nil {
return totalBytes, err
}
}
return totalBytes, nil
}
type noOpWriter struct{}
func (noOpWriter) Write(b MultiBuffer) error {
func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
b.Release()
return nil
}

View File

@ -25,9 +25,11 @@ func TestWriter(t *testing.T) {
writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
writer := NewWriter(NewBufferedWriter(writeBuffer))
err := writer.Write(NewMultiBufferValue(lb))
writer := NewBufferedWriter(NewWriter(writeBuffer))
writer.SetBuffered(false)
err := writer.WriteMultiBuffer(NewMultiBufferValue(lb))
assert(err, IsNil)
assert(writer.Flush(), IsNil)
assert(expectedBytes, Equals, writeBuffer.Bytes())
}
@ -36,20 +38,21 @@ func TestBytesWriterReadFrom(t *testing.T) {
cache := ray.NewStream(context.Background())
reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192))
_, err := reader.WriteTo(ToBytesWriter(cache))
writer := NewBufferedWriter(cache)
writer.SetBuffered(false)
_, err := reader.WriteTo(writer)
assert(err, IsNil)
mb, err := cache.Read()
mb, err := cache.ReadMultiBuffer()
assert(err, IsNil)
assert(mb.Len(), Equals, 8192)
assert(len(mb), Equals, 4)
}
func TestDiscardBytes(t *testing.T) {
assert := With(t)
b := New()
common.Must(b.Reset(ReadFrom(rand.Reader)))
common.Must(b.Reset(ReadFullFrom(rand.Reader, Size)))
nBytes, err := io.Copy(DiscardBytes, b)
assert(nBytes, Equals, int64(Size))
@ -64,7 +67,7 @@ func TestDiscardBytesMultiBuffer(t *testing.T) {
common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size)))
r := NewReader(buffer)
nBytes, err := io.Copy(DiscardBytes, ToBytesReader(r))
nBytes, err := io.Copy(DiscardBytes, NewBufferedReader(r))
assert(nBytes, Equals, int64(size))
assert(err, IsNil)
}

View File

@ -151,7 +151,7 @@ func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) {
return b, nil
}
func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
b, err := r.readChunk(true)
if err != nil {
return nil, err
@ -193,81 +193,97 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
return mb, nil
}
const (
WriteSize = 1024
)
type AuthenticationWriter struct {
auth Authenticator
buffer []byte
payload []byte
writer *buf.BufferedWriter
writer buf.Writer
sizeParser ChunkSizeEncoder
transferType protocol.TransferType
}
func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter {
const payloadSize = 1024
return &AuthenticationWriter{
auth: auth,
buffer: make([]byte, payloadSize+sizeParser.SizeBytes()+auth.Overhead()),
payload: make([]byte, payloadSize),
writer: buf.NewBufferedWriterSize(writer, readerBufferSize),
writer: buf.NewWriter(writer),
sizeParser: sizeParser,
transferType: transferType,
}
}
func (w *AuthenticationWriter) append(b []byte) error {
encryptedSize := len(b) + w.auth.Overhead()
buffer := w.sizeParser.Encode(uint16(encryptedSize), w.buffer[:0])
func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
encryptedSize := b.Len() + w.auth.Overhead()
buffer, err := w.auth.Seal(buffer, b)
if err != nil {
return err
eb := buf.New()
common.Must(eb.Reset(func(bb []byte) (int, error) {
w.sizeParser.Encode(uint16(encryptedSize), bb[:0])
return w.sizeParser.SizeBytes(), nil
}))
if err := eb.AppendSupplier(func(bb []byte) (int, error) {
_, err := w.auth.Seal(bb[:0], b.Bytes())
return encryptedSize, err
}); err != nil {
eb.Release()
return nil, err
}
if _, err := w.writer.Write(buffer); err != nil {
return err
}
return nil
return eb, nil
}
func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
defer mb.Release()
mb2Write := buf.NewMultiBufferCap(len(mb) + 10)
for {
n, _ := mb.Read(w.payload)
if err := w.append(w.payload[:n]); err != nil {
b := buf.New()
common.Must(b.Reset(func(bb []byte) (int, error) {
return mb.Read(bb[:WriteSize])
}))
eb, err := w.seal(b)
b.Release()
if err != nil {
mb2Write.Release()
return err
}
mb2Write.Append(eb)
if mb.IsEmpty() {
break
}
}
return w.writer.Flush()
return w.writer.WriteMultiBuffer(mb2Write)
}
func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
defer mb.Release()
mb2Write := buf.NewMultiBufferCap(len(mb) * 2)
for {
b := mb.SplitFirst()
if b == nil {
b = buf.New()
}
if err := w.append(b.Bytes()); err != nil {
eb, err := w.seal(b)
b.Release()
if err != nil {
mb2Write.Release()
return err
}
b.Release()
mb2Write.Append(eb)
if mb.IsEmpty() {
break
}
}
return w.writer.Flush()
return w.writer.WriteMultiBuffer(mb2Write)
}
func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error {
func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if w.transferType == protocol.TransferTypeStream {
return w.writeStream(mb)
}

View File

@ -42,9 +42,9 @@ func TestAuthenticationReaderWriter(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
assert(writer.Write(buf.NewMultiBufferValue(payload)), IsNil)
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil)
assert(cache.Len(), Equals, 83360)
assert(writer.Write(buf.MultiBuffer{}), IsNil)
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert(err, IsNil)
reader := NewAuthenticationReader(&AEADAuthenticator{
@ -58,7 +58,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
var mb buf.MultiBuffer
for mb.Len() < len(rawPayload) {
mb2, err := reader.Read()
mb2, err := reader.ReadMultiBuffer()
assert(err, IsNil)
mb.AppendMulti(mb2)
@ -68,7 +68,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
mb.Read(mbContent)
assert(mbContent, Equals, rawPayload)
_, err = reader.Read()
_, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF)
}
@ -104,9 +104,9 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
pb2.Append([]byte("efgh"))
payload.Append(pb2)
assert(writer.Write(payload), IsNil)
assert(writer.WriteMultiBuffer(payload), IsNil)
assert(cache.Len(), GreaterThan, 0)
assert(writer.Write(buf.MultiBuffer{}), IsNil)
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert(err, IsNil)
reader := NewAuthenticationReader(&AEADAuthenticator{
@ -117,7 +117,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
mb, err := reader.Read()
mb, err := reader.ReadMultiBuffer()
assert(err, IsNil)
b1 := mb.SplitFirst()
@ -126,6 +126,6 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
assert(b2.String(), Equals, "efgh")
assert(mb.IsEmpty(), IsTrue)
_, err = reader.Read()
_, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF)
}

View File

@ -56,7 +56,7 @@ func (r *ChunkStreamReader) readAtLeast(size int) error {
mb := r.leftOver
r.leftOver = nil
for mb.Len() < size {
extra, err := r.reader.Read()
extra, err := r.reader.ReadMultiBuffer()
if err != nil {
mb.Release()
return err
@ -78,7 +78,7 @@ func (r *ChunkStreamReader) readSize() (uint16, error) {
return r.sizeDecoder.Decode(r.buffer)
}
func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) {
func (r *ChunkStreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
size := r.leftOverSize
if size == 0 {
nextSize, err := r.readSize()
@ -129,7 +129,7 @@ func NewChunkStreamWriter(sizeEncoder ChunkSizeEncoder, writer io.Writer) *Chunk
}
}
func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error {
func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
const sliceSize = 8192
mbLen := mb.Len()
mb2Write := buf.NewMultiBufferCap(mbLen/buf.Size + mbLen/sliceSize + 2)
@ -150,5 +150,5 @@ func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error {
}
}
return w.writer.Write(mb2Write)
return w.writer.WriteMultiBuffer(mb2Write)
}

View File

@ -19,26 +19,26 @@ func TestChunkStreamIO(t *testing.T) {
b := buf.New()
b.AppendBytes('a', 'b', 'c', 'd')
assert(writer.Write(buf.NewMultiBufferValue(b)), IsNil)
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
b = buf.New()
b.AppendBytes('e', 'f', 'g')
assert(writer.Write(buf.NewMultiBufferValue(b)), IsNil)
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
assert(writer.Write(buf.MultiBuffer{}), IsNil)
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert(cache.Len(), Equals, 13)
mb, err := reader.Read()
mb, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(mb.Len(), Equals, 4)
assert(mb[0].Bytes(), Equals, []byte("abcd"))
mb, err = reader.Read()
mb, err = reader.ReadMultiBuffer()
assert(err, IsNil)
assert(mb.Len(), Equals, 3)
assert(mb[0].Bytes(), Equals, []byte("efg"))
_, err = reader.Read()
_, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF)
}

View File

@ -28,7 +28,7 @@ func (r *CryptionReader) Read(data []byte) (int, error) {
}
var (
_ buf.MultiBufferWriter = (*CryptionWriter)(nil)
_ buf.Writer = (*CryptionWriter)(nil)
)
type CryptionWriter struct {

View File

@ -29,7 +29,7 @@ func (*NoneResponse) WriteTo(buf.Writer) {}
func (*HTTPResponse) WriteTo(writer buf.Writer) {
b := buf.NewLocal(512)
common.Must(b.AppendSupplier(serial.WriteString(http403response)))
writer.Write(buf.NewMultiBufferValue(b))
writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
}
// GetInternalResponse converts response settings from proto to internal data structure.

View File

@ -255,15 +255,18 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, rea
requestDone := signal.ExecuteAsync(func() error {
request.Header.Set("Connection", "close")
requestWriter := buf.ToBytesWriter(ray.InboundInput())
requestWriter := buf.NewBufferedWriter(ray.InboundInput())
if err := request.Write(requestWriter); err != nil {
return err
}
if err := requestWriter.Flush(); err != nil {
return err
}
return nil
})
responseDone := signal.ExecuteAsync(func() error {
responseReader := bufio.NewReaderSize(buf.ToBytesReader(ray.InboundOutput()), 2048)
responseReader := bufio.NewReaderSize(buf.NewBufferedReader(ray.InboundOutput()), 2048)
response, err := http.ReadResponse(responseReader, request)
if err == nil {
StripHopByHopHeaders(response.Header)

View File

@ -93,7 +93,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
if request.Command == protocol.RequestCommandTCP {
bufferedWriter := buf.NewBufferedWriter(conn)
bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
bodyWriter, err := WriteTCPRequest(request, bufferedWriter)
if err != nil {
return newError("failed to write request").Base(err)

View File

@ -68,7 +68,7 @@ func NewChunkReader(reader io.Reader, auth *Authenticator) *ChunkReader {
}
}
func (v *ChunkReader) Read() (buf.MultiBuffer, error) {
func (v *ChunkReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer := buf.New()
if err := buffer.AppendSupplier(buf.ReadFullFrom(v.reader, 2)); err != nil {
buffer.Release()
@ -117,8 +117,8 @@ func NewChunkWriter(writer io.Writer, auth *Authenticator) *ChunkWriter {
}
}
// Write implements buf.MultiBufferWriter.
func (w *ChunkWriter) Write(mb buf.MultiBuffer) error {
// WriteMultiBuffer implements buf.Writer.
func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release()
for {

View File

@ -16,7 +16,7 @@ func TestNormalChunkReading(t *testing.T) {
0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18)
reader := NewChunkReader(buffer, NewAuthenticator(ChunkKeyGenerator(
[]byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36})))
payload, err := reader.Read()
payload, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(payload[0].Bytes(), Equals, []byte{11, 12, 13, 14, 15, 16, 17, 18})
}
@ -30,7 +30,7 @@ func TestNormalChunkWriting(t *testing.T) {
b := buf.NewLocal(256)
b.Append([]byte{11, 12, 13, 14, 15, 16, 17, 18})
err := writer.Write(buf.NewMultiBufferValue(b))
err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
assert(err, IsNil)
assert(buffer.Bytes(), Equals, []byte{0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18})
}

View File

@ -362,7 +362,7 @@ type UDPReader struct {
User *protocol.User
}
func (v *UDPReader) Read() (buf.MultiBuffer, error) {
func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer := buf.New()
err := buffer.AppendSupplier(buf.ReadFrom(v.Reader))
if err != nil {

View File

@ -112,14 +112,14 @@ func TestTCPRequest(t *testing.T) {
writer, err := WriteTCPRequest(request, cache)
assert(err, IsNil)
assert(writer.Write(buf.NewMultiBufferValue(data)), IsNil)
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(data)), IsNil)
decodedRequest, reader, err := ReadTCPSession(request.User, cache)
assert(err, IsNil)
assert(decodedRequest.Address, Equals, request.Address)
assert(decodedRequest.Port, Equals, request.Port)
decodedData, err := reader.Read()
decodedData, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(decodedData[0].String(), Equals, string(payload))
}
@ -158,19 +158,19 @@ func TestUDPReaderWriter(t *testing.T) {
b := buf.New()
b.AppendSupplier(serial.WriteString("test payload"))
err := writer.Write(buf.NewMultiBufferValue(b))
err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
assert(err, IsNil)
payload, err := reader.Read()
payload, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(payload[0].String(), Equals, "test payload")
b = buf.New()
b.AppendSupplier(serial.WriteString("test payload 2"))
err = writer.Write(buf.NewMultiBufferValue(b))
err = writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
assert(err, IsNil)
payload, err = reader.Read()
payload, err = reader.ReadMultiBuffer()
assert(err, IsNil)
assert(payload[0].String(), Equals, "test payload 2")
}

View File

@ -74,7 +74,7 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
reader := buf.NewReader(conn)
for {
mpayload, err := reader.Read()
mpayload, err := reader.ReadMultiBuffer()
if err != nil {
break
}
@ -129,7 +129,7 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error {
conn.SetReadDeadline(time.Now().Add(time.Second * 8))
bufferedReader := buf.NewBufferedReader(conn)
bufferedReader := buf.NewBufferedReader(buf.NewReader(conn))
request, bodyReader, err := ReadTCPSession(s.user, bufferedReader)
if err != nil {
log.Access(conn.RemoteAddr(), "", log.AccessRejected, err)
@ -153,17 +153,17 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
}
responseDone := signal.ExecuteAsync(func() error {
bufferedWriter := buf.NewBufferedWriter(conn)
bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
responseWriter, err := WriteTCPResponse(request, bufferedWriter)
if err != nil {
return newError("failed to write response").Base(err)
}
payload, err := ray.InboundOutput().Read()
payload, err := ray.InboundOutput().ReadMultiBuffer()
if err != nil {
return err
}
if err := responseWriter.Write(payload); err != nil {
if err := responseWriter.WriteMultiBuffer(payload); err != nil {
return err
}
payload.Release()

View File

@ -352,7 +352,7 @@ func NewUDPReader(reader io.Reader) *UDPReader {
return &UDPReader{reader: reader}
}
func (r *UDPReader) Read() (buf.MultiBuffer, error) {
func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
b := buf.New()
if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
return nil, err

View File

@ -24,11 +24,11 @@ func TestUDPEncoding(t *testing.T) {
content := []byte{'a'}
payload := buf.New()
payload.Append(content)
assert(writer.Write(buf.NewMultiBufferValue(payload)), IsNil)
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil)
reader := NewUDPReader(b)
decodedPayload, err := reader.Read()
decodedPayload, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(decodedPayload[0].Bytes(), Equals, content)
}

View File

@ -58,7 +58,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error {
conn.SetReadDeadline(time.Now().Add(time.Second * 8))
reader := buf.NewBufferedReader(conn)
reader := buf.NewBufferedReader(buf.NewReader(conn))
inboundDest, ok := proxy.InboundEntryPointFromContext(ctx)
if !ok {
@ -154,7 +154,7 @@ func (v *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
reader := buf.NewReader(conn)
for {
mpayload, err := reader.Read()
mpayload, err := reader.ReadMultiBuffer()
if err != nil {
return err
}

View File

@ -142,12 +142,12 @@ func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSess
bodyWriter := session.EncodeResponseBody(request, output)
// Optimize for small response packet
data, err := input.Read()
data, err := input.ReadMultiBuffer()
if err != nil {
return err
}
if err := bodyWriter.Write(data); err != nil {
if err := bodyWriter.WriteMultiBuffer(data); err != nil {
return err
}
data.Release()
@ -163,7 +163,7 @@ func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSess
}
if request.Option.Has(protocol.RequestOptionChunkStream) {
if err := bodyWriter.Write(buf.MultiBuffer{}); err != nil {
if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil {
return err
}
}
@ -177,7 +177,7 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
return err
}
reader := buf.NewBufferedReader(connection)
reader := buf.NewBufferedReader(buf.NewReader(connection))
session := encoding.NewServerSession(v.clients, v.sessionHistory)
request, err := session.DecodeRequestHeader(reader)
@ -213,14 +213,12 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
input := ray.InboundInput()
output := ray.InboundOutput()
reader.SetBuffered(false)
requestDone := signal.ExecuteAsync(func() error {
return transferRequest(timer, session, request, reader, input)
})
responseDone := signal.ExecuteAsync(func() error {
writer := buf.NewBufferedWriter(connection)
writer := buf.NewBufferedWriter(buf.NewWriter(connection))
defer writer.Flush()
response := &protocol.ResponseHeader{

View File

@ -106,7 +106,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
requestDone := signal.ExecuteAsync(func() error {
writer := buf.NewBufferedWriter(conn)
writer := buf.NewBufferedWriter(buf.NewWriter(conn))
if err := session.EncodeRequestHeader(request, writer); err != nil {
return newError("failed to encode request").Base(err).AtWarning()
}
@ -117,7 +117,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
return newError("failed to get first payload").Base(err)
}
if !firstPayload.IsEmpty() {
if err := bodyWriter.Write(firstPayload); err != nil {
if err := bodyWriter.WriteMultiBuffer(firstPayload); err != nil {
return newError("failed to write first payload").Base(err)
}
firstPayload.Release()
@ -132,7 +132,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
}
if request.Option.Has(protocol.RequestOptionChunkStream) {
if err := bodyWriter.Write(buf.MultiBuffer{}); err != nil {
if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil {
return err
}
}
@ -142,7 +142,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
responseDone := signal.ExecuteAsync(func() error {
defer output.Close()
reader := buf.NewBufferedReader(conn)
reader := buf.NewBufferedReader(buf.NewReader(conn))
header, err := session.DecodeResponseHeader(reader)
if err != nil {
return err

View File

@ -169,8 +169,7 @@ type SystemConnection interface {
}
var (
_ buf.MultiBufferReader = (*Connection)(nil)
_ buf.MultiBufferWriter = (*Connection)(nil)
_ buf.Reader = (*Connection)(nil)
)
// Connection is a KCP connection over UDP.
@ -265,7 +264,7 @@ func (v *Connection) OnDataOutput() {
}
}
// ReadMultiBuffer implements buf.MultiBufferReader.
// ReadMultiBuffer implements buf.Reader.
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
if v == nil {
return nil, io.EOF
@ -375,13 +374,6 @@ func (v *Connection) Write(b []byte) (int, error) {
}
}
func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.mergingWriter == nil {
c.mergingWriter = buf.NewMergingWriterSize(c, c.mss)
}
return c.mergingWriter.Write(mb)
}
func (v *Connection) SetState(state State) {
current := v.Elapsed()
atomic.StoreInt32((*int32)(&v.state), int32(state))

View File

@ -10,29 +10,23 @@ import (
//go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg tls -path Transport,Internet,TLS
var (
_ buf.MultiBufferReader = (*conn)(nil)
_ buf.MultiBufferWriter = (*conn)(nil)
_ buf.Writer = (*conn)(nil)
)
type conn struct {
net.Conn
mergingReader buf.Reader
mergingWriter buf.Writer
}
func (c *conn) ReadMultiBuffer() (buf.MultiBuffer, error) {
if c.mergingReader == nil {
c.mergingReader = buf.NewBytesToBufferReader(c.Conn)
}
return c.mergingReader.Read()
mergingWriter *buf.BufferedWriter
}
func (c *conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.mergingWriter == nil {
c.mergingWriter = buf.NewMergingWriter(c.Conn)
c.mergingWriter = buf.NewBufferedWriter(buf.NewWriter(c.Conn))
}
return c.mergingWriter.Write(mb)
if err := c.mergingWriter.WriteMultiBuffer(mb); err != nil {
return err
}
return c.mergingWriter.Flush()
}
func Client(c net.Conn, config *tls.Config) net.Conn {

View File

@ -57,7 +57,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
inboundRay, existing := v.getInboundRay(ctx, destination)
outputStream := inboundRay.InboundInput()
if outputStream != nil {
if err := outputStream.Write(buf.NewMultiBufferValue(payload)); err != nil {
if err := outputStream.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil {
v.RemoveRay(destination)
}
}
@ -71,7 +71,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
func handleInput(input ray.InputStream, callback ResponseCallback) {
for {
mb, err := input.Read()
mb, err := input.ReadMultiBuffer()
if err != nil {
break
}

View File

@ -28,11 +28,11 @@ func TestSameDestinationDispatching(t *testing.T) {
link := ray.NewRay(ctx)
go func() {
for {
data, err := link.OutboundInput().Read()
data, err := link.OutboundInput().ReadMultiBuffer()
if err != nil {
break
}
err = link.OutboundOutput().Write(data)
err = link.OutboundOutput().WriteMultiBuffer(data)
assert(err, IsNil)
}
}()

View File

@ -11,8 +11,7 @@ import (
)
var (
_ buf.MultiBufferReader = (*connection)(nil)
_ buf.MultiBufferWriter = (*connection)(nil)
_ buf.Writer = (*connection)(nil)
)
// connection is a wrapper for net.Conn over WebSocket connection.
@ -20,8 +19,7 @@ type connection struct {
conn *websocket.Conn
reader io.Reader
mergingReader buf.Reader
mergingWriter buf.Writer
mergingWriter *buf.BufferedWriter
}
func newConnection(conn *websocket.Conn) *connection {
@ -47,13 +45,6 @@ func (c *connection) Read(b []byte) (int, error) {
}
}
func (c *connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
if c.mergingReader == nil {
c.mergingReader = buf.NewBytesToBufferReader(c)
}
return c.mergingReader.Read()
}
func (c *connection) getReader() (io.Reader, error) {
if c.reader != nil {
return c.reader, nil
@ -77,9 +68,12 @@ func (c *connection) Write(b []byte) (int, error) {
func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.mergingWriter == nil {
c.mergingWriter = buf.NewMergingWriter(c)
c.mergingWriter = buf.NewBufferedWriter(buf.NewBufferToBytesWriter(c))
}
return c.mergingWriter.Write(mb)
if err := c.mergingWriter.WriteMultiBuffer(mb); err != nil {
return err
}
return c.mergingWriter.Flush()
}
func (c *connection) Close() error {

View File

@ -106,7 +106,7 @@ func (s *Stream) Peek(b *buf.Buffer) {
}
// Read reads data from the Stream.
func (s *Stream) Read() (buf.MultiBuffer, error) {
func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) {
for {
mb, err := s.getData()
if err != nil {
@ -178,7 +178,7 @@ func (s *Stream) waitForStreamSize() error {
}
// Write writes more data into the Stream.
func (s *Stream) Write(data buf.MultiBuffer) error {
func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error {
if data.IsEmpty() {
return nil
}

View File

@ -16,18 +16,18 @@ func TestStreamIO(t *testing.T) {
stream := NewStream(context.Background())
b1 := buf.New()
b1.AppendBytes('a')
assert(stream.Write(buf.NewMultiBufferValue(b1)), IsNil)
assert(stream.WriteMultiBuffer(buf.NewMultiBufferValue(b1)), IsNil)
_, err := stream.Read()
_, err := stream.ReadMultiBuffer()
assert(err, IsNil)
stream.Close()
_, err = stream.Read()
_, err = stream.ReadMultiBuffer()
assert(err, Equals, io.EOF)
b2 := buf.New()
b2.AppendBytes('b')
err = stream.Write(buf.NewMultiBufferValue(b2))
err = stream.WriteMultiBuffer(buf.NewMultiBufferValue(b2))
assert(err, Equals, io.ErrClosedPipe)
}
@ -37,13 +37,13 @@ func TestStreamClose(t *testing.T) {
stream := NewStream(context.Background())
b1 := buf.New()
b1.AppendBytes('a')
assert(stream.Write(buf.NewMultiBufferValue(b1)), IsNil)
assert(stream.WriteMultiBuffer(buf.NewMultiBufferValue(b1)), IsNil)
stream.Close()
_, err := stream.Read()
_, err := stream.ReadMultiBuffer()
assert(err, IsNil)
_, err = stream.Read()
_, err = stream.ReadMultiBuffer()
assert(err, Equals, io.EOF)
}