From 867bbb429e0565a1f36e74b109525ac12c7a5e9e Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Sat, 6 Mar 2021 14:33:20 +0000 Subject: [PATCH] create session content in the context if do not exist yet --- app/dispatcher/default.go | 2 +- common/session/context.go | 16 +++++++++++++--- transport/internet/tagged/taggedimpl/impl.go | 2 +- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 40beed96d..45f9f5cb0 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -295,7 +295,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. var handler outbound.Handler if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" { - session.SetForcedOutboundTagToContext(ctx, "") + ctx = session.SetForcedOutboundTagToContext(ctx, "") if h := d.ohm.GetHandler(forcedOutboundTag); h != nil { newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) handler = h diff --git a/common/session/context.go b/common/session/context.go index 8fdb01fa6..4300b1f6c 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -1,6 +1,8 @@ package session -import "context" +import ( + "context" +) type sessionKey int @@ -92,8 +94,12 @@ func GetTransportLayerProxyTagFromContext(ctx context.Context) string { return ContentFromContext(ctx).Attribute("transportLayerOutgoingTag") } -func SetTransportLayerProxyTagToContext(ctx context.Context, tag string) { +func SetTransportLayerProxyTagToContext(ctx context.Context, tag string) context.Context { + if contentFromContext := ContentFromContext(ctx); contentFromContext == nil { + ctx = ContextWithContent(ctx, &Content{}) + } ContentFromContext(ctx).SetAttribute("transportLayerOutgoingTag", tag) + return ctx } func GetForcedOutboundTagFromContext(ctx context.Context) string { @@ -103,6 +109,10 @@ func GetForcedOutboundTagFromContext(ctx context.Context) string { return ContentFromContext(ctx).Attribute("forcedOutboundTag") } -func SetForcedOutboundTagToContext(ctx context.Context, tag string) { +func SetForcedOutboundTagToContext(ctx context.Context, tag string) context.Context { + if contentFromContext := ContentFromContext(ctx); contentFromContext == nil { + ctx = ContextWithContent(ctx, &Content{}) + } ContentFromContext(ctx).SetAttribute("forcedOutboundTag", tag) + return ctx } diff --git a/transport/internet/tagged/taggedimpl/impl.go b/transport/internet/tagged/taggedimpl/impl.go index bf5fde74a..6d32be081 100644 --- a/transport/internet/tagged/taggedimpl/impl.go +++ b/transport/internet/tagged/taggedimpl/impl.go @@ -26,7 +26,7 @@ func DialTaggedOutbound(ctx context.Context, dest net.Destination, tag string) ( content.SkipDNSResolve = true ctx = session.ContextWithContent(ctx, content) - session.SetForcedOutboundTagToContext(ctx, tag) + ctx = session.SetForcedOutboundTagToContext(ctx, tag) r, err := dispatcher.Dispatch(ctx, dest) if err != nil {