From f7b96507f91b9b4c09b0a4cd748b37a02522d493 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Fri, 2 Nov 2018 21:34:04 +0100 Subject: [PATCH] simplify buffer extension --- app/dispatcher/default.go | 7 +- app/dns/udpns.go | 8 +-- common/buf/buffer.go | 33 +++++---- common/buf/buffer_test.go | 3 +- common/crypto/auth.go | 16 ++--- common/crypto/chunk.go | 5 +- common/serial/hash.go | 12 ---- proxy/blackhole/config.go | 3 +- proxy/shadowsocks/config.go | 23 +++---- proxy/shadowsocks/ota.go | 10 ++- proxy/shadowsocks/protocol.go | 7 +- proxy/shadowsocks/protocol_test.go | 7 +- proxy/vmess/encoding/client.go | 4 +- transport/internet/header.go | 2 +- transport/internet/headers/http/http.go | 40 ++++++----- transport/internet/headers/http/http_test.go | 4 +- transport/internet/headers/noop/noop.go | 6 +- transport/internet/headers/srtp/srtp.go | 5 +- transport/internet/headers/srtp/srtp_test.go | 2 +- transport/internet/headers/tls/dtls.go | 5 +- transport/internet/headers/tls/dtls_test.go | 2 +- transport/internet/headers/utp/utp.go | 5 +- transport/internet/headers/utp/utp_test.go | 2 +- transport/internet/headers/wechat/wechat.go | 5 +- .../internet/headers/wechat/wechat_test.go | 2 +- .../internet/headers/wireguard/wireguard.go | 10 +-- transport/internet/kcp/io.go | 15 ++--- transport/internet/kcp/output.go | 5 +- transport/internet/kcp/segment.go | 67 ++++++++----------- transport/internet/kcp/segment_test.go | 8 +-- transport/internet/udp/hub.go | 9 +-- 31 files changed, 139 insertions(+), 193 deletions(-) delete mode 100644 common/serial/hash.go diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index c38f67f49..75c7d7673 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -39,9 +39,10 @@ func (r *cachedReader) Cache(b *buf.Buffer) { if !mb.IsEmpty() { common.Must(r.cache.WriteMultiBuffer(mb)) } - common.Must(b.Reset(func(x []byte) (int, error) { - return r.cache.Copy(x), nil - })) + b.Clear() + rawBytes := b.Extend(buf.Size) + n := r.cache.Copy(rawBytes) + b.Resize(0, int32(n)) r.Unlock() } diff --git a/app/dns/udpns.go b/app/dns/udpns.go index e9b0ef4b7..20d99ffa2 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -260,13 +260,13 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg { func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) { buffer := buf.New() - if err := buffer.Reset(func(b []byte) (int, error) { - writtenBuffer, err := msg.PackBuffer(b) - return len(writtenBuffer), err - }); err != nil { + rawBytes := buffer.Extend(buf.Size) + packed, err := msg.PackBuffer(rawBytes) + if err != nil { buffer.Release() return nil, err } + buffer.Resize(0, int32(len(packed))) return buffer, nil } diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 5b28e7d45..8806ac80f 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -11,9 +11,6 @@ const ( Size = 2048 ) -// Supplier is a writer that writes contents into the given buffer. -type Supplier func([]byte) (int, error) - // Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles // the buffer into an internal buffer pool, in order to recreate a buffer more // quickly. @@ -40,13 +37,6 @@ func (b *Buffer) Clear() { b.end = 0 } -// AppendSupplier appends the content of a BytesWriter to the buffer. -func (b *Buffer) AppendSupplier(writer Supplier) error { - nBytes, err := writer(b.v[b.end:]) - b.end += int32(nBytes) - return err -} - // Byte returns the bytes at index. func (b *Buffer) Byte(index int32) byte { return b.v[b.start+index] @@ -62,15 +52,16 @@ func (b *Buffer) Bytes() []byte { return b.v[b.start:b.end] } -// Reset resets the content of the Buffer with a supplier. -func (b *Buffer) Reset(writer Supplier) error { - nBytes, err := writer(b.v) - if nBytes > len(b.v) { - return newError("too many bytes written: ", nBytes, " > ", len(b.v)) +// Extend increases the buffer size by n bytes, and returns the extended part. +// It panics if result size is larger than buf.Size. +func (b *Buffer) Extend(n int32) []byte { + end := b.end + n + if end > int32(len(b.v)) { + panic(newError("out of bound: ", end)) } - b.start = 0 - b.end = int32(nBytes) - return err + ext := b.v[b.end:end] + b.end = end + return ext } // BytesRange returns a slice of this buffer with given from and to boundary. @@ -153,6 +144,11 @@ func (b *Buffer) WriteBytes(bytes ...byte) (int, error) { return b.Write(bytes) } +// WriteString implements io.StringWriter. +func (b *Buffer) WriteString(s string) (int, error) { + return b.Write([]byte(s)) +} + // Read implements io.Reader.Read(). func (b *Buffer) Read(data []byte) (int, error) { if b.Len() == 0 { @@ -174,6 +170,7 @@ func (b *Buffer) ReadFrom(reader io.Reader) (int64, error) { return int64(n), err } +// ReadFullFrom reads exact size of bytes from given reader, or until error occurs. func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) { end := b.end + size if end > int32(len(b.v)) { diff --git a/common/buf/buffer_test.go b/common/buf/buffer_test.go index fa3973d70..d45677ab0 100644 --- a/common/buf/buffer_test.go +++ b/common/buf/buffer_test.go @@ -8,7 +8,6 @@ import ( "v2ray.com/core/common" . "v2ray.com/core/common/buf" "v2ray.com/core/common/compare" - "v2ray.com/core/common/serial" . "v2ray.com/ext/assert" ) @@ -41,7 +40,7 @@ func TestBufferString(t *testing.T) { buffer := New() defer buffer.Release() - assert(buffer.AppendSupplier(serial.WriteString("Test String")), IsNil) + common.Must2(buffer.WriteString("Test String")) assert(buffer.String(), Equals, "Test String") } diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 4d81fc41f..44f7eff6a 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -245,10 +245,10 @@ func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, wr } func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) { - encryptedSize := int(b.Len()) + w.auth.Overhead() - var paddingSize int + encryptedSize := b.Len() + int32(w.auth.Overhead()) + var paddingSize int32 if w.padding != nil { - paddingSize = int(w.padding.NextPaddingLen()) + paddingSize = int32(w.padding.NextPaddingLen()) } totalSize := encryptedSize + paddingSize @@ -257,14 +257,8 @@ func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) { } eb := buf.New() - common.Must(eb.Reset(func(bb []byte) (int, error) { - w.sizeParser.Encode(uint16(encryptedSize+paddingSize), bb) - return int(w.sizeParser.SizeBytes()), nil - })) - if err := eb.AppendSupplier(func(bb []byte) (int, error) { - _, err := w.auth.Seal(bb[:0], b.Bytes()) - return encryptedSize, err - }); err != nil { + w.sizeParser.Encode(uint16(encryptedSize+paddingSize), eb.Extend(w.sizeParser.SizeBytes())) + if _, err := w.auth.Seal(eb.Extend(encryptedSize)[:0], b.Bytes()); err != nil { eb.Release() return nil, err } diff --git a/common/crypto/chunk.go b/common/crypto/chunk.go index b5314617f..5bea4b8f2 100755 --- a/common/crypto/chunk.go +++ b/common/crypto/chunk.go @@ -146,10 +146,7 @@ func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { slice := mb.SliceBySize(sliceSize) b := buf.New() - common.Must(b.Reset(func(buffer []byte) (int, error) { - w.sizeEncoder.Encode(uint16(slice.Len()), buffer) - return int(w.sizeEncoder.SizeBytes()), nil - })) + w.sizeEncoder.Encode(uint16(slice.Len()), b.Extend(w.sizeEncoder.SizeBytes())) mb2Write.Append(b) mb2Write.AppendMulti(slice) diff --git a/common/serial/hash.go b/common/serial/hash.go deleted file mode 100644 index 627343d6c..000000000 --- a/common/serial/hash.go +++ /dev/null @@ -1,12 +0,0 @@ -package serial - -import ( - "hash" -) - -func WriteHash(h hash.Hash) func(b []byte) (int, error) { - return func(b []byte) (int, error) { - h.Sum(b[:0]) - return h.Size(), nil - } -} diff --git a/proxy/blackhole/config.go b/proxy/blackhole/config.go index 9ff564642..2a0da01d0 100644 --- a/proxy/blackhole/config.go +++ b/proxy/blackhole/config.go @@ -3,7 +3,6 @@ package blackhole import ( "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/serial" ) const ( @@ -28,7 +27,7 @@ func (*NoneResponse) WriteTo(buf.Writer) int32 { return 0 } // WriteTo implements ResponseConfig.WriteTo(). func (*HTTPResponse) WriteTo(writer buf.Writer) int32 { b := buf.New() - common.Must(b.Reset(serial.WriteString(http403response))) + common.Must2(b.Write([]byte(http403response))) n := b.Len() writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) return n diff --git a/proxy/shadowsocks/config.go b/proxy/shadowsocks/config.go index df720f8f4..f2a13e334 100644 --- a/proxy/shadowsocks/config.go +++ b/proxy/shadowsocks/config.go @@ -198,13 +198,10 @@ func (c *AEADCipher) EncodePacket(key []byte, b *buf.Buffer) error { ivLen := c.IVSize() payloadLen := b.Len() auth := c.createAuthenticator(key, b.BytesTo(ivLen)) - return b.Reset(func(bb []byte) (int, error) { - bbb, err := auth.Seal(bb[:ivLen], bb[ivLen:payloadLen]) - if err != nil { - return 0, err - } - return len(bbb), nil - }) + + b.Extend(int32(auth.Overhead())) + _, err := auth.Seal(b.BytesTo(ivLen), b.BytesRange(ivLen, payloadLen)) + return err } func (c *AEADCipher) DecodePacket(key []byte, b *buf.Buffer) error { @@ -214,16 +211,12 @@ func (c *AEADCipher) DecodePacket(key []byte, b *buf.Buffer) error { ivLen := c.IVSize() payloadLen := b.Len() auth := c.createAuthenticator(key, b.BytesTo(ivLen)) - if err := b.Reset(func(bb []byte) (int, error) { - bbb, err := auth.Open(bb[:ivLen], bb[ivLen:payloadLen]) - if err != nil { - return 0, err - } - return len(bbb), nil - }); err != nil { + + bbb, err := auth.Open(b.BytesTo(ivLen), b.BytesRange(ivLen, payloadLen)) + if err != nil { return err } - b.Advance(ivLen) + b.Resize(ivLen, ivLen+int32(len(bbb))) return nil } diff --git a/proxy/shadowsocks/ota.go b/proxy/shadowsocks/ota.go index 8b2f1e21c..3df85a9f5 100644 --- a/proxy/shadowsocks/ota.go +++ b/proxy/shadowsocks/ota.go @@ -30,13 +30,11 @@ func NewAuthenticator(keygen KeyGenerator) *Authenticator { } } -func (v *Authenticator) Authenticate(data []byte) buf.Supplier { +func (v *Authenticator) Authenticate(data []byte, dest []byte) { hasher := hmac.New(sha1.New, v.key()) common.Must2(hasher.Write(data)) res := hasher.Sum(nil) - return func(b []byte) (int, error) { - return copy(b, res[:AuthSize]), nil - } + copy(dest, res[:AuthSize]) } func HeaderKeyGenerator(key []byte, iv []byte) func() []byte { @@ -89,7 +87,7 @@ func (v *ChunkReader) ReadMultiBuffer() (buf.MultiBuffer, error) { payload := buffer[AuthSize:size] actualAuthBytes := make([]byte, AuthSize) - v.auth.Authenticate(payload)(actualAuthBytes) + v.auth.Authenticate(payload, actualAuthBytes) if !bytes.Equal(authBytes, actualAuthBytes) { return nil, newError("invalid auth") } @@ -121,7 +119,7 @@ func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { for { payloadLen, _ := mb.Read(w.buffer[2+AuthSize:]) binary.BigEndian.PutUint16(w.buffer, uint16(payloadLen)) - w.auth.Authenticate(w.buffer[2+AuthSize : 2+AuthSize+payloadLen])(w.buffer[2:]) + w.auth.Authenticate(w.buffer[2+AuthSize:2+AuthSize+payloadLen], w.buffer[2:]) if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil { return err } diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 6aba4b496..30bc7c6bc 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -142,7 +142,8 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri header.SetByte(0, header.Byte(0)|0x10) authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) - common.Must(header.AppendSupplier(authenticator.Authenticate(header.Bytes()))) + authBuffer := header.Extend(AuthSize) + authenticator.Authenticate(header.Bytes(), authBuffer) } if err := w.WriteMultiBuffer(buf.NewMultiBufferValue(header)); err != nil { @@ -210,7 +211,9 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) buffer.SetByte(ivLen, buffer.Byte(ivLen)|0x10) - common.Must(buffer.AppendSupplier(authenticator.Authenticate(buffer.BytesFrom(ivLen)))) + authPayload := buffer.BytesFrom(ivLen) + authBuffer := buffer.Extend(AuthSize) + authenticator.Authenticate(authPayload, authBuffer) } if err := account.Cipher.EncodePacket(account.Key, buffer); err != nil { return nil, newError("failed to encrypt UDP payload").Base(err) diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go index 6cf6cdb0c..d60dbc7f4 100644 --- a/proxy/shadowsocks/protocol_test.go +++ b/proxy/shadowsocks/protocol_test.go @@ -7,7 +7,6 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" - "v2ray.com/core/common/serial" . "v2ray.com/core/proxy/shadowsocks" . "v2ray.com/ext/assert" ) @@ -37,7 +36,7 @@ func TestUDPEncoding(t *testing.T) { } data := buf.New() - data.AppendSupplier(serial.WriteString("test string")) + common.Must2(data.WriteString("test string")) encodedData, err := EncodeUDPPacket(request, data.Bytes()) assert(err, IsNil) @@ -168,7 +167,7 @@ func TestUDPReaderWriter(t *testing.T) { { b := buf.New() - b.AppendSupplier(serial.WriteString("test payload")) + common.Must2(b.WriteString("test payload")) err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) assert(err, IsNil) @@ -179,7 +178,7 @@ func TestUDPReaderWriter(t *testing.T) { { b := buf.New() - b.AppendSupplier(serial.WriteString("test payload 2")) + common.Must2(b.WriteString("test payload 2")) err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) assert(err, IsNil) diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index d54c9e561..12356c4e2 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -16,7 +16,6 @@ import ( "v2ray.com/core/common/crypto" "v2ray.com/core/common/dice" "v2ray.com/core/common/protocol" - "v2ray.com/core/common/serial" "v2ray.com/core/common/vio" "v2ray.com/core/proxy/vmess" ) @@ -88,7 +87,8 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ { fnv1a := fnv.New32a() common.Must2(fnv1a.Write(buffer.Bytes())) - common.Must(buffer.AppendSupplier(serial.WriteHash(fnv1a))) + hashBytes := buffer.Extend(int32(fnv1a.Size())) + fnv1a.Sum(hashBytes[:0]) } iv := hashTimestamp(md5.New(), timestamp) diff --git a/transport/internet/header.go b/transport/internet/header.go index c6bde9258..9f71c60d5 100644 --- a/transport/internet/header.go +++ b/transport/internet/header.go @@ -9,7 +9,7 @@ import ( type PacketHeader interface { Size() int32 - Write([]byte) (int, error) + Serialize([]byte) } func CreatePacketHeader(config interface{}) (PacketHeader, error) { diff --git a/transport/internet/headers/http/http.go b/transport/internet/headers/http/http.go index 88d4ebf23..48c2c990d 100644 --- a/transport/internet/headers/http/http.go +++ b/transport/internet/headers/http/http.go @@ -13,7 +13,6 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/serial" ) const ( @@ -29,7 +28,6 @@ const ( var ( ErrHeaderToLong = newError("Header too long.") - writeCRLF = serial.WriteString(CRLF) ) type Reader interface { @@ -70,12 +68,12 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) { endingDetected = true break } - if buffer.Len() >= int32(len(ENDING)) { - totalBytes += buffer.Len() - int32(len(ENDING)) - leftover := buffer.BytesFrom(-int32(len(ENDING))) - buffer.Reset(func(b []byte) (int, error) { - return copy(b, leftover), nil - }) + lenEnding := int32(len(ENDING)) + if buffer.Len() >= lenEnding { + totalBytes += buffer.Len() - lenEnding + leftover := buffer.BytesFrom(-lenEnding) + buffer.Clear() + copy(buffer.Extend(lenEnding), leftover) } } if buffer.IsEmpty() { @@ -175,20 +173,20 @@ func (c *HttpConn) Close() error { func formResponseHeader(config *ResponseConfig) *HeaderWriter { header := buf.New() - header.AppendSupplier(serial.WriteString(strings.Join([]string{config.GetFullVersion(), config.GetStatusValue().Code, config.GetStatusValue().Reason}, " "))) - header.AppendSupplier(writeCRLF) + common.Must2(header.WriteString(strings.Join([]string{config.GetFullVersion(), config.GetStatusValue().Code, config.GetStatusValue().Reason}, " "))) + common.Must2(header.WriteString(CRLF)) headers := config.PickHeaders() for _, h := range headers { - header.AppendSupplier(serial.WriteString(h)) - header.AppendSupplier(writeCRLF) + common.Must2(header.WriteString(h)) + common.Must2(header.WriteString(CRLF)) } if !config.HasHeader("Date") { - header.AppendSupplier(serial.WriteString("Date: ")) - header.AppendSupplier(serial.WriteString(time.Now().Format(http.TimeFormat))) - header.AppendSupplier(writeCRLF) + common.Must2(header.WriteString("Date: ")) + common.Must2(header.WriteString(time.Now().Format(http.TimeFormat))) + common.Must2(header.WriteString(CRLF)) } - header.AppendSupplier(writeCRLF) + common.Must2(header.WriteString(CRLF)) return &HeaderWriter{ header: header, } @@ -201,15 +199,15 @@ type HttpAuthenticator struct { func (a HttpAuthenticator) GetClientWriter() *HeaderWriter { header := buf.New() config := a.config.Request - header.AppendSupplier(serial.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickUri(), config.GetFullVersion()}, " "))) - header.AppendSupplier(writeCRLF) + common.Must2(header.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickUri(), config.GetFullVersion()}, " "))) + common.Must2(header.WriteString(CRLF)) headers := config.PickHeaders() for _, h := range headers { - header.AppendSupplier(serial.WriteString(h)) - header.AppendSupplier(writeCRLF) + common.Must2(header.WriteString(h)) + common.Must2(header.WriteString(CRLF)) } - header.AppendSupplier(writeCRLF) + common.Must2(header.WriteString(CRLF)) return &HeaderWriter{ header: header, } diff --git a/transport/internet/headers/http/http_test.go b/transport/internet/headers/http/http_test.go index e23ca12c1..610360473 100644 --- a/transport/internet/headers/http/http_test.go +++ b/transport/internet/headers/http/http_test.go @@ -5,9 +5,9 @@ import ( "testing" "time" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" - "v2ray.com/core/common/serial" . "v2ray.com/core/transport/internet/headers/http" . "v2ray.com/ext/assert" ) @@ -17,7 +17,7 @@ func TestReaderWriter(t *testing.T) { cache := buf.New() b := buf.New() - b.AppendSupplier(serial.WriteString("abcd" + ENDING)) + common.Must2(b.WriteString("abcd" + ENDING)) writer := NewHeaderWriter(b) err := writer.Write(cache) assert(err, IsNil) diff --git a/transport/internet/headers/noop/noop.go b/transport/internet/headers/noop/noop.go index a9c87385c..e622b35ef 100644 --- a/transport/internet/headers/noop/noop.go +++ b/transport/internet/headers/noop/noop.go @@ -13,10 +13,8 @@ func (NoOpHeader) Size() int32 { return 0 } -// Write implements io.Writer. -func (NoOpHeader) Write([]byte) (int, error) { - return 0, nil -} +// Serialize implements PacketHeader. +func (NoOpHeader) Serialize([]byte) {} func NewNoOpHeader(context.Context, interface{}) (interface{}, error) { return NoOpHeader{}, nil diff --git a/transport/internet/headers/srtp/srtp.go b/transport/internet/headers/srtp/srtp.go index 85fe9b086..f4145291d 100644 --- a/transport/internet/headers/srtp/srtp.go +++ b/transport/internet/headers/srtp/srtp.go @@ -17,12 +17,11 @@ func (*SRTP) Size() int32 { return 4 } -// Write implements io.Writer. -func (s *SRTP) Write(b []byte) (int, error) { +// Serialize implements PacketHeader. +func (s *SRTP) Serialize(b []byte) { s.number++ binary.BigEndian.PutUint16(b, s.number) binary.BigEndian.PutUint16(b[2:], s.number) - return 4, nil } // New returns a new SRTP instance based on the given config. diff --git a/transport/internet/headers/srtp/srtp_test.go b/transport/internet/headers/srtp/srtp_test.go index 9c787013b..7ff95f2b0 100644 --- a/transport/internet/headers/srtp/srtp_test.go +++ b/transport/internet/headers/srtp/srtp_test.go @@ -19,7 +19,7 @@ func TestSRTPWrite(t *testing.T) { srtp := srtpRaw.(*SRTP) payload := buf.New() - payload.AppendSupplier(srtp.Write) + srtp.Serialize(payload.Extend(srtp.Size())) payload.Write(content) assert(payload.Len(), Equals, int32(len(content))+srtp.Size()) diff --git a/transport/internet/headers/tls/dtls.go b/transport/internet/headers/tls/dtls.go index 06a61f69a..e48ab6e99 100644 --- a/transport/internet/headers/tls/dtls.go +++ b/transport/internet/headers/tls/dtls.go @@ -19,8 +19,8 @@ func (*DTLS) Size() int32 { return 1 + 2 + 2 + 6 + 2 } -// Write implements PacketHeader. -func (d *DTLS) Write(b []byte) (int, error) { +// Serialize implements PacketHeader. +func (d *DTLS) Serialize(b []byte) { b[0] = 23 // application data b[1] = 254 b[2] = 253 @@ -39,7 +39,6 @@ func (d *DTLS) Write(b []byte) (int, error) { if d.length > 100 { d.length -= 50 } - return 13, nil } // New creates a new UTP header for the given config. diff --git a/transport/internet/headers/tls/dtls_test.go b/transport/internet/headers/tls/dtls_test.go index b02cb841a..91e2aa82e 100644 --- a/transport/internet/headers/tls/dtls_test.go +++ b/transport/internet/headers/tls/dtls_test.go @@ -19,7 +19,7 @@ func TestDTLSWrite(t *testing.T) { dtls := dtlsRaw.(*DTLS) payload := buf.New() - payload.AppendSupplier(dtls.Write) + dtls.Serialize(payload.Extend(dtls.Size())) payload.Write(content) assert(payload.Len(), Equals, int32(len(content))+dtls.Size()) diff --git a/transport/internet/headers/utp/utp.go b/transport/internet/headers/utp/utp.go index d35811602..0ed164ca9 100644 --- a/transport/internet/headers/utp/utp.go +++ b/transport/internet/headers/utp/utp.go @@ -18,12 +18,11 @@ func (*UTP) Size() int32 { return 4 } -// Write implements io.Writer. -func (u *UTP) Write(b []byte) (int, error) { +// Serialize implements PacketHeader. +func (u *UTP) Serialize(b []byte) { binary.BigEndian.PutUint16(b, u.connectionId) b[2] = u.header b[3] = u.extension - return 4, nil } // New creates a new UTP header for the given config. diff --git a/transport/internet/headers/utp/utp_test.go b/transport/internet/headers/utp/utp_test.go index 0eb65dc53..1cdf365fd 100644 --- a/transport/internet/headers/utp/utp_test.go +++ b/transport/internet/headers/utp/utp_test.go @@ -19,7 +19,7 @@ func TestUTPWrite(t *testing.T) { utp := utpRaw.(*UTP) payload := buf.New() - payload.AppendSupplier(utp.Write) + utp.Serialize(payload.Extend(utp.Size())) payload.Write(content) assert(payload.Len(), Equals, int32(len(content))+utp.Size()) diff --git a/transport/internet/headers/wechat/wechat.go b/transport/internet/headers/wechat/wechat.go index 86106b561..8cbb603b3 100644 --- a/transport/internet/headers/wechat/wechat.go +++ b/transport/internet/headers/wechat/wechat.go @@ -16,8 +16,8 @@ func (vc *VideoChat) Size() int32 { return 13 } -// Write implements io.Writer. -func (vc *VideoChat) Write(b []byte) (int, error) { +// Serialize implements PacketHeader. +func (vc *VideoChat) Serialize(b []byte) { vc.sn++ b[0] = 0xa1 b[1] = 0x08 @@ -29,7 +29,6 @@ func (vc *VideoChat) Write(b []byte) (int, error) { b[10] = 0x30 b[11] = 0x22 b[12] = 0x30 - return 13, nil } // NewVideoChat returns a new VideoChat instance based on given config. diff --git a/transport/internet/headers/wechat/wechat_test.go b/transport/internet/headers/wechat/wechat_test.go index 80185d1e2..a44e460d2 100644 --- a/transport/internet/headers/wechat/wechat_test.go +++ b/transport/internet/headers/wechat/wechat_test.go @@ -18,7 +18,7 @@ func TestUTPWrite(t *testing.T) { video := videoRaw.(*VideoChat) payload := buf.New() - payload.AppendSupplier(video.Write) + video.Serialize(payload.Extend(video.Size())) assert(payload.Len(), Equals, video.Size()) } diff --git a/transport/internet/headers/wireguard/wireguard.go b/transport/internet/headers/wireguard/wireguard.go index eb8a8ed9e..466f48e6a 100644 --- a/transport/internet/headers/wireguard/wireguard.go +++ b/transport/internet/headers/wireguard/wireguard.go @@ -12,10 +12,12 @@ func (Wireguard) Size() int32 { return 4 } -// Write implements io.Writer. -func (Wireguard) Write(b []byte) (int, error) { - b = append(b[:0], 0x04, 0x00, 0x00, 0x00) - return 4, nil +// Serialize implements PacketHeader. +func (Wireguard) Serialize(b []byte) { + b[0] = 0x04 + b[1] = 0x00 + b[2] = 0x00 + b[3] = 0x00 } // NewWireguard returns a new VideoChat instance based on given config. diff --git a/transport/internet/kcp/io.go b/transport/internet/kcp/io.go index 6249a7e6c..ef14d2ed5 100644 --- a/transport/internet/kcp/io.go +++ b/transport/internet/kcp/io.go @@ -74,20 +74,15 @@ func (w *KCPPacketWriter) Write(b []byte) (int, error) { defer bb.Release() if w.Header != nil { - common.Must(bb.AppendSupplier(func(x []byte) (int, error) { - return w.Header.Write(x) - })) + w.Header.Serialize(bb.Extend(w.Header.Size())) } if w.Security != nil { nonceSize := w.Security.NonceSize() - common.Must(bb.AppendSupplier(func(x []byte) (int, error) { - return rand.Read(x[:nonceSize]) - })) + common.Must2(bb.ReadFullFrom(rand.Reader, int32(nonceSize))) nonce := bb.BytesFrom(int32(-nonceSize)) - common.Must(bb.AppendSupplier(func(x []byte) (int, error) { - eb := w.Security.Seal(x[:0], nonce, b, nil) - return len(eb), nil - })) + + encrypted := bb.Extend(int32(w.Security.Overhead() + len(b))) + w.Security.Seal(encrypted[:0], nonce, b, nil) } else { bb.Write(b) } diff --git a/transport/internet/kcp/output.go b/transport/internet/kcp/output.go index 9e00215dd..470b8e456 100644 --- a/transport/internet/kcp/output.go +++ b/transport/internet/kcp/output.go @@ -6,7 +6,6 @@ import ( "v2ray.com/core/common/retry" - "v2ray.com/core/common" "v2ray.com/core/common/buf" ) @@ -31,7 +30,9 @@ func (w *SimpleSegmentWriter) Write(seg Segment) error { w.Lock() defer w.Unlock() - common.Must(w.buffer.Reset(seg.Bytes())) + w.buffer.Clear() + rawBytes := w.buffer.Extend(seg.ByteSize()) + seg.Serialize(rawBytes) _, err := w.writer.Write(w.buffer.Bytes()) return err } diff --git a/transport/internet/kcp/segment.go b/transport/internet/kcp/segment.go index beaefecfd..2d5b1d26a 100755 --- a/transport/internet/kcp/segment.go +++ b/transport/internet/kcp/segment.go @@ -31,7 +31,7 @@ type Segment interface { Conversation() uint16 Command() Command ByteSize() int32 - Bytes() buf.Supplier + Serialize([]byte) parse(conv uint16, cmd Command, opt SegmentOption, buf []byte) (bool, []byte) } @@ -104,18 +104,15 @@ func (s *DataSegment) Data() *buf.Buffer { return s.payload } -func (s *DataSegment) Bytes() buf.Supplier { - return func(b []byte) (int, error) { - binary.BigEndian.PutUint16(b, s.Conv) - b[2] = byte(CommandData) - b[3] = byte(s.Option) - binary.BigEndian.PutUint32(b[4:], s.Timestamp) - binary.BigEndian.PutUint32(b[8:], s.Number) - binary.BigEndian.PutUint32(b[12:], s.SendingNext) - binary.BigEndian.PutUint16(b[16:], uint16(s.payload.Len())) - n := copy(b[18:], s.payload.Bytes()) - return 18 + n, nil - } +func (s *DataSegment) Serialize(b []byte) { + binary.BigEndian.PutUint16(b, s.Conv) + b[2] = byte(CommandData) + b[3] = byte(s.Option) + binary.BigEndian.PutUint32(b[4:], s.Timestamp) + binary.BigEndian.PutUint32(b[8:], s.Number) + binary.BigEndian.PutUint32(b[12:], s.SendingNext) + binary.BigEndian.PutUint16(b[16:], uint16(s.payload.Len())) + copy(b[18:], s.payload.Bytes()) } func (s *DataSegment) ByteSize() int32 { @@ -202,21 +199,18 @@ func (s *AckSegment) ByteSize() int32 { return 2 + 1 + 1 + 4 + 4 + 4 + 1 + int32(len(s.NumberList)*4) } -func (s *AckSegment) Bytes() buf.Supplier { - return func(b []byte) (int, error) { - binary.BigEndian.PutUint16(b, s.Conv) - b[2] = byte(CommandACK) - b[3] = byte(s.Option) - binary.BigEndian.PutUint32(b[4:], s.ReceivingWindow) - binary.BigEndian.PutUint32(b[8:], s.ReceivingNext) - binary.BigEndian.PutUint32(b[12:], s.Timestamp) - b[16] = byte(len(s.NumberList)) - n := 17 - for _, number := range s.NumberList { - binary.BigEndian.PutUint32(b[n:], number) - n += 4 - } - return n, nil +func (s *AckSegment) Serialize(b []byte) { + binary.BigEndian.PutUint16(b, s.Conv) + b[2] = byte(CommandACK) + b[3] = byte(s.Option) + binary.BigEndian.PutUint32(b[4:], s.ReceivingWindow) + binary.BigEndian.PutUint32(b[8:], s.ReceivingNext) + binary.BigEndian.PutUint32(b[12:], s.Timestamp) + b[16] = byte(len(s.NumberList)) + n := 17 + for _, number := range s.NumberList { + binary.BigEndian.PutUint32(b[n:], number) + n += 4 } } @@ -268,16 +262,13 @@ func (*CmdOnlySegment) ByteSize() int32 { return 2 + 1 + 1 + 4 + 4 + 4 } -func (s *CmdOnlySegment) Bytes() buf.Supplier { - return func(b []byte) (int, error) { - binary.BigEndian.PutUint16(b, s.Conv) - b[2] = byte(s.Cmd) - b[3] = byte(s.Option) - binary.BigEndian.PutUint32(b[4:], s.SendingNext) - binary.BigEndian.PutUint32(b[8:], s.ReceivingNext) - binary.BigEndian.PutUint32(b[12:], s.PeerRTO) - return 16, nil - } +func (s *CmdOnlySegment) Serialize(b []byte) { + binary.BigEndian.PutUint16(b, s.Conv) + b[2] = byte(s.Cmd) + b[3] = byte(s.Option) + binary.BigEndian.PutUint32(b[4:], s.SendingNext) + binary.BigEndian.PutUint32(b[8:], s.ReceivingNext) + binary.BigEndian.PutUint32(b[12:], s.PeerRTO) } func (*CmdOnlySegment) Release() {} diff --git a/transport/internet/kcp/segment_test.go b/transport/internet/kcp/segment_test.go index 9d2a60e61..b60d26ebb 100644 --- a/transport/internet/kcp/segment_test.go +++ b/transport/internet/kcp/segment_test.go @@ -28,7 +28,7 @@ func TestDataSegment(t *testing.T) { nBytes := seg.ByteSize() bytes := make([]byte, nBytes) - seg.Bytes()(bytes) + seg.Serialize(bytes) assert(int32(len(bytes)), Equals, nBytes) @@ -54,7 +54,7 @@ func Test1ByteDataSegment(t *testing.T) { nBytes := seg.ByteSize() bytes := make([]byte, nBytes) - seg.Bytes()(bytes) + seg.Serialize(bytes) assert(int32(len(bytes)), Equals, nBytes) @@ -80,7 +80,7 @@ func TestACKSegment(t *testing.T) { nBytes := seg.ByteSize() bytes := make([]byte, nBytes) - seg.Bytes()(bytes) + seg.Serialize(bytes) assert(int32(len(bytes)), Equals, nBytes) @@ -110,7 +110,7 @@ func TestCmdSegment(t *testing.T) { nBytes := seg.ByteSize() bytes := make([]byte, nBytes) - seg.Bytes()(bytes) + seg.Serialize(bytes) assert(int32(len(bytes)), Equals, nBytes) diff --git a/transport/internet/udp/hub.go b/transport/internet/udp/hub.go index 2da7a5cf2..f62f598e7 100644 --- a/transport/internet/udp/hub.go +++ b/transport/internet/udp/hub.go @@ -88,18 +88,15 @@ func (h *Hub) start() { buffer := buf.New() var noob int var addr *net.UDPAddr - err := buffer.Reset(func(b []byte) (int, error) { - n, nb, _, a, e := ReadUDPMsg(h.conn, b, oobBytes) - noob = nb - addr = a - return n, e - }) + rawBytes := buffer.Extend(buf.Size) + n, noob, _, addr, err := ReadUDPMsg(h.conn, rawBytes, oobBytes) if err != nil { newError("failed to read UDP msg").Base(err).WriteToLog() buffer.Release() break } + buffer.Resize(0, int32(n)) if buffer.IsEmpty() { buffer.Release()