1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-21 01:27:03 -05:00

Make isAEAD more efficient

This commit is contained in:
RPRX 2020-09-21 01:10:56 +00:00 committed by GitHub
parent 5f620256b2
commit 470dc8523b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 38 deletions

View File

@ -12,8 +12,6 @@ import (
"hash" "hash"
"hash/fnv" "hash/fnv"
"io" "io"
"os"
vmessaead "v2ray.com/core/proxy/vmess/aead"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@ -25,6 +23,7 @@ import (
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
"v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess"
vmessaead "v2ray.com/core/proxy/vmess/aead"
) )
func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte { func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte {
@ -37,6 +36,7 @@ func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte {
// ClientSession stores connection session info for VMess client. // ClientSession stores connection session info for VMess client.
type ClientSession struct { type ClientSession struct {
isAEAD bool
idHash protocol.IDHash idHash protocol.IDHash
requestBodyKey [16]byte requestBodyKey [16]byte
requestBodyIV [16]byte requestBodyIV [16]byte
@ -44,35 +44,23 @@ type ClientSession struct {
responseBodyIV [16]byte responseBodyIV [16]byte
responseReader io.Reader responseReader io.Reader
responseHeader byte responseHeader byte
isAEADRequest bool
} }
// NewClientSession creates a new ClientSession. // NewClientSession creates a new ClientSession.
func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSession { func NewClientSession(isAEAD bool, idHash protocol.IDHash, ctx context.Context) *ClientSession {
session := &ClientSession{
isAEAD: isAEAD,
idHash: idHash,
}
randomBytes := make([]byte, 33) // 16 + 16 + 1 randomBytes := make([]byte, 33) // 16 + 16 + 1
common.Must2(rand.Read(randomBytes)) common.Must2(rand.Read(randomBytes))
session := &ClientSession{}
session.isAEADRequest = false
if ctxValueAlterID := ctx.Value(vmess.AlterID); ctxValueAlterID != nil {
if ctxValueAlterID == 0 {
session.isAEADRequest = true
}
}
if vmessAeadDisable, vmessAeadDisableFound := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); vmessAeadDisableFound {
if vmessAeadDisable == "true" {
session.isAEADRequest = false
}
}
copy(session.requestBodyKey[:], randomBytes[:16]) copy(session.requestBodyKey[:], randomBytes[:16])
copy(session.requestBodyIV[:], randomBytes[16:32]) copy(session.requestBodyIV[:], randomBytes[16:32])
session.responseHeader = randomBytes[32] session.responseHeader = randomBytes[32]
if !session.isAEADRequest {
if !session.isAEAD {
session.responseBodyKey = md5.Sum(session.requestBodyKey[:]) session.responseBodyKey = md5.Sum(session.requestBodyKey[:])
session.responseBodyIV = md5.Sum(session.requestBodyIV[:]) session.responseBodyIV = md5.Sum(session.requestBodyIV[:])
} else { } else {
@ -82,15 +70,13 @@ func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSessio
copy(session.responseBodyIV[:], BodyIV[:16]) copy(session.responseBodyIV[:], BodyIV[:16])
} }
session.idHash = idHash
return session return session
} }
func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
account := header.User.Account.(*vmess.MemoryAccount) account := header.User.Account.(*vmess.MemoryAccount)
if !c.isAEADRequest { if !c.isAEAD {
idHash := c.idHash(account.AnyValidID().Bytes()) idHash := c.idHash(account.AnyValidID().Bytes())
common.Must2(serial.WriteUint64(idHash, uint64(timestamp))) common.Must2(serial.WriteUint64(idHash, uint64(timestamp)))
common.Must2(writer.Write(idHash.Sum(nil))) common.Must2(writer.Write(idHash.Sum(nil)))
@ -126,7 +112,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
fnv1a.Sum(hashBytes[:0]) fnv1a.Sum(hashBytes[:0])
} }
if !c.isAEADRequest { if !c.isAEAD {
iv := hashTimestamp(md5.New(), timestamp) iv := hashTimestamp(md5.New(), timestamp)
aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:]) aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:])
aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
@ -203,7 +189,7 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
} }
func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) { func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
if !c.isAEADRequest { if !c.isAEAD {
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
c.responseReader = crypto.NewCryptionReader(aesStream, reader) c.responseReader = crypto.NewCryptionReader(aesStream, reader)
} else { } else {
@ -274,7 +260,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
header.Command = command header.Command = command
} }
} }
if c.isAEADRequest { if c.isAEAD {
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
c.responseReader = crypto.NewCryptionReader(aesStream, reader) c.responseReader = crypto.NewCryptionReader(aesStream, reader)
} }

