diff --git a/common/dice/dice.go b/common/dice/dice.go new file mode 100644 index 000000000..01d2ca804 --- /dev/null +++ b/common/dice/dice.go @@ -0,0 +1,12 @@ +package dice + +import ( + "math/rand" +) + +func Roll(n int) int { + if n == 1 { + return 0 + } + return rand.Intn(n) +} diff --git a/proxy/vmess/outbound/receiver.go b/proxy/vmess/outbound/receiver.go index 31e20d644..591f08274 100644 --- a/proxy/vmess/outbound/receiver.go +++ b/proxy/vmess/outbound/receiver.go @@ -1,10 +1,10 @@ package outbound import ( - "math/rand" "sync" "time" + "github.com/v2ray/v2ray-core/common/dice" v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/proxy/vmess" ) @@ -44,12 +44,7 @@ func (this *Receiver) AddUser(user *vmess.User) { } func (this *Receiver) PickUser() *vmess.User { - userLen := len(this.Accounts) - userIdx := 0 - if userLen > 1 { - userIdx = rand.Intn(userLen) - } - return this.Accounts[userIdx] + return this.Accounts[dice.Roll(len(this.Accounts))] } type ExpiringReceiver struct { @@ -108,11 +103,7 @@ func (this *ReceiverManager) pickDetour() *Receiver { return nil } this.detourAccess.RLock() - idx := 0 - detourLen := len(this.detours) - if detourLen > 1 { - idx = rand.Intn(detourLen) - } + idx := dice.Roll(len(this.detours)) rec := this.detours[idx] this.detourAccess.RUnlock() @@ -129,14 +120,7 @@ func (this *ReceiverManager) pickDetour() *Receiver { } func (this *ReceiverManager) pickStdReceiver() *Receiver { - receiverLen := len(this.receivers) - - receiverIdx := 0 - if receiverLen > 1 { - receiverIdx = rand.Intn(receiverLen) - } - - return this.receivers[receiverIdx] + return this.receivers[dice.Roll(len(this.receivers))] } func (this *ReceiverManager) PickReceiver() (v2net.Destination, *vmess.User) { diff --git a/proxy/vmess/user.go b/proxy/vmess/user.go index 700670e13..31cd0b6ec 100644 --- a/proxy/vmess/user.go +++ b/proxy/vmess/user.go @@ -1,7 +1,7 @@ package vmess import ( - "math/rand" + "github.com/v2ray/v2ray-core/common/dice" ) type UserLevel byte @@ -39,11 +39,7 @@ func (this *User) AnyValidID() *ID { if len(this.AlterIDs) == 0 { return this.ID } - if len(this.AlterIDs) == 1 { - return this.AlterIDs[0] - } - idx := rand.Intn(len(this.AlterIDs)) - return this.AlterIDs[idx] + return this.AlterIDs[dice.Roll(len(this.AlterIDs))] } type UserSettings struct { diff --git a/shell/point/inbound_detour_always.go b/shell/point/inbound_detour_always.go index de95c0e6c..dbcf2f841 100644 --- a/shell/point/inbound_detour_always.go +++ b/shell/point/inbound_detour_always.go @@ -1,9 +1,8 @@ package point import ( - "math/rand" - "github.com/v2ray/v2ray-core/app" + "github.com/v2ray/v2ray-core/common/dice" "github.com/v2ray/v2ray-core/common/log" v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/common/retry" @@ -46,8 +45,7 @@ func NewInboundDetourHandlerAlways(space app.Space, config *InboundDetourConfig) } func (this *InboundDetourHandlerAlways) GetConnectionHandler() (proxy.InboundConnectionHandler, int) { - idx := rand.Intn(len(this.ich)) - ich := this.ich[idx] + ich := this.ich[dice.Roll(len(this.ich))] return ich.handler, this.config.Allocation.Refresh } diff --git a/transport/dialer/dialer.go b/transport/dialer/dialer.go index fdbf05e36..306a7b92b 100644 --- a/transport/dialer/dialer.go +++ b/transport/dialer/dialer.go @@ -2,9 +2,9 @@ package dialer import ( "errors" - "math/rand" "net" + "github.com/v2ray/v2ray-core/common/dice" v2net "github.com/v2ray/v2ray-core/common/net" ) @@ -24,11 +24,7 @@ func Dial(dest v2net.Destination) (net.Conn, error) { if len(ips) == 0 { return nil, ErrInvalidHost } - if len(ips) == 1 { - ip = ips[0] - } else { - ip = ips[rand.Intn(len(ips))] - } + ip = ips[dice.Roll(len(ips))] } if dest.IsTCP() { return net.DialTCP("tcp", nil, &net.TCPAddr{ diff --git a/transport/dialer/dialer_test.go b/transport/dialer/dialer_test.go index e3ffa1e10..358af891b 100644 --- a/transport/dialer/dialer_test.go +++ b/transport/dialer/dialer_test.go @@ -19,6 +19,7 @@ func TestDialDomain(t *testing.T) { } dest, err := server.Start() assert.Error(err).IsNil() + defer server.Close() conn, err := Dial(v2net.TCPDestination(v2net.DomainAddress("local.v2ray.com"), dest.Port())) assert.Error(err).IsNil()