1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-22 01:57:12 -05:00

update quic vendor

This commit is contained in:
Darien Raymond 2018-11-27 15:29:03 +01:00
parent 90ab42b1cb
commit 135bf169c0
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
15 changed files with 553 additions and 381 deletions

View File

@ -27,7 +27,7 @@ type client struct {
token []byte
versionNegotiated bool // has the server accepted our version
versionNegotiated utils.AtomicBool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
@ -59,6 +59,7 @@ var (
)
// DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC session is closed.
// The hostname for SNI is taken from the given address.
func DialAddr(
addr string,
@ -69,7 +70,7 @@ func DialAddr(
}
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// The hostname for SNI is taken from the given address.
// See DialAddr for details.
func DialAddrContext(
ctx context.Context,
addr string,
@ -88,6 +89,8 @@ func DialAddrContext(
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The same PacketConn can be used for multiple calls to Dial and Listen,
// QUIC connection IDs are used for demultiplexing the different connections.
// The host parameter is used for SNI.
func Dial(
pconn net.PacketConn,
@ -100,7 +103,7 @@ func Dial(
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// The host parameter is used for SNI.
// See Dial for details.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
@ -164,7 +167,18 @@ func newClient(
}
}
}
srcConnID, err := generateConnectionID(config.ConnectionIDLength)
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
c := &client{
srcConnID: srcConnID,
destConnID: destConnID,
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
tlsConf: tlsConf,
@ -173,7 +187,7 @@ func newClient(
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, c.generateConnectionIDs()
return c, nil
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
@ -234,20 +248,6 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
}
}
func (c *client) generateConnectionIDs() error {
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
if err != nil {
return err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return err
}
c.srcConnID = srcConnID
c.destConnID = destConnID
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
@ -292,65 +292,49 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
}
func (c *client) handlePacket(p *receivedPacket) {
if err := c.handlePacketImpl(p); err != nil {
c.logger.Errorf("error handling packet: %s", err)
}
}
func (c *client) handlePacketImpl(p *receivedPacket) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// handle Version Negotiation Packets
if p.header.IsVersionNegotiation {
err := c.handleVersionNegotiationPacket(p.header)
if err != nil {
c.session.destroy(err)
}
// version negotiation packets have no payload
return err
if p.hdr.IsVersionNegotiation() {
go c.handleVersionNegotiationPacket(p.hdr)
return
}
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
if p.header.Type == protocol.PacketTypeRetry {
c.handleRetryPacket(p.header)
return nil
if p.hdr.Type == protocol.PacketTypeRetry {
go c.handleRetryPacket(p.hdr)
return
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
if !c.versionNegotiated.Get() {
c.versionNegotiated.Set(true)
}
c.session.handlePacket(p)
return nil
}
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
c.mutex.Lock()
defer c.mutex.Unlock()
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
return nil
if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
c.logger.Debugf("Received a delayed Version Negotiation packet.")
return
}
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
// this might be a packet sent by an attacker (or by a terribly broken server implementation)
// ignore it
return nil
// The Version Negotiation packet contains the version that we offered.
// This might be a packet sent by an attacker (or by a terribly broken server implementation).
return
}
}
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
c.session.destroy(qerr.InvalidVersion)
c.logger.Debugf("No compatible version found.")
return
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
@ -358,18 +342,17 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
if err := c.generateConnectionIDs(); err != nil {
return err
}
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
return nil
}
func (c *client) handleRetryPacket(hdr *wire.Header) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.logger.Debugf("<- Received Retry")
hdr.Log(c.logger)
(&wire.ExtendedHeader{Header: *hdr}).Log(c.logger)
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
return

View File

@ -75,12 +75,10 @@ type sentPacketHandler struct {
alarm time.Time
logger utils.Logger
version protocol.VersionNumber
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
pn := h.packetNumberGenerator.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked())
}
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {

View File

@ -11,11 +11,13 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
c = &tls.Config{}
}
// QUIC requires TLS 1.3 or newer
if c.MinVersion < qtls.VersionTLS13 {
c.MinVersion = qtls.VersionTLS13
minVersion := c.MinVersion
if minVersion < qtls.VersionTLS13 {
minVersion = qtls.VersionTLS13
}
if c.MaxVersion < qtls.VersionTLS13 {
c.MaxVersion = qtls.VersionTLS13
maxVersion := c.MaxVersion
if maxVersion < qtls.VersionTLS13 {
maxVersion = qtls.VersionTLS13
}
return &qtls.Config{
Rand: c.Rand,
@ -38,8 +40,8 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
MinVersion: minVersion,
MaxVersion: maxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,

View File

@ -1,20 +1,37 @@
package protocol
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
PacketNumberLenInvalid PacketNumberLen = 0
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen3 is a packet number length of 3 bytes
PacketNumberLen3 PacketNumberLen = 3
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
)
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func InferPacketNumber(
packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber,
version VersionNumber,
) PacketNumber {
var epochDelta PacketNumber
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7
epochDelta = PacketNumber(1) << 8
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14
epochDelta = PacketNumber(1) << 16
case PacketNumberLen3:
epochDelta = PacketNumber(1) << 24
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30
epochDelta = PacketNumber(1) << 32
}
epoch := lastPacketNumber & ^(epochDelta - 1)
prevEpochBegin := epoch - epochDelta
@ -42,11 +59,14 @@ func delta(a, b PacketNumber) PacketNumber {
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked)
if diff < (1 << (14 - 1)) {
if diff < (1 << (16 - 1)) {
return PacketNumberLen2
}
if diff < (1 << (24 - 1)) {
return PacketNumberLen3
}
return PacketNumberLen4
}
@ -58,5 +78,8 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
return PacketNumberLen2
}
if packetNumber < (1 << (uint8(PacketNumberLen3) * 8)) {
return PacketNumberLen3
}
return PacketNumberLen4
}

View File

@ -7,32 +7,18 @@ import (
// A PacketNumber in QUIC
type PacketNumber uint64
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
PacketNumberLenInvalid PacketNumberLen = 0
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
)
// The PacketType is the Long Header Type
type PacketType uint8
const (
// PacketTypeInitial is the packet type of an Initial packet
PacketTypeInitial PacketType = 0x7f
PacketTypeInitial PacketType = 1 + iota
// PacketTypeRetry is the packet type of a Retry packet
PacketTypeRetry PacketType = 0x7e
PacketTypeRetry
// PacketTypeHandshake is the packet type of a Handshake packet
PacketTypeHandshake PacketType = 0x7d
PacketTypeHandshake
// PacketType0RTT is the packet type of a 0-RTT packet
PacketType0RTT PacketType = 0x7c
PacketType0RTT
)
func (t PacketType) String() string {
@ -72,11 +58,5 @@ const DefaultTCPMSS ByteCount = 1460
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200
// MaxClientHellos is the maximum number of times we'll send a client hello
// The value 3 accounts for:
// * one failure due to an incorrect or missing source-address token
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8

View File

@ -8,11 +8,10 @@ import (
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface {
ReadUintN(b io.ByteReader, length uint8) (uint64, error)
ReadUint64(io.ByteReader) (uint64, error)
ReadUint32(io.ByteReader) (uint32, error)
ReadUint16(io.ByteReader) (uint16, error)
WriteUint64(*bytes.Buffer, uint64)
WriteUintN(b *bytes.Buffer, length uint8, value uint64)
WriteUint32(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
}

View File

@ -25,37 +25,6 @@ func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
return res, nil
}
// ReadUint64 reads a uint64
func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) {
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
var err error
if b8, err = b.ReadByte(); err != nil {
return 0, err
}
if b7, err = b.ReadByte(); err != nil {
return 0, err
}
if b6, err = b.ReadByte(); err != nil {
return 0, err
}
if b5, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
}
// ReadUint32 reads a uint32
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
@ -88,12 +57,10 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
return uint16(b1) + uint16(b2)<<8, nil
}
// WriteUint64 writes a uint64
func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
b.Write([]byte{
uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) {
for j := length; j > 0; j-- {
b.WriteByte(uint8(i >> (8 * (j - 1))))
}
}
// WriteUint32 writes a uint32

View File

@ -0,0 +1,205 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// ExtendedHeader is the header of a QUIC packet.
type ExtendedHeader struct {
Header
typeByte byte
Raw []byte
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
KeyPhase int
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return nil, err
}
if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil {
return nil, err
}
if h.IsLongHeader {
return h.parseLongHeader(b, v)
}
return h.parseShortHeader(b, v)
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
if h.typeByte&0xc != 0 {
return nil, errors.New("5th and 6th bit must be 0")
}
if err := h.readPacketNumber(b); err != nil {
return nil, err
}
return h, nil
}
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
if h.typeByte&0x18 != 0 {
return nil, errors.New("4th and 5th bit must be 0")
}
h.KeyPhase = int(h.typeByte&0x4) >> 2
if err := h.readPacketNumber(b); err != nil {
return nil, err
}
return h, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(pn)
return nil
}
// Write writes the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
}
return h.writeShortHeader(b, ver)
}
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
var packetType uint8
switch h.Type {
case protocol.PacketTypeInitial:
packetType = 0x0
case protocol.PacketType0RTT:
packetType = 0x1
case protocol.PacketTypeHandshake:
packetType = 0x2
case protocol.PacketTypeRetry:
packetType = 0x3
}
firstByte := 0xc0 | packetType<<4
if h.Type == protocol.PacketTypeRetry {
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
if err != nil {
return err
}
firstByte |= odcil
} else { // Retry packets don't have a packet number
firstByte |= uint8(h.PacketNumberLen - 1)
}
b.WriteByte(firstByte)
utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
if err != nil {
return err
}
b.WriteByte(connIDLen)
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
switch h.Type {
case protocol.PacketTypeRetry:
b.Write(h.OrigDestConnectionID.Bytes())
b.Write(h.Token)
return nil
case protocol.PacketTypeInitial:
utils.WriteVarInt(b, uint64(len(h.Token)))
b.Write(h.Token)
}
utils.WriteVarInt(b, uint64(h.Length))
return h.writePacketNumber(b)
}
// TODO: add support for the key phase
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
typeByte := 0x40 | uint8(h.PacketNumberLen-1)
typeByte |= byte(h.KeyPhase << 2)
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 {
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber))
return nil
}
// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
if h.Type == protocol.PacketTypeInitial {
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
}
return length
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *ExtendedHeader) Log(logger utils.Logger) {
if h.IsLongHeader {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
return
}
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
}
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
dcil, err := encodeSingleConnIDLen(dest)
if err != nil {
return 0, err
}
scil, err := encodeSingleConnIDLen(src)
if err != nil {
return 0, err
}
return scil | dcil<<4, nil
}
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
len := id.Len()
if len == 0 {
return 0, nil
}
if len < 4 || len > 18 {
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
}
return byte(len - 3), nil
}

