diff --git a/common/antireplay/antireplay.go b/common/antireplay/antireplay.go deleted file mode 100644 index 9ac4300fd..000000000 --- a/common/antireplay/antireplay.go +++ /dev/null @@ -1,51 +0,0 @@ -package antireplay - -import ( - "sync" - "time" - - cuckoo "github.com/seiflotfy/cuckoofilter" -) - -func NewAntiReplayWindow(antiReplayTime int64) *AntiReplayWindow { - arw := &AntiReplayWindow{} - arw.AntiReplayTime = antiReplayTime - return arw -} - -type AntiReplayWindow struct { - lock sync.Mutex - poolA *cuckoo.Filter - poolB *cuckoo.Filter - lastSwapTime int64 - PoolSwap bool - AntiReplayTime int64 -} - -func (aw *AntiReplayWindow) Check(sum []byte) bool { - aw.lock.Lock() - - if aw.lastSwapTime == 0 { - aw.lastSwapTime = time.Now().Unix() - aw.poolA = cuckoo.NewFilter(100000) - aw.poolB = cuckoo.NewFilter(100000) - } - - tnow := time.Now().Unix() - timediff := tnow - aw.lastSwapTime - - if timediff >= aw.AntiReplayTime { - if aw.PoolSwap { - aw.PoolSwap = false - aw.poolA.Reset() - } else { - aw.PoolSwap = true - aw.poolB.Reset() - } - aw.lastSwapTime = tnow - } - - ret := aw.poolA.InsertUnique(sum) && aw.poolB.InsertUnique(sum) - aw.lock.Unlock() - return ret -} diff --git a/common/antireplay/replayfilter.go b/common/antireplay/replayfilter.go new file mode 100644 index 000000000..4e0783d12 --- /dev/null +++ b/common/antireplay/replayfilter.go @@ -0,0 +1,58 @@ +package antireplay + +import ( + "sync" + "time" + + cuckoo "github.com/seiflotfy/cuckoofilter" +) + +const replayFilterCapacity = 100000 + +// ReplayFilter check for replay attacks. +type ReplayFilter struct { + lock sync.Mutex + poolA *cuckoo.Filter + poolB *cuckoo.Filter + poolSwap bool + lastSwap int64 + interval int64 +} + +// NewReplayFilter create a new filter with specifying the expiration time interval in seconds. +func NewReplayFilter(interval int64) *ReplayFilter { + filter := &ReplayFilter{} + filter.interval = interval + return filter +} + +// Interval in second for expiration time for duplicate records. +func (filter *ReplayFilter) Interval() int64 { + return filter.interval +} + +// Check determine if there are duplicate records. +func (filter *ReplayFilter) Check(sum []byte) bool { + filter.lock.Lock() + defer filter.lock.Unlock() + + now := time.Now().Unix() + if filter.lastSwap == 0 { + filter.lastSwap = now + filter.poolA = cuckoo.NewFilter(replayFilterCapacity) + filter.poolB = cuckoo.NewFilter(replayFilterCapacity) + } + + elapsed := now - filter.lastSwap + if elapsed >= filter.Interval() { + if filter.poolSwap { + filter.poolA.Reset() + } else { + filter.poolB.Reset() + } + filter.poolSwap = !filter.poolSwap + filter.lastSwap = now + } + + return filter.poolA.InsertUnique(sum) && filter.poolB.InsertUnique(sum) +} diff --git a/proxy/vmess/aead/authid.go b/proxy/vmess/aead/authid.go index adde14a75..88e9b1bd3 100644 --- a/proxy/vmess/aead/authid.go +++ b/proxy/vmess/aead/authid.go @@ -13,7 +13,7 @@ import ( "time" "v2ray.com/core/common" - antiReplayWindow "v2ray.com/core/common/antireplay" + "v2ray.com/core/common/antireplay" ) var ( @@ -66,12 +66,12 @@ func (aidd *AuthIDDecoder) Decode(data [16]byte) (int64, uint32, int32, []byte) } func NewAuthIDDecoderHolder() *AuthIDDecoderHolder { - return &AuthIDDecoderHolder{make(map[string]*AuthIDDecoderItem), antiReplayWindow.NewAntiReplayWindow(120)} + return &AuthIDDecoderHolder{make(map[string]*AuthIDDecoderItem), antireplay.NewReplayFilter(120)} } type AuthIDDecoderHolder struct { - aidhi map[string]*AuthIDDecoderItem - apw *antiReplayWindow.AntiReplayWindow + decoders map[string]*AuthIDDecoderItem + filter *antireplay.ReplayFilter } type AuthIDDecoderItem struct { @@ -87,16 +87,16 @@ func NewAuthIDDecoderItem(key [16]byte, ticket interface{}) *AuthIDDecoderItem { } func (a *AuthIDDecoderHolder) AddUser(key [16]byte, ticket interface{}) { - a.aidhi[string(key[:])] = NewAuthIDDecoderItem(key, ticket) + a.decoders[string(key[:])] = NewAuthIDDecoderItem(key, ticket) } func (a *AuthIDDecoderHolder) RemoveUser(key [16]byte) { - delete(a.aidhi, string(key[:])) + delete(a.decoders, string(key[:])) } func (a *AuthIDDecoderHolder) Match(authID [16]byte) (interface{}, error) { - for _, v := range a.aidhi { - t, z, r, d := v.dec.Decode(authID) + for _, v := range a.decoders { + t, z, _, d := v.dec.Decode(authID) if z != crc32.ChecksumIEEE(d[:12]) { continue } @@ -109,12 +109,10 @@ func (a *AuthIDDecoderHolder) Match(authID [16]byte) (interface{}, error) { continue } - if !a.apw.Check(authID[:]) { + if !a.filter.Check(authID[:]) { return nil, ErrReplay } - _ = r - return v.ticket, nil } return nil, ErrNotFound