diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index cdff7c8a2..78945532d 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -12,8 +12,6 @@ import ( "hash" "hash/fnv" "io" - "os" - vmessaead "v2ray.com/core/proxy/vmess/aead" "golang.org/x/crypto/chacha20poly1305" @@ -25,6 +23,7 @@ import ( "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" "v2ray.com/core/proxy/vmess" + vmessaead "v2ray.com/core/proxy/vmess/aead" ) func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte { @@ -37,6 +36,7 @@ func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte { // ClientSession stores connection session info for VMess client. type ClientSession struct { + isAEAD bool idHash protocol.IDHash requestBodyKey [16]byte requestBodyIV [16]byte @@ -44,35 +44,23 @@ type ClientSession struct { responseBodyIV [16]byte responseReader io.Reader responseHeader byte - - isAEADRequest bool } // NewClientSession creates a new ClientSession. -func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSession { +func NewClientSession(isAEAD bool, idHash protocol.IDHash, ctx context.Context) *ClientSession { + + session := &ClientSession{ + isAEAD: isAEAD, + idHash: idHash, + } + randomBytes := make([]byte, 33) // 16 + 16 + 1 common.Must2(rand.Read(randomBytes)) - - session := &ClientSession{} - - session.isAEADRequest = false - - if ctxValueAlterID := ctx.Value(vmess.AlterID); ctxValueAlterID != nil { - if ctxValueAlterID == 0 { - session.isAEADRequest = true - } - } - - if vmessAeadDisable, vmessAeadDisableFound := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); vmessAeadDisableFound { - if vmessAeadDisable == "true" { - session.isAEADRequest = false - } - } - copy(session.requestBodyKey[:], randomBytes[:16]) copy(session.requestBodyIV[:], randomBytes[16:32]) session.responseHeader = randomBytes[32] - if !session.isAEADRequest { + + if !session.isAEAD { session.responseBodyKey = md5.Sum(session.requestBodyKey[:]) session.responseBodyIV = md5.Sum(session.requestBodyIV[:]) } else { @@ -82,15 +70,13 @@ func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSessio copy(session.responseBodyIV[:], BodyIV[:16]) } - session.idHash = idHash - return session } func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() account := header.User.Account.(*vmess.MemoryAccount) - if !c.isAEADRequest { + if !c.isAEAD { idHash := c.idHash(account.AnyValidID().Bytes()) common.Must2(serial.WriteUint64(idHash, uint64(timestamp))) common.Must2(writer.Write(idHash.Sum(nil))) @@ -126,7 +112,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ fnv1a.Sum(hashBytes[:0]) } - if !c.isAEADRequest { + if !c.isAEAD { iv := hashTimestamp(md5.New(), timestamp) aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:]) aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) @@ -203,7 +189,7 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write } func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) { - if !c.isAEADRequest { + if !c.isAEAD { aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) c.responseReader = crypto.NewCryptionReader(aesStream, reader) } else { @@ -274,7 +260,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon header.Command = command } } - if c.isAEADRequest { + if c.isAEAD { aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) c.responseReader = crypto.NewCryptionReader(aesStream, reader) } diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index bc7eecd33..c0f938b7d 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -43,7 +43,7 @@ func TestRequestSerialization(t *testing.T) { } buffer := buf.New() - client := NewClientSession(protocol.DefaultIDHash, context.TODO()) + client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() @@ -93,7 +93,7 @@ func TestInvalidRequest(t *testing.T) { } buffer := buf.New() - client := NewClientSession(protocol.DefaultIDHash, context.TODO()) + client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() @@ -134,7 +134,7 @@ func TestMuxRequest(t *testing.T) { } buffer := buf.New() - client := NewClientSession(protocol.DefaultIDHash, context.TODO()) + client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index cbbbf585f..d76ca8680 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -6,6 +6,7 @@ package outbound import ( "context" + "os" "time" "v2ray.com/core" @@ -30,6 +31,7 @@ type Handler struct { serverList *protocol.ServerList serverPicker protocol.ServerPicker policyManager policy.Manager + aead_disabled bool } // New creates a new VMess outbound handler. @@ -50,16 +52,20 @@ func New(ctx context.Context, config *Config) (*Handler, error) { policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), } + if disabled, _ := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); disabled == "true" { + handler.aead_disabled = true + } + return handler, nil } // Process implements proxy.Outbound.Process(). -func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { +func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { var rec *protocol.ServerSpec var conn internet.Connection err := retry.ExponentialBackoff(5, 200).On(func() error { - rec = v.serverPicker.PickServer() + rec = h.serverPicker.PickServer() rawConn, err := dialer.Dial(ctx, rec.Destination()) if err != nil { return err @@ -113,10 +119,13 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte input := link.Reader output := link.Writer - ctx = context.WithValue(ctx, vmess.AlterID, len(account.AlterIDs)) + isAEAD := false + if !h.aead_disabled && len(account.AlterIDs) == 0 { + isAEAD = true + } - session := encoding.NewClientSession(protocol.DefaultIDHash, ctx) - sessionPolicy := v.policyManager.ForLevel(request.User.Level) + session := encoding.NewClientSession(isAEAD, protocol.DefaultIDHash, ctx) + sessionPolicy := h.policyManager.ForLevel(request.User.Level) ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) @@ -159,7 +168,7 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte if err != nil { return newError("failed to read header").Base(err) } - v.handleCommand(rec.Destination(), header.Command) + h.handleCommand(rec.Destination(), header.Command) bodyReader := session.DecodeResponseBody(request, reader) diff --git a/proxy/vmess/vmessCtxInterface.go b/proxy/vmess/vmessCtxInterface.go index dbfb5b72e..5d26f9e5d 100644 --- a/proxy/vmess/vmessCtxInterface.go +++ b/proxy/vmess/vmessCtxInterface.go @@ -1,3 +1,4 @@ package vmess +// example const AlterID = "VMessCtxInterface_AlterID"