diff --git a/app/tun/handler.go b/app/tun/handler.go index 57d2a8b55..882b9e6e7 100644 --- a/app/tun/handler.go +++ b/app/tun/handler.go @@ -1,7 +1,31 @@ package tun -import "github.com/v2fly/v2ray-core/v5/common/net" +import ( + "github.com/v2fly/v2ray-core/v5/common/net" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) -type Handler interface { - Handle(conn net.Conn) error +var ( + tcpQueue = make(chan TCPConn) + udpQueue = make(chan UDPConn) +) + +type TCPConn interface { + net.Conn + + ID() *stack.TransportEndpointID +} + +type UDPConn interface { + net.Conn + + ID() *stack.TransportEndpointID +} + +func handleTCP(conn TCPConn) { + tcpQueue <- conn +} + +func handleUDP(conn UDPConn) { + udpQueue <- conn } diff --git a/app/tun/handler_tcp.go b/app/tun/handler_tcp.go index dcbba79b9..cae5515f7 100644 --- a/app/tun/handler_tcp.go +++ b/app/tun/handler_tcp.go @@ -23,6 +23,15 @@ const ( maxInFlight = 2 << 10 ) +type tcpConn struct { + *gonet.TCPConn + id stack.TransportEndpointID +} + +func (c *tcpConn) ID() *stack.TransportEndpointID { + return &c.id +} + type TCPHandler struct { ctx context.Context dispatcher routing.Dispatcher @@ -32,7 +41,7 @@ type TCPHandler struct { stack *stack.Stack } -func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) StackOption { +func HandleTCP(handle func(TCPConn)) StackOption { return func(s *stack.Stack) error { tcpForwarder := tcp.NewForwarder(s, rcvWnd, maxInFlight, func(r *tcp.ForwarderRequest) { wg := new(waiter.Queue) @@ -45,19 +54,25 @@ func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyMan // TODO: set sockopt - tcpHandler := &TCPHandler{ - ctx: ctx, - dispatcher: dispatcher, - policyManager: policyManager, - config: config, - stack: s, - } - - if err := tcpHandler.Handle(gonet.NewTCPConn(wg, linkedEndpoint)); err != nil { - // TODO: log - // return newError("failed to handle tcp connection").Base(err) + // tcpHandler := &TCPHandler{ + // ctx: ctx, + // dispatcher: dispatcher, + // policyManager: policyManager, + // config: config, + // stack: s, + // } + + // if err := tcpHandler.Handle(gonet.NewTCPConn(wg, linkedEndpoint)); err != nil { + // // TODO: log + // // return newError("failed to handle tcp connection").Base(err) + // } + + tcpConn := &tcpConn{ + TCPConn: gonet.NewTCPConn(wg, linkedEndpoint), + id: r.ID(), } + tcpQueue <- tcpConn }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) @@ -65,7 +80,20 @@ func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyMan } } -func (h *TCPHandler) Handle(conn net.Conn) error { +func (h *TCPHandler) HandleQueue(ch chan TCPConn) { + for { + select { + case conn := <-ch: + if err := h.Handle(conn); err != nil { + newError(err).AtError().WriteToLog(session.ExportIDToError(h.ctx)) + } + case <-h.ctx.Done(): + return + } + } +} + +func (h *TCPHandler) Handle(conn TCPConn) error { ctx := session.ContextWithInbound(h.ctx, &session.Inbound{Tag: h.config.Tag}) sessionPolicy := h.policyManager.ForLevel(h.config.UserLevel) diff --git a/app/tun/handler_udp.go b/app/tun/handler_udp.go index 28401477c..df240c8b7 100644 --- a/app/tun/handler_udp.go +++ b/app/tun/handler_udp.go @@ -26,7 +26,16 @@ type UDPHandler struct { stack *stack.Stack } -func SetUDPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) StackOption { +type udpConn struct { + *gonet.UDPConn + id stack.TransportEndpointID +} + +func (c *udpConn) ID() *stack.TransportEndpointID { + return &c.id +} + +func HandleUDP(handle func(UDPConn)) StackOption { return func(s *stack.Stack) error { udpForwarder := gvisor_udp.NewForwarder(s, func(r *gvisor_udp.ForwarderRequest) { wg := new(waiter.Queue) @@ -36,21 +45,32 @@ func SetUDPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyMan return } - udpConn := gonet.NewUDPConn(s, wg, linkedEndpoint) - udpHandler := &UDPHandler{ - ctx: ctx, - dispatcher: dispatcher, - policyManager: policyManager, - config: config, - stack: s, + udpConn := &udpConn{ + UDPConn: gonet.NewUDPConn(s, wg, linkedEndpoint), + id: r.ID(), } - udpHandler.Handle(udpConn) + + handle(udpConn) }) s.SetTransportProtocolHandler(gvisor_udp.ProtocolNumber, udpForwarder.HandlePacket) return nil } } -func (h *UDPHandler) Handle(conn net.Conn) error { + +func (h *UDPHandler) HandleQueue(ch chan UDPConn) { + for { + select { + case <-h.ctx.Done(): + return + case conn := <-ch: + if err := h.Handle(conn); err != nil { + newError(err).AtError().WriteToLog(session.ExportIDToError(h.ctx)) + } + } + } +} + +func (h *UDPHandler) Handle(conn UDPConn) error { ctx := session.ContextWithInbound(h.ctx, &session.Inbound{Tag: h.config.Tag}) packetConn := conn.(net.PacketConn) diff --git a/app/tun/stack.go b/app/tun/stack.go index 6040a388c..63cb3f612 100644 --- a/app/tun/stack.go +++ b/app/tun/stack.go @@ -29,8 +29,8 @@ func (t *TUN) CreateStack(linkedEndpoint stack.LinkEndpoint) (*stack.Stack, erro nicID := tcpip.NICID(s.UniqueID()) opts := []StackOption{ - SetTCPHandler(t.ctx, t.dispatcher, t.policyManager, t.config), - SetUDPHandler(t.ctx, t.dispatcher, t.policyManager, t.config), + HandleTCP(handleTCP), + HandleUDP(handleUDP), CreateNIC(nicID, linkedEndpoint), AddProtocolAddress(nicID, t.config.Ips), diff --git a/app/tun/tun.go b/app/tun/tun.go index 5b1396663..b7fa20e69 100644 --- a/app/tun/tun.go +++ b/app/tun/tun.go @@ -46,6 +46,24 @@ func (t *TUN) Start() error { } t.stack = stack + tcpHandler := &TCPHandler{ + ctx: t.ctx, + dispatcher: t.dispatcher, + policyManager: t.policyManager, + config: t.config, + stack: stack, + } + go tcpHandler.Handle(<-tcpQueue) + + udpHander := &UDPHandler{ + ctx: t.ctx, + dispatcher: t.dispatcher, + policyManager: t.policyManager, + config: t.config, + stack: stack, + } + go udpHander.Handle(<-udpQueue) + return nil }