1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-21 17:46:58 -05:00

sync quic lib

This commit is contained in:
Darien Raymond 2019-01-16 11:48:15 +01:00
parent 35432832c4
commit 1cf07c3379
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
14 changed files with 353 additions and 214 deletions

View File

@ -27,6 +27,7 @@ type SentPacketHandler interface {
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
// only to be called once the handshake is complete
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() *Packet
DequeueProbePacket() (*Packet, error)
@ -40,9 +41,9 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame
}

View File

@ -1,6 +1,7 @@
package ackhandler
import (
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
@ -9,27 +10,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
version protocol.VersionNumber
}
const (
// maximum delay that can be applied to an ACK for a retransmittable packet
ackSendDelay = 25 * time.Millisecond
@ -53,6 +33,14 @@ const (
maxPacketsAfterNewMissing = 4
)
type receivedPacketHandler struct {
initialPackets *receivedPacketTracker
handshakePackets *receivedPacketTracker
oneRTTPackets *receivedPacketTracker
}
var _ ReceivedPacketHandler = &receivedPacketHandler{}
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(
rttStats *congestion.RTTStats,
@ -60,156 +48,51 @@ func NewReceivedPacketHandler(
version protocol.VersionNumber,
) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: ackSendDelay,
rttStats: rttStats,
logger: logger,
version: version,
initialPackets: newReceivedPacketTracker(rttStats, logger, version),
handshakePackets: newReceivedPacketTracker(rttStats, logger, version),
oneRTTPackets: newReceivedPacketTracker(rttStats, logger, version),
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
func (h *receivedPacketHandler) ReceivedPacket(
pn protocol.PacketNumber,
encLevel protocol.EncryptionLevel,
rcvTime time.Time,
shouldInstigateAck bool,
) error {
switch encLevel {
case protocol.EncryptionInitial:
return h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck)
case protocol.EncryptionHandshake:
return h.handshakePackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck)
case protocol.Encryption1RTT:
return h.oneRTTPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck)
default:
return fmt.Errorf("received packet with unknown encryption level: %s", encLevel)
}
}
// only to be used with 1-RTT packets
func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) {
h.oneRTTPackets.IgnoreBelow(pn)
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
initialAlarm := h.initialPackets.GetAlarmTimeout()
handshakeAlarm := h.handshakePackets.GetAlarmTimeout()
oneRTTAlarm := h.oneRTTPackets.GetAlarmTimeout()
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
}
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) *wire.AckFrame {
switch encLevel {
case protocol.EncryptionInitial:
return h.initialPackets.GetAckFrame()
case protocol.EncryptionHandshake:
return h.handshakePackets.GetAckFrame()
case protocol.Encryption1RTT:
return h.oneRTTPackets.GetAckFrame()
default:
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber >= h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil
}
// IgnoreBelow sets a lower limit for acking packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
if p <= h.ignoreBelow {
return
}
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
}
// maybeQueueAck queues an ACK, if necessary.
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
// in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
h.ackQueued = true
return
}
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
}
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
}
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
} else {
// send an ACK every 2 retransmittable packets
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
}
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
ackTime := rcvTime.Add(ackDelay)
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
now := time.Now()
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
ack := &wire.AckFrame{
AckRanges: h.packetHistory.GetAckRanges(),
DelayTime: now.Sub(h.largestObservedReceivedTime),
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -8,6 +8,7 @@ import (
)
// The receivedPacketHistory stores if a packet number has already been received.
// It generates ACK ranges which can be used to assemble an ACK frame.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges *utils.PacketIntervalList

View File

@ -0,0 +1,191 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receivedPacketTracker struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
version protocol.VersionNumber
}
func newReceivedPacketTracker(
rttStats *congestion.RTTStats,
logger utils.Logger,
version protocol.VersionNumber,
) *receivedPacketTracker {
return &receivedPacketTracker{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: ackSendDelay,
rttStats: rttStats,
logger: logger,
version: version,
}
}
func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber >= h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil
}
// IgnoreBelow sets a lower limit for acking packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) {
if p <= h.ignoreBelow {
return
}
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketTracker) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
}
// maybeQueueAck queues an ACK, if necessary.
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
// in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketTracker) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
h.ackQueued = true
return
}
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
}
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
}
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
} else {
// send an ACK every 2 retransmittable packets
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
}
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
ackTime := rcvTime.Add(ackDelay)
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
now := time.Now()
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
ack := &wire.AckFrame{
AckRanges: h.packetHistory.GetAckRanges(),
DelayTime: now.Sub(h.largestObservedReceivedTime),
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -359,6 +359,12 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
case <-h.handshakeErrChan:
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return false
}
return true
case typeEncryptedExtensions:
select {
@ -372,12 +378,6 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
// nothing to do
return false
case typeFinished:
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return false
}
// While the order of these two is not defined by the TLS spec,
// we have to do it on the same order as our TLS library does it.
// get the handshake write key

