From 595f3d685ed9c21a8a95b1055d287e0dd44e1132 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Mon, 15 Oct 2018 08:36:50 +0200 Subject: [PATCH] merge user info inbound metadata --- app/dispatcher/default.go | 7 ++++++- app/router/condition.go | 8 ++++++-- app/router/condition_test.go | 8 ++++++-- common/protocol/context.go | 17 +---------------- common/protocol/user.go | 2 ++ common/session/session.go | 3 +++ proxy/shadowsocks/server.go | 16 ++++++++++++---- proxy/vmess/inbound/inbound.go | 7 ++++++- 8 files changed, 42 insertions(+), 26 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 8517d82eb..1113f3fc6 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -133,7 +133,12 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*vio.Link, *vio.Link) Writer: downlinkWriter, } - user := protocol.UserFromContext(ctx) + sessionInbound := session.InboundFromContext(ctx) + var user *protocol.MemoryUser + if sessionInbound != nil { + user = sessionInbound.User + } + if user != nil && len(user.Email) > 0 { p := d.policy.ForLevel(user.Level) if p.Stats.UserUplink { diff --git a/app/router/condition.go b/app/router/condition.go index 307c26120..120dd39b5 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -8,7 +8,6 @@ import ( "v2ray.com/core/app/dispatcher" "v2ray.com/core/common/net" - "v2ray.com/core/common/protocol" "v2ray.com/core/common/strmatcher" "v2ray.com/core/proxy" ) @@ -282,7 +281,12 @@ func NewUserMatcher(users []string) *UserMatcher { } func (v *UserMatcher) Apply(ctx context.Context) bool { - user := protocol.UserFromContext(ctx) + inbound := session.InboundFromContext(ctx) + if inbound == nil { + return false + } + + user := inbound.User if user == nil { return false } diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 1dd8fb416..96fe096c7 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -27,6 +27,10 @@ func withOutbound(outbound *session.Outbound) context.Context { return session.ContextWithOutbound(context.Background(), outbound) } +func withInbound(inbound *session.Inbound) context.Context { + return session.ContextWithInbound(context.Background(), inbound) +} + func TestRoutingRule(t *testing.T) { assert := With(t) @@ -131,11 +135,11 @@ func TestRoutingRule(t *testing.T) { }, test: []ruleTest{ { - input: protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "admin@v2ray.com"}), + input: withInbound(&session.Inbound{User: &protocol.MemoryUser{Email: "admin@v2ray.com"}}), output: true, }, { - input: protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "love@v2ray.com"}), + input: withInbound(&session.Inbound{User: &protocol.MemoryUser{Email: "love@v2ray.com"}}), output: false, }, { diff --git a/common/protocol/context.go b/common/protocol/context.go index 42c980790..6bb510424 100755 --- a/common/protocol/context.go +++ b/common/protocol/context.go @@ -7,24 +7,9 @@ import ( type key int const ( - userKey key = iota - requestKey + requestKey key = iota ) -// ContextWithUser returns a context combined with a User. -func ContextWithUser(ctx context.Context, user *MemoryUser) context.Context { - return context.WithValue(ctx, userKey, user) -} - -// UserFromContext extracts a User from the given context, if any. -func UserFromContext(ctx context.Context) *MemoryUser { - v := ctx.Value(userKey) - if v == nil { - return nil - } - return v.(*MemoryUser) -} - func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context { return context.WithValue(ctx, requestKey, request) } diff --git a/common/protocol/user.go b/common/protocol/user.go index 30b78c80b..8325f5551 100644 --- a/common/protocol/user.go +++ b/common/protocol/user.go @@ -30,7 +30,9 @@ func (u *User) ToMemoryUser() (*MemoryUser, error) { }, nil } +// MemoryUser is a parsed form of User, to reduce number of parsing of Account proto. type MemoryUser struct { + // Account is the parsed account of the protocol. Account Account Email string Level uint32 diff --git a/common/session/session.go b/common/session/session.go index 0f750884d..7c018f0b0 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -7,6 +7,7 @@ import ( "v2ray.com/core/common/errors" "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" ) // ID of a session. @@ -34,6 +35,8 @@ type Inbound struct { Source net.Destination Gateway net.Destination Tag string + // User is the user that authencates for the inbound. May be nil if the protocol allows anounymous traffic. + User *protocol.MemoryUser } type Outbound struct { diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index fefd4c814..40776035e 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -89,6 +89,11 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection }) account := s.user.Account.(*MemoryAccount) + inbound := session.InboundFromContext(ctx) + if inbound == nil { + panic("no inbound metadata") + } + inbound.User = s.user reader := buf.NewReader(conn) for { @@ -126,7 +131,7 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection } dest := request.Destination() - if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { + if inbound.Source.IsValid() { log.Record(&log.AccessMessage{ From: inbound.Source, To: dest, @@ -136,7 +141,6 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection } newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx)) - ctx = protocol.ContextWithUser(ctx, request.User) ctx = protocol.ContextWithRequestHeader(ctx, request) udpServer.Dispatch(ctx, dest, data) } @@ -162,6 +166,12 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, } conn.SetReadDeadline(time.Time{}) + inbound := session.InboundFromContext(ctx) + if inbound == nil { + panic("no inbound metadata") + } + inbound.User = s.user + dest := request.Destination() log.Record(&log.AccessMessage{ From: conn.RemoteAddr(), @@ -171,8 +181,6 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, }) newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx)) - ctx = protocol.ContextWithUser(ctx, request.User) - ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 8f6e9d815..f7aa925a4 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -264,8 +264,13 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i newError("unable to set back read deadline").Base(err).WriteToLog(session.ExportIDToError(ctx)) } + inbound := session.InboundFromContext(ctx) + if inbound == nil { + panic("no inbound metadata") + } + inbound.User = request.User + sessionPolicy = h.policyManager.ForLevel(request.User.Level) - ctx = protocol.ContextWithUser(ctx, request.User) ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)