diff --git a/app/reverse/bridge.go b/app/reverse/bridge.go index c38bf5af0..a34309dee 100644 --- a/app/reverse/bridge.go +++ b/app/reverse/bridge.go @@ -17,6 +17,7 @@ import ( // Bridge is a component in reverse proxy, that relays connections from Portal to local address. type Bridge struct { + ctx context.Context dispatcher routing.Dispatcher tag string domain string @@ -25,7 +26,7 @@ type Bridge struct { } // NewBridge creates a new Bridge instance. -func NewBridge(config *BridgeConfig, dispatcher routing.Dispatcher) (*Bridge, error) { +func NewBridge(ctx context.Context, config *BridgeConfig, dispatcher routing.Dispatcher) (*Bridge, error) { if config.Tag == "" { return nil, newError("bridge tag is empty") } @@ -34,6 +35,7 @@ func NewBridge(config *BridgeConfig, dispatcher routing.Dispatcher) (*Bridge, er } b := &Bridge{ + ctx: ctx, dispatcher: dispatcher, tag: config.Tag, domain: config.Domain, @@ -73,7 +75,7 @@ func (b *Bridge) monitor() error { } if numWorker == 0 || numConnections/numWorker > 16 { - worker, err := NewBridgeWorker(b.domain, b.tag, b.dispatcher) + worker, err := NewBridgeWorker(b.ctx, b.domain, b.tag, b.dispatcher) if err != nil { newError("failed to create bridge worker").Base(err).AtWarning().WriteToLog() return nil @@ -99,12 +101,11 @@ type BridgeWorker struct { state Control_State } -func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) { - ctx := context.Background() - ctx = session.ContextWithInbound(ctx, &session.Inbound{ +func NewBridgeWorker(ctx context.Context, domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) { + bridgeCtx := session.ContextWithInbound(ctx, &session.Inbound{ Tag: tag, }) - link, err := d.Dispatch(ctx, net.Destination{ + link, err := d.Dispatch(bridgeCtx, net.Destination{ Network: net.Network_TCP, Address: net.DomainAddress(domain), Port: 0, @@ -118,7 +119,7 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo tag: tag, } - worker, err := mux.NewServerWorker(context.Background(), w, link) + worker, err := mux.NewServerWorker(ctx, w, link) if err != nil { return nil, err } diff --git a/app/reverse/portal.go b/app/reverse/portal.go index 9a6f57421..02e915e76 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -19,6 +19,7 @@ import ( ) type Portal struct { + ctx context.Context ohm outbound.Manager tag string domain string @@ -26,7 +27,7 @@ type Portal struct { client *mux.ClientManager } -func NewPortal(config *PortalConfig, ohm outbound.Manager) (*Portal, error) { +func NewPortal(ctx context.Context, config *PortalConfig, ohm outbound.Manager) (*Portal, error) { if config.Tag == "" { return nil, newError("portal tag is empty") } @@ -41,6 +42,7 @@ func NewPortal(config *PortalConfig, ohm outbound.Manager) (*Portal, error) { } return &Portal{ + ctx: ctx, ohm: ohm, tag: config.Tag, domain: config.Domain, @@ -52,14 +54,14 @@ func NewPortal(config *PortalConfig, ohm outbound.Manager) (*Portal, error) { } func (p *Portal) Start() error { - return p.ohm.AddHandler(context.Background(), &Outbound{ + return p.ohm.AddHandler(p.ctx, &Outbound{ portal: p, tag: p.tag, }) } func (p *Portal) Close() error { - return p.ohm.RemoveHandler(context.Background(), p.tag) + return p.ohm.RemoveHandler(p.ctx, p.tag) } func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error { @@ -74,7 +76,7 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err return newError("failed to create mux client worker").Base(err).AtWarning() } - worker, err := NewPortalWorker(muxClient) + worker, err := NewPortalWorker(ctx, muxClient) if err != nil { return newError("failed to create portal worker").Base(err) } @@ -198,12 +200,11 @@ type PortalWorker struct { draining bool } -func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { +func NewPortalWorker(ctx context.Context, client *mux.ClientWorker) (*PortalWorker, error) { opt := []pipe.Option{pipe.WithSizeLimit(16 * 1024)} uplinkReader, uplinkWriter := pipe.New(opt...) downlinkReader, downlinkWriter := pipe.New(opt...) - ctx := context.Background() ctx = session.ContextWithOutbound(ctx, &session.Outbound{ Target: net.UDPDestination(net.DomainAddress(internalDomain), 0), }) diff --git a/app/reverse/reverse.go b/app/reverse/reverse.go index e9c23e33d..3beb18ede 100644 --- a/app/reverse/reverse.go +++ b/app/reverse/reverse.go @@ -29,7 +29,7 @@ func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { r := new(Reverse) if err := core.RequireFeatures(ctx, func(d routing.Dispatcher, om outbound.Manager) error { - return r.Init(config.(*Config), d, om) + return r.Init(ctx, config.(*Config), d, om) }); err != nil { return nil, err } @@ -42,9 +42,9 @@ type Reverse struct { portals []*Portal } -func (r *Reverse) Init(config *Config, d routing.Dispatcher, ohm outbound.Manager) error { +func (r *Reverse) Init(ctx context.Context, config *Config, d routing.Dispatcher, ohm outbound.Manager) error { for _, bConfig := range config.BridgeConfig { - b, err := NewBridge(bConfig, d) + b, err := NewBridge(ctx, bConfig, d) if err != nil { return err } @@ -52,7 +52,7 @@ func (r *Reverse) Init(config *Config, d routing.Dispatcher, ohm outbound.Manage } for _, pConfig := range config.PortalConfig { - p, err := NewPortal(pConfig, ohm) + p, err := NewPortal(ctx, pConfig, ohm) if err != nil { return err }