diff --git a/app/tun/handler_tcp.go b/app/tun/handler_tcp.go index 78314299e..11b4a71ef 100644 --- a/app/tun/handler_tcp.go +++ b/app/tun/handler_tcp.go @@ -31,7 +31,7 @@ type TCPHandler struct { stack *stack.Stack } -func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) func(*stack.Stack) error { +func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) StackOption { return func(s *stack.Stack) error { tcpForwarder := tcp.NewForwarder(s, rcvWnd, maxInFlight, func(r *tcp.ForwarderRequest) { wg := new(waiter.Queue) diff --git a/app/tun/nic.go b/app/tun/nic.go new file mode 100644 index 000000000..d24b748ce --- /dev/null +++ b/app/tun/nic.go @@ -0,0 +1,19 @@ +package tun + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +func CreateNIC(nicID tcpip.NICID, linkEndpoint stack.LinkEndpoint) StackOption { + return func(s *stack.Stack) error { + if err := s.CreateNICWithOptions(nicID, linkEndpoint, + stack.NICOptions{ + Disabled: false, + QDisc: nil, + }); err != nil { + return newError("failed to create NIC:", err) + } + return nil + } +} diff --git a/app/tun/stack.go b/app/tun/stack.go index c80887aa0..c127a9dfe 100644 --- a/app/tun/stack.go +++ b/app/tun/stack.go @@ -1,6 +1,7 @@ package tun import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -11,7 +12,7 @@ import ( type StackOption func(*stack.Stack) error -func (t *TUN) CreateStack(_ stack.LinkEndpoint) (*stack.Stack, error) { +func (t *TUN) CreateStack(linkedEndpoint stack.LinkEndpoint) (*stack.Stack, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -25,8 +26,12 @@ func (t *TUN) CreateStack(_ stack.LinkEndpoint) (*stack.Stack, error) { }, }) + nicID := tcpip.NICID(s.UniqueID()) + opts := []StackOption{ SetTCPHandler(t.ctx, t.dispatcher, t.policyManager, t.config), + + CreateNIC(nicID, linkedEndpoint), } for _, opt := range opts { @@ -35,7 +40,5 @@ func (t *TUN) CreateStack(_ stack.LinkEndpoint) (*stack.Stack, error) { } } - // nicID := tcpip.NICID(s.UniqueID()) - return s, nil }