View File

@ -2,150 +2,183 @@ package wire
import (
"bytes"
"crypto/rand"
"fmt"
"errors"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Header is the header of a QUIC packet.
// The Header is the version independent part of the header
type Header struct {
Raw []byte
Version protocol.VersionNumber
SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID
Version protocol.VersionNumber
DestConnectionID protocol.ConnectionID
SrcConnectionID protocol.ConnectionID
OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
IsVersionNegotiation bool
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
Type protocol.PacketType
IsLongHeader bool
KeyPhase int
Type protocol.PacketType
Length protocol.ByteCount
Token []byte
Token []byte
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
typeByte byte
len int // how many bytes were read while parsing this header
}
// Write writes the Header.
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
startLen := b.Len()
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
if err != nil {
return nil, err
}
return h.writeShortHeader(b, ver)
h.len = startLen - b.Len()
return h, nil
}
// TODO: add support for the key phase
func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
b.WriteByte(byte(0x80 | h.Type))
utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
h := &Header{
typeByte: typeByte,
IsLongHeader: typeByte&0x80 > 0,
}
if !h.IsLongHeader {
if h.typeByte&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
}
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
return nil, err
}
return h, nil
}
if err := h.parseLongHeader(b); err != nil {
return nil, err
}
return h, nil
}
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
return err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
b.WriteByte(connIDLen)
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
if h.Type == protocol.PacketTypeInitial {
utils.WriteVarInt(b, uint64(len(h.Token)))
b.Write(h.Token)
h.Version = protocol.VersionNumber(v)
if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
}
if h.Type == protocol.PacketTypeRetry {
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
if err != nil {
return err
}
// randomize the first 4 bits
odcilByte := make([]byte, 1)
_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
b.Write(odcilByte)
b.Write(h.OrigDestConnectionID.Bytes())
b.Write(h.Token)
connIDLenByte, err := b.ReadByte()
if err != nil {
return err
}
dcil, scil := decodeConnIDLen(connIDLenByte)
h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
if err != nil {
return err
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
if err != nil {
return err
}
if h.Version == 0 {
return h.parseVersionNegotiationPacket(b)
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return nil
}
utils.WriteVarInt(b, uint64(h.Length))
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
switch (h.typeByte & 0x30) >> 4 {
case 0x0:
h.Type = protocol.PacketTypeInitial
case 0x1:
h.Type = protocol.PacketType0RTT
case 0x2:
h.Type = protocol.PacketTypeHandshake
case 0x3:
h.Type = protocol.PacketTypeRetry
}
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
typeByte := byte(0x30)
typeByte |= byte(h.KeyPhase << 6)
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
// GetLength determines the length of the Header.
func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
if h.Type == protocol.PacketTypeInitial {
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
if h.Type == protocol.PacketTypeRetry {
odcil := decodeSingleConnIDLen(h.typeByte & 0xf)
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
if err != nil {
return err
}
return length
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *Header) Log(logger utils.Logger) {
if h.IsLongHeader {
if h.Version == 0 {
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
} else {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
return
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
h.Token = make([]byte, b.Len())
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
return nil
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := utils.ReadVarInt(b)
if err != nil {
return err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
}
pl, err := utils.ReadVarInt(b)
if err != nil {
return err
}
h.Length = protocol.ByteCount(pl)
return nil
}
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
dcil, err := encodeSingleConnIDLen(dest)
if err != nil {
return 0, err
func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error {
if b.Len() == 0 {
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
}
scil, err := encodeSingleConnIDLen(src)
if err != nil {
return 0, err
h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
for i := 0; b.Len() > 0; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return qerr.InvalidVersionNegotiationPacket
}
h.SupportedVersions[i] = protocol.VersionNumber(v)
}
return scil | dcil<<4, nil
return nil
}
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
len := id.Len()
if len == 0 {
return 0, nil
}
if len < 4 || len > 18 {
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
}
return byte(len - 3), nil
// IsVersionNegotiation says if this a version negotiation packet
func (h *Header) IsVersionNegotiation() bool {
return h.IsLongHeader && h.Version == 0
}
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
return h.toExtendedHeader().parse(b, ver)
}
func (h *Header) toExtendedHeader() *ExtendedHeader {
return &ExtendedHeader{Header: *h}
}
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {

View File

@ -162,75 +162,46 @@ func (h *packetHandlerMap) listen() {
}
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
rcvTime := time.Now()
r := bytes.NewReader(data)
iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
hdr, err := wire.ParseHeader(r, h.connIDLen)
// drop the packet if we can't parse the header
if err != nil {
return fmt.Errorf("error parsing invariant header: %s", err)
}
h.mutex.RLock()
handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)]
server := h.server
var sentBy protocol.Perspective
var version protocol.VersionNumber
var handlePacket func(*receivedPacket)
if handlerFound { // existing session
handler := handlerEntry.handler
sentBy = handler.GetPerspective().Opposite()
version = handler.GetVersion()
handlePacket = handler.handlePacket
} else { // no session found
// this might be a stateless reset
if !iHdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
h.mutex.RUnlock()
sess.destroy(errors.New("received a stateless reset"))
return nil
}
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
if server == nil { // no server set
h.mutex.RUnlock()
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
handlePacket = server.handlePacket
sentBy = protocol.PerspectiveClient
version = iHdr.Version
}
h.mutex.RUnlock()
hdr, err := iHdr.Parse(r, sentBy, version)
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
packetData := data[len(data)-r.Len():]
if hdr.IsLongHeader {
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
}
if protocol.ByteCount(len(packetData))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(packetData)+int(hdr.PacketNumberLen), hdr.Length)
}
packetData = packetData[:int(hdr.Length)-int(hdr.PacketNumberLen)]
// TODO(#1312): implement parsing of compound packets
p := &receivedPacket{
remoteAddr: addr,
hdr: hdr,
data: data,
rcvTime: time.Now(),
}
handlePacket(&receivedPacket{
remoteAddr: addr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
h.mutex.RLock()
defer h.mutex.RUnlock()
handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
return nil
}
// No session found.
// This might be a stateless reset.
if !hdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
return nil
}
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
if h.server == nil { // no server set
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
h.server.handlePacket(p)
return nil
}

View File

@ -25,7 +25,7 @@ type packer interface {
}
type packedPacket struct {
header *wire.Header
header *wire.ExtendedHeader
raw []byte
frames []wire.Frame
encryptionLevel protocol.EncryptionLevel
@ -397,14 +397,13 @@ func (p *packetPacker) composeNextPacket(
return frames, nil
}
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber()
header := &wire.Header{
PacketNumber: pn,
PacketNumberLen: pnLen,
Version: p.version,
DestConnectionID: p.destConnID,
}
header := &wire.ExtendedHeader{}
header.PacketNumber = pn
header.PacketNumberLen = pnLen
header.Version = p.version
header.DestConnectionID = p.destConnID
if encLevel != protocol.Encryption1RTT {
header.IsLongHeader = true
@ -424,8 +423,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
}
func (p *packetPacker) writeAndSealPacket(
header *wire.Header,
frames []wire.Frame,
header *wire.ExtendedHeader, frames []wire.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := *getPacketBuffer()
@ -450,7 +448,7 @@ func (p *packetPacker) writeAndSealPacket(
}
}
if err := header.Write(buffer, p.perspective, p.version); err != nil {
if err := header.Write(buffer, p.version); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()

View File

@ -35,7 +35,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
}
}
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) {
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
buf := *getPacketBuffer()
buf = buf[:0]
defer putPacketBuffer(&buf)

View File

@ -21,7 +21,6 @@ type packetHandler interface {
handlePacket(*receivedPacket)
io.Closer
destroy(error)
GetVersion() protocol.VersionNumber
GetPerspective() protocol.Perspective
}
@ -99,7 +98,8 @@ var _ Listener = &server{}
var _ unknownPacketHandler = &server{}
// ListenAddr creates a QUIC server listening on a given address.
// The tls.Config must not be nil, the quic.Config may be nil.
// The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
@ -118,7 +118,11 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
}
// Listen listens for QUIC connections on a given net.PacketConn.
// The tls.Config must not be nil, the quic.Config may be nil.
// A single PacketConn only be used for a single call to Listen.
// The PacketConn can be used for simultaneous calls to Dial.
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
return listen(conn, tlsConf, config)
}
@ -300,23 +304,17 @@ func (s *server) Addr() net.Addr {
}
func (s *server) handlePacket(p *receivedPacket) {
if err := s.handlePacketImpl(p); err != nil {
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
}
}
func (s *server) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
hdr := p.hdr
// send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
return s.sendVersionNegotiationPacket(p)
go s.sendVersionNegotiationPacket(p)
return
}
if hdr.Type == protocol.PacketTypeInitial {
go s.handleInitial(p)
}
// TODO(#943): send Stateless Reset
return nil
}
func (s *server) handleInitial(p *receivedPacket) {
@ -335,11 +333,11 @@ func (s *server) handleInitial(p *receivedPacket) {
}
func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
hdr := p.header
hdr := p.hdr
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
}
if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize {
if len(p.data) < protocol.MinInitialPacketSize {
return nil, nil, errors.New("dropping too small Initial packet")
}
@ -358,7 +356,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con
if !s.config.AcceptCookie(p.remoteAddr, cookie) {
// Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the session.
p.header.Log(s.logger)
(&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger)
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
}
@ -431,19 +429,18 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
if err != nil {
return err
}
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
Version: hdr.Version,
SrcConnectionID: connID,
DestConnectionID: hdr.SrcConnectionID,
OrigDestConnectionID: hdr.DestConnectionID,
Token: token,
}
replyHdr := &wire.ExtendedHeader{}
replyHdr.IsLongHeader = true
replyHdr.Type = protocol.PacketTypeRetry
replyHdr.Version = hdr.Version
replyHdr.SrcConnectionID = connID
replyHdr.DestConnectionID = hdr.SrcConnectionID
replyHdr.OrigDestConnectionID = hdr.DestConnectionID
replyHdr.Token = token
s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
replyHdr.Log(s.logger)
buf := &bytes.Buffer{}
if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil {
if err := replyHdr.Write(buf, hdr.Version); err != nil {
return err
}
if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
@ -452,14 +449,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
return nil
}
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
hdr := p.header
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
hdr := p.hdr
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
if err != nil {
return err
s.logger.Debugf("Error composing Version Negotiation: %s", err)
return
}
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
}
_, err = s.conn.WriteTo(data, p.remoteAddr)
return err
}

