1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-02 07:26:24 -05:00

VMess AEAD Experiment

This commit is contained in:
Shelikhoo 2020-06-06 17:11:30 +08:00
parent b610fc0a70
commit 9bf07b1f26
No known key found for this signature in database
GPG Key ID: C4D5E79D22B25316
11 changed files with 621 additions and 45 deletions

View File

@ -0,0 +1,50 @@
package antireplay
import (
cuckoo "github.com/seiflotfy/cuckoofilter"
"sync"
"time"
)
func NewAntiReplayWindow(AntiReplayTime int64) *AntiReplayWindow {
arw := &AntiReplayWindow{}
arw.AntiReplayTime = AntiReplayTime
return arw
}
type AntiReplayWindow struct {
lock sync.Mutex
poolA *cuckoo.Filter
poolB *cuckoo.Filter
lastSwapTime int64
PoolSwap bool
AntiReplayTime int64
}
func (aw *AntiReplayWindow) Check(sum []byte) bool {
aw.lock.Lock()
if aw.lastSwapTime == 0 {
aw.lastSwapTime = time.Now().Unix()
aw.poolA = cuckoo.NewFilter(100000)
aw.poolB = cuckoo.NewFilter(100000)
}
tnow := time.Now().Unix()
timediff := tnow - aw.lastSwapTime
if timediff >= aw.AntiReplayTime {
if aw.PoolSwap {
aw.PoolSwap = false
aw.poolA.Reset()
} else {
aw.PoolSwap = true
aw.poolB.Reset()
}
aw.lastSwapTime = tnow
}
ret := aw.poolA.InsertUnique(sum) && aw.poolB.InsertUnique(sum)
aw.lock.Unlock()
return ret
}

View File

