From 2034d54bab7c2a9169dff2ca2156218d539f7fbc Mon Sep 17 00:00:00 2001 From: v2ray Date: Mon, 25 Jul 2016 17:36:24 +0200 Subject: [PATCH] rename VMessAccount to vmess.Account --- common/protocol/account.go | 25 ---- common/protocol/account_json.go | 36 ------ common/protocol/server_spec_test.go | 29 +++-- common/protocol/user.go | 7 +- common/protocol/user_json.go | 12 +- common/protocol/user_json_test.go | 15 +-- common/protocol/user_validator.go | 151 ---------------------- proxy/vmess/account_json.go | 31 +++++ proxy/vmess/encoding/client.go | 5 +- proxy/vmess/encoding/encoding_test.go | 11 +- proxy/vmess/encoding/server.go | 3 +- proxy/vmess/inbound/command.go | 5 +- proxy/vmess/inbound/config_json.go | 24 +++- proxy/vmess/inbound/inbound.go | 8 +- proxy/vmess/outbound/command.go | 6 +- proxy/vmess/outbound/config_json.go | 17 ++- proxy/vmess/vmess.go | 177 ++++++++++++++++++++++++++ 17 files changed, 288 insertions(+), 274 deletions(-) delete mode 100644 common/protocol/account_json.go create mode 100644 proxy/vmess/account_json.go diff --git a/common/protocol/account.go b/common/protocol/account.go index e70ff4f71..f78b0f49f 100644 --- a/common/protocol/account.go +++ b/common/protocol/account.go @@ -1,30 +1,5 @@ package protocol -import ( - "github.com/v2ray/v2ray-core/common/dice" -) - type Account interface { Equals(Account) bool } - -type VMessAccount struct { - ID *ID - AlterIDs []*ID -} - -func (this *VMessAccount) AnyValidID() *ID { - if len(this.AlterIDs) == 0 { - return this.ID - } - return this.AlterIDs[dice.Roll(len(this.AlterIDs))] -} - -func (this *VMessAccount) Equals(account Account) bool { - vmessAccount, ok := account.(*VMessAccount) - if !ok { - return false - } - // TODO: handle AlterIds difference - return this.ID.Equals(vmessAccount.ID) -} diff --git a/common/protocol/account_json.go b/common/protocol/account_json.go deleted file mode 100644 index f17421d78..000000000 --- a/common/protocol/account_json.go +++ /dev/null @@ -1,36 +0,0 @@ -// +build json - -package protocol - -import ( - "errors" - - "github.com/v2ray/v2ray-core/common/uuid" -) - -type AccountJson struct { - ID string `json:"id"` - AlterIds uint16 `json:"alterId"` - - Username string `json:"user"` - Password string `json:"pass"` -} - -func (this *AccountJson) GetAccount() (Account, error) { - if len(this.ID) > 0 { - id, err := uuid.ParseString(this.ID) - if err != nil { - return nil, err - } - - primaryID := NewID(id) - alterIDs := NewAlterIDs(primaryID, this.AlterIds) - - return &VMessAccount{ - ID: primaryID, - AlterIDs: alterIDs, - }, nil - } - - return nil, errors.New("Protocol: Malformed account.") -} diff --git a/common/protocol/server_spec_test.go b/common/protocol/server_spec_test.go index c71ed58c9..ee53c351e 100644 --- a/common/protocol/server_spec_test.go +++ b/common/protocol/server_spec_test.go @@ -5,30 +5,33 @@ import ( v2net "github.com/v2ray/v2ray-core/common/net" . "github.com/v2ray/v2ray-core/common/protocol" - "github.com/v2ray/v2ray-core/common/uuid" "github.com/v2ray/v2ray-core/testing/assert" ) +type TestAccount struct { + id int +} + +func (this *TestAccount) Equals(account Account) bool { + return account.(*TestAccount).id == this.id +} + func TestReceiverUser(t *testing.T) { assert := assert.On(t) - id := NewID(uuid.New()) - alters := NewAlterIDs(id, 100) - account := &VMessAccount{ - ID: id, - AlterIDs: alters, + account := &TestAccount{ + id: 0, } - user := NewUser(account, UserLevel(0), "") + user := NewUser(UserLevel(0), "") + user.Account = account rec := NewServerSpec(v2net.TCPDestination(v2net.DomainAddress("v2ray.com"), 80), AlwaysValid(), user) assert.Bool(rec.HasUser(user)).IsTrue() - id2 := NewID(uuid.New()) - alters2 := NewAlterIDs(id2, 100) - account2 := &VMessAccount{ - ID: id2, - AlterIDs: alters2, + account2 := &TestAccount{ + id: 1, } - user2 := NewUser(account2, UserLevel(0), "") + user2 := NewUser(UserLevel(0), "") + user2.Account = account2 assert.Bool(rec.HasUser(user2)).IsFalse() rec.AddUser(user2) diff --git a/common/protocol/user.go b/common/protocol/user.go index 379b03579..73c8425e2 100644 --- a/common/protocol/user.go +++ b/common/protocol/user.go @@ -13,11 +13,10 @@ type User struct { Email string } -func NewUser(account Account, level UserLevel, email string) *User { +func NewUser(level UserLevel, email string) *User { return &User{ - Account: account, - Level: level, - Email: email, + Level: level, + Email: email, } } diff --git a/common/protocol/user_json.go b/common/protocol/user_json.go index d171871bc..cbfc77274 100644 --- a/common/protocol/user_json.go +++ b/common/protocol/user_json.go @@ -14,16 +14,8 @@ func (u *User) UnmarshalJSON(data []byte) error { return err } - var rawAccount AccountJson - if err := json.Unmarshal(data, &rawAccount); err != nil { - return err - } - account, err := rawAccount.GetAccount() - if err != nil { - return err - } - - *u = *NewUser(account, UserLevel(rawUserValue.LevelByte), rawUserValue.EmailString) + u.Email = rawUserValue.EmailString + u.Level = UserLevel(rawUserValue.LevelByte) return nil } diff --git a/common/protocol/user_json_test.go b/common/protocol/user_json_test.go index 385900c50..de72387f7 100644 --- a/common/protocol/user_json_test.go +++ b/common/protocol/user_json_test.go @@ -22,24 +22,13 @@ func TestUserParsing(t *testing.T) { }`), user) assert.Error(err).IsNil() assert.Byte(byte(user.Level)).Equals(1) - - account, ok := user.Account.(*VMessAccount) - assert.Bool(ok).IsTrue() - assert.String(account.ID.String()).Equals("96edb838-6d68-42ef-a933-25f7ac3a9d09") + assert.String(user.Email).Equals("love@v2ray.com") } func TestInvalidUserJson(t *testing.T) { assert := assert.On(t) user := new(User) - err := json.Unmarshal([]byte(`{"id": 1234}`), user) - assert.Error(err).IsNotNil() -} - -func TestInvalidIdJson(t *testing.T) { - assert := assert.On(t) - - user := new(User) - err := json.Unmarshal([]byte(`{"id": "1234"}`), user) + err := json.Unmarshal([]byte(`{"email": 1234}`), user) assert.Error(err).IsNotNil() } diff --git a/common/protocol/user_validator.go b/common/protocol/user_validator.go index 4b665abae..dda6673c4 100644 --- a/common/protocol/user_validator.go +++ b/common/protocol/user_validator.go @@ -1,163 +1,12 @@ package protocol import ( - "sync" - "time" - "github.com/v2ray/v2ray-core/common" - "github.com/v2ray/v2ray-core/common/signal" ) -const ( - updateIntervalSec = 10 - cacheDurationSec = 120 -) - -type idEntry struct { - id *ID - userIdx int - lastSec Timestamp - lastSecRemoval Timestamp -} - type UserValidator interface { common.Releasable Add(user *User) error Get(timeHash []byte) (*User, Timestamp, bool) } - -type TimedUserValidator struct { - sync.RWMutex - running bool - validUsers []*User - userHash map[[16]byte]*indexTimePair - ids []*idEntry - hasher IDHash - cancel *signal.CancelSignal -} - -type indexTimePair struct { - index int - timeSec Timestamp -} - -func NewTimedUserValidator(hasher IDHash) UserValidator { - tus := &TimedUserValidator{ - validUsers: make([]*User, 0, 16), - userHash: make(map[[16]byte]*indexTimePair, 512), - ids: make([]*idEntry, 0, 512), - hasher: hasher, - running: true, - cancel: signal.NewCloseSignal(), - } - go tus.updateUserHash(updateIntervalSec * time.Second) - return tus -} - -func (this *TimedUserValidator) Release() { - if !this.running { - return - } - - this.cancel.Cancel() - <-this.cancel.WaitForDone() - - this.Lock() - defer this.Unlock() - - if !this.running { - return - } - - this.running = false - this.validUsers = nil - this.userHash = nil - this.ids = nil - this.hasher = nil - this.cancel = nil -} - -func (this *TimedUserValidator) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) { - var hashValue [16]byte - var hashValueRemoval [16]byte - idHash := this.hasher(entry.id.Bytes()) - for entry.lastSec <= nowSec { - idHash.Write(entry.lastSec.Bytes(nil)) - idHash.Sum(hashValue[:0]) - idHash.Reset() - - idHash.Write(entry.lastSecRemoval.Bytes(nil)) - idHash.Sum(hashValueRemoval[:0]) - idHash.Reset() - - this.Lock() - this.userHash[hashValue] = &indexTimePair{idx, entry.lastSec} - delete(this.userHash, hashValueRemoval) - this.Unlock() - - entry.lastSec++ - entry.lastSecRemoval++ - } -} - -func (this *TimedUserValidator) updateUserHash(interval time.Duration) { -L: - for { - select { - case now := <-time.After(interval): - nowSec := Timestamp(now.Unix() + cacheDurationSec) - for _, entry := range this.ids { - this.generateNewHashes(nowSec, entry.userIdx, entry) - } - case <-this.cancel.WaitForCancel(): - break L - } - } - this.cancel.Done() -} - -func (this *TimedUserValidator) Add(user *User) error { - idx := len(this.validUsers) - this.validUsers = append(this.validUsers, user) - account := user.Account.(*VMessAccount) - - nowSec := time.Now().Unix() - - entry := &idEntry{ - id: account.ID, - userIdx: idx, - lastSec: Timestamp(nowSec - cacheDurationSec), - lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3), - } - this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry) - this.ids = append(this.ids, entry) - for _, alterid := range account.AlterIDs { - entry := &idEntry{ - id: alterid, - userIdx: idx, - lastSec: Timestamp(nowSec - cacheDurationSec), - lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3), - } - this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry) - this.ids = append(this.ids, entry) - } - - return nil -} - -func (this *TimedUserValidator) Get(userHash []byte) (*User, Timestamp, bool) { - defer this.RUnlock() - this.RLock() - - if !this.running { - return nil, 0, false - } - var fixedSizeHash [16]byte - copy(fixedSizeHash[:], userHash) - pair, found := this.userHash[fixedSizeHash] - if found { - return this.validUsers[pair.index], pair.timeSec, true - } - return nil, 0, false -} diff --git a/proxy/vmess/account_json.go b/proxy/vmess/account_json.go new file mode 100644 index 000000000..546a3386f --- /dev/null +++ b/proxy/vmess/account_json.go @@ -0,0 +1,31 @@ +// +build json + +package vmess + +import ( + "encoding/json" + + "github.com/v2ray/v2ray-core/common/log" + "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/common/uuid" +) + +func (u *Account) UnmarshalJSON(data []byte) error { + type JsonConfig struct { + ID string `json:"id"` + AlterIds uint16 `json:"alterId"` + } + var rawConfig JsonConfig + if err := json.Unmarshal(data, &rawConfig); err != nil { + return err + } + id, err := uuid.ParseString(rawConfig.ID) + if err != nil { + log.Error("VMess: Failed to parse ID: ", err) + return err + } + u.ID = protocol.NewID(id) + u.AlterIDs = protocol.NewAlterIDs(u.ID, rawConfig.AlterIds) + + return nil +} diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index c04325858..b15ad1dc5 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -9,6 +9,7 @@ import ( "github.com/v2ray/v2ray-core/common/crypto" "github.com/v2ray/v2ray-core/common/log" "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/proxy/vmess" "github.com/v2ray/v2ray-core/transport" ) @@ -50,7 +51,7 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { func (this *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) { timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() - idHash := this.idHash(header.User.Account.(*protocol.VMessAccount).AnyValidID().Bytes()) + idHash := this.idHash(header.User.Account.(*vmess.Account).AnyValidID().Bytes()) idHash.Write(timestamp.Bytes(nil)) writer.Write(idHash.Sum(nil)) @@ -81,7 +82,7 @@ func (this *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, w timestampHash := md5.New() timestampHash.Write(hashTimestamp(timestamp)) iv := timestampHash.Sum(nil) - account := header.User.Account.(*protocol.VMessAccount) + account := header.User.Account.(*vmess.Account) aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv) aesStream.XORKeyStream(buffer, buffer) writer.Write(buffer) diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index 9cc0b225f..fcf72edfb 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -7,6 +7,7 @@ import ( v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/common/uuid" + "github.com/v2ray/v2ray-core/proxy/vmess" . "github.com/v2ray/v2ray-core/proxy/vmess/encoding" "github.com/v2ray/v2ray-core/testing/assert" ) @@ -15,12 +16,12 @@ func TestRequestSerialization(t *testing.T) { assert := assert.On(t) user := protocol.NewUser( - &protocol.VMessAccount{ - ID: protocol.NewID(uuid.New()), - AlterIDs: nil, - }, protocol.UserLevelUntrusted, "test@v2ray.com") + user.Account = &vmess.Account{ + ID: protocol.NewID(uuid.New()), + AlterIDs: nil, + } expectedRequest := &protocol.RequestHeader{ Version: 1, @@ -35,7 +36,7 @@ func TestRequestSerialization(t *testing.T) { client := NewClientSession(protocol.DefaultIDHash) client.EncodeRequestHeader(expectedRequest, buffer) - userValidator := protocol.NewTimedUserValidator(protocol.DefaultIDHash) + userValidator := vmess.NewTimedUserValidator(protocol.DefaultIDHash) userValidator.Add(user) server := NewServerSession(userValidator) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index f63de500b..d93332f0f 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -10,6 +10,7 @@ import ( v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/common/serial" + "github.com/v2ray/v2ray-core/proxy/vmess" "github.com/v2ray/v2ray-core/transport" ) @@ -58,7 +59,7 @@ func (this *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Requ timestampHash := md5.New() timestampHash.Write(hashTimestamp(timestamp)) iv := timestampHash.Sum(nil) - account := user.Account.(*protocol.VMessAccount) + account := user.Account.(*vmess.Account) aesStream := crypto.NewAesDecryptionStream(account.ID.CmdKey(), iv) decryptor := crypto.NewCryptionReader(aesStream, reader) diff --git a/proxy/vmess/inbound/command.go b/proxy/vmess/inbound/command.go index c4b7a14cd..ffaa66a2a 100644 --- a/proxy/vmess/inbound/command.go +++ b/proxy/vmess/inbound/command.go @@ -3,6 +3,7 @@ package inbound import ( "github.com/v2ray/v2ray-core/common/log" "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/proxy/vmess" ) func (this *VMessInboundHandler) generateCommand(request *protocol.RequestHeader) protocol.ResponseCommand { @@ -23,8 +24,8 @@ func (this *VMessInboundHandler) generateCommand(request *protocol.RequestHeader } return &protocol.CommandSwitchAccount{ Port: inboundHandler.Port(), - ID: user.Account.(*protocol.VMessAccount).ID.UUID(), - AlterIds: uint16(len(user.Account.(*protocol.VMessAccount).AlterIDs)), + ID: user.Account.(*vmess.Account).ID.UUID(), + AlterIds: uint16(len(user.Account.(*vmess.Account).AlterIDs)), Level: user.Level, ValidMin: byte(availableMin), } diff --git a/proxy/vmess/inbound/config_json.go b/proxy/vmess/inbound/config_json.go index 0538d235d..8a5944748 100644 --- a/proxy/vmess/inbound/config_json.go +++ b/proxy/vmess/inbound/config_json.go @@ -8,6 +8,7 @@ import ( "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/proxy/internal" + "github.com/v2ray/v2ray-core/proxy/vmess" ) func (this *DetourConfig) UnmarshalJSON(data []byte) error { @@ -53,16 +54,15 @@ func (this *DefaultConfig) UnmarshalJSON(data []byte) error { func (this *Config) UnmarshalJSON(data []byte) error { type JsonConfig struct { - Users []*protocol.User `json:"clients"` - Features *FeaturesConfig `json:"features"` - Defaults *DefaultConfig `json:"default"` - DetourConfig *DetourConfig `json:"detour"` + Users []json.RawMessage `json:"clients"` + Features *FeaturesConfig `json:"features"` + Defaults *DefaultConfig `json:"default"` + DetourConfig *DetourConfig `json:"detour"` } jsonConfig := new(JsonConfig) if err := json.Unmarshal(data, jsonConfig); err != nil { return errors.New("VMessIn: Failed to parse config: " + err.Error()) } - this.AllowedUsers = jsonConfig.Users this.Features = jsonConfig.Features // Backward compatibility this.Defaults = jsonConfig.Defaults if this.Defaults == nil { @@ -76,6 +76,20 @@ func (this *Config) UnmarshalJSON(data []byte) error { if this.Features != nil && this.DetourConfig == nil { this.DetourConfig = this.Features.Detour } + this.AllowedUsers = make([]*protocol.User, len(jsonConfig.Users)) + for idx, rawData := range jsonConfig.Users { + user := new(protocol.User) + if err := json.Unmarshal(rawData, user); err != nil { + return errors.New("VMess|Inbound: Invalid user: " + err.Error()) + } + account := new(vmess.Account) + if err := json.Unmarshal(rawData, account); err != nil { + return errors.New("VMess|Inbound: Invalid user: " + err.Error()) + } + user.Account = account + this.AllowedUsers[idx] = user + } + return nil } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 7a9941825..14117cf04 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -15,6 +15,7 @@ import ( "github.com/v2ray/v2ray-core/common/uuid" "github.com/v2ray/v2ray-core/proxy" "github.com/v2ray/v2ray-core/proxy/internal" + "github.com/v2ray/v2ray-core/proxy/vmess" "github.com/v2ray/v2ray-core/proxy/vmess/encoding" vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io" "github.com/v2ray/v2ray-core/transport/internet" @@ -51,11 +52,12 @@ func (this *userByEmail) Get(email string) (*protocol.User, bool) { if !found { id := protocol.NewID(uuid.New()) alterIDs := protocol.NewAlterIDs(id, this.defaultAlterIDs) - account := &protocol.VMessAccount{ + account := &vmess.Account{ ID: id, AlterIDs: alterIDs, } - user = protocol.NewUser(account, this.defaultLevel, email) + user = protocol.NewUser(this.defaultLevel, email) + user.Account = account this.cache[email] = user } this.Unlock() @@ -249,7 +251,7 @@ func (this *Factory) Create(space app.Space, rawConfig interface{}, meta *proxy. } config := rawConfig.(*Config) - allowedClients := protocol.NewTimedUserValidator(protocol.DefaultIDHash) + allowedClients := vmess.NewTimedUserValidator(protocol.DefaultIDHash) for _, user := range config.AllowedUsers { allowedClients.Add(user) } diff --git a/proxy/vmess/outbound/command.go b/proxy/vmess/outbound/command.go index ce19d61db..3ffefbb9a 100644 --- a/proxy/vmess/outbound/command.go +++ b/proxy/vmess/outbound/command.go @@ -5,16 +5,18 @@ import ( v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/proxy/vmess" ) func (this *VMessOutboundHandler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { primary := protocol.NewID(cmd.ID) alters := protocol.NewAlterIDs(primary, cmd.AlterIds) - account := &protocol.VMessAccount{ + account := &vmess.Account{ ID: primary, AlterIDs: alters, } - user := protocol.NewUser(account, cmd.Level, "") + user := protocol.NewUser(cmd.Level, "") + user.Account = account dest := v2net.TCPDestination(cmd.Host, cmd.Port) until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute) this.serverList.AddServer(protocol.NewServerSpec(dest, protocol.BeforeTime(until), user)) diff --git a/proxy/vmess/outbound/config_json.go b/proxy/vmess/outbound/config_json.go index 9aa6f0d67..fa94c8787 100644 --- a/proxy/vmess/outbound/config_json.go +++ b/proxy/vmess/outbound/config_json.go @@ -11,13 +11,14 @@ import ( "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/common/serial" "github.com/v2ray/v2ray-core/proxy/internal" + "github.com/v2ray/v2ray-core/proxy/vmess" ) func (this *Config) UnmarshalJSON(data []byte) error { type RawConfigTarget struct { Address *v2net.AddressJson `json:"address"` Port v2net.Port `json:"port"` - Users []*protocol.User `json:"users"` + Users []json.RawMessage `json:"users"` } type RawOutbound struct { Receivers []*RawConfigTarget `json:"vnext"` @@ -45,7 +46,19 @@ func (this *Config) UnmarshalJSON(data []byte) error { rec.Address.Address = v2net.IPAddress(serial.Uint32ToBytes(2891346854, nil)) } spec := protocol.NewServerSpec(v2net.TCPDestination(rec.Address.Address, rec.Port), protocol.AlwaysValid()) - for _, user := range rec.Users { + for _, rawUser := range rec.Users { + user := new(protocol.User) + if err := json.Unmarshal(rawUser, user); err != nil { + log.Error("VMess|Outbound: Invalid user: ", err) + return err + } + account := new(vmess.Account) + if err := json.Unmarshal(rawUser, account); err != nil { + log.Error("VMess|Outbound: Invalid user: ", err) + return err + } + user.Account = account + spec.AddUser(user) } serverSpecs[idx] = spec diff --git a/proxy/vmess/vmess.go b/proxy/vmess/vmess.go index 5a6936957..d8e08bbff 100644 --- a/proxy/vmess/vmess.go +++ b/proxy/vmess/vmess.go @@ -4,3 +4,180 @@ // together with 'freedom' to talk to final destination, while VMess outbound is usually used on // clients with 'socks' for proxying. package vmess // import "github.com/v2ray/v2ray-core/proxy/vmess" + +import ( + "sync" + "time" + + "github.com/v2ray/v2ray-core/common/dice" + "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/common/signal" +) + +type Account struct { + ID *protocol.ID + AlterIDs []*protocol.ID +} + +func (this *Account) AnyValidID() *protocol.ID { + if len(this.AlterIDs) == 0 { + return this.ID + } + return this.AlterIDs[dice.Roll(len(this.AlterIDs))] +} + +func (this *Account) Equals(account protocol.Account) bool { + vmessAccount, ok := account.(*Account) + if !ok { + return false + } + // TODO: handle AlterIds difference + return this.ID.Equals(vmessAccount.ID) +} + +const ( + updateIntervalSec = 10 + cacheDurationSec = 120 +) + +type idEntry struct { + id *protocol.ID + userIdx int + lastSec protocol.Timestamp + lastSecRemoval protocol.Timestamp +} + +type TimedUserValidator struct { + sync.RWMutex + running bool + validUsers []*protocol.User + userHash map[[16]byte]*indexTimePair + ids []*idEntry + hasher protocol.IDHash + cancel *signal.CancelSignal +} + +type indexTimePair struct { + index int + timeSec protocol.Timestamp +} + +func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator { + tus := &TimedUserValidator{ + validUsers: make([]*protocol.User, 0, 16), + userHash: make(map[[16]byte]*indexTimePair, 512), + ids: make([]*idEntry, 0, 512), + hasher: hasher, + running: true, + cancel: signal.NewCloseSignal(), + } + go tus.updateUserHash(updateIntervalSec * time.Second) + return tus +} + +func (this *TimedUserValidator) Release() { + if !this.running { + return + } + + this.cancel.Cancel() + <-this.cancel.WaitForDone() + + this.Lock() + defer this.Unlock() + + if !this.running { + return + } + + this.running = false + this.validUsers = nil + this.userHash = nil + this.ids = nil + this.hasher = nil + this.cancel = nil +} + +func (this *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx int, entry *idEntry) { + var hashValue [16]byte + var hashValueRemoval [16]byte + idHash := this.hasher(entry.id.Bytes()) + for entry.lastSec <= nowSec { + idHash.Write(entry.lastSec.Bytes(nil)) + idHash.Sum(hashValue[:0]) + idHash.Reset() + + idHash.Write(entry.lastSecRemoval.Bytes(nil)) + idHash.Sum(hashValueRemoval[:0]) + idHash.Reset() + + this.Lock() + this.userHash[hashValue] = &indexTimePair{idx, entry.lastSec} + delete(this.userHash, hashValueRemoval) + this.Unlock() + + entry.lastSec++ + entry.lastSecRemoval++ + } +} + +func (this *TimedUserValidator) updateUserHash(interval time.Duration) { +L: + for { + select { + case now := <-time.After(interval): + nowSec := protocol.Timestamp(now.Unix() + cacheDurationSec) + for _, entry := range this.ids { + this.generateNewHashes(nowSec, entry.userIdx, entry) + } + case <-this.cancel.WaitForCancel(): + break L + } + } + this.cancel.Done() +} + +func (this *TimedUserValidator) Add(user *protocol.User) error { + idx := len(this.validUsers) + this.validUsers = append(this.validUsers, user) + account := user.Account.(*Account) + + nowSec := time.Now().Unix() + + entry := &idEntry{ + id: account.ID, + userIdx: idx, + lastSec: protocol.Timestamp(nowSec - cacheDurationSec), + lastSecRemoval: protocol.Timestamp(nowSec - cacheDurationSec*3), + } + this.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry) + this.ids = append(this.ids, entry) + for _, alterid := range account.AlterIDs { + entry := &idEntry{ + id: alterid, + userIdx: idx, + lastSec: protocol.Timestamp(nowSec - cacheDurationSec), + lastSecRemoval: protocol.Timestamp(nowSec - cacheDurationSec*3), + } + this.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry) + this.ids = append(this.ids, entry) + } + + return nil +} + +func (this *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Timestamp, bool) { + defer this.RUnlock() + this.RLock() + + if !this.running { + return nil, 0, false + } + var fixedSizeHash [16]byte + copy(fixedSizeHash[:], userHash) + pair, found := this.userHash[fixedSizeHash] + if found { + return this.validUsers[pair.index], pair.timeSec, true + } + return nil, 0, false +}