View File

@ -32,7 +32,7 @@ func (s *serverSession) handlePacket(p *receivedPacket) {
}
func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
hdr := p.hdr
// Probably an old packet that was sent by the client before the version was negotiated.
// It is safe to drop it.

View File

@ -1,6 +1,7 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
@ -21,7 +22,7 @@ import (
)
type unpacker interface {
Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error)
Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
}
type streamGetter interface {
@ -52,7 +53,7 @@ type cryptoStreamHandler interface {
type receivedPacket struct {
remoteAddr net.Addr
header *wire.Header
hdr *wire.Header
data []byte
rcvTime time.Time
}
@ -113,7 +114,6 @@ type session struct {
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
receivedFirstForwardSecurePacket bool
lastRcvdPacketNumber protocol.PacketNumber
// Used to calculate the next packet number from the truncated wire
// representation, and sent back in public reset packets
largestRcvdPacketNumber protocol.PacketNumber
@ -289,7 +289,7 @@ var newClientSession = func(
func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version)
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.InitialMaxData,
@ -374,7 +374,7 @@ runLoop:
}
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
putPacketBuffer(&p.header.Raw)
// TODO: putPacketBuffer(&p.extHdr.Raw)
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
}
@ -479,24 +479,41 @@ func (s *session) handleHandshakeComplete() {
}
func (s *session) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID)
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID)
return nil
}
p.rcvTime = time.Now()
data := p.data
r := bytes.NewReader(data)
hdr, err := p.hdr.ParseExtended(r, s.version)
if err != nil {
return fmt.Errorf("error parsing extended header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
data = data[len(data)-r.Len():]
if hdr.IsLongHeader {
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
}
if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
}
data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
// TODO(#1312): implement parsing of compound packets
}
// Calculate packet number
hdr.PacketNumber = protocol.InferPacketNumber(
hdr.PacketNumberLen,
s.largestRcvdPacketNumber,
hdr.PacketNumber,
s.version,
)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
if s.logger.Debug() {
if err != nil {
s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID)
@ -530,7 +547,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
}
}
s.lastRcvdPacketNumber = hdr.PacketNumber
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
@ -543,10 +559,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
}
}
return s.handleFrames(packet.frames, packet.encryptionLevel)
return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel)
}
func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error {
func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
for _, ff := range fs {
var err error
wire.LogFrame(s.logger, ff, false)
@ -556,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
case *wire.StreamFrame:
err = s.handleStreamFrame(frame, encLevel)
case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel)
err = s.handleAckFrame(frame, pn, encLevel)
case *wire.ConnectionCloseFrame:
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
case *wire.ResetStreamFrame:
@ -702,8 +718,8 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) {
s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data})
}
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil {
return err
}
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
@ -1065,14 +1081,14 @@ func (s *session) scheduleSending() {
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
if s.handshakeComplete {
s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data))
s.logger.Debugf("Received undecryptable packet from %s after the handshake (%d bytes)", p.remoteAddr.String(), len(p.data))
return
}
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber)
s.logger.Infof("Dropping undecrytable packet (%d bytes). Undecryptable packet queue full.", len(p.data))
return
}
s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber)
s.logger.Infof("Queueing packet (%d bytes) for later decryption", len(p.data))
s.undecryptablePackets = append(s.undecryptablePackets, p)
}