View File

@ -37,15 +37,15 @@ func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecor
}
// GetAckFrame mocks base method
func (m *MockReceivedPacketHandler) GetAckFrame() *wire.AckFrame {
ret := m.ctrl.Call(m, "GetAckFrame")
func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame {
ret := m.ctrl.Call(m, "GetAckFrame", arg0)
ret0, _ := ret[0].(*wire.AckFrame)
return ret0
}
// GetAckFrame indicates an expected call of GetAckFrame
func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame))
func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0)
}
// GetAlarmTimeout mocks base method
@ -71,13 +71,13 @@ func (mr *MockReceivedPacketHandlerMockRecorder) IgnoreBelow(arg0 interface{}) *
}
// ReceivedPacket mocks base method
func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 time.Time, arg2 bool) error {
ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2)
func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel, arg2 time.Time, arg3 bool) error {
ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// ReceivedPacket indicates an expected call of ReceivedPacket
func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2)
func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3)
}

View File

@ -122,6 +122,18 @@ func MinTime(a, b time.Time) time.Time {
return a
}
// MinNonZeroTime returns the earlist time that is not time.Time{}
// If both a and b are time.Time{}, it returns time.Time{}
func MinNonZeroTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
if b.IsZero() {
return a
}
return MinTime(a, b)
}
// MaxTime returns the later time
func MaxTime(a, b time.Time) time.Time {
if a.After(b) {

View File

@ -15,6 +15,7 @@ var (
type multiplexer interface {
AddConn(net.PacketConn, int) (packetHandlerManager, error)
RemoveConn(net.PacketConn) error
}
type connManager struct {
@ -61,3 +62,15 @@ func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandle
}
return p.manager, nil
}
func (m *connMultiplexer) RemoveConn(c net.PacketConn) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, ok := m.conns[c]; !ok {
return fmt.Errorf("cannote remove connection, connection is unknown")
}
delete(m.conns, c)
return nil
}

View File

@ -139,7 +139,7 @@ func (h *packetHandlerMap) close(e error) error {
}
h.mutex.Unlock()
wg.Wait()
return nil
return getMultiplexer().RemoveConn(h.conn)
}
func (h *packetHandlerMap) listen() {

View File

@ -90,7 +90,7 @@ type frameSource interface {
}
type ackFrameSource interface {
GetAckFrame() *wire.AckFrame
GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame
}
type packetPacker struct {
@ -155,7 +155,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
}
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
ack := p.acks.GetAckFrame()
ack := p.acks.GetAckFrame(protocol.Encryption1RTT)
if ack == nil {
return nil, nil
}
@ -285,30 +285,41 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
var s cryptoStream
var encLevel protocol.EncryptionLevel
if p.initialStream.HasData() {
hasData := p.initialStream.HasData()
ack := p.acks.GetAckFrame(protocol.EncryptionInitial)
if hasData || ack != nil {
s = p.initialStream
encLevel = protocol.EncryptionInitial
} else if p.handshakeStream.HasData() {
s = p.handshakeStream
encLevel = protocol.EncryptionHandshake
} else {
hasData = p.handshakeStream.HasData()
ack = p.acks.GetAckFrame(protocol.EncryptionHandshake)
if hasData || ack != nil {
s = p.handshakeStream
encLevel = protocol.EncryptionHandshake
}
}
if s == nil {
return nil, nil
}
hdr := p.getHeader(encLevel)
hdrLen := hdr.GetLength(p.version)
sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
if err != nil {
// The sealer
return nil, err
}
hdr := p.getHeader(encLevel)
hdrLen := hdr.GetLength(p.version)
var length protocol.ByteCount
frames := make([]wire.Frame, 0, 2)
if ack := p.acks.GetAckFrame(); ack != nil {
if ack != nil {
frames = append(frames, ack)
length += ack.Length(p.version)
}
cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
frames = append(frames, cf)
if hasData {
cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
frames = append(frames, cf)
}
return p.writeAndSealPacket(hdr, frames, sealer)
}
@ -317,7 +328,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wir
var frames []wire.Frame
// ACKs need to go first, so that the sentPacketHandler will recognize them
if ack := p.acks.GetAckFrame(); ack != nil {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT); ack != nil {
frames = append(frames, ack)
length += ack.Length(p.version)
}

View File

@ -566,7 +566,7 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time
}
}
if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, rcvTime, isRetransmittable); err != nil {
if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, packet.encryptionLevel, rcvTime, isRetransmittable); err != nil {
return err
}
return nil
@ -726,7 +726,9 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber,
if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil {
return err
}
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
if encLevel == protocol.Encryption1RTT {
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
}
return nil
}

