diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index b16c344a5..2621571c2 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -56,7 +56,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher Interval: time.Minute, Execute: s.Cleanup, } - s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse) + s.udpServer = udp.NewSplitDispatcher(dispatcher, s.HandleResponse) newError("DNS: created UDP client initialized for ", address.NetAddr()).AtInfo().WriteToLog() return s } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index d5fda1df9..f78440016 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -70,7 +70,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet } func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error { - udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { + udpServer := udp.NewSplitDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { request := protocol.RequestHeaderFromContext(ctx) if request == nil { return diff --git a/proxy/socks/server.go b/proxy/socks/server.go index ad2927f59..f974df33e 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -186,7 +186,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ } func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error { - udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { + udpServer := udp.NewSplitDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { payload := packet.Payload newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx)) diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 909cf93e5..0892f619f 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -204,7 +204,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet } func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { - udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { + udpServer := udp.NewSplitDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { if err := clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source); err != nil { newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) } diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go index 2ae1151fa..63d79bbdc 100644 --- a/transport/internet/udp/dispatcher.go +++ b/transport/internet/udp/dispatcher.go @@ -2,196 +2,10 @@ package udp import ( "context" - "io" - "sync" - "time" - - "github.com/v2fly/v2ray-core/v5/common" "github.com/v2fly/v2ray-core/v5/common/buf" "github.com/v2fly/v2ray-core/v5/common/net" - "github.com/v2fly/v2ray-core/v5/common/protocol/udp" - "github.com/v2fly/v2ray-core/v5/common/session" - "github.com/v2fly/v2ray-core/v5/common/signal" - "github.com/v2fly/v2ray-core/v5/common/signal/done" - "github.com/v2fly/v2ray-core/v5/features/routing" - "github.com/v2fly/v2ray-core/v5/transport" ) -type ResponseCallback func(ctx context.Context, packet *udp.Packet) - -type connEntry struct { - link *transport.Link - timer signal.ActivityUpdater - cancel context.CancelFunc -} - -type Dispatcher struct { - sync.RWMutex - conns map[net.Destination]*connEntry - dispatcher routing.Dispatcher - callback ResponseCallback -} - -func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { - return &Dispatcher{ - conns: make(map[net.Destination]*connEntry), - dispatcher: dispatcher, - callback: callback, - } -} - -func (v *Dispatcher) RemoveRay(dest net.Destination) { - v.Lock() - defer v.Unlock() - if conn, found := v.conns[dest]; found { - common.Close(conn.link.Reader) - common.Close(conn.link.Writer) - delete(v.conns, dest) - } -} - -func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry { - v.Lock() - defer v.Unlock() - - if entry, found := v.conns[dest]; found { - return entry - } - - newError("establishing new connection for ", dest).WriteToLog() - - ctx, cancel := context.WithCancel(ctx) - removeRay := func() { - cancel() - v.RemoveRay(dest) - } - timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4) - link, _ := v.dispatcher.Dispatch(ctx, dest) - entry := &connEntry{ - link: link, - timer: timer, - cancel: removeRay, - } - v.conns[dest] = entry - go handleInput(ctx, entry, dest, v.callback) - return entry -} - -func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) { - // TODO: Add user to destString - newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx)) - - conn := v.getInboundRay(ctx, destination) - outputStream := conn.link.Writer - if outputStream != nil { - if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { - newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) - conn.cancel() - return - } - } -} - -func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) { - defer conn.cancel() - - input := conn.link.Reader - timer := conn.timer - - for { - select { - case <-ctx.Done(): - return - default: - } - - mb, err := input.ReadMultiBuffer() - if err != nil { - newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx)) - return - } - timer.Update() - for _, b := range mb { - callback(ctx, &udp.Packet{ - Payload: b, - Source: dest, - }) - } - } -} - -type dispatcherConn struct { - dispatcher *Dispatcher - cache chan *udp.Packet - done *done.Instance -} - -func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) { - c := &dispatcherConn{ - cache: make(chan *udp.Packet, 16), - done: done.New(), - } - - d := NewDispatcher(dispatcher, c.callback) - c.dispatcher = d - return c, nil -} - -func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) { - select { - case <-c.done.Wait(): - packet.Payload.Release() - return - case c.cache <- packet: - default: - packet.Payload.Release() - return - } -} - -func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) { - select { - case <-c.done.Wait(): - return 0, nil, io.EOF - case packet := <-c.cache: - n := copy(p, packet.Payload.Bytes()) - return n, &net.UDPAddr{ - IP: packet.Source.Address.IP(), - Port: int(packet.Source.Port), - }, nil - } -} - -func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) { - buffer := buf.New() - raw := buffer.Extend(buf.Size) - n := copy(raw, p) - buffer.Resize(0, int32(n)) - - ctx := context.Background() - c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer) - return n, nil -} - -func (c *dispatcherConn) Close() error { - return c.done.Close() -} - -func (c *dispatcherConn) LocalAddr() net.Addr { - return &net.UDPAddr{ - IP: []byte{0, 0, 0, 0}, - Port: 0, - } -} - -func (c *dispatcherConn) SetDeadline(t time.Time) error { - return nil -} - -func (c *dispatcherConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *dispatcherConn) SetWriteDeadline(t time.Time) error { - return nil +type DispatcherI interface { + Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) } diff --git a/transport/internet/udp/dispatcher_split.go b/transport/internet/udp/dispatcher_split.go new file mode 100644 index 000000000..ae3030cd5 --- /dev/null +++ b/transport/internet/udp/dispatcher_split.go @@ -0,0 +1,197 @@ +package udp + +import ( + "context" + "io" + "sync" + "time" + + "github.com/v2fly/v2ray-core/v5/common" + "github.com/v2fly/v2ray-core/v5/common/buf" + "github.com/v2fly/v2ray-core/v5/common/net" + "github.com/v2fly/v2ray-core/v5/common/protocol/udp" + "github.com/v2fly/v2ray-core/v5/common/session" + "github.com/v2fly/v2ray-core/v5/common/signal" + "github.com/v2fly/v2ray-core/v5/common/signal/done" + "github.com/v2fly/v2ray-core/v5/features/routing" + "github.com/v2fly/v2ray-core/v5/transport" +) + +type ResponseCallback func(ctx context.Context, packet *udp.Packet) + +type connEntry struct { + link *transport.Link + timer signal.ActivityUpdater + cancel context.CancelFunc +} + +type Dispatcher struct { + sync.RWMutex + conns map[net.Destination]*connEntry + dispatcher routing.Dispatcher + callback ResponseCallback +} + +func NewSplitDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { + return &Dispatcher{ + conns: make(map[net.Destination]*connEntry), + dispatcher: dispatcher, + callback: callback, + } +} + +func (v *Dispatcher) RemoveRay(dest net.Destination) { + v.Lock() + defer v.Unlock() + if conn, found := v.conns[dest]; found { + common.Close(conn.link.Reader) + common.Close(conn.link.Writer) + delete(v.conns, dest) + } +} + +func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry { + v.Lock() + defer v.Unlock() + + if entry, found := v.conns[dest]; found { + return entry + } + + newError("establishing new connection for ", dest).WriteToLog() + + ctx, cancel := context.WithCancel(ctx) + removeRay := func() { + cancel() + v.RemoveRay(dest) + } + timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4) + link, _ := v.dispatcher.Dispatch(ctx, dest) + entry := &connEntry{ + link: link, + timer: timer, + cancel: removeRay, + } + v.conns[dest] = entry + go handleInput(ctx, entry, dest, v.callback) + return entry +} + +func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) { + // TODO: Add user to destString + newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx)) + + conn := v.getInboundRay(ctx, destination) + outputStream := conn.link.Writer + if outputStream != nil { + if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { + newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) + conn.cancel() + return + } + } +} + +func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) { + defer conn.cancel() + + input := conn.link.Reader + timer := conn.timer + + for { + select { + case <-ctx.Done(): + return + default: + } + + mb, err := input.ReadMultiBuffer() + if err != nil { + newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx)) + return + } + timer.Update() + for _, b := range mb { + callback(ctx, &udp.Packet{ + Payload: b, + Source: dest, + }) + } + } +} + +type dispatcherConn struct { + dispatcher *Dispatcher + cache chan *udp.Packet + done *done.Instance +} + +func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) { + c := &dispatcherConn{ + cache: make(chan *udp.Packet, 16), + done: done.New(), + } + + d := NewSplitDispatcher(dispatcher, c.callback) + c.dispatcher = d + return c, nil +} + +func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) { + select { + case <-c.done.Wait(): + packet.Payload.Release() + return + case c.cache <- packet: + default: + packet.Payload.Release() + return + } +} + +func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) { + select { + case <-c.done.Wait(): + return 0, nil, io.EOF + case packet := <-c.cache: + n := copy(p, packet.Payload.Bytes()) + return n, &net.UDPAddr{ + IP: packet.Source.Address.IP(), + Port: int(packet.Source.Port), + }, nil + } +} + +func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) { + buffer := buf.New() + raw := buffer.Extend(buf.Size) + n := copy(raw, p) + buffer.Resize(0, int32(n)) + + ctx := context.Background() + c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer) + return n, nil +} + +func (c *dispatcherConn) Close() error { + return c.done.Close() +} + +func (c *dispatcherConn) LocalAddr() net.Addr { + return &net.UDPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, + } +} + +func (c *dispatcherConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *dispatcherConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *dispatcherConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/transport/internet/udp/dispatcher_test.go b/transport/internet/udp/dispatcher_split_test.go similarity index 95% rename from transport/internet/udp/dispatcher_test.go rename to transport/internet/udp/dispatcher_split_test.go index 12c9e7c8a..b704b57ff 100644 --- a/transport/internet/udp/dispatcher_test.go +++ b/transport/internet/udp/dispatcher_split_test.go @@ -65,7 +65,7 @@ func TestSameDestinationDispatching(t *testing.T) { b.WriteString("abcd") var msgCount uint32 - dispatcher := NewDispatcher(td, func(ctx context.Context, packet *udp.Packet) { + dispatcher := NewSplitDispatcher(td, func(ctx context.Context, packet *udp.Packet) { atomic.AddUint32(&msgCount, 1) })