diff --git a/common/protocol/user_validator.go b/common/protocol/user_validator.go new file mode 100644 index 000000000..e8dc3526b --- /dev/null +++ b/common/protocol/user_validator.go @@ -0,0 +1,123 @@ +package protocol + +import ( + "hash" + "sync" + "time" +) + +const ( + updateIntervalSec = 10 + cacheDurationSec = 120 +) + +type IDHash func(key []byte) hash.Hash + +type idEntry struct { + id *ID + userIdx int + lastSec Timestamp + lastSecRemoval Timestamp +} + +type UserValidator interface { + Add(user *User) error + Get(timeHash []byte) (*User, Timestamp, bool) +} + +type TimedUserValidator struct { + validUsers []*User + userHash map[[16]byte]*indexTimePair + ids []*idEntry + access sync.RWMutex + hasher IDHash +} + +type indexTimePair struct { + index int + timeSec Timestamp +} + +func NewTimedUserValidator(hasher IDHash) UserValidator { + tus := &TimedUserValidator{ + validUsers: make([]*User, 0, 16), + userHash: make(map[[16]byte]*indexTimePair, 512), + access: sync.RWMutex{}, + ids: make([]*idEntry, 0, 512), + hasher: hasher, + } + go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second)) + return tus +} + +func (this *TimedUserValidator) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) { + var hashValue [16]byte + var hashValueRemoval [16]byte + idHash := this.hasher(entry.id.Bytes()) + for entry.lastSec <= nowSec { + idHash.Write(entry.lastSec.Bytes()) + idHash.Sum(hashValue[:0]) + idHash.Reset() + + idHash.Write(entry.lastSecRemoval.Bytes()) + idHash.Sum(hashValueRemoval[:0]) + idHash.Reset() + + this.access.Lock() + this.userHash[hashValue] = &indexTimePair{idx, entry.lastSec} + delete(this.userHash, hashValueRemoval) + this.access.Unlock() + + entry.lastSec++ + entry.lastSecRemoval++ + } +} + +func (this *TimedUserValidator) updateUserHash(tick <-chan time.Time) { + for now := range tick { + nowSec := Timestamp(now.Unix() + cacheDurationSec) + for _, entry := range this.ids { + this.generateNewHashes(nowSec, entry.userIdx, entry) + } + } +} + +func (this *TimedUserValidator) Add(user *User) error { + idx := len(this.validUsers) + this.validUsers = append(this.validUsers, user) + + nowSec := time.Now().Unix() + + entry := &idEntry{ + id: user.ID, + userIdx: idx, + lastSec: Timestamp(nowSec - cacheDurationSec), + lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3), + } + this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry) + this.ids = append(this.ids, entry) + for _, alterid := range user.AlterIDs { + entry := &idEntry{ + id: alterid, + userIdx: idx, + lastSec: Timestamp(nowSec - cacheDurationSec), + lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3), + } + this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry) + this.ids = append(this.ids, entry) + } + + return nil +} + +func (this *TimedUserValidator) Get(userHash []byte) (*User, Timestamp, bool) { + defer this.access.RUnlock() + this.access.RLock() + var fixedSizeHash [16]byte + copy(fixedSizeHash[:], userHash) + pair, found := this.userHash[fixedSizeHash] + if found { + return this.validUsers[pair.index], pair.timeSec, true + } + return nil, 0, false +} diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 15c9319cc..2b6ea9391 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -66,7 +66,7 @@ type VMessInboundHandler struct { sync.Mutex packetDispatcher dispatcher.PacketDispatcher inboundHandlerManager proxyman.InboundHandlerManager - clients protocol.UserSet + clients proto.UserValidator usersByEmail *userByEmail accepting bool listener *hub.TCPHub @@ -91,7 +91,7 @@ func (this *VMessInboundHandler) Close() { func (this *VMessInboundHandler) GetUser(email string) *proto.User { user, existing := this.usersByEmail.Get(email) if !existing { - this.clients.AddUser(user) + this.clients.Add(user) } return user } @@ -211,9 +211,9 @@ func init() { } config := rawConfig.(*Config) - allowedClients := protocol.NewTimedUserSet() + allowedClients := proto.NewTimedUserValidator(protocol.IDHash) for _, user := range config.AllowedUsers { - allowedClients.AddUser(user) + allowedClients.Add(user) } handler := &VMessInboundHandler{ diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 03f7ecd8a..58138ff1c 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -14,6 +14,7 @@ import ( v2io "github.com/v2ray/v2ray-core/common/io" "github.com/v2ray/v2ray-core/common/log" v2net "github.com/v2ray/v2ray-core/common/net" + proto "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/proxy" "github.com/v2ray/v2ray-core/proxy/internal" vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io" @@ -106,7 +107,7 @@ func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol buffer := alloc.NewBuffer().Clear() defer buffer.Release() - buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(protocol.Timestamp(time.Now().Unix()), 30), buffer) + buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), buffer) if err != nil { log.Error("VMessOut: Failed to serialize VMess request: ", err) return diff --git a/proxy/vmess/protocol/rand.go b/proxy/vmess/protocol/rand.go index d9e1c2528..ba6f75be0 100644 --- a/proxy/vmess/protocol/rand.go +++ b/proxy/vmess/protocol/rand.go @@ -2,25 +2,27 @@ package protocol import ( "math/rand" + + "github.com/v2ray/v2ray-core/common/protocol" ) type RandomTimestampGenerator interface { - Next() Timestamp + Next() protocol.Timestamp } type RealRandomTimestampGenerator struct { - base Timestamp + base protocol.Timestamp delta int } -func NewRandomTimestampGenerator(base Timestamp, delta int) RandomTimestampGenerator { +func NewRandomTimestampGenerator(base protocol.Timestamp, delta int) RandomTimestampGenerator { return &RealRandomTimestampGenerator{ base: base, delta: delta, } } -func (this *RealRandomTimestampGenerator) Next() Timestamp { +func (this *RealRandomTimestampGenerator) Next() protocol.Timestamp { rangeInDelta := rand.Intn(this.delta*2) - this.delta - return this.base + Timestamp(rangeInDelta) + return this.base + protocol.Timestamp(rangeInDelta) } diff --git a/proxy/vmess/protocol/rand_test.go b/proxy/vmess/protocol/rand_test.go index 2e502cde9..137e8b356 100644 --- a/proxy/vmess/protocol/rand_test.go +++ b/proxy/vmess/protocol/rand_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/v2ray/v2ray-core/common/protocol" . "github.com/v2ray/v2ray-core/proxy/vmess/protocol" v2testing "github.com/v2ray/v2ray-core/testing" "github.com/v2ray/v2ray-core/testing/assert" @@ -14,7 +15,7 @@ func TestGenerateRandomInt64InRange(t *testing.T) { base := time.Now().Unix() delta := 100 - generator := NewRandomTimestampGenerator(Timestamp(base), delta) + generator := NewRandomTimestampGenerator(protocol.Timestamp(base), delta) for i := 0; i < 100; i++ { v := int64(generator.Next()) diff --git a/proxy/vmess/protocol/testing/mockuserset.go b/proxy/vmess/protocol/testing/mockuserset.go index e12c1f772..28a54dfb2 100644 --- a/proxy/vmess/protocol/testing/mockuserset.go +++ b/proxy/vmess/protocol/testing/mockuserset.go @@ -1,22 +1,21 @@ package mocks import ( - proto "github.com/v2ray/v2ray-core/common/protocol" - "github.com/v2ray/v2ray-core/proxy/vmess/protocol" + "github.com/v2ray/v2ray-core/common/protocol" ) type MockUserSet struct { - Users []*proto.User + Users []*protocol.User UserHashes map[string]int Timestamps map[string]protocol.Timestamp } -func (us *MockUserSet) AddUser(user *proto.User) error { +func (us *MockUserSet) Add(user *protocol.User) error { us.Users = append(us.Users, user) return nil } -func (us *MockUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) { +func (us *MockUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) { idx, found := us.UserHashes[string(userhash)] if found { return us.Users[idx], us.Timestamps[string(userhash)], true diff --git a/proxy/vmess/protocol/testing/static_userset.go b/proxy/vmess/protocol/testing/static_userset.go index b0fa8c2bc..6289854bd 100644 --- a/proxy/vmess/protocol/testing/static_userset.go +++ b/proxy/vmess/protocol/testing/static_userset.go @@ -1,21 +1,20 @@ package mocks import ( - proto "github.com/v2ray/v2ray-core/common/protocol" + "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/common/uuid" - "github.com/v2ray/v2ray-core/proxy/vmess/protocol" ) type StaticUserSet struct { } -func (us *StaticUserSet) AddUser(user *proto.User) error { +func (us *StaticUserSet) Add(user *protocol.User) error { return nil } -func (us *StaticUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) { +func (us *StaticUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) { id, _ := uuid.ParseString("703e9102-eb57-499c-8b59-faf4f371bb21") - return &proto.User{ - ID: proto.NewID(id), + return &protocol.User{ + ID: protocol.NewID(id), }, 0, true } diff --git a/proxy/vmess/protocol/userset.go b/proxy/vmess/protocol/userset.go deleted file mode 100644 index ef4dc30bb..000000000 --- a/proxy/vmess/protocol/userset.go +++ /dev/null @@ -1,137 +0,0 @@ -package protocol - -import ( - "sync" - "time" - - proto "github.com/v2ray/v2ray-core/common/protocol" - "github.com/v2ray/v2ray-core/common/serial" -) - -const ( - updateIntervalSec = 10 - cacheDurationSec = 120 -) - -type Timestamp int64 - -func (this Timestamp) Bytes() []byte { - return serial.Int64Literal(this).Bytes() -} - -func (this Timestamp) HashBytes() []byte { - once := this.Bytes() - bytes := make([]byte, 0, 32) - bytes = append(bytes, once...) - bytes = append(bytes, once...) - bytes = append(bytes, once...) - bytes = append(bytes, once...) - return bytes -} - -type idEntry struct { - id *proto.ID - userIdx int - lastSec Timestamp - lastSecRemoval Timestamp -} - -type UserSet interface { - AddUser(user *proto.User) error - GetUser(timeHash []byte) (*proto.User, Timestamp, bool) -} - -type TimedUserSet struct { - validUsers []*proto.User - userHash map[[16]byte]*indexTimePair - ids []*idEntry - access sync.RWMutex -} - -type indexTimePair struct { - index int - timeSec Timestamp -} - -func NewTimedUserSet() UserSet { - tus := &TimedUserSet{ - validUsers: make([]*proto.User, 0, 16), - userHash: make(map[[16]byte]*indexTimePair, 512), - access: sync.RWMutex{}, - ids: make([]*idEntry, 0, 512), - } - go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second)) - return tus -} - -func (us *TimedUserSet) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) { - var hashValue [16]byte - var hashValueRemoval [16]byte - idHash := IDHash(entry.id.Bytes()) - for entry.lastSec <= nowSec { - idHash.Write(entry.lastSec.Bytes()) - idHash.Sum(hashValue[:0]) - idHash.Reset() - - idHash.Write(entry.lastSecRemoval.Bytes()) - idHash.Sum(hashValueRemoval[:0]) - idHash.Reset() - - us.access.Lock() - us.userHash[hashValue] = &indexTimePair{idx, entry.lastSec} - delete(us.userHash, hashValueRemoval) - us.access.Unlock() - - entry.lastSec++ - entry.lastSecRemoval++ - } -} - -func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) { - for now := range tick { - nowSec := Timestamp(now.Unix() + cacheDurationSec) - for _, entry := range us.ids { - us.generateNewHashes(nowSec, entry.userIdx, entry) - } - } -} - -func (us *TimedUserSet) AddUser(user *proto.User) error { - idx := len(us.validUsers) - us.validUsers = append(us.validUsers, user) - - nowSec := time.Now().Unix() - - entry := &idEntry{ - id: user.ID, - userIdx: idx, - lastSec: Timestamp(nowSec - cacheDurationSec), - lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3), - } - us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry) - us.ids = append(us.ids, entry) - for _, alterid := range user.AlterIDs { - entry := &idEntry{ - id: alterid, - userIdx: idx, - lastSec: Timestamp(nowSec - cacheDurationSec), - lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3), - } - us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry) - us.ids = append(us.ids, entry) - } - - return nil -} - -func (us *TimedUserSet) GetUser(userHash []byte) (*proto.User, Timestamp, bool) { - defer us.access.RUnlock() - us.access.RLock() - var fixedSizeHash [16]byte - copy(fixedSizeHash[:], userHash) - pair, found := us.userHash[fixedSizeHash] - if found { - return us.validUsers[pair.index], pair.timeSec, true - } - return nil, 0, false -} diff --git a/proxy/vmess/protocol/vmess.go b/proxy/vmess/protocol/vmess.go index daa2d3947..d569985ad 100644 --- a/proxy/vmess/protocol/vmess.go +++ b/proxy/vmess/protocol/vmess.go @@ -31,6 +31,16 @@ const ( blockSize = 16 ) +func hashTimestamp(t proto.Timestamp) []byte { + once := t.Bytes() + bytes := make([]byte, 0, 32) + bytes = append(bytes, once...) + bytes = append(bytes, once...) + bytes = append(bytes, once...) + bytes = append(bytes, once...) + return bytes +} + // VMessRequest implements the request message of VMess protocol. It only contains the header of a // request message. The data part will be handled by connection handler directly, in favor of data // streaming. @@ -61,11 +71,11 @@ func (this *VMessRequest) IsChunkStream() bool { // VMessRequestReader is a parser to read VMessRequest from a byte stream. type VMessRequestReader struct { - vUserSet UserSet + vUserSet proto.UserValidator } // NewVMessRequestReader creates a new VMessRequestReader with a given UserSet -func NewVMessRequestReader(vUserSet UserSet) *VMessRequestReader { +func NewVMessRequestReader(vUserSet proto.UserValidator) *VMessRequestReader { return &VMessRequestReader{ vUserSet: vUserSet, } @@ -82,13 +92,13 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { return nil, err } - userObj, timeSec, valid := this.vUserSet.GetUser(buffer.Value[:nBytes]) + userObj, timeSec, valid := this.vUserSet.Get(buffer.Value[:nBytes]) if !valid { return nil, proxy.ErrorInvalidAuthentication } timestampHash := TimestampHash() - timestampHash.Write(timeSec.HashBytes()) + timestampHash.Write(hashTimestamp(timeSec)) iv := timestampHash.Sum(nil) aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID.CmdKey(), iv) if err != nil { @@ -223,7 +233,7 @@ func (this *VMessRequest) ToBytes(timestampGenerator RandomTimestampGenerator, b encryptionEnd += 4 timestampHash := md5.New() - timestampHash.Write(timestamp.HashBytes()) + timestampHash.Write(hashTimestamp(timestamp)) iv := timestampHash.Sum(nil) aesStream, err := v2crypto.NewAesEncryptionStream(this.User.ID.CmdKey(), iv) if err != nil { diff --git a/proxy/vmess/protocol/vmess_test.go b/proxy/vmess/protocol/vmess_test.go index 1e3b7411e..34f706a14 100644 --- a/proxy/vmess/protocol/vmess_test.go +++ b/proxy/vmess/protocol/vmess_test.go @@ -17,10 +17,10 @@ import ( ) type FakeTimestampGenerator struct { - timestamp Timestamp + timestamp proto.Timestamp } -func (this *FakeTimestampGenerator) Next() Timestamp { +func (this *FakeTimestampGenerator) Next() proto.Timestamp { return this.timestamp } @@ -36,8 +36,8 @@ func TestVMessSerialization(t *testing.T) { ID: userId, } - userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)} - userSet.AddUser(testUser) + userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)} + userSet.Add(testUser) request := new(VMessRequest) request.Version = byte(0x01) @@ -54,7 +54,7 @@ func TestVMessSerialization(t *testing.T) { request.Address = v2net.DomainAddress("v2ray.com") request.Port = v2net.Port(80) - mockTime := Timestamp(1823730) + mockTime := proto.Timestamp(1823730) buffer, err := request.ToBytes(&FakeTimestampGenerator{timestamp: mockTime}, nil) if err != nil { @@ -92,12 +92,12 @@ func BenchmarkVMessRequestWriting(b *testing.B) { assert.Error(err).IsNil() userId := proto.NewID(id) - userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)} + userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)} testUser := &proto.User{ ID: userId, } - userSet.AddUser(testUser) + userSet.Add(testUser) request := new(VMessRequest) request.Version = byte(0x01) @@ -114,6 +114,6 @@ func BenchmarkVMessRequestWriting(b *testing.B) { request.Port = v2net.Port(80) for i := 0; i < b.N; i++ { - request.ToBytes(NewRandomTimestampGenerator(Timestamp(time.Now().Unix()), 30), nil) + request.ToBytes(NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), nil) } }