diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 804092fec..296eb0f85 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -244,10 +244,7 @@ func (w *udpWorker) removeConn(id connID) { func (w *udpWorker) Start() error { w.activeConn = make(map[connID]*udpConn, 16) w.done = signal.NewDone() - h, err := udp.ListenUDP(w.address, w.port, udp.ListenOption{ - Callback: w.callback, - ReceiveOriginalDest: w.recvOrigDest, - }) + h, err := udp.ListenUDP(w.address, w.port, w.callback, udp.HubReceiveOriginalDestination(w.recvOrigDest)) if err != nil { return err } diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 3fdcbf9f4..f35f57eff 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -61,7 +61,7 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon l.tlsConfig = config.GetTLSConfig() } - hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive, Concurrency: 2}) + hub, err := udp.ListenUDP(address, port, l.OnReceive, udp.HubCapacity(64)) if err != nil { return nil, err } diff --git a/transport/internet/udp/hub.go b/transport/internet/udp/hub.go index 5651d1f50..52b2b510f 100644 --- a/transport/internet/udp/hub.go +++ b/transport/internet/udp/hub.go @@ -2,7 +2,6 @@ package udp import ( "v2ray.com/core/common/buf" - "v2ray.com/core/common/dice" "v2ray.com/core/common/net" ) @@ -16,71 +15,28 @@ type Payload struct { // PayloadHandler is function to handle Payload. type PayloadHandler func(payload *buf.Buffer, source net.Destination, originalDest net.Destination) -// PayloadQueue is a queue of Payload. -type PayloadQueue struct { - queue []chan Payload - callback PayloadHandler -} +type HubOption func(h *Hub) -// NewPayloadQueue returns a new PayloadQueue. -func NewPayloadQueue(option ListenOption) *PayloadQueue { - queue := &PayloadQueue{ - callback: option.Callback, - queue: make([]chan Payload, option.Concurrency), - } - for i := range queue.queue { - queue.queue[i] = make(chan Payload, 64) - go queue.Dequeue(queue.queue[i]) - } - return queue -} - -// Enqueue adds the payload to the end of this queue. -func (q *PayloadQueue) Enqueue(payload Payload) { - size := len(q.queue) - idx := 0 - if size > 1 { - idx = dice.Roll(size) - } - for i := 0; i < size; i++ { - select { - case q.queue[idx%size] <- payload: - return - default: - idx++ - } +func HubCapacity(cap int) HubOption { + return func(h *Hub) { + h.capacity = cap } } -func (q *PayloadQueue) Dequeue(queue <-chan Payload) { - for payload := range queue { - q.callback(payload.payload, payload.source, payload.originalDest) +func HubReceiveOriginalDestination(r bool) HubOption { + return func(h *Hub) { + h.recvOrigDest = r } } -func (q *PayloadQueue) Close() error { - for _, queue := range q.queue { - close(queue) - } - return nil -} - -type ListenOption struct { - Callback PayloadHandler - ReceiveOriginalDest bool - Concurrency int -} - type Hub struct { - conn *net.UDPConn - queue *PayloadQueue - option ListenOption + conn *net.UDPConn + callback PayloadHandler + capacity int + recvOrigDest bool } -func ListenUDP(address net.Address, port net.Port, option ListenOption) (*Hub, error) { - if option.Concurrency < 1 { - option.Concurrency = 1 - } +func ListenUDP(address net.Address, port net.Port, callback PayloadHandler, options ...HubOption) (*Hub, error) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{ IP: address.IP(), Port: int(port), @@ -89,7 +45,17 @@ func ListenUDP(address net.Address, port net.Port, option ListenOption) (*Hub, e return nil, err } newError("listening UDP on ", address, ":", port).WriteToLog() - if option.ReceiveOriginalDest { + hub := &Hub{ + conn: udpConn, + capacity: 16, + callback: callback, + recvOrigDest: false, + } + for _, opt := range options { + opt(hub) + } + + if hub.recvOrigDest { rawConn, err := udpConn.SyscallConn() if err != nil { return nil, newError("failed to get fd").Base(err) @@ -103,12 +69,11 @@ func ListenUDP(address net.Address, port net.Port, option ListenOption) (*Hub, e return nil, newError("failed to control socket").Base(err) } } - hub := &Hub{ - conn: udpConn, - queue: NewPayloadQueue(option), - option: option, - } - go hub.start() + + c := make(chan *Payload, hub.capacity) + + go hub.start(c) + go hub.process(c) return hub, nil } @@ -125,7 +90,15 @@ func (h *Hub) WriteTo(payload []byte, dest net.Destination) (int, error) { }) } -func (h *Hub) start() { +func (h *Hub) process(c <-chan *Payload) { + for p := range c { + h.callback(p.payload, p.source, p.originalDest) + } +} + +func (h *Hub) start(c chan<- *Payload) { + defer close(c) + oobBytes := make([]byte, 256) for { @@ -145,11 +118,11 @@ func (h *Hub) start() { break } - payload := Payload{ + payload := &Payload{ payload: buffer, } payload.source = net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)) - if h.option.ReceiveOriginalDest && noob > 0 { + if h.recvOrigDest && noob > 0 { payload.originalDest = RetrieveOriginalDest(oobBytes[:noob]) if payload.originalDest.IsValid() { newError("UDP original destination: ", payload.originalDest).AtDebug().WriteToLog() @@ -157,9 +130,13 @@ func (h *Hub) start() { newError("failed to read UDP original destination").WriteToLog() } } - h.queue.Enqueue(payload) + + select { + case c <- payload: + default: + } + } - h.queue.Close() } // Addr implements net.Listener.