@ -38,8 +38,6 @@ const (
RequestOptionChunkMasking bitmask.Byte = 0x04 RequestOptionChunkMasking bitmask.Byte = 0x04
RequestOptionGlobalPadding bitmask.Byte = 0x08 RequestOptionGlobalPadding bitmask.Byte = 0x08
RequestOptionEarlyChecksum bitmask.Byte = 0x16
) )
type RequestHeader struct { type RequestHeader struct {

2
go.mod
View File

@ -1,12 +1,14 @@
module v2ray.com/core module v2ray.com/core
require ( require (
github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc // indirect
github.com/golang/mock v1.2.0 github.com/golang/mock v1.2.0
github.com/golang/protobuf v1.3.2 github.com/golang/protobuf v1.3.2
github.com/google/go-cmp v0.2.0 github.com/google/go-cmp v0.2.0
github.com/gorilla/websocket v1.4.1 github.com/gorilla/websocket v1.4.1
github.com/miekg/dns v1.1.4 github.com/miekg/dns v1.1.4
github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57 github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57
github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841
go.starlark.net v0.0.0-20190919145610-979af19b165c go.starlark.net v0.0.0-20190919145610-979af19b165c
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3

4
go.sum
View File

@ -1,6 +1,8 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc h1:8WFBn63wegobsYAX0YjD+8suexZDga5CctH4CCTx2+8=
github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
@ -16,6 +18,8 @@ github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0=
github.com/miekg/dns v1.1.4/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.4/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57 h1:SL1K0QAuC1b54KoY1pjPWe6kSlsFHwK9/oC960fKrTY= github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57 h1:SL1K0QAuC1b54KoY1pjPWe6kSlsFHwK9/oC960fKrTY=
github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57/go.mod h1:tz9gX959MEFfFN5whTIocCLUG57WiILqtdVxI8c6Wj0= github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57/go.mod h1:tz9gX959MEFfFN5whTIocCLUG57WiILqtdVxI8c6Wj0=
github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841 h1:pnfutQFsV7ySmHUeX6ANGfPsBo29RctUvDn8G3rmJVw=
github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841/go.mod h1:ET5mVvNjwaGXRgZxO9UZr7X+8eAf87AfIYNwRSp9s4Y=
go.starlark.net v0.0.0-20190919145610-979af19b165c h1:WR7X1xgXJlXhQBdorVc9Db3RhwG+J/kp6bLuMyJjfVw= go.starlark.net v0.0.0-20190919145610-979af19b165c h1:WR7X1xgXJlXhQBdorVc9Db3RhwG+J/kp6bLuMyJjfVw=
go.starlark.net v0.0.0-20190919145610-979af19b165c/go.mod h1:c1/X6cHgvdXj6pUlmWKMkuqRnW4K8x2vwt6JAaaircg= go.starlark.net v0.0.0-20190919145610-979af19b165c/go.mod h1:c1/X6cHgvdXj6pUlmWKMkuqRnW4K8x2vwt6JAaaircg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=

114
proxy/vmess/aead/authid.go Normal file
View File

@ -0,0 +1,114 @@
package aead
import (
"bytes"
"crypto/aes"
"crypto/cipher"
rand3 "crypto/rand"
"encoding/binary"
"errors"
"hash/crc32"
"io"
"math"
"time"
"v2ray.com/core/common"
antiReplayWindow "v2ray.com/core/common/antireplay"
)
func CreateAuthID(cmdKey []byte, time int64) [16]byte {
buf := bytes.NewBuffer(nil)
common.Must(binary.Write(buf, binary.BigEndian, time))
var zero uint32
common.Must2(io.CopyN(buf, rand3.Reader, 4))
zero = crc32.ChecksumIEEE(buf.Bytes())
common.Must(binary.Write(buf, binary.BigEndian, zero))
aesBlock := NewCipherFromKey(cmdKey)
if buf.Len() != 16 {
panic("Size unexpected")
}
var result [16]byte
aesBlock.Encrypt(result[:], buf.Bytes())
return result
}
func NewCipherFromKey(cmdKey []byte) cipher.Block {
aesBlock, err := aes.NewCipher(KDF16(cmdKey, "AES Auth ID Encryption"))
if err != nil {
panic(err)
}
return aesBlock
}
type AuthIDDecoder struct {
s cipher.Block
}
func NewAuthIDDecoder(cmdKey []byte) *AuthIDDecoder {
return &AuthIDDecoder{NewCipherFromKey(cmdKey)}
}
func (aidd *AuthIDDecoder) Decode(data [16]byte) (int64, uint32, int32, []byte) {
aidd.s.Decrypt(data[:], data[:])
var t int64
var zero uint32
var rand int32
reader := bytes.NewReader(data[:])
common.Must(binary.Read(reader, binary.BigEndian, &t))
common.Must(binary.Read(reader, binary.BigEndian, &rand))
common.Must(binary.Read(reader, binary.BigEndian, &zero))
return t, zero, rand, data[:]
}
func NewAuthIDDecoderHolder() *AuthIDDecoderHolder {
return &AuthIDDecoderHolder{make(map[string]*AuthIDDecoderItem), antiReplayWindow.NewAntiReplayWindow(120)}
}
type AuthIDDecoderHolder struct {
aidhi map[string]*AuthIDDecoderItem
apw *antiReplayWindow.AntiReplayWindow
}
type AuthIDDecoderItem struct {
dec *AuthIDDecoder
ticket interface{}
}
func NewAuthIDDecoderItem(key [16]byte, ticket interface{}) *AuthIDDecoderItem {
return &AuthIDDecoderItem{
dec: NewAuthIDDecoder(key[:]),
ticket: ticket,
}
}
func (a *AuthIDDecoderHolder) AddUser(key [16]byte, ticket interface{}) {
a.aidhi[string(key[:])] = NewAuthIDDecoderItem(key, ticket)
}
func (a *AuthIDDecoderHolder) RemoveUser(key [16]byte) {
delete(a.aidhi, string(key[:]))
}
func (a *AuthIDDecoderHolder) Match(AuthID [16]byte) (interface{}, error) {
if !a.apw.Check(AuthID[:]) {
return nil, errReplay
}
for _, v := range a.aidhi {
t, z, r, d := v.dec.Decode(AuthID)
if z != crc32.ChecksumIEEE(d[:12]) {
continue
}
if math.Abs(float64(t-time.Now().Unix())) > 120 {
continue
}
_ = r
return v.ticket, nil
}
return nil, errNotFound
}
var errNotFound = errors.New("user do not exist")
var errReplay = errors.New("replayed request")

141
proxy/vmess/aead/encrypt.go Normal file
View File

@ -0,0 +1,141 @@
package aead
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"time"
"v2ray.com/core/common"
)
func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
authid := CreateAuthID(key[:], time.Now().Unix())
nonce := make([]byte, 8)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
panic(err.Error())
}
lengthbuf := bytes.NewBuffer(nil)
var HeaderDataLen uint16
HeaderDataLen = uint16(len(data))
common.Must(binary.Write(lengthbuf, binary.BigEndian, HeaderDataLen))
authidCheck := KDF16(key[:], "VMess AuthID Check Value", string(authid[:]), string(lengthbuf.Bytes()), string(nonce))
lengthbufb := lengthbuf.Bytes()
LengthMask := KDF16(key[:], "VMess AuthID Mask Value", string(authid[:]), string(nonce[:]))[:2]
lengthbufb[0] = lengthbufb[0] ^ LengthMask[0]
lengthbufb[1] = lengthbufb[1] ^ LengthMask[1]
HeaderAEADKey := KDF16(key[:], "VMess Header AEAD Key", string(authid[:]), string(nonce))
HeaderAEADNonce := KDF(key[:], "VMess Header AEAD Nonce", string(authid[:]), string(nonce))[:12]
block, err := aes.NewCipher(HeaderAEADKey)
if err != nil {
panic(err.Error())
}
headerAEAD, err := cipher.NewGCM(block)
if err != nil {
panic(err.Error())
}
headerSealed := headerAEAD.Seal(nil, HeaderAEADNonce, data, authid[:])
var outPutBuf = bytes.NewBuffer(nil)
common.Must2(outPutBuf.Write(authid[:])) //16
common.Must2(outPutBuf.Write(authidCheck)) //16
common.Must2(outPutBuf.Write(lengthbufb)) //2
common.Must2(outPutBuf.Write(nonce)) //8
common.Must2(outPutBuf.Write(headerSealed))
return outPutBuf.Bytes()
}
func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, bool, error, int) {
var authidCheck [16]byte
var lengthbufb [2]byte
var nonce [8]byte
n, err := io.ReadFull(data, authidCheck[:])
if err != nil {
return nil, false, err, n
}
n2, err := io.ReadFull(data, lengthbufb[:])
if err != nil {
return nil, false, err, n + n2
}
n4, err := io.ReadFull(data, nonce[:])
if err != nil {
return nil, false, err, n + n2 + n4
}
//Unmask Length
LengthMask := KDF16(key[:], "VMess AuthID Mask Value", string(authid[:]), string(nonce[:]))[:2]
lengthbufb[0] = lengthbufb[0] ^ LengthMask[0]
lengthbufb[1] = lengthbufb[1] ^ LengthMask[1]
authidCheckV := KDF16(key[:], "VMess AuthID Check Value", string(authid[:]), string(lengthbufb[:]), string(nonce[:]))
if !hmac.Equal(authidCheckV, authidCheck[:]) {
return nil, true, errCheckMismatch, n + n2 + n4
}
var length uint16
common.Must(binary.Read(bytes.NewReader(lengthbufb[:]), binary.BigEndian, &length))
HeaderAEADKey := KDF16(key[:], "VMess Header AEAD Key", string(authid[:]), string(nonce[:]))
HeaderAEADNonce := KDF(key[:], "VMess Header AEAD Nonce", string(authid[:]), string(nonce[:]))[:12]
//16 == AEAD Tag size
header := make([]byte, length+16)
n3, err := io.ReadFull(data, header)
if err != nil {
return nil, false, err, n + n2 + n3 + n4
}
block, err := aes.NewCipher(HeaderAEADKey)
if err != nil {
panic(err.Error())
}
headerAEAD, err := cipher.NewGCM(block)
if err != nil {
panic(err.Error())
}
out, erropenAEAD := headerAEAD.Open(nil, HeaderAEADNonce, header, authid[:])
if erropenAEAD != nil {
return nil, true, erropenAEAD, n + n2 + n3 + n4
}
return out, false, nil, n + n2 + n3 + n4
}
var errCheckMismatch = errors.New("check verify failed")

