diff --git a/common/collect/timed_map.go b/common/collect/timed_map.go new file mode 100644 index 000000000..0b017c086 --- /dev/null +++ b/common/collect/timed_map.go @@ -0,0 +1,106 @@ +package collect + +import ( + "container/heap" + "sync" + "time" +) + +type timedQueueEntry struct { + timeSec int64 + value interface{} +} + +type timedQueue []*timedQueueEntry + +func (queue timedQueue) Len() int { + return len(queue) +} + +func (queue timedQueue) Less(i, j int) bool { + return queue[i].timeSec < queue[j].timeSec +} + +func (queue timedQueue) Swap(i, j int) { + tmp := queue[i] + queue[i] = queue[j] + queue[j] = tmp +} + +func (queue *timedQueue) Push(value interface{}) { + entry := value.(*timedQueueEntry) + *queue = append(*queue, entry) +} + +func (queue *timedQueue) Pop() interface{} { + old := *queue + n := len(old) + v := old[n-1] + *queue = old[:n-1] + return v +} + +type TimedStringMap struct { + timedQueue + access sync.RWMutex + data map[string]interface{} + interval int +} + +func NewTimedStringMap(updateInterval int) *TimedStringMap { + m := &TimedStringMap{ + timedQueue: make([]*timedQueueEntry, 0, 1024), + access: sync.RWMutex{}, + data: make(map[string]interface{}, 1024), + interval: updateInterval, + } + m.initialize() + return m +} + +func (m *TimedStringMap) initialize() { + go m.cleanup(time.Tick(time.Duration(m.interval) * time.Second)) +} + +func (m *TimedStringMap) cleanup(tick <-chan time.Time) { + for { + now := <-tick + nowSec := now.UTC().Unix() + if m.timedQueue.Len() == 0 { + continue + } + for m.timedQueue.Len() > 0 { + entry := m.timedQueue[0] + if entry.timeSec > nowSec { + break + } + m.access.Lock() + entry = heap.Pop(&m.timedQueue).(*timedQueueEntry) + m.access.Unlock() + m.Remove(entry.value.(string)) + } + } +} + +func (m *TimedStringMap) Get(key string) (interface{}, bool) { + m.access.RLock() + value, ok := m.data[key] + m.access.RUnlock() + return value, ok +} + +func (m *TimedStringMap) Set(key string, value interface{}, time2Delete int64) { + m.access.Lock() + m.data[key] = value + heap.Push(&m.timedQueue, &timedQueueEntry{ + timeSec: time2Delete, + value: key, + }) + m.access.Unlock() +} + +func (m *TimedStringMap) Remove(key string) { + m.access.Lock() + delete(m.data, key) + m.access.Unlock() +} diff --git a/proxy/vmess/protocol/user/userset.go b/proxy/vmess/protocol/user/userset.go index 6088f685d..0575382cc 100644 --- a/proxy/vmess/protocol/user/userset.go +++ b/proxy/vmess/protocol/user/userset.go @@ -1,9 +1,9 @@ package user import ( - "container/heap" "time" + "github.com/v2ray/v2ray-core/common/collect" "github.com/v2ray/v2ray-core/common/log" ) @@ -19,8 +19,7 @@ type UserSet interface { type TimedUserSet struct { validUserIds []ID - userHashes map[string]indexTimePair - hash2Remove hashEntrySet + userHash *collect.TimedStringMap } type indexTimePair struct { @@ -28,58 +27,21 @@ type indexTimePair struct { timeSec int64 } -type hashEntry struct { - hash string - timeSec int64 -} - -type hashEntrySet []*hashEntry - -func (set hashEntrySet) Len() int { - return len(set) -} - -func (set hashEntrySet) Less(i, j int) bool { - return set[i].timeSec < set[j].timeSec -} - -func (set hashEntrySet) Swap(i, j int) { - tmp := set[i] - set[i] = set[j] - set[j] = tmp -} - -func (set *hashEntrySet) Push(value interface{}) { - entry := value.(*hashEntry) - *set = append(*set, entry) -} - -func (set *hashEntrySet) Pop() interface{} { - old := *set - n := len(old) - v := old[n-1] - *set = old[:n-1] - return v -} - func NewTimedUserSet() UserSet { - vuSet := new(TimedUserSet) - vuSet.validUserIds = make([]ID, 0, 16) - vuSet.userHashes = make(map[string]indexTimePair) - vuSet.hash2Remove = make(hashEntrySet, 0, cacheDurationSec*10) - - go vuSet.updateUserHash(time.Tick(updateIntervalSec * time.Second)) - return vuSet + tus := &TimedUserSet{ + validUserIds: make([]ID, 0, 16), + userHash: collect.NewTimedStringMap(updateIntervalSec), + } + go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second)) + return tus } func (us *TimedUserSet) generateNewHashes(lastSec, nowSec int64, idx int, id ID) { idHash := NewTimeHash(HMACHash{}) for lastSec < nowSec+cacheDurationSec { - idHash := idHash.Hash(id.Bytes, lastSec) log.Debug("Valid User Hash: %v", idHash) - heap.Push(&us.hash2Remove, &hashEntry{string(idHash), lastSec}) - us.userHashes[string(idHash)] = indexTimePair{idx, lastSec} + us.userHash.Set(string(idHash), indexTimePair{idx, lastSec}, lastSec+2*cacheDurationSec) lastSec++ } } @@ -87,24 +49,14 @@ func (us *TimedUserSet) generateNewHashes(lastSec, nowSec int64, idx int, id ID) func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) { now := time.Now().UTC() lastSec := now.Unix() - lastSec2Remove := now.Unix() for { now := <-tick nowSec := now.UTC().Unix() - - remove2Sec := nowSec - cacheDurationSec - if remove2Sec > lastSec2Remove { - for lastSec2Remove+1 < remove2Sec { - front := heap.Pop(&us.hash2Remove) - entry := front.(*hashEntry) - lastSec2Remove = entry.timeSec - delete(us.userHashes, entry.hash) - } - } for idx, id := range us.validUserIds { us.generateNewHashes(lastSec, nowSec, idx, id) } + lastSec = nowSec } } @@ -121,8 +73,9 @@ func (us *TimedUserSet) AddUser(user User) error { } func (us TimedUserSet) GetUser(userHash []byte) (*ID, int64, bool) { - pair, found := us.userHashes[string(userHash)] + rawPair, found := us.userHash.Get(string(userHash)) if found { + pair := rawPair.(indexTimePair) return &us.validUserIds[pair.index], pair.timeSec, true } return nil, 0, false