From aa40b8b8352615e636fcfcdd5c707952152f4fad Mon Sep 17 00:00:00 2001 From: rurirei <72071920+rurirei@users.noreply.github.com> Date: Sun, 4 Apr 2021 19:28:00 +0800 Subject: [PATCH] Feat: core.ToContext(ctx, v) for ctx initialization (#852) --- context.go | 16 ++++++++++++++++ context_test.go | 13 ++++++++++++- functions.go | 10 +++------- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/context.go b/context.go index 84cd2396b..3146689ac 100644 --- a/context.go +++ b/context.go @@ -27,3 +27,19 @@ func MustFromContext(ctx context.Context) *Instance { } return v } + +// ToContext returns ctx from the given context, or creates an Instance if the context doesn't find that. +func ToContext(ctx context.Context, v *Instance) context.Context { + if FromContext(ctx) != v { + ctx = context.WithValue(ctx, v2rayKey, v) + } + return ctx +} + +// MustToContext returns ctx from the given context, or panics if not found that. +func MustToContext(ctx context.Context, v *Instance) context.Context { + if c := ToContext(ctx, v); c != ctx { + panic("V is not in context.") + } + return ctx +} diff --git a/context_test.go b/context_test.go index f52a48de5..cb069fd8c 100644 --- a/context_test.go +++ b/context_test.go @@ -7,7 +7,7 @@ import ( . "github.com/v2fly/v2ray-core/v4" ) -func TestContextPanic(t *testing.T) { +func TestFromContextPanic(t *testing.T) { defer func() { r := recover() if r == nil { @@ -17,3 +17,14 @@ func TestContextPanic(t *testing.T) { MustFromContext(context.Background()) } + +func TestToContextPanic(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Error("expect panic, but nil") + } + }() + + MustToContext(context.Background(), &Instance{}) +} diff --git a/functions.go b/functions.go index 836c1bb0e..115df1b08 100644 --- a/functions.go +++ b/functions.go @@ -16,7 +16,7 @@ import ( func CreateObject(v *Instance, config interface{}) (interface{}, error) { var ctx context.Context if v != nil { - ctx = context.WithValue(v.ctx, v2rayKey, v) + ctx = ToContext(v.ctx, v) } return common.CreateObject(ctx, config) } @@ -47,9 +47,7 @@ func StartInstance(configFormat string, configBytes []byte) (*Instance, error) { // // v2ray:api:stable func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, error) { - if FromContext(ctx) == nil { - ctx = context.WithValue(ctx, v2rayKey, v) - } + ctx = ToContext(ctx, v) dispatcher := v.GetFeature(routing.DispatcherType()) if dispatcher == nil { @@ -76,9 +74,7 @@ func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, err // // v2ray:api:beta func DialUDP(ctx context.Context, v *Instance) (net.PacketConn, error) { - if FromContext(ctx) == nil { - ctx = context.WithValue(ctx, v2rayKey, v) - } + ctx = ToContext(ctx, v) dispatcher := v.GetFeature(routing.DispatcherType()) if dispatcher == nil {