diff --git a/infra/conf/v2ray.go b/infra/conf/v2ray.go index 58c9c6dad..afe1fee6d 100644 --- a/infra/conf/v2ray.go +++ b/infra/conf/v2ray.go @@ -304,6 +304,86 @@ type Config struct { Reverse *ReverseConfig `json:"reverse"` } +func (c *Config) findInboundTag(tag string) int { + found := -1 + for idx, ib := range c.InboundConfigs { + if ib.Tag == tag { + found = idx + break + } + } + return found +} + +func (c *Config) findOutboundTag(tag string) int { + found := -1 + for idx, ob := range c.OutboundConfigs { + if ob.Tag == tag { + found = idx + break + } + } + return found +} + +// Override method accepts another Config overrides the current attribute +func (c *Config) Override(o *Config) { + + // only process the non-deprecated members + + if o.LogConfig != nil { + c.LogConfig = o.LogConfig + } + if o.RouterConfig != nil { + c.RouterConfig = o.RouterConfig + } + if o.DNSConfig != nil { + c.DNSConfig = o.DNSConfig + } + if o.Transport != nil { + c.Transport = o.Transport + } + if o.Policy != nil { + c.Policy = o.Policy + } + if o.Api != nil { + c.Api = o.Api + } + if o.Stats != nil { + c.Stats = o.Stats + } + if o.Reverse != nil { + c.Reverse = o.Reverse + } + + // update the Inbound in slice if the only one in overide config has same tag + if len(o.InboundConfigs) > 0 { + if len(c.InboundConfigs) > 0 && len(o.InboundConfigs) == 1 { + if idx := c.findInboundTag(o.InboundConfigs[0].Tag); idx > -1 { + c.InboundConfigs[idx] = o.InboundConfigs[0] + newError("updated inbound with tag: ", o.InboundConfigs[0].Tag).AtInfo().WriteToLog() + } else { + c.InboundConfigs = append(c.InboundConfigs, o.InboundConfigs[0]) + } + } else { + c.InboundConfigs = o.InboundConfigs + } + } + + // update the Outbound in slice if the only one in overide config has same tag + if len(o.OutboundConfigs) > 0 { + if len(c.OutboundConfigs) > 0 && len(o.OutboundConfigs) == 1 { + if idx := c.findOutboundTag(o.OutboundConfigs[0].Tag); idx > -1 { + c.OutboundConfigs[idx] = o.OutboundConfigs[0] + } else { + c.OutboundConfigs = append(c.OutboundConfigs, o.OutboundConfigs[0]) + } + } else { + c.OutboundConfigs = o.OutboundConfigs + } + } +} + func applyTransportConfig(s *StreamConfig, t *TransportConfig) { if s.TCPSettings == nil { s.TCPSettings = t.TCPConfig diff --git a/infra/conf/v2ray_test.go b/infra/conf/v2ray_test.go index d324d3ea6..9381af62e 100644 --- a/infra/conf/v2ray_test.go +++ b/infra/conf/v2ray_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" "v2ray.com/core" "v2ray.com/core/app/dispatcher" "v2ray.com/core/app/log" @@ -369,3 +370,72 @@ func TestMuxConfig_Build(t *testing.T) { }) } } + +func TestConfig_Override(t *testing.T) { + tests := []struct { + name string + orig *Config + over *Config + want *Config + }{ + {"combine/empty", + &Config{}, + &Config{ + LogConfig: &LogConfig{}, + RouterConfig: &RouterConfig{}, + DNSConfig: &DnsConfig{}, + Transport: &TransportConfig{}, + Policy: &PolicyConfig{}, + Api: &ApiConfig{}, + Stats: &StatsConfig{}, + Reverse: &ReverseConfig{}, + }, + &Config{ + LogConfig: &LogConfig{}, + RouterConfig: &RouterConfig{}, + DNSConfig: &DnsConfig{}, + Transport: &TransportConfig{}, + Policy: &PolicyConfig{}, + Api: &ApiConfig{}, + Stats: &StatsConfig{}, + Reverse: &ReverseConfig{}, + }, + }, + {"combine/newattr", + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "old"}}}, + &Config{LogConfig: &LogConfig{}}, + &Config{LogConfig: &LogConfig{}, InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "old"}}}}, + {"replace/inbounds", + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos0"}, InboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}, + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos0"}, InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}}, + {"replace/inbounds-all", + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos0"}, InboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}, InboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}, + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}, InboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}}, + {"replace/notag", + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{}, InboundDetourConfig{Protocol: "vmess"}}}, + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}, + &Config{InboundConfigs: []InboundDetourConfig{InboundDetourConfig{}, InboundDetourConfig{Protocol: "vmess"}, InboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}}, + {"replace/outbounds", + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos0"}, OutboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos0"}, OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}}, + {"replace/outbounds-all", + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos0"}, OutboundDetourConfig{Protocol: "vmess", Tag: "pos1"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}, OutboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}, OutboundDetourConfig{Tag: "pos2", Protocol: "kcp"}}}}, + {"replace/outbound-notag", + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{}, OutboundDetourConfig{Protocol: "vmess"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}, + &Config{OutboundConfigs: []OutboundDetourConfig{OutboundDetourConfig{}, OutboundDetourConfig{Protocol: "vmess"}, OutboundDetourConfig{Tag: "pos1", Protocol: "kcp"}}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.orig.Override(tt.over) + if r := cmp.Diff(tt.orig, tt.want); r != "" { + t.Error(r) + } + }) + } +}