26
proxy/vmess/aead/kdf.go Normal file
View File

@ -0,0 +1,26 @@
package aead
import (
"crypto/hmac"
"crypto/sha256"
"hash"
)
func KDF(key []byte, path ...string) []byte {
hmacf := hmac.New(func() hash.Hash {
return sha256.New()
}, []byte("VMess AEAD KDF"))
for _, v := range path {
hmacf = hmac.New(func() hash.Hash {
return hmacf
}, []byte(v))
}
hmacf.Write(key)
return hmacf.Sum(nil)
}
func KDF16(key []byte, path ...string) []byte {
r := KDF(key, path...)
return r[:16]
}

View File

@ -1,12 +1,19 @@
package encoding package encoding
import ( import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5" "crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/binary" "encoding/binary"
"fmt"
"hash" "hash"
"hash/fnv" "hash/fnv"
"io" "io"
"os"
vmessaead "v2ray.com/core/proxy/vmess/aead"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@ -37,6 +44,8 @@ type ClientSession struct {
responseBodyIV [16]byte responseBodyIV [16]byte
responseReader io.Reader responseReader io.Reader
responseHeader byte responseHeader byte
isAEADRequest bool
} }
// NewClientSession creates a new ClientSession. // NewClientSession creates a new ClientSession.
@ -45,11 +54,29 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
common.Must2(rand.Read(randomBytes)) common.Must2(rand.Read(randomBytes))
session := &ClientSession{} session := &ClientSession{}
session.isAEADRequest = false
if vmessexp, vmessexp_found := os.LookupEnv("VMESSAEADEXPERIMENT"); vmessexp_found {
if vmessexp == "y" {
session.isAEADRequest = true
fmt.Println("=======VMESSAEADEXPERIMENT ENABLED========")
}
}
copy(session.requestBodyKey[:], randomBytes[:16]) copy(session.requestBodyKey[:], randomBytes[:16])
copy(session.requestBodyIV[:], randomBytes[16:32]) copy(session.requestBodyIV[:], randomBytes[16:32])
session.responseHeader = randomBytes[32] session.responseHeader = randomBytes[32]
session.responseBodyKey = md5.Sum(session.requestBodyKey[:]) if !session.isAEADRequest {
session.responseBodyIV = md5.Sum(session.requestBodyIV[:]) session.responseBodyKey = md5.Sum(session.requestBodyKey[:])
session.responseBodyIV = md5.Sum(session.requestBodyIV[:])
} else {
BodyKey := sha256.Sum256(session.requestBodyKey[:])
copy(session.responseBodyKey[:], BodyKey[:16])
BodyIV := sha256.Sum256(session.requestBodyKey[:])
copy(session.responseBodyIV[:], BodyIV[:16])
}
session.idHash = idHash session.idHash = idHash
return session return session
@ -58,9 +85,11 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
account := header.User.Account.(*vmess.MemoryAccount) account := header.User.Account.(*vmess.MemoryAccount)
idHash := c.idHash(account.AnyValidID().Bytes()) if !c.isAEADRequest {
common.Must2(serial.WriteUint64(idHash, uint64(timestamp))) idHash := c.idHash(account.AnyValidID().Bytes())
common.Must2(writer.Write(idHash.Sum(nil))) common.Must2(serial.WriteUint64(idHash, uint64(timestamp)))
common.Must2(writer.Write(idHash.Sum(nil)))
}
buffer := buf.New() buffer := buf.New()
defer buffer.Release() defer buffer.Release()
@ -92,10 +121,18 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
fnv1a.Sum(hashBytes[:0]) fnv1a.Sum(hashBytes[:0])
} }
iv := hashTimestamp(md5.New(), timestamp) if !c.isAEADRequest {
aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:]) iv := hashTimestamp(md5.New(), timestamp)
aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:])
common.Must2(writer.Write(buffer.Bytes())) aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
common.Must2(writer.Write(buffer.Bytes()))
} else {
var fixedLengthCmdKey [16]byte
copy(fixedLengthCmdKey[:], account.ID.CmdKey())
vmessout := vmessaead.SealVMessAEADHeader(fixedLengthCmdKey, buffer.Bytes())
common.Must2(io.Copy(writer, bytes.NewReader(vmessout)))
}
return nil return nil
} }
@ -161,8 +198,49 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
} }
func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) { func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) if !c.isAEADRequest {
c.responseReader = crypto.NewCryptionReader(aesStream, reader) aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
c.responseReader = crypto.NewCryptionReader(aesStream, reader)
} else {
resph := vmessaead.KDF16(c.responseBodyKey[:], "AEAD Resp Header Len Key")
respi := vmessaead.KDF(c.responseBodyIV[:], "AEAD Resp Header Len IV")[:12]
aesblock := common.Must2(aes.NewCipher(resph)).(cipher.Block)
aeadHeader := common.Must2(cipher.NewGCM(aesblock)).(cipher.AEAD)
var AEADLen [18]byte
var lenresp int
var lenrespr uint16
if _, err := io.ReadFull(reader, AEADLen[:]); err != nil {
return nil, newError("Unable to Read Header Len").Base(err)
}
if AEADLend, err := aeadHeader.Open(nil, respi, AEADLen[:], nil); err != nil {
return nil, newError("Failed To Decrypt Length").Base(err)
} else {
common.Must(binary.Read(bytes.NewReader(AEADLend), binary.BigEndian, &lenrespr))
lenresp = int(lenrespr)
}
resphc := vmessaead.KDF16(c.responseBodyKey[:], "AEAD Resp Header Key")
respic := vmessaead.KDF(c.responseBodyIV[:], "AEAD Resp Header IV")[:12]
aesblockc := common.Must2(aes.NewCipher(resphc)).(cipher.Block)
aeadHeaderc := common.Must2(cipher.NewGCM(aesblockc)).(cipher.AEAD)
respPayload := make([]byte, lenresp+16)
if _, err := io.ReadFull(reader, respPayload); err != nil {
return nil, newError("Unable to Read Header Data").Base(err)
}
if AEADData, err := aeadHeaderc.Open(nil, respic, respPayload, nil); err != nil {
return nil, newError("Failed To Decrypt Payload").Base(err)
} else {
c.responseReader = bytes.NewReader(AEADData)
}
}
buffer := buf.StackNew() buffer := buf.StackNew()
defer buffer.Release() defer buffer.Release()
@ -192,7 +270,10 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
header.Command = command header.Command = command
} }
} }
if c.isAEADRequest {
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
c.responseReader = crypto.NewCryptionReader(aesStream, reader)
}
return header, nil return header, nil
} }

