From 94b880d060fd59ed2a1ce2433c4393dcb0502f1f Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Tue, 18 Sep 2018 23:09:54 +0200 Subject: [PATCH] move session based info into session package --- app/dispatcher/default.go | 8 ++-- app/proxyman/inbound/worker.go | 30 ++++++++------ app/proxyman/mux/mux.go | 10 +++-- app/proxyman/outbound/handler.go | 11 ++++- app/router/condition.go | 56 +++++++++++++++++--------- app/router/condition_test.go | 25 +++++++----- app/router/router.go | 12 ++++-- app/router/router_test.go | 5 +-- common/session/context.go | 46 +++++++++++++++++++++ common/session/request.go | 6 --- common/session/session.go | 32 ++++++--------- proxy/context.go | 56 +------------------------- proxy/dokodemo/dokodemo.go | 7 ++-- proxy/freedom/freedom.go | 6 ++- proxy/mtproto/client.go | 7 ++-- proxy/shadowsocks/client.go | 5 ++- proxy/shadowsocks/server.go | 11 +++-- proxy/socks/client.go | 5 ++- proxy/socks/server.go | 26 ++++++------ proxy/vmess/outbound/outbound.go | 6 ++- transport/internet/context.go | 11 ----- transport/internet/dialer.go | 8 +++- transport/internet/dialer_test.go | 2 +- transport/internet/http/dialer.go | 2 +- transport/internet/kcp/dialer.go | 3 +- transport/internet/tcp/dialer.go | 4 +- transport/internet/udp/dialer.go | 3 +- transport/internet/websocket/dialer.go | 3 +- 28 files changed, 212 insertions(+), 194 deletions(-) create mode 100644 common/session/context.go delete mode 100644 common/session/request.go diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 41bd983ca..e899bd869 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -16,7 +16,6 @@ import ( "v2ray.com/core/common/protocol" "v2ray.com/core/common/session" "v2ray.com/core/common/stats" - "v2ray.com/core/proxy" "v2ray.com/core/transport/pipe" ) @@ -165,7 +164,10 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin if !destination.IsValid() { panic("Dispatcher: Invalid destination.") } - ctx = proxy.ContextWithTarget(ctx, destination) + ob := &session.Outbound{ + Target: destination, + } + ctx = session.ContextWithOutbound(ctx, ob) inbound, outbound := d.getLink(ctx) sniffingConfig := proxyman.SniffingConfigFromContext(ctx) @@ -185,7 +187,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin domain := result.Domain() newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) destination.Address = net.ParseAddress(domain) - ctx = proxy.ContextWithTarget(ctx, destination) + ob.Target = destination } d.routedDispatch(ctx, outbound, destination) }() diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index bf00c3534..31cfe2180 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -55,14 +55,16 @@ func (w *tcpWorker) callback(conn internet.Connection) { newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx)) } if dest.IsValid() { - ctx = proxy.ContextWithOriginalTarget(ctx, dest) + ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + Target: dest, + }) } } - if len(w.tag) > 0 { - ctx = proxy.ContextWithInboundTag(ctx, w.tag) - } - ctx = proxy.ContextWithInboundEntryPoint(ctx, net.TCPDestination(w.address, w.port)) - ctx = proxy.ContextWithSource(ctx, net.DestinationFromAddr(conn.RemoteAddr())) + ctx = session.ContextWithInbound(ctx, &session.Inbound{ + Source: net.DestinationFromAddr(conn.RemoteAddr()), + Gateway: net.TCPDestination(w.address, w.port), + Tag: w.tag, + }) if w.sniffingConfig != nil { ctx = proxyman.ContextWithSniffingConfig(ctx, w.sniffingConfig) } @@ -268,15 +270,17 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest ctx = session.ContextWithID(ctx, sid) if originalDest.IsValid() { - ctx = proxy.ContextWithOriginalTarget(ctx, originalDest) + ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + Target: originalDest, + }) } - if len(w.tag) > 0 { - ctx = proxy.ContextWithInboundTag(ctx, w.tag) - } - ctx = proxy.ContextWithSource(ctx, source) - ctx = proxy.ContextWithInboundEntryPoint(ctx, net.UDPDestination(w.address, w.port)) + ctx = session.ContextWithInbound(ctx, &session.Inbound{ + Source: source, + Gateway: net.UDPDestination(w.address, w.port), + Tag: w.tag, + }) if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil { - newError("connection ends").Base(err).WriteToLog() + newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx)) } conn.Close() // nolint: errcheck w.removeConn(id) diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index 9cb039d35..1f5b4209f 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -88,7 +88,9 @@ var muxCoolPort = net.Port(9527) // NewClient creates a new mux.Client. func NewClient(pctx context.Context, p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client, error) { - ctx := proxy.ContextWithTarget(context.Background(), net.TCPDestination(muxCoolAddress, muxCoolPort)) + ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ + Target: net.TCPDestination(muxCoolAddress, muxCoolPort), + }) ctx, cancel := context.WithCancel(ctx) opts := pipe.OptionsFromContext(pctx) @@ -160,7 +162,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error { } func fetchInput(ctx context.Context, s *Session, output buf.Writer) { - dest, _ := proxy.TargetFromContext(ctx) + dest := session.OutboundFromContext(ctx).Target transferType := protocol.TransferTypeStream if dest.Network == net.Network_UDP { transferType = protocol.TransferTypePacket @@ -367,8 +369,8 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, Status: log.AccessAccepted, Reason: "", } - if src, f := proxy.SourceFromContext(ctx); f { - msg.From = src + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { + msg.From = inbound.Source } log.Record(msg) } diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 0267f38a6..10300e42f 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -107,7 +107,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn handler := h.outboundManager.GetHandler(tag) if handler != nil { newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx)) - ctx = proxy.ContextWithTarget(ctx, dest) + ctx = session.ContextWithOutbound(ctx, &session.Outbound{ + Target: dest, + }) opts := pipe.OptionsFromContext(ctx) uplinkReader, uplinkWriter := pipe.New(opts...) @@ -121,7 +123,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn } if h.senderSettings.Via != nil { - ctx = internet.ContextWithDialerSource(ctx, h.senderSettings.Via.AsAddress()) + outbound := session.OutboundFromContext(ctx) + if outbound == nil { + outbound = new(session.Outbound) + ctx = session.ContextWithOutbound(ctx, outbound) + } + outbound.Gateway = h.senderSettings.Via.AsAddress() } ctx = internet.ContextWithStreamSettings(ctx, h.streamSettings) diff --git a/app/router/condition.go b/app/router/condition.go index 9fcbbf33c..307c26120 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -4,6 +4,8 @@ import ( "context" "strings" + "v2ray.com/core/common/session" + "v2ray.com/core/app/dispatcher" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -110,11 +112,11 @@ func (m *DomainMatcher) ApplyDomain(domain string) bool { } func (m *DomainMatcher) Apply(ctx context.Context) bool { - dest, ok := proxy.TargetFromContext(ctx) - if !ok { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { return false } - + dest := outbound.Target if !dest.Address.Family().IsDomain() { return false } @@ -137,6 +139,22 @@ func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error) }, nil } +func sourceFromContext(ctx context.Context) net.Destination { + inbound := session.InboundFromContext(ctx) + if inbound == nil { + return net.Destination{} + } + return inbound.Source +} + +func targetFromContent(ctx context.Context) net.Destination { + outbound := session.OutboundFromContext(ctx) + if outbound == nil { + return net.Destination{} + } + return outbound.Target +} + func (v *CIDRMatcher) Apply(ctx context.Context) bool { ips := make([]net.IP, 0, 4) if resolver, ok := proxy.ResolvedIPsFromContext(ctx); ok { @@ -150,14 +168,13 @@ func (v *CIDRMatcher) Apply(ctx context.Context) bool { } var dest net.Destination - var ok bool if v.onSource { - dest, ok = proxy.SourceFromContext(ctx) + dest = sourceFromContext(ctx) } else { - dest, ok = proxy.TargetFromContext(ctx) + dest = targetFromContent(ctx) } - if ok && dest.Address.Family().IsIPv6() { + if dest.IsValid() && dest.Address.Family().IsIPv6() { ips = append(ips, dest.Address.IP()) } @@ -194,14 +211,13 @@ func (v *IPv4Matcher) Apply(ctx context.Context) bool { } var dest net.Destination - var ok bool if v.onSource { - dest, ok = proxy.SourceFromContext(ctx) + dest = sourceFromContext(ctx) } else { - dest, ok = proxy.TargetFromContext(ctx) + dest = targetFromContent(ctx) } - if ok && dest.Address.Family().IsIPv4() { + if dest.IsValid() && dest.Address.Family().IsIPv4() { ips = append(ips, dest.Address.IP()) } @@ -224,11 +240,11 @@ func NewPortMatcher(portRange net.PortRange) *PortMatcher { } func (v *PortMatcher) Apply(ctx context.Context) bool { - dest, ok := proxy.TargetFromContext(ctx) - if !ok { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { return false } - return v.port.Contains(dest.Port) + return v.port.Contains(outbound.Target.Port) } type NetworkMatcher struct { @@ -242,11 +258,11 @@ func NewNetworkMatcher(network *net.NetworkList) *NetworkMatcher { } func (v *NetworkMatcher) Apply(ctx context.Context) bool { - dest, ok := proxy.TargetFromContext(ctx) - if !ok { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { return false } - return v.network.HasNetwork(dest.Network) + return v.network.HasNetwork(outbound.Target.Network) } type UserMatcher struct { @@ -295,11 +311,11 @@ func NewInboundTagMatcher(tags []string) *InboundTagMatcher { } func (v *InboundTagMatcher) Apply(ctx context.Context) bool { - tag, ok := proxy.InboundTagFromContext(ctx) - if !ok { + inbound := session.InboundFromContext(ctx) + if inbound == nil || len(inbound.Tag) == 0 { return false } - + tag := inbound.Tag for _, t := range v.tags { if t == tag { return true diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 0ecacce50..1dd8fb416 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "v2ray.com/core/common/session" + proto "github.com/golang/protobuf/proto" "v2ray.com/core/app/dispatcher" . "v2ray.com/core/app/router" @@ -17,11 +19,14 @@ import ( "v2ray.com/core/common/platform" "v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol/http" - "v2ray.com/core/proxy" . "v2ray.com/ext/assert" "v2ray.com/ext/sysio" ) +func withOutbound(outbound *session.Outbound) context.Context { + return session.ContextWithOutbound(context.Background(), outbound) +} + func TestRoutingRule(t *testing.T) { assert := With(t) @@ -53,27 +58,27 @@ func TestRoutingRule(t *testing.T) { }, test: []ruleTest{ { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("v2ray.com"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}), output: true, }, { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v2ray.com.www"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("www.v2ray.com.www"), 80)}), output: true, }, { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("v2ray.co"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.co"), 80)}), output: false, }, { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.google.com"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("www.google.com"), 80)}), output: true, }, { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("facebook.com"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("facebook.com"), 80)}), output: true, }, { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.facebook.com"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("www.facebook.com"), 80)}), output: false, }, { @@ -101,15 +106,15 @@ func TestRoutingRule(t *testing.T) { }, test: []ruleTest{ { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.ParseAddress("8.8.8.8"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.ParseAddress("8.8.8.8"), 80)}), output: true, }, { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.ParseAddress("8.8.4.4"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.ParseAddress("8.8.4.4"), 80)}), output: false, }, { - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.ParseAddress("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), 80)), + input: withOutbound(&session.Outbound{Target: net.TCPDestination(net.ParseAddress("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), 80)}), output: true, }, { diff --git a/app/router/router.go b/app/router/router.go index fb30b419c..ac1d8db62 100644 --- a/app/router/router.go +++ b/app/router/router.go @@ -5,6 +5,8 @@ package router import ( "context" + "v2ray.com/core/common/session" + "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/net" @@ -75,9 +77,11 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) { resolver := &ipResolver{ dns: r.dns, } + + outbound := session.OutboundFromContext(ctx) if r.domainStrategy == Config_IpOnDemand { - if dest, ok := proxy.TargetFromContext(ctx); ok && dest.Address.Family().IsDomain() { - resolver.domain = dest.Address.Domain() + if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() { + resolver.domain = outbound.Target.Address.Domain() ctx = proxy.ContextWithResolveIPs(ctx, resolver) } } @@ -88,11 +92,11 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) { } } - dest, ok := proxy.TargetFromContext(ctx) - if !ok { + if outbound == nil || !outbound.Target.IsValid() { return "", core.ErrNoClue } + dest := outbound.Target if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() { resolver.domain = dest.Address.Domain() ips := resolver.Resolve() diff --git a/app/router/router_test.go b/app/router/router_test.go index 856d5adb5..8e69d47d6 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -1,7 +1,6 @@ package router_test import ( - "context" "testing" "v2ray.com/core" @@ -12,7 +11,7 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/net" "v2ray.com/core/common/serial" - "v2ray.com/core/proxy" + "v2ray.com/core/common/session" . "v2ray.com/ext/assert" ) @@ -41,7 +40,7 @@ func TestSimpleRouter(t *testing.T) { r := v.Router() - ctx := proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("v2ray.com"), 80)) + ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) tag, err := r.PickRoute(ctx) assert(err, IsNil) assert(tag, Equals, "test") diff --git a/common/session/context.go b/common/session/context.go new file mode 100644 index 000000000..6574e9abc --- /dev/null +++ b/common/session/context.go @@ -0,0 +1,46 @@ +package session + +import "context" + +type sessionKey int + +const ( + idSessionKey sessionKey = iota + inboundSessionKey + outboundSessionKey +) + +// ContextWithID returns a new context with the given ID. +func ContextWithID(ctx context.Context, id ID) context.Context { + return context.WithValue(ctx, idSessionKey, id) +} + +// IDFromContext returns ID in this context, or 0 if not contained. +func IDFromContext(ctx context.Context) ID { + if id, ok := ctx.Value(idSessionKey).(ID); ok { + return id + } + return 0 +} + +func ContextWithInbound(ctx context.Context, inbound *Inbound) context.Context { + return context.WithValue(ctx, inboundSessionKey, inbound) +} + +func InboundFromContext(ctx context.Context) *Inbound { + if inbound, ok := ctx.Value(inboundSessionKey).(*Inbound); ok { + return inbound + } + return nil +} + +func ContextWithOutbound(ctx context.Context, outbound *Outbound) context.Context { + return context.WithValue(ctx, outboundSessionKey, outbound) +} + +func OutboundFromContext(ctx context.Context) *Outbound { + if outbound, ok := ctx.Value(outboundSessionKey).(*Outbound); ok { + return outbound + } + return nil +} diff --git a/common/session/request.go b/common/session/request.go deleted file mode 100644 index 2b54047aa..000000000 --- a/common/session/request.go +++ /dev/null @@ -1,6 +0,0 @@ -package session - -type Request struct { - //Destination net.Destination - DecodedLayers []interface{} -} diff --git a/common/session/session.go b/common/session/session.go index b7b9042e1..0f750884d 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -6,6 +6,7 @@ import ( "math/rand" "v2ray.com/core/common/errors" + "v2ray.com/core/common/net" ) // ID of a session. @@ -22,28 +23,21 @@ func NewID() ID { } } -type sessionKey int - -const ( - idSessionKey sessionKey = iota -) - -// ContextWithID returns a new context with the given ID. -func ContextWithID(ctx context.Context, id ID) context.Context { - return context.WithValue(ctx, idSessionKey, id) -} - -// IDFromContext returns ID in this context, or 0 if not contained. -func IDFromContext(ctx context.Context) ID { - if id, ok := ctx.Value(idSessionKey).(ID); ok { - return id - } - return 0 -} - func ExportIDToError(ctx context.Context) errors.ExportOption { id := IDFromContext(ctx) return func(h *errors.ExportOptionHolder) { h.SessionID = uint32(id) } } + +type Inbound struct { + Source net.Destination + Gateway net.Destination + Tag string +} + +type Outbound struct { + Target net.Destination + Gateway net.Address + ResolvedIPs []net.IP +} diff --git a/proxy/context.go b/proxy/context.go index 7bc551255..c15380c01 100644 --- a/proxy/context.go +++ b/proxy/context.go @@ -6,64 +6,12 @@ import ( "v2ray.com/core/common/net" ) -type key int +type key uint32 const ( - sourceKey key = iota - targetKey - originalTargetKey - inboundEntryPointKey - inboundTagKey - resolvedIPsKey + resolvedIPsKey key = iota ) -// ContextWithSource creates a new context with given source. -func ContextWithSource(ctx context.Context, src net.Destination) context.Context { - return context.WithValue(ctx, sourceKey, src) -} - -// SourceFromContext retrieves source from the given context. -func SourceFromContext(ctx context.Context) (net.Destination, bool) { - v, ok := ctx.Value(sourceKey).(net.Destination) - return v, ok -} - -func ContextWithOriginalTarget(ctx context.Context, dest net.Destination) context.Context { - return context.WithValue(ctx, originalTargetKey, dest) -} - -func OriginalTargetFromContext(ctx context.Context) (net.Destination, bool) { - v, ok := ctx.Value(originalTargetKey).(net.Destination) - return v, ok -} - -func ContextWithTarget(ctx context.Context, dest net.Destination) context.Context { - return context.WithValue(ctx, targetKey, dest) -} - -func TargetFromContext(ctx context.Context) (net.Destination, bool) { - v, ok := ctx.Value(targetKey).(net.Destination) - return v, ok -} - -func ContextWithInboundEntryPoint(ctx context.Context, dest net.Destination) context.Context { - return context.WithValue(ctx, inboundEntryPointKey, dest) -} - -func InboundEntryPointFromContext(ctx context.Context) (net.Destination, bool) { - v, ok := ctx.Value(inboundEntryPointKey).(net.Destination) - return v, ok -} - -func ContextWithInboundTag(ctx context.Context, tag string) context.Context { - return context.WithValue(ctx, inboundTagKey, tag) -} - -func InboundTagFromContext(ctx context.Context) (string, bool) { - v, ok := ctx.Value(inboundTagKey).(string) - return v, ok -} - type IPResolver interface { Resolve() []net.Address } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index cc612519d..f347b05dd 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -13,7 +13,6 @@ import ( "v2ray.com/core/common/session" "v2ray.com/core/common/signal" "v2ray.com/core/common/task" - "v2ray.com/core/proxy" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/pipe" ) @@ -65,8 +64,8 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in Port: d.port, } if d.config.FollowRedirect { - if origDest, ok := proxy.OriginalTargetFromContext(ctx); ok { - dest = origDest + if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() { + dest = outbound.Target } else if handshake, ok := conn.(hasHandshakeAddress); ok { addr := handshake.HandshakeAddress() if addr != nil { @@ -117,7 +116,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in Tproxy: internet.SocketConfig_TProxy, }, }) - tConn, err := internet.DialSystem(tCtx, nil, net.DestinationFromAddr(conn.RemoteAddr())) + tConn, err := internet.DialSystem(tCtx, net.DestinationFromAddr(conn.RemoteAddr())) if err != nil { return err } diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 7f1044e93..41cfecab3 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -76,7 +76,11 @@ func isValidAddress(addr *net.IPOrDomain) bool { // Process implements proxy.Outbound. func (h *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dialer) error { - destination, _ := proxy.TargetFromContext(ctx) + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { + return newError("target not specified.") + } + destination := outbound.Target if h.config.DestinationOverride != nil { server := h.config.DestinationOverride.Server if isValidAddress(server.Address) { diff --git a/proxy/mtproto/client.go b/proxy/mtproto/client.go index db6d35ad6..058e19e53 100644 --- a/proxy/mtproto/client.go +++ b/proxy/mtproto/client.go @@ -8,6 +8,7 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/crypto" "v2ray.com/core/common/net" + "v2ray.com/core/common/session" "v2ray.com/core/common/task" "v2ray.com/core/proxy" ) @@ -20,11 +21,11 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { } func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dialer) error { - dest, ok := proxy.TargetFromContext(ctx) - if !ok { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { return newError("unknown destination.") } - + dest := outbound.Target if dest.Network != net.Network_TCP { return newError("not TCP traffic", dest) } diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 362bb7963..9ca16df1a 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -45,10 +45,11 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements OutboundHandler.Process(). func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dialer) error { - destination, ok := proxy.TargetFromContext(ctx) - if !ok { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified") } + destination := outbound.Target network := destination.Network var server *protocol.ServerSpec diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 0ed848292..f46013532 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -13,7 +13,6 @@ import ( "v2ray.com/core/common/session" "v2ray.com/core/common/signal" "v2ray.com/core/common/task" - "v2ray.com/core/proxy" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/udp" "v2ray.com/core/transport/pipe" @@ -99,10 +98,10 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection for _, payload := range mpayload { request, data, err := DecodeUDPPacket(s.user, payload) if err != nil { - if source, ok := proxy.SourceFromContext(ctx); ok { - newError("dropping invalid UDP packet from: ", source).Base(err).WriteToLog(session.ExportIDToError(ctx)) + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { + newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx)) log.Record(&log.AccessMessage{ - From: source, + From: inbound.Source, To: "", Status: log.AccessRejected, Reason: err, @@ -125,9 +124,9 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection } dest := request.Destination() - if source, ok := proxy.SourceFromContext(ctx); ok { + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { log.Record(&log.AccessMessage{ - From: source, + From: inbound.Source, To: dest, Status: log.AccessAccepted, Reason: "", diff --git a/proxy/socks/client.go b/proxy/socks/client.go index bd67ba39f..f6f789ddc 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -47,10 +47,11 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { // Process implements proxy.Outbound.Process. func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dialer) error { - destination, ok := proxy.TargetFromContext(ctx) - if !ok { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified.") } + destination := outbound.Target var server *protocol.ServerSpec var conn internet.Connection diff --git a/proxy/socks/server.go b/proxy/socks/server.go index bb5d1f661..5821b3ab6 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -14,7 +14,6 @@ import ( "v2ray.com/core/common/session" "v2ray.com/core/common/signal" "v2ray.com/core/common/task" - "v2ray.com/core/proxy" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/udp" "v2ray.com/core/transport/pipe" @@ -73,21 +72,22 @@ func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispa newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx)) } - inboundDest, ok := proxy.InboundEntryPointFromContext(ctx) - if !ok { - return newError("inbound entry point not specified") + inbound := session.InboundFromContext(ctx) + if inbound == nil || !inbound.Gateway.IsValid() { + return newError("inbound gateway not specified") } + svrSession := &ServerSession{ config: s.config, - port: inboundDest.Port, + port: inbound.Gateway.Port, } reader := &buf.BufferedReader{Reader: buf.NewReader(conn)} request, err := svrSession.Handshake(reader, conn) if err != nil { - if source, ok := proxy.SourceFromContext(ctx); ok { + if inbound != nil && inbound.Source.IsValid() { log.Record(&log.AccessMessage{ - From: source, + From: inbound.Source, To: "", Status: log.AccessRejected, Reason: err, @@ -103,9 +103,9 @@ func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispa if request.Command == protocol.RequestCommandTCP { dest := request.Destination() newError("TCP Connect request to ", dest).WriteToLog(session.ExportIDToError(ctx)) - if source, ok := proxy.SourceFromContext(ctx); ok { + if inbound != nil && inbound.Source.IsValid() { log.Record(&log.AccessMessage{ - From: source, + From: inbound.Source, To: dest, Status: log.AccessAccepted, Reason: "", @@ -188,8 +188,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, conn.Write(udpMessage.Bytes()) // nolint: errcheck }) - if source, ok := proxy.SourceFromContext(ctx); ok { - newError("client UDP connection from ", source).WriteToLog(session.ExportIDToError(ctx)) + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { + newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx)) } reader := buf.NewReader(conn) @@ -214,9 +214,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, } newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx)) - if source, ok := proxy.SourceFromContext(ctx); ok { + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { log.Record(&log.AccessMessage{ - From: source, + From: inbound.Source, To: request.Destination(), Status: log.AccessAccepted, Reason: "", diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index bfce2c287..adad38aea 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -68,10 +68,12 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia } defer conn.Close() //nolint: errcheck - target, ok := proxy.TargetFromContext(ctx) - if !ok { + outbound := session.OutboundFromContext(ctx) + if outbound == nil || !outbound.Target.IsValid() { return newError("target not specified").AtError() } + + target := outbound.Target newError("tunneling request to ", target, " via ", rec.Destination()).WriteToLog(session.ExportIDToError(ctx)) command := protocol.RequestCommandTCP diff --git a/transport/internet/context.go b/transport/internet/context.go index 6b059a6c4..da9168f6f 100644 --- a/transport/internet/context.go +++ b/transport/internet/context.go @@ -26,17 +26,6 @@ func StreamSettingsFromContext(ctx context.Context) *MemoryStreamConfig { return ss.(*MemoryStreamConfig) } -func ContextWithDialerSource(ctx context.Context, addr net.Address) context.Context { - return context.WithValue(ctx, dialerSrcKey, addr) -} - -func DialerSourceFromContext(ctx context.Context) net.Address { - if addr, ok := ctx.Value(dialerSrcKey).(net.Address); ok { - return addr - } - return net.AnyIP -} - func ContextWithBindAddress(ctx context.Context, dest net.Destination) context.Context { return context.WithValue(ctx, bindAddrKey, dest) } diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 23633150f..e5bd13c0d 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -3,6 +3,8 @@ package internet import ( "context" + "v2ray.com/core/common/session" + "v2ray.com/core/common/net" ) @@ -53,6 +55,10 @@ func Dial(ctx context.Context, dest net.Destination) (Connection, error) { } // DialSystem calls system dialer to create a network connection. -func DialSystem(ctx context.Context, src net.Address, dest net.Destination) (net.Conn, error) { +func DialSystem(ctx context.Context, dest net.Destination) (net.Conn, error) { + var src net.Address + if outbound := session.OutboundFromContext(ctx); outbound != nil { + src = outbound.Gateway + } return effectiveSystemDialer.Dial(ctx, src, dest) } diff --git a/transport/internet/dialer_test.go b/transport/internet/dialer_test.go index dbc9e587c..42a02c9bc 100644 --- a/transport/internet/dialer_test.go +++ b/transport/internet/dialer_test.go @@ -18,7 +18,7 @@ func TestDialWithLocalAddr(t *testing.T) { assert(err, IsNil) defer server.Close() - conn, err := DialSystem(context.Background(), net.LocalHostIP, net.TCPDestination(net.LocalHostIP, dest.Port)) + conn, err := DialSystem(context.Background(), net.TCPDestination(net.LocalHostIP, dest.Port)) assert(err, IsNil) assert(conn.RemoteAddr().String(), Equals, "127.0.0.1:"+dest.Port.String()) conn.Close() diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index 99474e3fb..7cdd5bf14 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -53,7 +53,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, err } address := net.ParseAddress(rawHost) - pconn, err := internet.DialSystem(context.Background(), nil, net.TCPDestination(address, port)) + pconn, err := internet.DialSystem(context.Background(), net.TCPDestination(address, port)) if err != nil { return nil, err } diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index f4fd18d74..ae8c36fa5 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -49,8 +49,7 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er dest.Network = net.Network_UDP newError("dialing mKCP to ", dest).WriteToLog() - src := internet.DialerSourceFromContext(ctx) - rawConn, err := internet.DialSystem(ctx, src, dest) + rawConn, err := internet.DialSystem(ctx, dest) if err != nil { return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err) } diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index ccec5812c..c445dbf5c 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -21,9 +21,7 @@ func getTCPSettingsFromContext(ctx context.Context) *Config { // Dial dials a new TCP connection to the given destination. func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { newError("dialing TCP to ", dest).WriteToLog(session.ExportIDToError(ctx)) - src := internet.DialerSourceFromContext(ctx) - - conn, err := internet.DialSystem(ctx, src, dest) + conn, err := internet.DialSystem(ctx, dest) if err != nil { return nil, err } diff --git a/transport/internet/udp/dialer.go b/transport/internet/udp/dialer.go index 0e6d7e539..0fdb3bdc7 100644 --- a/transport/internet/udp/dialer.go +++ b/transport/internet/udp/dialer.go @@ -11,8 +11,7 @@ import ( func init() { common.Must(internet.RegisterTransportDialer(protocolName, func(ctx context.Context, dest net.Destination) (internet.Connection, error) { - src := internet.DialerSourceFromContext(ctx) - conn, err := internet.DialSystem(ctx, src, dest) + conn, err := internet.DialSystem(ctx, dest) if err != nil { return nil, err } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 17556b9c1..0cb2d0690 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -29,12 +29,11 @@ func init() { } func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) { - src := internet.DialerSourceFromContext(ctx) wsSettings := internet.StreamSettingsFromContext(ctx).ProtocolSettings.(*Config) dialer := &websocket.Dialer{ NetDial: func(network, addr string) (net.Conn, error) { - return internet.DialSystem(ctx, src, dest) + return internet.DialSystem(ctx, dest) }, ReadBufferSize: 4 * 1024, WriteBufferSize: 4 * 1024,