1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-22 01:57:12 -05:00

update quic vendor

This commit is contained in:
Darien Raymond 2018-11-27 15:29:03 +01:00
parent 90ab42b1cb
commit 135bf169c0
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
15 changed files with 553 additions and 381 deletions

View File

@ -27,7 +27,7 @@ type client struct {
token []byte token []byte
versionNegotiated bool // has the server accepted our version versionNegotiated utils.AtomicBool // has the server accepted our version
receivedVersionNegotiationPacket bool receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
@ -59,6 +59,7 @@ var (
) )
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC session is closed.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddr( func DialAddr(
addr string, addr string,
@ -69,7 +70,7 @@ func DialAddr(
} }
// DialAddrContext establishes a new QUIC connection to a server using the provided context. // DialAddrContext establishes a new QUIC connection to a server using the provided context.
// The hostname for SNI is taken from the given address. // See DialAddr for details.
func DialAddrContext( func DialAddrContext(
ctx context.Context, ctx context.Context,
addr string, addr string,
@ -88,6 +89,8 @@ func DialAddrContext(
} }
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The same PacketConn can be used for multiple calls to Dial and Listen,
// QUIC connection IDs are used for demultiplexing the different connections.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func Dial( func Dial(
pconn net.PacketConn, pconn net.PacketConn,
@ -100,7 +103,7 @@ func Dial(
} }
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// The host parameter is used for SNI. // See Dial for details.
func DialContext( func DialContext(
ctx context.Context, ctx context.Context,
pconn net.PacketConn, pconn net.PacketConn,
@ -164,7 +167,18 @@ func newClient(
} }
} }
} }
srcConnID, err := generateConnectionID(config.ConnectionIDLength)
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
c := &client{ c := &client{
srcConnID: srcConnID,
destConnID: destConnID,
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn, createdPacketConn: createdPacketConn,
tlsConf: tlsConf, tlsConf: tlsConf,
@ -173,7 +187,7 @@ func newClient(
handshakeChan: make(chan struct{}), handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"), logger: utils.DefaultLogger.WithPrefix("client"),
} }
return c, c.generateConnectionIDs() return c, nil
} }
// populateClientConfig populates fields in the quic.Config with their default values, if none are set // populateClientConfig populates fields in the quic.Config with their default values, if none are set
@ -234,20 +248,6 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
} }
} }
func (c *client) generateConnectionIDs() error {
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
if err != nil {
return err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return err
}
c.srcConnID = srcConnID
c.destConnID = destConnID
return nil
}
func (c *client) dial(ctx context.Context) error { func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
@ -292,65 +292,49 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
} }
func (c *client) handlePacket(p *receivedPacket) { func (c *client) handlePacket(p *receivedPacket) {
if err := c.handlePacketImpl(p); err != nil { if p.hdr.IsVersionNegotiation() {
c.logger.Errorf("error handling packet: %s", err) go c.handleVersionNegotiationPacket(p.hdr)
} return
}
func (c *client) handlePacketImpl(p *receivedPacket) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// handle Version Negotiation Packets
if p.header.IsVersionNegotiation {
err := c.handleVersionNegotiationPacket(p.header)
if err != nil {
c.session.destroy(err)
}
// version negotiation packets have no payload
return err
} }
// reject packets with the wrong connection ID if p.hdr.Type == protocol.PacketTypeRetry {
if !p.header.DestConnectionID.Equal(c.srcConnID) { go c.handleRetryPacket(p.hdr)
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID) return
}
if p.header.Type == protocol.PacketTypeRetry {
c.handleRetryPacket(p.header)
return nil
} }
// this is the first packet we are receiving // this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version // since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated { if !c.versionNegotiated.Get() {
c.versionNegotiated = true c.versionNegotiated.Set(true)
} }
c.session.handlePacket(p) c.session.handlePacket(p)
return nil
} }
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
c.mutex.Lock()
defer c.mutex.Unlock()
// ignore delayed / duplicated version negotiation packets // ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated { if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
c.logger.Debugf("Received a delayed Version Negotiation Packet.") c.logger.Debugf("Received a delayed Version Negotiation packet.")
return nil return
} }
for _, v := range hdr.SupportedVersions { for _, v := range hdr.SupportedVersions {
if v == c.version { if v == c.version {
// the version negotiation packet contains the version that we offered // The Version Negotiation packet contains the version that we offered.
// this might be a packet sent by an attacker (or by a terribly broken server implementation) // This might be a packet sent by an attacker (or by a terribly broken server implementation).
// ignore it return
return nil
} }
} }
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok { if !ok {
return qerr.InvalidVersion c.session.destroy(qerr.InvalidVersion)
c.logger.Debugf("No compatible version found.")
return
} }
c.receivedVersionNegotiationPacket = true c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions c.negotiatedVersions = hdr.SupportedVersions
@ -358,18 +342,17 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// switch to negotiated version // switch to negotiated version
c.initialVersion = c.version c.initialVersion = c.version
c.version = newVersion c.version = newVersion
if err := c.generateConnectionIDs(); err != nil {
return err
}
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion) c.session.destroy(errCloseSessionForNewVersion)
return nil
} }
func (c *client) handleRetryPacket(hdr *wire.Header) { func (c *client) handleRetryPacket(hdr *wire.Header) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.logger.Debugf("<- Received Retry") c.logger.Debugf("<- Received Retry")
hdr.Log(c.logger) (&wire.ExtendedHeader{Header: *hdr}).Log(c.logger)
if !hdr.OrigDestConnectionID.Equal(c.destConnID) { if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID) c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
return return

View File

@ -75,12 +75,10 @@ type sentPacketHandler struct {
alarm time.Time alarm time.Time
logger utils.Logger logger utils.Logger
version protocol.VersionNumber
} }
// NewSentPacketHandler creates a new sentPacketHandler // NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler { func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
congestion := congestion.NewCubicSender( congestion := congestion.NewCubicSender(
congestion.DefaultClock{}, congestion.DefaultClock{},
rttStats, rttStats,
@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
rttStats: rttStats, rttStats: rttStats,
congestion: congestion, congestion: congestion,
logger: logger, logger: logger,
version: version,
} }
} }
@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) { func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
pn := h.packetNumberGenerator.Peek() pn := h.packetNumberGenerator.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version) return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked())
} }
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber { func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {

View File

@ -11,11 +11,13 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
c = &tls.Config{} c = &tls.Config{}
} }
// QUIC requires TLS 1.3 or newer // QUIC requires TLS 1.3 or newer
if c.MinVersion < qtls.VersionTLS13 { minVersion := c.MinVersion
c.MinVersion = qtls.VersionTLS13 if minVersion < qtls.VersionTLS13 {
minVersion = qtls.VersionTLS13
} }
if c.MaxVersion < qtls.VersionTLS13 { maxVersion := c.MaxVersion
c.MaxVersion = qtls.VersionTLS13 if maxVersion < qtls.VersionTLS13 {
maxVersion = qtls.VersionTLS13
} }
return &qtls.Config{ return &qtls.Config{
Rand: c.Rand, Rand: c.Rand,
@ -38,8 +40,8 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
PreferServerCipherSuites: c.PreferServerCipherSuites, PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled, SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey, SessionTicketKey: c.SessionTicketKey,
MinVersion: c.MinVersion, MinVersion: minVersion,
MaxVersion: c.MaxVersion, MaxVersion: maxVersion,
CurvePreferences: c.CurvePreferences, CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation, Renegotiation: c.Renegotiation,

View File

@ -1,20 +1,37 @@
package protocol package protocol
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
PacketNumberLenInvalid PacketNumberLen = 0
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen3 is a packet number length of 3 bytes
PacketNumberLen3 PacketNumberLen = 3
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
)
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number // InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func InferPacketNumber( func InferPacketNumber(
packetNumberLength PacketNumberLen, packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber, lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber, wirePacketNumber PacketNumber,
version VersionNumber,
) PacketNumber { ) PacketNumber {
var epochDelta PacketNumber var epochDelta PacketNumber
switch packetNumberLength { switch packetNumberLength {
case PacketNumberLen1: case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7 epochDelta = PacketNumber(1) << 8
case PacketNumberLen2: case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14 epochDelta = PacketNumber(1) << 16
case PacketNumberLen3:
epochDelta = PacketNumber(1) << 24
case PacketNumberLen4: case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30 epochDelta = PacketNumber(1) << 32
} }
epoch := lastPacketNumber & ^(epochDelta - 1) epoch := lastPacketNumber & ^(epochDelta - 1)
prevEpochBegin := epoch - epochDelta prevEpochBegin := epoch - epochDelta
@ -42,11 +59,14 @@ func delta(a, b PacketNumber) PacketNumber {
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header // GetPacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen { func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked) diff := uint64(packetNumber - leastUnacked)
if diff < (1 << (14 - 1)) { if diff < (1 << (16 - 1)) {
return PacketNumberLen2 return PacketNumberLen2
} }
if diff < (1 << (24 - 1)) {
return PacketNumberLen3
}
return PacketNumberLen4 return PacketNumberLen4
} }
@ -58,5 +78,8 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) { if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
return PacketNumberLen2 return PacketNumberLen2
} }
if packetNumber < (1 << (uint8(PacketNumberLen3) * 8)) {
return PacketNumberLen3
}
return PacketNumberLen4 return PacketNumberLen4
} }

