diff --git a/common/protocol/id.go b/common/protocol/id.go index 3e6346b6b..271315e53 100644 --- a/common/protocol/id.go +++ b/common/protocol/id.go @@ -57,3 +57,15 @@ func NewID(uuid *uuid.UUID) *ID { md5hash.Sum(id.cmdKey[:0]) return id } + +func NewAlterIDs(primary *ID, alterIDCount uint16) []*ID { + alterIDs := make([]*ID, alterIDCount) + prevID := primary.UUID() + for idx := range alterIDs { + newid := prevID.Next() + // TODO: check duplicates + alterIDs[idx] = NewID(newid) + prevID = newid + } + return alterIDs +} diff --git a/common/protocol/raw/encoding_test.go b/common/protocol/raw/encoding_test.go index 8e4aa0ffc..1b6712939 100644 --- a/common/protocol/raw/encoding_test.go +++ b/common/protocol/raw/encoding_test.go @@ -18,8 +18,8 @@ func TestRequestSerialization(t *testing.T) { user := protocol.NewUser( protocol.NewID(uuid.New()), + nil, protocol.UserLevelUntrusted, - 0, "test@v2ray.com") expectedRequest := &protocol.RequestHeader{ diff --git a/common/protocol/user.go b/common/protocol/user.go index b4363ed1d..4c765f6c7 100644 --- a/common/protocol/user.go +++ b/common/protocol/user.go @@ -18,23 +18,13 @@ type User struct { Email string } -func NewUser(id *ID, level UserLevel, alterIdCount uint16, email string) *User { - u := &User{ - ID: id, - Level: level, - Email: email, +func NewUser(primary *ID, secondary []*ID, level UserLevel, email string) *User { + return &User{ + ID: primary, + AlterIDs: secondary, + Level: level, + Email: email, } - if alterIdCount > 0 { - u.AlterIDs = make([]*ID, alterIdCount) - prevId := id.UUID() - for idx := range u.AlterIDs { - newid := prevId.Next() - // TODO: check duplicate - u.AlterIDs[idx] = NewID(newid) - prevId = newid - } - } - return u } func (this *User) AnyValidID() *ID { diff --git a/common/protocol/user_json.go b/common/protocol/user_json.go index 175772440..6b1351dda 100644 --- a/common/protocol/user_json.go +++ b/common/protocol/user_json.go @@ -23,7 +23,9 @@ func (u *User) UnmarshalJSON(data []byte) error { if err != nil { return err } - *u = *NewUser(NewID(id), UserLevel(rawUserValue.LevelByte), rawUserValue.AlterIdCount, rawUserValue.EmailString) + primaryID := NewID(id) + alterIDs := NewAlterIDs(primaryID, rawUserValue.AlterIdCount) + *u = *NewUser(primaryID, alterIDs, UserLevel(rawUserValue.LevelByte), rawUserValue.EmailString) return nil } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 0bb551418..bab194ece 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -50,7 +50,8 @@ func (this *userByEmail) Get(email string) (*protocol.User, bool) { user, found = this.cache[email] if !found { id := protocol.NewID(uuid.New()) - user = protocol.NewUser(id, this.defaultLevel, this.defaultAlterIDs, email) + alterIDs := protocol.NewAlterIDs(id, this.defaultAlterIDs) + user = protocol.NewUser(id, alterIDs, this.defaultLevel, email) this.cache[email] = user } this.Unlock() diff --git a/proxy/vmess/outbound/command.go b/proxy/vmess/outbound/command.go index f775f469a..f7fabf6e0 100644 --- a/proxy/vmess/outbound/command.go +++ b/proxy/vmess/outbound/command.go @@ -6,7 +6,9 @@ import ( ) func (this *VMessOutboundHandler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { - user := protocol.NewUser(protocol.NewID(cmd.ID), cmd.Level, cmd.AlterIds.Value(), "") + primary := protocol.NewID(cmd.ID) + alters := protocol.NewAlterIDs(primary, cmd.AlterIds.Value()) + user := protocol.NewUser(primary, alters, cmd.Level, "") dest := v2net.TCPDestination(cmd.Host, cmd.Port) this.receiverManager.AddDetour(NewReceiver(dest, user), cmd.ValidMin) } diff --git a/proxy/vmess/outbound/receiver_test.go b/proxy/vmess/outbound/receiver_test.go index f0b2e1431..53fe8a0e5 100644 --- a/proxy/vmess/outbound/receiver_test.go +++ b/proxy/vmess/outbound/receiver_test.go @@ -15,13 +15,15 @@ func TestReceiverUser(t *testing.T) { v2testing.Current(t) id := protocol.NewID(uuid.New()) - user := protocol.NewUser(id, protocol.UserLevel(0), 100, "") + alters := protocol.NewAlterIDs(id, 100) + user := protocol.NewUser(id, alters, protocol.UserLevel(0), "") rec := NewReceiver(v2net.TCPDestination(v2net.DomainAddress("v2ray.com"), 80), user) assert.Bool(rec.HasUser(user)).IsTrue() assert.Int(len(rec.Accounts)).Equals(1) id2 := protocol.NewID(uuid.New()) - user2 := protocol.NewUser(id2, protocol.UserLevel(0), 100, "") + alters2 := protocol.NewAlterIDs(id2, 100) + user2 := protocol.NewUser(id2, alters2, protocol.UserLevel(0), "") assert.Bool(rec.HasUser(user2)).IsFalse() rec.AddUser(user2)