diff --git a/transport/internet/quic/conn.go b/transport/internet/quic/conn.go index 5d0937275..5123414e7 100644 --- a/transport/internet/quic/conn.go +++ b/transport/internet/quic/conn.go @@ -35,7 +35,7 @@ func wrapSysConn(rawConn net.PacketConn, config *Config) (*sysConn, error) { }, nil } -var errCipherError = errors.New("cipher error") +var errInvalidPacket = errors.New("invalid packet") func (c *sysConn) readFromInternal(p []byte) (int, net.Addr, error) { buffer := getBuffer() @@ -48,6 +48,9 @@ func (c *sysConn) readFromInternal(p []byte) (int, net.Addr, error) { payload := buffer[:nBytes] if c.header != nil { + if len(payload) <= int(c.header.Size()) { + return 0, nil, errInvalidPacket + } payload = payload[c.header.Size():] } @@ -56,12 +59,16 @@ func (c *sysConn) readFromInternal(p []byte) (int, net.Addr, error) { return n, addr, nil } + if len(payload) <= c.auth.NonceSize() { + return 0, nil, errInvalidPacket + } + nonce := payload[:c.auth.NonceSize()] payload = payload[c.auth.NonceSize():] p, err = c.auth.Open(p[:0], nonce, payload, nil) if err != nil { - return 0, nil, errCipherError + return 0, nil, errInvalidPacket } return len(p), addr, nil @@ -74,7 +81,7 @@ func (c *sysConn) ReadFrom(p []byte) (int, net.Addr, error) { for { n, addr, err := c.readFromInternal(p) - if err != nil && err != errCipherError { + if err != nil && err != errInvalidPacket { return 0, nil, err } if err == nil {