View File

@ -7,32 +7,18 @@ import (
// A PacketNumber in QUIC // A PacketNumber in QUIC
type PacketNumber uint64 type PacketNumber uint64
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
PacketNumberLenInvalid PacketNumberLen = 0
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
)
// The PacketType is the Long Header Type // The PacketType is the Long Header Type
type PacketType uint8 type PacketType uint8
const ( const (
// PacketTypeInitial is the packet type of an Initial packet // PacketTypeInitial is the packet type of an Initial packet
PacketTypeInitial PacketType = 0x7f PacketTypeInitial PacketType = 1 + iota
// PacketTypeRetry is the packet type of a Retry packet // PacketTypeRetry is the packet type of a Retry packet
PacketTypeRetry PacketType = 0x7e PacketTypeRetry
// PacketTypeHandshake is the packet type of a Handshake packet // PacketTypeHandshake is the packet type of a Handshake packet
PacketTypeHandshake PacketType = 0x7d PacketTypeHandshake
// PacketType0RTT is the packet type of a 0-RTT packet // PacketType0RTT is the packet type of a 0-RTT packet
PacketType0RTT PacketType = 0x7c PacketType0RTT
) )
func (t PacketType) String() string { func (t PacketType) String() string {
@ -72,11 +58,5 @@ const DefaultTCPMSS ByteCount = 1460
// MinInitialPacketSize is the minimum size an Initial packet is required to have. // MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200 const MinInitialPacketSize = 1200
// MaxClientHellos is the maximum number of times we'll send a client hello
// The value 3 accounts for:
// * one failure due to an incorrect or missing source-address token
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. // MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8 const MinConnectionIDLenInitial = 8

View File

@ -8,11 +8,10 @@ import (
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. // A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface { type ByteOrder interface {
ReadUintN(b io.ByteReader, length uint8) (uint64, error) ReadUintN(b io.ByteReader, length uint8) (uint64, error)
ReadUint64(io.ByteReader) (uint64, error)
ReadUint32(io.ByteReader) (uint32, error) ReadUint32(io.ByteReader) (uint32, error)
ReadUint16(io.ByteReader) (uint16, error) ReadUint16(io.ByteReader) (uint16, error)
WriteUint64(*bytes.Buffer, uint64) WriteUintN(b *bytes.Buffer, length uint8, value uint64)
WriteUint32(*bytes.Buffer, uint32) WriteUint32(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16) WriteUint16(*bytes.Buffer, uint16)
} }

View File

@ -25,37 +25,6 @@ func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
return res, nil return res, nil
} }
// ReadUint64 reads a uint64
func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) {
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
var err error
if b8, err = b.ReadByte(); err != nil {
return 0, err
}
if b7, err = b.ReadByte(); err != nil {
return 0, err
}
if b6, err = b.ReadByte(); err != nil {
return 0, err
}
if b5, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
}
// ReadUint32 reads a uint32 // ReadUint32 reads a uint32
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8 var b1, b2, b3, b4 uint8
@ -88,12 +57,10 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
return uint16(b1) + uint16(b2)<<8, nil return uint16(b1) + uint16(b2)<<8, nil
} }
// WriteUint64 writes a uint64 func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) {
func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) { for j := length; j > 0; j-- {
b.Write([]byte{ b.WriteByte(uint8(i >> (8 * (j - 1))))
uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), }
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
} }
// WriteUint32 writes a uint32 // WriteUint32 writes a uint32

