1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-02 07:26:24 -05:00

simplify buffer extension

This commit is contained in:
Darien Raymond 2018-11-02 21:34:04 +01:00
parent 35ccc3a49c
commit f7b96507f9
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
31 changed files with 139 additions and 193 deletions

View File

@ -39,9 +39,10 @@ func (r *cachedReader) Cache(b *buf.Buffer) {
if !mb.IsEmpty() { if !mb.IsEmpty() {
common.Must(r.cache.WriteMultiBuffer(mb)) common.Must(r.cache.WriteMultiBuffer(mb))
} }
common.Must(b.Reset(func(x []byte) (int, error) { b.Clear()
return r.cache.Copy(x), nil rawBytes := b.Extend(buf.Size)
})) n := r.cache.Copy(rawBytes)
b.Resize(0, int32(n))
r.Unlock() r.Unlock()
} }

View File

@ -260,13 +260,13 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) { func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) {
buffer := buf.New() buffer := buf.New()
if err := buffer.Reset(func(b []byte) (int, error) { rawBytes := buffer.Extend(buf.Size)
writtenBuffer, err := msg.PackBuffer(b) packed, err := msg.PackBuffer(rawBytes)
return len(writtenBuffer), err if err != nil {
}); err != nil {
buffer.Release() buffer.Release()
return nil, err return nil, err
} }
buffer.Resize(0, int32(len(packed)))
return buffer, nil return buffer, nil
} }

View File

@ -11,9 +11,6 @@ const (
Size = 2048 Size = 2048
) )
// Supplier is a writer that writes contents into the given buffer.
type Supplier func([]byte) (int, error)
// Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles // Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles
// the buffer into an internal buffer pool, in order to recreate a buffer more // the buffer into an internal buffer pool, in order to recreate a buffer more
// quickly. // quickly.
@ -40,13 +37,6 @@ func (b *Buffer) Clear() {
b.end = 0 b.end = 0
} }
// AppendSupplier appends the content of a BytesWriter to the buffer.
func (b *Buffer) AppendSupplier(writer Supplier) error {
nBytes, err := writer(b.v[b.end:])
b.end += int32(nBytes)
return err
}
// Byte returns the bytes at index. // Byte returns the bytes at index.
func (b *Buffer) Byte(index int32) byte { func (b *Buffer) Byte(index int32) byte {
return b.v[b.start+index] return b.v[b.start+index]
@ -62,15 +52,16 @@ func (b *Buffer) Bytes() []byte {
return b.v[b.start:b.end] return b.v[b.start:b.end]
} }
// Reset resets the content of the Buffer with a supplier. // Extend increases the buffer size by n bytes, and returns the extended part.
func (b *Buffer) Reset(writer Supplier) error { // It panics if result size is larger than buf.Size.
nBytes, err := writer(b.v) func (b *Buffer) Extend(n int32) []byte {
if nBytes > len(b.v) { end := b.end + n
return newError("too many bytes written: ", nBytes, " > ", len(b.v)) if end > int32(len(b.v)) {
panic(newError("out of bound: ", end))
} }
b.start = 0 ext := b.v[b.end:end]
b.end = int32(nBytes) b.end = end
return err return ext
} }
// BytesRange returns a slice of this buffer with given from and to boundary. // BytesRange returns a slice of this buffer with given from and to boundary.
@ -153,6 +144,11 @@ func (b *Buffer) WriteBytes(bytes ...byte) (int, error) {
return b.Write(bytes) return b.Write(bytes)
} }
// WriteString implements io.StringWriter.
func (b *Buffer) WriteString(s string) (int, error) {
return b.Write([]byte(s))
}
// Read implements io.Reader.Read(). // Read implements io.Reader.Read().
func (b *Buffer) Read(data []byte) (int, error) { func (b *Buffer) Read(data []byte) (int, error) {
if b.Len() == 0 { if b.Len() == 0 {
@ -174,6 +170,7 @@ func (b *Buffer) ReadFrom(reader io.Reader) (int64, error) {
return int64(n), err return int64(n), err
} }
// ReadFullFrom reads exact size of bytes from given reader, or until error occurs.
func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) { func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) {
end := b.end + size end := b.end + size
if end > int32(len(b.v)) { if end > int32(len(b.v)) {

View File

@ -8,7 +8,6 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
. "v2ray.com/core/common/buf" . "v2ray.com/core/common/buf"
"v2ray.com/core/common/compare" "v2ray.com/core/common/compare"
"v2ray.com/core/common/serial"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
) )
@ -41,7 +40,7 @@ func TestBufferString(t *testing.T) {
buffer := New() buffer := New()
defer buffer.Release() defer buffer.Release()
assert(buffer.AppendSupplier(serial.WriteString("Test String")), IsNil) common.Must2(buffer.WriteString("Test String"))
assert(buffer.String(), Equals, "Test String") assert(buffer.String(), Equals, "Test String")
} }

View File

@ -245,10 +245,10 @@ func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, wr
} }
func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) { func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
encryptedSize := int(b.Len()) + w.auth.Overhead() encryptedSize := b.Len() + int32(w.auth.Overhead())
var paddingSize int var paddingSize int32
if w.padding != nil { if w.padding != nil {
paddingSize = int(w.padding.NextPaddingLen()) paddingSize = int32(w.padding.NextPaddingLen())
} }
totalSize := encryptedSize + paddingSize totalSize := encryptedSize + paddingSize
@ -257,14 +257,8 @@ func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
} }
eb := buf.New() eb := buf.New()
common.Must(eb.Reset(func(bb []byte) (int, error) { w.sizeParser.Encode(uint16(encryptedSize+paddingSize), eb.Extend(w.sizeParser.SizeBytes()))
w.sizeParser.Encode(uint16(encryptedSize+paddingSize), bb) if _, err := w.auth.Seal(eb.Extend(encryptedSize)[:0], b.Bytes()); err != nil {
return int(w.sizeParser.SizeBytes()), nil
}))
if err := eb.AppendSupplier(func(bb []byte) (int, error) {
_, err := w.auth.Seal(bb[:0], b.Bytes())
return encryptedSize, err
}); err != nil {
eb.Release() eb.Release()
return nil, err return nil, err
} }

