1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-02 15:36:41 -05:00

implement remove user in vmess

This commit is contained in:
Darien Raymond 2018-02-09 11:32:12 +01:00
parent c1fc7c738a
commit 87ba7dd0d1
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
4 changed files with 181 additions and 70 deletions

View File

@ -3,4 +3,5 @@ package protocol
type UserValidator interface { type UserValidator interface {
Add(user *User) error Add(user *User) error
Get(timeHash []byte) (*User, Timestamp, bool) Get(timeHash []byte) (*User, Timestamp, bool)
Remove(email string) bool
} }

View File

@ -5,6 +5,7 @@ package inbound
import ( import (
"context" "context"
"io" "io"
"strings"
"sync" "sync"
"time" "time"
@ -25,7 +26,7 @@ import (
) )
type userByEmail struct { type userByEmail struct {
sync.RWMutex sync.Mutex
cache map[string]*protocol.User cache map[string]*protocol.User
defaultLevel uint32 defaultLevel uint32
defaultAlterIDs uint16 defaultAlterIDs uint16
@ -34,7 +35,7 @@ type userByEmail struct {
func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail { func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail {
cache := make(map[string]*protocol.User) cache := make(map[string]*protocol.User)
for _, user := range users { for _, user := range users {
cache[user.Email] = user cache[strings.ToLower(user.Email)] = user
} }
return &userByEmail{ return &userByEmail{
cache: cache, cache: cache,
@ -43,33 +44,59 @@ func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail
} }
} }
func (v *userByEmail) addNoLock(u *protocol.User) bool {
email := strings.ToLower(u.Email)
user, found := v.cache[email]
if found {
return false
}
v.cache[email] = user
return true
}
func (v *userByEmail) Add(u *protocol.User) bool {
v.Lock()
defer v.Unlock()
return v.addNoLock(u)
}
func (v *userByEmail) Get(email string) (*protocol.User, bool) { func (v *userByEmail) Get(email string) (*protocol.User, bool) {
var user *protocol.User email = strings.ToLower(email)
var found bool
v.RLock() v.Lock()
user, found = v.cache[email] defer v.Unlock()
v.RUnlock()
user, found := v.cache[email]
if !found { if !found {
v.Lock() id := uuid.New()
user, found = v.cache[email] account := &vmess.Account{
if !found { Id: id.String(),
id := uuid.New() AlterId: uint32(v.defaultAlterIDs),
account := &vmess.Account{
Id: id.String(),
AlterId: uint32(v.defaultAlterIDs),
}
user = &protocol.User{
Level: v.defaultLevel,
Email: email,
Account: serial.ToTypedMessage(account),
}
v.cache[email] = user
} }
v.Unlock() user = &protocol.User{
Level: v.defaultLevel,
Email: email,
Account: serial.ToTypedMessage(account),
}
v.cache[email] = user
} }
return user, found return user, found
} }
func (v *userByEmail) Remove(email string) bool {
email = strings.ToLower(email)
v.Lock()
defer v.Unlock()
if _, found := v.cache[email]; !found {
return false
}
delete(v.cache, email)
return true
}
// Handler is an inbound connection handler that handles messages in VMess protocol. // Handler is an inbound connection handler that handles messages in VMess protocol.
type Handler struct { type Handler struct {
policyManager core.PolicyManager policyManager core.PolicyManager
@ -129,11 +156,18 @@ func (h *Handler) GetUser(email string) *protocol.User {
} }
func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error { func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error {
if !h.usersByEmail.Add(user) {
return newError("User ", user.Email, " already exists.")
}
return h.clients.Add(user) return h.clients.Add(user)
} }
func (h *Handler) RemoveUser(ctx context.Context, email string) error { func (h *Handler) RemoveUser(ctx context.Context, email string) error {
return newError("not implemented") if !h.usersByEmail.Remove(email) {
return newError("User ", email, " not found.")
}
h.clients.Remove(email)
return nil
} }
func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error { func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error {

View File

@ -8,6 +8,7 @@ package vmess
//go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg vmess -path Proxy,VMess //go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg vmess -path Proxy,VMess
import ( import (
"strings"
"sync" "sync"
"time" "time"
@ -21,34 +22,32 @@ const (
cacheDurationSec = 120 cacheDurationSec = 120
) )
type idEntry struct { type user struct {
id *protocol.ID user *protocol.User
userIdx int account *InternalAccount
lastSec protocol.Timestamp lastSec protocol.Timestamp
} }
type TimedUserValidator struct { type TimedUserValidator struct {
sync.RWMutex sync.RWMutex
validUsers []*protocol.User users []*user
userHash map[[16]byte]indexTimePair userHash map[[16]byte]indexTimePair
ids []*idEntry hasher protocol.IDHash
hasher protocol.IDHash baseTime protocol.Timestamp
baseTime protocol.Timestamp task *signal.PeriodicTask
task *signal.PeriodicTask
} }
type indexTimePair struct { type indexTimePair struct {
index int user *user
timeInc uint32 timeInc uint32
} }
func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator { func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator {
tuv := &TimedUserValidator{ tuv := &TimedUserValidator{
validUsers: make([]*protocol.User, 0, 16), users: make([]*user, 0, 16),
userHash: make(map[[16]byte]indexTimePair, 512), userHash: make(map[[16]byte]indexTimePair, 1024),
ids: make([]*idEntry, 0, 512), hasher: hasher,
hasher: hasher, baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3),
baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3),
} }
tuv.task = &signal.PeriodicTask{ tuv.task = &signal.PeriodicTask{
Interval: updateInterval, Interval: updateInterval,
@ -61,21 +60,27 @@ func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator {
return tuv return tuv
} }
func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx int, entry *idEntry) { func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
var hashValue [16]byte var hashValue [16]byte
idHash := v.hasher(entry.id.Bytes()) genHashForID := func(id *protocol.ID) {
for entry.lastSec <= nowSec { idHash := v.hasher(id.Bytes())
common.Must2(idHash.Write(entry.lastSec.Bytes(nil))) for ts := user.lastSec; ts <= nowSec; ts++ {
idHash.Sum(hashValue[:0]) common.Must2(idHash.Write(ts.Bytes(nil)))
idHash.Reset() idHash.Sum(hashValue[:0])
idHash.Reset()
v.userHash[hashValue] = indexTimePair{ v.userHash[hashValue] = indexTimePair{
index: idx, user: user,
timeInc: uint32(entry.lastSec - v.baseTime), timeInc: uint32(ts - v.baseTime),
}
} }
entry.lastSec++
} }
genHashForID(user.account.ID)
for _, id := range user.account.AlterIDs {
genHashForID(id)
}
user.lastSec = nowSec
} }
func (v *TimedUserValidator) removeExpiredHashes(expire uint32) { func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
@ -92,8 +97,8 @@ func (v *TimedUserValidator) updateUserHash() {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
for _, entry := range v.ids { for _, user := range v.users {
v.generateNewHashes(nowSec, entry.userIdx, entry) v.generateNewHashes(nowSec, user)
} }
expire := protocol.Timestamp(now.Unix() - cacheDurationSec*3) expire := protocol.Timestamp(now.Unix() - cacheDurationSec*3)
@ -102,13 +107,11 @@ func (v *TimedUserValidator) updateUserHash() {
} }
} }
func (v *TimedUserValidator) Add(user *protocol.User) error { func (v *TimedUserValidator) Add(u *protocol.User) error {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
idx := len(v.validUsers) rawAccount, err := u.GetTypedAccount()
v.validUsers = append(v.validUsers, user)
rawAccount, err := user.GetTypedAccount()
if err != nil { if err != nil {
return err return err
} }
@ -116,22 +119,13 @@ func (v *TimedUserValidator) Add(user *protocol.User) error {
nowSec := time.Now().Unix() nowSec := time.Now().Unix()
entry := &idEntry{ uu := &user{
id: account.ID, user: u,
userIdx: idx, account: account,
lastSec: protocol.Timestamp(nowSec - cacheDurationSec), lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
} }
v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry) v.users = append(v.users, uu)
v.ids = append(v.ids, entry) v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), uu)
for _, alterid := range account.AlterIDs {
entry := &idEntry{
id: alterid,
userIdx: idx,
lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
}
v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry)
v.ids = append(v.ids, entry)
}
return nil return nil
} }
@ -144,11 +138,35 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Time
copy(fixedSizeHash[:], userHash) copy(fixedSizeHash[:], userHash)
pair, found := v.userHash[fixedSizeHash] pair, found := v.userHash[fixedSizeHash]
if found { if found {
return v.validUsers[pair.index], protocol.Timestamp(pair.timeInc) + v.baseTime, true return pair.user.user, protocol.Timestamp(pair.timeInc) + v.baseTime, true
} }
return nil, 0, false return nil, 0, false
} }
func (v *TimedUserValidator) Remove(email string) bool {
v.Lock()
defer v.Unlock()
email = strings.ToLower(email)
idx := -1
for i, u := range v.users {
if strings.ToLower(u.user.Email) == email {
idx = i
break
}
}
if idx == -1 {
return false
}
ulen := len(v.users)
if idx < len(v.users) {
v.users[idx] = v.users[ulen-1]
v.users[ulen-1] = nil
v.users = v.users[:ulen-1]
}
return true
}
// Close implements common.Closable. // Close implements common.Closable.
func (v *TimedUserValidator) Close() error { func (v *TimedUserValidator) Close() error {
return v.task.Close() return v.task.Close()

58
proxy/vmess/vmess_test.go Normal file
View File

@ -0,0 +1,58 @@
package vmess_test
import (
"testing"
"time"
"v2ray.com/core/common"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/uuid"
"v2ray.com/core/common/protocol"
. "v2ray.com/core/proxy/vmess"
. "v2ray.com/ext/assert"
)
func TestUserValidator(t *testing.T) {
assert := With(t)
hasher := protocol.DefaultIDHash
v := NewTimedUserValidator(hasher)
defer common.Close(v)
id := uuid.New()
user := &protocol.User{
Email: "test",
Account: serial.ToTypedMessage(&Account{
Id: id.String(),
AlterId: 8,
}),
}
common.Must(v.Add(user))
{
ts := protocol.Timestamp(time.Now().Unix())
idHash := hasher(id.Bytes())
idHash.Write(ts.Bytes(nil))
userHash := idHash.Sum(nil)
euser, ets, found := v.Get(userHash)
assert(found, IsTrue)
assert(euser.Email, Equals, user.Email)
assert(int64(ets), Equals, int64(ts))
}
{
ts := protocol.Timestamp(time.Now().Add(time.Second * 500).Unix())
idHash := hasher(id.Bytes())
idHash.Write(ts.Bytes(nil))
userHash := idHash.Sum(nil)
euser, _, found := v.Get(userHash)
assert(found, IsFalse)
assert(euser, IsNil)
}
assert(v.Remove(user.Email), IsTrue)
assert(v.Remove(user.Email), IsFalse)
}