View File

@ -0,0 +1,205 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// ExtendedHeader is the header of a QUIC packet.
type ExtendedHeader struct {
Header
typeByte byte
Raw []byte
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
KeyPhase int
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return nil, err
}
if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil {
return nil, err
}
if h.IsLongHeader {
return h.parseLongHeader(b, v)
}
return h.parseShortHeader(b, v)
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
if h.typeByte&0xc != 0 {
return nil, errors.New("5th and 6th bit must be 0")
}
if err := h.readPacketNumber(b); err != nil {
return nil, err
}
return h, nil
}
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
if h.typeByte&0x18 != 0 {
return nil, errors.New("4th and 5th bit must be 0")
}
h.KeyPhase = int(h.typeByte&0x4) >> 2
if err := h.readPacketNumber(b); err != nil {
return nil, err
}
return h, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(pn)
return nil
}
// Write writes the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
}
return h.writeShortHeader(b, ver)
}
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
var packetType uint8
switch h.Type {
case protocol.PacketTypeInitial:
packetType = 0x0
case protocol.PacketType0RTT:
packetType = 0x1
case protocol.PacketTypeHandshake:
packetType = 0x2
case protocol.PacketTypeRetry:
packetType = 0x3
}
firstByte := 0xc0 | packetType<<4
if h.Type == protocol.PacketTypeRetry {
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
if err != nil {
return err
}
firstByte |= odcil
} else { // Retry packets don't have a packet number
firstByte |= uint8(h.PacketNumberLen - 1)
}
b.WriteByte(firstByte)
utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
if err != nil {
return err
}
b.WriteByte(connIDLen)
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
switch h.Type {
case protocol.PacketTypeRetry:
b.Write(h.OrigDestConnectionID.Bytes())
b.Write(h.Token)
return nil
case protocol.PacketTypeInitial:
utils.WriteVarInt(b, uint64(len(h.Token)))
b.Write(h.Token)
}
utils.WriteVarInt(b, uint64(h.Length))
return h.writePacketNumber(b)
}
// TODO: add support for the key phase
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
typeByte := 0x40 | uint8(h.PacketNumberLen-1)
typeByte |= byte(h.KeyPhase << 2)
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 {
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber))
return nil
}
// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
if h.Type == protocol.PacketTypeInitial {
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
}
return length
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *ExtendedHeader) Log(logger utils.Logger) {
if h.IsLongHeader {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
return
}
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
}
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
dcil, err := encodeSingleConnIDLen(dest)
if err != nil {
return 0, err
}
scil, err := encodeSingleConnIDLen(src)
if err != nil {
return 0, err
}
return scil | dcil<<4, nil
}
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
len := id.Len()
if len == 0 {
return 0, nil
}
if len < 4 || len > 18 {
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
}
return byte(len - 3), nil
}

