diff --git a/common/buf/multi_buffer.go b/common/buf/multi_buffer.go index a3b639fca..71ca773be 100644 --- a/common/buf/multi_buffer.go +++ b/common/buf/multi_buffer.go @@ -122,6 +122,17 @@ func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int) { return mb, totalBytes } +// SplitFirstBytes splits the first buffer from MultiBuffer, and then copy its content into the given slice. +func SplitFirstBytes(mb MultiBuffer, p []byte) (MultiBuffer, int) { + mb, b := SplitFirst(mb) + if b == nil { + return mb, 0 + } + n := copy(p, b.Bytes()) + b.Release() + return mb, n +} + // Compact returns another MultiBuffer by merging all content of the given one together. func Compact(mb MultiBuffer) MultiBuffer { if len(mb) == 0 { diff --git a/common/buf/reader.go b/common/buf/reader.go index ee13355b7..f0c408b1c 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -58,6 +58,8 @@ type BufferedReader struct { Reader Reader // Buffer is the internal buffer to be read from first Buffer MultiBuffer + // Spliter is a function to read bytes from MultiBuffer + Spliter func(MultiBuffer, []byte) (MultiBuffer, int) } // BufferedBytes returns the number of bytes that is cached in this reader. @@ -74,8 +76,13 @@ func (r *BufferedReader) ReadByte() (byte, error) { // Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader. func (r *BufferedReader) Read(b []byte) (int, error) { + spliter := r.Spliter + if spliter == nil { + spliter = SplitBytes + } + if !r.Buffer.IsEmpty() { - buffer, nBytes := SplitBytes(r.Buffer, b) + buffer, nBytes := spliter(r.Buffer, b) r.Buffer = buffer if r.Buffer.IsEmpty() { r.Buffer = nil @@ -88,7 +95,7 @@ func (r *BufferedReader) Read(b []byte) (int, error) { return 0, err } - mb, nBytes := SplitBytes(mb, b) + mb, nBytes := spliter(mb, b) if !mb.IsEmpty() { r.Buffer = mb } diff --git a/common/net/connection.go b/common/net/connection.go index 41fc55e9d..2a1d73099 100644 --- a/common/net/connection.go +++ b/common/net/connection.go @@ -48,6 +48,15 @@ func ConnectionOutputMulti(reader buf.Reader) ConnectionOption { } } +func ConnectionOutputMultiUDP(reader buf.Reader) ConnectionOption { + return func(c *connection) { + c.reader = &buf.BufferedReader{ + Reader: reader, + Spliter: buf.SplitFirstBytes, + } + } +} + func ConnectionOnClose(n io.Closer) ConnectionOption { return func(c *connection) { c.onClose = n diff --git a/functions.go b/functions.go index f8d7f000b..6394b3d40 100644 --- a/functions.go +++ b/functions.go @@ -53,7 +53,13 @@ func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, err if err != nil { return nil, err } - return net.NewConnection(net.ConnectionInputMulti(r.Writer), net.ConnectionOutputMulti(r.Reader)), nil + var readerOpt net.ConnectionOption + if dest.Network == net.Network_TCP { + readerOpt = net.ConnectionOutputMulti(r.Reader) + } else { + readerOpt = net.ConnectionOutputMultiUDP(r.Reader) + } + return net.NewConnection(net.ConnectionInputMulti(r.Writer), readerOpt), nil } // DialUDP provides a way to exchange UDP packets through V2Ray instance to remote servers. diff --git a/functions_test.go b/functions_test.go index b5a25ed32..4fd3112e2 100644 --- a/functions_test.go +++ b/functions_test.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "io" "testing" + "time" "github.com/golang/protobuf/proto" "github.com/google/go-cmp/cmp" @@ -86,6 +87,66 @@ func TestV2RayDial(t *testing.T) { } } +func TestV2RayDialUDPConn(t *testing.T) { + udpServer := udp.Server{ + MsgProcessor: xor, + } + dest, err := udpServer.Start() + common.Must(err) + defer udpServer.Close() + + config := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&dispatcher.Config{}), + serial.ToTypedMessage(&proxyman.InboundConfig{}), + serial.ToTypedMessage(&proxyman.OutboundConfig{}), + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + cfgBytes, err := proto.Marshal(config) + common.Must(err) + + server, err := core.StartInstance("protobuf", cfgBytes) + common.Must(err) + defer server.Close() + + conn, err := core.Dial(context.Background(), server, dest) + common.Must(err) + defer conn.Close() + + const size = 1024 + payload := make([]byte, size) + common.Must2(rand.Read(payload)) + + for i := 0; i < 2; i++ { + if _, err := conn.Write(payload); err != nil { + t.Fatal(err) + } + } + + time.Sleep(time.Millisecond * 500) + + receive := make([]byte, size*2) + for i := 0; i < 2; i++ { + n, err := conn.Read(receive) + if err != nil { + t.Fatal("expect no error, but got ", err) + } + if n != size { + t.Fatal("expect read size ", size, " but got ", n) + } + + if r := cmp.Diff(xor(receive[:n]), payload); r != "" { + t.Fatal(r) + } + } +} + func TestV2RayDialUDP(t *testing.T) { udpServer1 := udp.Server{ MsgProcessor: xor,