diff --git a/common/registry/registry.go b/common/registry/registry.go index 8b510aea8..34c73bce1 100644 --- a/common/registry/registry.go +++ b/common/registry/registry.go @@ -2,9 +2,11 @@ package registry import ( "bytes" + "context" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" "github.com/v2fly/v2ray-core/v4/common/protoext" + "github.com/v2fly/v2ray-core/v4/common/protofilter" "github.com/v2fly/v2ray-core/v4/common/serial" protov2 "google.golang.org/protobuf/proto" "reflect" @@ -34,7 +36,7 @@ func (i *implementationRegistry) findImplementationByAlias(interfaceType, alias return implSet.findImplementationByAlias(alias) } -func (i *implementationRegistry) LoadImplementationByAlias(interfaceType, alias string, data []byte) (proto.Message, error) { +func (i *implementationRegistry) LoadImplementationByAlias(ctx context.Context, interfaceType, alias string, data []byte) (proto.Message, error) { var implementationFullName string if strings.HasPrefix(alias, "#") { @@ -61,6 +63,11 @@ func (i *implementationRegistry) LoadImplementationByAlias(interfaceType, alias return nil, newError("unable to parse json content").Base(err) } + implementationConfigInstancev2 := proto.MessageV2(implementationConfigInstance) + if err := protofilter.FilterProtoConfig(ctx, implementationConfigInstancev2); err != nil { + return nil, err + } + return implementationConfigInstance.(proto.Message), nil } @@ -108,14 +115,14 @@ func registerImplementation(proto interface{}, loader CustomLoader) error { } type LoadByAlias interface { - LoadImplementationByAlias(interfaceType, alias string, data []byte) (proto.Message, error) + LoadImplementationByAlias(ctx context.Context, interfaceType, alias string, data []byte) (proto.Message, error) } -func LoadImplementationByAlias(interfaceType, alias string, data []byte) (proto.Message, error) { +func LoadImplementationByAlias(ctx context.Context, interfaceType, alias string, data []byte) (proto.Message, error) { initialized.Do(func() { for _, v := range registerRequests { registerImplementation(v.proto, v.loader) } }) - return globalImplementationRegistry.LoadImplementationByAlias(interfaceType, alias, data) + return globalImplementationRegistry.LoadImplementationByAlias(ctx, interfaceType, alias, data) } diff --git a/common/session/context.go b/common/session/context.go index 2e43b9205..bc9ef3a70 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -15,6 +15,7 @@ const ( sockoptSessionKey trackedConnectionErrorKey handlerSessionKey + environmentKey ) // ContextWithID returns a new context with the given ID. @@ -133,3 +134,14 @@ func SubmitOutboundErrorToOriginator(ctx context.Context, err error) { func TrackedConnectionError(ctx context.Context, tracker TrackedRequestErrorFeedback) context.Context { return context.WithValue(ctx, trackedConnectionErrorKey, tracker) } + +func ContextWithEnvironment(ctx context.Context, environment interface{}) context.Context { + return context.WithValue(ctx, environmentKey, environment) +} + +func EnvironmentFromContext(ctx context.Context) interface{} { + if environment, ok := ctx.Value(environmentKey).(interface{}); ok { + return environment + } + return nil +}