View File

@ -146,10 +146,7 @@ func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
slice := mb.SliceBySize(sliceSize) slice := mb.SliceBySize(sliceSize)
b := buf.New() b := buf.New()
common.Must(b.Reset(func(buffer []byte) (int, error) { w.sizeEncoder.Encode(uint16(slice.Len()), b.Extend(w.sizeEncoder.SizeBytes()))
w.sizeEncoder.Encode(uint16(slice.Len()), buffer)
return int(w.sizeEncoder.SizeBytes()), nil
}))
mb2Write.Append(b) mb2Write.Append(b)
mb2Write.AppendMulti(slice) mb2Write.AppendMulti(slice)

View File

@ -1,12 +0,0 @@
package serial
import (
"hash"
)
func WriteHash(h hash.Hash) func(b []byte) (int, error) {
return func(b []byte) (int, error) {
h.Sum(b[:0])
return h.Size(), nil
}
}

View File

@ -3,7 +3,6 @@ package blackhole
import ( import (
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/serial"
) )
const ( const (
@ -28,7 +27,7 @@ func (*NoneResponse) WriteTo(buf.Writer) int32 { return 0 }
// WriteTo implements ResponseConfig.WriteTo(). // WriteTo implements ResponseConfig.WriteTo().
func (*HTTPResponse) WriteTo(writer buf.Writer) int32 { func (*HTTPResponse) WriteTo(writer buf.Writer) int32 {
b := buf.New() b := buf.New()
common.Must(b.Reset(serial.WriteString(http403response))) common.Must2(b.Write([]byte(http403response)))
n := b.Len() n := b.Len()
writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
return n return n

View File

@ -198,13 +198,10 @@ func (c *AEADCipher) EncodePacket(key []byte, b *buf.Buffer) error {
ivLen := c.IVSize() ivLen := c.IVSize()
payloadLen := b.Len() payloadLen := b.Len()
auth := c.createAuthenticator(key, b.BytesTo(ivLen)) auth := c.createAuthenticator(key, b.BytesTo(ivLen))
return b.Reset(func(bb []byte) (int, error) {
bbb, err := auth.Seal(bb[:ivLen], bb[ivLen:payloadLen]) b.Extend(int32(auth.Overhead()))
if err != nil { _, err := auth.Seal(b.BytesTo(ivLen), b.BytesRange(ivLen, payloadLen))
return 0, err return err
}
return len(bbb), nil
})
} }
func (c *AEADCipher) DecodePacket(key []byte, b *buf.Buffer) error { func (c *AEADCipher) DecodePacket(key []byte, b *buf.Buffer) error {
@ -214,16 +211,12 @@ func (c *AEADCipher) DecodePacket(key []byte, b *buf.Buffer) error {
ivLen := c.IVSize() ivLen := c.IVSize()
payloadLen := b.Len() payloadLen := b.Len()
auth := c.createAuthenticator(key, b.BytesTo(ivLen)) auth := c.createAuthenticator(key, b.BytesTo(ivLen))
if err := b.Reset(func(bb []byte) (int, error) {
bbb, err := auth.Open(bb[:ivLen], bb[ivLen:payloadLen]) bbb, err := auth.Open(b.BytesTo(ivLen), b.BytesRange(ivLen, payloadLen))
if err != nil { if err != nil {
return 0, err
}
return len(bbb), nil
}); err != nil {
return err return err
} }
b.Advance(ivLen) b.Resize(ivLen, ivLen+int32(len(bbb)))
return nil return nil
} }

View File

@ -30,13 +30,11 @@ func NewAuthenticator(keygen KeyGenerator) *Authenticator {
} }
} }
func (v *Authenticator) Authenticate(data []byte) buf.Supplier { func (v *Authenticator) Authenticate(data []byte, dest []byte) {
hasher := hmac.New(sha1.New, v.key()) hasher := hmac.New(sha1.New, v.key())
common.Must2(hasher.Write(data)) common.Must2(hasher.Write(data))
res := hasher.Sum(nil) res := hasher.Sum(nil)
return func(b []byte) (int, error) { copy(dest, res[:AuthSize])
return copy(b, res[:AuthSize]), nil
}
} }
func HeaderKeyGenerator(key []byte, iv []byte) func() []byte { func HeaderKeyGenerator(key []byte, iv []byte) func() []byte {
@ -89,7 +87,7 @@ func (v *ChunkReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
payload := buffer[AuthSize:size] payload := buffer[AuthSize:size]
actualAuthBytes := make([]byte, AuthSize) actualAuthBytes := make([]byte, AuthSize)
v.auth.Authenticate(payload)(actualAuthBytes) v.auth.Authenticate(payload, actualAuthBytes)
if !bytes.Equal(authBytes, actualAuthBytes) { if !bytes.Equal(authBytes, actualAuthBytes) {
return nil, newError("invalid auth") return nil, newError("invalid auth")
} }
@ -121,7 +119,7 @@ func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for { for {
payloadLen, _ := mb.Read(w.buffer[2+AuthSize:]) payloadLen, _ := mb.Read(w.buffer[2+AuthSize:])
binary.BigEndian.PutUint16(w.buffer, uint16(payloadLen)) binary.BigEndian.PutUint16(w.buffer, uint16(payloadLen))
w.auth.Authenticate(w.buffer[2+AuthSize : 2+AuthSize+payloadLen])(w.buffer[2:]) w.auth.Authenticate(w.buffer[2+AuthSize:2+AuthSize+payloadLen], w.buffer[2:])
if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil { if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil {
return err return err
} }

View File

@ -142,7 +142,8 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
header.SetByte(0, header.Byte(0)|0x10) header.SetByte(0, header.Byte(0)|0x10)
authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
common.Must(header.AppendSupplier(authenticator.Authenticate(header.Bytes()))) authBuffer := header.Extend(AuthSize)
authenticator.Authenticate(header.Bytes(), authBuffer)
} }
if err := w.WriteMultiBuffer(buf.NewMultiBufferValue(header)); err != nil { if err := w.WriteMultiBuffer(buf.NewMultiBufferValue(header)); err != nil {
@ -210,7 +211,9 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
buffer.SetByte(ivLen, buffer.Byte(ivLen)|0x10) buffer.SetByte(ivLen, buffer.Byte(ivLen)|0x10)
common.Must(buffer.AppendSupplier(authenticator.Authenticate(buffer.BytesFrom(ivLen)))) authPayload := buffer.BytesFrom(ivLen)
authBuffer := buffer.Extend(AuthSize)
authenticator.Authenticate(authPayload, authBuffer)
} }
if err := account.Cipher.EncodePacket(account.Key, buffer); err != nil { if err := account.Cipher.EncodePacket(account.Key, buffer); err != nil {
return nil, newError("failed to encrypt UDP payload").Base(err) return nil, newError("failed to encrypt UDP payload").Base(err)

View File

@ -7,7 +7,6 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
. "v2ray.com/core/proxy/shadowsocks" . "v2ray.com/core/proxy/shadowsocks"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
) )
@ -37,7 +36,7 @@ func TestUDPEncoding(t *testing.T) {
} }
data := buf.New() data := buf.New()
data.AppendSupplier(serial.WriteString("test string")) common.Must2(data.WriteString("test string"))
encodedData, err := EncodeUDPPacket(request, data.Bytes()) encodedData, err := EncodeUDPPacket(request, data.Bytes())
assert(err, IsNil) assert(err, IsNil)
@ -168,7 +167,7 @@ func TestUDPReaderWriter(t *testing.T) {
{ {
b := buf.New() b := buf.New()
b.AppendSupplier(serial.WriteString("test payload")) common.Must2(b.WriteString("test payload"))
err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
assert(err, IsNil) assert(err, IsNil)
@ -179,7 +178,7 @@ func TestUDPReaderWriter(t *testing.T) {
{ {
b := buf.New() b := buf.New()
b.AppendSupplier(serial.WriteString("test payload 2")) common.Must2(b.WriteString("test payload 2"))
err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) err := writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
assert(err, IsNil) assert(err, IsNil)

View File

@ -16,7 +16,6 @@ import (
"v2ray.com/core/common/crypto" "v2ray.com/core/common/crypto"
"v2ray.com/core/common/dice" "v2ray.com/core/common/dice"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/vio" "v2ray.com/core/common/vio"
"v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess"
) )
@ -88,7 +87,8 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
{ {
fnv1a := fnv.New32a() fnv1a := fnv.New32a()
common.Must2(fnv1a.Write(buffer.Bytes())) common.Must2(fnv1a.Write(buffer.Bytes()))
common.Must(buffer.AppendSupplier(serial.WriteHash(fnv1a))) hashBytes := buffer.Extend(int32(fnv1a.Size()))
fnv1a.Sum(hashBytes[:0])
} }
iv := hashTimestamp(md5.New(), timestamp) iv := hashTimestamp(md5.New(), timestamp)

View File

@ -9,7 +9,7 @@ import (
type PacketHeader interface { type PacketHeader interface {
Size() int32 Size() int32
Write([]byte) (int, error) Serialize([]byte)
} }
func CreatePacketHeader(config interface{}) (PacketHeader, error) { func CreatePacketHeader(config interface{}) (PacketHeader, error) {

View File

@ -13,7 +13,6 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/serial"
) )
const ( const (
@ -29,7 +28,6 @@ const (
var ( var (
ErrHeaderToLong = newError("Header too long.") ErrHeaderToLong = newError("Header too long.")
writeCRLF = serial.WriteString(CRLF)
) )
type Reader interface { type Reader interface {
@ -70,12 +68,12 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
endingDetected = true endingDetected = true
break break
} }
if buffer.Len() >= int32(len(ENDING)) { lenEnding := int32(len(ENDING))
totalBytes += buffer.Len() - int32(len(ENDING)) if buffer.Len() >= lenEnding {
leftover := buffer.BytesFrom(-int32(len(ENDING))) totalBytes += buffer.Len() - lenEnding
buffer.Reset(func(b []byte) (int, error) { leftover := buffer.BytesFrom(-lenEnding)
return copy(b, leftover), nil buffer.Clear()
}) copy(buffer.Extend(lenEnding), leftover)
} }
} }
if buffer.IsEmpty() { if buffer.IsEmpty() {
@ -175,20 +173,20 @@ func (c *HttpConn) Close() error {
func formResponseHeader(config *ResponseConfig) *HeaderWriter { func formResponseHeader(config *ResponseConfig) *HeaderWriter {
header := buf.New() header := buf.New()
header.AppendSupplier(serial.WriteString(strings.Join([]string{config.GetFullVersion(), config.GetStatusValue().Code, config.GetStatusValue().Reason}, " "))) common.Must2(header.WriteString(strings.Join([]string{config.GetFullVersion(), config.GetStatusValue().Code, config.GetStatusValue().Reason}, " ")))
header.AppendSupplier(writeCRLF) common.Must2(header.WriteString(CRLF))
headers := config.PickHeaders() headers := config.PickHeaders()
for _, h := range headers { for _, h := range headers {
header.AppendSupplier(serial.WriteString(h)) common.Must2(header.WriteString(h))
header.AppendSupplier(writeCRLF) common.Must2(header.WriteString(CRLF))
} }
if !config.HasHeader("Date") { if !config.HasHeader("Date") {
header.AppendSupplier(serial.WriteString("Date: ")) common.Must2(header.WriteString("Date: "))
header.AppendSupplier(serial.WriteString(time.Now().Format(http.TimeFormat))) common.Must2(header.WriteString(time.Now().Format(http.TimeFormat)))
header.AppendSupplier(writeCRLF) common.Must2(header.WriteString(CRLF))
} }
header.AppendSupplier(writeCRLF) common.Must2(header.WriteString(CRLF))
return &HeaderWriter{ return &HeaderWriter{
header: header, header: header,
} }
@ -201,15 +199,15 @@ type HttpAuthenticator struct {
func (a HttpAuthenticator) GetClientWriter() *HeaderWriter { func (a HttpAuthenticator) GetClientWriter() *HeaderWriter {
header := buf.New() header := buf.New()
config := a.config.Request config := a.config.Request
header.AppendSupplier(serial.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickUri(), config.GetFullVersion()}, " "))) common.Must2(header.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickUri(), config.GetFullVersion()}, " ")))
header.AppendSupplier(writeCRLF) common.Must2(header.WriteString(CRLF))
headers := config.PickHeaders() headers := config.PickHeaders()
for _, h := range headers { for _, h := range headers {
header.AppendSupplier(serial.WriteString(h)) common.Must2(header.WriteString(h))
header.AppendSupplier(writeCRLF) common.Must2(header.WriteString(CRLF))
} }
header.AppendSupplier(writeCRLF) common.Must2(header.WriteString(CRLF))
return &HeaderWriter{ return &HeaderWriter{
header: header, header: header,
} }

View File

@ -5,9 +5,9 @@ import (
"testing" "testing"
"time" "time"
"v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
. "v2ray.com/core/transport/internet/headers/http" . "v2ray.com/core/transport/internet/headers/http"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
) )
@ -17,7 +17,7 @@ func TestReaderWriter(t *testing.T) {
cache := buf.New() cache := buf.New()
b := buf.New() b := buf.New()
b.AppendSupplier(serial.WriteString("abcd" + ENDING)) common.Must2(b.WriteString("abcd" + ENDING))
writer := NewHeaderWriter(b) writer := NewHeaderWriter(b)
err := writer.Write(cache) err := writer.Write(cache)
assert(err, IsNil) assert(err, IsNil)

View File

@ -13,10 +13,8 @@ func (NoOpHeader) Size() int32 {
return 0 return 0
} }
// Write implements io.Writer. // Serialize implements PacketHeader.
func (NoOpHeader) Write([]byte) (int, error) { func (NoOpHeader) Serialize([]byte) {}
return 0, nil
}
func NewNoOpHeader(context.Context, interface{}) (interface{}, error) { func NewNoOpHeader(context.Context, interface{}) (interface{}, error) {
return NoOpHeader{}, nil return NoOpHeader{}, nil

View File

@ -17,12 +17,11 @@ func (*SRTP) Size() int32 {
return 4 return 4
} }
// Write implements io.Writer. // Serialize implements PacketHeader.
func (s *SRTP) Write(b []byte) (int, error) { func (s *SRTP) Serialize(b []byte) {
s.number++ s.number++
binary.BigEndian.PutUint16(b, s.number) binary.BigEndian.PutUint16(b, s.number)
binary.BigEndian.PutUint16(b[2:], s.number) binary.BigEndian.PutUint16(b[2:], s.number)
return 4, nil
} }
// New returns a new SRTP instance based on the given config. // New returns a new SRTP instance based on the given config.

View File

@ -19,7 +19,7 @@ func TestSRTPWrite(t *testing.T) {
srtp := srtpRaw.(*SRTP) srtp := srtpRaw.(*SRTP)
payload := buf.New() payload := buf.New()
payload.AppendSupplier(srtp.Write) srtp.Serialize(payload.Extend(srtp.Size()))
payload.Write(content) payload.Write(content)
assert(payload.Len(), Equals, int32(len(content))+srtp.Size()) assert(payload.Len(), Equals, int32(len(content))+srtp.Size())

View File

@ -19,8 +19,8 @@ func (*DTLS) Size() int32 {
return 1 + 2 + 2 + 6 + 2 return 1 + 2 + 2 + 6 + 2
} }
// Write implements PacketHeader. // Serialize implements PacketHeader.
func (d *DTLS) Write(b []byte) (int, error) { func (d *DTLS) Serialize(b []byte) {
b[0] = 23 // application data b[0] = 23 // application data
b[1] = 254 b[1] = 254
b[2] = 253 b[2] = 253
@ -39,7 +39,6 @@ func (d *DTLS) Write(b []byte) (int, error) {
if d.length > 100 { if d.length > 100 {
d.length -= 50 d.length -= 50
} }
return 13, nil
} }
// New creates a new UTP header for the given config. // New creates a new UTP header for the given config.

View File

@ -19,7 +19,7 @@ func TestDTLSWrite(t *testing.T) {
dtls := dtlsRaw.(*DTLS) dtls := dtlsRaw.(*DTLS)
payload := buf.New() payload := buf.New()
payload.AppendSupplier(dtls.Write) dtls.Serialize(payload.Extend(dtls.Size()))
payload.Write(content) payload.Write(content)
assert(payload.Len(), Equals, int32(len(content))+dtls.Size()) assert(payload.Len(), Equals, int32(len(content))+dtls.Size())

View File

@ -18,12 +18,11 @@ func (*UTP) Size() int32 {
return 4 return 4
} }
// Write implements io.Writer. // Serialize implements PacketHeader.
func (u *UTP) Write(b []byte) (int, error) { func (u *UTP) Serialize(b []byte) {
binary.BigEndian.PutUint16(b, u.connectionId) binary.BigEndian.PutUint16(b, u.connectionId)
b[2] = u.header b[2] = u.header
b[3] = u.extension b[3] = u.extension
return 4, nil
} }
// New creates a new UTP header for the given config. // New creates a new UTP header for the given config.

View File

@ -19,7 +19,7 @@ func TestUTPWrite(t *testing.T) {
utp := utpRaw.(*UTP) utp := utpRaw.(*UTP)
payload := buf.New() payload := buf.New()
payload.AppendSupplier(utp.Write) utp.Serialize(payload.Extend(utp.Size()))
payload.Write(content) payload.Write(content)
assert(payload.Len(), Equals, int32(len(content))+utp.Size()) assert(payload.Len(), Equals, int32(len(content))+utp.Size())

View File

@ -16,8 +16,8 @@ func (vc *VideoChat) Size() int32 {
return 13 return 13
} }
// Write implements io.Writer. // Serialize implements PacketHeader.
func (vc *VideoChat) Write(b []byte) (int, error) { func (vc *VideoChat) Serialize(b []byte) {
vc.sn++ vc.sn++
b[0] = 0xa1 b[0] = 0xa1
b[1] = 0x08 b[1] = 0x08
@ -29,7 +29,6 @@ func (vc *VideoChat) Write(b []byte) (int, error) {
b[10] = 0x30 b[10] = 0x30
b[11] = 0x22 b[11] = 0x22
b[12] = 0x30 b[12] = 0x30
return 13, nil
} }
// NewVideoChat returns a new VideoChat instance based on given config. // NewVideoChat returns a new VideoChat instance based on given config.

View File

@ -18,7 +18,7 @@ func TestUTPWrite(t *testing.T) {
video := videoRaw.(*VideoChat) video := videoRaw.(*VideoChat)
payload := buf.New() payload := buf.New()
payload.AppendSupplier(video.Write) video.Serialize(payload.Extend(video.Size()))
assert(payload.Len(), Equals, video.Size()) assert(payload.Len(), Equals, video.Size())
} }

View File

@ -12,10 +12,12 @@ func (Wireguard) Size() int32 {
return 4 return 4
} }
// Write implements io.Writer. // Serialize implements PacketHeader.
func (Wireguard) Write(b []byte) (int, error) { func (Wireguard) Serialize(b []byte) {
b = append(b[:0], 0x04, 0x00, 0x00, 0x00) b[0] = 0x04
return 4, nil b[1] = 0x00
b[2] = 0x00
b[3] = 0x00
} }
// NewWireguard returns a new VideoChat instance based on given config. // NewWireguard returns a new VideoChat instance based on given config.

View File

@ -74,20 +74,15 @@ func (w *KCPPacketWriter) Write(b []byte) (int, error) {
defer bb.Release() defer bb.Release()
if w.Header != nil { if w.Header != nil {
common.Must(bb.AppendSupplier(func(x []byte) (int, error) { w.Header.Serialize(bb.Extend(w.Header.Size()))
return w.Header.Write(x)
}))
} }
if w.Security != nil { if w.Security != nil {
nonceSize := w.Security.NonceSize() nonceSize := w.Security.NonceSize()
common.Must(bb.AppendSupplier(func(x []byte) (int, error) { common.Must2(bb.ReadFullFrom(rand.Reader, int32(nonceSize)))
return rand.Read(x[:nonceSize])
}))
nonce := bb.BytesFrom(int32(-nonceSize)) nonce := bb.BytesFrom(int32(-nonceSize))
common.Must(bb.AppendSupplier(func(x []byte) (int, error) {
eb := w.Security.Seal(x[:0], nonce, b, nil) encrypted := bb.Extend(int32(w.Security.Overhead() + len(b)))
return len(eb), nil w.Security.Seal(encrypted[:0], nonce, b, nil)
}))
} else { } else {
bb.Write(b) bb.Write(b)
} }

View File

@ -6,7 +6,6 @@ import (
"v2ray.com/core/common/retry" "v2ray.com/core/common/retry"
"v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
) )
@ -31,7 +30,9 @@ func (w *SimpleSegmentWriter) Write(seg Segment) error {
w.Lock() w.Lock()
defer w.Unlock() defer w.Unlock()
common.Must(w.buffer.Reset(seg.Bytes())) w.buffer.Clear()
rawBytes := w.buffer.Extend(seg.ByteSize())
seg.Serialize(rawBytes)
_, err := w.writer.Write(w.buffer.Bytes()) _, err := w.writer.Write(w.buffer.Bytes())
return err return err
} }

View File

@ -31,7 +31,7 @@ type Segment interface {
Conversation() uint16 Conversation() uint16
Command() Command Command() Command
ByteSize() int32 ByteSize() int32
Bytes() buf.Supplier Serialize([]byte)
parse(conv uint16, cmd Command, opt SegmentOption, buf []byte) (bool, []byte) parse(conv uint16, cmd Command, opt SegmentOption, buf []byte) (bool, []byte)
} }
@ -104,8 +104,7 @@ func (s *DataSegment) Data() *buf.Buffer {
return s.payload return s.payload
} }
func (s *DataSegment) Bytes() buf.Supplier { func (s *DataSegment) Serialize(b []byte) {
return func(b []byte) (int, error) {
binary.BigEndian.PutUint16(b, s.Conv) binary.BigEndian.PutUint16(b, s.Conv)
b[2] = byte(CommandData) b[2] = byte(CommandData)
b[3] = byte(s.Option) b[3] = byte(s.Option)
@ -113,9 +112,7 @@ func (s *DataSegment) Bytes() buf.Supplier {
binary.BigEndian.PutUint32(b[8:], s.Number) binary.BigEndian.PutUint32(b[8:], s.Number)
binary.BigEndian.PutUint32(b[12:], s.SendingNext) binary.BigEndian.PutUint32(b[12:], s.SendingNext)
binary.BigEndian.PutUint16(b[16:], uint16(s.payload.Len())) binary.BigEndian.PutUint16(b[16:], uint16(s.payload.Len()))
n := copy(b[18:], s.payload.Bytes()) copy(b[18:], s.payload.Bytes())
return 18 + n, nil
}
} }
func (s *DataSegment) ByteSize() int32 { func (s *DataSegment) ByteSize() int32 {
@ -202,8 +199,7 @@ func (s *AckSegment) ByteSize() int32 {
return 2 + 1 + 1 + 4 + 4 + 4 + 1 + int32(len(s.NumberList)*4) return 2 + 1 + 1 + 4 + 4 + 4 + 1 + int32(len(s.NumberList)*4)
} }
func (s *AckSegment) Bytes() buf.Supplier { func (s *AckSegment) Serialize(b []byte) {
return func(b []byte) (int, error) {
binary.BigEndian.PutUint16(b, s.Conv) binary.BigEndian.PutUint16(b, s.Conv)
b[2] = byte(CommandACK) b[2] = byte(CommandACK)
b[3] = byte(s.Option) b[3] = byte(s.Option)
@ -216,8 +212,6 @@ func (s *AckSegment) Bytes() buf.Supplier {
binary.BigEndian.PutUint32(b[n:], number) binary.BigEndian.PutUint32(b[n:], number)
n += 4 n += 4
} }
return n, nil
}
} }
func (s *AckSegment) Release() {} func (s *AckSegment) Release() {}
@ -268,16 +262,13 @@ func (*CmdOnlySegment) ByteSize() int32 {
return 2 + 1 + 1 + 4 + 4 + 4 return 2 + 1 + 1 + 4 + 4 + 4
} }
func (s *CmdOnlySegment) Bytes() buf.Supplier { func (s *CmdOnlySegment) Serialize(b []byte) {
return func(b []byte) (int, error) {
binary.BigEndian.PutUint16(b, s.Conv) binary.BigEndian.PutUint16(b, s.Conv)
b[2] = byte(s.Cmd) b[2] = byte(s.Cmd)
b[3] = byte(s.Option) b[3] = byte(s.Option)
binary.BigEndian.PutUint32(b[4:], s.SendingNext) binary.BigEndian.PutUint32(b[4:], s.SendingNext)
binary.BigEndian.PutUint32(b[8:], s.ReceivingNext) binary.BigEndian.PutUint32(b[8:], s.ReceivingNext)
binary.BigEndian.PutUint32(b[12:], s.PeerRTO) binary.BigEndian.PutUint32(b[12:], s.PeerRTO)
return 16, nil
}
} }
func (*CmdOnlySegment) Release() {} func (*CmdOnlySegment) Release() {}

View File

@ -28,7 +28,7 @@ func TestDataSegment(t *testing.T) {
nBytes := seg.ByteSize() nBytes := seg.ByteSize()
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
seg.Bytes()(bytes) seg.Serialize(bytes)
assert(int32(len(bytes)), Equals, nBytes) assert(int32(len(bytes)), Equals, nBytes)
@ -54,7 +54,7 @@ func Test1ByteDataSegment(t *testing.T) {
nBytes := seg.ByteSize() nBytes := seg.ByteSize()
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
seg.Bytes()(bytes) seg.Serialize(bytes)
assert(int32(len(bytes)), Equals, nBytes) assert(int32(len(bytes)), Equals, nBytes)
@ -80,7 +80,7 @@ func TestACKSegment(t *testing.T) {
nBytes := seg.ByteSize() nBytes := seg.ByteSize()
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
seg.Bytes()(bytes) seg.Serialize(bytes)
assert(int32(len(bytes)), Equals, nBytes) assert(int32(len(bytes)), Equals, nBytes)
@ -110,7 +110,7 @@ func TestCmdSegment(t *testing.T) {
nBytes := seg.ByteSize() nBytes := seg.ByteSize()
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
seg.Bytes()(bytes) seg.Serialize(bytes)
assert(int32(len(bytes)), Equals, nBytes) assert(int32(len(bytes)), Equals, nBytes)

View File

@ -88,18 +88,15 @@ func (h *Hub) start() {
buffer := buf.New() buffer := buf.New()
var noob int var noob int
var addr *net.UDPAddr var addr *net.UDPAddr
err := buffer.Reset(func(b []byte) (int, error) { rawBytes := buffer.Extend(buf.Size)
n, nb, _, a, e := ReadUDPMsg(h.conn, b, oobBytes)
noob = nb
addr = a
return n, e
})
n, noob, _, addr, err := ReadUDPMsg(h.conn, rawBytes, oobBytes)
if err != nil { if err != nil {
newError("failed to read UDP msg").Base(err).WriteToLog() newError("failed to read UDP msg").Base(err).WriteToLog()
buffer.Release() buffer.Release()
break break
} }
buffer.Resize(0, int32(n))
if buffer.IsEmpty() { if buffer.IsEmpty() {
buffer.Release() buffer.Release()