diff --git a/common/dice/dice.go b/common/dice/dice.go index 6dc94b716..f5a625592 100644 --- a/common/dice/dice.go +++ b/common/dice/dice.go @@ -28,6 +28,25 @@ func RollUint16() uint16 { return uint16(rand.Intn(65536)) } +func RollUint64() uint64 { + return rand.Uint64() +} + +func NewDeterministicDice(seed int64) *deterministicDice { + return &deterministicDice{rand.New(rand.NewSource(seed))} +} + +type deterministicDice struct { + *rand.Rand +} + +func (dd *deterministicDice) Roll(n int) int { + if n == 1 { + return 0 + } + return dd.Intn(n) +} + func init() { rand.Seed(time.Now().Unix()) } diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 6c04fa33c..a7b6e81ba 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -125,7 +125,26 @@ func parseSecurityType(b byte) protocol.SecurityType { // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream. func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) { buffer := buf.New() - defer buffer.Release() + behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed())) + DrainSize := behaviorRand.Roll(387) + 16 + 38 + readSizeRemain := DrainSize + + drainConnection := func(e error) error { + //We read a deterministic generated length of data before closing the connection to offset padding read pattern + readSizeRemain -= int(buffer.Len()) + if readSizeRemain > 0 { + err := s.DrainConnN(reader, readSizeRemain) + if err != nil { + return newError("failed to drain connection").Base(err).Base(e) + } + return newError("connection drained DrainSize = ", DrainSize).Base(e) + } + return e + } + + defer func() { + buffer.Release() + }() if _, err := buffer.ReadFullFrom(reader, protocol.IDBytesLen); err != nil { return nil, newError("failed to read request header").Base(err) @@ -133,7 +152,8 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request user, timestamp, valid := s.userValidator.Get(buffer.Bytes()) if !valid { - return nil, newError("invalid user") + //It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523 + return nil, drainConnection(newError("invalid user")) } iv := hashTimestamp(md5.New(), timestamp) @@ -142,6 +162,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:]) decryptor := crypto.NewCryptionReader(aesStream, reader) + readSizeRemain -= int(buffer.Len()) buffer.Clear() if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil { return nil, newError("failed to read request header").Base(err) @@ -197,12 +218,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request if actualHash != expectedHash { //It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523 - //We read a deterministic generated length of data before closing the connection to offset padding read pattern - drainSum := dice.RollDeterministic(48, int64(actualHash)) - if err := s.DrainConnN(reader, drainSum); err != nil { - return nil, newError("invalid auth, failed to drain connection").Base(err) - } - return nil, newError("invalid auth, connection drained") + return nil, drainConnection(newError("invalid auth")) } if request.Address == nil { diff --git a/proxy/vmess/validator.go b/proxy/vmess/validator.go index 1690fac72..00d14d440 100644 --- a/proxy/vmess/validator.go +++ b/proxy/vmess/validator.go @@ -3,9 +3,11 @@ package vmess import ( + "hash/crc64" "strings" "sync" "time" + "v2ray.com/core/common/dice" "v2ray.com/core/common" "v2ray.com/core/common/protocol" @@ -26,11 +28,13 @@ type user struct { // TimedUserValidator is a user Validator based on time. type TimedUserValidator struct { sync.RWMutex - users []*user - userHash map[[16]byte]indexTimePair - hasher protocol.IDHash - baseTime protocol.Timestamp - task *task.Periodic + users []*user + userHash map[[16]byte]indexTimePair + hasher protocol.IDHash + baseTime protocol.Timestamp + task *task.Periodic + behaviorSeed uint64 + behaviorFused bool } type indexTimePair struct { @@ -124,6 +128,11 @@ func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error { v.users = append(v.users, uu) v.generateNewHashes(protocol.Timestamp(nowSec), uu) + if v.behaviorFused == false { + account := uu.user.Account.(*MemoryAccount) + crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), account.ID.Bytes()) + } + return nil } @@ -131,6 +140,8 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protoco defer v.RUnlock() v.RLock() + v.behaviorFused = true + var fixedSizeHash [16]byte copy(fixedSizeHash[:], userHash) pair, found := v.userHash[fixedSizeHash] @@ -170,3 +181,13 @@ func (v *TimedUserValidator) Remove(email string) bool { func (v *TimedUserValidator) Close() error { return v.task.Close() } + +func (v *TimedUserValidator) GetBehaviorSeed() uint64 { + v.Lock() + defer v.Unlock() + v.behaviorFused = true + if v.behaviorSeed == 0 { + v.behaviorSeed = dice.RollUint64() + } + return v.behaviorSeed +}