diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go index 933f7208f..8956bc18b 100644 --- a/transport/pipe/impl.go +++ b/transport/pipe/impl.go @@ -1,6 +1,7 @@ package pipe import ( + "errors" "io" "sync" "time" @@ -26,9 +27,14 @@ type pipe struct { state state } +var errBufferFull = errors.New("buffer full") + func (p *pipe) getState(forRead bool) error { switch p.state { case open: + if !forRead && p.limit >= 0 && p.data.Len() > p.limit { + return errBufferFull + } return nil case closed: if forRead { @@ -105,9 +111,10 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error { } for { - if p.limit < 0 || p.data.Len()+mb.Len() <= p.limit { - defer p.readSignal.Signal() - return p.writeMultiBufferInternal(mb) + err := p.writeMultiBufferInternal(mb) + if err == nil || err != errBufferFull { + p.readSignal.Signal() + return err } <-p.writeSignal.Wait() diff --git a/transport/pipe/pipe_test.go b/transport/pipe/pipe_test.go new file mode 100644 index 000000000..9020ea303 --- /dev/null +++ b/transport/pipe/pipe_test.go @@ -0,0 +1,23 @@ +package pipe_test + +import ( + "testing" + + "v2ray.com/core/common/buf" + . "v2ray.com/core/transport/pipe" + . "v2ray.com/ext/assert" +) + +func TestPipeReadWrite(t *testing.T) { + assert := With(t) + + pReader, pWriter := New() + payload := []byte{'a', 'b', 'c', 'd'} + b := buf.New() + b.Append(payload) + assert(pWriter.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil) + + rb, err := pReader.ReadMultiBuffer() + assert(err, IsNil) + assert(rb.String(), Equals, b.String()) +}