diff --git a/app/policy/config.go b/app/policy/config.go index 589a2c48e..4fc14f8c5 100644 --- a/app/policy/config.go +++ b/app/policy/config.go @@ -14,6 +14,40 @@ func (s *Second) Duration() time.Duration { return time.Second * time.Duration(s.Value) } +func defaultPolicy() *Policy { + p := core.DefaultPolicy() + + return &Policy{ + Timeout: &Policy_Timeout{ + Handshake: &Second{Value: uint32(p.Timeouts.Handshake / time.Second)}, + ConnectionIdle: &Second{Value: uint32(p.Timeouts.ConnectionIdle / time.Second)}, + UplinkOnly: &Second{Value: uint32(p.Timeouts.UplinkOnly / time.Second)}, + DownlinkOnly: &Second{Value: uint32(p.Timeouts.DownlinkOnly / time.Second)}, + }, + } +} + +func (p *Policy_Timeout) overrideWith(another *Policy_Timeout) { + if another.Handshake != nil { + p.Handshake = &Second{Value: another.Handshake.Value} + } + if another.ConnectionIdle != nil { + p.ConnectionIdle = &Second{Value: another.ConnectionIdle.Value} + } + if another.UplinkOnly != nil { + p.UplinkOnly = &Second{Value: another.UplinkOnly.Value} + } + if another.DownlinkOnly != nil { + p.DownlinkOnly = &Second{Value: another.DownlinkOnly.Value} + } +} + +func (p *Policy) overrideWith(another *Policy) { + if another.Timeout != nil { + p.Timeout.overrideWith(another.Timeout) + } +} + func (p *Policy) ToCorePolicy() core.Policy { var cp core.Policy if p.Timeout != nil { diff --git a/app/policy/manager.go b/app/policy/manager.go index 14a1c4c79..20e785154 100644 --- a/app/policy/manager.go +++ b/app/policy/manager.go @@ -9,17 +9,19 @@ import ( // Instance is an instance of Policy manager. type Instance struct { - levels map[uint32]core.Policy + levels map[uint32]*Policy } // New creates new Policy manager instance. func New(ctx context.Context, config *Config) (*Instance, error) { m := &Instance{ - levels: make(map[uint32]core.Policy), + levels: make(map[uint32]*Policy), } if len(config.Level) > 0 { for lv, p := range config.Level { - m.levels[lv] = p.ToCorePolicy().OverrideWith(core.DefaultPolicy()) + pp := defaultPolicy() + pp.overrideWith(p) + m.levels[lv] = pp } } @@ -36,7 +38,7 @@ func New(ctx context.Context, config *Config) (*Instance, error) { // ForLevel implements core.PolicyManager. func (m *Instance) ForLevel(level uint32) core.Policy { if p, ok := m.levels[level]; ok { - return p + return p.ToCorePolicy() } return core.DefaultPolicy() } diff --git a/policy.go b/policy.go index a333a84b0..1ac9ac1fb 100644 --- a/policy.go +++ b/policy.go @@ -13,40 +13,17 @@ type TimeoutPolicy struct { Handshake time.Duration // Timeout for connection being idle, i.e., there is no egress or ingress traffic in this connection. ConnectionIdle time.Duration - // Timeout for an uplink only connection, i.e., the downlink of the connection has ben closed. + // Timeout for an uplink only connection, i.e., the downlink of the connection has been closed. UplinkOnly time.Duration - // Timeout for an downlink only connection, i.e., the uplink of the connection has ben closed. + // Timeout for an downlink only connection, i.e., the uplink of the connection has been closed. DownlinkOnly time.Duration } -// OverrideWith overrides the current TimeoutPolicy with another one. All timeouts with zero value will be overridden with the new value. -func (p TimeoutPolicy) OverrideWith(another TimeoutPolicy) TimeoutPolicy { - if p.Handshake == 0 { - p.Handshake = another.Handshake - } - if p.ConnectionIdle == 0 { - p.ConnectionIdle = another.ConnectionIdle - } - if p.UplinkOnly == 0 { - p.UplinkOnly = another.UplinkOnly - } - if p.DownlinkOnly == 0 { - p.DownlinkOnly = another.DownlinkOnly - } - return p -} - // Policy is session based settings for controlling V2Ray requests. It contains various settings (or limits) that may differ for different users in the context. type Policy struct { Timeouts TimeoutPolicy // Timeout settings } -// OverrideWith overrides the current Policy with another one. All values with default value will be overridden. -func (p Policy) OverrideWith(another Policy) Policy { - p.Timeouts = p.Timeouts.OverrideWith(another.Timeouts) - return p -} - // PolicyManager is a feature that provides Policy for the given user by its id or level. type PolicyManager interface { Feature diff --git a/testing/scenarios/policy_test.go b/testing/scenarios/policy_test.go new file mode 100644 index 000000000..b3187cc64 --- /dev/null +++ b/testing/scenarios/policy_test.go @@ -0,0 +1,168 @@ +package scenarios + +import ( + "io" + "testing" + "time" + + "v2ray.com/core" + "v2ray.com/core/app/policy" + "v2ray.com/core/app/proxyman" + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" + "v2ray.com/core/common/serial" + "v2ray.com/core/common/uuid" + "v2ray.com/core/proxy/dokodemo" + "v2ray.com/core/proxy/freedom" + "v2ray.com/core/proxy/vmess" + "v2ray.com/core/proxy/vmess/inbound" + "v2ray.com/core/proxy/vmess/outbound" + . "v2ray.com/ext/assert" +) + +func startQuickClosingTCPServer() (net.Listener, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + go func() { + for { + conn, err := listener.Accept() + if err != nil { + break + } + b := make([]byte, 1024) + conn.Read(b) + conn.Close() + } + }() + return listener, nil +} + +func TestVMessClosing(t *testing.T) { + assert := With(t) + + tcpServer, err := startQuickClosingTCPServer() + assert(err, IsNil) + defer tcpServer.Close() + + dest := net.DestinationFromAddr(tcpServer.Addr()) + + userID := protocol.NewID(uuid.New()) + serverPort := pickPort() + serverConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&policy.Config{ + Level: map[uint32]*policy.Policy{ + 0: &policy.Policy{ + Timeout: &policy.Policy_Timeout{ + UplinkOnly: &policy.Second{Value: 0}, + DownlinkOnly: &policy.Second{Value: 0}, + }, + }, + }, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(serverPort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&inbound.Config{ + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + AlterId: 64, + }), + }, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + clientPort := pickPort() + clientConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&policy.Config{ + Level: map[uint32]*policy.Policy{ + 0: &policy.Policy{ + Timeout: &policy.Policy_Timeout{ + UplinkOnly: &policy.Second{Value: 0}, + DownlinkOnly: &policy.Second{Value: 0}, + }, + }, + }, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(clientPort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: net.NewIPOrDomain(dest.Address), + Port: uint32(dest.Port), + NetworkList: &net.NetworkList{ + Network: []net.Network{net.Network_TCP}, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&outbound.Config{ + Receiver: []*protocol.ServerEndpoint{ + { + Address: net.NewIPOrDomain(net.LocalHostIP), + Port: uint32(serverPort), + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + AlterId: 64, + SecuritySettings: &protocol.SecurityConfig{ + Type: protocol.SecurityType_AES128_GCM, + }, + }), + }, + }, + }, + }, + }), + }, + }, + } + + servers, err := InitializeServerConfigs(serverConfig, clientConfig) + assert(err, IsNil) + + defer CloseAllServers(servers) + + conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: int(clientPort), + }) + assert(err, IsNil) + + conn.SetDeadline(time.Now().Add(time.Second * 2)) + + nBytes, err := conn.Write([]byte("test payload")) + assert(nBytes, GreaterThan, 0) + assert(err, IsNil) + + resp := make([]byte, 1024) + nBytes, err = conn.Read(resp) + assert(err, Equals, io.EOF) + assert(nBytes, Equals, 0) + + CloseAllServers(servers) +}