simplify ray stream

This commit is contained in:
Darien Raymond 2017-04-16 09:57:28 +02:00
parent d809973621
commit 2f565bfd5e
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
9 changed files with 120 additions and 168 deletions

View File

@ -1,34 +0,0 @@
package buf
type MergingReader struct {
reader Reader
timeoutReader TimeoutReader
}
func NewMergingReader(reader Reader) Reader {
return &MergingReader{
reader: reader,
timeoutReader: reader.(TimeoutReader),
}
}
func (r *MergingReader) Read() (MultiBuffer, error) {
mb, err := r.reader.Read()
if err != nil {
return nil, err
}
if r.timeoutReader == nil {
return mb, nil
}
for {
mb2, err := r.timeoutReader.ReadTimeout(0)
if err != nil {
break
}
mb.AppendMulti(mb2)
}
return mb, nil
}

View File

@ -1,33 +0,0 @@
package buf_test
import (
"testing"
"context"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/ray"
)
func TestMergingReader(t *testing.T) {
assert := assert.On(t)
stream := ray.NewStream(context.Background())
b1 := New()
b1.AppendBytes('a', 'b', 'c')
stream.Write(NewMultiBufferValue(b1))
b2 := New()
b2.AppendBytes('e', 'f', 'g')
stream.Write(NewMultiBufferValue(b2))
b3 := New()
b3.AppendBytes('h', 'i', 'j')
stream.Write(NewMultiBufferValue(b3))
reader := NewMergingReader(stream)
b, err := reader.Read()
assert.Error(err).IsNil()
assert.Int(b.Len()).Equals(9)
}

View File