View File

@ -1,7 +1,11 @@
package encoding package encoding
import ( import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5" "crypto/md5"
"crypto/sha256"
"encoding/binary" "encoding/binary"
"hash/fnv" "hash/fnv"
"io" "io"
@ -9,6 +13,7 @@ import (
"sync" "sync"
"time" "time"
"v2ray.com/core/common/dice" "v2ray.com/core/common/dice"
vmessaead "v2ray.com/core/proxy/vmess/aead"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@ -99,6 +104,10 @@ type ServerSession struct {
responseBodyIV [16]byte responseBodyIV [16]byte
responseWriter io.Writer responseWriter io.Writer
responseHeader byte responseHeader byte
isAEADRequest bool
isAEADForced bool
} }
// NewServerSession creates a new ServerSession, using the given UserValidator. // NewServerSession creates a new ServerSession, using the given UserValidator.
@ -153,17 +162,44 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
return nil, newError("failed to read request header").Base(err) return nil, newError("failed to read request header").Base(err)
} }
user, timestamp, valid := s.userValidator.Get(buffer.Bytes()) var decryptor io.Reader
if !valid { var vmessAccount *vmess.MemoryAccount
user, foundAEAD := s.userValidator.GetAEAD(buffer.Bytes())
var fixedSizeAuthID [16]byte
copy(fixedSizeAuthID[:], buffer.Bytes())
if foundAEAD == true {
vmessAccount = user.Account.(*vmess.MemoryAccount)
var fixedSizeCmdKey [16]byte
copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey())
aeadData, shouldDrain, errorReason, bytesRead := vmessaead.OpenVMessAEADHeader(fixedSizeCmdKey, fixedSizeAuthID, reader)
if errorReason != nil {
if shouldDrain {
readSizeRemain -= bytesRead
return nil, drainConnection(newError("AEAD read failed").Base(errorReason))
} else {
return nil, drainConnection(newError("AEAD read failed, drain skiped").Base(errorReason))
}
}
decryptor = bytes.NewReader(aeadData)
s.isAEADRequest = true
} else if !s.isAEADForced {
userLegacy, timestamp, valid, userValidationError := s.userValidator.Get(buffer.Bytes())
if !valid || userValidationError != nil {
return nil, drainConnection(newError("invalid user").Base(userValidationError))
}
user = userLegacy
iv := hashTimestamp(md5.New(), timestamp)
vmessAccount = userLegacy.Account.(*vmess.MemoryAccount)
aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:])
decryptor = crypto.NewCryptionReader(aesStream, reader)
} else {
return nil, drainConnection(newError("invalid user")) return nil, drainConnection(newError("invalid user"))
} }
iv := hashTimestamp(md5.New(), timestamp)
vmessAccount := user.Account.(*vmess.MemoryAccount)
aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:])
decryptor := crypto.NewCryptionReader(aesStream, reader)
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
buffer.Clear() buffer.Clear()
if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil { if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil {
@ -182,7 +218,16 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
sid.key = s.requestBodyKey sid.key = s.requestBodyKey
sid.nonce = s.requestBodyIV sid.nonce = s.requestBodyIV
if !s.sessionHistory.addIfNotExits(sid) { if !s.sessionHistory.addIfNotExits(sid) {
return nil, drainConnection(newError("duplicated session id, possibly under replay attack")) if !s.isAEADRequest {
drainErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:])
if drainErr != nil {
return nil, drainConnection(newError("duplicated session id, possibly under replay attack, and failed to taint userHash").Base(drainErr))
}
return nil, drainConnection(newError("duplicated session id, possibly under replay attack, userHash tainted"))
} else {
return nil, newError("duplicated session id, possibly under replay attack, but this is a AEAD request")
}
} }
s.responseHeader = buffer.Byte(33) // 1 byte s.responseHeader = buffer.Byte(33) // 1 byte
@ -205,11 +250,25 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
if padingLen > 0 { if padingLen > 0 {
if _, err := buffer.ReadFullFrom(decryptor, int32(padingLen)); err != nil { if _, err := buffer.ReadFullFrom(decryptor, int32(padingLen)); err != nil {
if !s.isAEADRequest {
burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:])
if burnErr != nil {
return nil, newError("failed to read padding, failed to taint userHash").Base(burnErr).Base(err)
}
return nil, newError("failed to read padding, userHash tainted").Base(err)
}
return nil, newError("failed to read padding").Base(err) return nil, newError("failed to read padding").Base(err)
} }
} }
if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil { if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil {
if !s.isAEADRequest {
burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:])
if burnErr != nil {
return nil, newError("failed to read checksum, failed to taint userHash").Base(burnErr).Base(err)
}
return nil, newError("failed to read checksum, userHash tainted").Base(err)
}
return nil, newError("failed to read checksum").Base(err) return nil, newError("failed to read checksum").Base(err)
} }
@ -219,8 +278,18 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
expectedHash := binary.BigEndian.Uint32(buffer.BytesFrom(-4)) expectedHash := binary.BigEndian.Uint32(buffer.BytesFrom(-4))
if actualHash != expectedHash { if actualHash != expectedHash {
//It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523 if !s.isAEADRequest {
return nil, drainConnection(newError("invalid auth")) Autherr := newError("invalid auth, legacy userHash tainted")
burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:])
if burnErr != nil {
Autherr = newError("invalid auth, can't taint legacy userHash").Base(burnErr)
}
//It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523
return nil, drainConnection(Autherr)
} else {
return nil, newError("invalid auth, but this is a AEAD request")
}
} }
if request.Address == nil { if request.Address == nil {
@ -299,18 +368,60 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
// EncodeResponseHeader writes encoded response header into the given writer. // EncodeResponseHeader writes encoded response header into the given writer.
func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) { func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) {
s.responseBodyKey = md5.Sum(s.requestBodyKey[:]) var encryptionWriter io.Writer
s.responseBodyIV = md5.Sum(s.requestBodyIV[:]) if !s.isAEADRequest {
s.responseBodyKey = md5.Sum(s.requestBodyKey[:])
s.responseBodyIV = md5.Sum(s.requestBodyIV[:])
} else {
BodyKey := sha256.Sum256(s.requestBodyKey[:])
copy(s.responseBodyKey[:], BodyKey[:16])
BodyIV := sha256.Sum256(s.requestBodyKey[:])
copy(s.responseBodyIV[:], BodyIV[:16])
}
aesStream := crypto.NewAesEncryptionStream(s.responseBodyKey[:], s.responseBodyIV[:]) aesStream := crypto.NewAesEncryptionStream(s.responseBodyKey[:], s.responseBodyIV[:])
encryptionWriter := crypto.NewCryptionWriter(aesStream, writer) encryptionWriter = crypto.NewCryptionWriter(aesStream, writer)
s.responseWriter = encryptionWriter s.responseWriter = encryptionWriter
aeadBuffer := bytes.NewBuffer(nil)
if s.isAEADRequest {
encryptionWriter = aeadBuffer
}
common.Must2(encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)})) common.Must2(encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)}))
err := MarshalCommand(header.Command, encryptionWriter) err := MarshalCommand(header.Command, encryptionWriter)
if err != nil { if err != nil {
common.Must2(encryptionWriter.Write([]byte{0x00, 0x00})) common.Must2(encryptionWriter.Write([]byte{0x00, 0x00}))
} }
if s.isAEADRequest {
resph := vmessaead.KDF16(s.responseBodyKey[:], "AEAD Resp Header Len Key")
respi := vmessaead.KDF(s.responseBodyIV[:], "AEAD Resp Header Len IV")[:12]
aesblock := common.Must2(aes.NewCipher(resph)).(cipher.Block)
aeadHeader := common.Must2(cipher.NewGCM(aesblock)).(cipher.AEAD)
aeadlenBuf := bytes.NewBuffer(nil)
var aeadLen uint16
aeadLen = uint16(aeadBuffer.Len())
common.Must(binary.Write(aeadlenBuf, binary.BigEndian, aeadLen))
sealedLen := aeadHeader.Seal(nil, respi, aeadlenBuf.Bytes(), nil)
common.Must2(io.Copy(writer, bytes.NewReader(sealedLen)))
resphc := vmessaead.KDF16(s.responseBodyKey[:], "AEAD Resp Header Key")
respic := vmessaead.KDF(s.responseBodyIV[:], "AEAD Resp Header IV")[:12]
aesblockc := common.Must2(aes.NewCipher(resphc)).(cipher.Block)
aeadHeaderc := common.Must2(cipher.NewGCM(aesblockc)).(cipher.AEAD)
sealed := aeadHeaderc.Seal(nil, respic, aeadBuffer.Bytes(), nil)
common.Must2(io.Copy(writer, bytes.NewReader(sealed)))
}
} }
// EncodeResponseBody returns a Writer that auto-encrypt content written by caller. // EncodeResponseBody returns a Writer that auto-encrypt content written by caller.