View File

@ -2,150 +2,183 @@ package wire
import ( import (
"bytes" "bytes"
"crypto/rand" "errors"
"fmt" "io"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
// Header is the header of a QUIC packet. // The Header is the version independent part of the header
type Header struct { type Header struct {
Raw []byte Version protocol.VersionNumber
SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID
Version protocol.VersionNumber
DestConnectionID protocol.ConnectionID
SrcConnectionID protocol.ConnectionID
OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
IsVersionNegotiation bool
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
Type protocol.PacketType
IsLongHeader bool IsLongHeader bool
KeyPhase int Type protocol.PacketType
Length protocol.ByteCount Length protocol.ByteCount
Token []byte
Token []byte
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
typeByte byte
len int // how many bytes were read while parsing this header
} }
// Write writes the Header. // ParseHeader parses the header.
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error { // For short header packets: up to the packet number.
if h.IsLongHeader { // For long header packets:
return h.writeLongHeader(b, ver) // * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
startLen := b.Len()
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
if err != nil {
return nil, err
} }
return h.writeShortHeader(b, ver) h.len = startLen - b.Len()
return h, nil
} }
// TODO: add support for the key phase func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error { typeByte, err := b.ReadByte()
b.WriteByte(byte(0x80 | h.Type)) if err != nil {
utils.BigEndian.WriteUint32(b, uint32(h.Version)) return nil, err
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) }
h := &Header{
typeByte: typeByte,
IsLongHeader: typeByte&0x80 > 0,
}
if !h.IsLongHeader {
if h.typeByte&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
}
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
return nil, err
}
return h, nil
}
if err := h.parseLongHeader(b); err != nil {
return nil, err
}
return h, nil
}
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
return err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil { if err != nil {
return err return err
} }
b.WriteByte(connIDLen) h.Version = protocol.VersionNumber(v)
b.Write(h.DestConnectionID.Bytes()) if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 {
b.Write(h.SrcConnectionID.Bytes()) return errors.New("not a QUIC packet")
if h.Type == protocol.PacketTypeInitial {
utils.WriteVarInt(b, uint64(len(h.Token)))
b.Write(h.Token)
} }
connIDLenByte, err := b.ReadByte()
if h.Type == protocol.PacketTypeRetry { if err != nil {
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID) return err
if err != nil { }
return err dcil, scil := decodeConnIDLen(connIDLenByte)
} h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
// randomize the first 4 bits if err != nil {
odcilByte := make([]byte, 1) return err
_, _ = rand.Read(odcilByte) // it's safe to ignore the error here }
odcilByte[0] = (odcilByte[0] & 0xf0) | odcil h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
b.Write(odcilByte) if err != nil {
b.Write(h.OrigDestConnectionID.Bytes()) return err
b.Write(h.Token) }
if h.Version == 0 {
return h.parseVersionNegotiationPacket(b)
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return nil return nil
} }
utils.WriteVarInt(b, uint64(h.Length)) switch (h.typeByte & 0x30) >> 4 {
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen) case 0x0:
} h.Type = protocol.PacketTypeInitial
case 0x1:
h.Type = protocol.PacketType0RTT
case 0x2:
h.Type = protocol.PacketTypeHandshake
case 0x3:
h.Type = protocol.PacketTypeRetry
}
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error { if h.Type == protocol.PacketTypeRetry {
typeByte := byte(0x30) odcil := decodeSingleConnIDLen(h.typeByte & 0xf)
typeByte |= byte(h.KeyPhase << 6) h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
if err != nil {
b.WriteByte(typeByte) return err
b.Write(h.DestConnectionID.Bytes())
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
// GetLength determines the length of the Header.
func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
if h.Type == protocol.PacketTypeInitial {
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
} }
return length h.Token = make([]byte, b.Len())
} if _, err := io.ReadFull(b, h.Token); err != nil {
return err
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *Header) Log(logger utils.Logger) {
if h.IsLongHeader {
if h.Version == 0 {
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
} else {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
return
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
} }
} else { return nil
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
} }
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := utils.ReadVarInt(b)
if err != nil {
return err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
}
pl, err := utils.ReadVarInt(b)
if err != nil {
return err
}
h.Length = protocol.ByteCount(pl)
return nil
} }
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) { func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error {
dcil, err := encodeSingleConnIDLen(dest) if b.Len() == 0 {
if err != nil { return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
return 0, err
} }
scil, err := encodeSingleConnIDLen(src) h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
if err != nil { for i := 0; b.Len() > 0; i++ {
return 0, err v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return qerr.InvalidVersionNegotiationPacket
}
h.SupportedVersions[i] = protocol.VersionNumber(v)
} }
return scil | dcil<<4, nil return nil
} }
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) { // IsVersionNegotiation says if this a version negotiation packet
len := id.Len() func (h *Header) IsVersionNegotiation() bool {
if len == 0 { return h.IsLongHeader && h.Version == 0
return 0, nil }
}
if len < 4 || len > 18 { // ParseExtended parses the version dependent part of the header.
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len) // The Reader has to be set such that it points to the first byte of the header.
} func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
return byte(len - 3), nil return h.toExtendedHeader().parse(b, ver)
}
func (h *Header) toExtendedHeader() *ExtendedHeader {
return &ExtendedHeader{Header: *h}
} }
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) { func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {

View File

@ -162,75 +162,46 @@ func (h *packetHandlerMap) listen() {
} }
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
rcvTime := time.Now()
r := bytes.NewReader(data) r := bytes.NewReader(data)
iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen) hdr, err := wire.ParseHeader(r, h.connIDLen)
// drop the packet if we can't parse the header // drop the packet if we can't parse the header
if err != nil {
return fmt.Errorf("error parsing invariant header: %s", err)
}
h.mutex.RLock()
handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)]
server := h.server
var sentBy protocol.Perspective
var version protocol.VersionNumber
var handlePacket func(*receivedPacket)
if handlerFound { // existing session
handler := handlerEntry.handler
sentBy = handler.GetPerspective().Opposite()
version = handler.GetVersion()
handlePacket = handler.handlePacket
} else { // no session found
// this might be a stateless reset
if !iHdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
h.mutex.RUnlock()
sess.destroy(errors.New("received a stateless reset"))
return nil
}
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
if server == nil { // no server set
h.mutex.RUnlock()
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
handlePacket = server.handlePacket
sentBy = protocol.PerspectiveClient
version = iHdr.Version
}
h.mutex.RUnlock()
hdr, err := iHdr.Parse(r, sentBy, version)
if err != nil { if err != nil {
return fmt.Errorf("error parsing header: %s", err) return fmt.Errorf("error parsing header: %s", err)
} }
hdr.Raw = data[:len(data)-r.Len()]
packetData := data[len(data)-r.Len():]
if hdr.IsLongHeader { p := &receivedPacket{
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) { remoteAddr: addr,
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen) hdr: hdr,
} data: data,
if protocol.ByteCount(len(packetData))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length { rcvTime: time.Now(),
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(packetData)+int(hdr.PacketNumberLen), hdr.Length)
}
packetData = packetData[:int(hdr.Length)-int(hdr.PacketNumberLen)]
// TODO(#1312): implement parsing of compound packets
} }
handlePacket(&receivedPacket{ h.mutex.RLock()
remoteAddr: addr, defer h.mutex.RUnlock()
header: hdr,
data: packetData, handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
rcvTime: rcvTime,
}) if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
return nil
}
// No session found.
// This might be a stateless reset.
if !hdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
return nil
}
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
if h.server == nil { // no server set
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
h.server.handlePacket(p)
return nil return nil
} }