@ -13,7 +13,7 @@ type MultiBufferReader interface {
type MultiBuffer []*Buffer
func NewMultiBuffer() MultiBuffer {
return MultiBuffer(make([]*Buffer, 0, 8))
return MultiBuffer(make([]*Buffer, 0, 32))
}
func NewMultiBufferValue(b ...*Buffer) MultiBuffer {

View File

@ -4,6 +4,7 @@ package freedom
import (
"context"
"io"
"runtime"
"time"
@ -112,8 +113,13 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
requestDone := signal.ExecuteAsync(func() error {
v2writer := buf.NewWriter(conn)
if err := buf.PipeUntilEOF(timer, input, v2writer); err != nil {
var writer buf.Writer
if destination.Network == net.Network_TCP {
writer = buf.NewWriter(conn)
} else {
writer = &seqWriter{writer: conn}
}
if err := buf.PipeUntilEOF(timer, input, writer); err != nil {
return err
}
return nil
@ -145,3 +151,19 @@ func init() {
return New(ctx, config.(*Config))
}))
}
type seqWriter struct {
writer io.Writer
}
func (w *seqWriter) Write(mb buf.MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
if _, err := w.writer.Write(b.Bytes()); err != nil {
return err
}
}
return nil
}

View File

@ -105,8 +105,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
}
requestDone := signal.ExecuteAsync(func() error {
mergedInput := buf.NewMergingReader(outboundRay.OutboundInput())
if err := buf.PipeUntilEOF(timer, mergedInput, bodyWriter); err != nil {
if err := buf.PipeUntilEOF(timer, outboundRay.OutboundInput(), bodyWriter); err != nil {
return err
}
return nil

View File

@ -160,8 +160,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
return newError("failed to write response").Base(err)
}
mergeReader := buf.NewMergingReader(ray.InboundOutput())
payload, err := mergeReader.Read()
payload, err := ray.InboundOutput().Read()
if err != nil {
return err
}
@ -174,7 +173,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
return err
}
if err := buf.PipeUntilEOF(timer, mergeReader, responseWriter); err != nil {
if err := buf.PipeUntilEOF(timer, ray.InboundOutput(), responseWriter); err != nil {
return newError("failed to transport all TCP response").Base(err)
}

View File

@ -140,12 +140,8 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio
bodyWriter := session.EncodeResponseBody(request, output)
var reader buf.Reader = input
if request.Command == protocol.RequestCommandTCP {
reader = buf.NewMergingReader(input)
}
// Optimize for small response packet
data, err := reader.Read()
data, err := input.Read()
if err != nil {
return err
}
@ -161,7 +157,7 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio
}
}
if err := buf.PipeUntilEOF(timer, reader, bodyWriter); err != nil {
if err := buf.PipeUntilEOF(timer, input, bodyWriter); err != nil {
return err
}

View File

@ -123,12 +123,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
return err
}
var inputReader buf.Reader = input
if request.Command == protocol.RequestCommandTCP {
inputReader = buf.NewMergingReader(input)
}
if err := buf.PipeUntilEOF(timer, inputReader, bodyWriter); err != nil {
if err := buf.PipeUntilEOF(timer, input, bodyWriter); err != nil {
return err
}

View File

@ -3,15 +3,12 @@ package ray
import (
"context"
"io"
"sync"
"time"
"v2ray.com/core/common/buf"
)
const (
bufferSize = 512
)
// NewRay creates a new Ray for direct traffic transport.
func NewRay(ctx context.Context) Ray {
return &directRay{
@ -42,121 +39,132 @@ func (v *directRay) InboundOutput() InputStream {
}
type Stream struct {
buffer chan buf.MultiBuffer
access sync.Mutex
data buf.MultiBuffer
ctx context.Context
close chan bool
err chan bool
wakeup chan bool
close bool
err bool
}
func NewStream(ctx context.Context) *Stream {
return &Stream{
ctx: ctx,
buffer: make(chan buf.MultiBuffer, bufferSize),
close: make(chan bool),
err: make(chan bool),
wakeup: make(chan bool, 1),
}
}
func (v *Stream) Read() (buf.MultiBuffer, error) {
select {
case <-v.ctx.Done():
func (s *Stream) getData() (buf.MultiBuffer, error) {
s.access.Lock()
defer s.access.Unlock()
if s.data != nil {
mb := s.data
s.data = nil
return mb, nil
}
if s.close {
return nil, io.EOF
}
if s.err {
return nil, io.ErrClosedPipe
case <-v.err:
return nil, io.ErrClosedPipe
case b := <-v.buffer:
return b, nil
default:
}
return nil, nil
}
func (s *Stream) Read() (buf.MultiBuffer, error) {
for {
mb, err := s.getData()
if err != nil {
return nil, err
}
if mb != nil {
return mb, nil
}
select {
case <-v.ctx.Done():
return nil, io.ErrClosedPipe
case b := <-v.buffer:
return b, nil
case <-v.close:
return nil, io.EOF
case <-v.err:
case <-s.ctx.Done():
return nil, io.ErrClosedPipe
case <-s.wakeup:
}
}
}
func (v *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
select {
case <-v.ctx.Done():
return nil, io.ErrClosedPipe
case <-v.err:
return nil, io.ErrClosedPipe
case b := <-v.buffer:
return b, nil
default:
if timeout == 0 {
return nil, buf.ErrReadTimeout
func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
for {
mb, err := s.getData()
if err != nil {
return nil, err
}
if mb != nil {
return mb, nil
}
select {
case <-v.ctx.Done():
return nil, io.ErrClosedPipe
case b := <-v.buffer:
return b, nil
case <-v.close:
return nil, io.EOF
case <-v.err:
case <-s.ctx.Done():
return nil, io.ErrClosedPipe
case <-time.After(timeout):
return nil, buf.ErrReadTimeout
case <-s.wakeup:
}
}
}
func (v *Stream) Write(data buf.MultiBuffer) (err error) {
func (s *Stream) Write(data buf.MultiBuffer) (err error) {
if data.IsEmpty() {
return
}
s.access.Lock()
defer s.access.Unlock()
if s.err {
data.Release()
return io.ErrClosedPipe
}
if s.close {
data.Release()
return io.ErrClosedPipe
}
if s.data == nil {
s.data = data
} else {
s.data.AppendMulti(data)
}
s.wakeUp()
return nil
}
func (s *Stream) wakeUp() {
select {
case <-v.ctx.Done():
return io.ErrClosedPipe
case <-v.err:
return io.ErrClosedPipe
case <-v.close:
return io.ErrClosedPipe
case s.wakeup <- true:
default:
select {
case <-v.ctx.Done():
return io.ErrClosedPipe
case <-v.err:
return io.ErrClosedPipe
case <-v.close:
return io.ErrClosedPipe
case v.buffer <- data:
return nil
}
}
}
func (v *Stream) Close() {
defer swallowPanic()
close(v.close)
func (s *Stream) Close() {
s.access.Lock()
s.close = true
s.wakeUp()
s.access.Unlock()
}
func (v *Stream) CloseError() {
defer swallowPanic()
close(v.err)
n := len(v.buffer)
for i := 0; i < n; i++ {
select {
case b := <-v.buffer:
b.Release()
default:
return
}
func (s *Stream) CloseError() {
s.access.Lock()
s.err = true
if s.data != nil {
s.data.Release()
s.data = nil
}
s.wakeUp()
s.access.Unlock()
}
func (v *Stream) Release() {}
func swallowPanic() {
recover()
}