From 498c7dafdf4326e751604a5a2c4324ffd736aabe Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Fri, 21 Apr 2017 14:51:09 +0200 Subject: [PATCH] clean udp writer --- common/buf/io.go | 6 ++++++ common/buf/writer.go | 19 +++++++++++++++++++ proxy/freedom/freedom.go | 19 +------------------ proxy/shadowsocks/client.go | 4 ++-- proxy/shadowsocks/protocol.go | 27 ++++++++------------------- proxy/shadowsocks/protocol_test.go | 6 +++--- proxy/shadowsocks/server.go | 2 +- proxy/socks/client.go | 2 +- proxy/socks/protocol.go | 16 ++++++---------- proxy/socks/protocol_test.go | 2 +- 10 files changed, 48 insertions(+), 55 deletions(-) diff --git a/common/buf/io.go b/common/buf/io.go index 731c8b7fa..f0024b48d 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -128,6 +128,12 @@ func NewMergingWriterSize(writer io.Writer, size uint32) Writer { } } +func NewSequentialWriter(writer io.Writer) Writer { + return &seqWriter{ + writer: writer, + } +} + // ToBytesWriter converts a Writer to io.Writer func ToBytesWriter(writer Writer) io.Writer { return &bytesToBufferWriter{ diff --git a/common/buf/writer.go b/common/buf/writer.go index dc2066a38..da777619c 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -42,6 +42,25 @@ func (w *mergingWriter) Write(mb MultiBuffer) error { return nil } +type seqWriter struct { + writer io.Writer +} + +func (w *seqWriter) Write(mb MultiBuffer) error { + defer mb.Release() + + for _, b := range mb { + if b.IsEmpty() { + continue + } + if _, err := w.writer.Write(b.Bytes()); err != nil { + return err + } + } + + return nil +} + type bytesToBufferWriter struct { writer Writer } diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index a2bd0182c..147ce1a1e 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -4,7 +4,6 @@ package freedom import ( "context" - "io" "runtime" "time" @@ -117,7 +116,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial if destination.Network == net.Network_TCP { writer = buf.NewWriter(conn) } else { - writer = &seqWriter{writer: conn} + writer = buf.NewSequentialWriter(conn) } if err := buf.Copy(timer, input, writer); err != nil { return newError("failed to process request").Base(err) @@ -151,19 +150,3 @@ func init() { return New(ctx, config.(*Config)) })) } - -type seqWriter struct { - writer io.Writer -} - -func (w *seqWriter) Write(mb buf.MultiBuffer) error { - defer mb.Release() - - for _, b := range mb { - if _, err := w.writer.Write(b.Bytes()); err != nil { - return err - } - } - - return nil -} diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 91cdea448..e6fa70ade 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -135,10 +135,10 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale if request.Command == protocol.RequestCommandUDP { - writer := &UDPWriter{ + writer := buf.NewSequentialWriter(&UDPWriter{ Writer: conn, Request: request, - } + }) requestDone := signal.ExecuteAsync(func() error { if err := buf.Copy(timer, outboundRay.OutboundInput(), writer); err != nil { diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index d4f567bf1..baaab768a 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -238,7 +238,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr return buf.NewWriter(crypto.NewCryptionWriter(stream, writer)), nil } -func EncodeUDPPacket(request *protocol.RequestHeader, payload *buf.Buffer) (*buf.Buffer, error) { +func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) { user := request.User rawAccount, err := user.GetTypedAccount() if err != nil { @@ -266,7 +266,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload *buf.Buffer) (*buf } buffer.AppendSupplier(serial.WriteUint16(uint16(request.Port))) - buffer.Append(payload.Bytes()) + buffer.Append(payload) if request.Option.Has(RequestOptionOneTimeAuth) { authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) @@ -382,23 +382,12 @@ type UDPWriter struct { Request *protocol.RequestHeader } -func (w *UDPWriter) Write(mb buf.MultiBuffer) error { - defer mb.Release() - - for _, b := range mb { - if err := w.writeInternal(b); err != nil { - return err - } - } - return nil -} - -func (w *UDPWriter) writeInternal(buffer *buf.Buffer) error { - payload, err := EncodeUDPPacket(w.Request, buffer) +func (w *UDPWriter) Write(payload []byte) (int, error) { + packet, err := EncodeUDPPacket(w.Request, payload) if err != nil { - return err + return 0, err } - _, err = w.Writer.Write(payload.Bytes()) - payload.Release() - return err + _, err = w.Writer.Write(packet.Bytes()) + packet.Release() + return len(payload), err } diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go index 1326cd704..d3871efd1 100644 --- a/proxy/shadowsocks/protocol_test.go +++ b/proxy/shadowsocks/protocol_test.go @@ -31,7 +31,7 @@ func TestUDPEncoding(t *testing.T) { data := buf.NewLocal(256) data.AppendSupplier(serial.WriteString("test string")) - encodedData, err := EncodeUDPPacket(request, data) + encodedData, err := EncodeUDPPacket(request, data.Bytes()) assert.Error(err).IsNil() decodedRequest, decodedData, err := DecodeUDPPacket(request.User, encodedData) @@ -88,7 +88,7 @@ func TestUDPReaderWriter(t *testing.T) { }), } cache := buf.New() - writer := &UDPWriter{ + writer := buf.NewSequentialWriter(&UDPWriter{ Writer: cache, Request: &protocol.RequestHeader{ Version: Version, @@ -97,7 +97,7 @@ func TestUDPReaderWriter(t *testing.T) { User: user, Option: RequestOptionOneTimeAuth, }, - } + }) reader := &UDPReader{ Reader: cache, diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 6ced719ee..158e89be5 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -113,7 +113,7 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) { defer payload.Release() - data, err := EncodeUDPPacket(request, payload) + data, err := EncodeUDPPacket(request, payload.Bytes()) if err != nil { log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning()) return diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 97eaee6af..c8414da6d 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -103,7 +103,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy. } defer udpConn.Close() requestFunc = func() error { - return buf.Copy(timer, ray.OutboundInput(), &UDPWriter{request: request, writer: udpConn}) + return buf.Copy(timer, ray.OutboundInput(), buf.NewSequentialWriter(NewUDPWriter(request, udpConn))) } responseFunc = func() error { defer ray.OutboundOutput().Close() diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 930e22192..2529a7f22 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -369,17 +369,13 @@ func NewUDPWriter(request *protocol.RequestHeader, writer io.Writer) *UDPWriter } } -func (w *UDPWriter) Write(mb buf.MultiBuffer) error { - defer mb.Release() - - for _, b := range mb { - eb := EncodeUDPPacket(w.request, b.Bytes()) - defer eb.Release() - if _, err := w.writer.Write(eb.Bytes()); err != nil { - return err - } +func (w *UDPWriter) Write(b []byte) (int, error) { + eb := EncodeUDPPacket(w.request, b) + defer eb.Release() + if _, err := w.writer.Write(eb.Bytes()); err != nil { + return 0, err } - return nil + return len(b), nil } func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { diff --git a/proxy/socks/protocol_test.go b/proxy/socks/protocol_test.go index 7e2575d3f..f0111abc4 100644 --- a/proxy/socks/protocol_test.go +++ b/proxy/socks/protocol_test.go @@ -19,7 +19,7 @@ func TestUDPEncoding(t *testing.T) { Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}), Port: 1024, } - writer := NewUDPWriter(request, b) + writer := buf.NewSequentialWriter(NewUDPWriter(request, b)) content := []byte{'a'} payload := buf.New()