View File

@ -25,7 +25,7 @@ type packer interface {
} }
type packedPacket struct { type packedPacket struct {
header *wire.Header header *wire.ExtendedHeader
raw []byte raw []byte
frames []wire.Frame frames []wire.Frame
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
@ -397,14 +397,13 @@ func (p *packetPacker) composeNextPacket(
return frames, nil return frames, nil
} }
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber() pn, pnLen := p.pnManager.PeekPacketNumber()
header := &wire.Header{ header := &wire.ExtendedHeader{}
PacketNumber: pn, header.PacketNumber = pn
PacketNumberLen: pnLen, header.PacketNumberLen = pnLen
Version: p.version, header.Version = p.version
DestConnectionID: p.destConnID, header.DestConnectionID = p.destConnID
}
if encLevel != protocol.Encryption1RTT { if encLevel != protocol.Encryption1RTT {
header.IsLongHeader = true header.IsLongHeader = true
@ -424,8 +423,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
} }
func (p *packetPacker) writeAndSealPacket( func (p *packetPacker) writeAndSealPacket(
header *wire.Header, header *wire.ExtendedHeader, frames []wire.Frame,
frames []wire.Frame,
sealer handshake.Sealer, sealer handshake.Sealer,
) ([]byte, error) { ) ([]byte, error) {
raw := *getPacketBuffer() raw := *getPacketBuffer()
@ -450,7 +448,7 @@ func (p *packetPacker) writeAndSealPacket(
} }
} }
if err := header.Write(buffer, p.perspective, p.version); err != nil { if err := header.Write(buffer, p.version); err != nil {
return nil, err return nil, err
} }
payloadStartIndex := buffer.Len() payloadStartIndex := buffer.Len()

