diff --git a/id.go b/id.go index dd4bc5c97..6d09a6a46 100644 --- a/id.go +++ b/id.go @@ -54,7 +54,10 @@ func (v ID) TimeHash(timeSec int64) []byte { } func (v ID) Hash(data []byte) []byte { - return v.hasher.Sum(data) + v.hasher.Write(data) + hash := v.hasher.Sum(nil) + v.hasher.Reset() + return hash } var byteGroups = []int{8, 4, 4, 4, 12} diff --git a/id_test.go b/id_test.go index 1d9d74f04..72fa958e0 100644 --- a/id_test.go +++ b/id_test.go @@ -12,6 +12,6 @@ func TestUUIDToID(t *testing.T) { uuid := "2418d087-648d-4990-86e8-19dca1d006d3" expectedBytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3} - actualBytes, _ := UUIDToID(uuid) - assert.Bytes(actualBytes.Bytes()).Named("UUID").Equals(expectedBytes) + actualBytes, _ := NewID(uuid) + assert.Bytes(actualBytes.Bytes).Named("UUID").Equals(expectedBytes) } diff --git a/io/vmess/vmess.go b/io/vmess/vmess.go index dbaab0389..1aea3d58e 100644 --- a/io/vmess/vmess.go +++ b/io/vmess/vmess.go @@ -49,10 +49,10 @@ type VMessRequest struct { } type VMessRequestReader struct { - vUserSet *core.UserSet + vUserSet core.UserSet } -func NewVMessRequestReader(vUserSet *core.UserSet) *VMessRequestReader { +func NewVMessRequestReader(vUserSet core.UserSet) *VMessRequestReader { reader := new(VMessRequestReader) reader.vUserSet = vUserSet return reader @@ -74,7 +74,7 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { return nil, err } - userId, valid := r.vUserSet.IsValidUserId(buffer[:nBytes]) + userId, valid := r.vUserSet.GetUser(buffer[:nBytes]) if !valid { return nil, ErrorInvalidUser } diff --git a/io/vmess/vmess_test.go b/io/vmess/vmess_test.go index 43a37cfed..1d1900da5 100644 --- a/io/vmess/vmess_test.go +++ b/io/vmess/vmess_test.go @@ -7,18 +7,19 @@ import ( "github.com/v2ray/v2ray-core" v2net "github.com/v2ray/v2ray-core/net" + "github.com/v2ray/v2ray-core/testing/mocks" "github.com/v2ray/v2ray-core/testing/unit" ) func TestVMessSerialization(t *testing.T) { assert := unit.Assert(t) - userId, err := core.UUIDToID("2b2966ac-16aa-4fbf-8d81-c5f172a3da51") + userId, err := core.NewID("2b2966ac-16aa-4fbf-8d81-c5f172a3da51") if err != nil { t.Fatal(err) } - userSet := core.NewUserSet() + userSet := mocks.MockUserSet{[]core.ID{}, make(map[string]int)} userSet.AddUser(core.User{userId}) request := new(VMessRequest) @@ -50,14 +51,16 @@ func TestVMessSerialization(t *testing.T) { t.Fatal(err) } - requestReader := NewVMessRequestReader(userSet) + userSet.UserHashes[string(buffer.Bytes()[1:17])] = 0 + + requestReader := NewVMessRequestReader(&userSet) actualRequest, err := requestReader.Read(buffer) if err != nil { t.Fatal(err) } assert.Byte(actualRequest.Version).Named("Version").Equals(byte(0x01)) - assert.Bytes(actualRequest.UserId[:]).Named("UserId").Equals(request.UserId[:]) + assert.String(actualRequest.UserId.String).Named("UserId").Equals(request.UserId.String) assert.Bytes(actualRequest.RequestIV[:]).Named("RequestIV").Equals(request.RequestIV[:]) assert.Bytes(actualRequest.RequestKey[:]).Named("RequestKey").Equals(request.RequestKey[:]) assert.Bytes(actualRequest.ResponseHeader[:]).Named("ResponseHeader").Equals(request.ResponseHeader[:]) diff --git a/net/vmess/vmessin.go b/net/vmess/vmessin.go index 8cc036953..09d0bbf51 100644 --- a/net/vmess/vmessin.go +++ b/net/vmess/vmessin.go @@ -15,11 +15,11 @@ import ( type VMessInboundHandler struct { vPoint *core.Point - clients *core.UserSet + clients core.UserSet accepting bool } -func NewVMessInboundHandler(vp *core.Point, clients *core.UserSet) *VMessInboundHandler { +func NewVMessInboundHandler(vp *core.Point, clients core.UserSet) *VMessInboundHandler { handler := new(VMessInboundHandler) handler.vPoint = vp handler.clients = clients @@ -121,7 +121,7 @@ func (factory *VMessInboundHandlerFactory) Create(vp *core.Point, rawConfig []by if err != nil { panic(log.Error("Failed to load VMess inbound config: %v", err)) } - allowedClients := core.NewUserSet() + allowedClients := core.NewTimedUserSet() for _, client := range config.AllowedClients { user, err := client.ToUser() if err != nil { diff --git a/testing/mocks/mockuserset.go b/testing/mocks/mockuserset.go new file mode 100644 index 000000000..5c994118b --- /dev/null +++ b/testing/mocks/mockuserset.go @@ -0,0 +1,23 @@ +package mocks + +import ( + "github.com/v2ray/v2ray-core" +) + +type MockUserSet struct { + UserIds []core.ID + UserHashes map[string]int +} + +func (us *MockUserSet) AddUser(user core.User) error { + us.UserIds = append(us.UserIds, user.Id) + return nil +} + +func (us *MockUserSet) GetUser(userhash []byte) (*core.ID, bool) { + idx, found := us.UserHashes[string(userhash)] + if found { + return &us.UserIds[idx], true + } + return nil, false +} diff --git a/userset.go b/userset.go index f3e98d9eb..4c466ed2f 100644 --- a/userset.go +++ b/userset.go @@ -9,7 +9,12 @@ const ( cacheDurationSec = 120 ) -type UserSet struct { +type UserSet interface { + AddUser(user User) error + GetUser(timeHash []byte) (*ID, bool) +} + +type TimedUserSet struct { validUserIds []ID userHashes map[string]int } @@ -19,8 +24,8 @@ type hashEntry struct { timeSec int64 } -func NewUserSet() *UserSet { - vuSet := new(UserSet) +func NewTimedUserSet() UserSet { + vuSet := new(TimedUserSet) vuSet.validUserIds = make([]ID, 0, 16) vuSet.userHashes = make(map[string]int) @@ -28,7 +33,7 @@ func NewUserSet() *UserSet { return vuSet } -func (us *UserSet) updateUserHash(tick <-chan time.Time) { +func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) { now := time.Now().UTC() lastSec := now.Unix() - cacheDurationSec @@ -57,13 +62,13 @@ func (us *UserSet) updateUserHash(tick <-chan time.Time) { } } -func (us *UserSet) AddUser(user User) error { +func (us *TimedUserSet) AddUser(user User) error { id := user.Id us.validUserIds = append(us.validUserIds, id) return nil } -func (us UserSet) IsValidUserId(userHash []byte) (*ID, bool) { +func (us TimedUserSet) GetUser(userHash []byte) (*ID, bool) { idIndex, found := us.userHashes[string(userHash)] if found { return &us.validUserIds[idIndex], true