View File

@ -8,6 +8,7 @@ import (
"sync" "sync"
"time" "time"
"v2ray.com/core/common/dice" "v2ray.com/core/common/dice"
"v2ray.com/core/proxy/vmess/aead"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
@ -28,27 +29,33 @@ type user struct {
// TimedUserValidator is a user Validator based on time. // TimedUserValidator is a user Validator based on time.
type TimedUserValidator struct { type TimedUserValidator struct {
sync.RWMutex sync.RWMutex
users []*user users []*user
userHash map[[16]byte]indexTimePair userHash map[[16]byte]indexTimePair
hasher protocol.IDHash hasher protocol.IDHash
baseTime protocol.Timestamp baseTime protocol.Timestamp
task *task.Periodic task *task.Periodic
behaviorSeed uint64 behaviorSeed uint64
behaviorFused bool behaviorFused bool
aeadDecoderHolder *aead.AuthIDDecoderHolder
} }
type indexTimePair struct { type indexTimePair struct {
user *user user *user
timeInc uint32 timeInc uint32
taintedFuse *bool
} }
// NewTimedUserValidator creates a new TimedUserValidator. // NewTimedUserValidator creates a new TimedUserValidator.
func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator { func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
tuv := &TimedUserValidator{ tuv := &TimedUserValidator{
users: make([]*user, 0, 16), users: make([]*user, 0, 16),
userHash: make(map[[16]byte]indexTimePair, 1024), userHash: make(map[[16]byte]indexTimePair, 1024),
hasher: hasher, hasher: hasher,
baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2), baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
aeadDecoderHolder: aead.NewAuthIDDecoderHolder(),
} }
tuv.task = &task.Periodic{ tuv.task = &task.Periodic{
Interval: updateInterval, Interval: updateInterval,
@ -76,8 +83,9 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *
idHash.Reset() idHash.Reset()
v.userHash[hashValue] = indexTimePair{ v.userHash[hashValue] = indexTimePair{
user: user, user: user,
timeInc: uint32(ts - v.baseTime), timeInc: uint32(ts - v.baseTime),
taintedFuse: new(bool),
} }
} }
} }
@ -128,15 +136,19 @@ func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
v.users = append(v.users, uu) v.users = append(v.users, uu)
v.generateNewHashes(protocol.Timestamp(nowSec), uu) v.generateNewHashes(protocol.Timestamp(nowSec), uu)
account := uu.user.Account.(*MemoryAccount)
if v.behaviorFused == false { if v.behaviorFused == false {
account := uu.user.Account.(*MemoryAccount)
v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), account.ID.Bytes()) v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), account.ID.Bytes())
} }
var cmdkeyfl [16]byte
copy(cmdkeyfl[:], account.ID.CmdKey())
v.aeadDecoderHolder.AddUser(cmdkeyfl, u)
return nil return nil
} }
func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool) { func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) {
defer v.RUnlock() defer v.RUnlock()
v.RLock() v.RLock()
@ -148,9 +160,25 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protoco
if found { if found {
var user protocol.MemoryUser var user protocol.MemoryUser
user = pair.user.user user = pair.user.user
return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true if *pair.taintedFuse == false {
return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil
}
return nil, 0, false, ErrTainted
} }
return nil, 0, false return nil, 0, false, ErrNotFound
}
func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool) {
defer v.RUnlock()
v.RLock()
var userHashFL [16]byte
copy(userHashFL[:], userHash)
userd, err := v.aeadDecoderHolder.Match(userHashFL)
if err != nil {
return nil, false
}
return userd.(*protocol.MemoryUser), true
} }
func (v *TimedUserValidator) Remove(email string) bool { func (v *TimedUserValidator) Remove(email string) bool {
@ -162,6 +190,9 @@ func (v *TimedUserValidator) Remove(email string) bool {
for i, u := range v.users { for i, u := range v.users {
if strings.EqualFold(u.user.Email, email) { if strings.EqualFold(u.user.Email, email) {
idx = i idx = i
var cmdkeyfl [16]byte
copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey())
v.aeadDecoderHolder.RemoveUser(cmdkeyfl)
break break
} }
} }
@ -191,3 +222,21 @@ func (v *TimedUserValidator) GetBehaviorSeed() uint64 {
} }
return v.behaviorSeed return v.behaviorSeed
} }
func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error {
v.Lock()
defer v.Unlock()
var userHashFL [16]byte
copy(userHashFL[:], userHash)
pair, found := v.userHash[userHashFL]
if found {
*pair.taintedFuse = true
return nil
}
return ErrNotFound
}
var ErrNotFound = newError("Not Found")
var ErrTainted = newError("ErrTainted")

View File

@ -39,7 +39,7 @@ func TestUserValidator(t *testing.T) {
common.Must2(serial.WriteUint64(idHash, uint64(ts))) common.Must2(serial.WriteUint64(idHash, uint64(ts)))
userHash := idHash.Sum(nil) userHash := idHash.Sum(nil)
euser, ets, found := v.Get(userHash) euser, ets, found, _ := v.Get(userHash)
if !found { if !found {
t.Fatal("user not found") t.Fatal("user not found")
} }
@ -67,7 +67,7 @@ func TestUserValidator(t *testing.T) {
common.Must2(serial.WriteUint64(idHash, uint64(ts))) common.Must2(serial.WriteUint64(idHash, uint64(ts)))
userHash := idHash.Sum(nil) userHash := idHash.Sum(nil)
euser, _, found := v.Get(userHash) euser, _, found, _ := v.Get(userHash)
if found || euser != nil { if found || euser != nil {
t.Error("unexpected user") t.Error("unexpected user")
} }