View File

@ -35,7 +35,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
} }
} }
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
buf := *getPacketBuffer() buf := *getPacketBuffer()
buf = buf[:0] buf = buf[:0]
defer putPacketBuffer(&buf) defer putPacketBuffer(&buf)

View File

@ -21,7 +21,6 @@ type packetHandler interface {
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
io.Closer io.Closer
destroy(error) destroy(error)
GetVersion() protocol.VersionNumber
GetPerspective() protocol.Perspective GetPerspective() protocol.Perspective
} }
@ -99,7 +98,8 @@ var _ Listener = &server{}
var _ unknownPacketHandler = &server{} var _ unknownPacketHandler = &server{}
// ListenAddr creates a QUIC server listening on a given address. // ListenAddr creates a QUIC server listening on a given address.
// The tls.Config must not be nil, the quic.Config may be nil. // The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
@ -118,7 +118,11 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
} }
// Listen listens for QUIC connections on a given net.PacketConn. // Listen listens for QUIC connections on a given net.PacketConn.
// The tls.Config must not be nil, the quic.Config may be nil. // A single PacketConn only be used for a single call to Listen.
// The PacketConn can be used for simultaneous calls to Dial.
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
return listen(conn, tlsConf, config) return listen(conn, tlsConf, config)
} }
@ -300,23 +304,17 @@ func (s *server) Addr() net.Addr {
} }
func (s *server) handlePacket(p *receivedPacket) { func (s *server) handlePacket(p *receivedPacket) {
if err := s.handlePacketImpl(p); err != nil { hdr := p.hdr
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
}
}
func (s *server) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
// send a Version Negotiation Packet if the client is speaking a different protocol version // send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
return s.sendVersionNegotiationPacket(p) go s.sendVersionNegotiationPacket(p)
return
} }
if hdr.Type == protocol.PacketTypeInitial { if hdr.Type == protocol.PacketTypeInitial {
go s.handleInitial(p) go s.handleInitial(p)
} }
// TODO(#943): send Stateless Reset // TODO(#943): send Stateless Reset
return nil
} }
func (s *server) handleInitial(p *receivedPacket) { func (s *server) handleInitial(p *receivedPacket) {
@ -335,11 +333,11 @@ func (s *server) handleInitial(p *receivedPacket) {
} }
func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) { func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
hdr := p.header hdr := p.hdr
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
return nil, nil, errors.New("dropping Initial packet with too short connection ID") return nil, nil, errors.New("dropping Initial packet with too short connection ID")
} }
if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize { if len(p.data) < protocol.MinInitialPacketSize {
return nil, nil, errors.New("dropping too small Initial packet") return nil, nil, errors.New("dropping too small Initial packet")
} }
@ -358,7 +356,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con
if !s.config.AcceptCookie(p.remoteAddr, cookie) { if !s.config.AcceptCookie(p.remoteAddr, cookie) {
// Log the Initial packet now. // Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the session. // If no Retry is sent, the packet will be logged by the session.
p.header.Log(s.logger) (&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger)
return nil, nil, s.sendRetry(p.remoteAddr, hdr) return nil, nil, s.sendRetry(p.remoteAddr, hdr)
} }
@ -431,19 +429,18 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
if err != nil { if err != nil {
return err return err
} }
replyHdr := &wire.Header{ replyHdr := &wire.ExtendedHeader{}
IsLongHeader: true, replyHdr.IsLongHeader = true
Type: protocol.PacketTypeRetry, replyHdr.Type = protocol.PacketTypeRetry
Version: hdr.Version, replyHdr.Version = hdr.Version
SrcConnectionID: connID, replyHdr.SrcConnectionID = connID
DestConnectionID: hdr.SrcConnectionID, replyHdr.DestConnectionID = hdr.SrcConnectionID
OrigDestConnectionID: hdr.DestConnectionID, replyHdr.OrigDestConnectionID = hdr.DestConnectionID
Token: token, replyHdr.Token = token
}
s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID) s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
replyHdr.Log(s.logger) replyHdr.Log(s.logger)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil { if err := replyHdr.Write(buf, hdr.Version); err != nil {
return err return err
} }
if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
@ -452,14 +449,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
return nil return nil
} }
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error { func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
hdr := p.header hdr := p.hdr
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
if err != nil { if err != nil {
return err s.logger.Debugf("Error composing Version Negotiation: %s", err)
return
}
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
} }
_, err = s.conn.WriteTo(data, p.remoteAddr)
return err
} }

