diff --git a/common/antireplay/antireplay.go b/common/antireplay/antireplay.go new file mode 100644 index 000000000..4d1a93196 --- /dev/null +++ b/common/antireplay/antireplay.go @@ -0,0 +1,6 @@ +package antireplay + +type GeneralizedReplayFilter interface { + Interval() int64 + Check(sum []byte) bool +} diff --git a/proxy/shadowsocks/config.go b/proxy/shadowsocks/config.go index dad16d952..8b59986ed 100644 --- a/proxy/shadowsocks/config.go +++ b/proxy/shadowsocks/config.go @@ -6,6 +6,7 @@ import ( "crypto/cipher" "crypto/md5" "crypto/sha1" + "github.com/v2fly/v2ray-core/v4/common/antireplay" "io" "golang.org/x/crypto/chacha20poly1305" @@ -21,6 +22,8 @@ import ( type MemoryAccount struct { Cipher Cipher Key []byte + + replayFilter antireplay.GeneralizedReplayFilter } // Equals implements protocol.Account.Equals(). @@ -31,6 +34,16 @@ func (a *MemoryAccount) Equals(another protocol.Account) bool { return false } +func (a *MemoryAccount) CheckIV(iv []byte) error { + if a.replayFilter == nil { + return nil + } + if a.replayFilter.Check(iv) { + return nil + } + return newError("IV is not unique") +} + func createAesGcm(key []byte) cipher.AEAD { block, err := aes.NewCipher(key) common.Must(err) @@ -81,6 +94,12 @@ func (a *Account) AsAccount() (protocol.Account, error) { return &MemoryAccount{ Cipher: Cipher, Key: passwordToCipherKey([]byte(a.Password), Cipher.KeySize()), + replayFilter: func() antireplay.GeneralizedReplayFilter { + if a.ReplayProtection { + return antireplay.NewReplayFilter(300) + } + return nil + }(), }, nil } diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 325c0b0ef..f11982594 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -61,6 +61,12 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ iv = append([]byte(nil), buffer.BytesTo(ivLen)...) } + if ivError := account.CheckIV(iv); ivError != nil { + readSizeRemain -= int(buffer.Len()) + DrainConnN(reader, readSizeRemain) + return nil, nil, newError("failed iv check").Base(ivError) + } + r, err := account.Cipher.NewDecryptionReader(account.Key, iv, reader) if err != nil { readSizeRemain -= int(buffer.Len()) @@ -111,6 +117,9 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri if account.Cipher.IVSize() > 0 { iv = make([]byte, account.Cipher.IVSize()) common.Must2(rand.Read(iv)) + if ivError := account.CheckIV(iv); ivError != nil { + return nil, newError("failed to mark outgoing iv").Base(ivError) + } if err := buf.WriteAllBytes(writer, iv); err != nil { return nil, newError("failed to write IV") } @@ -145,6 +154,10 @@ func ReadTCPResponse(user *protocol.MemoryUser, reader io.Reader) (buf.Reader, e } } + if ivError := account.CheckIV(iv); ivError != nil { + return nil, newError("failed iv check").Base(ivError) + } + return account.Cipher.NewDecryptionReader(account.Key, iv, reader) } @@ -156,6 +169,9 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr if account.Cipher.IVSize() > 0 { iv = make([]byte, account.Cipher.IVSize()) common.Must2(rand.Read(iv)) + if ivError := account.CheckIV(iv); ivError != nil { + return nil, newError("failed to mark outgoing iv").Base(ivError) + } if err := buf.WriteAllBytes(writer, iv); err != nil { return nil, newError("failed to write IV.").Base(err) }