From 2f565bfd5eb40dc98dcf8defeec437922be68a1b Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 16 Apr 2017 09:57:28 +0200 Subject: [PATCH] simplify ray stream --- common/buf/merge_reader.go | 34 ------- common/buf/merge_reader_test.go | 33 ------ common/buf/multi_buffer.go | 2 +- proxy/freedom/freedom.go | 26 ++++- proxy/shadowsocks/client.go | 3 +- proxy/shadowsocks/server.go | 5 +- proxy/vmess/inbound/inbound.go | 8 +- proxy/vmess/outbound/outbound.go | 7 +- transport/ray/direct.go | 170 ++++++++++++++++--------------- 9 files changed, 120 insertions(+), 168 deletions(-) delete mode 100644 common/buf/merge_reader.go delete mode 100644 common/buf/merge_reader_test.go diff --git a/common/buf/merge_reader.go b/common/buf/merge_reader.go deleted file mode 100644 index 9d3f43752..000000000 --- a/common/buf/merge_reader.go +++ /dev/null @@ -1,34 +0,0 @@ -package buf - -type MergingReader struct { - reader Reader - timeoutReader TimeoutReader -} - -func NewMergingReader(reader Reader) Reader { - return &MergingReader{ - reader: reader, - timeoutReader: reader.(TimeoutReader), - } -} - -func (r *MergingReader) Read() (MultiBuffer, error) { - mb, err := r.reader.Read() - if err != nil { - return nil, err - } - - if r.timeoutReader == nil { - return mb, nil - } - - for { - mb2, err := r.timeoutReader.ReadTimeout(0) - if err != nil { - break - } - mb.AppendMulti(mb2) - } - - return mb, nil -} diff --git a/common/buf/merge_reader_test.go b/common/buf/merge_reader_test.go deleted file mode 100644 index 48a0aefec..000000000 --- a/common/buf/merge_reader_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package buf_test - -import ( - "testing" - - "context" - - . "v2ray.com/core/common/buf" - "v2ray.com/core/testing/assert" - "v2ray.com/core/transport/ray" -) - -func TestMergingReader(t *testing.T) { - assert := assert.On(t) - - stream := ray.NewStream(context.Background()) - b1 := New() - b1.AppendBytes('a', 'b', 'c') - stream.Write(NewMultiBufferValue(b1)) - - b2 := New() - b2.AppendBytes('e', 'f', 'g') - stream.Write(NewMultiBufferValue(b2)) - - b3 := New() - b3.AppendBytes('h', 'i', 'j') - stream.Write(NewMultiBufferValue(b3)) - - reader := NewMergingReader(stream) - b, err := reader.Read() - assert.Error(err).IsNil() - assert.Int(b.Len()).Equals(9) -} diff --git a/common/buf/multi_buffer.go b/common/buf/multi_buffer.go index d1fe389b2..dd6b706ca 100644 --- a/common/buf/multi_buffer.go +++ b/common/buf/multi_buffer.go @@ -13,7 +13,7 @@ type MultiBufferReader interface { type MultiBuffer []*Buffer func NewMultiBuffer() MultiBuffer { - return MultiBuffer(make([]*Buffer, 0, 8)) + return MultiBuffer(make([]*Buffer, 0, 32)) } func NewMultiBufferValue(b ...*Buffer) MultiBuffer { diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index a1d7cef31..9f3209b3a 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -4,6 +4,7 @@ package freedom import ( "context" + "io" "runtime" "time" @@ -112,8 +113,13 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial ctx, timer := signal.CancelAfterInactivity(ctx, timeout) requestDone := signal.ExecuteAsync(func() error { - v2writer := buf.NewWriter(conn) - if err := buf.PipeUntilEOF(timer, input, v2writer); err != nil { + var writer buf.Writer + if destination.Network == net.Network_TCP { + writer = buf.NewWriter(conn) + } else { + writer = &seqWriter{writer: conn} + } + if err := buf.PipeUntilEOF(timer, input, writer); err != nil { return err } return nil @@ -145,3 +151,19 @@ func init() { return New(ctx, config.(*Config)) })) } + +type seqWriter struct { + writer io.Writer +} + +func (w *seqWriter) Write(mb buf.MultiBuffer) error { + defer mb.Release() + + for _, b := range mb { + if _, err := w.writer.Write(b.Bytes()); err != nil { + return err + } + } + + return nil +} diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 8b78adaa6..5da0d01c0 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -105,8 +105,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale } requestDone := signal.ExecuteAsync(func() error { - mergedInput := buf.NewMergingReader(outboundRay.OutboundInput()) - if err := buf.PipeUntilEOF(timer, mergedInput, bodyWriter); err != nil { + if err := buf.PipeUntilEOF(timer, outboundRay.OutboundInput(), bodyWriter); err != nil { return err } return nil diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index a969f857f..f22a19f9b 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -160,8 +160,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, return newError("failed to write response").Base(err) } - mergeReader := buf.NewMergingReader(ray.InboundOutput()) - payload, err := mergeReader.Read() + payload, err := ray.InboundOutput().Read() if err != nil { return err } @@ -174,7 +173,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, return err } - if err := buf.PipeUntilEOF(timer, mergeReader, responseWriter); err != nil { + if err := buf.PipeUntilEOF(timer, ray.InboundOutput(), responseWriter); err != nil { return newError("failed to transport all TCP response").Base(err) } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index d6137364f..773b4a3e2 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -140,12 +140,8 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio bodyWriter := session.EncodeResponseBody(request, output) - var reader buf.Reader = input - if request.Command == protocol.RequestCommandTCP { - reader = buf.NewMergingReader(input) - } // Optimize for small response packet - data, err := reader.Read() + data, err := input.Read() if err != nil { return err } @@ -161,7 +157,7 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio } } - if err := buf.PipeUntilEOF(timer, reader, bodyWriter); err != nil { + if err := buf.PipeUntilEOF(timer, input, bodyWriter); err != nil { return err } diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 13ac94bca..544c4d322 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -123,12 +123,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial return err } - var inputReader buf.Reader = input - if request.Command == protocol.RequestCommandTCP { - inputReader = buf.NewMergingReader(input) - } - - if err := buf.PipeUntilEOF(timer, inputReader, bodyWriter); err != nil { + if err := buf.PipeUntilEOF(timer, input, bodyWriter); err != nil { return err } diff --git a/transport/ray/direct.go b/transport/ray/direct.go index 9dd0a6253..743852aeb 100644 --- a/transport/ray/direct.go +++ b/transport/ray/direct.go @@ -3,15 +3,12 @@ package ray import ( "context" "io" + "sync" "time" "v2ray.com/core/common/buf" ) -const ( - bufferSize = 512 -) - // NewRay creates a new Ray for direct traffic transport. func NewRay(ctx context.Context) Ray { return &directRay{ @@ -42,121 +39,132 @@ func (v *directRay) InboundOutput() InputStream { } type Stream struct { - buffer chan buf.MultiBuffer + access sync.Mutex + data buf.MultiBuffer ctx context.Context - close chan bool - err chan bool + wakeup chan bool + close bool + err bool } func NewStream(ctx context.Context) *Stream { return &Stream{ ctx: ctx, - buffer: make(chan buf.MultiBuffer, bufferSize), - close: make(chan bool), - err: make(chan bool), + wakeup: make(chan bool, 1), } } -func (v *Stream) Read() (buf.MultiBuffer, error) { - select { - case <-v.ctx.Done(): +func (s *Stream) getData() (buf.MultiBuffer, error) { + s.access.Lock() + defer s.access.Unlock() + + if s.data != nil { + mb := s.data + s.data = nil + return mb, nil + } + + if s.close { + return nil, io.EOF + } + + if s.err { return nil, io.ErrClosedPipe - case <-v.err: - return nil, io.ErrClosedPipe - case b := <-v.buffer: - return b, nil - default: + } + + return nil, nil +} + +func (s *Stream) Read() (buf.MultiBuffer, error) { + for { + mb, err := s.getData() + if err != nil { + return nil, err + } + + if mb != nil { + return mb, nil + } + select { - case <-v.ctx.Done(): - return nil, io.ErrClosedPipe - case b := <-v.buffer: - return b, nil - case <-v.close: - return nil, io.EOF - case <-v.err: + case <-s.ctx.Done(): return nil, io.ErrClosedPipe + case <-s.wakeup: } } } -func (v *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) { - select { - case <-v.ctx.Done(): - return nil, io.ErrClosedPipe - case <-v.err: - return nil, io.ErrClosedPipe - case b := <-v.buffer: - return b, nil - default: - if timeout == 0 { - return nil, buf.ErrReadTimeout +func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) { + for { + mb, err := s.getData() + if err != nil { + return nil, err + } + + if mb != nil { + return mb, nil } select { - case <-v.ctx.Done(): - return nil, io.ErrClosedPipe - case b := <-v.buffer: - return b, nil - case <-v.close: - return nil, io.EOF - case <-v.err: + case <-s.ctx.Done(): return nil, io.ErrClosedPipe case <-time.After(timeout): return nil, buf.ErrReadTimeout + case <-s.wakeup: } } } -func (v *Stream) Write(data buf.MultiBuffer) (err error) { +func (s *Stream) Write(data buf.MultiBuffer) (err error) { if data.IsEmpty() { return } + s.access.Lock() + defer s.access.Unlock() + + if s.err { + data.Release() + return io.ErrClosedPipe + } + if s.close { + data.Release() + return io.ErrClosedPipe + } + + if s.data == nil { + s.data = data + } else { + s.data.AppendMulti(data) + } + s.wakeUp() + + return nil +} + +func (s *Stream) wakeUp() { select { - case <-v.ctx.Done(): - return io.ErrClosedPipe - case <-v.err: - return io.ErrClosedPipe - case <-v.close: - return io.ErrClosedPipe + case s.wakeup <- true: default: - select { - case <-v.ctx.Done(): - return io.ErrClosedPipe - case <-v.err: - return io.ErrClosedPipe - case <-v.close: - return io.ErrClosedPipe - case v.buffer <- data: - return nil - } } } -func (v *Stream) Close() { - defer swallowPanic() - - close(v.close) +func (s *Stream) Close() { + s.access.Lock() + s.close = true + s.wakeUp() + s.access.Unlock() } -func (v *Stream) CloseError() { - defer swallowPanic() - - close(v.err) - - n := len(v.buffer) - for i := 0; i < n; i++ { - select { - case b := <-v.buffer: - b.Release() - default: - return - } +func (s *Stream) CloseError() { + s.access.Lock() + s.err = true + if s.data != nil { + s.data.Release() + s.data = nil } + s.wakeUp() + s.access.Unlock() } func (v *Stream) Release() {} - -func swallowPanic() { - recover() -}