diff --git a/common/antireplay/antireplay.go b/common/antireplay/antireplay.go new file mode 100644 index 000000000..bdf22f266 --- /dev/null +++ b/common/antireplay/antireplay.go @@ -0,0 +1,50 @@ +package antireplay + +import ( + cuckoo "github.com/seiflotfy/cuckoofilter" + "sync" + "time" +) + +func NewAntiReplayWindow(AntiReplayTime int64) *AntiReplayWindow { + arw := &AntiReplayWindow{} + arw.AntiReplayTime = AntiReplayTime + return arw +} + +type AntiReplayWindow struct { + lock sync.Mutex + poolA *cuckoo.Filter + poolB *cuckoo.Filter + lastSwapTime int64 + PoolSwap bool + AntiReplayTime int64 +} + +func (aw *AntiReplayWindow) Check(sum []byte) bool { + aw.lock.Lock() + + if aw.lastSwapTime == 0 { + aw.lastSwapTime = time.Now().Unix() + aw.poolA = cuckoo.NewFilter(100000) + aw.poolB = cuckoo.NewFilter(100000) + } + + tnow := time.Now().Unix() + timediff := tnow - aw.lastSwapTime + + if timediff >= aw.AntiReplayTime { + if aw.PoolSwap { + aw.PoolSwap = false + aw.poolA.Reset() + } else { + aw.PoolSwap = true + aw.poolB.Reset() + } + aw.lastSwapTime = tnow + } + + ret := aw.poolA.InsertUnique(sum) && aw.poolB.InsertUnique(sum) + aw.lock.Unlock() + return ret +} diff --git a/common/protocol/headers.go b/common/protocol/headers.go index 007ff5aa4..83a5fdd6d 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -38,8 +38,6 @@ const ( RequestOptionChunkMasking bitmask.Byte = 0x04 RequestOptionGlobalPadding bitmask.Byte = 0x08 - - RequestOptionEarlyChecksum bitmask.Byte = 0x16 ) type RequestHeader struct { diff --git a/go.mod b/go.mod index 861c55e62..c9b1737c9 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,14 @@ module v2ray.com/core require ( + github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc // indirect github.com/golang/mock v1.2.0 github.com/golang/protobuf v1.3.2 github.com/google/go-cmp v0.2.0 github.com/gorilla/websocket v1.4.1 github.com/miekg/dns v1.1.4 github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57 + github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841 go.starlark.net v0.0.0-20190919145610-979af19b165c golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 diff --git a/go.sum b/go.sum index ee9c25bda..35cc556fd 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc h1:8WFBn63wegobsYAX0YjD+8suexZDga5CctH4CCTx2+8= +github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -16,6 +18,8 @@ github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0= github.com/miekg/dns v1.1.4/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57 h1:SL1K0QAuC1b54KoY1pjPWe6kSlsFHwK9/oC960fKrTY= github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57/go.mod h1:tz9gX959MEFfFN5whTIocCLUG57WiILqtdVxI8c6Wj0= +github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841 h1:pnfutQFsV7ySmHUeX6ANGfPsBo29RctUvDn8G3rmJVw= +github.com/seiflotfy/cuckoofilter v0.0.0-20200511222245-56093a4d3841/go.mod h1:ET5mVvNjwaGXRgZxO9UZr7X+8eAf87AfIYNwRSp9s4Y= go.starlark.net v0.0.0-20190919145610-979af19b165c h1:WR7X1xgXJlXhQBdorVc9Db3RhwG+J/kp6bLuMyJjfVw= go.starlark.net v0.0.0-20190919145610-979af19b165c/go.mod h1:c1/X6cHgvdXj6pUlmWKMkuqRnW4K8x2vwt6JAaaircg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= diff --git a/proxy/vmess/aead/authid.go b/proxy/vmess/aead/authid.go new file mode 100644 index 000000000..d4b3f446e --- /dev/null +++ b/proxy/vmess/aead/authid.go @@ -0,0 +1,114 @@ +package aead + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + rand3 "crypto/rand" + "encoding/binary" + "errors" + "hash/crc32" + "io" + "math" + "time" + "v2ray.com/core/common" + antiReplayWindow "v2ray.com/core/common/antireplay" +) + +func CreateAuthID(cmdKey []byte, time int64) [16]byte { + buf := bytes.NewBuffer(nil) + common.Must(binary.Write(buf, binary.BigEndian, time)) + var zero uint32 + common.Must2(io.CopyN(buf, rand3.Reader, 4)) + zero = crc32.ChecksumIEEE(buf.Bytes()) + common.Must(binary.Write(buf, binary.BigEndian, zero)) + aesBlock := NewCipherFromKey(cmdKey) + if buf.Len() != 16 { + panic("Size unexpected") + } + var result [16]byte + aesBlock.Encrypt(result[:], buf.Bytes()) + return result +} + +func NewCipherFromKey(cmdKey []byte) cipher.Block { + aesBlock, err := aes.NewCipher(KDF16(cmdKey, "AES Auth ID Encryption")) + if err != nil { + panic(err) + } + return aesBlock +} + +type AuthIDDecoder struct { + s cipher.Block +} + +func NewAuthIDDecoder(cmdKey []byte) *AuthIDDecoder { + return &AuthIDDecoder{NewCipherFromKey(cmdKey)} +} + +func (aidd *AuthIDDecoder) Decode(data [16]byte) (int64, uint32, int32, []byte) { + aidd.s.Decrypt(data[:], data[:]) + var t int64 + var zero uint32 + var rand int32 + reader := bytes.NewReader(data[:]) + common.Must(binary.Read(reader, binary.BigEndian, &t)) + common.Must(binary.Read(reader, binary.BigEndian, &rand)) + common.Must(binary.Read(reader, binary.BigEndian, &zero)) + return t, zero, rand, data[:] +} + +func NewAuthIDDecoderHolder() *AuthIDDecoderHolder { + return &AuthIDDecoderHolder{make(map[string]*AuthIDDecoderItem), antiReplayWindow.NewAntiReplayWindow(120)} +} + +type AuthIDDecoderHolder struct { + aidhi map[string]*AuthIDDecoderItem + apw *antiReplayWindow.AntiReplayWindow +} + +type AuthIDDecoderItem struct { + dec *AuthIDDecoder + ticket interface{} +} + +func NewAuthIDDecoderItem(key [16]byte, ticket interface{}) *AuthIDDecoderItem { + return &AuthIDDecoderItem{ + dec: NewAuthIDDecoder(key[:]), + ticket: ticket, + } +} + +func (a *AuthIDDecoderHolder) AddUser(key [16]byte, ticket interface{}) { + a.aidhi[string(key[:])] = NewAuthIDDecoderItem(key, ticket) +} + +func (a *AuthIDDecoderHolder) RemoveUser(key [16]byte) { + delete(a.aidhi, string(key[:])) +} + +func (a *AuthIDDecoderHolder) Match(AuthID [16]byte) (interface{}, error) { + if !a.apw.Check(AuthID[:]) { + return nil, errReplay + } + for _, v := range a.aidhi { + + t, z, r, d := v.dec.Decode(AuthID) + if z != crc32.ChecksumIEEE(d[:12]) { + continue + } + if math.Abs(float64(t-time.Now().Unix())) > 120 { + continue + } + _ = r + + return v.ticket, nil + + } + return nil, errNotFound +} + +var errNotFound = errors.New("user do not exist") + +var errReplay = errors.New("replayed request") diff --git a/proxy/vmess/aead/encrypt.go b/proxy/vmess/aead/encrypt.go new file mode 100644 index 000000000..86ab4bc3e --- /dev/null +++ b/proxy/vmess/aead/encrypt.go @@ -0,0 +1,141 @@ +package aead + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "encoding/binary" + "errors" + "io" + "time" + "v2ray.com/core/common" +) + +func SealVMessAEADHeader(key [16]byte, data []byte) []byte { + authid := CreateAuthID(key[:], time.Now().Unix()) + + nonce := make([]byte, 8) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + panic(err.Error()) + } + + lengthbuf := bytes.NewBuffer(nil) + + var HeaderDataLen uint16 + HeaderDataLen = uint16(len(data)) + + common.Must(binary.Write(lengthbuf, binary.BigEndian, HeaderDataLen)) + + authidCheck := KDF16(key[:], "VMess AuthID Check Value", string(authid[:]), string(lengthbuf.Bytes()), string(nonce)) + + lengthbufb := lengthbuf.Bytes() + + LengthMask := KDF16(key[:], "VMess AuthID Mask Value", string(authid[:]), string(nonce[:]))[:2] + + lengthbufb[0] = lengthbufb[0] ^ LengthMask[0] + lengthbufb[1] = lengthbufb[1] ^ LengthMask[1] + + HeaderAEADKey := KDF16(key[:], "VMess Header AEAD Key", string(authid[:]), string(nonce)) + + HeaderAEADNonce := KDF(key[:], "VMess Header AEAD Nonce", string(authid[:]), string(nonce))[:12] + + block, err := aes.NewCipher(HeaderAEADKey) + if err != nil { + panic(err.Error()) + } + + headerAEAD, err := cipher.NewGCM(block) + + if err != nil { + panic(err.Error()) + } + + headerSealed := headerAEAD.Seal(nil, HeaderAEADNonce, data, authid[:]) + + var outPutBuf = bytes.NewBuffer(nil) + + common.Must2(outPutBuf.Write(authid[:])) //16 + + common.Must2(outPutBuf.Write(authidCheck)) //16 + + common.Must2(outPutBuf.Write(lengthbufb)) //2 + + common.Must2(outPutBuf.Write(nonce)) //8 + + common.Must2(outPutBuf.Write(headerSealed)) + + return outPutBuf.Bytes() +} + +func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, bool, error, int) { + var authidCheck [16]byte + var lengthbufb [2]byte + var nonce [8]byte + + n, err := io.ReadFull(data, authidCheck[:]) + if err != nil { + return nil, false, err, n + } + + n2, err := io.ReadFull(data, lengthbufb[:]) + if err != nil { + return nil, false, err, n + n2 + } + + n4, err := io.ReadFull(data, nonce[:]) + if err != nil { + return nil, false, err, n + n2 + n4 + } + + //Unmask Length + + LengthMask := KDF16(key[:], "VMess AuthID Mask Value", string(authid[:]), string(nonce[:]))[:2] + + lengthbufb[0] = lengthbufb[0] ^ LengthMask[0] + lengthbufb[1] = lengthbufb[1] ^ LengthMask[1] + + authidCheckV := KDF16(key[:], "VMess AuthID Check Value", string(authid[:]), string(lengthbufb[:]), string(nonce[:])) + + if !hmac.Equal(authidCheckV, authidCheck[:]) { + return nil, true, errCheckMismatch, n + n2 + n4 + } + + var length uint16 + + common.Must(binary.Read(bytes.NewReader(lengthbufb[:]), binary.BigEndian, &length)) + + HeaderAEADKey := KDF16(key[:], "VMess Header AEAD Key", string(authid[:]), string(nonce[:])) + + HeaderAEADNonce := KDF(key[:], "VMess Header AEAD Nonce", string(authid[:]), string(nonce[:]))[:12] + + //16 == AEAD Tag size + header := make([]byte, length+16) + + n3, err := io.ReadFull(data, header) + if err != nil { + return nil, false, err, n + n2 + n3 + n4 + } + + block, err := aes.NewCipher(HeaderAEADKey) + if err != nil { + panic(err.Error()) + } + + headerAEAD, err := cipher.NewGCM(block) + + if err != nil { + panic(err.Error()) + } + + out, erropenAEAD := headerAEAD.Open(nil, HeaderAEADNonce, header, authid[:]) + + if erropenAEAD != nil { + return nil, true, erropenAEAD, n + n2 + n3 + n4 + } + + return out, false, nil, n + n2 + n3 + n4 +} + +var errCheckMismatch = errors.New("check verify failed") diff --git a/proxy/vmess/aead/kdf.go b/proxy/vmess/aead/kdf.go new file mode 100644 index 000000000..0872db9de --- /dev/null +++ b/proxy/vmess/aead/kdf.go @@ -0,0 +1,26 @@ +package aead + +import ( + "crypto/hmac" + "crypto/sha256" + "hash" +) + +func KDF(key []byte, path ...string) []byte { + hmacf := hmac.New(func() hash.Hash { + return sha256.New() + }, []byte("VMess AEAD KDF")) + + for _, v := range path { + hmacf = hmac.New(func() hash.Hash { + return hmacf + }, []byte(v)) + } + hmacf.Write(key) + return hmacf.Sum(nil) +} + +func KDF16(key []byte, path ...string) []byte { + r := KDF(key, path...) + return r[:16] +} diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index 8773e6089..4d7e17765 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -1,12 +1,19 @@ package encoding import ( + "bytes" + "crypto/aes" + "crypto/cipher" "crypto/md5" "crypto/rand" + "crypto/sha256" "encoding/binary" + "fmt" "hash" "hash/fnv" "io" + "os" + vmessaead "v2ray.com/core/proxy/vmess/aead" "golang.org/x/crypto/chacha20poly1305" @@ -37,6 +44,8 @@ type ClientSession struct { responseBodyIV [16]byte responseReader io.Reader responseHeader byte + + isAEADRequest bool } // NewClientSession creates a new ClientSession. @@ -45,11 +54,29 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { common.Must2(rand.Read(randomBytes)) session := &ClientSession{} + + session.isAEADRequest = false + + if vmessexp, vmessexp_found := os.LookupEnv("VMESSAEADEXPERIMENT"); vmessexp_found { + if vmessexp == "y" { + session.isAEADRequest = true + fmt.Println("=======VMESSAEADEXPERIMENT ENABLED========") + } + } + copy(session.requestBodyKey[:], randomBytes[:16]) copy(session.requestBodyIV[:], randomBytes[16:32]) session.responseHeader = randomBytes[32] - session.responseBodyKey = md5.Sum(session.requestBodyKey[:]) - session.responseBodyIV = md5.Sum(session.requestBodyIV[:]) + if !session.isAEADRequest { + session.responseBodyKey = md5.Sum(session.requestBodyKey[:]) + session.responseBodyIV = md5.Sum(session.requestBodyIV[:]) + } else { + BodyKey := sha256.Sum256(session.requestBodyKey[:]) + copy(session.responseBodyKey[:], BodyKey[:16]) + BodyIV := sha256.Sum256(session.requestBodyKey[:]) + copy(session.responseBodyIV[:], BodyIV[:16]) + } + session.idHash = idHash return session @@ -58,9 +85,11 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() account := header.User.Account.(*vmess.MemoryAccount) - idHash := c.idHash(account.AnyValidID().Bytes()) - common.Must2(serial.WriteUint64(idHash, uint64(timestamp))) - common.Must2(writer.Write(idHash.Sum(nil))) + if !c.isAEADRequest { + idHash := c.idHash(account.AnyValidID().Bytes()) + common.Must2(serial.WriteUint64(idHash, uint64(timestamp))) + common.Must2(writer.Write(idHash.Sum(nil))) + } buffer := buf.New() defer buffer.Release() @@ -92,10 +121,18 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ fnv1a.Sum(hashBytes[:0]) } - iv := hashTimestamp(md5.New(), timestamp) - aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:]) - aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) - common.Must2(writer.Write(buffer.Bytes())) + if !c.isAEADRequest { + iv := hashTimestamp(md5.New(), timestamp) + aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:]) + aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) + common.Must2(writer.Write(buffer.Bytes())) + } else { + var fixedLengthCmdKey [16]byte + copy(fixedLengthCmdKey[:], account.ID.CmdKey()) + vmessout := vmessaead.SealVMessAEADHeader(fixedLengthCmdKey, buffer.Bytes()) + common.Must2(io.Copy(writer, bytes.NewReader(vmessout))) + } + return nil } @@ -161,8 +198,49 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write } func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) { - aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) - c.responseReader = crypto.NewCryptionReader(aesStream, reader) + if !c.isAEADRequest { + aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) + c.responseReader = crypto.NewCryptionReader(aesStream, reader) + } else { + resph := vmessaead.KDF16(c.responseBodyKey[:], "AEAD Resp Header Len Key") + respi := vmessaead.KDF(c.responseBodyIV[:], "AEAD Resp Header Len IV")[:12] + + aesblock := common.Must2(aes.NewCipher(resph)).(cipher.Block) + aeadHeader := common.Must2(cipher.NewGCM(aesblock)).(cipher.AEAD) + + var AEADLen [18]byte + var lenresp int + + var lenrespr uint16 + + if _, err := io.ReadFull(reader, AEADLen[:]); err != nil { + return nil, newError("Unable to Read Header Len").Base(err) + } + if AEADLend, err := aeadHeader.Open(nil, respi, AEADLen[:], nil); err != nil { + return nil, newError("Failed To Decrypt Length").Base(err) + } else { + common.Must(binary.Read(bytes.NewReader(AEADLend), binary.BigEndian, &lenrespr)) + lenresp = int(lenrespr) + } + + resphc := vmessaead.KDF16(c.responseBodyKey[:], "AEAD Resp Header Key") + respic := vmessaead.KDF(c.responseBodyIV[:], "AEAD Resp Header IV")[:12] + + aesblockc := common.Must2(aes.NewCipher(resphc)).(cipher.Block) + aeadHeaderc := common.Must2(cipher.NewGCM(aesblockc)).(cipher.AEAD) + + respPayload := make([]byte, lenresp+16) + + if _, err := io.ReadFull(reader, respPayload); err != nil { + return nil, newError("Unable to Read Header Data").Base(err) + } + + if AEADData, err := aeadHeaderc.Open(nil, respic, respPayload, nil); err != nil { + return nil, newError("Failed To Decrypt Payload").Base(err) + } else { + c.responseReader = bytes.NewReader(AEADData) + } + } buffer := buf.StackNew() defer buffer.Release() @@ -192,7 +270,10 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon header.Command = command } } - + if c.isAEADRequest { + aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) + c.responseReader = crypto.NewCryptionReader(aesStream, reader) + } return header, nil } diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 6e6a5e4bf..2a6e1dea5 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -1,7 +1,11 @@ package encoding import ( + "bytes" + "crypto/aes" + "crypto/cipher" "crypto/md5" + "crypto/sha256" "encoding/binary" "hash/fnv" "io" @@ -9,6 +13,7 @@ import ( "sync" "time" "v2ray.com/core/common/dice" + vmessaead "v2ray.com/core/proxy/vmess/aead" "golang.org/x/crypto/chacha20poly1305" @@ -99,6 +104,10 @@ type ServerSession struct { responseBodyIV [16]byte responseWriter io.Writer responseHeader byte + + isAEADRequest bool + + isAEADForced bool } // NewServerSession creates a new ServerSession, using the given UserValidator. @@ -153,17 +162,44 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request return nil, newError("failed to read request header").Base(err) } - user, timestamp, valid := s.userValidator.Get(buffer.Bytes()) - if !valid { + var decryptor io.Reader + var vmessAccount *vmess.MemoryAccount + + user, foundAEAD := s.userValidator.GetAEAD(buffer.Bytes()) + + var fixedSizeAuthID [16]byte + copy(fixedSizeAuthID[:], buffer.Bytes()) + + if foundAEAD == true { + vmessAccount = user.Account.(*vmess.MemoryAccount) + var fixedSizeCmdKey [16]byte + copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey()) + aeadData, shouldDrain, errorReason, bytesRead := vmessaead.OpenVMessAEADHeader(fixedSizeCmdKey, fixedSizeAuthID, reader) + if errorReason != nil { + if shouldDrain { + readSizeRemain -= bytesRead + return nil, drainConnection(newError("AEAD read failed").Base(errorReason)) + } else { + return nil, drainConnection(newError("AEAD read failed, drain skiped").Base(errorReason)) + } + } + decryptor = bytes.NewReader(aeadData) + s.isAEADRequest = true + } else if !s.isAEADForced { + userLegacy, timestamp, valid, userValidationError := s.userValidator.Get(buffer.Bytes()) + if !valid || userValidationError != nil { + return nil, drainConnection(newError("invalid user").Base(userValidationError)) + } + user = userLegacy + iv := hashTimestamp(md5.New(), timestamp) + vmessAccount = userLegacy.Account.(*vmess.MemoryAccount) + + aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:]) + decryptor = crypto.NewCryptionReader(aesStream, reader) + } else { return nil, drainConnection(newError("invalid user")) } - iv := hashTimestamp(md5.New(), timestamp) - vmessAccount := user.Account.(*vmess.MemoryAccount) - - aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:]) - decryptor := crypto.NewCryptionReader(aesStream, reader) - readSizeRemain -= int(buffer.Len()) buffer.Clear() if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil { @@ -182,7 +218,16 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request sid.key = s.requestBodyKey sid.nonce = s.requestBodyIV if !s.sessionHistory.addIfNotExits(sid) { - return nil, drainConnection(newError("duplicated session id, possibly under replay attack")) + if !s.isAEADRequest { + drainErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:]) + if drainErr != nil { + return nil, drainConnection(newError("duplicated session id, possibly under replay attack, and failed to taint userHash").Base(drainErr)) + } + return nil, drainConnection(newError("duplicated session id, possibly under replay attack, userHash tainted")) + } else { + return nil, newError("duplicated session id, possibly under replay attack, but this is a AEAD request") + } + } s.responseHeader = buffer.Byte(33) // 1 byte @@ -205,11 +250,25 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request if padingLen > 0 { if _, err := buffer.ReadFullFrom(decryptor, int32(padingLen)); err != nil { + if !s.isAEADRequest { + burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:]) + if burnErr != nil { + return nil, newError("failed to read padding, failed to taint userHash").Base(burnErr).Base(err) + } + return nil, newError("failed to read padding, userHash tainted").Base(err) + } return nil, newError("failed to read padding").Base(err) } } if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil { + if !s.isAEADRequest { + burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:]) + if burnErr != nil { + return nil, newError("failed to read checksum, failed to taint userHash").Base(burnErr).Base(err) + } + return nil, newError("failed to read checksum, userHash tainted").Base(err) + } return nil, newError("failed to read checksum").Base(err) } @@ -219,8 +278,18 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request expectedHash := binary.BigEndian.Uint32(buffer.BytesFrom(-4)) if actualHash != expectedHash { - //It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523 - return nil, drainConnection(newError("invalid auth")) + if !s.isAEADRequest { + Autherr := newError("invalid auth, legacy userHash tainted") + burnErr := s.userValidator.BurnTaintFuse(fixedSizeAuthID[:]) + if burnErr != nil { + Autherr = newError("invalid auth, can't taint legacy userHash").Base(burnErr) + } + //It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523 + return nil, drainConnection(Autherr) + } else { + return nil, newError("invalid auth, but this is a AEAD request") + } + } if request.Address == nil { @@ -299,18 +368,60 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade // EncodeResponseHeader writes encoded response header into the given writer. func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) { - s.responseBodyKey = md5.Sum(s.requestBodyKey[:]) - s.responseBodyIV = md5.Sum(s.requestBodyIV[:]) + var encryptionWriter io.Writer + if !s.isAEADRequest { + s.responseBodyKey = md5.Sum(s.requestBodyKey[:]) + s.responseBodyIV = md5.Sum(s.requestBodyIV[:]) + } else { + BodyKey := sha256.Sum256(s.requestBodyKey[:]) + copy(s.responseBodyKey[:], BodyKey[:16]) + BodyIV := sha256.Sum256(s.requestBodyKey[:]) + copy(s.responseBodyIV[:], BodyIV[:16]) + } aesStream := crypto.NewAesEncryptionStream(s.responseBodyKey[:], s.responseBodyIV[:]) - encryptionWriter := crypto.NewCryptionWriter(aesStream, writer) + encryptionWriter = crypto.NewCryptionWriter(aesStream, writer) s.responseWriter = encryptionWriter + aeadBuffer := bytes.NewBuffer(nil) + + if s.isAEADRequest { + encryptionWriter = aeadBuffer + } + common.Must2(encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)})) err := MarshalCommand(header.Command, encryptionWriter) if err != nil { common.Must2(encryptionWriter.Write([]byte{0x00, 0x00})) } + + if s.isAEADRequest { + + resph := vmessaead.KDF16(s.responseBodyKey[:], "AEAD Resp Header Len Key") + respi := vmessaead.KDF(s.responseBodyIV[:], "AEAD Resp Header Len IV")[:12] + + aesblock := common.Must2(aes.NewCipher(resph)).(cipher.Block) + aeadHeader := common.Must2(cipher.NewGCM(aesblock)).(cipher.AEAD) + + aeadlenBuf := bytes.NewBuffer(nil) + + var aeadLen uint16 + aeadLen = uint16(aeadBuffer.Len()) + + common.Must(binary.Write(aeadlenBuf, binary.BigEndian, aeadLen)) + + sealedLen := aeadHeader.Seal(nil, respi, aeadlenBuf.Bytes(), nil) + common.Must2(io.Copy(writer, bytes.NewReader(sealedLen))) + + resphc := vmessaead.KDF16(s.responseBodyKey[:], "AEAD Resp Header Key") + respic := vmessaead.KDF(s.responseBodyIV[:], "AEAD Resp Header IV")[:12] + + aesblockc := common.Must2(aes.NewCipher(resphc)).(cipher.Block) + aeadHeaderc := common.Must2(cipher.NewGCM(aesblockc)).(cipher.AEAD) + + sealed := aeadHeaderc.Seal(nil, respic, aeadBuffer.Bytes(), nil) + common.Must2(io.Copy(writer, bytes.NewReader(sealed))) + } } // EncodeResponseBody returns a Writer that auto-encrypt content written by caller. diff --git a/proxy/vmess/validator.go b/proxy/vmess/validator.go index 1682300f9..d9956608d 100644 --- a/proxy/vmess/validator.go +++ b/proxy/vmess/validator.go @@ -8,6 +8,7 @@ import ( "sync" "time" "v2ray.com/core/common/dice" + "v2ray.com/core/proxy/vmess/aead" "v2ray.com/core/common" "v2ray.com/core/common/protocol" @@ -28,27 +29,33 @@ type user struct { // TimedUserValidator is a user Validator based on time. type TimedUserValidator struct { sync.RWMutex - users []*user - userHash map[[16]byte]indexTimePair - hasher protocol.IDHash - baseTime protocol.Timestamp - task *task.Periodic + users []*user + userHash map[[16]byte]indexTimePair + hasher protocol.IDHash + baseTime protocol.Timestamp + task *task.Periodic + behaviorSeed uint64 behaviorFused bool + + aeadDecoderHolder *aead.AuthIDDecoderHolder } type indexTimePair struct { user *user timeInc uint32 + + taintedFuse *bool } // NewTimedUserValidator creates a new TimedUserValidator. func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator { tuv := &TimedUserValidator{ - users: make([]*user, 0, 16), - userHash: make(map[[16]byte]indexTimePair, 1024), - hasher: hasher, - baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2), + users: make([]*user, 0, 16), + userHash: make(map[[16]byte]indexTimePair, 1024), + hasher: hasher, + baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2), + aeadDecoderHolder: aead.NewAuthIDDecoderHolder(), } tuv.task = &task.Periodic{ Interval: updateInterval, @@ -76,8 +83,9 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user * idHash.Reset() v.userHash[hashValue] = indexTimePair{ - user: user, - timeInc: uint32(ts - v.baseTime), + user: user, + timeInc: uint32(ts - v.baseTime), + taintedFuse: new(bool), } } } @@ -128,15 +136,19 @@ func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error { v.users = append(v.users, uu) v.generateNewHashes(protocol.Timestamp(nowSec), uu) + account := uu.user.Account.(*MemoryAccount) if v.behaviorFused == false { - account := uu.user.Account.(*MemoryAccount) v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), account.ID.Bytes()) } + var cmdkeyfl [16]byte + copy(cmdkeyfl[:], account.ID.CmdKey()) + v.aeadDecoderHolder.AddUser(cmdkeyfl, u) + return nil } -func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool) { +func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) { defer v.RUnlock() v.RLock() @@ -148,9 +160,25 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protoco if found { var user protocol.MemoryUser user = pair.user.user - return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true + if *pair.taintedFuse == false { + return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil + } + return nil, 0, false, ErrTainted } - return nil, 0, false + return nil, 0, false, ErrNotFound +} + +func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool) { + defer v.RUnlock() + v.RLock() + var userHashFL [16]byte + copy(userHashFL[:], userHash) + + userd, err := v.aeadDecoderHolder.Match(userHashFL) + if err != nil { + return nil, false + } + return userd.(*protocol.MemoryUser), true } func (v *TimedUserValidator) Remove(email string) bool { @@ -162,6 +190,9 @@ func (v *TimedUserValidator) Remove(email string) bool { for i, u := range v.users { if strings.EqualFold(u.user.Email, email) { idx = i + var cmdkeyfl [16]byte + copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey()) + v.aeadDecoderHolder.RemoveUser(cmdkeyfl) break } } @@ -191,3 +222,21 @@ func (v *TimedUserValidator) GetBehaviorSeed() uint64 { } return v.behaviorSeed } + +func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error { + v.Lock() + defer v.Unlock() + var userHashFL [16]byte + copy(userHashFL[:], userHash) + + pair, found := v.userHash[userHashFL] + if found { + *pair.taintedFuse = true + return nil + } + return ErrNotFound +} + +var ErrNotFound = newError("Not Found") + +var ErrTainted = newError("ErrTainted") diff --git a/proxy/vmess/validator_test.go b/proxy/vmess/validator_test.go index 1b5160243..25c1cf6ab 100644 --- a/proxy/vmess/validator_test.go +++ b/proxy/vmess/validator_test.go @@ -39,7 +39,7 @@ func TestUserValidator(t *testing.T) { common.Must2(serial.WriteUint64(idHash, uint64(ts))) userHash := idHash.Sum(nil) - euser, ets, found := v.Get(userHash) + euser, ets, found, _ := v.Get(userHash) if !found { t.Fatal("user not found") } @@ -67,7 +67,7 @@ func TestUserValidator(t *testing.T) { common.Must2(serial.WriteUint64(idHash, uint64(ts))) userHash := idHash.Sum(nil) - euser, _, found := v.Get(userHash) + euser, _, found, _ := v.Get(userHash) if found || euser != nil { t.Error("unexpected user") }