diff --git a/common/buf/io.go b/common/buf/io.go index 0c58a53ff..1a56bf75f 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -2,6 +2,7 @@ package buf import ( "io" + "net" "syscall" "time" ) @@ -38,6 +39,11 @@ func WriteAllBytes(writer io.Writer, payload []byte) error { return nil } +func isPacketReader(reader io.Reader) bool { + _, ok := reader.(net.PacketConn) + return ok +} + // NewReader creates a new Reader. // The Reader instance doesn't take the ownership of reader. func NewReader(reader io.Reader) Reader { @@ -45,6 +51,12 @@ func NewReader(reader io.Reader) Reader { return mr } + if isPacketReader(reader) { + return &PacketReader{ + Reader: reader, + } + } + if useReadv { if sc, ok := reader.(syscall.Conn); ok { rawConn, err := sc.SyscallConn() @@ -61,14 +73,25 @@ func NewReader(reader io.Reader) Reader { } } +func isPacketWriter(writer io.Writer) bool { + if _, ok := writer.(net.PacketConn); ok { + return true + } + + // If the writer doesn't implement syscall.Conn, it is probably not a TCP connection. + if _, ok := writer.(syscall.Conn); !ok { + return true + } + return false +} + // NewWriter creates a new Writer. func NewWriter(writer io.Writer) Writer { if mw, ok := writer.(Writer); ok { return mw } - if _, ok := writer.(syscall.Conn); !ok { - // If the writer doesn't implement syscall.Conn, it is probably not a TCP connection. + if isPacketWriter(writer) { return &SequentialWriter{ Writer: writer, } diff --git a/common/buf/reader.go b/common/buf/reader.go index 2c740a0cb..dbc968443 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -7,6 +7,23 @@ import ( "v2ray.com/core/common/errors" ) +func readOneUDP(r io.Reader) (*Buffer, error) { + b := New() + for i := 0; i < 64; i++ { + _, err := b.ReadFrom(r) + if !b.IsEmpty() { + return b, nil + } + if err != nil { + b.Release() + return nil, err + } + } + + b.Release() + return nil, newError("Reader returns too many empty payloads.") +} + func readOne(r io.Reader) (*Buffer, error) { // Use an one-byte buffer to wait for incoming payload. var firstByte [1]byte @@ -152,3 +169,17 @@ func (r *SingleReader) ReadMultiBuffer() (MultiBuffer, error) { } return MultiBuffer{b}, nil } + +// PacketReader is a Reader that read one Buffer every time. +type PacketReader struct { + io.Reader +} + +// ReadMultiBuffer implements Reader. +func (r *PacketReader) ReadMultiBuffer() (MultiBuffer, error) { + b, err := readOneUDP(r.Reader) + if err != nil { + return nil, err + } + return MultiBuffer{b}, nil +}