1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-06-27 17:55:23 +00:00

introduce in-memory user

This commit is contained in:
Darien Raymond 2018-08-27 00:11:32 +02:00
parent b4d065610a
commit 54e1bb96cc
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
22 changed files with 212 additions and 167 deletions

View File

@ -39,7 +39,11 @@ func (op *AddUserOperation) ApplyInbound(ctx context.Context, handler core.Inbou
if !ok { if !ok {
return newError("proxy is not a UserManager") return newError("proxy is not a UserManager")
} }
return um.AddUser(ctx, op.User) mUser, err := op.User.ToMemoryUser()
if err != nil {
return newError("failed to parse user").Base(err)
}
return um.AddUser(ctx, mUser)
} }
// ApplyInbound implements InboundOperation. // ApplyInbound implements InboundOperation.

View File

@ -126,11 +126,11 @@ func TestRoutingRule(t *testing.T) {
}, },
test: []ruleTest{ test: []ruleTest{
{ {
input: protocol.ContextWithUser(context.Background(), &protocol.User{Email: "admin@v2ray.com"}), input: protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "admin@v2ray.com"}),
output: true, output: true,
}, },
{ {
input: protocol.ContextWithUser(context.Background(), &protocol.User{Email: "love@v2ray.com"}), input: protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "love@v2ray.com"}),
output: false, output: false,
}, },
{ {

View File

@ -12,17 +12,17 @@ const (
) )
// ContextWithUser returns a context combined with a User. // ContextWithUser returns a context combined with a User.
func ContextWithUser(ctx context.Context, user *User) context.Context { func ContextWithUser(ctx context.Context, user *MemoryUser) context.Context {
return context.WithValue(ctx, userKey, user) return context.WithValue(ctx, userKey, user)
} }
// UserFromContext extracts a User from the given context, if any. // UserFromContext extracts a User from the given context, if any.
func UserFromContext(ctx context.Context) *User { func UserFromContext(ctx context.Context) *MemoryUser {
v := ctx.Value(userKey) v := ctx.Value(userKey)
if v == nil { if v == nil {
return nil return nil
} }
return v.(*User) return v.(*MemoryUser)
} }
func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context { func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context {

View File

@ -47,7 +47,7 @@ type RequestHeader struct {
Security SecurityType Security SecurityType
Port net.Port Port net.Port
Address net.Address Address net.Address
User *User User *MemoryUser
} }
func (h *RequestHeader) Destination() net.Destination { func (h *RequestHeader) Destination() net.Destination {

View File

@ -46,11 +46,11 @@ func (s *timeoutValidStrategy) Invalidate() {
type ServerSpec struct { type ServerSpec struct {
sync.RWMutex sync.RWMutex
dest net.Destination dest net.Destination
users []*User users []*MemoryUser
valid ValidationStrategy valid ValidationStrategy
} }
func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*User) *ServerSpec { func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*MemoryUser) *ServerSpec {
return &ServerSpec{ return &ServerSpec{
dest: dest, dest: dest,
users: users, users: users,
@ -58,33 +58,36 @@ func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*Use
} }
} }
func NewServerSpecFromPB(spec ServerEndpoint) *ServerSpec { func NewServerSpecFromPB(spec ServerEndpoint) (*ServerSpec, error) {
dest := net.TCPDestination(spec.Address.AsAddress(), net.Port(spec.Port)) dest := net.TCPDestination(spec.Address.AsAddress(), net.Port(spec.Port))
return NewServerSpec(dest, AlwaysValid(), spec.User...) mUsers := make([]*MemoryUser, len(spec.User))
for idx, u := range spec.User {
mUser, err := u.ToMemoryUser()
if err != nil {
return nil, err
}
mUsers[idx] = mUser
}
return NewServerSpec(dest, AlwaysValid(), mUsers...), nil
} }
func (s *ServerSpec) Destination() net.Destination { func (s *ServerSpec) Destination() net.Destination {
return s.dest return s.dest
} }
func (s *ServerSpec) HasUser(user *User) bool { func (s *ServerSpec) HasUser(user *MemoryUser) bool {
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
accountA, err := user.GetTypedAccount()
if err != nil {
return false
}
for _, u := range s.users { for _, u := range s.users {
accountB, err := u.GetTypedAccount() if u.Account.Equals(user.Account) {
if err == nil && accountA.Equals(accountB) {
return true return true
} }
} }
return false return false
} }
func (s *ServerSpec) AddUser(user *User) { func (s *ServerSpec) AddUser(user *MemoryUser) {
if s.HasUser(user) { if s.HasUser(user) {
return return
} }
@ -95,7 +98,7 @@ func (s *ServerSpec) AddUser(user *User) {
s.users = append(s.users, user) s.users = append(s.users, user)
} }
func (s *ServerSpec) PickUser() *User { func (s *ServerSpec) PickUser() *MemoryUser {
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()

View File

@ -6,7 +6,6 @@ import (
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
. "v2ray.com/core/common/protocol" . "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/uuid" "v2ray.com/core/common/uuid"
"v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
@ -40,26 +39,26 @@ func TestUserInServerSpec(t *testing.T) {
uuid1 := uuid.New() uuid1 := uuid.New()
uuid2 := uuid.New() uuid2 := uuid.New()
spec := NewServerSpec(net.Destination{}, AlwaysValid(), &User{ spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{
Email: "test1@v2ray.com", Email: "test1@v2ray.com",
Account: serial.ToTypedMessage(&vmess.Account{Id: uuid1.String()}), Account: &vmess.Account{Id: uuid1.String()},
}) })
assert(spec.HasUser(&User{ assert(spec.HasUser(&MemoryUser{
Email: "test1@v2ray.com", Email: "test1@v2ray.com",
Account: serial.ToTypedMessage(&vmess.Account{Id: uuid2.String()}), Account: &vmess.Account{Id: uuid2.String()},
}), IsFalse) }), IsFalse)
spec.AddUser(&User{Email: "test2@v2ray.com"}) spec.AddUser(&MemoryUser{Email: "test2@v2ray.com"})
assert(spec.HasUser(&User{ assert(spec.HasUser(&MemoryUser{
Email: "test1@v2ray.com", Email: "test1@v2ray.com",
Account: serial.ToTypedMessage(&vmess.Account{Id: uuid1.String()}), Account: &vmess.Account{Id: uuid1.String()},
}), IsTrue) }), IsTrue)
} }
func TestPickUser(t *testing.T) { func TestPickUser(t *testing.T) {
assert := With(t) assert := With(t)
spec := NewServerSpec(net.Destination{}, AlwaysValid(), &User{Email: "test1@v2ray.com"}, &User{Email: "test2@v2ray.com"}, &User{Email: "test3@v2ray.com"}) spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{Email: "test1@v2ray.com"}, &MemoryUser{Email: "test2@v2ray.com"}, &MemoryUser{Email: "test3@v2ray.com"})
user := spec.PickUser() user := spec.PickUser()
assert(user.Email, HasSuffix, "@v2ray.com") assert(user.Email, HasSuffix, "@v2ray.com")
} }

View File

@ -17,3 +17,21 @@ func (u *User) GetTypedAccount() (Account, error) {
} }
return nil, newError("Unknown account type: ", u.Account.Type) return nil, newError("Unknown account type: ", u.Account.Type)
} }
func (u *User) ToMemoryUser() (*MemoryUser, error) {
account, err := u.GetTypedAccount()
if err != nil {
return nil, err
}
return &MemoryUser{
Account: account,
Email: u.Email,
Level: u.Level,
}, nil
}
type MemoryUser struct {
Account Account
Email string
Level uint32
}

View File

@ -2,12 +2,26 @@ package task
import ( import (
"context" "context"
"strings"
"v2ray.com/core/common"
"v2ray.com/core/common/signal/semaphore" "v2ray.com/core/common/signal/semaphore"
) )
type Task func() error type Task func() error
type MultiError []error
func (e MultiError) Error() string {
var r strings.Builder
common.Must2(r.WriteString("multierr: "))
for _, err := range e {
common.Must2(r.WriteString(err.Error()))
common.Must2(r.WriteString(" | "))
}
return r.String()
}
type executionContext struct { type executionContext struct {
ctx context.Context ctx context.Context
tasks []Task tasks []Task
@ -59,20 +73,44 @@ func Parallel(tasks ...Task) ExecutionOption {
} }
} }
// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
// Once a task returns an error, the following tasks will not run.
func Sequential(tasks ...Task) ExecutionOption { func Sequential(tasks ...Task) ExecutionOption {
return func(c *executionContext) { return func(c *executionContext) {
if len(tasks) == 0 { switch len(tasks) {
case 0:
return return
} case 1:
if len(tasks) == 1 {
c.tasks = append(c.tasks, tasks[0]) c.tasks = append(c.tasks, tasks[0])
return default:
c.tasks = append(c.tasks, func() error {
return execute(tasks...)
})
} }
}
}
c.tasks = append(c.tasks, func() error { func SequentialAll(tasks ...Task) ExecutionOption {
return execute(tasks...) return func(c *executionContext) {
}) switch len(tasks) {
case 0:
return
case 1:
c.tasks = append(c.tasks, tasks[0])
default:
c.tasks = append(c.tasks, func() error {
var merr MultiError
for _, task := range tasks {
if err := task(); err != nil {
merr = append(merr, err)
}
}
if len(merr) == 0 {
return nil
}
return merr
})
}
} }
} }

