From 87ba7dd0d152844c5f87274cbd2b16326b481a30 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Fri, 9 Feb 2018 11:32:12 +0100 Subject: [PATCH] implement remove user in vmess --- common/protocol/user_validator.go | 1 + proxy/vmess/inbound/inbound.go | 80 +++++++++++++++------ proxy/vmess/vmess.go | 112 +++++++++++++++++------------- proxy/vmess/vmess_test.go | 58 ++++++++++++++++ 4 files changed, 181 insertions(+), 70 deletions(-) create mode 100644 proxy/vmess/vmess_test.go diff --git a/common/protocol/user_validator.go b/common/protocol/user_validator.go index dfe779ed6..47816895c 100644 --- a/common/protocol/user_validator.go +++ b/common/protocol/user_validator.go @@ -3,4 +3,5 @@ package protocol type UserValidator interface { Add(user *User) error Get(timeHash []byte) (*User, Timestamp, bool) + Remove(email string) bool } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 26deba018..5f32b1231 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -5,6 +5,7 @@ package inbound import ( "context" "io" + "strings" "sync" "time" @@ -25,7 +26,7 @@ import ( ) type userByEmail struct { - sync.RWMutex + sync.Mutex cache map[string]*protocol.User defaultLevel uint32 defaultAlterIDs uint16 @@ -34,7 +35,7 @@ type userByEmail struct { func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail { cache := make(map[string]*protocol.User) for _, user := range users { - cache[user.Email] = user + cache[strings.ToLower(user.Email)] = user } return &userByEmail{ cache: cache, @@ -43,33 +44,59 @@ func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail } } +func (v *userByEmail) addNoLock(u *protocol.User) bool { + email := strings.ToLower(u.Email) + user, found := v.cache[email] + if found { + return false + } + v.cache[email] = user + return true +} + +func (v *userByEmail) Add(u *protocol.User) bool { + v.Lock() + defer v.Unlock() + + return v.addNoLock(u) +} + func (v *userByEmail) Get(email string) (*protocol.User, bool) { - var user *protocol.User - var found bool - v.RLock() - user, found = v.cache[email] - v.RUnlock() + email = strings.ToLower(email) + + v.Lock() + defer v.Unlock() + + user, found := v.cache[email] if !found { - v.Lock() - user, found = v.cache[email] - if !found { - id := uuid.New() - account := &vmess.Account{ - Id: id.String(), - AlterId: uint32(v.defaultAlterIDs), - } - user = &protocol.User{ - Level: v.defaultLevel, - Email: email, - Account: serial.ToTypedMessage(account), - } - v.cache[email] = user + id := uuid.New() + account := &vmess.Account{ + Id: id.String(), + AlterId: uint32(v.defaultAlterIDs), } - v.Unlock() + user = &protocol.User{ + Level: v.defaultLevel, + Email: email, + Account: serial.ToTypedMessage(account), + } + v.cache[email] = user } return user, found } +func (v *userByEmail) Remove(email string) bool { + email = strings.ToLower(email) + + v.Lock() + defer v.Unlock() + + if _, found := v.cache[email]; !found { + return false + } + delete(v.cache, email) + return true +} + // Handler is an inbound connection handler that handles messages in VMess protocol. type Handler struct { policyManager core.PolicyManager @@ -129,11 +156,18 @@ func (h *Handler) GetUser(email string) *protocol.User { } func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error { + if !h.usersByEmail.Add(user) { + return newError("User ", user.Email, " already exists.") + } return h.clients.Add(user) } func (h *Handler) RemoveUser(ctx context.Context, email string) error { - return newError("not implemented") + if !h.usersByEmail.Remove(email) { + return newError("User ", email, " not found.") + } + h.clients.Remove(email) + return nil } func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error { diff --git a/proxy/vmess/vmess.go b/proxy/vmess/vmess.go index 4c1cc5c35..52ead21f2 100644 --- a/proxy/vmess/vmess.go +++ b/proxy/vmess/vmess.go @@ -8,6 +8,7 @@ package vmess //go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg vmess -path Proxy,VMess import ( + "strings" "sync" "time" @@ -21,34 +22,32 @@ const ( cacheDurationSec = 120 ) -type idEntry struct { - id *protocol.ID - userIdx int +type user struct { + user *protocol.User + account *InternalAccount lastSec protocol.Timestamp } type TimedUserValidator struct { sync.RWMutex - validUsers []*protocol.User - userHash map[[16]byte]indexTimePair - ids []*idEntry - hasher protocol.IDHash - baseTime protocol.Timestamp - task *signal.PeriodicTask + users []*user + userHash map[[16]byte]indexTimePair + hasher protocol.IDHash + baseTime protocol.Timestamp + task *signal.PeriodicTask } type indexTimePair struct { - index int + user *user timeInc uint32 } func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator { tuv := &TimedUserValidator{ - validUsers: make([]*protocol.User, 0, 16), - userHash: make(map[[16]byte]indexTimePair, 512), - ids: make([]*idEntry, 0, 512), - hasher: hasher, - baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3), + users: make([]*user, 0, 16), + userHash: make(map[[16]byte]indexTimePair, 1024), + hasher: hasher, + baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3), } tuv.task = &signal.PeriodicTask{ Interval: updateInterval, @@ -61,21 +60,27 @@ func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator { return tuv } -func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx int, entry *idEntry) { +func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) { var hashValue [16]byte - idHash := v.hasher(entry.id.Bytes()) - for entry.lastSec <= nowSec { - common.Must2(idHash.Write(entry.lastSec.Bytes(nil))) - idHash.Sum(hashValue[:0]) - idHash.Reset() + genHashForID := func(id *protocol.ID) { + idHash := v.hasher(id.Bytes()) + for ts := user.lastSec; ts <= nowSec; ts++ { + common.Must2(idHash.Write(ts.Bytes(nil))) + idHash.Sum(hashValue[:0]) + idHash.Reset() - v.userHash[hashValue] = indexTimePair{ - index: idx, - timeInc: uint32(entry.lastSec - v.baseTime), + v.userHash[hashValue] = indexTimePair{ + user: user, + timeInc: uint32(ts - v.baseTime), + } } - - entry.lastSec++ } + + genHashForID(user.account.ID) + for _, id := range user.account.AlterIDs { + genHashForID(id) + } + user.lastSec = nowSec } func (v *TimedUserValidator) removeExpiredHashes(expire uint32) { @@ -92,8 +97,8 @@ func (v *TimedUserValidator) updateUserHash() { v.Lock() defer v.Unlock() - for _, entry := range v.ids { - v.generateNewHashes(nowSec, entry.userIdx, entry) + for _, user := range v.users { + v.generateNewHashes(nowSec, user) } expire := protocol.Timestamp(now.Unix() - cacheDurationSec*3) @@ -102,13 +107,11 @@ func (v *TimedUserValidator) updateUserHash() { } } -func (v *TimedUserValidator) Add(user *protocol.User) error { +func (v *TimedUserValidator) Add(u *protocol.User) error { v.Lock() defer v.Unlock() - idx := len(v.validUsers) - v.validUsers = append(v.validUsers, user) - rawAccount, err := user.GetTypedAccount() + rawAccount, err := u.GetTypedAccount() if err != nil { return err } @@ -116,22 +119,13 @@ func (v *TimedUserValidator) Add(user *protocol.User) error { nowSec := time.Now().Unix() - entry := &idEntry{ - id: account.ID, - userIdx: idx, + uu := &user{ + user: u, + account: account, lastSec: protocol.Timestamp(nowSec - cacheDurationSec), } - v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry) - v.ids = append(v.ids, entry) - for _, alterid := range account.AlterIDs { - entry := &idEntry{ - id: alterid, - userIdx: idx, - lastSec: protocol.Timestamp(nowSec - cacheDurationSec), - } - v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry) - v.ids = append(v.ids, entry) - } + v.users = append(v.users, uu) + v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), uu) return nil } @@ -144,11 +138,35 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Time copy(fixedSizeHash[:], userHash) pair, found := v.userHash[fixedSizeHash] if found { - return v.validUsers[pair.index], protocol.Timestamp(pair.timeInc) + v.baseTime, true + return pair.user.user, protocol.Timestamp(pair.timeInc) + v.baseTime, true } return nil, 0, false } +func (v *TimedUserValidator) Remove(email string) bool { + v.Lock() + defer v.Unlock() + + email = strings.ToLower(email) + idx := -1 + for i, u := range v.users { + if strings.ToLower(u.user.Email) == email { + idx = i + break + } + } + if idx == -1 { + return false + } + ulen := len(v.users) + if idx < len(v.users) { + v.users[idx] = v.users[ulen-1] + v.users[ulen-1] = nil + v.users = v.users[:ulen-1] + } + return true +} + // Close implements common.Closable. func (v *TimedUserValidator) Close() error { return v.task.Close() diff --git a/proxy/vmess/vmess_test.go b/proxy/vmess/vmess_test.go new file mode 100644 index 000000000..b3c805591 --- /dev/null +++ b/proxy/vmess/vmess_test.go @@ -0,0 +1,58 @@ +package vmess_test + +import ( + "testing" + "time" + + "v2ray.com/core/common" + "v2ray.com/core/common/serial" + "v2ray.com/core/common/uuid" + + "v2ray.com/core/common/protocol" + . "v2ray.com/core/proxy/vmess" + . "v2ray.com/ext/assert" +) + +func TestUserValidator(t *testing.T) { + assert := With(t) + + hasher := protocol.DefaultIDHash + v := NewTimedUserValidator(hasher) + defer common.Close(v) + + id := uuid.New() + user := &protocol.User{ + Email: "test", + Account: serial.ToTypedMessage(&Account{ + Id: id.String(), + AlterId: 8, + }), + } + common.Must(v.Add(user)) + + { + ts := protocol.Timestamp(time.Now().Unix()) + idHash := hasher(id.Bytes()) + idHash.Write(ts.Bytes(nil)) + userHash := idHash.Sum(nil) + + euser, ets, found := v.Get(userHash) + assert(found, IsTrue) + assert(euser.Email, Equals, user.Email) + assert(int64(ets), Equals, int64(ts)) + } + + { + ts := protocol.Timestamp(time.Now().Add(time.Second * 500).Unix()) + idHash := hasher(id.Bytes()) + idHash.Write(ts.Bytes(nil)) + userHash := idHash.Sum(nil) + + euser, _, found := v.Get(userHash) + assert(found, IsFalse) + assert(euser, IsNil) + } + + assert(v.Remove(user.Email), IsTrue) + assert(v.Remove(user.Email), IsFalse) +}