From f418b9bc20c2a5269b887640618c74ba3add378a Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Thu, 27 Apr 2017 22:20:29 +0200 Subject: [PATCH] swallow write error in mux --- app/proxyman/mux/mux.go | 20 ++++------------- app/proxyman/mux/reader.go | 46 ++++++++++++++++++++++++++++---------- common/buf/io.go | 45 ++++++++++++++++++++++++++++++++----- common/buf/writer.go | 11 +++++++++ 4 files changed, 88 insertions(+), 34 deletions(-) diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index 4236577a1..f22213765 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -175,22 +175,10 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool } func drain(reader *Reader) error { - data, err := reader.Read() - if err != nil { - return err - } - data.Release() + buf.Copy(signal.BackgroundTimer(), reader, buf.Discard) return nil } -func pipe(reader *Reader, writer buf.Writer) error { - data, err := reader.Read() - if err != nil { - return err - } - return writer.Write(data) -} - func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *Reader) error { if meta.Option.Has(OptionData) { return drain(reader) @@ -211,7 +199,7 @@ func (m *Client) handleStatusKeep(meta *FrameMetadata, reader *Reader) error { } if s, found := m.sessionManager.Get(meta.SessionID); found { - return pipe(reader, s.output) + return buf.Copy(signal.BackgroundTimer(), reader, s.output, buf.IgnoreWriterError()) } return drain(reader) } @@ -335,7 +323,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, w.sessionManager.Add(s) go handle(ctx, s, w.outboundRay.OutboundOutput()) if meta.Option.Has(OptionData) { - return pipe(reader, s.output) + return buf.Copy(signal.BackgroundTimer(), reader, s.output, buf.IgnoreWriterError()) } return nil } @@ -345,7 +333,7 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *Reader) err return nil } if s, found := w.sessionManager.Get(meta.SessionID); found { - return pipe(reader, s.output) + return buf.Copy(signal.BackgroundTimer(), reader, s.output, buf.IgnoreWriterError()) } return drain(reader) } diff --git a/app/proxyman/mux/reader.go b/app/proxyman/mux/reader.go index 8cede0a2b..278a05c49 100644 --- a/app/proxyman/mux/reader.go +++ b/app/proxyman/mux/reader.go @@ -8,18 +8,22 @@ import ( ) type Reader struct { - reader io.Reader - buffer *buf.Buffer + reader io.Reader + buffer *buf.Buffer + leftOver int } func NewReader(reader buf.Reader) *Reader { return &Reader{ - reader: buf.ToBytesReader(reader), - buffer: buf.NewLocal(1024), + reader: buf.ToBytesReader(reader), + buffer: buf.NewLocal(1024), + leftOver: -1, } } func (r *Reader) ReadMetadata() (*FrameMetadata, error) { + r.leftOver = -1 + b := r.buffer b.Clear() @@ -37,25 +41,43 @@ func (r *Reader) ReadMetadata() (*FrameMetadata, error) { return ReadFrameFrom(b.Bytes()) } -func (r *Reader) Read() (buf.MultiBuffer, error) { +func (r *Reader) readSize() error { if err := r.buffer.Reset(buf.ReadFullFrom(r.reader, 2)); err != nil { - return nil, err + return err + } + r.leftOver = int(serial.BytesToUint16(r.buffer.Bytes())) + return nil +} + +func (r *Reader) Read() (buf.MultiBuffer, error) { + if r.leftOver == 0 { + r.leftOver = -1 + return nil, io.EOF + } + if r.leftOver == -1 { + if err := r.readSize(); err != nil { + return nil, err + } } - dataLen := int(serial.BytesToUint16(r.buffer.Bytes())) mb := buf.NewMultiBuffer() - for dataLen > 0 { + for r.leftOver > 0 { readLen := buf.Size - if dataLen < readLen { - readLen = dataLen + if r.leftOver < readLen { + readLen = r.leftOver } b := buf.New() - if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, readLen)); err != nil { + if err := b.AppendSupplier(func(bb []byte) (int, error) { + return r.reader.Read(bb[:readLen]) + }); err != nil { mb.Release() return nil, err } - dataLen -= readLen + r.leftOver -= b.Len() mb.Append(b) + if b.Len() < readLen { + break + } } return mb, nil diff --git a/common/buf/io.go b/common/buf/io.go index 37d52e4a8..b18fe3768 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -47,13 +47,40 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier { } } -func copyInternal(timer signal.ActivityTimer, reader Reader, writer Writer) error { +type copyHandler struct { + onReadError func(error) error + onData func() + onWriteError func(error) error +} + +type CopyOption func(*copyHandler) + +func IgnoreReaderError() CopyOption { + return func(handler *copyHandler) { + handler.onReadError = func(err error) error { + return nil + } + } +} + +func IgnoreWriterError() CopyOption { + return func(handler *copyHandler) { + handler.onWriteError = func(err error) error { + return nil + } + } +} + +func copyInternal(timer signal.ActivityTimer, reader Reader, writer Writer, handler copyHandler) error { for { buffer, err := reader.Read() if err != nil { - return err + if err = handler.onReadError(err); err != nil { + return err + } } + handler.onData() timer.Update() if buffer.IsEmpty() { @@ -62,16 +89,22 @@ func copyInternal(timer signal.ActivityTimer, reader Reader, writer Writer) erro } if err := writer.Write(buffer); err != nil { - buffer.Release() - return err + if err = handler.onWriteError(err); err != nil { + buffer.Release() + return err + } } } } // Copy dumps all payload from reader to writer or stops when an error occurs. // ActivityTimer gets updated as soon as there is a payload. -func Copy(timer signal.ActivityTimer, reader Reader, writer Writer) error { - err := copyInternal(timer, reader, writer) +func Copy(timer signal.ActivityTimer, reader Reader, writer Writer, options ...CopyOption) error { + handler := copyHandler{} + for _, option := range options { + option(&handler) + } + err := copyInternal(timer, reader, writer, handler) if err != nil && errors.Cause(err) != io.EOF { return err } diff --git a/common/buf/writer.go b/common/buf/writer.go index 988580d1c..ab7fbca5b 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -106,3 +106,14 @@ func (w *bytesToBufferWriter) ReadFrom(reader io.Reader) (int64, error) { } return totalBytes, nil } + +type noOpWriter struct{} + +func (noOpWriter) Write(b MultiBuffer) error { + b.Release() + return nil +} + +var ( + Discard Writer = noOpWriter{} +)