View File

@ -32,7 +32,7 @@ func (s *serverSession) handlePacket(p *receivedPacket) {
} }
func (s *serverSession) handlePacketImpl(p *receivedPacket) error { func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
hdr := p.header hdr := p.hdr
// Probably an old packet that was sent by the client before the version was negotiated. // Probably an old packet that was sent by the client before the version was negotiated.
// It is safe to drop it. // It is safe to drop it.

View File

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
@ -21,7 +22,7 @@ import (
) )
type unpacker interface { type unpacker interface {
Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
} }
type streamGetter interface { type streamGetter interface {
@ -52,7 +53,7 @@ type cryptoStreamHandler interface {
type receivedPacket struct { type receivedPacket struct {
remoteAddr net.Addr remoteAddr net.Addr
header *wire.Header hdr *wire.Header
data []byte data []byte
rcvTime time.Time rcvTime time.Time
} }
@ -113,7 +114,6 @@ type session struct {
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
receivedFirstForwardSecurePacket bool receivedFirstForwardSecurePacket bool
lastRcvdPacketNumber protocol.PacketNumber
// Used to calculate the next packet number from the truncated wire // Used to calculate the next packet number from the truncated wire
// representation, and sent back in public reset packets // representation, and sent back in public reset packets
largestRcvdPacketNumber protocol.PacketNumber largestRcvdPacketNumber protocol.PacketNumber
@ -289,7 +289,7 @@ var newClientSession = func(
func (s *session) preSetup() { func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{} s.rttStats = &congestion.RTTStats{}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController( s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.InitialMaxData, protocol.InitialMaxData,
@ -374,7 +374,7 @@ runLoop:
} }
// This is a bit unclean, but works properly, since the packet always // This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it. // begins with the public header and we never copy it.
putPacketBuffer(&p.header.Raw) // TODO: putPacketBuffer(&p.extHdr.Raw)
case <-s.handshakeCompleteChan: case <-s.handshakeCompleteChan:
s.handleHandshakeComplete() s.handleHandshakeComplete()
} }
@ -479,24 +479,41 @@ func (s *session) handleHandshakeComplete() {
} }
func (s *session) handlePacketImpl(p *receivedPacket) error { func (s *session) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
// The server can change the source connection ID with the first Handshake packet. // The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored. // After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID) s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID)
return nil return nil
} }
p.rcvTime = time.Now() data := p.data
r := bytes.NewReader(data)
hdr, err := p.hdr.ParseExtended(r, s.version)
if err != nil {
return fmt.Errorf("error parsing extended header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
data = data[len(data)-r.Len():]
if hdr.IsLongHeader {
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
}
if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
}
data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
// TODO(#1312): implement parsing of compound packets
}
// Calculate packet number // Calculate packet number
hdr.PacketNumber = protocol.InferPacketNumber( hdr.PacketNumber = protocol.InferPacketNumber(
hdr.PacketNumberLen, hdr.PacketNumberLen,
s.largestRcvdPacketNumber, s.largestRcvdPacketNumber,
hdr.PacketNumber, hdr.PacketNumber,
s.version,
) )
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data) packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
if s.logger.Debug() { if s.logger.Debug() {
if err != nil { if err != nil {
s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID) s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID)
@ -530,7 +547,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
} }
} }
s.lastRcvdPacketNumber = hdr.PacketNumber
// Only do this after decrypting, so we are sure the packet is not attacker-controlled // Only do this after decrypting, so we are sure the packet is not attacker-controlled
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
@ -543,10 +559,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
} }
} }
return s.handleFrames(packet.frames, packet.encryptionLevel) return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel)
} }
func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error { func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
for _, ff := range fs { for _, ff := range fs {
var err error var err error
wire.LogFrame(s.logger, ff, false) wire.LogFrame(s.logger, ff, false)
@ -556,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
case *wire.StreamFrame: case *wire.StreamFrame:
err = s.handleStreamFrame(frame, encLevel) err = s.handleStreamFrame(frame, encLevel)
case *wire.AckFrame: case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel) err = s.handleAckFrame(frame, pn, encLevel)
case *wire.ConnectionCloseFrame: case *wire.ConnectionCloseFrame:
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
case *wire.ResetStreamFrame: case *wire.ResetStreamFrame:
@ -702,8 +718,8 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) {
s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data})
} }
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil { if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil {
return err return err
} }
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
@ -1065,14 +1081,14 @@ func (s *session) scheduleSending() {
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
if s.handshakeComplete { if s.handshakeComplete {
s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) s.logger.Debugf("Received undecryptable packet from %s after the handshake (%d bytes)", p.remoteAddr.String(), len(p.data))
return return
} }
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) s.logger.Infof("Dropping undecrytable packet (%d bytes). Undecryptable packet queue full.", len(p.data))
return return
} }
s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) s.logger.Infof("Queueing packet (%d bytes) for later decryption", len(p.data))
s.undecryptablePackets = append(s.undecryptablePackets, p) s.undecryptablePackets = append(s.undecryptablePackets, p)
} }