From 54e1bb96cc1722590ccc4941857e8eafbff6e113 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Mon, 27 Aug 2018 00:11:32 +0200 Subject: [PATCH] introduce in-memory user --- app/proxyman/command/command.go | 6 ++- app/router/condition_test.go | 4 +- common/protocol/context.go | 6 +-- common/protocol/headers.go | 2 +- common/protocol/server_spec.go | 29 +++++++------- common/protocol/server_spec_test.go | 17 ++++----- common/protocol/user.go | 18 +++++++++ common/task/task.go | 54 +++++++++++++++++++++++---- proxy/proxy.go | 2 +- proxy/shadowsocks/client.go | 13 ++++--- proxy/shadowsocks/protocol.go | 46 ++++++----------------- proxy/shadowsocks/protocol_test.go | 26 ++++++++----- proxy/shadowsocks/server.go | 25 ++++++------- proxy/socks/client.go | 6 ++- proxy/socks/protocol.go | 6 +-- proxy/vmess/encoding/client.go | 9 ++--- proxy/vmess/encoding/encoding_test.go | 19 ++++++---- proxy/vmess/encoding/server.go | 6 +-- proxy/vmess/inbound/inbound.go | 43 +++++++++++---------- proxy/vmess/outbound/command.go | 10 +++-- proxy/vmess/outbound/outbound.go | 12 +++--- proxy/vmess/vmess.go | 20 ++++------ 22 files changed, 212 insertions(+), 167 deletions(-) diff --git a/app/proxyman/command/command.go b/app/proxyman/command/command.go index ca3db17f9..2b9b92c00 100755 --- a/app/proxyman/command/command.go +++ b/app/proxyman/command/command.go @@ -39,7 +39,11 @@ func (op *AddUserOperation) ApplyInbound(ctx context.Context, handler core.Inbou if !ok { return newError("proxy is not a UserManager") } - return um.AddUser(ctx, op.User) + mUser, err := op.User.ToMemoryUser() + if err != nil { + return newError("failed to parse user").Base(err) + } + return um.AddUser(ctx, mUser) } // ApplyInbound implements InboundOperation. diff --git a/app/router/condition_test.go b/app/router/condition_test.go index b1adc4187..0ecacce50 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -126,11 +126,11 @@ func TestRoutingRule(t *testing.T) { }, test: []ruleTest{ { - input: protocol.ContextWithUser(context.Background(), &protocol.User{Email: "admin@v2ray.com"}), + input: protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "admin@v2ray.com"}), output: true, }, { - input: protocol.ContextWithUser(context.Background(), &protocol.User{Email: "love@v2ray.com"}), + input: protocol.ContextWithUser(context.Background(), &protocol.MemoryUser{Email: "love@v2ray.com"}), output: false, }, { diff --git a/common/protocol/context.go b/common/protocol/context.go index 85461de5e..42c980790 100755 --- a/common/protocol/context.go +++ b/common/protocol/context.go @@ -12,17 +12,17 @@ const ( ) // ContextWithUser returns a context combined with a User. -func ContextWithUser(ctx context.Context, user *User) context.Context { +func ContextWithUser(ctx context.Context, user *MemoryUser) context.Context { return context.WithValue(ctx, userKey, user) } // UserFromContext extracts a User from the given context, if any. -func UserFromContext(ctx context.Context) *User { +func UserFromContext(ctx context.Context) *MemoryUser { v := ctx.Value(userKey) if v == nil { return nil } - return v.(*User) + return v.(*MemoryUser) } func ContextWithRequestHeader(ctx context.Context, request *RequestHeader) context.Context { diff --git a/common/protocol/headers.go b/common/protocol/headers.go index f668255ad..83a5fdd6d 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -47,7 +47,7 @@ type RequestHeader struct { Security SecurityType Port net.Port Address net.Address - User *User + User *MemoryUser } func (h *RequestHeader) Destination() net.Destination { diff --git a/common/protocol/server_spec.go b/common/protocol/server_spec.go index 3f35d8f39..bcf05d691 100644 --- a/common/protocol/server_spec.go +++ b/common/protocol/server_spec.go @@ -46,11 +46,11 @@ func (s *timeoutValidStrategy) Invalidate() { type ServerSpec struct { sync.RWMutex dest net.Destination - users []*User + users []*MemoryUser valid ValidationStrategy } -func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*User) *ServerSpec { +func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*MemoryUser) *ServerSpec { return &ServerSpec{ dest: dest, users: users, @@ -58,33 +58,36 @@ func NewServerSpec(dest net.Destination, valid ValidationStrategy, users ...*Use } } -func NewServerSpecFromPB(spec ServerEndpoint) *ServerSpec { +func NewServerSpecFromPB(spec ServerEndpoint) (*ServerSpec, error) { dest := net.TCPDestination(spec.Address.AsAddress(), net.Port(spec.Port)) - return NewServerSpec(dest, AlwaysValid(), spec.User...) + mUsers := make([]*MemoryUser, len(spec.User)) + for idx, u := range spec.User { + mUser, err := u.ToMemoryUser() + if err != nil { + return nil, err + } + mUsers[idx] = mUser + } + return NewServerSpec(dest, AlwaysValid(), mUsers...), nil } func (s *ServerSpec) Destination() net.Destination { return s.dest } -func (s *ServerSpec) HasUser(user *User) bool { +func (s *ServerSpec) HasUser(user *MemoryUser) bool { s.RLock() defer s.RUnlock() - accountA, err := user.GetTypedAccount() - if err != nil { - return false - } for _, u := range s.users { - accountB, err := u.GetTypedAccount() - if err == nil && accountA.Equals(accountB) { + if u.Account.Equals(user.Account) { return true } } return false } -func (s *ServerSpec) AddUser(user *User) { +func (s *ServerSpec) AddUser(user *MemoryUser) { if s.HasUser(user) { return } @@ -95,7 +98,7 @@ func (s *ServerSpec) AddUser(user *User) { s.users = append(s.users, user) } -func (s *ServerSpec) PickUser() *User { +func (s *ServerSpec) PickUser() *MemoryUser { s.RLock() defer s.RUnlock() diff --git a/common/protocol/server_spec_test.go b/common/protocol/server_spec_test.go index b22647ad9..c56fd8548 100644 --- a/common/protocol/server_spec_test.go +++ b/common/protocol/server_spec_test.go @@ -6,7 +6,6 @@ import ( "v2ray.com/core/common/net" . "v2ray.com/core/common/protocol" - "v2ray.com/core/common/serial" "v2ray.com/core/common/uuid" "v2ray.com/core/proxy/vmess" . "v2ray.com/ext/assert" @@ -40,26 +39,26 @@ func TestUserInServerSpec(t *testing.T) { uuid1 := uuid.New() uuid2 := uuid.New() - spec := NewServerSpec(net.Destination{}, AlwaysValid(), &User{ + spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{ Email: "test1@v2ray.com", - Account: serial.ToTypedMessage(&vmess.Account{Id: uuid1.String()}), + Account: &vmess.Account{Id: uuid1.String()}, }) - assert(spec.HasUser(&User{ + assert(spec.HasUser(&MemoryUser{ Email: "test1@v2ray.com", - Account: serial.ToTypedMessage(&vmess.Account{Id: uuid2.String()}), + Account: &vmess.Account{Id: uuid2.String()}, }), IsFalse) - spec.AddUser(&User{Email: "test2@v2ray.com"}) - assert(spec.HasUser(&User{ + spec.AddUser(&MemoryUser{Email: "test2@v2ray.com"}) + assert(spec.HasUser(&MemoryUser{ Email: "test1@v2ray.com", - Account: serial.ToTypedMessage(&vmess.Account{Id: uuid1.String()}), + Account: &vmess.Account{Id: uuid1.String()}, }), IsTrue) } func TestPickUser(t *testing.T) { assert := With(t) - spec := NewServerSpec(net.Destination{}, AlwaysValid(), &User{Email: "test1@v2ray.com"}, &User{Email: "test2@v2ray.com"}, &User{Email: "test3@v2ray.com"}) + spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{Email: "test1@v2ray.com"}, &MemoryUser{Email: "test2@v2ray.com"}, &MemoryUser{Email: "test3@v2ray.com"}) user := spec.PickUser() assert(user.Email, HasSuffix, "@v2ray.com") } diff --git a/common/protocol/user.go b/common/protocol/user.go index 07f08ef94..30b78c80b 100644 --- a/common/protocol/user.go +++ b/common/protocol/user.go @@ -17,3 +17,21 @@ func (u *User) GetTypedAccount() (Account, error) { } return nil, newError("Unknown account type: ", u.Account.Type) } + +func (u *User) ToMemoryUser() (*MemoryUser, error) { + account, err := u.GetTypedAccount() + if err != nil { + return nil, err + } + return &MemoryUser{ + Account: account, + Email: u.Email, + Level: u.Level, + }, nil +} + +type MemoryUser struct { + Account Account + Email string + Level uint32 +} diff --git a/common/task/task.go b/common/task/task.go index 5a6763959..26227b856 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -2,12 +2,26 @@ package task import ( "context" + "strings" + "v2ray.com/core/common" "v2ray.com/core/common/signal/semaphore" ) type Task func() error +type MultiError []error + +func (e MultiError) Error() string { + var r strings.Builder + common.Must2(r.WriteString("multierr: ")) + for _, err := range e { + common.Must2(r.WriteString(err.Error())) + common.Must2(r.WriteString(" | ")) + } + return r.String() +} + type executionContext struct { ctx context.Context tasks []Task @@ -59,20 +73,44 @@ func Parallel(tasks ...Task) ExecutionOption { } } +// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential +// Once a task returns an error, the following tasks will not run. func Sequential(tasks ...Task) ExecutionOption { return func(c *executionContext) { - if len(tasks) == 0 { + switch len(tasks) { + case 0: return - } - - if len(tasks) == 1 { + case 1: c.tasks = append(c.tasks, tasks[0]) - return + default: + c.tasks = append(c.tasks, func() error { + return execute(tasks...) + }) } + } +} - c.tasks = append(c.tasks, func() error { - return execute(tasks...) - }) +func SequentialAll(tasks ...Task) ExecutionOption { + return func(c *executionContext) { + switch len(tasks) { + case 0: + return + case 1: + c.tasks = append(c.tasks, tasks[0]) + default: + c.tasks = append(c.tasks, func() error { + var merr MultiError + for _, task := range tasks { + if err := task(); err != nil { + merr = append(merr, err) + } + } + if len(merr) == 0 { + return nil + } + return merr + }) + } } } diff --git a/proxy/proxy.go b/proxy/proxy.go index fbd48a9cc..96b22eabd 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -38,7 +38,7 @@ type Dialer interface { // UserManager is the interface for Inbounds and Outbounds that can manage their users. type UserManager interface { // AddUser adds a new user. - AddUser(context.Context, *protocol.User) error + AddUser(context.Context, *protocol.MemoryUser) error // RemoveUser removes a user by email. RemoveUser(context.Context, string) error diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 7c0232d40..362bb7963 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -27,7 +27,11 @@ type Client struct { func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { serverList := protocol.NewServerList() for _, rec := range config.Server { - serverList.AddServer(protocol.NewServerSpecFromPB(*rec)) + s, err := protocol.NewServerSpecFromPB(*rec) + if err != nil { + return nil, newError("failed to parse server spec").Base(err) + } + serverList.AddServer(s) } if serverList.Size() == 0 { return nil, newError("0 server") @@ -81,11 +85,10 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial } user := server.PickUser() - rawAccount, err := user.GetTypedAccount() - if err != nil { - return newError("failed to get a valid user account").AtWarning().Base(err) + account, ok := user.Account.(*MemoryAccount) + if !ok { + return newError("user account is not valid") } - account := rawAccount.(*MemoryAccount) request.User = user if account.OneTimeAuth == Account_Auto || account.OneTimeAuth == Account_Enabled { diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index ce6681956..a6f49a1c4 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -27,12 +27,8 @@ var addrParser = protocol.NewAddressParser( ) // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts. -func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) { - rawAccount, err := user.GetTypedAccount() - if err != nil { - return nil, nil, newError("failed to parse account").Base(err).AtError() - } - account := rawAccount.(*MemoryAccount) +func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) { + account := user.Account.(*MemoryAccount) buffer := buf.New() defer buffer.Release() @@ -116,11 +112,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea // WriteTCPRequest writes Shadowsocks request into the given writer, and returns a writer for body. func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) { user := request.User - rawAccount, err := user.GetTypedAccount() - if err != nil { - return nil, newError("failed to parse account").Base(err).AtError() - } - account := rawAccount.(*MemoryAccount) + account := user.Account.(*MemoryAccount) if account.Cipher.IsAEAD() { request.Option.Clear(RequestOptionOneTimeAuth) @@ -167,17 +159,13 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri return chunkWriter, nil } -func ReadTCPResponse(user *protocol.User, reader io.Reader) (buf.Reader, error) { - rawAccount, err := user.GetTypedAccount() - if err != nil { - return nil, newError("failed to parse account").Base(err).AtError() - } - account := rawAccount.(*MemoryAccount) +func ReadTCPResponse(user *protocol.MemoryUser, reader io.Reader) (buf.Reader, error) { + account := user.Account.(*MemoryAccount) var iv []byte if account.Cipher.IVSize() > 0 { iv = make([]byte, account.Cipher.IVSize()) - if _, err = io.ReadFull(reader, iv); err != nil { + if _, err := io.ReadFull(reader, iv); err != nil { return nil, newError("failed to read IV").Base(err) } } @@ -187,11 +175,7 @@ func ReadTCPResponse(user *protocol.User, reader io.Reader) (buf.Reader, error) func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) { user := request.User - rawAccount, err := user.GetTypedAccount() - if err != nil { - return nil, newError("failed to parse account.").Base(err).AtError() - } - account := rawAccount.(*MemoryAccount) + account := user.Account.(*MemoryAccount) var iv []byte if account.Cipher.IVSize() > 0 { @@ -207,11 +191,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) { user := request.User - rawAccount, err := user.GetTypedAccount() - if err != nil { - return nil, newError("failed to parse account.").Base(err).AtError() - } - account := rawAccount.(*MemoryAccount) + account := user.Account.(*MemoryAccount) buffer := buf.New() ivLen := account.Cipher.IVSize() @@ -239,12 +219,8 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff return buffer, nil } -func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) { - rawAccount, err := user.GetTypedAccount() - if err != nil { - return nil, nil, newError("failed to parse account").Base(err).AtError() - } - account := rawAccount.(*MemoryAccount) +func DecodeUDPPacket(user *protocol.MemoryUser, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) { + account := user.Account.(*MemoryAccount) var iv []byte if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 { @@ -306,7 +282,7 @@ func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.Reques type UDPReader struct { Reader io.Reader - User *protocol.User + User *protocol.MemoryUser } func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go index 88ef221ba..6cf6cdb0c 100644 --- a/proxy/shadowsocks/protocol_test.go +++ b/proxy/shadowsocks/protocol_test.go @@ -12,6 +12,12 @@ import ( . "v2ray.com/ext/assert" ) +func toAccount(a *Account) protocol.Account { + account, err := a.AsAccount() + common.Must(err) + return account +} + func TestUDPEncoding(t *testing.T) { assert := With(t) @@ -20,9 +26,9 @@ func TestUDPEncoding(t *testing.T) { Command: protocol.RequestCommandUDP, Address: net.LocalHostIP, Port: 1234, - User: &protocol.User{ + User: &protocol.MemoryUser{ Email: "love@v2ray.com", - Account: serial.ToTypedMessage(&Account{ + Account: toAccount(&Account{ Password: "shadowsocks-password", CipherType: CipherType_AES_128_CFB, Ota: Account_Disabled, @@ -57,9 +63,9 @@ func TestTCPRequest(t *testing.T) { Address: net.LocalHostIP, Option: RequestOptionOneTimeAuth, Port: 1234, - User: &protocol.User{ + User: &protocol.MemoryUser{ Email: "love@v2ray.com", - Account: serial.ToTypedMessage(&Account{ + Account: toAccount(&Account{ Password: "tcp-password", CipherType: CipherType_CHACHA20, }), @@ -74,9 +80,9 @@ func TestTCPRequest(t *testing.T) { Address: net.LocalHostIPv6, Option: RequestOptionOneTimeAuth, Port: 1234, - User: &protocol.User{ + User: &protocol.MemoryUser{ Email: "love@v2ray.com", - Account: serial.ToTypedMessage(&Account{ + Account: toAccount(&Account{ Password: "password", CipherType: CipherType_AES_256_CFB, }), @@ -91,9 +97,9 @@ func TestTCPRequest(t *testing.T) { Address: net.DomainAddress("v2ray.com"), Option: RequestOptionOneTimeAuth, Port: 1234, - User: &protocol.User{ + User: &protocol.MemoryUser{ Email: "love@v2ray.com", - Account: serial.ToTypedMessage(&Account{ + Account: toAccount(&Account{ Password: "password", CipherType: CipherType_CHACHA20_IETF, }), @@ -135,8 +141,8 @@ func TestTCPRequest(t *testing.T) { func TestUDPReaderWriter(t *testing.T) { assert := With(t) - user := &protocol.User{ - Account: serial.ToTypedMessage(&Account{ + user := &protocol.MemoryUser{ + Account: toAccount(&Account{ Password: "test-password", CipherType: CipherType_CHACHA20_IETF, }), diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 401f188e9..0ed848292 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -20,10 +20,9 @@ import ( ) type Server struct { - config ServerConfig - user *protocol.User - account *MemoryAccount - v *core.Instance + config ServerConfig + user *protocol.MemoryUser + v *core.Instance } // NewServer create a new Shadowsocks server. @@ -32,17 +31,15 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { return nil, newError("user is not specified") } - rawAccount, err := config.User.GetTypedAccount() + mUser, err := config.User.ToMemoryUser() if err != nil { - return nil, newError("failed to get user account").Base(err) + return nil, newError("failed to parse user account").Base(err) } - account := rawAccount.(*MemoryAccount) s := &Server{ - config: *config, - user: config.GetUser(), - account: account, - v: core.MustFromContext(ctx), + config: *config, + user: mUser, + v: core.MustFromContext(ctx), } return s, nil @@ -90,6 +87,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection conn.Write(data.Bytes()) }) + account := s.user.Account.(*MemoryAccount) + reader := buf.NewReader(conn) for { mpayload, err := reader.ReadMultiBuffer() @@ -113,13 +112,13 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection continue } - if request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Disabled { + if request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Disabled { newError("client payload enables OTA but server doesn't allow it").WriteToLog(session.ExportIDToError(ctx)) payload.Release() continue } - if !request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Enabled { + if !request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Enabled { newError("client payload disables OTA but server forces it").WriteToLog(session.ExportIDToError(ctx)) payload.Release() continue diff --git a/proxy/socks/client.go b/proxy/socks/client.go index e837b53c2..bd67ba39f 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -28,7 +28,11 @@ type Client struct { func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { serverList := protocol.NewServerList() for _, rec := range config.Server { - serverList.AddServer(protocol.NewServerSpecFromPB(*rec)) + s, err := protocol.NewServerSpecFromPB(*rec) + if err != nil { + return nil, newError("failed to get server spec").Base(err) + } + serverList.AddServer(s) } if serverList.Size() == 0 { return nil, newError("0 target server") diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index a3c378503..35d08d798 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -350,11 +350,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i common.Must2(b.WriteBytes(socks5Version, 0x01, authByte)) if authByte == authPassword { - rawAccount, err := request.User.GetTypedAccount() - if err != nil { - return nil, err - } - account := rawAccount.(*Account) + account := request.User.Account.(*Account) common.Must2(b.WriteBytes(0x01, byte(len(account.Username)))) common.Must2(b.Write([]byte(account.Username))) diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index 31506dcda..596bb997f 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -58,11 +58,8 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() - account, err := header.User.GetTypedAccount() - if err != nil { - return newError("failed to get user account: ", err).AtError() - } - idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes()) + account := header.User.Account.(*vmess.InternalAccount) + idHash := c.idHash(account.AnyValidID().Bytes()) common.Must2(idHash.Write(timestamp.Bytes(nil))) common.Must2(writer.Write(idHash.Sum(nil))) @@ -97,7 +94,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ timestampHash := md5.New() common.Must2(timestampHash.Write(hashTimestamp(timestamp))) iv := timestampHash.Sum(nil) - aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv) + aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv) aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) common.Must2(writer.Write(buffer.Bytes())) return nil diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index 785508c91..b7cc10495 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -7,17 +7,22 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" - "v2ray.com/core/common/serial" "v2ray.com/core/common/uuid" "v2ray.com/core/proxy/vmess" . "v2ray.com/core/proxy/vmess/encoding" . "v2ray.com/ext/assert" ) +func toAccount(a *vmess.Account) protocol.Account { + account, err := a.AsAccount() + common.Must(err) + return account +} + func TestRequestSerialization(t *testing.T) { assert := With(t) - user := &protocol.User{ + user := &protocol.MemoryUser{ Level: 0, Email: "test@v2ray.com", } @@ -26,7 +31,7 @@ func TestRequestSerialization(t *testing.T) { Id: id.String(), AlterId: 0, } - user.Account = serial.ToTypedMessage(account) + user.Account = toAccount(account) expectedRequest := &protocol.RequestHeader{ Version: 1, @@ -70,7 +75,7 @@ func TestRequestSerialization(t *testing.T) { func TestInvalidRequest(t *testing.T) { assert := With(t) - user := &protocol.User{ + user := &protocol.MemoryUser{ Level: 0, Email: "test@v2ray.com", } @@ -79,7 +84,7 @@ func TestInvalidRequest(t *testing.T) { Id: id.String(), AlterId: 0, } - user.Account = serial.ToTypedMessage(account) + user.Account = toAccount(account) expectedRequest := &protocol.RequestHeader{ Version: 1, @@ -112,7 +117,7 @@ func TestInvalidRequest(t *testing.T) { func TestMuxRequest(t *testing.T) { assert := With(t) - user := &protocol.User{ + user := &protocol.MemoryUser{ Level: 0, Email: "test@v2ray.com", } @@ -121,7 +126,7 @@ func TestMuxRequest(t *testing.T) { Id: id.String(), AlterId: 0, } - user.Account = serial.ToTypedMessage(account) + user.Account = toAccount(account) expectedRequest := &protocol.RequestHeader{ Version: 1, diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 7bc185905..1c027cea5 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -139,11 +139,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request timestampHash := md5.New() common.Must2(timestampHash.Write(hashTimestamp(timestamp))) iv := timestampHash.Sum(nil) - account, err := user.GetTypedAccount() - if err != nil { - return nil, newError("failed to get user account").Base(err) - } - vmessAccount := account.(*vmess.InternalAccount) + vmessAccount := user.Account.(*vmess.InternalAccount) aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv) decryptor := crypto.NewCryptionReader(aesStream, reader) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index e3220f35c..ae42015e1 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -16,7 +16,6 @@ import ( "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" - "v2ray.com/core/common/serial" "v2ray.com/core/common/session" "v2ray.com/core/common/signal" "v2ray.com/core/common/task" @@ -29,20 +28,20 @@ import ( type userByEmail struct { sync.Mutex - cache map[string]*protocol.User + cache map[string]*protocol.MemoryUser defaultLevel uint32 defaultAlterIDs uint16 } func newUserByEmail(config *DefaultConfig) *userByEmail { return &userByEmail{ - cache: make(map[string]*protocol.User), + cache: make(map[string]*protocol.MemoryUser), defaultLevel: config.Level, defaultAlterIDs: uint16(config.AlterId), } } -func (v *userByEmail) addNoLock(u *protocol.User) bool { +func (v *userByEmail) addNoLock(u *protocol.MemoryUser) bool { email := strings.ToLower(u.Email) user, found := v.cache[email] if found { @@ -52,14 +51,14 @@ func (v *userByEmail) addNoLock(u *protocol.User) bool { return true } -func (v *userByEmail) Add(u *protocol.User) bool { +func (v *userByEmail) Add(u *protocol.MemoryUser) bool { v.Lock() defer v.Unlock() return v.addNoLock(u) } -func (v *userByEmail) Get(email string) (*protocol.User, bool) { +func (v *userByEmail) Get(email string) (*protocol.MemoryUser, bool) { email = strings.ToLower(email) v.Lock() @@ -68,14 +67,16 @@ func (v *userByEmail) Get(email string) (*protocol.User, bool) { user, found := v.cache[email] if !found { id := uuid.New() - account := &vmess.Account{ + rawAccount := &vmess.Account{ Id: id.String(), AlterId: uint32(v.defaultAlterIDs), } - user = &protocol.User{ + account, err := rawAccount.AsAccount() + common.Must(err) + user = &protocol.MemoryUser{ Level: v.defaultLevel, Email: email, - Account: serial.ToTypedMessage(account), + Account: account, } v.cache[email] = user } @@ -120,7 +121,12 @@ func New(ctx context.Context, config *Config) (*Handler, error) { } for _, user := range config.User { - if err := handler.AddUser(ctx, user); err != nil { + mUser, err := user.ToMemoryUser() + if err != nil { + return nil, newError("failed to get VMess user").Base(err) + } + + if err := handler.AddUser(ctx, mUser); err != nil { return nil, newError("failed to initiate user").Base(err) } } @@ -130,10 +136,9 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Close implements common.Closable. func (h *Handler) Close() error { - common.Close(h.clients) - common.Close(h.sessionHistory) - common.Close(h.usersByEmail) - return nil + return task.Run( + task.SequentialAll( + task.Close(h.clients), task.Close(h.sessionHistory), task.Close(h.usersByEmail)))() } // Network implements proxy.Inbound.Network(). @@ -143,7 +148,7 @@ func (*Handler) Network() net.NetworkList { } } -func (h *Handler) GetUser(email string) *protocol.User { +func (h *Handler) GetUser(email string) *protocol.MemoryUser { user, existing := h.usersByEmail.Get(email) if !existing { h.clients.Add(user) @@ -151,7 +156,7 @@ func (h *Handler) GetUser(email string) *protocol.User { return user } -func (h *Handler) AddUser(ctx context.Context, user *protocol.User) error { +func (h *Handler) AddUser(ctx context.Context, user *protocol.MemoryUser) error { if len(user.Email) > 0 && !h.usersByEmail.Add(user) { return newError("User ", user.Email, " already exists.") } @@ -325,11 +330,11 @@ func (h *Handler) generateCommand(ctx context.Context, request *protocol.Request if user == nil { return nil } - account, _ := user.GetTypedAccount() + account := user.Account.(*vmess.InternalAccount) return &protocol.CommandSwitchAccount{ Port: port, - ID: account.(*vmess.InternalAccount).ID.UUID(), - AlterIds: uint16(len(account.(*vmess.InternalAccount).AlterIDs)), + ID: account.ID.UUID(), + AlterIds: uint16(len(account.AlterIDs)), Level: user.Level, ValidMin: byte(availableMin), } diff --git a/proxy/vmess/outbound/command.go b/proxy/vmess/outbound/command.go index 19316acd3..94fc1bc17 100644 --- a/proxy/vmess/outbound/command.go +++ b/proxy/vmess/outbound/command.go @@ -3,14 +3,14 @@ package outbound import ( "time" + "v2ray.com/core/common" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" - "v2ray.com/core/common/serial" "v2ray.com/core/proxy/vmess" ) func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { - account := &vmess.Account{ + rawAccount := &vmess.Account{ Id: cmd.ID.String(), AlterId: uint32(cmd.AlterIds), SecuritySettings: &protocol.SecurityConfig{ @@ -18,10 +18,12 @@ func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { }, } - user := &protocol.User{ + account, err := rawAccount.AsAccount() + common.Must(err) + user := &protocol.MemoryUser{ Email: "", Level: cmd.Level, - Account: serial.ToTypedMessage(account), + Account: account, } dest := net.TCPDestination(cmd.Host, cmd.Port) until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 3801afdb1..bfce2c287 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -33,7 +33,11 @@ type Handler struct { func New(ctx context.Context, config *Config) (*Handler, error) { serverList := protocol.NewServerList() for _, rec := range config.Receiver { - serverList.AddServer(protocol.NewServerSpecFromPB(*rec)) + s, err := protocol.NewServerSpecFromPB(*rec) + if err != nil { + return nil, newError("failed to parse server spec").Base(err) + } + serverList.AddServer(s) } handler := &Handler{ serverList: serverList, @@ -87,11 +91,7 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia Option: protocol.RequestOptionChunkStream, } - rawAccount, err := request.User.GetTypedAccount() - if err != nil { - return newError("failed to get user account").Base(err).AtWarning() - } - account := rawAccount.(*vmess.InternalAccount) + account := request.User.Account.(*vmess.InternalAccount) request.Security = account.Security if request.Security == protocol.SecurityType_AES128_GCM || request.Security == protocol.SecurityType_NONE || request.Security == protocol.SecurityType_CHACHA20_POLY1305 { diff --git a/proxy/vmess/vmess.go b/proxy/vmess/vmess.go index 0c7038b81..525f07e41 100644 --- a/proxy/vmess/vmess.go +++ b/proxy/vmess/vmess.go @@ -23,8 +23,7 @@ const ( ) type user struct { - user *protocol.User - account *InternalAccount + user *protocol.MemoryUser lastSec protocol.Timestamp } @@ -80,8 +79,10 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user * } } - genHashForID(user.account.ID) - for _, id := range user.account.AlterIDs { + account := user.user.Account.(*InternalAccount) + + genHashForID(account.ID) + for _, id := range account.AlterIDs { genHashForID(id) } user.lastSec = nowSec @@ -111,21 +112,14 @@ func (v *TimedUserValidator) updateUserHash() { } } -func (v *TimedUserValidator) Add(u *protocol.User) error { +func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error { v.Lock() defer v.Unlock() - rawAccount, err := u.GetTypedAccount() - if err != nil { - return err - } - account := rawAccount.(*InternalAccount) - nowSec := time.Now().Unix() uu := &user{ user: u, - account: account, lastSec: protocol.Timestamp(nowSec - cacheDurationSec), } v.users = append(v.users, uu) @@ -134,7 +128,7 @@ func (v *TimedUserValidator) Add(u *protocol.User) error { return nil } -func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Timestamp, bool) { +func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool) { defer v.RUnlock() v.RLock()