1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-09-10 05:54:22 -04:00

introduce in-memory user

This commit is contained in:
Darien Raymond 2018-08-27 00:11:32 +02:00
parent b4d065610a
commit 54e1bb96cc
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
22 changed files with 212 additions and 167 deletions

View File

@ -39,7 +39,11 @@ func (op *AddUserOperation) ApplyInbound(ctx context.Context, handler core.Inbou
if !ok {
return newError("proxy is not a UserManager")
}
return um.AddUser(ctx, op.User)
mUser, err := op.User.ToMemoryUser()
if err != nil {
return newError("failed to parse user").Base(err)
}
return um.AddUser(ctx, mUser)
}
// ApplyInbound implements InboundOperation.

View File

@ -126,11 +126,11 @@ func TestRoutingRule(t *testing.T) {
},
test: []ruleTest{
{
input: protocol.ContextWithUser(context.Background(), &protocol.User{Email: "admin@v2ray.com"}),
input: protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "admin@v2ray.com"}),
output: true,
},
{
input: protocol.ContextWithUser(context.Background(), &protocol.User{Email: "love@v2ray.com"}),
input: protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "love@v2ray.com"}),
output: false,
},
{

View File

@ -12,17 +12,17 @@ const (
)
// ContextWithUser returns a context combined with a User.
func ContextWithUser(ctx context.Context, user *User) context.Context {
func ContextWithUser(ctx context.Context, user *MemoryUser) context.Context {
return context.WithValue(ctx, userKey, user)
}
// UserFromContext extracts a User from the given context, if any.
func UserFromContext(ctx context.Context) *User {
func UserFromContext(ctx context.Context) *MemoryUser {
v := ctx.Value(userKey)
if v == nil {
return nil
}
return v.(*User)
return v.(*MemoryUser)
}
func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context {

View File

@ -47,7 +47,7 @@ type RequestHeader struct {
Security SecurityType
Port net.Port
Address net.Address
User *User
User *MemoryUser
}
func (h *RequestHeader) Destination() net.Destination {

View File

@ -46,11 +46,11 @@ func (s *timeoutValidStrategy) Invalidate() {
type ServerSpec struct {
sync.RWMutex
dest net.Destination
users []*User
users []*MemoryUser
valid ValidationStrategy
}
func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*User) *ServerSpec {
func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*MemoryUser) *ServerSpec {
return &ServerSpec{
dest: dest,
users: users,
@ -58,33 +58,36 @@ func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*Use
}
}
func NewServerSpecFromPB(spec ServerEndpoint) *ServerSpec {
func NewServerSpecFromPB(spec ServerEndpoint) (*ServerSpec, error) {
dest := net.TCPDestination(spec.Address.AsAddress(), net.Port(spec.Port))
return NewServerSpec(dest, AlwaysValid(), spec.User...)
mUsers := make([]*MemoryUser, len(spec.User))
for idx, u := range spec.User {
mUser, err := u.ToMemoryUser()
if err != nil {
return nil, err
}
mUsers[idx] = mUser
}
return NewServerSpec(dest, AlwaysValid(), mUsers...), nil
}
func (s *ServerSpec) Destination() net.Destination {
return s.dest
}
func (s *ServerSpec) HasUser(user *User) bool {
func (s *ServerSpec) HasUser(user *MemoryUser) bool {
s.RLock()
defer s.RUnlock()
accountA, err := user.GetTypedAccount()
if err != nil {
return false
}
for _, u := range s.users {
accountB, err := u.GetTypedAccount()
if err == nil && accountA.Equals(accountB) {
if u.Account.Equals(user.Account) {
return true
}
}
return false
}
func (s *ServerSpec) AddUser(user *User) {
func (s *ServerSpec) AddUser(user *MemoryUser) {
if s.HasUser(user) {
return
}
@ -95,7 +98,7 @@ func (s *ServerSpec) AddUser(user *User) {
s.users = append(s.users, user)
}
func (s *ServerSpec) PickUser() *User {
func (s *ServerSpec) PickUser() *MemoryUser {
s.RLock()
defer s.RUnlock()

View File

@ -6,7 +6,6 @@ import (
"v2ray.com/core/common/net"
. "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/uuid"
"v2ray.com/core/proxy/vmess"
. "v2ray.com/ext/assert"
@ -40,26 +39,26 @@ func TestUserInServerSpec(t *testing.T) {
uuid1 := uuid.New()
uuid2 := uuid.New()
spec := NewServerSpec(net.Destination{}, AlwaysValid(), &User{
spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{
Email: "test1@v2ray.com",
Account: serial.ToTypedMessage(&vmess.Account{Id: uuid1.String()}),
Account: &vmess.Account{Id: uuid1.String()},
})
assert(spec.HasUser(&User{
assert(spec.HasUser(&MemoryUser{
Email: "test1@v2ray.com",
Account: serial.ToTypedMessage(&vmess.Account{Id: uuid2.String()}),
Account: &vmess.Account{Id: uuid2.String()},
}), IsFalse)
spec.AddUser(&User{Email: "test2@v2ray.com"})
assert(spec.HasUser(&User{
spec.AddUser(&MemoryUser{Email: "test2@v2ray.com"})
assert(spec.HasUser(&MemoryUser{
Email: "test1@v2ray.com",
Account: serial.ToTypedMessage(&vmess.Account{Id: uuid1.String()}),
Account: &vmess.Account{Id: uuid1.String()},
}), IsTrue)
}
func TestPickUser(t *testing.T) {
assert := With(t)
spec := NewServerSpec(net.Destination{}, AlwaysValid(), &User{Email: "test1@v2ray.com"}, &User{Email: "test2@v2ray.com"}, &User{Email: "test3@v2ray.com"})
spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{Email: "test1@v2ray.com"}, &MemoryUser{Email: "test2@v2ray.com"}, &MemoryUser{Email: "test3@v2ray.com"})
user := spec.PickUser()
assert(user.Email, HasSuffix, "@v2ray.com")
}

View File

@ -17,3 +17,21 @@ func (u *User) GetTypedAccount() (Account, error) {
}
return nil, newError("Unknown account type: ", u.Account.Type)
}
func (u *User) ToMemoryUser() (*MemoryUser, error) {
account, err := u.GetTypedAccount()
if err != nil {
return nil, err
}
return &MemoryUser{
Account: account,
Email: u.Email,
Level: u.Level,
}, nil
}
type MemoryUser struct {
Account Account
Email string
Level uint32
}

View File

@ -2,12 +2,26 @@ package task
import (
"context"
"strings"
"v2ray.com/core/common"
"v2ray.com/core/common/signal/semaphore"
)
type Task func() error
type MultiError []error
func (e MultiError) Error() string {
var r strings.Builder
common.Must2(r.WriteString("multierr: "))
for _, err := range e {
common.Must2(r.WriteString(err.Error()))
common.Must2(r.WriteString(" | "))
}
return r.String()
}
type executionContext struct {
ctx context.Context
tasks []Task
@ -59,20 +73,44 @@ func Parallel(tasks ...Task) ExecutionOption {
}
}
// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
// Once a task returns an error, the following tasks will not run.
func Sequential(tasks ...Task) ExecutionOption {
return func(c *executionContext) {
if len(tasks) == 0 {
switch len(tasks) {
case 0:
return
}
if len(tasks) == 1 {
case 1:
c.tasks = append(c.tasks, tasks[0])
return
default:
c.tasks = append(c.tasks, func() error {
return execute(tasks...)
})
}
}
}
c.tasks = append(c.tasks, func() error {
return execute(tasks...)
})
func SequentialAll(tasks ...Task) ExecutionOption {
return func(c *executionContext) {
switch len(tasks) {
case 0:
return
case 1:
c.tasks = append(c.tasks, tasks[0])
default:
c.tasks = append(c.tasks, func() error {
var merr MultiError
for _, task := range tasks {
if err := task(); err != nil {
merr = append(merr, err)
}
}
if len(merr) == 0 {
return nil
}
return merr
})
}
}
}

View File

@ -38,7 +38,7 @@ type Dialer interface {
// UserManager is the interface for Inbounds and Outbounds that can manage their users.
type UserManager interface {
// AddUser adds a new user.
AddUser(context.Context, *protocol.User) error
AddUser(context.Context, *protocol.MemoryUser) error
// RemoveUser removes a user by email.
RemoveUser(context.Context, string) error

View File

@ -27,7 +27,11 @@ type Client struct {
func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
serverList := protocol.NewServerList()
for _, rec := range config.Server {
serverList.AddServer(protocol.NewServerSpecFromPB(*rec))
s, err := protocol.NewServerSpecFromPB(*rec)
if err != nil {
return nil, newError("failed to parse server spec").Base(err)
}
serverList.AddServer(s)
}
if serverList.Size() == 0 {
return nil, newError("0 server")
@ -81,11 +85,10 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
}
user := server.PickUser()
rawAccount, err := user.GetTypedAccount()
if err != nil {
return newError("failed to get a valid user account").AtWarning().Base(err)
account, ok := user.Account.(*MemoryAccount)
if !ok {
return newError("user account is not valid")
}
account := rawAccount.(*MemoryAccount)
request.User = user
if account.OneTimeAuth == Account_Auto || account.OneTimeAuth == Account_Enabled {

View File

@ -27,12 +27,8 @@ var addrParser = protocol.NewAddressParser(
)
// ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
rawAccount, err := user.GetTypedAccount()
if err != nil {
return nil, nil, newError("failed to parse account").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
account := user.Account.(*MemoryAccount)
buffer := buf.New()
defer buffer.Release()
@ -116,11 +112,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
// WriteTCPRequest writes Shadowsocks request into the given writer, and returns a writer for body.
func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
user := request.User
rawAccount, err := user.GetTypedAccount()
if err != nil {
return nil, newError("failed to parse account").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
account := user.Account.(*MemoryAccount)
if account.Cipher.IsAEAD() {
request.Option.Clear(RequestOptionOneTimeAuth)
@ -167,17 +159,13 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
return chunkWriter, nil
}
func ReadTCPResponse(user *protocol.User, reader io.Reader) (buf.Reader, error) {
rawAccount, err := user.GetTypedAccount()
if err != nil {
return nil, newError("failed to parse account").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
func ReadTCPResponse(user *protocol.MemoryUser, reader io.Reader) (buf.Reader, error) {
account := user.Account.(*MemoryAccount)
var iv []byte
if account.Cipher.IVSize() > 0 {
iv = make([]byte, account.Cipher.IVSize())
if _, err = io.ReadFull(reader, iv); err != nil {
if _, err := io.ReadFull(reader, iv); err != nil {
return nil, newError("failed to read IV").Base(err)
}
}
@ -187,11 +175,7 @@ func ReadTCPResponse(user *protocol.User, reader io.Reader) (buf.Reader, error)
func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
user := request.User
rawAccount, err := user.GetTypedAccount()
if err != nil {
return nil, newError("failed to parse account.").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
account := user.Account.(*MemoryAccount)
var iv []byte
if account.Cipher.IVSize() > 0 {
@ -207,11 +191,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
user := request.User
rawAccount, err := user.GetTypedAccount()
if err != nil {
return nil, newError("failed to parse account.").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
account := user.Account.(*MemoryAccount)
buffer := buf.New()
ivLen := account.Cipher.IVSize()
@ -239,12 +219,8 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
return buffer, nil
}
func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
rawAccount, err := user.GetTypedAccount()
if err != nil {
return nil, nil, newError("failed to parse account").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
func DecodeUDPPacket(user *protocol.MemoryUser, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
account := user.Account.(*MemoryAccount)
var iv []byte
if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 {
@ -306,7 +282,7 @@ func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.Reques
type UDPReader struct {
Reader io.Reader
User *protocol.User
User *protocol.MemoryUser
}
func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {

View File

@ -12,6 +12,12 @@ import (
. "v2ray.com/ext/assert"
)
func toAccount(a *Account) protocol.Account {
account, err := a.AsAccount()
common.Must(err)
return account
}
func TestUDPEncoding(t *testing.T) {
assert := With(t)
@ -20,9 +26,9 @@ func TestUDPEncoding(t *testing.T) {
Command: protocol.RequestCommandUDP,
Address: net.LocalHostIP,
Port: 1234,
User: &protocol.User{
User: &protocol.MemoryUser{
Email: "love@v2ray.com",
Account: serial.ToTypedMessage(&Account{
Account: toAccount(&Account{
Password: "shadowsocks-password",
CipherType: CipherType_AES_128_CFB,
Ota: Account_Disabled,
@ -57,9 +63,9 @@ func TestTCPRequest(t *testing.T) {
Address: net.LocalHostIP,
Option: RequestOptionOneTimeAuth,
Port: 1234,
User: &protocol.User{
User: &protocol.MemoryUser{
Email: "love@v2ray.com",
Account: serial.ToTypedMessage(&Account{
Account: toAccount(&Account{
Password: "tcp-password",
CipherType: CipherType_CHACHA20,
}),
@ -74,9 +80,9 @@ func TestTCPRequest(t *testing.T) {
Address: net.LocalHostIPv6,
Option: RequestOptionOneTimeAuth,
Port: 1234,
User: &protocol.User{
User: &protocol.MemoryUser{
Email: "love@v2ray.com",
Account: serial.ToTypedMessage(&Account{
Account: toAccount(&Account{
Password: "password",
CipherType: CipherType_AES_256_CFB,
}),
@ -91,9 +97,9 @@ func TestTCPRequest(t *testing.T) {
Address: net.DomainAddress("v2ray.com"),
Option: RequestOptionOneTimeAuth,
Port: 1234,
User: &protocol.User{
User: &protocol.MemoryUser{
Email: "love@v2ray.com",
Account: serial.ToTypedMessage(&Account{
Account: toAccount(&Account{
Password: "password",
CipherType: CipherType_CHACHA20_IETF,
}),
@ -135,8 +141,8 @@ func TestTCPRequest(t *testing.T) {
func TestUDPReaderWriter(t *testing.T) {
assert := With(t)
user := &protocol.User{
Account: serial.ToTypedMessage(&Account{
user := &protocol.MemoryUser{
Account: toAccount(&Account{
Password: "test-password",
CipherType: CipherType_CHACHA20_IETF,
}),

View File

@ -20,10 +20,9 @@ import (
)
type Server struct {
config ServerConfig
user *protocol.User
account *MemoryAccount
v *core.Instance
config ServerConfig
user *protocol.MemoryUser
v *core.Instance
}
// NewServer create a new Shadowsocks server.
@ -32,17 +31,15 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
return nil, newError("user is not specified")
}
rawAccount, err := config.User.GetTypedAccount()
mUser, err := config.User.ToMemoryUser()
if err != nil {
return nil, newError("failed to get user account").Base(err)
return nil, newError("failed to parse user account").Base(err)
}
account := rawAccount.(*MemoryAccount)
s := &Server{
config: *config,
user: config.GetUser(),
account: account,
v: core.MustFromContext(ctx),
config: *config,
user: mUser,
v: core.MustFromContext(ctx),
}
return s, nil
@ -90,6 +87,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
conn.Write(data.Bytes())
})
account := s.user.Account.(*MemoryAccount)
reader := buf.NewReader(conn)
for {
mpayload, err := reader.ReadMultiBuffer()
@ -113,13 +112,13 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
continue
}
if request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Disabled {
if request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Disabled {
newError("client payload enables OTA but server doesn't allow it").WriteToLog(session.ExportIDToError(ctx))
payload.Release()
continue
}
if !request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Enabled {
if !request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Enabled {
newError("client payload disables OTA but server forces it").WriteToLog(session.ExportIDToError(ctx))
payload.Release()
continue

View File

@ -28,7 +28,11 @@ type Client struct {
func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
serverList := protocol.NewServerList()
for _, rec := range config.Server {
serverList.AddServer(protocol.NewServerSpecFromPB(*rec))
s, err := protocol.NewServerSpecFromPB(*rec)
if err != nil {
return nil, newError("failed to get server spec").Base(err)
}
serverList.AddServer(s)
}
if serverList.Size() == 0 {
return nil, newError("0 target server")

View File

@ -350,11 +350,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
common.Must2(b.WriteBytes(socks5Version, 0x01, authByte))
if authByte == authPassword {
rawAccount, err := request.User.GetTypedAccount()
if err != nil {
return nil, err
}
account := rawAccount.(*Account)
account := request.User.Account.(*Account)
common.Must2(b.WriteBytes(0x01, byte(len(account.Username))))
common.Must2(b.Write([]byte(account.Username)))

View File

@ -58,11 +58,8 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
account, err := header.User.GetTypedAccount()
if err != nil {
return newError("failed to get user account: ", err).AtError()
}
idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes())
account := header.User.Account.(*vmess.InternalAccount)
idHash := c.idHash(account.AnyValidID().Bytes())
common.Must2(idHash.Write(timestamp.Bytes(nil)))
common.Must2(writer.Write(idHash.Sum(nil)))
@ -97,7 +94,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
timestampHash := md5.New()
common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
iv := timestampHash.Sum(nil)
aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv)
aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv)
aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
common.Must2(writer.Write(buffer.Bytes()))
return nil

View File

@ -7,17 +7,22 @@ import (
"v2ray.com/core/common/buf"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/uuid"
"v2ray.com/core/proxy/vmess"
. "v2ray.com/core/proxy/vmess/encoding"
. "v2ray.com/ext/assert"
)
func toAccount(a *vmess.Account) protocol.Account {
account, err := a.AsAccount()
common.Must(err)
return account
}
func TestRequestSerialization(t *testing.T) {
assert := With(t)
user := &protocol.User{
user := &protocol.MemoryUser{
Level: 0,
Email: "test@v2ray.com",
}
@ -26,7 +31,7 @@ func TestRequestSerialization(t *testing.T) {
Id: id.String(),
AlterId: 0,
}
user.Account = serial.ToTypedMessage(account)
user.Account = toAccount(account)
expectedRequest := &protocol.RequestHeader{
Version: 1,
@ -70,7 +75,7 @@ func TestRequestSerialization(t *testing.T) {
func TestInvalidRequest(t *testing.T) {
assert := With(t)
user := &protocol.User{
user := &protocol.MemoryUser{
Level: 0,
Email: "test@v2ray.com",
}
@ -79,7 +84,7 @@ func TestInvalidRequest(t *testing.T) {
Id: id.String(),
AlterId: 0,
}
user.Account = serial.ToTypedMessage(account)
user.Account = toAccount(account)
expectedRequest := &protocol.RequestHeader{
Version: 1,
@ -112,7 +117,7 @@ func TestInvalidRequest(t *testing.T) {
func TestMuxRequest(t *testing.T) {
assert := With(t)
user := &protocol.User{
user := &protocol.MemoryUser{
Level: 0,
Email: "test@v2ray.com",
}
@ -121,7 +126,7 @@ func TestMuxRequest(t *testing.T) {
Id: id.String(),
AlterId: 0,
}
user.Account = serial.ToTypedMessage(account)
user.Account = toAccount(account)
expectedRequest := &protocol.RequestHeader{
Version: 1,

View File

@ -139,11 +139,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
timestampHash := md5.New()
common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
iv := timestampHash.Sum(nil)
account, err := user.GetTypedAccount()
if err != nil {
return nil, newError("failed to get user account").Base(err)
}
vmessAccount := account.(*vmess.InternalAccount)
vmessAccount := user.Account.(*vmess.InternalAccount)
aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv)
decryptor := crypto.NewCryptionReader(aesStream, reader)

View File

@ -16,7 +16,6 @@ import (
"v2ray.com/core/common/log"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/session"
"v2ray.com/core/common/signal"
"v2ray.com/core/common/task"
@ -29,20 +28,20 @@ import (
type userByEmail struct {
sync.Mutex
cache map[string]*protocol.User
cache map[string]*protocol.MemoryUser
defaultLevel uint32
defaultAlterIDs uint16
}
func newUserByEmail(config *DefaultConfig) *userByEmail {
return &userByEmail{
cache: make(map[string]*protocol.User),
cache: make(map[string]*protocol.MemoryUser),
defaultLevel: config.Level,
defaultAlterIDs: uint16(config.AlterId),
}
}
func (v *userByEmail) addNoLock(u *protocol.User) bool {
func (v *userByEmail) addNoLock(u *protocol.MemoryUser) bool {
email := strings.ToLower(u.Email)
user, found := v.cache[email]
if found {
@ -52,14 +51,14 @@ func (v *userByEmail) addNoLock(u *protocol.User) bool {
return true
}
func (v *userByEmail) Add(u *protocol.User) bool {
func (v *userByEmail) Add(u *protocol.MemoryUser) bool {
v.Lock()
defer v.Unlock()
return v.addNoLock(u)
}
func (v *userByEmail) Get(email string) (*protocol.User, bool) {
func (v *userByEmail) Get(email string) (*protocol.MemoryUser, bool) {
email = strings.ToLower(email)
v.Lock()
@ -68,14 +67,16 @@ func (v *userByEmail) Get(email string) (*protocol.User, bool) {
user, found := v.cache[email]
if !found {
id := uuid.New()
account := &vmess.Account{
rawAccount := &vmess.Account{
Id: id.String(),
AlterId: uint32(v.defaultAlterIDs),
}
user = &protocol.User{
account, err := rawAccount.AsAccount()
common.Must(err)
user = &protocol.MemoryUser{
Level: v.defaultLevel,
Email: email,
Account: serial.ToTypedMessage(account),
Account: account,
}
v.cache[email] = user
}
@ -120,7 +121,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
}
for _, user := range config.User {
if err := handler.AddUser(ctx, user); err != nil {
mUser, err := user.ToMemoryUser()
if err != nil {
return nil, newError("failed to get VMess user").Base(err)
}
if err := handler.AddUser(ctx, mUser); err != nil {
return nil, newError("failed to initiate user").Base(err)
}
}
@ -130,10 +136,9 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Close implements common.Closable.
func (h *Handler) Close() error {
common.Close(h.clients)
common.Close(h.sessionHistory)
common.Close(h.usersByEmail)
return nil
return task.Run(
task.SequentialAll(
task.Close(h.clients), task.Close(h.sessionHistory), task.Close(h.usersByEmail)))()
}
// Network implements proxy.Inbound.Network().
@ -143,7 +148,7 @@ func (*Handler) Network() net.NetworkList {
}
}
func (h *Handler) GetUser(email string) *protocol.User {
func (h *Handler) GetUser(email string) *protocol.MemoryUser {
user, existing := h.usersByEmail.Get(email)
if !existing {
h.clients.Add(user)
@ -151,7 +156,7 @@ func (h *Handler) GetUser(email string) *protocol.User {
return user
}
func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error {
func (h *Handler) AddUser(ctx context.Context, user *protocol.MemoryUser) error {
if len(user.Email) > 0 && !h.usersByEmail.Add(user) {
return newError("User ", user.Email, " already exists.")
}
@ -325,11 +330,11 @@ func (h *Handler) generateCommand(ctx context.Context, request *protocol.Request
if user == nil {
return nil
}
account, _ := user.GetTypedAccount()
account := user.Account.(*vmess.InternalAccount)
return &protocol.CommandSwitchAccount{
Port: port,
ID: account.(*vmess.InternalAccount).ID.UUID(),
AlterIds: uint16(len(account.(*vmess.InternalAccount).AlterIDs)),
ID: account.ID.UUID(),
AlterIds: uint16(len(account.AlterIDs)),
Level: user.Level,
ValidMin: byte(availableMin),
}

View File

@ -3,14 +3,14 @@ package outbound
import (
"time"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/proxy/vmess"
)
func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
account := &vmess.Account{
rawAccount := &vmess.Account{
Id: cmd.ID.String(),
AlterId: uint32(cmd.AlterIds),
SecuritySettings: &protocol.SecurityConfig{
@ -18,10 +18,12 @@ func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
},
}
user := &protocol.User{
account, err := rawAccount.AsAccount()
common.Must(err)
user := &protocol.MemoryUser{
Email: "",
Level: cmd.Level,
Account: serial.ToTypedMessage(account),
Account: account,
}
dest := net.TCPDestination(cmd.Host, cmd.Port)
until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute)

View File

@ -33,7 +33,11 @@ type Handler struct {
func New(ctx context.Context, config *Config) (*Handler, error) {
serverList := protocol.NewServerList()
for _, rec := range config.Receiver {
serverList.AddServer(protocol.NewServerSpecFromPB(*rec))
s, err := protocol.NewServerSpecFromPB(*rec)
if err != nil {
return nil, newError("failed to parse server spec").Base(err)
}
serverList.AddServer(s)
}
handler := &Handler{
serverList: serverList,
@ -87,11 +91,7 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
Option: protocol.RequestOptionChunkStream,
}
rawAccount, err := request.User.GetTypedAccount()
if err != nil {
return newError("failed to get user account").Base(err).AtWarning()
}
account := rawAccount.(*vmess.InternalAccount)
account := request.User.Account.(*vmess.InternalAccount)
request.Security = account.Security
if request.Security == protocol.SecurityType_AES128_GCM || request.Security == protocol.SecurityType_NONE || request.Security == protocol.SecurityType_CHACHA20_POLY1305 {

View File

@ -23,8 +23,7 @@ const (
)
type user struct {
user *protocol.User
account *InternalAccount
user *protocol.MemoryUser
lastSec protocol.Timestamp
}
@ -80,8 +79,10 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *
}
}
genHashForID(user.account.ID)
for _, id := range user.account.AlterIDs {
account := user.user.Account.(*InternalAccount)
genHashForID(account.ID)
for _, id := range account.AlterIDs {
genHashForID(id)
}
user.lastSec = nowSec
@ -111,21 +112,14 @@ func (v *TimedUserValidator) updateUserHash() {
}
}
func (v *TimedUserValidator) Add(u *protocol.User) error {
func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
v.Lock()
defer v.Unlock()
rawAccount, err := u.GetTypedAccount()
if err != nil {
return err
}
account := rawAccount.(*InternalAccount)
nowSec := time.Now().Unix()
uu := &user{
user: u,
account: account,
lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
}
v.users = append(v.users, uu)
@ -134,7 +128,7 @@ func (v *TimedUserValidator) Add(u *protocol.User) error {
return nil
}
func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Timestamp, bool) {
func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool) {
defer v.RUnlock()
v.RLock()