1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-21 09:36:34 -05:00

implement remove user in vmess

This commit is contained in:
Darien Raymond 2018-02-09 11:32:12 +01:00
parent c1fc7c738a
commit 87ba7dd0d1
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
4 changed files with 181 additions and 70 deletions

View File

@ -3,4 +3,5 @@ package protocol
type UserValidator interface {
Add(user *User) error
Get(timeHash []byte) (*User, Timestamp, bool)
Remove(email string) bool
}

View File

@ -5,6 +5,7 @@ package inbound
import (
"context"
"io"
"strings"
"sync"
"time"
@ -25,7 +26,7 @@ import (
)
type userByEmail struct {
sync.RWMutex
sync.Mutex
cache map[string]*protocol.User
defaultLevel uint32
defaultAlterIDs uint16
@ -34,7 +35,7 @@ type userByEmail struct {
func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail {
cache := make(map[string]*protocol.User)
for _, user := range users {
cache[user.Email] = user
cache[strings.ToLower(user.Email)] = user
}
return &userByEmail{
cache: cache,
@ -43,33 +44,59 @@ func newUserByEmail(users []*protocol.User, config *DefaultConfig) *userByEmail
}
}
func (v *userByEmail) addNoLock(u *protocol.User) bool {
email := strings.ToLower(u.Email)
user, found := v.cache[email]
if found {
return false
}
v.cache[email] = user
return true
}
func (v *userByEmail) Add(u *protocol.User) bool {
v.Lock()
defer v.Unlock()
return v.addNoLock(u)
}
func (v *userByEmail) Get(email string) (*protocol.User, bool) {
var user *protocol.User
var found bool
v.RLock()
user, found = v.cache[email]
v.RUnlock()
email = strings.ToLower(email)
v.Lock()
defer v.Unlock()
user, found := v.cache[email]
if !found {
v.Lock()
user, found = v.cache[email]
if !found {
id := uuid.New()
account := &vmess.Account{
Id: id.String(),
AlterId: uint32(v.defaultAlterIDs),
}
user = &protocol.User{
Level: v.defaultLevel,
Email: email,
Account: serial.ToTypedMessage(account),
}
v.cache[email] = user
id := uuid.New()
account := &vmess.Account{
Id: id.String(),
AlterId: uint32(v.defaultAlterIDs),
}
v.Unlock()
user = &protocol.User{
Level: v.defaultLevel,
Email: email,
Account: serial.ToTypedMessage(account),
}
v.cache[email] = user
}
return user, found
}
func (v *userByEmail) Remove(email string) bool {
email = strings.ToLower(email)
v.Lock()
defer v.Unlock()
if _, found := v.cache[email]; !found {
return false
}
delete(v.cache, email)
return true
}
// Handler is an inbound connection handler that handles messages in VMess protocol.
type Handler struct {
policyManager core.PolicyManager
@ -129,11 +156,18 @@ func (h *Handler) GetUser(email string) *protocol.User {
}
func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error {
if !h.usersByEmail.Add(user) {
return newError("User ", user.Email, " already exists.")
}
return h.clients.Add(user)
}
func (h *Handler) RemoveUser(ctx context.Context, email string) error {
return newError("not implemented")
if !h.usersByEmail.Remove(email) {
return newError("User ", email, " not found.")
}
h.clients.Remove(email)
return nil
}
func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error {

View File

@ -8,6 +8,7 @@ package vmess
//go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg vmess -path Proxy,VMess
import (
"strings"
"sync"
"time"
@ -21,34 +22,32 @@ const (
cacheDurationSec = 120
)
type idEntry struct {
id *protocol.ID
userIdx int
type user struct {
user *protocol.User
account *InternalAccount
lastSec protocol.Timestamp
}
type TimedUserValidator struct {
sync.RWMutex
validUsers []*protocol.User
userHash map[[16]byte]indexTimePair
ids []*idEntry
hasher protocol.IDHash
baseTime protocol.Timestamp
task *signal.PeriodicTask
users []*user
userHash map[[16]byte]indexTimePair
hasher protocol.IDHash
baseTime protocol.Timestamp
task *signal.PeriodicTask
}
type indexTimePair struct {
index int
user *user
timeInc uint32
}
func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator {
tuv := &TimedUserValidator{
validUsers: make([]*protocol.User, 0, 16),
userHash: make(map[[16]byte]indexTimePair, 512),
ids: make([]*idEntry, 0, 512),
hasher: hasher,
baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3),
users: make([]*user, 0, 16),
userHash: make(map[[16]byte]indexTimePair, 1024),
hasher: hasher,
baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3),
}
tuv.task = &signal.PeriodicTask{
Interval: updateInterval,
@ -61,21 +60,27 @@ func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator {
return tuv
}
func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx int, entry *idEntry) {
func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
var hashValue [16]byte
idHash := v.hasher(entry.id.Bytes())
for entry.lastSec <= nowSec {
common.Must2(idHash.Write(entry.lastSec.Bytes(nil)))
idHash.Sum(hashValue[:0])
idHash.Reset()
genHashForID := func(id *protocol.ID) {
idHash := v.hasher(id.Bytes())
for ts := user.lastSec; ts <= nowSec; ts++ {
common.Must2(idHash.Write(ts.Bytes(nil)))
idHash.Sum(hashValue[:0])
idHash.Reset()
v.userHash[hashValue] = indexTimePair{
index: idx,
timeInc: uint32(entry.lastSec - v.baseTime),
v.userHash[hashValue] = indexTimePair{
user: user,
timeInc: uint32(ts - v.baseTime),
}
}
entry.lastSec++
}
genHashForID(user.account.ID)
for _, id := range user.account.AlterIDs {
genHashForID(id)
}
user.lastSec = nowSec
}
func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
@ -92,8 +97,8 @@ func (v *TimedUserValidator) updateUserHash() {
v.Lock()
defer v.Unlock()
for _, entry := range v.ids {
v.generateNewHashes(nowSec, entry.userIdx, entry)
for _, user := range v.users {
v.generateNewHashes(nowSec, user)
}
expire := protocol.Timestamp(now.Unix() - cacheDurationSec*3)
@ -102,13 +107,11 @@ func (v *TimedUserValidator) updateUserHash() {
}
}
func (v *TimedUserValidator) Add(user *protocol.User) error {
func (v *TimedUserValidator) Add(u *protocol.User) error {
v.Lock()
defer v.Unlock()
idx := len(v.validUsers)
v.validUsers = append(v.validUsers, user)
rawAccount, err := user.GetTypedAccount()
rawAccount, err := u.GetTypedAccount()
if err != nil {
return err
}
@ -116,22 +119,13 @@ func (v *TimedUserValidator) Add(user *protocol.User) error {
nowSec := time.Now().Unix()
entry := &idEntry{
id: account.ID,
userIdx: idx,
uu := &user{
user: u,
account: account,
lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
}
v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry)
v.ids = append(v.ids, entry)
for _, alterid := range account.AlterIDs {
entry := &idEntry{
id: alterid,
userIdx: idx,
lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
}
v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), idx, entry)
v.ids = append(v.ids, entry)
}
v.users = append(v.users, uu)
v.generateNewHashes(protocol.Timestamp(nowSec+cacheDurationSec), uu)
return nil
}
@ -144,11 +138,35 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Time
copy(fixedSizeHash[:], userHash)
pair, found := v.userHash[fixedSizeHash]
if found {
return v.validUsers[pair.index], protocol.Timestamp(pair.timeInc) + v.baseTime, true
return pair.user.user, protocol.Timestamp(pair.timeInc) + v.baseTime, true
}
return nil, 0, false
}
func (v *TimedUserValidator) Remove(email string) bool {
v.Lock()
defer v.Unlock()
email = strings.ToLower(email)
idx := -1
for i, u := range v.users {
if strings.ToLower(u.user.Email) == email {
idx = i
break
}
}
if idx == -1 {
return false
}
ulen := len(v.users)
if idx < len(v.users) {
v.users[idx] = v.users[ulen-1]
v.users[ulen-1] = nil
v.users = v.users[:ulen-1]
}
return true
}
// Close implements common.Closable.
func (v *TimedUserValidator) Close() error {
return v.task.Close()

58
proxy/vmess/vmess_test.go Normal file
View File

@ -0,0 +1,58 @@
package vmess_test
import (
"testing"
"time"
"v2ray.com/core/common"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/uuid"
"v2ray.com/core/common/protocol"
. "v2ray.com/core/proxy/vmess"
. "v2ray.com/ext/assert"
)
func TestUserValidator(t *testing.T) {
assert := With(t)
hasher := protocol.DefaultIDHash
v := NewTimedUserValidator(hasher)
defer common.Close(v)
id := uuid.New()
user := &protocol.User{
Email: "test",
Account: serial.ToTypedMessage(&Account{
Id: id.String(),
AlterId: 8,
}),
}
common.Must(v.Add(user))
{
ts := protocol.Timestamp(time.Now().Unix())
idHash := hasher(id.Bytes())
idHash.Write(ts.Bytes(nil))
userHash := idHash.Sum(nil)
euser, ets, found := v.Get(userHash)
assert(found, IsTrue)
assert(euser.Email, Equals, user.Email)
assert(int64(ets), Equals, int64(ts))
}
{
ts := protocol.Timestamp(time.Now().Add(time.Second * 500).Unix())
idHash := hasher(id.Bytes())
idHash.Write(ts.Bytes(nil))
userHash := idHash.Sum(nil)
euser, _, found := v.Get(userHash)
assert(found, IsFalse)
assert(euser, IsNil)
}
assert(v.Remove(user.Email), IsTrue)
assert(v.Remove(user.Email), IsFalse)
}