1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-21 09:36:34 -05:00

multi buffer

This commit is contained in:
Darien Raymond 2017-04-15 21:07:23 +02:00
parent 0ef9143ffd
commit f506a39d32
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
29 changed files with 390 additions and 297 deletions

View File

@ -180,31 +180,20 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
}
func drain(reader *Reader) error {
for {
data, more, err := reader.Read()
if err != nil {
return err
}
data.Release()
if !more {
return nil
}
data, err := reader.Read()
if err != nil {
return err
}
data.Release()
return nil
}
func pipe(reader *Reader, writer buf.Writer) error {
for {
data, more, err := reader.Read()
if err != nil {
return err
}
if err := writer.Write(data); err != nil {
return err
}
if !more {
return nil
}
data, err := reader.Read()
if err != nil {
return err
}
return writer.Write(data)
}
func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *Reader) error {

View File

@ -20,7 +20,7 @@ func TestReaderWriter(t *testing.T) {
payload := buf.New()
payload.AppendBytes('a', 'b', 'c', 'd')
assert.Error(writer.Write(payload)).IsNil()
assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
writer.Close()
@ -32,10 +32,9 @@ func TestReaderWriter(t *testing.T) {
assert.Destination(meta.Target).Equals(dest)
assert.Byte(byte(meta.Option)).Equals(byte(OptionData))
data, more, err := reader.Read()
data, err := reader.Read()
assert.Error(err).IsNil()
assert.Bool(more).IsFalse()
assert.String(data.String()).Equals("abcd")
assert.String(data[0].String()).Equals("abcd")
meta, err = reader.ReadMetadata()
assert.Error(err).IsNil()

View File

@ -8,9 +8,8 @@ import (
)
type Reader struct {
reader io.Reader
remainingLength int
buffer *buf.Buffer
reader io.Reader
buffer *buf.Buffer
}
func NewReader(reader buf.Reader) *Reader {
@ -38,28 +37,27 @@ func (r *Reader) ReadMetadata() (*FrameMetadata, error) {
return ReadFrameFrom(b.Bytes())
}
func (r *Reader) Read() (*buf.Buffer, bool, error) {
b := buf.New()
var dataLen int
if r.remainingLength > 0 {
dataLen = r.remainingLength
r.remainingLength = 0
} else {
if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil {
return nil, false, err
func (r *Reader) Read() (buf.MultiBuffer, error) {
r.buffer.Clear()
if err := r.buffer.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil {
return nil, err
}
dataLen := int(serial.BytesToUint16(r.buffer.Bytes()))
mb := buf.NewMultiBuffer()
for dataLen > 0 {
b := buf.New()
readLen := buf.Size
if dataLen < readLen {
readLen = dataLen
}
dataLen = int(serial.BytesToUint16(b.Bytes()))
b.Clear()
if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, readLen)); err != nil {
mb.Release()
return nil, err
}
dataLen -= readLen
mb.Append(b)
}
if dataLen > buf.Size {
r.remainingLength = dataLen - buf.Size
dataLen = buf.Size
}
if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, dataLen)); err != nil {
return nil, false, err
}
return b, (r.remainingLength > 0), nil
return mb, nil
}

View File

@ -29,7 +29,7 @@ func NewResponseWriter(id uint16, writer buf.Writer) *Writer {
}
}
func (w *Writer) writeInternal(b *buf.Buffer) error {
func (w *Writer) Write(mb buf.MultiBuffer) error {
meta := FrameMetadata{
SessionID: w.id,
Target: w.dest,
@ -41,42 +41,21 @@ func (w *Writer) writeInternal(b *buf.Buffer) error {
meta.SessionStatus = SessionStatusNew
}
if b.Len() > 0 {
if mb.Len() > 0 {
meta.Option.Add(OptionData)
}
frame := buf.New()
frame.AppendSupplier(meta.AsSupplier())
if b.Len() > 0 {
frame.AppendSupplier(serial.WriteUint16(0))
lengthBytes := frame.BytesFrom(-2)
mb2 := buf.NewMultiBuffer()
mb2.Append(frame)
nBytes, err := frame.Write(b.Bytes())
if err != nil {
frame.Release()
return err
}
serial.Uint16ToBytes(uint16(nBytes), lengthBytes[:0])
b.SliceFrom(nBytes)
if mb.Len() > 0 {
frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len())))
mb2.AppendMulti(mb)
}
return w.writer.Write(frame)
}
func (w *Writer) Write(b *buf.Buffer) error {
defer b.Release()
if err := w.writeInternal(b); err != nil {
return err
}
for !b.IsEmpty() {
if err := w.writeInternal(b); err != nil {
return err
}
}
return nil
return w.writer.Write(mb2)
}
func (w *Writer) Close() {
@ -88,5 +67,8 @@ func (w *Writer) Close() {
frame := buf.New()
frame.AppendSupplier(meta.AsSupplier())
w.writer.Write(frame)
mb := buf.NewMultiBuffer()
mb.Append(frame)
w.writer.Write(mb)
}

View File

@ -129,7 +129,7 @@ type Connection struct {
remoteAddr net.Addr
reader io.Reader
writer io.Writer
writer buf.Writer
}
func NewConnection(stream ray.Ray) *Connection {
@ -144,7 +144,7 @@ func NewConnection(stream ray.Ray) *Connection {
Port: 0,
},
reader: buf.ToBytesReader(stream.InboundOutput()),
writer: buf.ToBytesWriter(stream.InboundInput()),
writer: stream.InboundInput(),
}
}
@ -161,7 +161,14 @@ func (v *Connection) Write(b []byte) (int, error) {
if v.closed {
return 0, io.ErrClosedPipe
}
return v.writer.Write(b)
return buf.ToBytesWriter(v.writer).Write(b)
}
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
if v.closed {
return 0, io.ErrClosedPipe
}
return mb.Len(), v.writer.Write(mb)
}
// Close implements net.Conn.Close().

View File

@ -11,19 +11,19 @@ import (
// Reader extends io.Reader with alloc.Buffer.
type Reader interface {
// Read reads content from underlying reader, and put it into an alloc.Buffer.
Read() (*Buffer, error)
Read() (MultiBuffer, error)
}
var ErrReadTimeout = newError("IO timeout")
type TimeoutReader interface {
ReadTimeout(time.Duration) (*Buffer, error)
ReadTimeout(time.Duration) (MultiBuffer, error)
}
// Writer extends io.Writer with alloc.Buffer.
type Writer interface {
// Write writes an alloc.Buffer into underlying writer.
Write(*Buffer) error
Write(MultiBuffer) error
}
// ReadFrom creates a Supplier to read from a given io.Reader.
@ -78,6 +78,7 @@ func PipeUntilEOF(timer signal.ActivityTimer, reader Reader, writer Writer) erro
func NewReader(reader io.Reader) Reader {
return &BytesToBufferReader{
reader: reader,
buffer: NewLocal(32 * 1024),
}
}

View File

@ -3,7 +3,6 @@ package buf
type MergingReader struct {
reader Reader
timeoutReader TimeoutReader
leftover *Buffer
}
func NewMergingReader(reader Reader) Reader {
@ -13,41 +12,23 @@ func NewMergingReader(reader Reader) Reader {
}
}
func (r *MergingReader) Read() (*Buffer, error) {
if r.leftover != nil {
b := r.leftover
r.leftover = nil
return b, nil
}
b, err := r.reader.Read()
func (r *MergingReader) Read() (MultiBuffer, error) {
mb, err := r.reader.Read()
if err != nil {
return nil, err
}
if b.IsFull() {
return b, nil
}
if r.timeoutReader == nil {
return b, nil
return mb, nil
}
for {
b2, err := r.timeoutReader.ReadTimeout(0)
mb2, err := r.timeoutReader.ReadTimeout(0)
if err != nil {
break
}
nBytes := b.Append(b2.Bytes())
b2.SliceFrom(nBytes)
if b2.IsEmpty() {
b2.Release()
} else {
r.leftover = b2
break
}
mb.AppendMulti(mb2)
}
return b, nil
return mb, nil
}

View File

@ -16,18 +16,18 @@ func TestMergingReader(t *testing.T) {
stream := ray.NewStream(context.Background())
b1 := New()
b1.AppendBytes('a', 'b', 'c')
stream.Write(b1)
stream.Write(NewMultiBufferValue(b1))
b2 := New()
b2.AppendBytes('e', 'f', 'g')
stream.Write(b2)
stream.Write(NewMultiBufferValue(b2))
b3 := New()
b3.AppendBytes('h', 'i', 'j')
stream.Write(b3)
stream.Write(NewMultiBufferValue(b3))
reader := NewMergingReader(stream)
b, err := reader.Read()
assert.Error(err).IsNil()
assert.String(b.String()).Equals("abcefghij")
assert.Int(b.Len()).Equals(9)
}

View File

@ -0,0 +1,88 @@
package buf
import (
"io"
"net"
)
type MultiBufferWriter interface {
WriteMultiBuffer(MultiBuffer) (int, error)
}
type MultiBuffer []*Buffer
func NewMultiBuffer() MultiBuffer {
return MultiBuffer(make([]*Buffer, 0, 8))
}
func NewMultiBufferValue(b ...*Buffer) MultiBuffer {
return MultiBuffer(b)
}
func (b *MultiBuffer) Append(buf *Buffer) {
*b = append(*b, buf)
}
func (b *MultiBuffer) AppendMulti(mb MultiBuffer) {
*b = append(*b, mb...)
}
func (mb *MultiBuffer) Read(b []byte) (int, error) {
if len(*mb) == 0 {
return 0, io.EOF
}
endIndex := len(*mb)
totalBytes := 0
for i, bb := range *mb {
nBytes, err := bb.Read(b)
totalBytes += nBytes
if err != nil {
return totalBytes, err
}
b = b[nBytes:]
if bb.IsEmpty() {
bb.Release()
} else {
endIndex = i
break
}
}
*mb = (*mb)[endIndex:]
return totalBytes, nil
}
func (mb MultiBuffer) WriteTo(writer io.Writer) (int, error) {
if mw, ok := writer.(MultiBufferWriter); ok {
return mw.WriteMultiBuffer(mb)
}
bs := make([][]byte, len(mb))
for i, b := range mb {
bs[i] = b.Bytes()
}
nbs := net.Buffers(bs)
nBytes, err := nbs.WriteTo(writer)
return int(nBytes), err
}
func (mb MultiBuffer) Len() int {
size := 0
for _, b := range mb {
size += b.Len()
}
return size
}
func (mb MultiBuffer) IsEmpty() bool {
for _, b := range mb {
if !b.IsEmpty() {
return false
}
}
return true
}
func (mb MultiBuffer) Release() {
for _, b := range mb {
b.Release()
}
}

View File

@ -0,0 +1,25 @@
package buf_test
import (
"testing"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
)
func TestMultiBufferRead(t *testing.T) {
assert := assert.On(t)
b1 := New()
b1.AppendBytes('a', 'b')
b2 := New()
b2.AppendBytes('c', 'd')
mb := NewMultiBufferValue(b1, b2)
bs := make([]byte, 32)
nBytes, err := mb.Read(bs)
assert.Error(err).IsNil()
assert.Int(nBytes).Equals(4)
assert.Bytes(bs[:nBytes]).Equals([]byte("abcd"))
}

View File

@ -4,48 +4,28 @@ import "io"
// BytesToBufferReader is a Reader that adjusts its reading speed automatically.
type BytesToBufferReader struct {
reader io.Reader
largeBuffer *Buffer
highVolumn bool
reader io.Reader
buffer *Buffer
}
// Read implements Reader.Read().
func (v *BytesToBufferReader) Read() (*Buffer, error) {
if v.highVolumn && v.largeBuffer.IsEmpty() {
if v.largeBuffer == nil {
v.largeBuffer = NewLocal(32 * 1024)
}
err := v.largeBuffer.AppendSupplier(ReadFrom(v.reader))
if err != nil {
return nil, err
}
if v.largeBuffer.Len() < Size {
v.highVolumn = false
}
}
buffer := New()
if !v.largeBuffer.IsEmpty() {
err := buffer.AppendSupplier(ReadFrom(v.largeBuffer))
return buffer, err
}
err := buffer.AppendSupplier(ReadFrom(v.reader))
if err != nil {
buffer.Release()
func (v *BytesToBufferReader) Read() (MultiBuffer, error) {
if err := v.buffer.Reset(ReadFrom(v.reader)); err != nil {
return nil, err
}
if buffer.IsFull() {
v.highVolumn = true
mb := NewMultiBuffer()
for !v.buffer.IsEmpty() {
b := New()
b.AppendSupplier(ReadFrom(v.buffer))
mb.Append(b)
}
return buffer, nil
return mb, nil
}
type bufferToBytesReader struct {
stream Reader
current *Buffer
current MultiBuffer
err error
}

View File

@ -15,14 +15,7 @@ func TestAdaptiveReader(t *testing.T) {
buffer := bytes.NewBuffer(rawContent)
reader := NewReader(buffer)
b1, err := reader.Read()
b, err := reader.Read()
assert.Error(err).IsNil()
assert.Bool(b1.IsFull()).IsTrue()
assert.Int(b1.Len()).Equals(Size)
assert.Int(buffer.Len()).Equals(cap(rawContent) - Size)
b2, err := reader.Read()
assert.Error(err).IsNil()
assert.Bool(b2.IsFull()).IsTrue()
assert.Int(buffer.Len()).Equals(1007616)
assert.Int(b.Len()).Equals(32 * 1024)
}

View File

@ -8,39 +8,30 @@ type BufferToBytesWriter struct {
}
// Write implements Writer.Write(). Write() takes ownership of the given buffer.
func (v *BufferToBytesWriter) Write(buffer *Buffer) error {
defer buffer.Release()
for {
nBytes, err := v.writer.Write(buffer.Bytes())
if err != nil {
return err
}
if nBytes == buffer.Len() {
break
}
buffer.SliceFrom(nBytes)
}
return nil
func (v *BufferToBytesWriter) Write(buffer MultiBuffer) error {
_, err := buffer.WriteTo(v.writer)
//buffer.Release()
return err
}
type bytesToBufferWriter struct {
writer Writer
}
func (v *bytesToBufferWriter) Write(payload []byte) (int, error) {
bytesWritten := 0
size := len(payload)
for size > 0 {
buffer := New()
nBytes, _ := buffer.Write(payload)
size -= nBytes
payload = payload[nBytes:]
bytesWritten += nBytes
err := v.writer.Write(buffer)
if err != nil {
return bytesWritten, err
}
func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
mb := NewMultiBuffer()
for p := payload; len(p) > 0; {
b := New()
nBytes, _ := b.Write(p)
p = p[nBytes:]
mb.Append(b)
}
return bytesWritten, nil
if err := w.writer.Write(mb); err != nil {
return 0, err
}
return len(payload), nil
}
func (w *bytesToBufferWriter) WriteMulteBuffer(mb MultiBuffer) (int, error) {
return mb.Len(), w.writer.Write(mb)
}

View File

@ -20,7 +20,7 @@ func TestWriter(t *testing.T) {
writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
writer := NewWriter(NewBufferedWriter(writeBuffer))
err := writer.Write(lb)
err := writer.Write(NewMultiBufferValue(lb))
assert.Error(err).IsNil()
assert.Bytes(expectedBytes).Equals(writeBuffer.Bytes())
}

View File

@ -215,14 +215,34 @@ func NewAuthenticationWriter(auth Authenticator, writer io.Writer, sizeMask Uint
}
}
func (v *AuthenticationWriter) Write(b []byte) (int, error) {
cipherChunk, err := v.auth.Seal(v.buffer[2:2], b)
func (w *AuthenticationWriter) Write(b []byte) (int, error) {
cipherChunk, err := w.auth.Seal(w.buffer[2:2], b)
if err != nil {
return 0, err
}
size := uint16(len(cipherChunk)) ^ v.sizeMask.Next()
serial.Uint16ToBytes(size, v.buffer[:0])
_, err = v.writer.Write(v.buffer[:2+len(cipherChunk)])
size := uint16(len(cipherChunk)) ^ w.sizeMask.Next()
serial.Uint16ToBytes(size, w.buffer[:0])
_, err = w.writer.Write(w.buffer[:2+len(cipherChunk)])
return len(b), err
}
func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
const StartIndex = 17 * 1024
var totalBytes int
for {
payloadLen, err := mb.Read(w.buffer[StartIndex:])
if err != nil {
return 0, err
}
nBytes, err := w.Write(w.buffer[StartIndex : StartIndex+payloadLen])
totalBytes += nBytes
if err != nil {
return totalBytes, err
}
if mb.IsEmpty() {
break
}
}
return totalBytes, nil
}

View File

@ -28,7 +28,9 @@ func (v *NoneResponse) WriteTo(buf.Writer) {}
func (v *HTTPResponse) WriteTo(writer buf.Writer) {
b := buf.NewLocal(512)
b.AppendSupplier(serial.WriteString(http403response))
writer.Write(b)
mb := buf.NewMultiBuffer()
mb.Append(b)
writer.Write(mb)
}
// GetInternalResponse converts response settings from proto to internal data structure.

View File

@ -68,7 +68,7 @@ func NewChunkReader(reader io.Reader, auth *Authenticator) *ChunkReader {
}
}
func (v *ChunkReader) Read() (*buf.Buffer, error) {
func (v *ChunkReader) Read() (buf.MultiBuffer, error) {
buffer := buf.New()
if err := buffer.AppendSupplier(buf.ReadFullFrom(v.reader, 2)); err != nil {
buffer.Release()
@ -100,7 +100,10 @@ func (v *ChunkReader) Read() (*buf.Buffer, error) {
}
buffer.SliceFrom(AuthSize)
return buffer, nil
mb := buf.NewMultiBuffer()
mb.Append(buffer)
return mb, nil
}
type ChunkWriter struct {
@ -117,11 +120,22 @@ func NewChunkWriter(writer io.Writer, auth *Authenticator) *ChunkWriter {
}
}
func (v *ChunkWriter) Write(payload *buf.Buffer) error {
func (w *ChunkWriter) Write(mb buf.MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
if err := w.writeInternal(b); err != nil {
return err
}
}
return nil
}
func (w *ChunkWriter) writeInternal(payload *buf.Buffer) error {
totalLength := payload.Len()
serial.Uint16ToBytes(uint16(totalLength), v.buffer[:0])
v.auth.Authenticate(payload.Bytes())(v.buffer[2:])
copy(v.buffer[2+AuthSize:], payload.Bytes())
_, err := v.writer.Write(v.buffer[:2+AuthSize+payload.Len()])
serial.Uint16ToBytes(uint16(totalLength), w.buffer[:0])
w.auth.Authenticate(payload.Bytes())(w.buffer[2:])
copy(w.buffer[2+AuthSize:], payload.Bytes())
_, err := w.writer.Write(w.buffer[:2+AuthSize+payload.Len()])
return err
}

View File

@ -18,7 +18,7 @@ func TestNormalChunkReading(t *testing.T) {
[]byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36})))
payload, err := reader.Read()
assert.Error(err).IsNil()
assert.Bytes(payload.Bytes()).Equals([]byte{11, 12, 13, 14, 15, 16, 17, 18})
assert.Bytes(payload[0].Bytes()).Equals([]byte{11, 12, 13, 14, 15, 16, 17, 18})
}
func TestNormalChunkWriting(t *testing.T) {
@ -30,7 +30,7 @@ func TestNormalChunkWriting(t *testing.T) {
b := buf.NewLocal(256)
b.Append([]byte{11, 12, 13, 14, 15, 16, 17, 18})
err := writer.Write(b)
err := writer.Write(buf.NewMultiBufferValue(b))
assert.Error(err).IsNil()
assert.Bytes(buffer.Bytes()).Equals([]byte{0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18})
}

View File

@ -362,7 +362,7 @@ type UDPReader struct {
User *protocol.User
}
func (v *UDPReader) Read() (*buf.Buffer, error) {
func (v *UDPReader) Read() (buf.MultiBuffer, error) {
buffer := buf.NewSmall()
err := buffer.AppendSupplier(buf.ReadFrom(v.Reader))
if err != nil {
@ -374,7 +374,9 @@ func (v *UDPReader) Read() (*buf.Buffer, error) {
buffer.Release()
return nil, err
}
return payload, nil
mb := buf.NewMultiBuffer()
mb.Append(payload)
return mb, nil
}
type UDPWriter struct {
@ -382,12 +384,21 @@ type UDPWriter struct {
Request *protocol.RequestHeader
}
func (v *UDPWriter) Write(buffer *buf.Buffer) error {
payload, err := EncodeUDPPacket(v.Request, buffer)
func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
for _, b := range mb {
if err := w.writeInternal(b); err != nil {
return err
}
}
return nil
}
func (w *UDPWriter) writeInternal(buffer *buf.Buffer) error {
payload, err := EncodeUDPPacket(w.Request, buffer)
if err != nil {
return err
}
_, err = v.Writer.Write(payload.Bytes())
_, err = w.Writer.Write(payload.Bytes())
payload.Release()
return err
}

View File

@ -66,7 +66,7 @@ func TestTCPRequest(t *testing.T) {
writer, err := WriteTCPRequest(request, cache)
assert.Error(err).IsNil()
writer.Write(data)
writer.Write(buf.NewMultiBufferValue(data))
decodedRequest, reader, err := ReadTCPSession(request.User, cache)
assert.Error(err).IsNil()
@ -75,7 +75,7 @@ func TestTCPRequest(t *testing.T) {
decodedData, err := reader.Read()
assert.Error(err).IsNil()
assert.String(decodedData.String()).Equals("test string")
assert.String(decodedData[0].String()).Equals("test string")
}
func TestUDPReaderWriter(t *testing.T) {
@ -106,19 +106,19 @@ func TestUDPReaderWriter(t *testing.T) {
b := buf.New()
b.AppendSupplier(serial.WriteString("test payload"))
err := writer.Write(b)
err := writer.Write(buf.NewMultiBufferValue(b))
assert.Error(err).IsNil()
payload, err := reader.Read()
assert.Error(err).IsNil()
assert.String(payload.String()).Equals("test payload")
assert.String(payload[0].String()).Equals("test payload")
b = buf.New()
b.AppendSupplier(serial.WriteString("test payload 2"))
err = writer.Write(b)
err = writer.Write(buf.NewMultiBufferValue(b))
assert.Error(err).IsNil()
payload, err = reader.Read()
assert.Error(err).IsNil()
assert.String(payload.String()).Equals("test payload 2")
assert.String(payload[0].String()).Equals("test payload 2")
}

View File

@ -75,52 +75,54 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
reader := buf.NewReader(conn)
for {
payload, err := reader.Read()
mpayload, err := reader.Read()
if err != nil {
break
}
request, data, err := DecodeUDPPacket(v.user, payload)
if err != nil {
if source, ok := proxy.SourceFromContext(ctx); ok {
log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err))
log.Access(source, "", log.AccessRejected, err)
}
payload.Release()
continue
}
if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled {
log.Trace(newError("client payload enables OTA but server doesn't allow it"))
payload.Release()
continue
}
if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled {
log.Trace(newError("client payload disables OTA but server forces it"))
payload.Release()
continue
}
dest := request.Destination()
if source, ok := proxy.SourceFromContext(ctx); ok {
log.Access(source, dest, log.AccessAccepted, "")
}
log.Trace(newError("tunnelling request to ", dest))
ctx = protocol.ContextWithUser(ctx, request.User)
udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
defer payload.Release()
data, err := EncodeUDPPacket(request, payload)
for _, payload := range mpayload {
request, data, err := DecodeUDPPacket(v.user, payload)
if err != nil {
log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
return
if source, ok := proxy.SourceFromContext(ctx); ok {
log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err))
log.Access(source, "", log.AccessRejected, err)
}
payload.Release()
continue
}
defer data.Release()
conn.Write(data.Bytes())
})
if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled {
log.Trace(newError("client payload enables OTA but server doesn't allow it"))
payload.Release()
continue
}
if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled {
log.Trace(newError("client payload disables OTA but server forces it"))
payload.Release()
continue
}
dest := request.Destination()
if source, ok := proxy.SourceFromContext(ctx); ok {
log.Access(source, dest, log.AccessAccepted, "")
}
log.Trace(newError("tunnelling request to ", dest))
ctx = protocol.ContextWithUser(ctx, request.User)
udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
defer payload.Release()
data, err := EncodeUDPPacket(request, payload)
if err != nil {
log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
return
}
defer data.Release()
conn.Write(data.Bytes())
})
}
}
return nil

View File

@ -347,7 +347,7 @@ func NewUDPReader(reader io.Reader) *UDPReader {
return &UDPReader{reader: reader}
}
func (r *UDPReader) Read() (*buf.Buffer, error) {
func (r *UDPReader) Read() (buf.MultiBuffer, error) {
b := buf.NewSmall()
if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
return nil, err
@ -358,7 +358,9 @@ func (r *UDPReader) Read() (*buf.Buffer, error) {
}
b.Clear()
b.Append(data)
return b, nil
mb := buf.NewMultiBuffer()
mb.Append(b)
return mb, nil
}
type UDPWriter struct {
@ -373,12 +375,15 @@ func NewUDPWriter(request *protocol.RequestHeader, writer io.Writer) *UDPWriter
}
}
func (w *UDPWriter) Write(b *buf.Buffer) error {
eb := EncodeUDPPacket(w.request, b.Bytes())
b.Release()
defer eb.Release()
if _, err := w.writer.Write(eb.Bytes()); err != nil {
return err
func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
eb := EncodeUDPPacket(w.request, b.Bytes())
defer eb.Release()
if _, err := w.writer.Write(eb.Bytes()); err != nil {
return err
}
}
return nil
}

View File

@ -24,11 +24,11 @@ func TestUDPEncoding(t *testing.T) {
content := []byte{'a'}
payload := buf.New()
payload.Append(content)
assert.Error(writer.Write(payload)).IsNil()
assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
reader := NewUDPReader(b)
decodedPayload, err := reader.Read()
assert.Error(err).IsNil()
assert.Bytes(decodedPayload.Bytes()).Equals(content)
assert.Bytes(decodedPayload[0].Bytes()).Equals(content)
}

View File

@ -159,38 +159,41 @@ func (v *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
reader := buf.NewReader(conn)
for {
payload, err := reader.Read()
mpayload, err := reader.Read()
if err != nil {
return err
}
request, data, err := DecodeUDPPacket(payload.Bytes())
if err != nil {
log.Trace(newError("failed to parse UDP request").Base(err))
continue
for _, payload := range mpayload {
request, data, err := DecodeUDPPacket(payload.Bytes())
if err != nil {
log.Trace(newError("failed to parse UDP request").Base(err))
continue
}
if len(data) == 0 {
continue
}
log.Trace(newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug())
if source, ok := proxy.SourceFromContext(ctx); ok {
log.Access(source, request.Destination, log.AccessAccepted, "")
}
dataBuf := buf.NewSmall()
dataBuf.Append(data)
udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) {
defer payload.Release()
log.Trace(newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug())
udpMessage := EncodeUDPPacket(request, payload.Bytes())
defer udpMessage.Release()
conn.Write(udpMessage.Bytes())
})
}
if len(data) == 0 {
continue
}
log.Trace(newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug())
if source, ok := proxy.SourceFromContext(ctx); ok {
log.Access(source, request.Destination, log.AccessAccepted, "")
}
dataBuf := buf.NewSmall()
dataBuf.Append(data)
udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) {
defer payload.Release()
log.Trace(newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug())
udpMessage := EncodeUDPPacket(request, payload.Bytes())
defer udpMessage.Release()
conn.Write(udpMessage.Bytes())
})
}
}

View File

@ -166,7 +166,7 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio
}
if request.Option.Has(protocol.RequestOptionChunkStream) {
if err := bodyWriter.Write(buf.NewLocal(8)); err != nil {
if err := bodyWriter.Write(buf.NewMultiBuffer()); err != nil {
return err
}
}

View File

@ -133,7 +133,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
}
if request.Option.Has(protocol.RequestOptionChunkStream) {
if err := bodyWriter.Write(buf.NewLocal(8)); err != nil {
if err := bodyWriter.Write(buf.NewMultiBuffer()); err != nil {
return err
}
}

View File

@ -57,7 +57,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination
inboundRay, existing := v.getInboundRay(ctx, destination)
outputStream := inboundRay.InboundInput()
if outputStream != nil {
if err := outputStream.Write(payload); err != nil {
if err := outputStream.Write(buf.NewMultiBufferValue(payload)); err != nil {
v.RemoveRay(destination)
}
}
@ -71,10 +71,12 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination
func handleInput(input ray.InputStream, callback ResponseCallback) {
for {
data, err := input.Read()
mb, err := input.Read()
if err != nil {
break
}
callback(data)
for _, b := range mb {
callback(b)
}
}
}

View File

@ -42,7 +42,7 @@ func (v *directRay) InboundOutput() InputStream {
}
type Stream struct {
buffer chan *buf.Buffer
buffer chan buf.MultiBuffer
ctx context.Context
close chan bool
err chan bool
@ -51,13 +51,13 @@ type Stream struct {
func NewStream(ctx context.Context) *Stream {
return &Stream{
ctx: ctx,
buffer: make(chan *buf.Buffer, bufferSize),
buffer: make(chan buf.MultiBuffer, bufferSize),
close: make(chan bool),
err: make(chan bool),
}
}
func (v *Stream) Read() (*buf.Buffer, error) {
func (v *Stream) Read() (buf.MultiBuffer, error) {
select {
case <-v.ctx.Done():
return nil, io.ErrClosedPipe
@ -79,7 +79,7 @@ func (v *Stream) Read() (*buf.Buffer, error) {
}
}
func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) {
func (v *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
select {
case <-v.ctx.Done():
return nil, io.ErrClosedPipe
@ -107,7 +107,7 @@ func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) {
}
}
func (v *Stream) Write(data *buf.Buffer) (err error) {
func (v *Stream) Write(data buf.MultiBuffer) (err error) {
if data.IsEmpty() {
return
}

View File

@ -16,7 +16,7 @@ func TestStreamIO(t *testing.T) {
stream := NewStream(context.Background())
b1 := buf.New()
b1.AppendBytes('a')
assert.Error(stream.Write(b1)).IsNil()
assert.Error(stream.Write(buf.NewMultiBufferValue(b1))).IsNil()
_, err := stream.Read()
assert.Error(err).IsNil()
@ -27,7 +27,7 @@ func TestStreamIO(t *testing.T) {
b2 := buf.New()
b2.AppendBytes('b')
err = stream.Write(b2)
err = stream.Write(buf.NewMultiBufferValue(b2))
assert.Error(err).Equals(io.ErrClosedPipe)
}
@ -37,7 +37,7 @@ func TestStreamClose(t *testing.T) {
stream := NewStream(context.Background())
b1 := buf.New()
b1.AppendBytes('a')
assert.Error(stream.Write(b1)).IsNil()
assert.Error(stream.Write(buf.NewMultiBufferValue(b1))).IsNil()
stream.Close()