1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-17 23:06:30 -05:00

Move userset to protocol

This commit is contained in:
v2ray 2016-02-25 16:40:43 +01:00
parent 2d82bb8d4d
commit 791ac307a2
10 changed files with 170 additions and 172 deletions

View File

@ -0,0 +1,123 @@
package protocol
import (
"hash"
"sync"
"time"
)
const (
updateIntervalSec = 10
cacheDurationSec = 120
)
type IDHash func(key []byte) hash.Hash
type idEntry struct {
id *ID
userIdx int
lastSec Timestamp
lastSecRemoval Timestamp
}
type UserValidator interface {
Add(user *User) error
Get(timeHash []byte) (*User, Timestamp, bool)
}
type TimedUserValidator struct {
validUsers []*User
userHash map[[16]byte]*indexTimePair
ids []*idEntry
access sync.RWMutex
hasher IDHash
}
type indexTimePair struct {
index int
timeSec Timestamp
}
func NewTimedUserValidator(hasher IDHash) UserValidator {
tus := &TimedUserValidator{
validUsers: make([]*User, 0, 16),
userHash: make(map[[16]byte]*indexTimePair, 512),
access: sync.RWMutex{},
ids: make([]*idEntry, 0, 512),
hasher: hasher,
}
go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
return tus
}
func (this *TimedUserValidator) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
var hashValue [16]byte
var hashValueRemoval [16]byte
idHash := this.hasher(entry.id.Bytes())
for entry.lastSec <= nowSec {
idHash.Write(entry.lastSec.Bytes())
idHash.Sum(hashValue[:0])
idHash.Reset()
idHash.Write(entry.lastSecRemoval.Bytes())
idHash.Sum(hashValueRemoval[:0])
idHash.Reset()
this.access.Lock()
this.userHash[hashValue] = &indexTimePair{idx, entry.lastSec}
delete(this.userHash, hashValueRemoval)
this.access.Unlock()
entry.lastSec++
entry.lastSecRemoval++
}
}
func (this *TimedUserValidator) updateUserHash(tick <-chan time.Time) {
for now := range tick {
nowSec := Timestamp(now.Unix() + cacheDurationSec)
for _, entry := range this.ids {
this.generateNewHashes(nowSec, entry.userIdx, entry)
}
}
}
func (this *TimedUserValidator) Add(user *User) error {
idx := len(this.validUsers)
this.validUsers = append(this.validUsers, user)
nowSec := time.Now().Unix()
entry := &idEntry{
id: user.ID,
userIdx: idx,
lastSec: Timestamp(nowSec - cacheDurationSec),
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
}
this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
this.ids = append(this.ids, entry)
for _, alterid := range user.AlterIDs {
entry := &idEntry{
id: alterid,
userIdx: idx,
lastSec: Timestamp(nowSec - cacheDurationSec),
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
}
this.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
this.ids = append(this.ids, entry)
}
return nil
}
func (this *TimedUserValidator) Get(userHash []byte) (*User, Timestamp, bool) {
defer this.access.RUnlock()
this.access.RLock()
var fixedSizeHash [16]byte
copy(fixedSizeHash[:], userHash)
pair, found := this.userHash[fixedSizeHash]
if found {
return this.validUsers[pair.index], pair.timeSec, true
}
return nil, 0, false
}

View File

