fix connection reading in UDP

This commit is contained in:
Darien Raymond 2019-01-06 00:34:38 +01:00
parent b52725cf65
commit 4e77570f36
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
5 changed files with 97 additions and 3 deletions

View File

@ -122,6 +122,17 @@ func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int) {
return mb, totalBytes
}
// SplitFirstBytes splits the first buffer from MultiBuffer, and then copy its content into the given slice.
func SplitFirstBytes(mb MultiBuffer, p []byte) (MultiBuffer, int) {
mb, b := SplitFirst(mb)
if b == nil {
return mb, 0
}
n := copy(p, b.Bytes())
b.Release()
return mb, n
}
// Compact returns another MultiBuffer by merging all content of the given one together.
func Compact(mb MultiBuffer) MultiBuffer {
if len(mb) == 0 {

View File

@ -58,6 +58,8 @@ type BufferedReader struct {
Reader Reader
// Buffer is the internal buffer to be read from first
Buffer MultiBuffer
// Spliter is a function to read bytes from MultiBuffer
Spliter func(MultiBuffer, []byte) (MultiBuffer, int)
}
// BufferedBytes returns the number of bytes that is cached in this reader.
@ -74,8 +76,13 @@ func (r *BufferedReader) ReadByte() (byte, error) {
// Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
func (r *BufferedReader) Read(b []byte) (int, error) {
spliter := r.Spliter
if spliter == nil {
spliter = SplitBytes
}
if !r.Buffer.IsEmpty() {
buffer, nBytes := SplitBytes(r.Buffer, b)
buffer, nBytes := spliter(r.Buffer, b)
r.Buffer = buffer
if r.Buffer.IsEmpty() {
r.Buffer = nil
@ -88,7 +95,7 @@ func (r *BufferedReader) Read(b []byte) (int, error) {
return 0, err
}
mb, nBytes := SplitBytes(mb, b)
mb, nBytes := spliter(mb, b)
if !mb.IsEmpty() {
r.Buffer = mb
}

View File

@ -48,6 +48,15 @@ func ConnectionOutputMulti(reader buf.Reader) ConnectionOption {
}
}
func ConnectionOutputMultiUDP(reader buf.Reader) ConnectionOption {
return func(c *connection) {
c.reader = &buf.BufferedReader{
Reader: reader,
Spliter: buf.SplitFirstBytes,
}
}
}
func ConnectionOnClose(n io.Closer) ConnectionOption {
return func(c *connection) {
c.onClose = n

View File

@ -53,7 +53,13 @@ func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, err
if err != nil {
return nil, err
}
return net.NewConnection(net.ConnectionInputMulti(r.Writer), net.ConnectionOutputMulti(r.Reader)), nil
var readerOpt net.ConnectionOption
if dest.Network == net.Network_TCP {
readerOpt = net.ConnectionOutputMulti(r.Reader)
} else {
readerOpt = net.ConnectionOutputMultiUDP(r.Reader)
}
return net.NewConnection(net.ConnectionInputMulti(r.Writer), readerOpt), nil
}
// DialUDP provides a way to exchange UDP packets through V2Ray instance to remote servers.

View File

@ -5,6 +5,7 @@ import (
"crypto/rand"
"io"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
@ -86,6 +87,66 @@ func TestV2RayDial(t *testing.T) {
}
}
func TestV2RayDialUDPConn(t *testing.T) {
udpServer := udp.Server{
MsgProcessor: xor,
}
dest, err := udpServer.Start()
common.Must(err)
defer udpServer.Close()
config := &core.Config{
App: []*serial.TypedMessage{
serial.ToTypedMessage(&dispatcher.Config{}),
serial.ToTypedMessage(&proxyman.InboundConfig{}),
serial.ToTypedMessage(&proxyman.OutboundConfig{}),
},
Outbound: []*core.OutboundHandlerConfig{
{
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
},
},
}
cfgBytes, err := proto.Marshal(config)
common.Must(err)
server, err := core.StartInstance("protobuf", cfgBytes)
common.Must(err)
defer server.Close()
conn, err := core.Dial(context.Background(), server, dest)
common.Must(err)
defer conn.Close()
const size = 1024
payload := make([]byte, size)
common.Must2(rand.Read(payload))
for i := 0; i < 2; i++ {
if _, err := conn.Write(payload); err != nil {
t.Fatal(err)
}
}
time.Sleep(time.Millisecond * 500)
receive := make([]byte, size*2)
for i := 0; i < 2; i++ {
n, err := conn.Read(receive)
if err != nil {
t.Fatal("expect no error, but got ", err)
}
if n != size {
t.Fatal("expect read size ", size, " but got ", n)
}
if r := cmp.Diff(xor(receive[:n]), payload); r != "" {
t.Fatal(r)
}
}
}
func TestV2RayDialUDP(t *testing.T) {
udpServer1 := udp.Server{
MsgProcessor: xor,