From ade88fd5c7feffa7e68fa4218504085745c5d4d4 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Wed, 24 May 2017 00:54:30 +0200 Subject: [PATCH] reuse buffered writer in auth writer --- common/buf/buffered_writer.go | 31 +++++++------ common/buf/buffered_writer_test.go | 1 + common/crypto/auth.go | 70 +++++++++++++----------------- proxy/vmess/inbound/inbound.go | 4 +- testing/scenarios/dokodemo_test.go | 13 ++++++ 5 files changed, 63 insertions(+), 56 deletions(-) diff --git a/common/buf/buffered_writer.go b/common/buf/buffered_writer.go index 6faed0dbd..2ab504dbf 100644 --- a/common/buf/buffered_writer.go +++ b/common/buf/buffered_writer.go @@ -11,10 +11,14 @@ type BufferedWriter struct { } // NewBufferedWriter creates a new BufferedWriter. -func NewBufferedWriter(rawWriter io.Writer) *BufferedWriter { +func NewBufferedWriter(writer io.Writer) *BufferedWriter { + return NewBufferedWriterSize(writer, 1024) +} + +func NewBufferedWriterSize(writer io.Writer, size uint32) *BufferedWriter { return &BufferedWriter{ - writer: rawWriter, - buffer: NewLocal(1024), + writer: writer, + buffer: NewLocal(int(size)), buffered: true, } } @@ -24,21 +28,20 @@ func (w *BufferedWriter) Write(b []byte) (int, error) { if !w.buffered || w.buffer == nil { return w.writer.Write(b) } - nBytes, err := w.buffer.Write(b) - if err != nil { - return 0, err - } - if w.buffer.IsFull() { - if err := w.Flush(); err != nil { - return 0, err + bytesWritten := 0 + for bytesWritten < len(b) { + nBytes, err := w.buffer.Write(b[bytesWritten:]) + if err != nil { + return bytesWritten, err } - if nBytes < len(b) { - if _, err := w.writer.Write(b[nBytes:]); err != nil { - return nBytes, err + bytesWritten += nBytes + if w.buffer.IsFull() { + if err := w.Flush(); err != nil { + return bytesWritten, err } } } - return len(b), nil + return bytesWritten, nil } // Flush writes all buffered content into underlying writer, if any. diff --git a/common/buf/buffered_writer_test.go b/common/buf/buffered_writer_test.go index d22db443b..c2254dd54 100644 --- a/common/buf/buffered_writer_test.go +++ b/common/buf/buffered_writer_test.go @@ -47,6 +47,7 @@ func TestBufferedWriterLargePayload(t *testing.T) { nBytes, err = writer.Write(payload[512:]) assert.Error(err).IsNil() + assert.Error(writer.Flush()).IsNil() assert.Int(nBytes).Equals(64*1024 - 512) assert.Bytes(content.Bytes()).Equals(payload) } diff --git a/common/crypto/auth.go b/common/crypto/auth.go index fd1b735ea..d268809f2 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -4,6 +4,7 @@ import ( "crypto/cipher" "io" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/protocol" ) @@ -123,7 +124,12 @@ func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) { if !waitForData { return nil, io.ErrNoProgress } - r.buffer.Reset(buf.ReadFrom(r.buffer)) + + if r.buffer.IsEmpty() { + r.buffer.Clear() + } else { + common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer))) + } delta := r.size - r.buffer.Len() if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil { @@ -184,42 +190,39 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { type AuthenticationWriter struct { auth Authenticator + buffer []byte payload []byte - buffer *buf.Buffer - writer io.Writer + writer *buf.BufferedWriter sizeParser ChunkSizeEncoder transferType protocol.TransferType } func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter { + const payloadSize = 1024 return &AuthenticationWriter{ auth: auth, - payload: make([]byte, 1024), - buffer: buf.NewLocal(readerBufferSize), - writer: writer, + buffer: make([]byte, payloadSize+sizeParser.SizeBytes()+auth.Overhead()), + payload: make([]byte, payloadSize), + writer: buf.NewBufferedWriterSize(writer, readerBufferSize), sizeParser: sizeParser, transferType: transferType, } } -func (w *AuthenticationWriter) append(b []byte) { +func (w *AuthenticationWriter) append(b []byte) error { encryptedSize := len(b) + w.auth.Overhead() + buffer := w.sizeParser.Encode(uint16(encryptedSize), w.buffer[:0]) - w.buffer.AppendSupplier(func(bb []byte) (int, error) { - w.sizeParser.Encode(uint16(encryptedSize), bb[:0]) - return w.sizeParser.SizeBytes(), nil - }) + buffer, err := w.auth.Seal(buffer, b) + if err != nil { + return err + } - w.buffer.AppendSupplier(func(bb []byte) (int, error) { - w.auth.Seal(bb[:0], b) - return encryptedSize, nil - }) -} + if _, err := w.writer.Write(buffer); err != nil { + return err + } -func (w *AuthenticationWriter) flush() error { - _, err := w.writer.Write(w.buffer.Bytes()) - w.buffer.Clear() - return err + return nil } func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error { @@ -227,21 +230,15 @@ func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error { for { n, _ := mb.Read(w.payload) - w.append(w.payload[:n]) - if w.buffer.Len() > readerBufferSize-2*1024 { - if err := w.flush(); err != nil { - return err - } + if err := w.append(w.payload[:n]); err != nil { + return err } if mb.IsEmpty() { break } } - if !w.buffer.IsEmpty() { - return w.flush() - } - return nil + return w.writer.Flush() } func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error { @@ -252,24 +249,17 @@ func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error { if b == nil { b = buf.New() } - if w.buffer.Len() > readerBufferSize-b.Len()-128 { - if err := w.flush(); err != nil { - b.Release() - return err - } + if err := w.append(b.Bytes()); err != nil { + b.Release() + return err } - w.append(b.Bytes()) b.Release() if mb.IsEmpty() { break } } - if !w.buffer.IsEmpty() { - return w.flush() - } - - return nil + return w.writer.Flush() } func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error { diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 9e7c49a5d..149753de8 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -181,7 +181,7 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i if err != nil { if errors.Cause(err) != io.EOF { log.Access(connection.RemoteAddr(), "", log.AccessRejected, err) - log.Trace(newError("invalid request from ", connection.RemoteAddr(), ": ", err)) + log.Trace(newError("invalid request from ", connection.RemoteAddr(), ": ", err).AtInfo()) } return err } @@ -194,7 +194,7 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i log.Access(connection.RemoteAddr(), request.Destination(), log.AccessAccepted, "") log.Trace(newError("received request for ", request.Destination())) - connection.SetReadDeadline(time.Time{}) + common.Must(connection.SetReadDeadline(time.Time{})) userSettings := request.User.GetSettings() diff --git a/testing/scenarios/dokodemo_test.go b/testing/scenarios/dokodemo_test.go index 8e4834e03..20ce2528c 100644 --- a/testing/scenarios/dokodemo_test.go +++ b/testing/scenarios/dokodemo_test.go @@ -5,6 +5,7 @@ import ( "testing" "v2ray.com/core" + "v2ray.com/core/app/log" "v2ray.com/core/app/proxyman" v2net "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -55,6 +56,12 @@ func TestDokodemoTCP(t *testing.T) { ProxySettings: serial.ToTypedMessage(&freedom.Config{}), }, }, + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: log.LogLevel_Debug, + ErrorLogType: log.LogType_Console, + }), + }, } clientPort := uint32(pickPort()) @@ -94,6 +101,12 @@ func TestDokodemoTCP(t *testing.T) { }), }, }, + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: log.LogLevel_Debug, + ErrorLogType: log.LogType_Console, + }), + }, } servers, err := InitializeServerConfigs(serverConfig, clientConfig)