mirror of
https://github.com/v2fly/v2ray-core.git
synced 2024-12-21 09:36:34 -05:00
multi buffer
This commit is contained in:
parent
0ef9143ffd
commit
f506a39d32
@ -180,31 +180,20 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
|
||||
}
|
||||
|
||||
func drain(reader *Reader) error {
|
||||
for {
|
||||
data, more, err := reader.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data.Release()
|
||||
if !more {
|
||||
return nil
|
||||
}
|
||||
data, err := reader.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
func pipe(reader *Reader, writer buf.Writer) error {
|
||||
for {
|
||||
data, more, err := reader.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writer.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
if !more {
|
||||
return nil
|
||||
}
|
||||
data, err := reader.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writer.Write(data)
|
||||
}
|
||||
|
||||
func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *Reader) error {
|
||||
|
@ -20,7 +20,7 @@ func TestReaderWriter(t *testing.T) {
|
||||
|
||||
payload := buf.New()
|
||||
payload.AppendBytes('a', 'b', 'c', 'd')
|
||||
assert.Error(writer.Write(payload)).IsNil()
|
||||
assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
|
||||
|
||||
writer.Close()
|
||||
|
||||
@ -32,10 +32,9 @@ func TestReaderWriter(t *testing.T) {
|
||||
assert.Destination(meta.Target).Equals(dest)
|
||||
assert.Byte(byte(meta.Option)).Equals(byte(OptionData))
|
||||
|
||||
data, more, err := reader.Read()
|
||||
data, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bool(more).IsFalse()
|
||||
assert.String(data.String()).Equals("abcd")
|
||||
assert.String(data[0].String()).Equals("abcd")
|
||||
|
||||
meta, err = reader.ReadMetadata()
|
||||
assert.Error(err).IsNil()
|
||||
|
@ -8,9 +8,8 @@ import (
|
||||
)
|
||||
|
||||
type Reader struct {
|
||||
reader io.Reader
|
||||
remainingLength int
|
||||
buffer *buf.Buffer
|
||||
reader io.Reader
|
||||
buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func NewReader(reader buf.Reader) *Reader {
|
||||
@ -38,28 +37,27 @@ func (r *Reader) ReadMetadata() (*FrameMetadata, error) {
|
||||
return ReadFrameFrom(b.Bytes())
|
||||
}
|
||||
|
||||
func (r *Reader) Read() (*buf.Buffer, bool, error) {
|
||||
b := buf.New()
|
||||
var dataLen int
|
||||
if r.remainingLength > 0 {
|
||||
dataLen = r.remainingLength
|
||||
r.remainingLength = 0
|
||||
} else {
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil {
|
||||
return nil, false, err
|
||||
func (r *Reader) Read() (buf.MultiBuffer, error) {
|
||||
r.buffer.Clear()
|
||||
if err := r.buffer.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataLen := int(serial.BytesToUint16(r.buffer.Bytes()))
|
||||
mb := buf.NewMultiBuffer()
|
||||
for dataLen > 0 {
|
||||
b := buf.New()
|
||||
readLen := buf.Size
|
||||
if dataLen < readLen {
|
||||
readLen = dataLen
|
||||
}
|
||||
dataLen = int(serial.BytesToUint16(b.Bytes()))
|
||||
b.Clear()
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, readLen)); err != nil {
|
||||
mb.Release()
|
||||
return nil, err
|
||||
}
|
||||
dataLen -= readLen
|
||||
mb.Append(b)
|
||||
}
|
||||
|
||||
if dataLen > buf.Size {
|
||||
r.remainingLength = dataLen - buf.Size
|
||||
dataLen = buf.Size
|
||||
}
|
||||
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, dataLen)); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return b, (r.remainingLength > 0), nil
|
||||
return mb, nil
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func NewResponseWriter(id uint16, writer buf.Writer) *Writer {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) writeInternal(b *buf.Buffer) error {
|
||||
func (w *Writer) Write(mb buf.MultiBuffer) error {
|
||||
meta := FrameMetadata{
|
||||
SessionID: w.id,
|
||||
Target: w.dest,
|
||||
@ -41,42 +41,21 @@ func (w *Writer) writeInternal(b *buf.Buffer) error {
|
||||
meta.SessionStatus = SessionStatusNew
|
||||
}
|
||||
|
||||
if b.Len() > 0 {
|
||||
if mb.Len() > 0 {
|
||||
meta.Option.Add(OptionData)
|
||||
}
|
||||
|
||||
frame := buf.New()
|
||||
frame.AppendSupplier(meta.AsSupplier())
|
||||
|
||||
if b.Len() > 0 {
|
||||
frame.AppendSupplier(serial.WriteUint16(0))
|
||||
lengthBytes := frame.BytesFrom(-2)
|
||||
mb2 := buf.NewMultiBuffer()
|
||||
mb2.Append(frame)
|
||||
|
||||
nBytes, err := frame.Write(b.Bytes())
|
||||
if err != nil {
|
||||
frame.Release()
|
||||
return err
|
||||
}
|
||||
|
||||
serial.Uint16ToBytes(uint16(nBytes), lengthBytes[:0])
|
||||
b.SliceFrom(nBytes)
|
||||
if mb.Len() > 0 {
|
||||
frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len())))
|
||||
mb2.AppendMulti(mb)
|
||||
}
|
||||
|
||||
return w.writer.Write(frame)
|
||||
}
|
||||
|
||||
func (w *Writer) Write(b *buf.Buffer) error {
|
||||
defer b.Release()
|
||||
|
||||
if err := w.writeInternal(b); err != nil {
|
||||
return err
|
||||
}
|
||||
for !b.IsEmpty() {
|
||||
if err := w.writeInternal(b); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return w.writer.Write(mb2)
|
||||
}
|
||||
|
||||
func (w *Writer) Close() {
|
||||
@ -88,5 +67,8 @@ func (w *Writer) Close() {
|
||||
frame := buf.New()
|
||||
frame.AppendSupplier(meta.AsSupplier())
|
||||
|
||||
w.writer.Write(frame)
|
||||
mb := buf.NewMultiBuffer()
|
||||
mb.Append(frame)
|
||||
|
||||
w.writer.Write(mb)
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ type Connection struct {
|
||||
remoteAddr net.Addr
|
||||
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
writer buf.Writer
|
||||
}
|
||||
|
||||
func NewConnection(stream ray.Ray) *Connection {
|
||||
@ -144,7 +144,7 @@ func NewConnection(stream ray.Ray) *Connection {
|
||||
Port: 0,
|
||||
},
|
||||
reader: buf.ToBytesReader(stream.InboundOutput()),
|
||||
writer: buf.ToBytesWriter(stream.InboundInput()),
|
||||
writer: stream.InboundInput(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,7 +161,14 @@ func (v *Connection) Write(b []byte) (int, error) {
|
||||
if v.closed {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return v.writer.Write(b)
|
||||
return buf.ToBytesWriter(v.writer).Write(b)
|
||||
}
|
||||
|
||||
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
|
||||
if v.closed {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return mb.Len(), v.writer.Write(mb)
|
||||
}
|
||||
|
||||
// Close implements net.Conn.Close().
|
||||
|
@ -11,19 +11,19 @@ import (
|
||||
// Reader extends io.Reader with alloc.Buffer.
|
||||
type Reader interface {
|
||||
// Read reads content from underlying reader, and put it into an alloc.Buffer.
|
||||
Read() (*Buffer, error)
|
||||
Read() (MultiBuffer, error)
|
||||
}
|
||||
|
||||
var ErrReadTimeout = newError("IO timeout")
|
||||
|
||||
type TimeoutReader interface {
|
||||
ReadTimeout(time.Duration) (*Buffer, error)
|
||||
ReadTimeout(time.Duration) (MultiBuffer, error)
|
||||
}
|
||||
|
||||
// Writer extends io.Writer with alloc.Buffer.
|
||||
type Writer interface {
|
||||
// Write writes an alloc.Buffer into underlying writer.
|
||||
Write(*Buffer) error
|
||||
Write(MultiBuffer) error
|
||||
}
|
||||
|
||||
// ReadFrom creates a Supplier to read from a given io.Reader.
|
||||
@ -78,6 +78,7 @@ func PipeUntilEOF(timer signal.ActivityTimer, reader Reader, writer Writer) erro
|
||||
func NewReader(reader io.Reader) Reader {
|
||||
return &BytesToBufferReader{
|
||||
reader: reader,
|
||||
buffer: NewLocal(32 * 1024),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,6 @@ package buf
|
||||
type MergingReader struct {
|
||||
reader Reader
|
||||
timeoutReader TimeoutReader
|
||||
leftover *Buffer
|
||||
}
|
||||
|
||||
func NewMergingReader(reader Reader) Reader {
|
||||
@ -13,41 +12,23 @@ func NewMergingReader(reader Reader) Reader {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MergingReader) Read() (*Buffer, error) {
|
||||
if r.leftover != nil {
|
||||
b := r.leftover
|
||||
r.leftover = nil
|
||||
return b, nil
|
||||
}
|
||||
|
||||
b, err := r.reader.Read()
|
||||
func (r *MergingReader) Read() (MultiBuffer, error) {
|
||||
mb, err := r.reader.Read()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if b.IsFull() {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
if r.timeoutReader == nil {
|
||||
return b, nil
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
for {
|
||||
b2, err := r.timeoutReader.ReadTimeout(0)
|
||||
mb2, err := r.timeoutReader.ReadTimeout(0)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
nBytes := b.Append(b2.Bytes())
|
||||
b2.SliceFrom(nBytes)
|
||||
if b2.IsEmpty() {
|
||||
b2.Release()
|
||||
} else {
|
||||
r.leftover = b2
|
||||
break
|
||||
}
|
||||
mb.AppendMulti(mb2)
|
||||
}
|
||||
|
||||
return b, nil
|
||||
return mb, nil
|
||||
}
|
||||
|
@ -16,18 +16,18 @@ func TestMergingReader(t *testing.T) {
|
||||
stream := ray.NewStream(context.Background())
|
||||
b1 := New()
|
||||
b1.AppendBytes('a', 'b', 'c')
|
||||
stream.Write(b1)
|
||||
stream.Write(NewMultiBufferValue(b1))
|
||||
|
||||
b2 := New()
|
||||
b2.AppendBytes('e', 'f', 'g')
|
||||
stream.Write(b2)
|
||||
stream.Write(NewMultiBufferValue(b2))
|
||||
|
||||
b3 := New()
|
||||
b3.AppendBytes('h', 'i', 'j')
|
||||
stream.Write(b3)
|
||||
stream.Write(NewMultiBufferValue(b3))
|
||||
|
||||
reader := NewMergingReader(stream)
|
||||
b, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.String(b.String()).Equals("abcefghij")
|
||||
assert.Int(b.Len()).Equals(9)
|
||||
}
|
||||
|
88
common/buf/multi_buffer.go
Normal file
88
common/buf/multi_buffer.go
Normal file
@ -0,0 +1,88 @@
|
||||
package buf
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type MultiBufferWriter interface {
|
||||
WriteMultiBuffer(MultiBuffer) (int, error)
|
||||
}
|
||||
|
||||
type MultiBuffer []*Buffer
|
||||
|
||||
func NewMultiBuffer() MultiBuffer {
|
||||
return MultiBuffer(make([]*Buffer, 0, 8))
|
||||
}
|
||||
|
||||
func NewMultiBufferValue(b ...*Buffer) MultiBuffer {
|
||||
return MultiBuffer(b)
|
||||
}
|
||||
|
||||
func (b *MultiBuffer) Append(buf *Buffer) {
|
||||
*b = append(*b, buf)
|
||||
}
|
||||
|
||||
func (b *MultiBuffer) AppendMulti(mb MultiBuffer) {
|
||||
*b = append(*b, mb...)
|
||||
}
|
||||
|
||||
func (mb *MultiBuffer) Read(b []byte) (int, error) {
|
||||
if len(*mb) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
endIndex := len(*mb)
|
||||
totalBytes := 0
|
||||
for i, bb := range *mb {
|
||||
nBytes, err := bb.Read(b)
|
||||
totalBytes += nBytes
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
b = b[nBytes:]
|
||||
if bb.IsEmpty() {
|
||||
bb.Release()
|
||||
} else {
|
||||
endIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
*mb = (*mb)[endIndex:]
|
||||
return totalBytes, nil
|
||||
}
|
||||
|
||||
func (mb MultiBuffer) WriteTo(writer io.Writer) (int, error) {
|
||||
if mw, ok := writer.(MultiBufferWriter); ok {
|
||||
return mw.WriteMultiBuffer(mb)
|
||||
}
|
||||
bs := make([][]byte, len(mb))
|
||||
for i, b := range mb {
|
||||
bs[i] = b.Bytes()
|
||||
}
|
||||
nbs := net.Buffers(bs)
|
||||
nBytes, err := nbs.WriteTo(writer)
|
||||
return int(nBytes), err
|
||||
}
|
||||
|
||||
func (mb MultiBuffer) Len() int {
|
||||
size := 0
|
||||
for _, b := range mb {
|
||||
size += b.Len()
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (mb MultiBuffer) IsEmpty() bool {
|
||||
for _, b := range mb {
|
||||
if !b.IsEmpty() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (mb MultiBuffer) Release() {
|
||||
for _, b := range mb {
|
||||
b.Release()
|
||||
}
|
||||
}
|
25
common/buf/multi_buffer_test.go
Normal file
25
common/buf/multi_buffer_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package buf_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/testing/assert"
|
||||
)
|
||||
|
||||
func TestMultiBufferRead(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
|
||||
b1 := New()
|
||||
b1.AppendBytes('a', 'b')
|
||||
|
||||
b2 := New()
|
||||
b2.AppendBytes('c', 'd')
|
||||
mb := NewMultiBufferValue(b1, b2)
|
||||
|
||||
bs := make([]byte, 32)
|
||||
nBytes, err := mb.Read(bs)
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(nBytes).Equals(4)
|
||||
assert.Bytes(bs[:nBytes]).Equals([]byte("abcd"))
|
||||
}
|
@ -4,48 +4,28 @@ import "io"
|
||||
|
||||
// BytesToBufferReader is a Reader that adjusts its reading speed automatically.
|
||||
type BytesToBufferReader struct {
|
||||
reader io.Reader
|
||||
largeBuffer *Buffer
|
||||
highVolumn bool
|
||||
reader io.Reader
|
||||
buffer *Buffer
|
||||
}
|
||||
|
||||
// Read implements Reader.Read().
|
||||
func (v *BytesToBufferReader) Read() (*Buffer, error) {
|
||||
if v.highVolumn && v.largeBuffer.IsEmpty() {
|
||||
if v.largeBuffer == nil {
|
||||
v.largeBuffer = NewLocal(32 * 1024)
|
||||
}
|
||||
err := v.largeBuffer.AppendSupplier(ReadFrom(v.reader))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v.largeBuffer.Len() < Size {
|
||||
v.highVolumn = false
|
||||
}
|
||||
}
|
||||
|
||||
buffer := New()
|
||||
if !v.largeBuffer.IsEmpty() {
|
||||
err := buffer.AppendSupplier(ReadFrom(v.largeBuffer))
|
||||
return buffer, err
|
||||
}
|
||||
|
||||
err := buffer.AppendSupplier(ReadFrom(v.reader))
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
func (v *BytesToBufferReader) Read() (MultiBuffer, error) {
|
||||
if err := v.buffer.Reset(ReadFrom(v.reader)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if buffer.IsFull() {
|
||||
v.highVolumn = true
|
||||
mb := NewMultiBuffer()
|
||||
for !v.buffer.IsEmpty() {
|
||||
b := New()
|
||||
b.AppendSupplier(ReadFrom(v.buffer))
|
||||
mb.Append(b)
|
||||
}
|
||||
|
||||
return buffer, nil
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
type bufferToBytesReader struct {
|
||||
stream Reader
|
||||
current *Buffer
|
||||
current MultiBuffer
|
||||
err error
|
||||
}
|
||||
|
||||
|
@ -15,14 +15,7 @@ func TestAdaptiveReader(t *testing.T) {
|
||||
buffer := bytes.NewBuffer(rawContent)
|
||||
|
||||
reader := NewReader(buffer)
|
||||
b1, err := reader.Read()
|
||||
b, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bool(b1.IsFull()).IsTrue()
|
||||
assert.Int(b1.Len()).Equals(Size)
|
||||
assert.Int(buffer.Len()).Equals(cap(rawContent) - Size)
|
||||
|
||||
b2, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bool(b2.IsFull()).IsTrue()
|
||||
assert.Int(buffer.Len()).Equals(1007616)
|
||||
assert.Int(b.Len()).Equals(32 * 1024)
|
||||
}
|
||||
|
@ -8,39 +8,30 @@ type BufferToBytesWriter struct {
|
||||
}
|
||||
|
||||
// Write implements Writer.Write(). Write() takes ownership of the given buffer.
|
||||
func (v *BufferToBytesWriter) Write(buffer *Buffer) error {
|
||||
defer buffer.Release()
|
||||
for {
|
||||
nBytes, err := v.writer.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nBytes == buffer.Len() {
|
||||
break
|
||||
}
|
||||
buffer.SliceFrom(nBytes)
|
||||
}
|
||||
return nil
|
||||
func (v *BufferToBytesWriter) Write(buffer MultiBuffer) error {
|
||||
_, err := buffer.WriteTo(v.writer)
|
||||
//buffer.Release()
|
||||
return err
|
||||
}
|
||||
|
||||
type bytesToBufferWriter struct {
|
||||
writer Writer
|
||||
}
|
||||
|
||||
func (v *bytesToBufferWriter) Write(payload []byte) (int, error) {
|
||||
bytesWritten := 0
|
||||
size := len(payload)
|
||||
for size > 0 {
|
||||
buffer := New()
|
||||
nBytes, _ := buffer.Write(payload)
|
||||
size -= nBytes
|
||||
payload = payload[nBytes:]
|
||||
bytesWritten += nBytes
|
||||
err := v.writer.Write(buffer)
|
||||
if err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
|
||||
mb := NewMultiBuffer()
|
||||
for p := payload; len(p) > 0; {
|
||||
b := New()
|
||||
nBytes, _ := b.Write(p)
|
||||
p = p[nBytes:]
|
||||
mb.Append(b)
|
||||
}
|
||||
|
||||
return bytesWritten, nil
|
||||
if err := w.writer.Write(mb); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
func (w *bytesToBufferWriter) WriteMulteBuffer(mb MultiBuffer) (int, error) {
|
||||
return mb.Len(), w.writer.Write(mb)
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ func TestWriter(t *testing.T) {
|
||||
writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
|
||||
|
||||
writer := NewWriter(NewBufferedWriter(writeBuffer))
|
||||
err := writer.Write(lb)
|
||||
err := writer.Write(NewMultiBufferValue(lb))
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bytes(expectedBytes).Equals(writeBuffer.Bytes())
|
||||
}
|
||||
|
@ -215,14 +215,34 @@ func NewAuthenticationWriter(auth Authenticator, writer io.Writer, sizeMask Uint
|
||||
}
|
||||
}
|
||||
|
||||
func (v *AuthenticationWriter) Write(b []byte) (int, error) {
|
||||
cipherChunk, err := v.auth.Seal(v.buffer[2:2], b)
|
||||
func (w *AuthenticationWriter) Write(b []byte) (int, error) {
|
||||
cipherChunk, err := w.auth.Seal(w.buffer[2:2], b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
size := uint16(len(cipherChunk)) ^ v.sizeMask.Next()
|
||||
serial.Uint16ToBytes(size, v.buffer[:0])
|
||||
_, err = v.writer.Write(v.buffer[:2+len(cipherChunk)])
|
||||
size := uint16(len(cipherChunk)) ^ w.sizeMask.Next()
|
||||
serial.Uint16ToBytes(size, w.buffer[:0])
|
||||
_, err = w.writer.Write(w.buffer[:2+len(cipherChunk)])
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
|
||||
const StartIndex = 17 * 1024
|
||||
var totalBytes int
|
||||
for {
|
||||
payloadLen, err := mb.Read(w.buffer[StartIndex:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
nBytes, err := w.Write(w.buffer[StartIndex : StartIndex+payloadLen])
|
||||
totalBytes += nBytes
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
if mb.IsEmpty() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return totalBytes, nil
|
||||
}
|
||||
|
@ -28,7 +28,9 @@ func (v *NoneResponse) WriteTo(buf.Writer) {}
|
||||
func (v *HTTPResponse) WriteTo(writer buf.Writer) {
|
||||
b := buf.NewLocal(512)
|
||||
b.AppendSupplier(serial.WriteString(http403response))
|
||||
writer.Write(b)
|
||||
mb := buf.NewMultiBuffer()
|
||||
mb.Append(b)
|
||||
writer.Write(mb)
|
||||
}
|
||||
|
||||
// GetInternalResponse converts response settings from proto to internal data structure.
|
||||
|
@ -68,7 +68,7 @@ func NewChunkReader(reader io.Reader, auth *Authenticator) *ChunkReader {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *ChunkReader) Read() (*buf.Buffer, error) {
|
||||
func (v *ChunkReader) Read() (buf.MultiBuffer, error) {
|
||||
buffer := buf.New()
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(v.reader, 2)); err != nil {
|
||||
buffer.Release()
|
||||
@ -100,7 +100,10 @@ func (v *ChunkReader) Read() (*buf.Buffer, error) {
|
||||
}
|
||||
buffer.SliceFrom(AuthSize)
|
||||
|
||||
return buffer, nil
|
||||
mb := buf.NewMultiBuffer()
|
||||
mb.Append(buffer)
|
||||
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
type ChunkWriter struct {
|
||||
@ -117,11 +120,22 @@ func NewChunkWriter(writer io.Writer, auth *Authenticator) *ChunkWriter {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *ChunkWriter) Write(payload *buf.Buffer) error {
|
||||
func (w *ChunkWriter) Write(mb buf.MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
|
||||
for _, b := range mb {
|
||||
if err := w.writeInternal(b); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ChunkWriter) writeInternal(payload *buf.Buffer) error {
|
||||
totalLength := payload.Len()
|
||||
serial.Uint16ToBytes(uint16(totalLength), v.buffer[:0])
|
||||
v.auth.Authenticate(payload.Bytes())(v.buffer[2:])
|
||||
copy(v.buffer[2+AuthSize:], payload.Bytes())
|
||||
_, err := v.writer.Write(v.buffer[:2+AuthSize+payload.Len()])
|
||||
serial.Uint16ToBytes(uint16(totalLength), w.buffer[:0])
|
||||
w.auth.Authenticate(payload.Bytes())(w.buffer[2:])
|
||||
copy(w.buffer[2+AuthSize:], payload.Bytes())
|
||||
_, err := w.writer.Write(w.buffer[:2+AuthSize+payload.Len()])
|
||||
return err
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ func TestNormalChunkReading(t *testing.T) {
|
||||
[]byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36})))
|
||||
payload, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bytes(payload.Bytes()).Equals([]byte{11, 12, 13, 14, 15, 16, 17, 18})
|
||||
assert.Bytes(payload[0].Bytes()).Equals([]byte{11, 12, 13, 14, 15, 16, 17, 18})
|
||||
}
|
||||
|
||||
func TestNormalChunkWriting(t *testing.T) {
|
||||
@ -30,7 +30,7 @@ func TestNormalChunkWriting(t *testing.T) {
|
||||
|
||||
b := buf.NewLocal(256)
|
||||
b.Append([]byte{11, 12, 13, 14, 15, 16, 17, 18})
|
||||
err := writer.Write(b)
|
||||
err := writer.Write(buf.NewMultiBufferValue(b))
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bytes(buffer.Bytes()).Equals([]byte{0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18})
|
||||
}
|
||||
|
@ -362,7 +362,7 @@ type UDPReader struct {
|
||||
User *protocol.User
|
||||
}
|
||||
|
||||
func (v *UDPReader) Read() (*buf.Buffer, error) {
|
||||
func (v *UDPReader) Read() (buf.MultiBuffer, error) {
|
||||
buffer := buf.NewSmall()
|
||||
err := buffer.AppendSupplier(buf.ReadFrom(v.Reader))
|
||||
if err != nil {
|
||||
@ -374,7 +374,9 @@ func (v *UDPReader) Read() (*buf.Buffer, error) {
|
||||
buffer.Release()
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
mb := buf.NewMultiBuffer()
|
||||
mb.Append(payload)
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
type UDPWriter struct {
|
||||
@ -382,12 +384,21 @@ type UDPWriter struct {
|
||||
Request *protocol.RequestHeader
|
||||
}
|
||||
|
||||
func (v *UDPWriter) Write(buffer *buf.Buffer) error {
|
||||
payload, err := EncodeUDPPacket(v.Request, buffer)
|
||||
func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
|
||||
for _, b := range mb {
|
||||
if err := w.writeInternal(b); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *UDPWriter) writeInternal(buffer *buf.Buffer) error {
|
||||
payload, err := EncodeUDPPacket(w.Request, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = v.Writer.Write(payload.Bytes())
|
||||
_, err = w.Writer.Write(payload.Bytes())
|
||||
payload.Release()
|
||||
return err
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func TestTCPRequest(t *testing.T) {
|
||||
writer, err := WriteTCPRequest(request, cache)
|
||||
assert.Error(err).IsNil()
|
||||
|
||||
writer.Write(data)
|
||||
writer.Write(buf.NewMultiBufferValue(data))
|
||||
|
||||
decodedRequest, reader, err := ReadTCPSession(request.User, cache)
|
||||
assert.Error(err).IsNil()
|
||||
@ -75,7 +75,7 @@ func TestTCPRequest(t *testing.T) {
|
||||
|
||||
decodedData, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.String(decodedData.String()).Equals("test string")
|
||||
assert.String(decodedData[0].String()).Equals("test string")
|
||||
}
|
||||
|
||||
func TestUDPReaderWriter(t *testing.T) {
|
||||
@ -106,19 +106,19 @@ func TestUDPReaderWriter(t *testing.T) {
|
||||
|
||||
b := buf.New()
|
||||
b.AppendSupplier(serial.WriteString("test payload"))
|
||||
err := writer.Write(b)
|
||||
err := writer.Write(buf.NewMultiBufferValue(b))
|
||||
assert.Error(err).IsNil()
|
||||
|
||||
payload, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.String(payload.String()).Equals("test payload")
|
||||
assert.String(payload[0].String()).Equals("test payload")
|
||||
|
||||
b = buf.New()
|
||||
b.AppendSupplier(serial.WriteString("test payload 2"))
|
||||
err = writer.Write(b)
|
||||
err = writer.Write(buf.NewMultiBufferValue(b))
|
||||
assert.Error(err).IsNil()
|
||||
|
||||
payload, err = reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.String(payload.String()).Equals("test payload 2")
|
||||
assert.String(payload[0].String()).Equals("test payload 2")
|
||||
}
|
||||
|
@ -75,52 +75,54 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
|
||||
|
||||
reader := buf.NewReader(conn)
|
||||
for {
|
||||
payload, err := reader.Read()
|
||||
mpayload, err := reader.Read()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
request, data, err := DecodeUDPPacket(v.user, payload)
|
||||
if err != nil {
|
||||
if source, ok := proxy.SourceFromContext(ctx); ok {
|
||||
log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err))
|
||||
log.Access(source, "", log.AccessRejected, err)
|
||||
}
|
||||
payload.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled {
|
||||
log.Trace(newError("client payload enables OTA but server doesn't allow it"))
|
||||
payload.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled {
|
||||
log.Trace(newError("client payload disables OTA but server forces it"))
|
||||
payload.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
dest := request.Destination()
|
||||
if source, ok := proxy.SourceFromContext(ctx); ok {
|
||||
log.Access(source, dest, log.AccessAccepted, "")
|
||||
}
|
||||
log.Trace(newError("tunnelling request to ", dest))
|
||||
|
||||
ctx = protocol.ContextWithUser(ctx, request.User)
|
||||
udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
|
||||
defer payload.Release()
|
||||
|
||||
data, err := EncodeUDPPacket(request, payload)
|
||||
for _, payload := range mpayload {
|
||||
request, data, err := DecodeUDPPacket(v.user, payload)
|
||||
if err != nil {
|
||||
log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
|
||||
return
|
||||
if source, ok := proxy.SourceFromContext(ctx); ok {
|
||||
log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err))
|
||||
log.Access(source, "", log.AccessRejected, err)
|
||||
}
|
||||
payload.Release()
|
||||
continue
|
||||
}
|
||||
defer data.Release()
|
||||
|
||||
conn.Write(data.Bytes())
|
||||
})
|
||||
if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled {
|
||||
log.Trace(newError("client payload enables OTA but server doesn't allow it"))
|
||||
payload.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled {
|
||||
log.Trace(newError("client payload disables OTA but server forces it"))
|
||||
payload.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
dest := request.Destination()
|
||||
if source, ok := proxy.SourceFromContext(ctx); ok {
|
||||
log.Access(source, dest, log.AccessAccepted, "")
|
||||
}
|
||||
log.Trace(newError("tunnelling request to ", dest))
|
||||
|
||||
ctx = protocol.ContextWithUser(ctx, request.User)
|
||||
udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
|
||||
defer payload.Release()
|
||||
|
||||
data, err := EncodeUDPPacket(request, payload)
|
||||
if err != nil {
|
||||
log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
|
||||
return
|
||||
}
|
||||
defer data.Release()
|
||||
|
||||
conn.Write(data.Bytes())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -347,7 +347,7 @@ func NewUDPReader(reader io.Reader) *UDPReader {
|
||||
return &UDPReader{reader: reader}
|
||||
}
|
||||
|
||||
func (r *UDPReader) Read() (*buf.Buffer, error) {
|
||||
func (r *UDPReader) Read() (buf.MultiBuffer, error) {
|
||||
b := buf.NewSmall()
|
||||
if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
|
||||
return nil, err
|
||||
@ -358,7 +358,9 @@ func (r *UDPReader) Read() (*buf.Buffer, error) {
|
||||
}
|
||||
b.Clear()
|
||||
b.Append(data)
|
||||
return b, nil
|
||||
mb := buf.NewMultiBuffer()
|
||||
mb.Append(b)
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
type UDPWriter struct {
|
||||
@ -373,12 +375,15 @@ func NewUDPWriter(request *protocol.RequestHeader, writer io.Writer) *UDPWriter
|
||||
}
|
||||
}
|
||||
|
||||
func (w *UDPWriter) Write(b *buf.Buffer) error {
|
||||
eb := EncodeUDPPacket(w.request, b.Bytes())
|
||||
b.Release()
|
||||
defer eb.Release()
|
||||
if _, err := w.writer.Write(eb.Bytes()); err != nil {
|
||||
return err
|
||||
func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
|
||||
for _, b := range mb {
|
||||
eb := EncodeUDPPacket(w.request, b.Bytes())
|
||||
defer eb.Release()
|
||||
if _, err := w.writer.Write(eb.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -24,11 +24,11 @@ func TestUDPEncoding(t *testing.T) {
|
||||
content := []byte{'a'}
|
||||
payload := buf.New()
|
||||
payload.Append(content)
|
||||
assert.Error(writer.Write(payload)).IsNil()
|
||||
assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
|
||||
|
||||
reader := NewUDPReader(b)
|
||||
|
||||
decodedPayload, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bytes(decodedPayload.Bytes()).Equals(content)
|
||||
assert.Bytes(decodedPayload[0].Bytes()).Equals(content)
|
||||
}
|
||||
|
@ -159,38 +159,41 @@ func (v *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
|
||||
|
||||
reader := buf.NewReader(conn)
|
||||
for {
|
||||
payload, err := reader.Read()
|
||||
mpayload, err := reader.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
request, data, err := DecodeUDPPacket(payload.Bytes())
|
||||
|
||||
if err != nil {
|
||||
log.Trace(newError("failed to parse UDP request").Base(err))
|
||||
continue
|
||||
for _, payload := range mpayload {
|
||||
request, data, err := DecodeUDPPacket(payload.Bytes())
|
||||
|
||||
if err != nil {
|
||||
log.Trace(newError("failed to parse UDP request").Base(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Trace(newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug())
|
||||
if source, ok := proxy.SourceFromContext(ctx); ok {
|
||||
log.Access(source, request.Destination, log.AccessAccepted, "")
|
||||
}
|
||||
|
||||
dataBuf := buf.NewSmall()
|
||||
dataBuf.Append(data)
|
||||
udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) {
|
||||
defer payload.Release()
|
||||
|
||||
log.Trace(newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug())
|
||||
|
||||
udpMessage := EncodeUDPPacket(request, payload.Bytes())
|
||||
defer udpMessage.Release()
|
||||
|
||||
conn.Write(udpMessage.Bytes())
|
||||
})
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Trace(newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug())
|
||||
if source, ok := proxy.SourceFromContext(ctx); ok {
|
||||
log.Access(source, request.Destination, log.AccessAccepted, "")
|
||||
}
|
||||
|
||||
dataBuf := buf.NewSmall()
|
||||
dataBuf.Append(data)
|
||||
udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) {
|
||||
defer payload.Release()
|
||||
|
||||
log.Trace(newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug())
|
||||
|
||||
udpMessage := EncodeUDPPacket(request, payload.Bytes())
|
||||
defer udpMessage.Release()
|
||||
|
||||
conn.Write(udpMessage.Bytes())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,7 +166,7 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio
|
||||
}
|
||||
|
||||
if request.Option.Has(protocol.RequestOptionChunkStream) {
|
||||
if err := bodyWriter.Write(buf.NewLocal(8)); err != nil {
|
||||
if err := bodyWriter.Write(buf.NewMultiBuffer()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
|
||||
}
|
||||
|
||||
if request.Option.Has(protocol.RequestOptionChunkStream) {
|
||||
if err := bodyWriter.Write(buf.NewLocal(8)); err != nil {
|
||||
if err := bodyWriter.Write(buf.NewMultiBuffer()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination
|
||||
inboundRay, existing := v.getInboundRay(ctx, destination)
|
||||
outputStream := inboundRay.InboundInput()
|
||||
if outputStream != nil {
|
||||
if err := outputStream.Write(payload); err != nil {
|
||||
if err := outputStream.Write(buf.NewMultiBufferValue(payload)); err != nil {
|
||||
v.RemoveRay(destination)
|
||||
}
|
||||
}
|
||||
@ -71,10 +71,12 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination
|
||||
|
||||
func handleInput(input ray.InputStream, callback ResponseCallback) {
|
||||
for {
|
||||
data, err := input.Read()
|
||||
mb, err := input.Read()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
callback(data)
|
||||
for _, b := range mb {
|
||||
callback(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ func (v *directRay) InboundOutput() InputStream {
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
buffer chan *buf.Buffer
|
||||
buffer chan buf.MultiBuffer
|
||||
ctx context.Context
|
||||
close chan bool
|
||||
err chan bool
|
||||
@ -51,13 +51,13 @@ type Stream struct {
|
||||
func NewStream(ctx context.Context) *Stream {
|
||||
return &Stream{
|
||||
ctx: ctx,
|
||||
buffer: make(chan *buf.Buffer, bufferSize),
|
||||
buffer: make(chan buf.MultiBuffer, bufferSize),
|
||||
close: make(chan bool),
|
||||
err: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Stream) Read() (*buf.Buffer, error) {
|
||||
func (v *Stream) Read() (buf.MultiBuffer, error) {
|
||||
select {
|
||||
case <-v.ctx.Done():
|
||||
return nil, io.ErrClosedPipe
|
||||
@ -79,7 +79,7 @@ func (v *Stream) Read() (*buf.Buffer, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) {
|
||||
func (v *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
|
||||
select {
|
||||
case <-v.ctx.Done():
|
||||
return nil, io.ErrClosedPipe
|
||||
@ -107,7 +107,7 @@ func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Stream) Write(data *buf.Buffer) (err error) {
|
||||
func (v *Stream) Write(data buf.MultiBuffer) (err error) {
|
||||
if data.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ func TestStreamIO(t *testing.T) {
|
||||
stream := NewStream(context.Background())
|
||||
b1 := buf.New()
|
||||
b1.AppendBytes('a')
|
||||
assert.Error(stream.Write(b1)).IsNil()
|
||||
assert.Error(stream.Write(buf.NewMultiBufferValue(b1))).IsNil()
|
||||
|
||||
_, err := stream.Read()
|
||||
assert.Error(err).IsNil()
|
||||
@ -27,7 +27,7 @@ func TestStreamIO(t *testing.T) {
|
||||
|
||||
b2 := buf.New()
|
||||
b2.AppendBytes('b')
|
||||
err = stream.Write(b2)
|
||||
err = stream.Write(buf.NewMultiBufferValue(b2))
|
||||
assert.Error(err).Equals(io.ErrClosedPipe)
|
||||
}
|
||||
|
||||
@ -37,7 +37,7 @@ func TestStreamClose(t *testing.T) {
|
||||
stream := NewStream(context.Background())
|
||||
b1 := buf.New()
|
||||
b1.AppendBytes('a')
|
||||
assert.Error(stream.Write(b1)).IsNil()
|
||||
assert.Error(stream.Write(buf.NewMultiBufferValue(b1))).IsNil()
|
||||
|
||||
stream.Close()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user