View File

@ -38,7 +38,7 @@ type Dialer interface {
// UserManager is the interface for Inbounds and Outbounds that can manage their users. // UserManager is the interface for Inbounds and Outbounds that can manage their users.
type UserManager interface { type UserManager interface {
// AddUser adds a new user. // AddUser adds a new user.
AddUser(context.Context, *protocol.User) error AddUser(context.Context, *protocol.MemoryUser) error
// RemoveUser removes a user by email. // RemoveUser removes a user by email.
RemoveUser(context.Context, string) error RemoveUser(context.Context, string) error

View File

@ -27,7 +27,11 @@ type Client struct {
func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
serverList := protocol.NewServerList() serverList := protocol.NewServerList()
for _, rec := range config.Server { for _, rec := range config.Server {
serverList.AddServer(protocol.NewServerSpecFromPB(*rec)) s, err := protocol.NewServerSpecFromPB(*rec)
if err != nil {
return nil, newError("failed to parse server spec").Base(err)
}
serverList.AddServer(s)
} }
if serverList.Size() == 0 { if serverList.Size() == 0 {
return nil, newError("0 server") return nil, newError("0 server")
@ -81,11 +85,10 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
} }
user := server.PickUser() user := server.PickUser()
rawAccount, err := user.GetTypedAccount() account, ok := user.Account.(*MemoryAccount)
if err != nil { if !ok {
return newError("failed to get a valid user account").AtWarning().Base(err) return newError("user account is not valid")
} }
account := rawAccount.(*MemoryAccount)
request.User = user request.User = user
if account.OneTimeAuth == Account_Auto || account.OneTimeAuth == Account_Enabled { if account.OneTimeAuth == Account_Auto || account.OneTimeAuth == Account_Enabled {

View File

@ -27,12 +27,8 @@ var addrParser = protocol.NewAddressParser(
) )
// ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts. // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) { func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
rawAccount, err := user.GetTypedAccount() account := user.Account.(*MemoryAccount)
if err != nil {
return nil, nil, newError("failed to parse account").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
buffer := buf.New() buffer := buf.New()
defer buffer.Release() defer buffer.Release()
@ -116,11 +112,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
// WriteTCPRequest writes Shadowsocks request into the given writer, and returns a writer for body. // WriteTCPRequest writes Shadowsocks request into the given writer, and returns a writer for body.
func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) { func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
user := request.User user := request.User
rawAccount, err := user.GetTypedAccount() account := user.Account.(*MemoryAccount)
if err != nil {
return nil, newError("failed to parse account").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
if account.Cipher.IsAEAD() { if account.Cipher.IsAEAD() {
request.Option.Clear(RequestOptionOneTimeAuth) request.Option.Clear(RequestOptionOneTimeAuth)
@ -167,17 +159,13 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
return chunkWriter, nil return chunkWriter, nil
} }
func ReadTCPResponse(user *protocol.User, reader io.Reader) (buf.Reader, error) { func ReadTCPResponse(user *protocol.MemoryUser, reader io.Reader) (buf.Reader, error) {
rawAccount, err := user.GetTypedAccount() account := user.Account.(*MemoryAccount)
if err != nil {
return nil, newError("failed to parse account").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
var iv []byte var iv []byte
if account.Cipher.IVSize() > 0 { if account.Cipher.IVSize() > 0 {
iv = make([]byte, account.Cipher.IVSize()) iv = make([]byte, account.Cipher.IVSize())
if _, err = io.ReadFull(reader, iv); err != nil { if _, err := io.ReadFull(reader, iv); err != nil {
return nil, newError("failed to read IV").Base(err) return nil, newError("failed to read IV").Base(err)
} }
} }
@ -187,11 +175,7 @@ func ReadTCPResponse(user *protocol.User, reader io.Reader) (buf.Reader, error)
func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) { func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
user := request.User user := request.User
rawAccount, err := user.GetTypedAccount() account := user.Account.(*MemoryAccount)
if err != nil {
return nil, newError("failed to parse account.").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
var iv []byte var iv []byte
if account.Cipher.IVSize() > 0 { if account.Cipher.IVSize() > 0 {
@ -207,11 +191,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) { func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
user := request.User user := request.User
rawAccount, err := user.GetTypedAccount() account := user.Account.(*MemoryAccount)
if err != nil {
return nil, newError("failed to parse account.").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
buffer := buf.New() buffer := buf.New()
ivLen := account.Cipher.IVSize() ivLen := account.Cipher.IVSize()
@ -239,12 +219,8 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
return buffer, nil return buffer, nil
} }
func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) { func DecodeUDPPacket(user *protocol.MemoryUser, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
rawAccount, err := user.GetTypedAccount() account := user.Account.(*MemoryAccount)
if err != nil {
return nil, nil, newError("failed to parse account").Base(err).AtError()
}
account := rawAccount.(*MemoryAccount)
var iv []byte var iv []byte
if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 { if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 {
@ -306,7 +282,7 @@ func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.Reques
type UDPReader struct { type UDPReader struct {
Reader io.Reader Reader io.Reader
User *protocol.User User *protocol.MemoryUser
} }
func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {

View File

@ -12,6 +12,12 @@ import (
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
) )
func toAccount(a *Account) protocol.Account {
account, err := a.AsAccount()
common.Must(err)
return account
}
func TestUDPEncoding(t *testing.T) { func TestUDPEncoding(t *testing.T) {
assert := With(t) assert := With(t)
@ -20,9 +26,9 @@ func TestUDPEncoding(t *testing.T) {
Command: protocol.RequestCommandUDP, Command: protocol.RequestCommandUDP,
Address: net.LocalHostIP, Address: net.LocalHostIP,
Port: 1234, Port: 1234,
User: &protocol.User{ User: &protocol.MemoryUser{
Email: "love@v2ray.com", Email: "love@v2ray.com",
Account: serial.ToTypedMessage(&Account{ Account: toAccount(&Account{
Password: "shadowsocks-password", Password: "shadowsocks-password",
CipherType: CipherType_AES_128_CFB, CipherType: CipherType_AES_128_CFB,
Ota: Account_Disabled, Ota: Account_Disabled,
@ -57,9 +63,9 @@ func TestTCPRequest(t *testing.T) {
Address: net.LocalHostIP, Address: net.LocalHostIP,
Option: RequestOptionOneTimeAuth, Option: RequestOptionOneTimeAuth,
Port: 1234, Port: 1234,
User: &protocol.User{ User: &protocol.MemoryUser{
Email: "love@v2ray.com", Email: "love@v2ray.com",
Account: serial.ToTypedMessage(&Account{ Account: toAccount(&Account{
Password: "tcp-password", Password: "tcp-password",
CipherType: CipherType_CHACHA20, CipherType: CipherType_CHACHA20,
}), }),
@ -74,9 +80,9 @@ func TestTCPRequest(t *testing.T) {
Address: net.LocalHostIPv6, Address: net.LocalHostIPv6,
Option: RequestOptionOneTimeAuth, Option: RequestOptionOneTimeAuth,
Port: 1234, Port: 1234,
User: &protocol.User{ User: &protocol.MemoryUser{
Email: "love@v2ray.com", Email: "love@v2ray.com",
Account: serial.ToTypedMessage(&Account{ Account: toAccount(&Account{
Password: "password", Password: "password",
CipherType: CipherType_AES_256_CFB, CipherType: CipherType_AES_256_CFB,
}), }),
@ -91,9 +97,9 @@ func TestTCPRequest(t *testing.T) {
Address: net.DomainAddress("v2ray.com"), Address: net.DomainAddress("v2ray.com"),
Option: RequestOptionOneTimeAuth, Option: RequestOptionOneTimeAuth,
Port: 1234, Port: 1234,
User: &protocol.User{ User: &protocol.MemoryUser{
Email: "love@v2ray.com", Email: "love@v2ray.com",
Account: serial.ToTypedMessage(&Account{ Account: toAccount(&Account{
Password: "password", Password: "password",
CipherType: CipherType_CHACHA20_IETF, CipherType: CipherType_CHACHA20_IETF,
}), }),
@ -135,8 +141,8 @@ func TestTCPRequest(t *testing.T) {
func TestUDPReaderWriter(t *testing.T) { func TestUDPReaderWriter(t *testing.T) {
assert := With(t) assert := With(t)
user := &protocol.User{ user := &protocol.MemoryUser{
Account: serial.ToTypedMessage(&Account{ Account: toAccount(&Account{
Password: "test-password", Password: "test-password",
CipherType: CipherType_CHACHA20_IETF, CipherType: CipherType_CHACHA20_IETF,
}), }),

View File

@ -20,10 +20,9 @@ import (
) )
type Server struct { type Server struct {
config ServerConfig config ServerConfig
user *protocol.User user *protocol.MemoryUser
account *MemoryAccount v *core.Instance
v *core.Instance
} }
// NewServer create a new Shadowsocks server. // NewServer create a new Shadowsocks server.
@ -32,17 +31,15 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
return nil, newError("user is not specified") return nil, newError("user is not specified")
} }
rawAccount, err := config.User.GetTypedAccount() mUser, err := config.User.ToMemoryUser()
if err != nil { if err != nil {
return nil, newError("failed to get user account").Base(err) return nil, newError("failed to parse user account").Base(err)
} }
account := rawAccount.(*MemoryAccount)
s := &Server{ s := &Server{
config: *config, config: *config,
user: config.GetUser(), user: mUser,
account: account, v: core.MustFromContext(ctx),
v: core.MustFromContext(ctx),
} }
return s, nil return s, nil
@ -90,6 +87,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
conn.Write(data.Bytes()) conn.Write(data.Bytes())
}) })
account := s.user.Account.(*MemoryAccount)
reader := buf.NewReader(conn) reader := buf.NewReader(conn)
for { for {
mpayload, err := reader.ReadMultiBuffer() mpayload, err := reader.ReadMultiBuffer()
@ -113,13 +112,13 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
continue continue
} }
if request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Disabled { if request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Disabled {
newError("client payload enables OTA but server doesn't allow it").WriteToLog(session.ExportIDToError(ctx)) newError("client payload enables OTA but server doesn't allow it").WriteToLog(session.ExportIDToError(ctx))
payload.Release() payload.Release()
continue continue
} }
if !request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Enabled { if !request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Enabled {
newError("client payload disables OTA but server forces it").WriteToLog(session.ExportIDToError(ctx)) newError("client payload disables OTA but server forces it").WriteToLog(session.ExportIDToError(ctx))
payload.Release() payload.Release()
continue continue

View File

@ -28,7 +28,11 @@ type Client struct {
func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
serverList := protocol.NewServerList() serverList := protocol.NewServerList()
for _, rec := range config.Server { for _, rec := range config.Server {
serverList.AddServer(protocol.NewServerSpecFromPB(*rec)) s, err := protocol.NewServerSpecFromPB(*rec)
if err != nil {
return nil, newError("failed to get server spec").Base(err)
}
serverList.AddServer(s)
} }
if serverList.Size() == 0 { if serverList.Size() == 0 {
return nil, newError("0 target server") return nil, newError("0 target server")

View File

@ -350,11 +350,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
common.Must2(b.WriteBytes(socks5Version, 0x01, authByte)) common.Must2(b.WriteBytes(socks5Version, 0x01, authByte))
if authByte == authPassword { if authByte == authPassword {
rawAccount, err := request.User.GetTypedAccount() account := request.User.Account.(*Account)
if err != nil {
return nil, err
}
account := rawAccount.(*Account)
common.Must2(b.WriteBytes(0x01, byte(len(account.Username)))) common.Must2(b.WriteBytes(0x01, byte(len(account.Username))))
common.Must2(b.Write([]byte(account.Username))) common.Must2(b.Write([]byte(account.Username)))

View File

@ -58,11 +58,8 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
account, err := header.User.GetTypedAccount() account := header.User.Account.(*vmess.InternalAccount)
if err != nil { idHash := c.idHash(account.AnyValidID().Bytes())
return newError("failed to get user account: ", err).AtError()
}
idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes())
common.Must2(idHash.Write(timestamp.Bytes(nil))) common.Must2(idHash.Write(timestamp.Bytes(nil)))
common.Must2(writer.Write(idHash.Sum(nil))) common.Must2(writer.Write(idHash.Sum(nil)))
@ -97,7 +94,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
timestampHash := md5.New() timestampHash := md5.New()
common.Must2(timestampHash.Write(hashTimestamp(timestamp))) common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
iv := timestampHash.Sum(nil) iv := timestampHash.Sum(nil)
aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv) aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv)
aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
common.Must2(writer.Write(buffer.Bytes())) common.Must2(writer.Write(buffer.Bytes()))
return nil return nil

View File

@ -7,17 +7,22 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/uuid" "v2ray.com/core/common/uuid"
"v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess"
. "v2ray.com/core/proxy/vmess/encoding" . "v2ray.com/core/proxy/vmess/encoding"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
) )
func toAccount(a *vmess.Account) protocol.Account {
account, err := a.AsAccount()
common.Must(err)
return account
}
func TestRequestSerialization(t *testing.T) { func TestRequestSerialization(t *testing.T) {
assert := With(t) assert := With(t)
user := &protocol.User{ user := &protocol.MemoryUser{
Level: 0, Level: 0,
Email: "test@v2ray.com", Email: "test@v2ray.com",
} }
@ -26,7 +31,7 @@ func TestRequestSerialization(t *testing.T) {
Id: id.String(), Id: id.String(),
AlterId: 0, AlterId: 0,
} }
user.Account = serial.ToTypedMessage(account) user.Account = toAccount(account)
expectedRequest := &protocol.RequestHeader{ expectedRequest := &protocol.RequestHeader{
Version: 1, Version: 1,
@ -70,7 +75,7 @@ func TestRequestSerialization(t *testing.T) {
func TestInvalidRequest(t *testing.T) { func TestInvalidRequest(t *testing.T) {
assert := With(t) assert := With(t)
user := &protocol.User{ user := &protocol.MemoryUser{
Level: 0, Level: 0,
Email: "test@v2ray.com", Email: "test@v2ray.com",
} }
@ -79,7 +84,7 @@ func TestInvalidRequest(t *testing.T) {
Id: id.String(), Id: id.String(),
AlterId: 0, AlterId: 0,
} }
user.Account = serial.ToTypedMessage(account) user.Account = toAccount(account)
expectedRequest := &protocol.RequestHeader{ expectedRequest := &protocol.RequestHeader{
Version: 1, Version: 1,
@ -112,7 +117,7 @@ func TestInvalidRequest(t *testing.T) {
func TestMuxRequest(t *testing.T) { func TestMuxRequest(t *testing.T) {
assert := With(t) assert := With(t)
user := &protocol.User{ user := &protocol.MemoryUser{
Level: 0, Level: 0,
Email: "test@v2ray.com", Email: "test@v2ray.com",
} }
@ -121,7 +126,7 @@ func TestMuxRequest(t *testing.T) {
Id: id.String(), Id: id.String(),
AlterId: 0, AlterId: 0,
} }
user.Account = serial.ToTypedMessage(account) user.Account = toAccount(account)
expectedRequest := &protocol.RequestHeader{ expectedRequest := &protocol.RequestHeader{
Version: 1, Version: 1,

View File

@ -139,11 +139,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
timestampHash := md5.New() timestampHash := md5.New()
common.Must2(timestampHash.Write(hashTimestamp(timestamp))) common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
iv := timestampHash.Sum(nil) iv := timestampHash.Sum(nil)
account, err := user.GetTypedAccount() vmessAccount := user.Account.(*vmess.InternalAccount)
if err != nil {
return nil, newError("failed to get user account").Base(err)
}
vmessAccount := account.(*vmess.InternalAccount)
aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv) aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv)
decryptor := crypto.NewCryptionReader(aesStream, reader) decryptor := crypto.NewCryptionReader(aesStream, reader)

View File

@ -16,7 +16,6 @@ import (
"v2ray.com/core/common/log" "v2ray.com/core/common/log"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/common/session" "v2ray.com/core/common/session"
"v2ray.com/core/common/signal" "v2ray.com/core/common/signal"
"v2ray.com/core/common/task" "v2ray.com/core/common/task"
@ -29,20 +28,20 @@ import (
type userByEmail struct { type userByEmail struct {
sync.Mutex sync.Mutex
cache map[string]*protocol.User cache map[string]*protocol.MemoryUser
defaultLevel uint32 defaultLevel uint32
defaultAlterIDs uint16 defaultAlterIDs uint16
} }
func newUserByEmail(config *DefaultConfig) *userByEmail { func newUserByEmail(config *DefaultConfig) *userByEmail {
return &userByEmail{ return &userByEmail{
cache: make(map[string]*protocol.User), cache: make(map[string]*protocol.MemoryUser),
defaultLevel: config.Level, defaultLevel: config.Level,
defaultAlterIDs: uint16(config.AlterId), defaultAlterIDs: uint16(config.AlterId),
} }
} }
func (v *userByEmail) addNoLock(u *protocol.User) bool { func (v *userByEmail) addNoLock(u *protocol.MemoryUser) bool {
email := strings.ToLower(u.Email) email := strings.ToLower(u.Email)
user, found := v.cache[email] user, found := v.cache[email]
if found { if found {
@ -52,14 +51,14 @@ func (v *userByEmail) addNoLock(u *protocol.User) bool {
return true return true
} }
func (v *userByEmail) Add(u *protocol.User) bool { func (v *userByEmail) Add(u *protocol.MemoryUser) bool {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
return v.addNoLock(u) return v.addNoLock(u)
} }
func (v *userByEmail) Get(email string) (*protocol.User, bool) { func (v *userByEmail) Get(email string) (*protocol.MemoryUser, bool) {
email = strings.ToLower(email) email = strings.ToLower(email)
v.Lock() v.Lock()
@ -68,14 +67,16 @@ func (v *userByEmail) Get(email string) (*protocol.User, bool) {
user, found := v.cache[email] user, found := v.cache[email]
if !found { if !found {
id := uuid.New() id := uuid.New()
account := &vmess.Account{ rawAccount := &vmess.Account{
Id: id.String(), Id: id.String(),
AlterId: uint32(v.defaultAlterIDs), AlterId: uint32(v.defaultAlterIDs),
} }
user = &protocol.User{ account, err := rawAccount.AsAccount()
common.Must(err)
user = &protocol.MemoryUser{
Level: v.defaultLevel, Level: v.defaultLevel,
Email: email, Email: email,
Account: serial.ToTypedMessage(account), Account: account,
} }
v.cache[email] = user v.cache[email] = user
} }
@ -120,7 +121,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
} }
for _, user := range config.User { for _, user := range config.User {
if err := handler.AddUser(ctx, user); err != nil { mUser, err := user.ToMemoryUser()
if err != nil {
return nil, newError("failed to get VMess user").Base(err)
}
if err := handler.AddUser(ctx, mUser); err != nil {
return nil, newError("failed to initiate user").Base(err) return nil, newError("failed to initiate user").Base(err)
} }
} }
@ -130,10 +136,9 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Close implements common.Closable. // Close implements common.Closable.
func (h *Handler) Close() error { func (h *Handler) Close() error {
common.Close(h.clients) return task.Run(
common.Close(h.sessionHistory) task.SequentialAll(
common.Close(h.usersByEmail) task.Close(h.clients), task.Close(h.sessionHistory), task.Close(h.usersByEmail)))()
return nil
} }
// Network implements proxy.Inbound.Network(). // Network implements proxy.Inbound.Network().
@ -143,7 +148,7 @@ func (*Handler) Network() net.NetworkList {
} }
} }
func (h *Handler) GetUser(email string) *protocol.User { func (h *Handler) GetUser(email string) *protocol.MemoryUser {
user, existing := h.usersByEmail.Get(email) user, existing := h.usersByEmail.Get(email)
if !existing { if !existing {
h.clients.Add(user) h.clients.Add(user)
@ -151,7 +156,7 @@ func (h *Handler) GetUser(email string) *protocol.User {
return user return user
} }
func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error { func (h *Handler) AddUser(ctx context.Context, user *protocol.MemoryUser) error {
if len(user.Email) > 0 && !h.usersByEmail.Add(user) { if len(user.Email) > 0 && !h.usersByEmail.Add(user) {
return newError("User ", user.Email, " already exists.") return newError("User ", user.Email, " already exists.")
} }
@ -325,11 +330,11 @@ func (h *Handler) generateCommand(ctx context.Context, request *protocol.Request
if user == nil { if user == nil {
return nil return nil
} }
account, _ := user.GetTypedAccount() account := user.Account.(*vmess.InternalAccount)
return &protocol.CommandSwitchAccount{ return &protocol.CommandSwitchAccount{
Port: port, Port: port,
ID: account.(*vmess.InternalAccount).ID.UUID(), ID: account.ID.UUID(),
AlterIds: uint16(len(account.(*vmess.InternalAccount).AlterIDs)), AlterIds: uint16(len(account.AlterIDs)),
Level: user.Level, Level: user.Level,
ValidMin: byte(availableMin), ValidMin: byte(availableMin),
} }

View File

@ -3,14 +3,14 @@ package outbound
import ( import (
"time" "time"
"v2ray.com/core/common"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
"v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess"
) )
func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
account := &vmess.Account{ rawAccount := &vmess.Account{
Id: cmd.ID.String(), Id: cmd.ID.String(),
AlterId: uint32(cmd.AlterIds), AlterId: uint32(cmd.AlterIds),
SecuritySettings: &protocol.SecurityConfig{ SecuritySettings: &protocol.SecurityConfig{
@ -18,10 +18,12 @@ func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
}, },
} }
user := &protocol.User{ account, err := rawAccount.AsAccount()
common.Must(err)
user := &protocol.MemoryUser{
Email: "", Email: "",
Level: cmd.Level, Level: cmd.Level,
Account: serial.ToTypedMessage(account), Account: account,
} }
dest := net.TCPDestination(cmd.Host, cmd.Port) dest := net.TCPDestination(cmd.Host, cmd.Port)
until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute) until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute)

View File

@ -33,7 +33,11 @@ type Handler struct {
func New(ctx context.Context, config *Config) (*Handler, error) { func New(ctx context.Context, config *Config) (*Handler, error) {
serverList := protocol.NewServerList() serverList := protocol.NewServerList()
for _, rec := range config.Receiver { for _, rec := range config.Receiver {
serverList.AddServer(protocol.NewServerSpecFromPB(*rec)) s, err := protocol.NewServerSpecFromPB(*rec)
if err != nil {
return nil, newError("failed to parse server spec").Base(err)
}
serverList.AddServer(s)
} }
handler := &Handler{ handler := &Handler{
serverList: serverList, serverList: serverList,
@ -87,11 +91,7 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
Option: protocol.RequestOptionChunkStream, Option: protocol.RequestOptionChunkStream,
} }
rawAccount, err := request.User.GetTypedAccount() account := request.User.Account.(*vmess.InternalAccount)
if err != nil {
return newError("failed to get user account").Base(err).AtWarning()
}
account := rawAccount.(*vmess.InternalAccount)
request.Security = account.Security request.Security = account.Security
if request.Security == protocol.SecurityType_AES128_GCM || request.Security == protocol.SecurityType_NONE || request.Security == protocol.SecurityType_CHACHA20_POLY1305 { if request.Security == protocol.SecurityType_AES128_GCM || request.Security == protocol.SecurityType_NONE || request.Security == protocol.SecurityType_CHACHA20_POLY1305 {

View File

@ -23,8 +23,7 @@ const (
) )
type user struct { type user struct {
user *protocol.User user *protocol.MemoryUser
account *InternalAccount
lastSec protocol.Timestamp lastSec protocol.Timestamp
} }
@ -80,8 +79,10 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *
} }
} }
genHashForID(user.account.ID) account := user.user.Account.(*InternalAccount)
for _, id := range user.account.AlterIDs {
genHashForID(account.ID)
for _, id := range account.AlterIDs {
genHashForID(id) genHashForID(id)
} }
user.lastSec = nowSec user.lastSec = nowSec
@ -111,21 +112,14 @@ func (v *TimedUserValidator) updateUserHash() {
} }
} }
func (v *TimedUserValidator) Add(u *protocol.User) error { func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
rawAccount, err := u.GetTypedAccount()
if err != nil {
return err
}
account := rawAccount.(*InternalAccount)
nowSec := time.Now().Unix() nowSec := time.Now().Unix()
uu := &user{ uu := &user{
user: u, user: u,
account: account,
lastSec: protocol.Timestamp(nowSec - cacheDurationSec), lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
} }
v.users = append(v.users, uu) v.users = append(v.users, uu)
@ -134,7 +128,7 @@ func (v *TimedUserValidator) Add(u *protocol.User) error {
return nil return nil
} }
func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Timestamp, bool) { func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool) {
defer v.RUnlock() defer v.RUnlock()
v.RLock() v.RLock()