From 8f0cb97e895cf4b98d8bbe2dd3517173df0d10d1 Mon Sep 17 00:00:00 2001 From: V2Ray Date: Mon, 28 Sep 2015 01:11:40 +0200 Subject: [PATCH] Refactor timed queue --- common/collect/timed_map.go | 111 --------------------------- common/collect/timed_map_test.go | 48 ------------ common/collect/timed_queue.go | 89 +++++++++++++++++++++ common/collect/timed_queue_test.go | 60 +++++++++++++++ proxy/vmess/protocol/user/userset.go | 33 ++++++-- 5 files changed, 175 insertions(+), 166 deletions(-) delete mode 100644 common/collect/timed_map.go delete mode 100644 common/collect/timed_map_test.go create mode 100644 common/collect/timed_queue.go create mode 100644 common/collect/timed_queue_test.go diff --git a/common/collect/timed_map.go b/common/collect/timed_map.go deleted file mode 100644 index 47f4d9e53..000000000 --- a/common/collect/timed_map.go +++ /dev/null @@ -1,111 +0,0 @@ -package collect - -import ( - "container/heap" - "sync" - "time" -) - -type timedQueueEntry struct { - timeSec int64 - value interface{} -} - -type timedQueue []*timedQueueEntry - -func (queue timedQueue) Len() int { - return len(queue) -} - -func (queue timedQueue) Less(i, j int) bool { - return queue[i].timeSec < queue[j].timeSec -} - -func (queue timedQueue) Swap(i, j int) { - tmp := queue[i] - queue[i] = queue[j] - queue[j] = tmp -} - -func (queue *timedQueue) Push(value interface{}) { - entry := value.(*timedQueueEntry) - *queue = append(*queue, entry) -} - -func (queue *timedQueue) Pop() interface{} { - old := *queue - n := len(old) - v := old[n-1] - *queue = old[:n-1] - return v -} - -type TimedStringMap struct { - timedQueue - queueMutex sync.Mutex - dataMutext sync.RWMutex - data map[string]interface{} - interval int -} - -func NewTimedStringMap(updateInterval int) *TimedStringMap { - m := &TimedStringMap{ - timedQueue: make([]*timedQueueEntry, 0, 1024), - queueMutex: sync.Mutex{}, - dataMutext: sync.RWMutex{}, - data: make(map[string]interface{}, 1024), - interval: updateInterval, - } - m.initialize() - return m -} - -func (m *TimedStringMap) initialize() { - go m.cleanup(time.Tick(time.Duration(m.interval) * time.Second)) -} - -func (m *TimedStringMap) cleanup(tick <-chan time.Time) { - for { - now := <-tick - nowSec := now.UTC().Unix() - if m.timedQueue.Len() == 0 { - continue - } - for m.timedQueue.Len() > 0 { - entry := m.timedQueue[0] - if entry.timeSec > nowSec { - break - } - m.queueMutex.Lock() - entry = heap.Pop(&m.timedQueue).(*timedQueueEntry) - m.queueMutex.Unlock() - m.Remove(entry.value.(string)) - } - } -} - -func (m *TimedStringMap) Get(key string) (interface{}, bool) { - m.dataMutext.RLock() - value, ok := m.data[key] - m.dataMutext.RUnlock() - return value, ok -} - -func (m *TimedStringMap) Set(key string, value interface{}, time2Delete int64) { - m.dataMutext.Lock() - m.data[key] = value - m.dataMutext.Unlock() - - m.queueMutex.Lock() - heap.Push(&m.timedQueue, &timedQueueEntry{ - timeSec: time2Delete, - value: key, - }) - m.queueMutex.Unlock() -} - -func (m *TimedStringMap) Remove(key string) { - m.dataMutext.Lock() - delete(m.data, key) - m.dataMutext.Unlock() -} diff --git a/common/collect/timed_map_test.go b/common/collect/timed_map_test.go deleted file mode 100644 index 4f24946f8..000000000 --- a/common/collect/timed_map_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package collect - -import ( - "testing" - "time" - - "github.com/v2ray/v2ray-core/testing/unit" -) - -func TestTimedStringMap(t *testing.T) { - assert := unit.Assert(t) - - nowSec := time.Now().UTC().Unix() - m := NewTimedStringMap(2) - m.Set("Key1", "Value1", nowSec) - m.Set("Key2", "Value2", nowSec+5) - - v1, ok := m.Get("Key1") - assert.Bool(ok).IsTrue() - assert.String(v1.(string)).Equals("Value1") - - v2, ok := m.Get("Key2") - assert.Bool(ok).IsTrue() - assert.String(v2.(string)).Equals("Value2") - - tick := time.Tick(4 * time.Second) - <-tick - - v1, ok = m.Get("Key1") - assert.Bool(ok).IsFalse() - - v2, ok = m.Get("Key2") - assert.Bool(ok).IsTrue() - assert.String(v2.(string)).Equals("Value2") - - <-tick - v2, ok = m.Get("Key2") - assert.Bool(ok).IsFalse() - - <-tick - v2, ok = m.Get("Key2") - assert.Bool(ok).IsFalse() - - m.Set("Key1", "Value1", time.Now().UTC().Unix()+10) - v1, ok = m.Get("Key1") - assert.Bool(ok).IsTrue() - assert.String(v1.(string)).Equals("Value1") -} diff --git a/common/collect/timed_queue.go b/common/collect/timed_queue.go new file mode 100644 index 000000000..4505db476 --- /dev/null +++ b/common/collect/timed_queue.go @@ -0,0 +1,89 @@ +package collect + +import ( + "container/heap" + "sync" + "time" +) + +type timedQueueEntry struct { + timeSec int64 + value interface{} +} + +type timedQueueImpl []*timedQueueEntry + +func (queue timedQueueImpl) Len() int { + return len(queue) +} + +func (queue timedQueueImpl) Less(i, j int) bool { + return queue[i].timeSec < queue[j].timeSec +} + +func (queue timedQueueImpl) Swap(i, j int) { + tmp := queue[i] + queue[i] = queue[j] + queue[j] = tmp +} + +func (queue *timedQueueImpl) Push(value interface{}) { + entry := value.(*timedQueueEntry) + *queue = append(*queue, entry) +} + +func (queue *timedQueueImpl) Pop() interface{} { + old := *queue + n := len(old) + v := old[n-1] + *queue = old[:n-1] + return v +} + +type TimedQueue struct { + queue timedQueueImpl + access sync.Mutex + removed chan interface{} +} + +func NewTimedQueue(updateInterval int) *TimedQueue { + queue := &TimedQueue{ + queue: make([]*timedQueueEntry, 0, 256), + removed: make(chan interface{}, 16), + access: sync.Mutex{}, + } + go queue.cleanup(time.Tick(time.Duration(updateInterval) * time.Second)) + return queue +} + +func (queue *TimedQueue) Add(value interface{}, time2Remove int64) { + queue.access.Lock() + heap.Push(&queue.queue, &timedQueueEntry{ + timeSec: time2Remove, + value: value, + }) + queue.access.Unlock() +} + +func (queue *TimedQueue) RemovedEntries() <-chan interface{} { + return queue.removed +} + +func (queue *TimedQueue) cleanup(tick <-chan time.Time) { + for { + now := <-tick + if queue.queue.Len() == 0 { + continue + } + nowSec := now.UTC().Unix() + entry := queue.queue[0] + if entry.timeSec > nowSec { + continue + } + queue.access.Lock() + entry = heap.Pop(&queue.queue).(*timedQueueEntry) + queue.access.Unlock() + + queue.removed <- entry.value + } +} diff --git a/common/collect/timed_queue_test.go b/common/collect/timed_queue_test.go new file mode 100644 index 000000000..fb3c5f99e --- /dev/null +++ b/common/collect/timed_queue_test.go @@ -0,0 +1,60 @@ +package collect + +import ( + "testing" + "time" + + "github.com/v2ray/v2ray-core/testing/unit" +) + +func TestTimedQueue(t *testing.T) { + assert := unit.Assert(t) + + removed := make(map[string]bool) + + nowSec := time.Now().UTC().Unix() + q := NewTimedQueue(2) + + go func() { + for { + entry := <-q.RemovedEntries() + removed[entry.(string)] = true + } + }() + + q.Add("Value1", nowSec) + q.Add("Value2", nowSec+5) + + v1, ok := removed["Value1"] + assert.Bool(ok).IsFalse() + + v2, ok := removed["Value2"] + assert.Bool(ok).IsFalse() + + tick := time.Tick(4 * time.Second) + <-tick + + v1, ok = removed["Value1"] + assert.Bool(ok).IsTrue() + assert.Bool(v1).IsTrue() + removed["Value1"] = false + + v2, ok = removed["Value2"] + assert.Bool(ok).IsFalse() + + <-tick + v2, ok = removed["Value2"] + assert.Bool(ok).IsTrue() + assert.Bool(v2).IsTrue() + removed["Value2"] = false + + <-tick + assert.Bool(removed["Values"]).IsFalse() + + q.Add("Value1", time.Now().UTC().Unix()+10) + + <-tick + v1, ok = removed["Value1"] + assert.Bool(ok).IsTrue() + assert.Bool(v1).IsFalse() +} diff --git a/proxy/vmess/protocol/user/userset.go b/proxy/vmess/protocol/user/userset.go index 2e2d2c9be..ca8c2ea21 100644 --- a/proxy/vmess/protocol/user/userset.go +++ b/proxy/vmess/protocol/user/userset.go @@ -1,6 +1,7 @@ package user import ( + "sync" "time" "github.com/v2ray/v2ray-core/common/collect" @@ -18,8 +19,10 @@ type UserSet interface { } type TimedUserSet struct { - validUserIds []ID - userHash *collect.TimedStringMap + validUserIds []ID + userHash map[string]indexTimePair + userHashDeleteQueue *collect.TimedQueue + access sync.RWMutex } type indexTimePair struct { @@ -29,19 +32,34 @@ type indexTimePair struct { func NewTimedUserSet() UserSet { tus := &TimedUserSet{ - validUserIds: make([]ID, 0, 16), - userHash: collect.NewTimedStringMap(updateIntervalSec), + validUserIds: make([]ID, 0, 16), + userHash: make(map[string]indexTimePair, 512), + userHashDeleteQueue: collect.NewTimedQueue(updateIntervalSec), + access: sync.RWMutex{}, } go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second)) + go tus.removeEntries(tus.userHashDeleteQueue.RemovedEntries()) return tus } +func (us *TimedUserSet) removeEntries(entries <-chan interface{}) { + for { + entry := <-entries + us.access.Lock() + delete(us.userHash, entry.(string)) + us.access.Unlock() + } +} + func (us *TimedUserSet) generateNewHashes(lastSec, nowSec int64, idx int, id ID) { idHash := NewTimeHash(HMACHash{}) for lastSec < nowSec+cacheDurationSec { idHash := idHash.Hash(id.Bytes[:], lastSec) log.Debug("Valid User Hash: %v", idHash) - us.userHash.Set(string(idHash), indexTimePair{idx, lastSec}, lastSec+2*cacheDurationSec) + us.access.Lock() + us.userHash[string(idHash)] = indexTimePair{idx, lastSec} + us.access.Unlock() + us.userHashDeleteQueue.Add(string(idHash), lastSec+2*cacheDurationSec) lastSec++ } } @@ -73,9 +91,10 @@ func (us *TimedUserSet) AddUser(user User) error { } func (us TimedUserSet) GetUser(userHash []byte) (*ID, int64, bool) { - rawPair, found := us.userHash.Get(string(userHash)) + defer us.access.RUnlock() + us.access.RLock() + pair, found := us.userHash[string(userHash)] if found { - pair := rawPair.(indexTimePair) return &us.validUserIds[pair.index], pair.timeSec, true } return nil, 0, false