mirror of
https://github.com/v2fly/v2ray-core.git
synced 2024-12-21 01:27:03 -05:00
Make isAEAD more efficient
This commit is contained in:
parent
5f620256b2
commit
470dc8523b
@ -12,8 +12,6 @@ import (
|
|||||||
"hash"
|
"hash"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
vmessaead "v2ray.com/core/proxy/vmess/aead"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
|
||||||
@ -25,6 +23,7 @@ import (
|
|||||||
"v2ray.com/core/common/protocol"
|
"v2ray.com/core/common/protocol"
|
||||||
"v2ray.com/core/common/serial"
|
"v2ray.com/core/common/serial"
|
||||||
"v2ray.com/core/proxy/vmess"
|
"v2ray.com/core/proxy/vmess"
|
||||||
|
vmessaead "v2ray.com/core/proxy/vmess/aead"
|
||||||
)
|
)
|
||||||
|
|
||||||
func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte {
|
func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte {
|
||||||
@ -37,6 +36,7 @@ func hashTimestamp(h hash.Hash, t protocol.Timestamp) []byte {
|
|||||||
|
|
||||||
// ClientSession stores connection session info for VMess client.
|
// ClientSession stores connection session info for VMess client.
|
||||||
type ClientSession struct {
|
type ClientSession struct {
|
||||||
|
isAEAD bool
|
||||||
idHash protocol.IDHash
|
idHash protocol.IDHash
|
||||||
requestBodyKey [16]byte
|
requestBodyKey [16]byte
|
||||||
requestBodyIV [16]byte
|
requestBodyIV [16]byte
|
||||||
@ -44,35 +44,23 @@ type ClientSession struct {
|
|||||||
responseBodyIV [16]byte
|
responseBodyIV [16]byte
|
||||||
responseReader io.Reader
|
responseReader io.Reader
|
||||||
responseHeader byte
|
responseHeader byte
|
||||||
|
|
||||||
isAEADRequest bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientSession creates a new ClientSession.
|
// NewClientSession creates a new ClientSession.
|
||||||
func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSession {
|
func NewClientSession(isAEAD bool, idHash protocol.IDHash, ctx context.Context) *ClientSession {
|
||||||
|
|
||||||
|
session := &ClientSession{
|
||||||
|
isAEAD: isAEAD,
|
||||||
|
idHash: idHash,
|
||||||
|
}
|
||||||
|
|
||||||
randomBytes := make([]byte, 33) // 16 + 16 + 1
|
randomBytes := make([]byte, 33) // 16 + 16 + 1
|
||||||
common.Must2(rand.Read(randomBytes))
|
common.Must2(rand.Read(randomBytes))
|
||||||
|
|
||||||
session := &ClientSession{}
|
|
||||||
|
|
||||||
session.isAEADRequest = false
|
|
||||||
|
|
||||||
if ctxValueAlterID := ctx.Value(vmess.AlterID); ctxValueAlterID != nil {
|
|
||||||
if ctxValueAlterID == 0 {
|
|
||||||
session.isAEADRequest = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if vmessAeadDisable, vmessAeadDisableFound := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); vmessAeadDisableFound {
|
|
||||||
if vmessAeadDisable == "true" {
|
|
||||||
session.isAEADRequest = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(session.requestBodyKey[:], randomBytes[:16])
|
copy(session.requestBodyKey[:], randomBytes[:16])
|
||||||
copy(session.requestBodyIV[:], randomBytes[16:32])
|
copy(session.requestBodyIV[:], randomBytes[16:32])
|
||||||
session.responseHeader = randomBytes[32]
|
session.responseHeader = randomBytes[32]
|
||||||
if !session.isAEADRequest {
|
|
||||||
|
if !session.isAEAD {
|
||||||
session.responseBodyKey = md5.Sum(session.requestBodyKey[:])
|
session.responseBodyKey = md5.Sum(session.requestBodyKey[:])
|
||||||
session.responseBodyIV = md5.Sum(session.requestBodyIV[:])
|
session.responseBodyIV = md5.Sum(session.requestBodyIV[:])
|
||||||
} else {
|
} else {
|
||||||
@ -82,15 +70,13 @@ func NewClientSession(idHash protocol.IDHash, ctx context.Context) *ClientSessio
|
|||||||
copy(session.responseBodyIV[:], BodyIV[:16])
|
copy(session.responseBodyIV[:], BodyIV[:16])
|
||||||
}
|
}
|
||||||
|
|
||||||
session.idHash = idHash
|
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
|
func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
|
||||||
timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
|
timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
|
||||||
account := header.User.Account.(*vmess.MemoryAccount)
|
account := header.User.Account.(*vmess.MemoryAccount)
|
||||||
if !c.isAEADRequest {
|
if !c.isAEAD {
|
||||||
idHash := c.idHash(account.AnyValidID().Bytes())
|
idHash := c.idHash(account.AnyValidID().Bytes())
|
||||||
common.Must2(serial.WriteUint64(idHash, uint64(timestamp)))
|
common.Must2(serial.WriteUint64(idHash, uint64(timestamp)))
|
||||||
common.Must2(writer.Write(idHash.Sum(nil)))
|
common.Must2(writer.Write(idHash.Sum(nil)))
|
||||||
@ -126,7 +112,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
|
|||||||
fnv1a.Sum(hashBytes[:0])
|
fnv1a.Sum(hashBytes[:0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.isAEADRequest {
|
if !c.isAEAD {
|
||||||
iv := hashTimestamp(md5.New(), timestamp)
|
iv := hashTimestamp(md5.New(), timestamp)
|
||||||
aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:])
|
aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:])
|
||||||
aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
|
aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
|
||||||
@ -203,7 +189,7 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
|
func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
|
||||||
if !c.isAEADRequest {
|
if !c.isAEAD {
|
||||||
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
|
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
|
||||||
c.responseReader = crypto.NewCryptionReader(aesStream, reader)
|
c.responseReader = crypto.NewCryptionReader(aesStream, reader)
|
||||||
} else {
|
} else {
|
||||||
@ -274,7 +260,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
|
|||||||
header.Command = command
|
header.Command = command
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if c.isAEADRequest {
|
if c.isAEAD {
|
||||||
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
|
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
|
||||||
c.responseReader = crypto.NewCryptionReader(aesStream, reader)
|
c.responseReader = crypto.NewCryptionReader(aesStream, reader)
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ func TestRequestSerialization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer := buf.New()
|
buffer := buf.New()
|
||||||
client := NewClientSession(protocol.DefaultIDHash, context.TODO())
|
client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
|
||||||
common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
|
common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
|
||||||
|
|
||||||
buffer2 := buf.New()
|
buffer2 := buf.New()
|
||||||
@ -93,7 +93,7 @@ func TestInvalidRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer := buf.New()
|
buffer := buf.New()
|
||||||
client := NewClientSession(protocol.DefaultIDHash, context.TODO())
|
client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
|
||||||
common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
|
common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
|
||||||
|
|
||||||
buffer2 := buf.New()
|
buffer2 := buf.New()
|
||||||
@ -134,7 +134,7 @@ func TestMuxRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer := buf.New()
|
buffer := buf.New()
|
||||||
client := NewClientSession(protocol.DefaultIDHash, context.TODO())
|
client := NewClientSession(true, protocol.DefaultIDHash, context.TODO())
|
||||||
common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
|
common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
|
||||||
|
|
||||||
buffer2 := buf.New()
|
buffer2 := buf.New()
|
||||||
|
@ -6,6 +6,7 @@ package outbound
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"v2ray.com/core"
|
"v2ray.com/core"
|
||||||
@ -30,6 +31,7 @@ type Handler struct {
|
|||||||
serverList *protocol.ServerList
|
serverList *protocol.ServerList
|
||||||
serverPicker protocol.ServerPicker
|
serverPicker protocol.ServerPicker
|
||||||
policyManager policy.Manager
|
policyManager policy.Manager
|
||||||
|
aead_disabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new VMess outbound handler.
|
// New creates a new VMess outbound handler.
|
||||||
@ -50,16 +52,20 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
|
|||||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if disabled, _ := os.LookupEnv("V2RAY_VMESS_AEAD_DISABLED"); disabled == "true" {
|
||||||
|
handler.aead_disabled = true
|
||||||
|
}
|
||||||
|
|
||||||
return handler, nil
|
return handler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process implements proxy.Outbound.Process().
|
// Process implements proxy.Outbound.Process().
|
||||||
func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||||||
var rec *protocol.ServerSpec
|
var rec *protocol.ServerSpec
|
||||||
var conn internet.Connection
|
var conn internet.Connection
|
||||||
|
|
||||||
err := retry.ExponentialBackoff(5, 200).On(func() error {
|
err := retry.ExponentialBackoff(5, 200).On(func() error {
|
||||||
rec = v.serverPicker.PickServer()
|
rec = h.serverPicker.PickServer()
|
||||||
rawConn, err := dialer.Dial(ctx, rec.Destination())
|
rawConn, err := dialer.Dial(ctx, rec.Destination())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -113,10 +119,13 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||||||
input := link.Reader
|
input := link.Reader
|
||||||
output := link.Writer
|
output := link.Writer
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, vmess.AlterID, len(account.AlterIDs))
|
isAEAD := false
|
||||||
|
if !h.aead_disabled && len(account.AlterIDs) == 0 {
|
||||||
|
isAEAD = true
|
||||||
|
}
|
||||||
|
|
||||||
session := encoding.NewClientSession(protocol.DefaultIDHash, ctx)
|
session := encoding.NewClientSession(isAEAD, protocol.DefaultIDHash, ctx)
|
||||||
sessionPolicy := v.policyManager.ForLevel(request.User.Level)
|
sessionPolicy := h.policyManager.ForLevel(request.User.Level)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
|
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
|
||||||
@ -159,7 +168,7 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return newError("failed to read header").Base(err)
|
return newError("failed to read header").Base(err)
|
||||||
}
|
}
|
||||||
v.handleCommand(rec.Destination(), header.Command)
|
h.handleCommand(rec.Destination(), header.Command)
|
||||||
|
|
||||||
bodyReader := session.DecodeResponseBody(request, reader)
|
bodyReader := session.DecodeResponseBody(request, reader)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
package vmess
|
package vmess
|
||||||
|
|
||||||
|
// example
|
||||||
const AlterID = "VMessCtxInterface_AlterID"
|
const AlterID = "VMessCtxInterface_AlterID"
|
||||||
|
Loading…
Reference in New Issue
Block a user