1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-22 18:17:52 -05:00
v2fly/app/tun/handler_tcp.go
2023-10-22 22:44:28 +01:00

117 lines
3.2 KiB
Go

package tun
import (
"context"
"github.com/v2fly/v2ray-core/v5/common"
"github.com/v2fly/v2ray-core/v5/common/buf"
"github.com/v2fly/v2ray-core/v5/common/log"
"github.com/v2fly/v2ray-core/v5/common/net"
"github.com/v2fly/v2ray-core/v5/common/session"
"github.com/v2fly/v2ray-core/v5/common/signal"
"github.com/v2fly/v2ray-core/v5/common/task"
"github.com/v2fly/v2ray-core/v5/features/policy"
"github.com/v2fly/v2ray-core/v5/features/routing"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
rcvWnd = 0 // default settings
maxInFlight = 2 << 10
)
type TCPHandler struct {
ctx context.Context
dispatcher routing.Dispatcher
policyManager policy.Manager
config *Config
stack *stack.Stack
}
func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) StackOption {
return func(s *stack.Stack) error {
tcpForwarder := tcp.NewForwarder(s, rcvWnd, maxInFlight, func(r *tcp.ForwarderRequest) {
wg := new(waiter.Queue)
linkedEndpoint, err := r.CreateEndpoint(wg)
if err != nil {
r.Complete(true)
return
}
defer r.Complete(false)
// TODO: set sockopt
tcpHandler := &TCPHandler{
ctx: ctx,
dispatcher: dispatcher,
policyManager: policyManager,
config: config,
stack: s,
}
if err := tcpHandler.Handle(gonet.NewTCPConn(wg, linkedEndpoint)); err != nil {
// TODO: log
// return newError("failed to handle tcp connection").Base(err)
}
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
return nil
}
}
func (h *TCPHandler) Handle(conn net.Conn) error {
ctx := session.ContextWithInbound(h.ctx, &session.Inbound{Tag: h.config.Tag})
sessionPolicy := h.policyManager.ForLevel(h.config.UserLevel)
addr := conn.RemoteAddr()
dest := net.DestinationFromAddr(addr)
ctx = log.ContextWithAccessMessage(h.ctx, &log.AccessMessage{
From: addr,
To: dest,
Status: log.AccessAccepted,
Reason: "",
})
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
link, err := h.dispatcher.Dispatch(ctx, dest)
if err != nil {
return newError("failed to dispatch").Base(err)
}
responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all TCP response").Base(err)
}
return nil
}
requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all TCP request").Base(err)
}
return nil
}
requestDoneAndCloseWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(h.ctx, requestDoneAndCloseWriter, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}
return nil
}