View File

@ -116,9 +116,9 @@ func (ks *keySchedule13) setSecret(secret []byte) {
salt := ks.secret
if salt != nil {
h0 := hash.New().Sum(nil)
salt = HkdfExpandLabel(hash, salt, h0, "derived", hash.Size())
salt = hkdfExpandLabel(hash, salt, h0, "derived", hash.Size())
}
ks.secret = HkdfExtract(hash, secret, salt)
ks.secret = hkdfExtract(hash, secret, salt)
}
// Depending on role returns pair of key variant to be used by
@ -168,7 +168,7 @@ func (ks *keySchedule13) deriveSecret(secretLabel secretLabel) []byte {
ks.handshakeCtx = ks.transcriptHash.Sum(nil)
}
hash := hashForSuite(ks.suite)
secret := HkdfExpandLabel(hash, ks.secret, ks.handshakeCtx, label, hash.Size())
secret := hkdfExpandLabel(hash, ks.secret, ks.handshakeCtx, label, hash.Size())
if keylogType != "" && ks.config != nil {
ks.config.writeKeyLog(keylogType, ks.clientRandom, secret)
}
@ -177,8 +177,8 @@ func (ks *keySchedule13) deriveSecret(secretLabel secretLabel) []byte {
func (ks *keySchedule13) prepareCipher(trafficSecret []byte) cipher.AEAD {
hash := hashForSuite(ks.suite)
key := HkdfExpandLabel(hash, trafficSecret, nil, "key", ks.suite.keyLen)
iv := HkdfExpandLabel(hash, trafficSecret, nil, "iv", ks.suite.ivLen)
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", ks.suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", ks.suite.ivLen)
return ks.suite.aead(key, iv)
}
@ -254,10 +254,11 @@ CurvePreferenceLoop:
hs.keySchedule.setSecret(ecdheSecret)
hs.hsClientTrafficSecret = hs.keySchedule.deriveSecret(secretHandshakeClient)
hsServerTrafficSecret := hs.keySchedule.deriveSecret(secretHandshakeServer)
c.out.exportKey(hs.keySchedule.suite, hsServerTrafficSecret)
c.out.setKey(c.vers, hs.keySchedule.suite, hsServerTrafficSecret)
serverFinishedKey := HkdfExpandLabel(hash, hsServerTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = HkdfExpandLabel(hash, hs.hsClientTrafficSecret, nil, "finished", hashSize)
serverFinishedKey := hkdfExpandLabel(hash, hsServerTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, hs.hsClientTrafficSecret, nil, "finished", hashSize)
// EncryptedExtensions
hs.keySchedule.write(hs.hello13Enc.marshal())
@ -296,6 +297,7 @@ CurvePreferenceLoop:
hs.keySchedule.setSecret(nil) // derive master secret
serverAppTrafficSecret := hs.keySchedule.deriveSecret(secretApplicationServer)
c.out.exportKey(hs.keySchedule.suite, serverAppTrafficSecret)
c.out.setKey(c.vers, hs.keySchedule.suite, serverAppTrafficSecret)
if c.hand.Len() > 0 {
@ -303,9 +305,11 @@ CurvePreferenceLoop:
}
hs.appClientTrafficSecret = hs.keySchedule.deriveSecret(secretApplicationClient)
if hs.hello13Enc.earlyData {
c.in.exportKey(hs.keySchedule.suite, earlyClientTrafficSecret)
c.in.setKey(c.vers, hs.keySchedule.suite, earlyClientTrafficSecret)
c.phase = readingEarlyData
} else {
c.in.exportKey(hs.keySchedule.suite, hs.hsClientTrafficSecret)
c.in.setKey(c.vers, hs.keySchedule.suite, hs.hsClientTrafficSecret)
if hs.clientHello.earlyData {
c.phase = discardingEarlyData
@ -418,6 +422,7 @@ func (hs *serverHandshakeState) readClientFinished13(hasConfirmLock bool) error
if c.hand.Len() > 0 {
return c.sendAlert(alertUnexpectedMessage)
}
c.in.exportKey(hs.keySchedule.suite, hs.appClientTrafficSecret)
c.in.setKey(c.vers, hs.keySchedule.suite, hs.appClientTrafficSecret)
c.in.traceErr, c.out.traceErr = nil, nil
c.phase = handshakeConfirmed
@ -514,6 +519,7 @@ func (c *Conn) handleEndOfEarlyData() error {
}
c.hs.keySchedule.write(endOfEarlyData.marshal())
c.phase = waitingClientFinished
c.in.exportKey(c.hs.keySchedule.suite, c.hs.hsClientTrafficSecret)
c.in.setKey(c.vers, c.hs.keySchedule.suite, c.hs.hsClientTrafficSecret)
return nil
}
@ -618,6 +624,10 @@ func (c *Conn) deriveDHESecret(ks keyShare, secretKey []byte) []byte {
// HkdfExpandLabel HKDF expands a label
func HkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
return hkdfExpandLabel(hash, secret, hashValue, label, L)
}
func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
prefix := "tls13 "
hkdfLabel := make([]byte, 4+len(prefix)+len(label)+len(hashValue))
hkdfLabel[0] = byte(L >> 8)
@ -710,7 +720,7 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
hs.keySchedule.setSecret(s.pskSecret)
binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
binderFinishedKey := HkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New()
chHash.Write(hs.clientHello.rawTruncated)
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
@ -781,7 +791,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
// tickets might have the same PSK which could be a problem if
// one of them is compromised.
ticketNonce := []byte{byte(i)}
sessionState.pskSecret = HkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size())
sessionState.pskSecret = hkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size())
ticket := sessionState.marshal()
var err error
if c.config.SessionTicketSealer != nil {
@ -1006,13 +1016,17 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: unexpected data after Server Hello")
}
// Do not change the sender key yet, the server must authenticate first.
serverHandshakeSecret := hs.keySchedule.deriveSecret(secretHandshakeServer)
c.in.exportKey(hs.keySchedule.suite, serverHandshakeSecret)
// Already the sender key yet, when using an alternative record layer.
// QUIC needs the handshake write key in order to acknowlege Handshake packets.
c.out.exportKey(hs.keySchedule.suite, clientHandshakeSecret)
// Do not change the sender key yet, the server must authenticate first.
c.in.setKey(c.vers, hs.keySchedule.suite, serverHandshakeSecret)
// Calculate MAC key for Finished messages.
serverFinishedKey := HkdfExpandLabel(hash, serverHandshakeSecret, nil, "finished", hashSize)
clientFinishedKey := HkdfExpandLabel(hash, clientHandshakeSecret, nil, "finished", hashSize)
serverFinishedKey := hkdfExpandLabel(hash, serverHandshakeSecret, nil, "finished", hashSize)
clientFinishedKey := hkdfExpandLabel(hash, clientHandshakeSecret, nil, "finished", hashSize)
msg, err := c.readHandshake()
if err != nil {
@ -1155,11 +1169,13 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// Handshake done, set application traffic secret
// TODO store initial traffic secret key for KeyUpdate GH #85
c.out.exportKey(hs.keySchedule.suite, clientAppTrafficSecret)
c.out.setKey(c.vers, hs.keySchedule.suite, clientAppTrafficSecret)
if c.hand.Len() > 0 {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: unexpected data after handshake")
}
c.in.exportKey(hs.keySchedule.suite, serverAppTrafficSecret)
c.in.setKey(c.vers, hs.keySchedule.suite, serverAppTrafficSecret)
return nil
}

View File

@ -234,15 +234,20 @@ func (hc *halfConn) changeCipherSpec() error {
return nil
}
func (hc *halfConn) setKey(version uint16, suite *cipherSuite, trafficSecret []byte) {
func (hc *halfConn) exportKey(suite *cipherSuite, trafficSecret []byte) {
if hc.setKeyCallback != nil {
hc.setKeyCallback(&CipherSuite{*suite}, trafficSecret)
}
}
func (hc *halfConn) setKey(version uint16, suite *cipherSuite, trafficSecret []byte) {
if hc.setKeyCallback != nil {
return
}
hc.version = version
hash := hashForSuite(suite)
key := HkdfExpandLabel(hash, trafficSecret, nil, "key", suite.keyLen)
iv := HkdfExpandLabel(hash, trafficSecret, nil, "iv", suite.ivLen)
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", suite.ivLen)
hc.cipher = suite.aead(key, iv)
for i := range hc.seq {
hc.seq[i] = 0

View File

@ -47,6 +47,10 @@ func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte {
// HkdfExtract generates a pseudorandom key for use with Expand from an input secret and an optional independent salt.
func HkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
return hkdfExtract(hash, secret, salt)
}
func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
if salt == nil {
salt = make([]byte, hash.Size())
}