1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-19 15:57:04 -05:00

Require Type() for Feature

This commit is contained in:
Darien Raymond 2018-10-12 23:57:56 +02:00
parent dcd26ec61f
commit d730637239
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
30 changed files with 162 additions and 49 deletions

View File

@ -21,7 +21,7 @@ type Commander struct {
server *grpc.Server server *grpc.Server
config Config config Config
v *core.Instance v *core.Instance
ohm outbound.HandlerManager ohm outbound.Manager
} }
// NewCommander creates a new Commander based on the given config. // NewCommander creates a new Commander based on the given config.
@ -32,7 +32,7 @@ func NewCommander(ctx context.Context, config *Config) (*Commander, error) {
ohm: v.OutboundHandlerManager(), ohm: v.OutboundHandlerManager(),
v: v, v: v,
} }
if err := v.RegisterFeature((*Commander)(nil), c); err != nil { if err := v.RegisterFeature(c); err != nil {
return nil, err return nil, err
} }
return c, nil return c, nil

View File

@ -84,7 +84,7 @@ func (r *cachedReader) CloseError() {
// DefaultDispatcher is a default implementation of Dispatcher. // DefaultDispatcher is a default implementation of Dispatcher.
type DefaultDispatcher struct { type DefaultDispatcher struct {
ohm outbound.HandlerManager ohm outbound.Manager
router routing.Router router routing.Router
policy policy.Manager policy policy.Manager
stats feature_stats.Manager stats feature_stats.Manager
@ -100,12 +100,16 @@ func NewDefaultDispatcher(ctx context.Context, config *Config) (*DefaultDispatch
stats: v.Stats(), stats: v.Stats(),
} }
if err := v.RegisterFeature((*routing.Dispatcher)(nil), d); err != nil { if err := v.RegisterFeature(d); err != nil {
return nil, newError("unable to register Dispatcher").Base(err) return nil, newError("unable to register Dispatcher").Base(err)
} }
return d, nil return d, nil
} }
func (*DefaultDispatcher) Type() interface{} {
return routing.DispatcherType()
}
// Start implements common.Runnable. // Start implements common.Runnable.
func (*DefaultDispatcher) Start() error { func (*DefaultDispatcher) Start() error {
return nil return nil

View File

@ -41,7 +41,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
server.hosts = hosts server.hosts = hosts
v := core.MustFromContext(ctx) v := core.MustFromContext(ctx)
if err := v.RegisterFeature((*dns.Client)(nil), server); err != nil { if err := v.RegisterFeature(server); err != nil {
return nil, newError("unable to register DNSClient.").Base(err) return nil, newError("unable to register DNSClient.").Base(err)
} }
@ -97,6 +97,10 @@ func New(ctx context.Context, config *Config) (*Server, error) {
return server, nil return server, nil
} }
func (*Server) Type() interface{} {
return dns.ClientType()
}
// Start implements common.Runnable. // Start implements common.Runnable.
func (s *Server) Start() error { func (s *Server) Start() error {
return nil return nil

View File

@ -30,7 +30,7 @@ func New(ctx context.Context, config *Config) (*Instance, error) {
v := core.FromContext(ctx) v := core.FromContext(ctx)
if v != nil { if v != nil {
common.Must(v.RegisterFeature((*log.Handler)(nil), g)) common.Must(v.RegisterFeature(g))
} }
return g, nil return g, nil

View File

@ -30,7 +30,7 @@ func New(ctx context.Context, config *Config) (*Instance, error) {
v := core.FromContext(ctx) v := core.FromContext(ctx)
if v != nil { if v != nil {
if err := v.RegisterFeature((*policy.Manager)(nil), m); err != nil { if err := v.RegisterFeature(m); err != nil {
return nil, newError("unable to register PolicyManager in core").Base(err).AtError() return nil, newError("unable to register PolicyManager in core").Base(err).AtError()
} }
} }
@ -38,6 +38,11 @@ func New(ctx context.Context, config *Config) (*Instance, error) {
return m, nil return m, nil
} }
// Type implements common.HasType.
func (*Instance) Type() interface{} {
return policy.ManagerType()
}
// ForLevel implements policy.Manager. // ForLevel implements policy.Manager.
func (m *Instance) ForLevel(level uint32) policy.Session { func (m *Instance) ForLevel(level uint32) policy.Session {
if p, ok := m.levels[level]; ok { if p, ok := m.levels[level]; ok {

View File

@ -64,7 +64,7 @@ func (op *RemoveUserOperation) ApplyInbound(ctx context.Context, handler inbound
type handlerServer struct { type handlerServer struct {
s *core.Instance s *core.Instance
ihm inbound.Manager ihm inbound.Manager
ohm outbound.HandlerManager ohm outbound.Manager
} }
func (s *handlerServer) AddInbound(ctx context.Context, request *AddInboundRequest) (*AddInboundResponse, error) { func (s *handlerServer) AddInbound(ctx context.Context, request *AddInboundRequest) (*AddInboundResponse, error) {

View File

@ -28,12 +28,17 @@ func New(ctx context.Context, config *proxyman.InboundConfig) (*Manager, error)
taggedHandlers: make(map[string]inbound.Handler), taggedHandlers: make(map[string]inbound.Handler),
} }
v := core.MustFromContext(ctx) v := core.MustFromContext(ctx)
if err := v.RegisterFeature((*inbound.Manager)(nil), m); err != nil { if err := v.RegisterFeature(m); err != nil {
return nil, newError("unable to register InboundHandlerManager").Base(err) return nil, newError("unable to register InboundHandlerManager").Base(err)
} }
return m, nil return m, nil
} }
// Type implements common.HasType.
func (*Manager) Type() interface{} {
return inbound.ManagerType()
}
// AddHandler implements inbound.Manager. // AddHandler implements inbound.Manager.
func (m *Manager) AddHandler(ctx context.Context, handler inbound.Handler) error { func (m *Manager) AddHandler(ctx context.Context, handler inbound.Handler) error {
m.access.Lock() m.access.Lock()

View File

@ -310,6 +310,10 @@ func NewServer(ctx context.Context) *Server {
return s return s
} }
func (s *Server) Type() interface{} {
return s.dispatcher.Type()
}
func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (*vio.Link, error) { func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (*vio.Link, error) {
if dest.Address != muxCoolAddress { if dest.Address != muxCoolAddress {
return s.dispatcher.Dispatch(ctx, dest) return s.dispatcher.Dispatch(ctx, dest)

View File

@ -21,7 +21,7 @@ type Handler struct {
senderSettings *proxyman.SenderConfig senderSettings *proxyman.SenderConfig
streamSettings *internet.MemoryStreamConfig streamSettings *internet.MemoryStreamConfig
proxy proxy.Outbound proxy proxy.Outbound
outboundManager outbound.HandlerManager outboundManager outbound.Manager
mux *mux.ClientManager mux *mux.ClientManager
} }

View File

@ -12,5 +12,5 @@ func TestInterfaces(t *testing.T) {
assert := With(t) assert := With(t)
assert((*Handler)(nil), Implements, (*outbound.Handler)(nil)) assert((*Handler)(nil), Implements, (*outbound.Handler)(nil))
assert((*Manager)(nil), Implements, (*outbound.HandlerManager)(nil)) assert((*Manager)(nil), Implements, (*outbound.Manager)(nil))
} }

View File

@ -27,12 +27,16 @@ func New(ctx context.Context, config *proxyman.OutboundConfig) (*Manager, error)
taggedHandler: make(map[string]outbound.Handler), taggedHandler: make(map[string]outbound.Handler),
} }
v := core.MustFromContext(ctx) v := core.MustFromContext(ctx)
if err := v.RegisterFeature((*outbound.HandlerManager)(nil), m); err != nil { if err := v.RegisterFeature(m); err != nil {
return nil, newError("unable to register OutboundHandlerManager").Base(err) return nil, newError("unable to register outbound.Manager").Base(err)
} }
return m, nil return m, nil
} }
func (m *Manager) Type() interface{} {
return outbound.ManagerType()
}
// Start implements core.Feature // Start implements core.Feature
func (m *Manager) Start() error { func (m *Manager) Start() error {
m.access.Lock() m.access.Lock()
@ -73,7 +77,7 @@ func (m *Manager) Close() error {
return nil return nil
} }
// GetDefaultHandler implements outbound.HandlerManager. // GetDefaultHandler implements outbound.Manager.
func (m *Manager) GetDefaultHandler() outbound.Handler { func (m *Manager) GetDefaultHandler() outbound.Handler {
m.access.RLock() m.access.RLock()
defer m.access.RUnlock() defer m.access.RUnlock()
@ -84,7 +88,7 @@ func (m *Manager) GetDefaultHandler() outbound.Handler {
return m.defaultHandler return m.defaultHandler
} }
// GetHandler implements outbound.HandlerManager. // GetHandler implements outbound.Manager.
func (m *Manager) GetHandler(tag string) outbound.Handler { func (m *Manager) GetHandler(tag string) outbound.Handler {
m.access.RLock() m.access.RLock()
defer m.access.RUnlock() defer m.access.RUnlock()
@ -94,7 +98,7 @@ func (m *Manager) GetHandler(tag string) outbound.Handler {
return nil return nil
} }
// AddHandler implements outbound.HandlerManager. // AddHandler implements outbound.Manager.
func (m *Manager) AddHandler(ctx context.Context, handler outbound.Handler) error { func (m *Manager) AddHandler(ctx context.Context, handler outbound.Handler) error {
m.access.Lock() m.access.Lock()
defer m.access.Unlock() defer m.access.Unlock()
@ -117,7 +121,7 @@ func (m *Manager) AddHandler(ctx context.Context, handler outbound.Handler) erro
return nil return nil
} }
// RemoveHandler implements outbound.HandlerManager. // RemoveHandler implements outbound.Manager.
func (m *Manager) RemoveHandler(ctx context.Context, tag string) error { func (m *Manager) RemoveHandler(ctx context.Context, tag string) error {
if len(tag) == 0 { if len(tag) == 0 {
return common.ErrNoClue return common.ErrNoClue

View File

@ -39,7 +39,7 @@ func NewRouter(ctx context.Context, config *Config) (*Router, error) {
r.rules[idx].Condition = cond r.rules[idx].Condition = cond
} }
if err := v.RegisterFeature((*routing.Router)(nil), r); err != nil { if err := v.RegisterFeature(r); err != nil {
return nil, newError("unable to register Router").Base(err) return nil, newError("unable to register Router").Base(err)
} }
return r, nil return r, nil
@ -124,6 +124,11 @@ func (*Router) Close() error {
return nil return nil
} }
// Type implement common.HasType.
func (*Router) Type() interface{} {
return routing.RouterType()
}
func init() { func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return NewRouter(ctx, config.(*Config)) return NewRouter(ctx, config.(*Config))

View File

@ -44,7 +44,7 @@ func NewManager(ctx context.Context, config *Config) (*Manager, error) {
v := core.FromContext(ctx) v := core.FromContext(ctx)
if v != nil { if v != nil {
if err := v.RegisterFeature((*stats.Manager)(nil), m); err != nil { if err := v.RegisterFeature(m); err != nil {
return nil, newError("failed to register StatManager").Base(err) return nil, newError("failed to register StatManager").Base(err)
} }
} }
@ -52,6 +52,10 @@ func NewManager(ctx context.Context, config *Config) (*Manager, error) {
return m, nil return m, nil
} }
func (*Manager) Type() interface{} {
return stats.ManagerType()
}
func (m *Manager) RegisterCounter(name string) (stats.Counter, error) { func (m *Manager) RegisterCounter(name string) (stats.Counter, error) {
m.access.Lock() m.access.Lock()
defer m.access.Unlock() defer m.access.Unlock()

View File

@ -25,6 +25,7 @@ type Runnable interface {
// HasType is the interface for objects that knows its type. // HasType is the interface for objects that knows its type.
type HasType interface { type HasType interface {
// Type returns the type of the object. // Type returns the type of the object.
// Usually it returns (*Type)(nil) of the object.
Type() interface{} Type() interface{}
} }

4
dns.go
View File

@ -13,6 +13,10 @@ type syncDNSClient struct {
dns.Client dns.Client
} }
func (d *syncDNSClient) Type() interface{} {
return dns.ClientType()
}
func (d *syncDNSClient) LookupIP(host string) ([]net.IP, error) { func (d *syncDNSClient) LookupIP(host string) ([]net.IP, error) {
d.RLock() d.RLock()
defer d.RUnlock() defer d.RUnlock()

View File

@ -10,3 +10,7 @@ type Client interface {
features.Feature features.Feature
LookupIP(host string) ([]net.IP, error) LookupIP(host string) ([]net.IP, error)
} }
func ClientType() interface{} {
return (*Client)(nil)
}

View File

@ -5,5 +5,6 @@ import "v2ray.com/core/common"
// Feature is the interface for V2Ray features. All features must implement this interface. // Feature is the interface for V2Ray features. All features must implement this interface.
// All existing features have an implementation in app directory. These features can be replaced by third-party ones. // All existing features have an implementation in app directory. These features can be replaced by third-party ones.
type Feature interface { type Feature interface {
common.HasType
common.Runnable common.Runnable
} }

View File

@ -23,9 +23,13 @@ type Manager interface {
features.Feature features.Feature
// GetHandlers returns an InboundHandler for the given tag. // GetHandlers returns an InboundHandler for the given tag.
GetHandler(ctx context.Context, tag string) (Handler, error) GetHandler(ctx context.Context, tag string) (Handler, error)
// AddHandler adds the given handler into this InboundHandlerManager. // AddHandler adds the given handler into this Manager.
AddHandler(ctx context.Context, handler Handler) error AddHandler(ctx context.Context, handler Handler) error
// RemoveHandler removes a handler from InboundHandlerManager. // RemoveHandler removes a handler from Manager.
RemoveHandler(ctx context.Context, tag string) error RemoveHandler(ctx context.Context, tag string) error
} }
func ManagerType() interface{} {
return (*Manager)(nil)
}

View File

@ -15,16 +15,20 @@ type Handler interface {
Dispatch(ctx context.Context, link *vio.Link) Dispatch(ctx context.Context, link *vio.Link)
} }
// HandlerManager is a feature that manages outbound.Handlers. // Manager is a feature that manages outbound.Handlers.
type HandlerManager interface { type Manager interface {
features.Feature features.Feature
// GetHandler returns an outbound.Handler for the given tag. // GetHandler returns an outbound.Handler for the given tag.
GetHandler(tag string) Handler GetHandler(tag string) Handler
// GetDefaultHandler returns the default outbound.Handler. It is usually the first outbound.Handler specified in the configuration. // GetDefaultHandler returns the default outbound.Handler. It is usually the first outbound.Handler specified in the configuration.
GetDefaultHandler() Handler GetDefaultHandler() Handler
// AddHandler adds a handler into this outbound.HandlerManager. // AddHandler adds a handler into this outbound.Manager.
AddHandler(ctx context.Context, handler Handler) error AddHandler(ctx context.Context, handler Handler) error
// RemoveHandler removes a handler from outbound.HandlerManager. // RemoveHandler removes a handler from outbound.Manager.
RemoveHandler(ctx context.Context, tag string) error RemoveHandler(ctx context.Context, tag string) error
} }
func ManagerType() interface{} {
return (*Manager)(nil)
}

View File

@ -67,6 +67,10 @@ type Manager interface {
ForSystem() System ForSystem() System
} }
func ManagerType() interface{} {
return (*Manager)(nil)
}
var defaultBufferSize int32 var defaultBufferSize int32
func init() { func init() {

View File

@ -16,3 +16,7 @@ type Dispatcher interface {
// Dispatch returns a Ray for transporting data for the given request. // Dispatch returns a Ray for transporting data for the given request.
Dispatch(ctx context.Context, dest net.Destination) (*vio.Link, error) Dispatch(ctx context.Context, dest net.Destination) (*vio.Link, error)
} }
func DispatcherType() interface{} {
return (*Dispatcher)(nil)
}

View File

@ -13,3 +13,7 @@ type Router interface {
// PickRoute returns a tag of an OutboundHandler based on the given context. // PickRoute returns a tag of an OutboundHandler based on the given context.
PickRoute(ctx context.Context) (string, error) PickRoute(ctx context.Context) (string, error)
} }
func RouterType() interface{} {
return (*Router)(nil)
}

View File

@ -24,3 +24,7 @@ func GetOrRegisterCounter(m Manager, name string) (Counter, error) {
return m.RegisterCounter(name) return m.RegisterCounter(name)
} }
func ManagerType() interface{} {
return (*Manager)(nil)
}

View File

@ -14,6 +14,10 @@ type syncInboundHandlerManager struct {
inbound.Manager inbound.Manager
} }
func (*syncInboundHandlerManager) Type() interface{} {
return inbound.ManagerType()
}
func (m *syncInboundHandlerManager) GetHandler(ctx context.Context, tag string) (inbound.Handler, error) { func (m *syncInboundHandlerManager) GetHandler(ctx context.Context, tag string) (inbound.Handler, error) {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
@ -68,61 +72,65 @@ func (m *syncInboundHandlerManager) Set(manager inbound.Manager) {
type syncOutboundHandlerManager struct { type syncOutboundHandlerManager struct {
sync.RWMutex sync.RWMutex
outbound.HandlerManager outbound.Manager
}
func (*syncOutboundHandlerManager) Type() interface{} {
return outbound.ManagerType()
} }
func (m *syncOutboundHandlerManager) GetHandler(tag string) outbound.Handler { func (m *syncOutboundHandlerManager) GetHandler(tag string) outbound.Handler {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
if m.HandlerManager == nil { if m.Manager == nil {
return nil return nil
} }
return m.HandlerManager.GetHandler(tag) return m.Manager.GetHandler(tag)
} }
func (m *syncOutboundHandlerManager) GetDefaultHandler() outbound.Handler { func (m *syncOutboundHandlerManager) GetDefaultHandler() outbound.Handler {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
if m.HandlerManager == nil { if m.Manager == nil {
return nil return nil
} }
return m.HandlerManager.GetDefaultHandler() return m.Manager.GetDefaultHandler()
} }
func (m *syncOutboundHandlerManager) AddHandler(ctx context.Context, handler outbound.Handler) error { func (m *syncOutboundHandlerManager) AddHandler(ctx context.Context, handler outbound.Handler) error {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
if m.HandlerManager == nil { if m.Manager == nil {
return newError("OutboundHandlerManager not set.").AtError() return newError("OutboundHandlerManager not set.").AtError()
} }
return m.HandlerManager.AddHandler(ctx, handler) return m.Manager.AddHandler(ctx, handler)
} }
func (m *syncOutboundHandlerManager) Start() error { func (m *syncOutboundHandlerManager) Start() error {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
if m.HandlerManager == nil { if m.Manager == nil {
return newError("OutboundHandlerManager not set.").AtError() return newError("OutboundHandlerManager not set.").AtError()
} }
return m.HandlerManager.Start() return m.Manager.Start()
} }
func (m *syncOutboundHandlerManager) Close() error { func (m *syncOutboundHandlerManager) Close() error {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
return common.Close(m.HandlerManager) return common.Close(m.Manager)
} }
func (m *syncOutboundHandlerManager) Set(manager outbound.HandlerManager) { func (m *syncOutboundHandlerManager) Set(manager outbound.Manager) {
if manager == nil { if manager == nil {
return return
} }
@ -130,6 +138,6 @@ func (m *syncOutboundHandlerManager) Set(manager outbound.HandlerManager) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
common.Close(m.HandlerManager) // nolint: errcheck common.Close(m.Manager) // nolint: errcheck
m.HandlerManager = manager m.Manager = manager
} }

View File

@ -13,6 +13,10 @@ type syncPolicyManager struct {
policy.Manager policy.Manager
} }
func (*syncPolicyManager) Type() interface{} {
return policy.ManagerType()
}
func (m *syncPolicyManager) ForLevel(level uint32) policy.Session { func (m *syncPolicyManager) ForLevel(level uint32) policy.Session {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()

View File

@ -15,6 +15,10 @@ type syncDispatcher struct {
routing.Dispatcher routing.Dispatcher
} }
func (*syncDispatcher) Type() interface{} {
return routing.DispatcherType()
}
func (d *syncDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*vio.Link, error) { func (d *syncDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*vio.Link, error) {
d.RLock() d.RLock()
defer d.RUnlock() defer d.RUnlock()
@ -61,6 +65,10 @@ type syncRouter struct {
routing.Router routing.Router
} }
func (*syncRouter) Type() interface{} {
return routing.RouterType()
}
func (r *syncRouter) PickRoute(ctx context.Context) (string, error) { func (r *syncRouter) PickRoute(ctx context.Context) (string, error) {
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()

View File

@ -11,6 +11,10 @@ type syncStatManager struct {
stats.Manager stats.Manager
} }
func (*syncStatManager) Type() interface{} {
return stats.ManagerType()
}
func (s *syncStatManager) Start() error { func (s *syncStatManager) Start() error {
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()

View File

@ -3,12 +3,12 @@ package scenarios
import ( import (
"testing" "testing"
"v2ray.com/core/app/router"
xproxy "golang.org/x/net/proxy" xproxy "golang.org/x/net/proxy"
socks4 "h12.io/socks" socks4 "h12.io/socks"
"v2ray.com/core" "v2ray.com/core"
"v2ray.com/core/app/proxyman" "v2ray.com/core/app/proxyman"
"v2ray.com/core/app/router"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"

View File

@ -9,6 +9,7 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/vio" "v2ray.com/core/common/vio"
"v2ray.com/core/features/routing"
. "v2ray.com/core/transport/internet/udp" . "v2ray.com/core/transport/internet/udp"
"v2ray.com/core/transport/pipe" "v2ray.com/core/transport/pipe"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
@ -30,6 +31,10 @@ func (d *TestDispatcher) Close() error {
return nil return nil
} }
func (*TestDispatcher) Type() interface{} {
return routing.DispatcherType()
}
func TestSameDestinationDispatching(t *testing.T) { func TestSameDestinationDispatching(t *testing.T) {
assert := With(t) assert := With(t)

View File

@ -140,10 +140,10 @@ func (s *Instance) Start() error {
// RegisterFeature registers the given feature into V2Ray. // RegisterFeature registers the given feature into V2Ray.
// If feature is one of the following types, the corresponding feature in this Instance // If feature is one of the following types, the corresponding feature in this Instance
// will be replaced: DNSClient, PolicyManager, Router, Dispatcher, InboundHandlerManager, OutboundHandlerManager. // will be replaced: DNSClient, PolicyManager, Router, Dispatcher, InboundHandlerManager, OutboundHandlerManager.
func (s *Instance) RegisterFeature(feature interface{}, instance features.Feature) error { func (s *Instance) RegisterFeature(instance features.Feature) error {
running := false running := false
switch feature.(type) { switch instance.Type().(type) {
case dns.Client, *dns.Client: case dns.Client, *dns.Client:
s.dnsClient.Set(instance.(dns.Client)) s.dnsClient.Set(instance.(dns.Client))
case policy.Manager, *policy.Manager: case policy.Manager, *policy.Manager:
@ -154,8 +154,8 @@ func (s *Instance) RegisterFeature(feature interface{}, instance features.Featur
s.dispatcher.Set(instance.(routing.Dispatcher)) s.dispatcher.Set(instance.(routing.Dispatcher))
case inbound.Manager, *inbound.Manager: case inbound.Manager, *inbound.Manager:
s.ihm.Set(instance.(inbound.Manager)) s.ihm.Set(instance.(inbound.Manager))
case outbound.HandlerManager, *outbound.HandlerManager: case outbound.Manager, *outbound.Manager:
s.ohm.Set(instance.(outbound.HandlerManager)) s.ohm.Set(instance.(outbound.Manager))
case stats.Manager, *stats.Manager: case stats.Manager, *stats.Manager:
s.stats.Set(instance.(stats.Manager)) s.stats.Set(instance.(stats.Manager))
default: default:
@ -178,14 +178,29 @@ func (s *Instance) allFeatures() []features.Feature {
// GetFeature returns a feature that was registered in this Instance. Nil if not found. // GetFeature returns a feature that was registered in this Instance. Nil if not found.
// The returned Feature must implement common.HasType and whose type equals to the given feature type. // The returned Feature must implement common.HasType and whose type equals to the given feature type.
func (s *Instance) GetFeature(featureType interface{}) features.Feature { func (s *Instance) GetFeature(featureType interface{}) features.Feature {
for _, f := range s.features { switch featureType.(type) {
if hasType, ok := f.(common.HasType); ok { case dns.Client, *dns.Client:
if hasType.Type() == featureType { return s.DNSClient()
case policy.Manager, *policy.Manager:
return s.PolicyManager()
case routing.Router, *routing.Router:
return s.Router()
case routing.Dispatcher, *routing.Dispatcher:
return s.Dispatcher()
case inbound.Manager, *inbound.Manager:
return s.InboundHandlerManager()
case outbound.Manager, *outbound.Manager:
return s.OutboundHandlerManager()
case stats.Manager, *stats.Manager:
return s.Stats()
default:
for _, f := range s.features {
if f.Type() == featureType {
return f return f
} }
} }
return nil
} }
return nil
} }
// DNSClient returns the dns.Client used by this Instance. The returned dns.Client is always functional. // DNSClient returns the dns.Client used by this Instance. The returned dns.Client is always functional.
@ -214,7 +229,7 @@ func (s *Instance) InboundHandlerManager() inbound.Manager {
} }
// OutboundHandlerManager returns the OutboundHandlerManager used by this Instance. If OutboundHandlerManager was not registered before, the returned value doesn't work. // OutboundHandlerManager returns the OutboundHandlerManager used by this Instance. If OutboundHandlerManager was not registered before, the returned value doesn't work.
func (s *Instance) OutboundHandlerManager() outbound.HandlerManager { func (s *Instance) OutboundHandlerManager() outbound.Manager {
return &(s.ohm) return &(s.ohm)
} }