@ -66,7 +66,7 @@ type VMessInboundHandler struct {
sync.Mutex sync.Mutex
packetDispatcher dispatcher.PacketDispatcher packetDispatcher dispatcher.PacketDispatcher
inboundHandlerManager proxyman.InboundHandlerManager inboundHandlerManager proxyman.InboundHandlerManager
clients protocol.UserSet clients proto.UserValidator
usersByEmail *userByEmail usersByEmail *userByEmail
accepting bool accepting bool
listener *hub.TCPHub listener *hub.TCPHub
@ -91,7 +91,7 @@ func (this *VMessInboundHandler) Close() {
func (this *VMessInboundHandler) GetUser(email string) *proto.User { func (this *VMessInboundHandler) GetUser(email string) *proto.User {
user, existing := this.usersByEmail.Get(email) user, existing := this.usersByEmail.Get(email)
if !existing { if !existing {
this.clients.AddUser(user) this.clients.Add(user)
} }
return user return user
} }
@ -211,9 +211,9 @@ func init() {
} }
config := rawConfig.(*Config) config := rawConfig.(*Config)
allowedClients := protocol.NewTimedUserSet() allowedClients := proto.NewTimedUserValidator(protocol.IDHash)
for _, user := range config.AllowedUsers { for _, user := range config.AllowedUsers {
allowedClients.AddUser(user) allowedClients.Add(user)
} }
handler := &VMessInboundHandler{ handler := &VMessInboundHandler{

View File

@ -14,6 +14,7 @@ import (
v2io "github.com/v2ray/v2ray-core/common/io" v2io "github.com/v2ray/v2ray-core/common/io"
"github.com/v2ray/v2ray-core/common/log" "github.com/v2ray/v2ray-core/common/log"
v2net "github.com/v2ray/v2ray-core/common/net" v2net "github.com/v2ray/v2ray-core/common/net"
proto "github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/proxy" "github.com/v2ray/v2ray-core/proxy"
"github.com/v2ray/v2ray-core/proxy/internal" "github.com/v2ray/v2ray-core/proxy/internal"
vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io" vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io"
@ -106,7 +107,7 @@ func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol
buffer := alloc.NewBuffer().Clear() buffer := alloc.NewBuffer().Clear()
defer buffer.Release() defer buffer.Release()
buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(protocol.Timestamp(time.Now().Unix()), 30), buffer) buffer, err = request.ToBytes(protocol.NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), buffer)
if err != nil { if err != nil {
log.Error("VMessOut: Failed to serialize VMess request: ", err) log.Error("VMessOut: Failed to serialize VMess request: ", err)
return return

View File

@ -2,25 +2,27 @@ package protocol
import ( import (
"math/rand" "math/rand"
"github.com/v2ray/v2ray-core/common/protocol"
) )
type RandomTimestampGenerator interface { type RandomTimestampGenerator interface {
Next() Timestamp Next() protocol.Timestamp
} }
type RealRandomTimestampGenerator struct { type RealRandomTimestampGenerator struct {
base Timestamp base protocol.Timestamp
delta int delta int
} }
func NewRandomTimestampGenerator(base Timestamp, delta int) RandomTimestampGenerator { func NewRandomTimestampGenerator(base protocol.Timestamp, delta int) RandomTimestampGenerator {
return &RealRandomTimestampGenerator{ return &RealRandomTimestampGenerator{
base: base, base: base,
delta: delta, delta: delta,
} }
} }
func (this *RealRandomTimestampGenerator) Next() Timestamp { func (this *RealRandomTimestampGenerator) Next() protocol.Timestamp {
rangeInDelta := rand.Intn(this.delta*2) - this.delta rangeInDelta := rand.Intn(this.delta*2) - this.delta
return this.base + Timestamp(rangeInDelta) return this.base + protocol.Timestamp(rangeInDelta)
} }

View File

@ -4,6 +4,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/v2ray/v2ray-core/common/protocol"
. "github.com/v2ray/v2ray-core/proxy/vmess/protocol" . "github.com/v2ray/v2ray-core/proxy/vmess/protocol"
v2testing "github.com/v2ray/v2ray-core/testing" v2testing "github.com/v2ray/v2ray-core/testing"
"github.com/v2ray/v2ray-core/testing/assert" "github.com/v2ray/v2ray-core/testing/assert"
@ -14,7 +15,7 @@ func TestGenerateRandomInt64InRange(t *testing.T) {
base := time.Now().Unix() base := time.Now().Unix()
delta := 100 delta := 100
generator := NewRandomTimestampGenerator(Timestamp(base), delta) generator := NewRandomTimestampGenerator(protocol.Timestamp(base), delta)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
v := int64(generator.Next()) v := int64(generator.Next())

View File

@ -1,22 +1,21 @@
package mocks package mocks
import ( import (
proto "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
) )
type MockUserSet struct { type MockUserSet struct {
Users []*proto.User Users []*protocol.User
UserHashes map[string]int UserHashes map[string]int
Timestamps map[string]protocol.Timestamp Timestamps map[string]protocol.Timestamp
} }
func (us *MockUserSet) AddUser(user *proto.User) error { func (us *MockUserSet) Add(user *protocol.User) error {
us.Users = append(us.Users, user) us.Users = append(us.Users, user)
return nil return nil
} }
func (us *MockUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) { func (us *MockUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) {
idx, found := us.UserHashes[string(userhash)] idx, found := us.UserHashes[string(userhash)]
if found { if found {
return us.Users[idx], us.Timestamps[string(userhash)], true return us.Users[idx], us.Timestamps[string(userhash)], true

View File

@ -1,21 +1,20 @@
package mocks package mocks
import ( import (
proto "github.com/v2ray/v2ray-core/common/protocol" "github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/common/uuid" "github.com/v2ray/v2ray-core/common/uuid"
"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
) )
type StaticUserSet struct { type StaticUserSet struct {
} }
func (us *StaticUserSet) AddUser(user *proto.User) error { func (us *StaticUserSet) Add(user *protocol.User) error {
return nil return nil
} }
func (us *StaticUserSet) GetUser(userhash []byte) (*proto.User, protocol.Timestamp, bool) { func (us *StaticUserSet) Get(userhash []byte) (*protocol.User, protocol.Timestamp, bool) {
id, _ := uuid.ParseString("703e9102-eb57-499c-8b59-faf4f371bb21") id, _ := uuid.ParseString("703e9102-eb57-499c-8b59-faf4f371bb21")
return &proto.User{ return &protocol.User{
ID: proto.NewID(id), ID: protocol.NewID(id),
}, 0, true }, 0, true
} }

View File

@ -1,137 +0,0 @@
package protocol
import (
"sync"
"time"
proto "github.com/v2ray/v2ray-core/common/protocol"
"github.com/v2ray/v2ray-core/common/serial"
)
const (
updateIntervalSec = 10
cacheDurationSec = 120
)
type Timestamp int64
func (this Timestamp) Bytes() []byte {
return serial.Int64Literal(this).Bytes()
}
func (this Timestamp) HashBytes() []byte {
once := this.Bytes()
bytes := make([]byte, 0, 32)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
return bytes
}
type idEntry struct {
id *proto.ID
userIdx int
lastSec Timestamp
lastSecRemoval Timestamp
}
type UserSet interface {
AddUser(user *proto.User) error
GetUser(timeHash []byte) (*proto.User, Timestamp, bool)
}
type TimedUserSet struct {
validUsers []*proto.User
userHash map[[16]byte]*indexTimePair
ids []*idEntry
access sync.RWMutex
}
type indexTimePair struct {
index int
timeSec Timestamp
}
func NewTimedUserSet() UserSet {
tus := &TimedUserSet{
validUsers: make([]*proto.User, 0, 16),
userHash: make(map[[16]byte]*indexTimePair, 512),
access: sync.RWMutex{},
ids: make([]*idEntry, 0, 512),
}
go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
return tus
}
func (us *TimedUserSet) generateNewHashes(nowSec Timestamp, idx int, entry *idEntry) {
var hashValue [16]byte
var hashValueRemoval [16]byte
idHash := IDHash(entry.id.Bytes())
for entry.lastSec <= nowSec {
idHash.Write(entry.lastSec.Bytes())
idHash.Sum(hashValue[:0])
idHash.Reset()
idHash.Write(entry.lastSecRemoval.Bytes())
idHash.Sum(hashValueRemoval[:0])
idHash.Reset()
us.access.Lock()
us.userHash[hashValue] = &indexTimePair{idx, entry.lastSec}
delete(us.userHash, hashValueRemoval)
us.access.Unlock()
entry.lastSec++
entry.lastSecRemoval++
}
}
func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) {
for now := range tick {
nowSec := Timestamp(now.Unix() + cacheDurationSec)
for _, entry := range us.ids {
us.generateNewHashes(nowSec, entry.userIdx, entry)
}
}
}
func (us *TimedUserSet) AddUser(user *proto.User) error {
idx := len(us.validUsers)
us.validUsers = append(us.validUsers, user)
nowSec := time.Now().Unix()
entry := &idEntry{
id: user.ID,
userIdx: idx,
lastSec: Timestamp(nowSec - cacheDurationSec),
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
}
us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
us.ids = append(us.ids, entry)
for _, alterid := range user.AlterIDs {
entry := &idEntry{
id: alterid,
userIdx: idx,
lastSec: Timestamp(nowSec - cacheDurationSec),
lastSecRemoval: Timestamp(nowSec - cacheDurationSec*3),
}
us.generateNewHashes(Timestamp(nowSec+cacheDurationSec), idx, entry)
us.ids = append(us.ids, entry)
}
return nil
}
func (us *TimedUserSet) GetUser(userHash []byte) (*proto.User, Timestamp, bool) {
defer us.access.RUnlock()
us.access.RLock()
var fixedSizeHash [16]byte
copy(fixedSizeHash[:], userHash)
pair, found := us.userHash[fixedSizeHash]
if found {
return us.validUsers[pair.index], pair.timeSec, true
}
return nil, 0, false
}

View File

@ -31,6 +31,16 @@ const (
blockSize = 16 blockSize = 16
) )
func hashTimestamp(t proto.Timestamp) []byte {
once := t.Bytes()
bytes := make([]byte, 0, 32)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
bytes = append(bytes, once...)
return bytes
}
// VMessRequest implements the request message of VMess protocol. It only contains the header of a // VMessRequest implements the request message of VMess protocol. It only contains the header of a
// request message. The data part will be handled by connection handler directly, in favor of data // request message. The data part will be handled by connection handler directly, in favor of data
// streaming. // streaming.
@ -61,11 +71,11 @@ func (this *VMessRequest) IsChunkStream() bool {
// VMessRequestReader is a parser to read VMessRequest from a byte stream. // VMessRequestReader is a parser to read VMessRequest from a byte stream.
type VMessRequestReader struct { type VMessRequestReader struct {
vUserSet UserSet vUserSet proto.UserValidator
} }
// NewVMessRequestReader creates a new VMessRequestReader with a given UserSet // NewVMessRequestReader creates a new VMessRequestReader with a given UserSet
func NewVMessRequestReader(vUserSet UserSet) *VMessRequestReader { func NewVMessRequestReader(vUserSet proto.UserValidator) *VMessRequestReader {
return &VMessRequestReader{ return &VMessRequestReader{
vUserSet: vUserSet, vUserSet: vUserSet,
} }
@ -82,13 +92,13 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
return nil, err return nil, err
} }
userObj, timeSec, valid := this.vUserSet.GetUser(buffer.Value[:nBytes]) userObj, timeSec, valid := this.vUserSet.Get(buffer.Value[:nBytes])
if !valid { if !valid {
return nil, proxy.ErrorInvalidAuthentication return nil, proxy.ErrorInvalidAuthentication
} }
timestampHash := TimestampHash() timestampHash := TimestampHash()
timestampHash.Write(timeSec.HashBytes()) timestampHash.Write(hashTimestamp(timeSec))
iv := timestampHash.Sum(nil) iv := timestampHash.Sum(nil)
aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID.CmdKey(), iv) aesStream, err := v2crypto.NewAesDecryptionStream(userObj.ID.CmdKey(), iv)
if err != nil { if err != nil {
@ -223,7 +233,7 @@ func (this *VMessRequest) ToBytes(timestampGenerator RandomTimestampGenerator, b
encryptionEnd += 4 encryptionEnd += 4
timestampHash := md5.New() timestampHash := md5.New()
timestampHash.Write(timestamp.HashBytes()) timestampHash.Write(hashTimestamp(timestamp))
iv := timestampHash.Sum(nil) iv := timestampHash.Sum(nil)
aesStream, err := v2crypto.NewAesEncryptionStream(this.User.ID.CmdKey(), iv) aesStream, err := v2crypto.NewAesEncryptionStream(this.User.ID.CmdKey(), iv)
if err != nil { if err != nil {

View File

@ -17,10 +17,10 @@ import (
) )
type FakeTimestampGenerator struct { type FakeTimestampGenerator struct {
timestamp Timestamp timestamp proto.Timestamp
} }
func (this *FakeTimestampGenerator) Next() Timestamp { func (this *FakeTimestampGenerator) Next() proto.Timestamp {
return this.timestamp return this.timestamp
} }
@ -36,8 +36,8 @@ func TestVMessSerialization(t *testing.T) {
ID: userId, ID: userId,
} }
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)} userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)}
userSet.AddUser(testUser) userSet.Add(testUser)
request := new(VMessRequest) request := new(VMessRequest)
request.Version = byte(0x01) request.Version = byte(0x01)
@ -54,7 +54,7 @@ func TestVMessSerialization(t *testing.T) {
request.Address = v2net.DomainAddress("v2ray.com") request.Address = v2net.DomainAddress("v2ray.com")
request.Port = v2net.Port(80) request.Port = v2net.Port(80)
mockTime := Timestamp(1823730) mockTime := proto.Timestamp(1823730)
buffer, err := request.ToBytes(&FakeTimestampGenerator{timestamp: mockTime}, nil) buffer, err := request.ToBytes(&FakeTimestampGenerator{timestamp: mockTime}, nil)
if err != nil { if err != nil {
@ -92,12 +92,12 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
assert.Error(err).IsNil() assert.Error(err).IsNil()
userId := proto.NewID(id) userId := proto.NewID(id)
userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]Timestamp)} userSet := protocoltesting.MockUserSet{[]*proto.User{}, make(map[string]int), make(map[string]proto.Timestamp)}
testUser := &proto.User{ testUser := &proto.User{
ID: userId, ID: userId,
} }
userSet.AddUser(testUser) userSet.Add(testUser)
request := new(VMessRequest) request := new(VMessRequest)
request.Version = byte(0x01) request.Version = byte(0x01)
@ -114,6 +114,6 @@ func BenchmarkVMessRequestWriting(b *testing.B) {
request.Port = v2net.Port(80) request.Port = v2net.Port(80)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
request.ToBytes(NewRandomTimestampGenerator(Timestamp(time.Now().Unix()), 30), nil) request.ToBytes(NewRandomTimestampGenerator(proto.Timestamp(time.Now().Unix()), 30), nil)
} }
} }