View File

@ -43,7 +43,7 @@ func TestRequestSerialization(t *testing.T) {
} }
buffer := buf.New() buffer := buf.New()
client := NewClientSession(protocol.DefaultIDHash, context.TODO()) client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
buffer2 := buf.New() buffer2 := buf.New()
@ -93,7 +93,7 @@ func TestInvalidRequest(t *testing.T) {
} }
buffer := buf.New() buffer := buf.New()
client := NewClientSession(protocol.DefaultIDHash, context.TODO()) client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
buffer2 := buf.New() buffer2 := buf.New()
@ -134,7 +134,7 @@ func TestMuxRequest(t *testing.T) {
} }
buffer := buf.New() buffer := buf.New()
client := NewClientSession(protocol.DefaultIDHash, context.TODO()) client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
buffer2 := buf.New() buffer2 := buf.New()

View File

@ -6,6 +6,7 @@ package outbound
import ( import (
"context" "context"
"os"
"time" "time"
"v2ray.com/core" "v2ray.com/core"
@ -30,6 +31,7 @@ type Handler struct {
serverList *protocol.ServerList serverList *protocol.ServerList
serverPicker protocol.ServerPicker serverPicker protocol.ServerPicker
policyManager policy.Manager policyManager policy.Manager
aead_disabled bool
} }
// New creates a new VMess outbound handler. // New creates a new VMess outbound handler.
@ -50,16 +52,20 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
} }
if disabled, _ := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); disabled == "true" {
handler.aead_disabled = true
}
return handler, nil return handler, nil
} }
// Process implements proxy.Outbound.Process(). // Process implements proxy.Outbound.Process().
func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
var rec *protocol.ServerSpec var rec *protocol.ServerSpec
var conn internet.Connection var conn internet.Connection
err := retry.ExponentialBackoff(5, 200).On(func() error { err := retry.ExponentialBackoff(5, 200).On(func() error {
rec = v.serverPicker.PickServer() rec = h.serverPicker.PickServer()
rawConn, err := dialer.Dial(ctx, rec.Destination()) rawConn, err := dialer.Dial(ctx, rec.Destination())
if err != nil { if err != nil {
return err return err
@ -113,10 +119,13 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
input := link.Reader input := link.Reader
output := link.Writer output := link.Writer
ctx = context.WithValue(ctx, vmess.AlterID, len(account.AlterIDs)) isAEAD := false
if !h.aead_disabled && len(account.AlterIDs) == 0 {
isAEAD = true
}
session := encoding.NewClientSession(protocol.DefaultIDHash, ctx) session := encoding.NewClientSession(isAEAD, protocol.DefaultIDHash, ctx)
sessionPolicy := v.policyManager.ForLevel(request.User.Level) sessionPolicy := h.policyManager.ForLevel(request.User.Level)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
@ -159,7 +168,7 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if err != nil { if err != nil {
return newError("failed to read header").Base(err) return newError("failed to read header").Base(err)
} }
v.handleCommand(rec.Destination(), header.Command) h.handleCommand(rec.Destination(), header.Command)
bodyReader := session.DecodeResponseBody(request, reader) bodyReader := session.DecodeResponseBody(request, reader)

View File

@ -1,3 +1,4 @@
package vmess package vmess
// example
const AlterID = "VMessCtxInterface_AlterID" const AlterID = "VMessCtxInterface_AlterID"