diff --git a/common/registry/registry.go b/common/registry/registry.go index 26280b6c6..8b510aea8 100644 --- a/common/registry/registry.go +++ b/common/registry/registry.go @@ -6,8 +6,10 @@ import ( "github.com/golang/protobuf/proto" "github.com/v2fly/v2ray-core/v4/common/protoext" "github.com/v2fly/v2ray-core/v4/common/serial" - "google.golang.org/protobuf/reflect/protoreflect" + protov2 "google.golang.org/protobuf/proto" + "reflect" "strings" + "sync" ) type implementationRegistry struct { @@ -69,10 +71,33 @@ func newImplementationRegistry() *implementationRegistry { var globalImplementationRegistry = newImplementationRegistry() +var initialized = &sync.Once{} + +type registerRequest struct { + proto interface{} + loader CustomLoader +} + +var registerRequests []registerRequest + // RegisterImplementation register an implementation of a type of interface // loader(CustomLoader) is a private API, its interface is subject to breaking changes -func RegisterImplementation(proto protoreflect.MessageDescriptor, loader CustomLoader) error { - msgDesc := proto +func RegisterImplementation(proto interface{}, loader CustomLoader) error { + registerRequests = append(registerRequests, registerRequest{ + proto: proto, + loader: loader, + }) + return nil +} + +func registerImplementation(proto interface{}, loader CustomLoader) error { + protoReflect := reflect.New(reflect.TypeOf(proto).Elem()) + var proto2 protov2.Message + assignMessage := func(msg protov2.Message) { + proto2 = msg + } + reflect.ValueOf(assignMessage).Call([]reflect.Value{protoReflect}) + msgDesc := proto2.ProtoReflect().Descriptor() fullName := string(msgDesc.FullName()) msgOpts, err := protoext.GetMessageOptions(msgDesc) if err != nil { @@ -87,5 +112,10 @@ type LoadByAlias interface { } func LoadImplementationByAlias(interfaceType, alias string, data []byte) (proto.Message, error) { + initialized.Do(func() { + for _, v := range registerRequests { + registerImplementation(v.proto, v.loader) + } + }) return globalImplementationRegistry.LoadImplementationByAlias(interfaceType, alias, data) }