1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-09-28 14:56:33 -04:00

sync quic package

This commit is contained in:
Darien Raymond 2019-01-02 13:01:06 +01:00
parent d20f87da4b
commit ec89d42feb
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
61 changed files with 10569 additions and 539 deletions

View File

@ -10,9 +10,6 @@ environment:
- GOARCH: 386
- GOARCH: amd64
hosts:
quic.clemente.io: 127.0.0.1
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install:

View File

@ -3,24 +3,51 @@ package quic
import (
"sync"
"v2ray.com/core/common/bytespool"
"github.com/lucas-clemente/quic-go/internal/protocol"
"v2ray.com/core/common/bytespool"
)
type packetBuffer struct {
Slice []byte
// refCount counts how many packets the Slice is used in.
// It doesn't support concurrent use.
// It is > 1 when used for coalesced packet.
refCount int
}
// Split increases the refCount.
// It must be called when a packet buffer is used for more than one packet,
// e.g. when splitting coalesced packets.
func (b *packetBuffer) Split() {
b.refCount++
}
// Release decreases the refCount.
// It should be called when processing the packet is finished.
// When the refCount reaches 0, the packet buffer is put back into the pool.
func (b *packetBuffer) Release() {
if cap(b.Slice) < 2048 {
return
}
b.refCount--
if b.refCount < 0 {
panic("negative packetBuffer refCount")
}
// only put the packetBuffer back if it's not used any more
if b.refCount == 0 {
bufferPool.Put(b.Slice)
}
}
var bufferPool *sync.Pool
func getPacketBuffer() *[]byte {
b := bufferPool.Get().([]byte)
return &b
}
func putPacketBuffer(buf *[]byte) {
b := *buf
if cap(b) < 2048 {
return
func getPacketBuffer() *packetBuffer {
buffer := bufferPool.Get().([]byte)
return &packetBuffer{
refCount: 1,
Slice: buffer[:protocol.MaxReceivePacketSize],
}
bufferPool.Put(b[:cap(b)])
}
func init() {

View File

@ -3,7 +3,6 @@ package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
@ -38,6 +37,8 @@ type client struct {
destConnID protocol.ConnectionID
origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
initialPacketNumber protocol.PacketNumber
initialVersion protocol.VersionNumber
version protocol.VersionNumber
@ -54,8 +55,6 @@ var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
)
// DialAddr establishes a new QUIC connection to a server.
@ -255,7 +254,7 @@ func (c *client) dial(ctx context.Context) error {
return err
}
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
if err == errCloseForRecreating {
return c.dial(ctx)
}
return err
@ -263,8 +262,7 @@ func (c *client) dial(ctx context.Context) error {
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
// - errCloseSessionRecreating when the server sends a version negotiation packet, or a stateless retry is performed
// - any other error that might occur
// - when the connection is forward-secure
func (c *client) establishSecureConnection(ctx context.Context) error {
@ -272,7 +270,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
if err != errCloseForRecreating && c.createdPacketConn {
c.conn.Close()
}
errorChan <- err
@ -344,7 +342,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
c.version = newVersion
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
c.initialPacketNumber = c.session.closeForRecreating()
}
func (c *client) handleRetryPacket(hdr *wire.Header) {
@ -370,7 +368,7 @@ func (c *client) handleRetryPacket(hdr *wire.Header) {
c.origDestConnID = c.destConnID
c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token
c.session.destroy(errCloseSessionForRetry)
c.initialPacketNumber = c.session.closeForRecreating()
}
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
@ -401,6 +399,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
c.srcConnID,
c.config,
c.tlsConf,
c.initialPacketNumber,
params,
c.initialVersion,
c.logger,

View File

@ -34,7 +34,7 @@ type sentPacketHandler struct {
packetNumberGenerator *packetNumberGenerator
lastSentRetransmittablePacketTime time.Time
lastSentHandshakePacketTime time.Time
lastSentCryptoPacketTime time.Time
nextPacketSendTime time.Time
@ -56,8 +56,8 @@ type sentPacketHandler struct {
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times the crypto packets have been retransmitted without receiving an ack.
cryptoCount uint32
// The number of times a TLP has been sent without receiving an ack.
tlpCount uint32
@ -78,7 +78,11 @@ type sentPacketHandler struct {
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
func NewSentPacketHandler(
initialPacketNumber protocol.PacketNumber,
rttStats *congestion.RTTStats,
logger utils.Logger,
) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
@ -88,7 +92,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se
)
return &sentPacketHandler{
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
packetHistory: newSentPacketHistory(),
rttStats: rttStats,
congestion: congestion,
@ -104,21 +108,21 @@ func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
h.logger.Debugf("Handshake complete. Discarding all outstanding crypto packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.Encryption1RTT {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
var cryptoPackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
cryptoPackets = append(cryptoPackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
for _, p := range cryptoPackets {
h.packetHistory.Remove(p.PacketNumber)
}
h.retransmissionQueue = queue
@ -144,8 +148,10 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra
}
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %#x", p)
if h.logger.Debug() && h.lastSentPacketNumber != 0 {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %#x", p)
}
}
h.lastSentPacketNumber = packet.PacketNumber
@ -161,7 +167,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
if isRetransmittable {
if packet.EncryptionLevel != protocol.Encryption1RTT {
h.lastSentHandshakePacketTime = packet.SendTime
h.lastSentCryptoPacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
packet.includedInBytesInFlight = true
@ -185,7 +191,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
}
// duplicate or out of order ACK
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
if withPacketNumber != 0 && withPacketNumber < h.largestReceivedPacketWithAck {
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
return nil
}
@ -299,8 +305,8 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() {
return
}
if h.packetHistory.HasOutstandingHandshakePackets() {
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
if h.packetHistory.HasOutstandingCryptoPackets() {
h.alarm = h.lastSentCryptoPacketTime.Add(h.computeCryptoTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
@ -381,12 +387,12 @@ func (h *sentPacketHandler) OnAlarm() error {
func (h *sentPacketHandler) onVerifiedAlarm() error {
var err error
if h.packetHistory.HasOutstandingHandshakePackets() {
if h.packetHistory.HasOutstandingCryptoPackets() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
h.logger.Debugf("Loss detection alarm fired in crypto mode. Crypto count: %d", h.cryptoCount)
}
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
h.cryptoCount++
err = h.queueCryptoPacketsForRetransmission()
} else if !h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
@ -456,7 +462,7 @@ func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
}
h.rtoCount = 0
h.tlpCount = 0
h.handshakeCount = 0
h.cryptoCount = 0
return h.packetHistory.Remove(p.PacketNumber)
}
@ -575,16 +581,16 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
func (h *sentPacketHandler) queueCryptoPacketsForRetransmission() error {
var cryptoPackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted && p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
cryptoPackets = append(cryptoPackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
for _, p := range cryptoPackets {
h.logger.Debugf("Queueing packet %#x as a crypto retransmission", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
@ -603,11 +609,11 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
return nil
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
func (h *sentPacketHandler) computeCryptoTimeout() time.Duration {
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
// There's an implicit limit to this set by the crypto timeout.
return duration << h.cryptoCount
}
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {

View File

@ -10,8 +10,8 @@ type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
numOutstandingPackets int
numOutstandingHandshakePackets int
numOutstandingPackets int
numOutstandingCryptoPackets int
firstOutstanding *PacketElement
}
@ -36,7 +36,7 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
if p.canBeRetransmitted {
h.numOutstandingPackets++
if p.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets++
h.numOutstandingCryptoPackets++
}
}
return el
@ -107,8 +107,8 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
h.numOutstandingCryptoPackets--
if h.numOutstandingCryptoPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
@ -148,8 +148,8 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
h.numOutstandingCryptoPackets--
if h.numOutstandingCryptoPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
@ -163,6 +163,6 @@ func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstandingPackets > 0
}
func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool {
return h.numOutstandingHandshakePackets > 0
func (h *sentPacketHistory) HasOutstandingCryptoPackets() bool {
return h.numOutstandingCryptoPackets > 0
}

View File

@ -8,26 +8,56 @@ import (
)
type sealer struct {
iv []byte
aead cipher.AEAD
iv []byte
aead cipher.AEAD
pnEncrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
pnMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
}
var _ Sealer = &sealer{}
func newSealer(aead cipher.AEAD, iv []byte) Sealer {
func newSealer(aead cipher.AEAD, iv []byte, pnEncrypter cipher.Block, is1RTT bool) Sealer {
return &sealer{
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnEncrypter: pnEncrypter,
pnMask: make([]byte, pnEncrypter.BlockSize()),
}
}
func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
return s.aead.Seal(dst, s.nonceBuf, src, ad)
for i := 0; i < len(s.nonceBuf); i++ {
s.nonceBuf[i] ^= s.iv[i]
}
sealed := s.aead.Seal(dst, s.nonceBuf, src, ad)
for i := 0; i < len(s.nonceBuf); i++ {
s.nonceBuf[i] = 0
}
return sealed
}
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != s.pnEncrypter.BlockSize() {
panic("invalid sample size")
}
s.pnEncrypter.Encrypt(s.pnMask, sample)
if s.is1RTT {
*firstByte ^= s.pnMask[0] & 0x1f
} else {
*firstByte ^= s.pnMask[0] & 0xf
}
for i := range pnBytes {
pnBytes[i] ^= s.pnMask[i+1]
}
}
func (s *sealer) Overhead() int {
@ -35,24 +65,54 @@ func (s *sealer) Overhead() int {
}
type opener struct {
iv []byte
aead cipher.AEAD
iv []byte
aead cipher.AEAD
pnDecrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
pnMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
}
var _ Opener = &opener{}
func newOpener(aead cipher.AEAD, iv []byte) Opener {
func newOpener(aead cipher.AEAD, iv []byte, pnDecrypter cipher.Block, is1RTT bool) Opener {
return &opener{
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnDecrypter: pnDecrypter,
pnMask: make([]byte, pnDecrypter.BlockSize()),
}
}
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
return o.aead.Open(dst, o.nonceBuf, src, ad)
for i := 0; i < len(o.nonceBuf); i++ {
o.nonceBuf[i] ^= o.iv[i]
}
opened, err := o.aead.Open(dst, o.nonceBuf, src, ad)
for i := 0; i < len(o.nonceBuf); i++ {
o.nonceBuf[i] = 0
}
return opened, err
}
func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() {
panic("invalid sample size")
}
o.pnDecrypter.Encrypt(o.pnMask, sample)
if o.is1RTT {
*firstByte ^= o.pnMask[0] & 0x1f
} else {
*firstByte ^= o.pnMask[0] & 0xf
}
for i := range pnBytes {
pnBytes[i] ^= o.pnMask[i+1]
}
}

View File

@ -1,12 +1,12 @@
package handshake
import (
"crypto/aes"
"crypto/tls"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
@ -46,6 +46,11 @@ func (m messageType) String() string {
}
}
// ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level,
// but the corresponding opener has not yet been initialized
// This can happen when packets arrive out of order.
var ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available")
type cryptoSetup struct {
tlsConf *qtls.Config
@ -74,7 +79,8 @@ type cryptoSetup struct {
clientHelloWrittenChan chan struct{}
initialStream io.Writer
initialAEAD crypto.AEAD
initialOpener Opener
initialSealer Sealer
handshakeStream io.Writer
handshakeOpener Opener
@ -175,13 +181,14 @@ func newCryptoSetup(
logger utils.Logger,
perspective protocol.Perspective,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
initialAEAD, err := crypto.NewNullAEAD(connID, perspective)
initialSealer, initialOpener, err := newInitialAEAD(connID, perspective)
if err != nil {
return nil, nil, err
}
cs := &cryptoSetup{
initialStream: initialStream,
initialAEAD: initialAEAD,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
@ -403,9 +410,19 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
}
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
opener := newOpener(suite.AEAD(key, iv), iv)
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "key", suite.KeyLen())
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "iv", suite.IVLen())
pnKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "pn", suite.KeyLen())
pnDecrypter, err := aes.NewCipher(pnKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
opener := newOpener(
suite.AEAD(key, iv),
iv,
pnDecrypter,
h.readEncLevel == protocol.Encryption1RTT,
)
switch h.readEncLevel {
case protocol.EncryptionInitial:
@ -423,9 +440,19 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
}
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
sealer := newSealer(suite.AEAD(key, iv), iv)
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "key", suite.KeyLen())
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "iv", suite.IVLen())
pnKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "pn", suite.KeyLen())
pnEncrypter, err := aes.NewCipher(pnKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
sealer := newSealer(
suite.AEAD(key, iv),
iv,
pnEncrypter,
h.writeEncLevel == protocol.Encryption1RTT,
)
switch h.writeEncLevel {
case protocol.EncryptionInitial:
@ -467,7 +494,7 @@ func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
if h.handshakeSealer != nil {
return protocol.EncryptionHandshake, h.handshakeSealer
}
return protocol.EncryptionInitial, h.initialAEAD
return protocol.EncryptionInitial, h.initialSealer
}
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
@ -475,7 +502,7 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
switch level {
case protocol.EncryptionInitial:
return h.initialAEAD, nil
return h.initialSealer, nil
case protocol.EncryptionHandshake:
if h.handshakeSealer == nil {
return nil, errNoSealer
@ -491,22 +518,23 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
}
}
func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
return h.initialAEAD.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
if h.handshakeOpener == nil {
return nil, errors.New("no handshake opener")
func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) {
switch level {
case protocol.EncryptionInitial:
return h.initialOpener, nil
case protocol.EncryptionHandshake:
if h.handshakeOpener == nil {
return nil, ErrOpenerNotYetAvailable
}
return h.handshakeOpener, nil
case protocol.Encryption1RTT:
if h.opener == nil {
return nil, ErrOpenerNotYetAvailable
}
return h.opener, nil
default:
return nil, fmt.Errorf("CryptoSetup: no opener with encryption level %s", level)
}
return h.handshakeOpener.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
if h.opener == nil {
return nil, errors.New("no 1-RTT opener")
}
return h.opener.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) ConnectionState() ConnectionState {

View File

@ -0,0 +1,66 @@
package handshake
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
)
var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0}
func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Sealer, Opener, error) {
clientSecret, serverSecret := computeSecrets(connID)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
mySecret = clientSecret
otherSecret = serverSecret
} else {
mySecret = serverSecret
otherSecret = clientSecret
}
myKey, myPNKey, myIV := computeInitialKeyAndIV(mySecret)
otherKey, otherPNKey, otherIV := computeInitialKeyAndIV(otherSecret)
encrypterCipher, err := aes.NewCipher(myKey)
if err != nil {
return nil, nil, err
}
encrypter, err := cipher.NewGCM(encrypterCipher)
if err != nil {
return nil, nil, err
}
pnEncrypter, err := aes.NewCipher(myPNKey)
if err != nil {
return nil, nil, err
}
decrypterCipher, err := aes.NewCipher(otherKey)
if err != nil {
return nil, nil, err
}
decrypter, err := cipher.NewGCM(decrypterCipher)
if err != nil {
return nil, nil, err
}
pnDecrypter, err := aes.NewCipher(otherPNKey)
if err != nil {
return nil, nil, err
}
return newSealer(encrypter, myIV, pnEncrypter, false), newOpener(decrypter, otherIV, pnDecrypter, false), nil
}
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
initialSecret := qtls.HkdfExtract(crypto.SHA256, connID, quicVersion1Salt)
clientSecret = qtls.HkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
serverSecret = qtls.HkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
return
}
func computeInitialKeyAndIV(secret []byte) (key, pnKey, iv []byte) {
key = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
pnKey = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic hp", 16)
iv = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
return
}

View File

@ -11,11 +11,13 @@ import (
// Opener opens a packet
type Opener interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
// Sealer seals a packet
type Sealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
Overhead() int
}
@ -35,10 +37,7 @@ type CryptoSetup interface {
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
GetOpener(protocol.EncryptionLevel) (Opener, error)
}
// ConnectionState records basic details about the QUIC connection.

View File

@ -59,6 +59,19 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
}
// GetOpener mocks base method
func (m *MockCryptoSetup) GetOpener(arg0 protocol.EncryptionLevel) (handshake.Opener, error) {
ret := m.ctrl.Call(m, "GetOpener", arg0)
ret0, _ := ret[0].(handshake.Opener)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOpener indicates an expected call of GetOpener
func (mr *MockCryptoSetupMockRecorder) GetOpener(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetOpener), arg0)
}
// GetSealer mocks base method
func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
ret := m.ctrl.Call(m, "GetSealer")
@ -97,45 +110,6 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
}
// Open1RTT mocks base method
func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open1RTT indicates an expected call of Open1RTT
func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3)
}
// OpenHandshake mocks base method
func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenHandshake indicates an expected call of OpenHandshake
func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
}
// OpenInitial mocks base method
func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenInitial indicates an expected call of OpenInitial
func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3)
}
// RunHandshake mocks base method
func (m *MockCryptoSetup) RunHandshake() error {
ret := m.ctrl.Call(m, "RunHandshake")

View File

@ -1,10 +1,10 @@
package mocks
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
//go:generate sh -c "../mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm"
//go:generate sh -c "../mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
//go:generate sh -c "../mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"

View File

@ -0,0 +1,58 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Opener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockOpener is a mock of Opener interface
type MockOpener struct {
ctrl *gomock.Controller
recorder *MockOpenerMockRecorder
}
// MockOpenerMockRecorder is the mock recorder for MockOpener
type MockOpenerMockRecorder struct {
mock *MockOpener
}
// NewMockOpener creates a new mock instance
func NewMockOpener(ctrl *gomock.Controller) *MockOpener {
mock := &MockOpener{ctrl: ctrl}
mock.recorder = &MockOpenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockOpener) EXPECT() *MockOpenerMockRecorder {
return m.recorder
}
// DecryptHeader mocks base method
func (m *MockOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method
func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockOpener)(nil).Open), arg0, arg1, arg2, arg3)
}

View File

@ -34,6 +34,16 @@ func (m *MockSealer) EXPECT() *MockSealerMockRecorder {
return m.recorder
}
// EncryptHeader mocks base method
func (m *MockSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2)
}
// EncryptHeader indicates an expected call of EncryptHeader
func (mr *MockSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockSealer)(nil).EncryptHeader), arg0, arg1, arg2)
}
// Overhead mocks base method
func (m *MockSealer) Overhead() int {
ret := m.ctrl.Call(m, "Overhead")

View File

@ -16,8 +16,8 @@ const (
PacketNumberLen4 PacketNumberLen = 4
)
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func InferPacketNumber(
// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func DecodePacketNumber(
packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber,

View File

@ -0,0 +1,18 @@
-----BEGIN CERTIFICATE-----
MIIC0DCCAbgCCQCmiwJpSoekpDANBgkqhkiG9w0BAQsFADAqMRMwEQYDVQQKDApx
dWljLWdvIENBMRMwEQYDVQQLDApxdWljLWdvIENBMB4XDTE4MTIwODA2NDIyMVoX
DTI4MTIwNTA2NDIyMVowKjETMBEGA1UECgwKcXVpYy1nbyBDQTETMBEGA1UECwwK
cXVpYy1nbyBDQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAN5MxI09
i01xRON732BFIuxO2SGjA9jYkvUvNXK886gifp2BfWLcOW1DHkXxBnhWMqfpcIWM
GviF4G2Mp0HEJDMe+4LBxje/1e2WA+nzQlIZD6LaDi98nXJaAcCMM4a64Vm0i8Z3
+4c+O93+5TekPn507nl7QA1IaEEtoek7w7wDw4ZF3ET+nns2HwVpV/ugfuYOQbTJ
8Np+zO8EfPMTUjEpKdl4bp/yqcouWD+oIhoxmx1V+LxshcpSwtzHIAi6gjHUDCEe
bk5Y2GBT4VR5WKmNGvlfe9L0Gn0ZLJoeXDshrunF0xEmSv8MxlHcKH/u4IHiO+6x
+5sdslqY7uEPEhkCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAhvXUMiatkgsnoRHc
UobKraGttETivxvtKpc48o1TSkR+kCKbMnygmrvc5niEqc9iDg8JI6HjBKJ3/hfA
uKdyiR8cQNcQRgJ/3FVx0n3KGDUbHJSuIQzFvXom2ZPdlAHFqAT+8AVrz42v8gct
gyiGdFCSNisDbevOiRHuJtZ0m8YsGgtfU48wqGOaSSsRz4mYD6kqBFd0+Ja3/EGv
vl24L5xMCy1zGGl6wKPa7TT7ok4TfD1YmIXOfmWYop6cTLwePLj1nHrLi0AlsSn1
2pFlosc9/qEbO5drqNoxUZfeF0L9RUSuArHRSO779dW/AmOtFdK3yaBGqflg0r7p
lYombA==
-----END CERTIFICATE-----

View File

@ -2,6 +2,9 @@ package testdata
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"io/ioutil"
"path"
"runtime"
)
@ -14,13 +17,12 @@ func init() {
panic("Failed to get current frame")
}
certPath = path.Join(path.Dir(path.Dir(path.Dir(filename))), "example")
certPath = path.Dir(filename)
}
// GetCertificatePaths returns the paths to 'fullchain.pem' and 'privkey.pem' for the
// quic.clemente.io cert.
// GetCertificatePaths returns the paths to certificate and key
func GetCertificatePaths() (string, string) {
return path.Join(certPath, "fullchain.pem"), path.Join(certPath, "privkey.pem")
return path.Join(certPath, "cert.pem"), path.Join(certPath, "priv.key")
}
// GetTLSConfig returns a tls config for quic.clemente.io
@ -34,11 +36,22 @@ func GetTLSConfig() *tls.Config {
}
}
// GetCertificate returns a certificate for quic.clemente.io
func GetCertificate() tls.Certificate {
cert, err := tls.LoadX509KeyPair(GetCertificatePaths())
// GetRootCA returns an x509.CertPool containing the CA certificate
func GetRootCA() *x509.CertPool {
caCertPath := path.Join(certPath, "ca.pem")
caCertRaw, err := ioutil.ReadFile(caCertPath)
if err != nil {
panic(err)
}
return cert
p, _ := pem.Decode(caCertRaw)
if p.Type != "CERTIFICATE" {
panic("expected a certificate")
}
caCert, err := x509.ParseCertificate(p.Bytes)
if err != nil {
panic(err)
}
certPool := x509.NewCertPool()
certPool.AddCert(caCert)
return certPool
}

View File

@ -0,0 +1,18 @@
-----BEGIN CERTIFICATE-----
MIIC3jCCAcYCCQCV4BOv+SRo4zANBgkqhkiG9w0BAQUFADAqMRMwEQYDVQQKDApx
dWljLWdvIENBMRMwEQYDVQQLDApxdWljLWdvIENBMB4XDTE4MTIwODA2NDMwMloX
DTI4MTIwNTA2NDMwMlowODEQMA4GA1UECgwHcXVpYy1nbzEQMA4GA1UECwwHcXVp
Yy1nbzESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
MIIBCgKCAQEAyc/hS8XHkOJaLrdPOSTZFUBVyHNSfQUX/3dEpmccPlLQLgopYZZO
W/cVhkxAfQ3e68xKkuZKfZN5Hytn5V/AOSk281BqxFxpfCcKVYqVpDZH99+jaVfG
ImPp5Y22qCnbSEwYrMTcLiK8PVa4MkpKf1KNacVlqawU+ZWI5fevAFGTtmrMJ4S+
qZY7tAaVkax+OiKWWfhLQjJCsN3IIDysTfbWao6cYKgtTfqVChEddzS7LRJVRaB+
+huUbB87tRBJbCuJX65yB7Fw77YiKoFjc5r2845fcS2Ew4+w29mbXoj7M7g6eup5
SnCydsCvyNy6VkgaSlWS0DXvxuzWshwUrwIDAQABMA0GCSqGSIb3DQEBBQUAA4IB
AQBWgmFunf44X3/NIjNvVLeQsfGW+4L/lCi2F5tqa70Hkda+xhKACnQQGB2qCSCF
Jfxj4iKrFJ7+JB8GnribWthLuDq49PQrTI+1wKFd9c2b8DXzJLz4Onw+mPX97pZm
TflQSIxXRaFAIQuUWNTArZZEe1ESSlnaBuE5w77LMf4GMFD3P3jzSHKUyM1sF97j
gRbIt8Jw7Uyd8vlXk6m2wvO5H3hZrrhJUJH3WW13a7wLJRnff2meKU90hkLQwuxO
kyh0k/h158/r2ibiahTmQEgHs9vQaCM+HXuk5P+Tzq5Zl/n0dMFZMfkqNkD4nym/
nu7zfdwMlcBjKt9g3BGw+KE3
-----END CERTIFICATE-----

View File

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEAyc/hS8XHkOJaLrdPOSTZFUBVyHNSfQUX/3dEpmccPlLQLgop
YZZOW/cVhkxAfQ3e68xKkuZKfZN5Hytn5V/AOSk281BqxFxpfCcKVYqVpDZH99+j
aVfGImPp5Y22qCnbSEwYrMTcLiK8PVa4MkpKf1KNacVlqawU+ZWI5fevAFGTtmrM
J4S+qZY7tAaVkax+OiKWWfhLQjJCsN3IIDysTfbWao6cYKgtTfqVChEddzS7LRJV
RaB++huUbB87tRBJbCuJX65yB7Fw77YiKoFjc5r2845fcS2Ew4+w29mbXoj7M7g6
eup5SnCydsCvyNy6VkgaSlWS0DXvxuzWshwUrwIDAQABAoIBADunQwVO1Qqync2p
SbWueqyZc8HotL1XwBw3eQdm+yZA/GBfiJPcBhWRF7+20mkkrHwuyuxZPjOYX/ki
r3dRslQzJpcNckHQvy1/rMJUUJ9VnDhc1sTQuTR5LC46kX9rv/HC7JhFKIBKrDHF
bHURGKxCDqLxQnfA8gJEfU7cw9HnxMxmKv7qJ3O7EHYMuTQstkYsGOr60zX/C+Zm
7YA+d7nx1LpL0m2lKs70iz5MzGg+KgKyrkMWQ30gpxILBxNzzuQr7Kv/+63/3+G9
nfCGeLmwGakPFpm6/GwiABE0yGa71YNAQs18iUTZwP/ZEDw3KB2SoG8wcqWjNAd+
cUF2PgECgYEA5Xe/OZouw9h0NBo0Zut+HC0YOuUfY72Ug9Fm8bAS6wDuPiO3jIvK
J40d+ZHNp4AakfTuugiqEDJRlV7T/F2K/KHDWvXTg5ZpAC8dsZKJMxyyAp8EniYQ
vsoFWeHBfsD83rCVKLcjDB3hbQH+MSoT3lsqjZRNiNUMK13gyuX7k28CgYEA4SWF
ySRXUqUezX5D8kV5rQVYLcw6WVB3czYd7cKf8zHy4xJX0ZicyZjohknMmKCkdx+M
1mrxlqUO7EBGokM8vs87m/4rz6bjgZffpWzUmP/x1+3f3j/wIZeqNilW8NqY5nLi
tj3JxMwaesU86rOekSy27BlX4sjQ8NRs7Z2d8sECgYBKAD8kBWwVbqWy88x4cHOA
BK7ut1tTIB1YEVzgjobbULaERaJ46c/sx16mUHYBEZf///xI9Ghbxs52nFlC5qve
4xAMMoDey8/a5lbuIDKs0BE8NSoZEm+OB7qIDP0IspYZ/tprgfwEeVJshBsEoew8
Ziwn8m66tPIyvhizdk2WcwKBgH2M8RgDffaGQbESEk3N1FZZvpx7YKZhqtrCeNoX
SB7T4cAigHpPAk+hRzlref46xrvvChiftmztSm8QQNNHb15wLauFh2Taic/Ao2Sa
VcukHnbtHYPQX9Y7vx1I3ESfgdgwhKBfwF5P+wwvZRL0ax5FsxPh5hJ/LZS+wKeY
13WBAoGAXSqG3ANmCyvSLVmAXGIbr0Tuixf/a25sPrlq7Im1H1OnqLrcyxWCLV3E
6gprhG5An0Zlr/FFRxVojf0TKmtJZs9B70/6WPwVvFtBduCM1zuUuCQYU9opTJQL
ElMIP4VfjABm4tm1fqGIy1PQP0Osb6/qb2DPPJqsFiW0oRByyMA=
-----END RSA PRIVATE KEY-----

View File

@ -1,6 +1,9 @@
package utils
import "time"
import (
"math"
"time"
)
// A Timer wrapper that behaves correctly when resetting
type Timer struct {
@ -11,7 +14,7 @@ type Timer struct {
// NewTimer creates a new timer that is not set
func NewTimer() *Timer {
return &Timer{t: time.NewTimer(0)}
return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))}
}
// Chan returns the channel of the wrapped timer
@ -31,7 +34,9 @@ func (t *Timer) Reset(deadline time.Time) {
if !t.t.Stop() && !t.read {
<-t.t.C
}
t.t.Reset(time.Until(deadline))
if !deadline.IsZero() {
t.t.Reset(time.Until(deadline))
}
t.read = false
t.deadline = deadline

View File

@ -30,7 +30,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*Exte
if err != nil {
return nil, err
}
if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil {
if _, err := b.Seek(int64(h.ParsedLen())-1, io.SeekCurrent); err != nil {
return nil, err
}
if h.IsLongHeader {

View File

@ -24,8 +24,8 @@ type Header struct {
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
typeByte byte
len int // how many bytes were read while parsing this header
typeByte byte
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}
// ParseHeader parses the header.
@ -39,7 +39,7 @@ func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
if err != nil {
return nil, err
}
h.len = startLen - b.Len()
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return h, nil
}
@ -171,6 +171,11 @@ func (h *Header) IsVersionNegotiation() bool {
return h.IsLongHeader && h.Version == 0
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *Header) ParsedLen() protocol.ByteCount {
return h.parsedLen
}
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {

View File

@ -13,7 +13,6 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager"
//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker"
//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession"
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler"

View File

@ -144,8 +144,8 @@ func (h *packetHandlerMap) close(e error) error {
func (h *packetHandlerMap) listen() {
for {
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
buffer := getPacketBuffer()
data := buffer.Slice
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err := h.conn.ReadFrom(data)
@ -153,55 +153,110 @@ func (h *packetHandlerMap) listen() {
h.close(err)
return
}
data = data[:n]
if err := h.handlePacket(addr, data); err != nil {
h.logger.Debugf("error handling packet from %s: %s", addr, err)
}
h.handlePacket(addr, buffer, data[:n])
}
}
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
r := bytes.NewReader(data)
hdr, err := wire.ParseHeader(r, h.connIDLen)
// drop the packet if we can't parse the header
func (h *packetHandlerMap) handlePacket(
addr net.Addr,
buffer *packetBuffer,
data []byte,
) {
packets, err := h.parsePacket(addr, buffer, data)
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
h.logger.Debugf("error parsing packets from %s: %s", addr, err)
// This is just the error from parsing the last packet.
// We still need to process the packets that were successfully parsed before.
}
p := &receivedPacket{
remoteAddr: addr,
hdr: hdr,
data: data,
rcvTime: time.Now(),
if len(packets) == 0 {
buffer.Release()
return
}
h.handleParsedPackets(packets)
}
func (h *packetHandlerMap) parsePacket(
addr net.Addr,
buffer *packetBuffer,
data []byte,
) ([]*receivedPacket, error) {
rcvTime := time.Now()
packets := make([]*receivedPacket, 0, 1)
var counter int
var lastConnID protocol.ConnectionID
for len(data) > 0 {
if counter > 0 && h.logger.Debug() {
h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes", counter, len(packets[counter-1].data))
}
hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen)
// drop the packet if we can't parse the header
if err != nil {
return packets, fmt.Errorf("error parsing header: %s", err)
}
if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) {
return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
}
lastConnID = hdr.DestConnectionID
var rest []byte
if hdr.IsLongHeader {
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
rest = data[packetLen:]
data = data[:packetLen]
}
if counter > 0 {
buffer.Split()
}
counter++
packets = append(packets, &receivedPacket{
remoteAddr: addr,
hdr: hdr,
rcvTime: rcvTime,
data: data,
buffer: buffer,
})
data = rest
}
return packets, nil
}
func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) {
h.mutex.RLock()
defer h.mutex.RUnlock()
handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
// coalesced packets all have the same destination connection ID
handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)]
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
return nil
}
// No session found.
// This might be a stateless reset.
if !hdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
return nil
}
for _, p := range packets {
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
continue
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
// No session found.
// This might be a stateless reset.
if !p.hdr.IsLongHeader {
if len(p.data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], p.data[len(p.data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
continue
}
}
// TODO(#943): send a stateless reset
h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
break // a short header packet is always the last in a coalesced packet
}
if h.server != nil { // no server set
h.server.handlePacket(p)
}
h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
}
if h.server == nil { // no server set
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
h.server.handlePacket(p)
return nil
}

View File

@ -25,10 +25,25 @@ type packer interface {
}
type packedPacket struct {
header *wire.ExtendedHeader
raw []byte
frames []wire.Frame
encryptionLevel protocol.EncryptionLevel
header *wire.ExtendedHeader
raw []byte
frames []wire.Frame
buffer *packetBuffer
}
func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
if !p.header.IsLongHeader {
return protocol.Encryption1RTT
}
switch p.header.Type {
case protocol.PacketTypeInitial:
return protocol.EncryptionInitial
case protocol.PacketTypeHandshake:
return protocol.EncryptionHandshake
default:
return protocol.EncryptionUnspecified
}
}
func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
@ -37,7 +52,7 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
PacketType: p.header.Type,
Frames: p.frames,
Length: protocol.ByteCount(len(p.raw)),
EncryptionLevel: p.encryptionLevel,
EncryptionLevel: p.EncryptionLevel(),
SendTime: time.Now(),
}
}
@ -136,13 +151,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
frames := []wire.Frame{ccf}
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
raw, err := p.writeAndSealPacket(header, frames, sealer)
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
return p.writeAndSealPacket(header, frames, sealer)
}
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
@ -154,13 +163,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
frames := []wire.Frame{ack}
raw, err := p.writeAndSealPacket(header, frames, sealer)
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
return p.writeAndSealPacket(header, frames, sealer)
}
// PackRetransmission packs a retransmission
@ -227,16 +230,11 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
sf.DataLenPresent = false
}
raw, err := p.writeAndSealPacket(header, frames, sealer)
p, err := p.writeAndSealPacket(header, frames, sealer)
if err != nil {
return nil, err
}
packets = append(packets, &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
})
packets = append(packets, p)
}
return packets, nil
}
@ -281,16 +279,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
p.numNonRetransmittableAcks = 0
}
raw, err := p.writeAndSealPacket(header, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
return p.writeAndSealPacket(header, frames, sealer)
}
func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
@ -320,16 +309,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
}
cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
frames = append(frames, cf)
raw, err := p.writeAndSealPacket(hdr, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
header: hdr,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
return p.writeAndSealPacket(hdr, frames, sealer)
}
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) {
@ -395,9 +375,9 @@ func (p *packetPacker) writeAndSealPacket(
header *wire.ExtendedHeader,
frames []wire.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := *getPacketBuffer()
buffer := bytes.NewBuffer(raw[:0])
) (*packedPacket, error) {
packetBuffer := getPacketBuffer()
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
@ -421,7 +401,7 @@ func (p *packetPacker) writeAndSealPacket(
if err := header.Write(buffer, p.version); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
payloadOffset := buffer.Len()
// write all frames but the last one
for _, frame := range frames[:len(frames)-1] {
@ -436,7 +416,7 @@ func (p *packetPacker) writeAndSealPacket(
sf.DataLenPresent = true
}
} else {
payloadLen := buffer.Len() - payloadStartIndex + int(lastFrame.Length(p.version))
payloadLen := buffer.Len() - payloadOffset + int(lastFrame.Length(p.version))
if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 {
// Pad the packet such that packet number length + payload length is 4 bytes.
// This is needed to enable the peer to get a 16 byte sample for header protection.
@ -458,15 +438,27 @@ func (p *packetPacker) writeAndSealPacket(
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
}
raw = raw[0:buffer.Len()]
_ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex])
raw := buffer.Bytes()
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
raw = raw[0 : buffer.Len()+sealer.Overhead()]
pnOffset := payloadOffset - int(header.PacketNumberLen)
sealer.EncryptHeader(
raw[pnOffset+4:pnOffset+4+16],
&raw[0],
raw[pnOffset:payloadOffset],
)
num := p.pnManager.PopPacketNumber()
if num != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
return raw, nil
return &packedPacket{
header: header,
raw: raw,
frames: frames,
buffer: packetBuffer,
}, nil
}
func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {

View File

@ -4,62 +4,94 @@ import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type unpackedPacket struct {
packetNumber protocol.PacketNumber // the decoded packet number
hdr *wire.ExtendedHeader
encryptionLevel protocol.EncryptionLevel
frames []wire.Frame
}
type quicAEAD interface {
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
}
// The packetUnpacker unpacks QUIC packets.
type packetUnpacker struct {
aead quicAEAD
cs handshake.CryptoSetup
largestRcvdPacketNumber protocol.PacketNumber
version protocol.VersionNumber
}
var _ unpacker = &packetUnpacker{}
func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker {
return &packetUnpacker{
aead: aead,
cs: cs,
version: version,
}
}
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
buf := *getPacketBuffer()
buf = buf[:0]
defer putPacketBuffer(&buf)
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
r := bytes.NewReader(data)
var decrypted []byte
var encryptionLevel protocol.EncryptionLevel
var err error
var encLevel protocol.EncryptionLevel
switch hdr.Type {
case protocol.PacketTypeInitial:
decrypted, err = u.aead.OpenInitial(buf, data, hdr.PacketNumber, headerBinary)
encryptionLevel = protocol.EncryptionInitial
encLevel = protocol.EncryptionInitial
case protocol.PacketTypeHandshake:
decrypted, err = u.aead.OpenHandshake(buf, data, hdr.PacketNumber, headerBinary)
encryptionLevel = protocol.EncryptionHandshake
encLevel = protocol.EncryptionHandshake
default:
if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
}
decrypted, err = u.aead.Open1RTT(buf, data, hdr.PacketNumber, headerBinary)
encryptionLevel = protocol.Encryption1RTT
encLevel = protocol.Encryption1RTT
}
opener, err := u.cs.GetOpener(encLevel)
if err != nil {
return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
return nil, err
}
hdrLen := int(hdr.ParsedLen())
// The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
// 1. save a copy of the 4 bytes
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
opener.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
extHdr, err := hdr.ParseExtended(r, u.version)
if err != nil {
return nil, fmt.Errorf("error parsing extended header: %s", err)
}
extHdr.Raw = data[:hdrLen+int(extHdr.PacketNumberLen)]
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[hdrLen+int(extHdr.PacketNumberLen):hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
}
data = data[hdrLen+int(extHdr.PacketNumberLen):]
pn := protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
decrypted, err := opener.Open(data[:0], data, pn, extHdr.Raw)
if err != nil {
return nil, err
}
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, pn)
fs, err := u.parseFrames(decrypted)
if err != nil {
@ -67,7 +99,9 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, d
}
return &unpackedPacket{
encryptionLevel: encryptionLevel,
hdr: extHdr,
packetNumber: pn,
encryptionLevel: encLevel,
frames: fs,
}, nil
}

View File

@ -8,6 +8,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
@ -43,9 +44,8 @@ type receiveStream struct {
canceledRead bool // set when CancelRead() is called
resetRemotely bool // set when HandleResetStreamFrame() is called
readChan chan struct{}
deadline time.Time
deadlineTimer *time.Timer // initialized by SetReadDeadline()
readChan chan struct{}
deadline time.Time
flowController flowcontrol.StreamFlowController
version protocol.VersionNumber
@ -116,6 +116,7 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
return false, bytesRead, s.closeForShutdownErr
}
var deadlineTimer *utils.Timer
for {
// Stop waiting on errors
if s.closedForShutdown {
@ -128,8 +129,15 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
return false, bytesRead, s.resetRemotelyErr
}
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
return false, bytesRead, errDeadline
deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
return false, bytesRead, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
}
deadlineTimer.Reset(deadline)
}
if s.currentFrame != nil || s.currentFrameIsLast {
@ -137,12 +145,13 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
}
s.mutex.Unlock()
if s.deadline.IsZero() {
if deadline.IsZero() {
<-s.readChan
} else {
select {
case <-s.readChan:
case <-s.deadlineTimer.C:
case <-deadlineTimer.Chan():
deadlineTimer.SetRead()
}
}
s.mutex.Lock()
@ -259,22 +268,9 @@ func (s *receiveStream) CloseRemote(offset protocol.ByteCount) {
func (s *receiveStream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
defer s.mutex.Unlock()
s.deadline = t
if s.deadline.IsZero() { // skip if there's no deadline to set
s.signalRead()
return nil
}
// Lazily initialize the deadline timer.
if s.deadlineTimer == nil {
s.deadlineTimer = time.NewTimer(time.Until(t))
return nil
}
// reset the timer to the new deadline
if !s.deadlineTimer.Stop() {
<-s.deadlineTimer.C
}
s.deadlineTimer.Reset(time.Until(t))
s.mutex.Unlock()
s.signalRead()
return nil
}

View File

@ -42,9 +42,8 @@ type sendStream struct {
dataForWriting []byte
writeChan chan struct{}
deadline time.Time
deadlineTimer *time.Timer // initialized by SetReadDeadline()
writeChan chan struct{}
deadline time.Time
flowController flowcontrol.StreamFlowController
@ -95,41 +94,53 @@ func (s *sendStream) Write(p []byte) (int, error) {
return 0, nil
}
s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p)
go s.sender.onHasStreamData(s.streamID)
s.dataForWriting = p
var bytesWritten int
var err error
var (
deadlineTimer *utils.Timer
bytesWritten int
notifiedSender bool
)
for {
bytesWritten = len(p) - len(s.dataForWriting)
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
s.dataForWriting = nil
err = errDeadline
break
deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
s.dataForWriting = nil
return bytesWritten, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
}
deadlineTimer.Reset(deadline)
}
if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
break
}
s.mutex.Unlock()
if s.deadline.IsZero() {
if !notifiedSender {
s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
notifiedSender = true
}
if deadline.IsZero() {
<-s.writeChan
} else {
select {
case <-s.writeChan:
case <-s.deadlineTimer.C:
case <-deadlineTimer.Chan():
deadlineTimer.SetRead()
}
}
s.mutex.Lock()
}
if s.closeForShutdownErr != nil {
err = s.closeForShutdownErr
return bytesWritten, s.closeForShutdownErr
} else if s.cancelWriteErr != nil {
err = s.cancelWriteErr
return bytesWritten, s.cancelWriteErr
}
return bytesWritten, err
return bytesWritten, nil
}
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
@ -202,10 +213,12 @@ func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, boo
var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes]
ret = make([]byte, int(maxBytes))
copy(ret, s.dataForWriting[:maxBytes])
s.dataForWriting = s.dataForWriting[maxBytes:]
} else {
ret = s.dataForWriting
ret = make([]byte, len(s.dataForWriting))
copy(ret, s.dataForWriting)
s.dataForWriting = nil
s.signalWrite()
}
@ -216,13 +229,14 @@ func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, boo
func (s *sendStream) Close() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.canceledWrite {
s.mutex.Unlock()
return fmt.Errorf("Close called for canceled stream %d", s.streamID)
}
s.finishedWriting = true
go s.sender.onHasStreamData(s.streamID) // need to send the FIN
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
s.ctxCancel()
return nil
}
@ -233,7 +247,7 @@ func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
s.sender.onStreamCompleted(s.streamID) // must be called without holding the mutex
}
return err
}
@ -266,14 +280,11 @@ func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
}
func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) {
s.flowController.UpdateSendWindow(frame.ByteOffset)
s.mutex.Lock()
hasData := false
if s.dataForWriting != nil {
hasData = true
}
hasStreamData := s.dataForWriting != nil
s.mutex.Unlock()
if hasData {
s.flowController.UpdateSendWindow(frame.ByteOffset)
if hasStreamData {
s.sender.onHasStreamData(s.streamID)
}
}
@ -298,22 +309,9 @@ func (s *sendStream) Context() context.Context {
func (s *sendStream) SetWriteDeadline(t time.Time) error {
s.mutex.Lock()
defer s.mutex.Unlock()
s.deadline = t
if s.deadline.IsZero() { // skip if there's no deadline to set
s.signalWrite()
return nil
}
// Lazily initialize the deadline timer.
if s.deadlineTimer == nil {
s.deadlineTimer = time.NewTimer(time.Until(t))
return nil
}
// reset the timer to the new deadline
if !s.deadlineTimer.Stop() {
<-s.deadlineTimer.C
}
s.deadlineTimer.Reset(time.Until(t))
s.mutex.Unlock()
s.signalWrite()
return nil
}

View File

@ -43,6 +43,7 @@ type quicSession interface {
GetVersion() protocol.VersionNumber
run() error
destroy(error)
closeForRecreating() protocol.PacketNumber
closeRemote(error)
}
@ -317,21 +318,27 @@ func (s *server) handlePacket(p *receivedPacket) {
}
if hdr.Type == protocol.PacketTypeInitial {
go s.handleInitial(p)
return
}
// TODO(#943): send Stateless Reset
p.buffer.Release()
}
func (s *server) handleInitial(p *receivedPacket) {
// TODO: add a check that DestConnID == SrcConnID
s.logger.Debugf("<- Received Initial packet.")
sess, connID, err := s.handleInitialImpl(p)
if err != nil {
p.buffer.Release()
s.logger.Errorf("Error occurred handling initial packet: %s", err)
return
}
if sess == nil { // a retry was done
p.buffer.Release()
return
}
// Don't put the packet buffer back if a new session was created.
// The session will handle the packet and take of that.
serverSession := newServerSession(sess, s.config, s.logger)
s.sessionHandler.Add(connID, serverSession)
}
@ -454,6 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
}
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
defer p.buffer.Release()
hdr := p.hdr
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)

View File

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
@ -22,7 +21,7 @@ import (
)
type unpacker interface {
Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error)
}
type streamGetter interface {
@ -54,8 +53,10 @@ type cryptoStreamHandler interface {
type receivedPacket struct {
remoteAddr net.Addr
hdr *wire.Header
data []byte
rcvTime time.Time
data []byte
buffer *packetBuffer
}
type closeError struct {
@ -64,6 +65,8 @@ type closeError struct {
sendClose bool
}
var errCloseForRecreating = errors.New("closing session in order to recreate it")
// A Session is a QUIC session
type session struct {
sessionRunner sessionRunner
@ -112,9 +115,8 @@ type session struct {
handshakeCompleteChan chan struct{} // is closed when the handshake completes
handshakeComplete bool
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
receivedFirstPacket bool
receivedFirstForwardSecurePacket bool
largestRcvdPacketNumber protocol.PacketNumber // used to calculate the next packet number
sessionCreationTime time.Time
lastNetworkActivityTime time.Time
@ -158,6 +160,7 @@ var newSession = func(
version: v,
}
s.preSetup()
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.logger)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
s.streamsMap = newStreamsMap(
@ -185,7 +188,6 @@ var newSession = func(
return nil, err
}
s.cryptoStreamHandler = cs
s.framer = newFramer(s.streamsMap, s.version)
s.packer = newPacketPacker(
s.destConnID,
s.srcConnID,
@ -219,6 +221,7 @@ var newClientSession = func(
srcConnID protocol.ConnectionID,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
params *handshake.TransportParameters,
initialVersion protocol.VersionNumber,
logger utils.Logger,
@ -236,6 +239,7 @@ var newClientSession = func(
version: v,
}
s.preSetup()
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.logger)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
cs, clientHelloWritten, err := handshake.NewCryptoSetupClient(
@ -287,7 +291,6 @@ var newClientSession = func(
func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.InitialMaxData,
@ -361,18 +364,12 @@ runLoop:
// We do all the interesting stuff after the switch statement, so
// nothing to see here.
case p := <-s.receivedPackets:
err := s.handlePacketImpl(p)
if err != nil {
if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure {
s.tryQueueingUndecryptablePacket(p)
continue
}
s.closeLocal(err)
// Only reset the timers if this packet was actually processed.
// This avoids modifying any state when handling undecryptable packets,
// which could be injected by an attacker.
if wasProcessed := s.handlePacketImpl(p); !wasProcessed {
continue
}
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
// TODO: putPacketBuffer(&p.extHdr.Raw)
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
}
@ -476,64 +473,61 @@ func (s *session) handleHandshakeComplete() {
}
}
func (s *session) handlePacketImpl(p *receivedPacket) error {
func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
p.buffer.Release()
}
}()
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID)
return nil
return false
}
// drop 0-RTT packets
if p.hdr.Type == protocol.PacketType0RTT {
return false
}
data := p.data
r := bytes.NewReader(data)
hdr, err := p.hdr.ParseExtended(r, s.version)
if err != nil {
return fmt.Errorf("error parsing extended header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
data = data[len(data)-r.Len():]
if hdr.IsLongHeader {
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
}
if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
}
data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
// TODO(#1312): implement parsing of compound packets
}
// Calculate packet number
hdr.PacketNumber = protocol.InferPacketNumber(
hdr.PacketNumberLen,
s.largestRcvdPacketNumber,
hdr.PacketNumber,
)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
if s.logger.Debug() {
if err != nil {
s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID)
} else {
s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID, packet.encryptionLevel)
}
hdr.Log(s.logger)
}
packet, err := s.unpacker.Unpack(p.hdr, p.data)
// if the decryption failed, this might be a packet sent by an attacker
if err != nil {
return err
if err == handshake.ErrOpenerNotYetAvailable {
wasQueued = true
s.tryQueueingUndecryptablePacket(p)
return false
}
s.closeLocal(err)
return false
}
if s.logger.Debug() {
s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), p.hdr.DestConnectionID, packet.encryptionLevel)
packet.hdr.Log(s.logger)
}
if err := s.handleUnpackedPacket(packet, p.rcvTime); err != nil {
s.closeLocal(err)
return false
}
return true
}
func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time) error {
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID)
s.destConnID = hdr.SrcConnectionID
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", packet.hdr.SrcConnectionID)
s.destConnID = packet.hdr.SrcConnectionID
s.packer.ChangeDestConnectionID(s.destConnID)
}
s.receivedFirstPacket = true
s.lastNetworkActivityTime = p.rcvTime
s.lastNetworkActivityTime = rcvTime
s.keepAlivePingSent = false
// The client completes the handshake first (after sending the CFIN).
@ -545,19 +539,16 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
}
}
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
// If this is a Retry packet, there's no need to send an ACK.
// The session will be closed and recreated as soon as the crypto setup processed the HRR.
if hdr.Type != protocol.PacketTypeRetry {
if packet.hdr.Type != protocol.PacketTypeRetry {
isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames)
if err := s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, p.rcvTime, isRetransmittable); err != nil {
if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, rcvTime, isRetransmittable); err != nil {
return err
}
}
return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel)
return s.handleFrames(packet.frames, packet.packetNumber, packet.encryptionLevel)
}
func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
@ -740,6 +731,14 @@ func (s *session) destroy(e error) {
})
}
// closeForRecreating closes the session in order to recreate it immediately afterwards
// It returns the first packet number that should be used in the new session.
func (s *session) closeForRecreating() protocol.PacketNumber {
s.destroy(errCloseForRecreating)
nextPN, _ := s.sentPacketHandler.PeekPacketNumber()
return nextPN
}
func (s *session) closeRemote(e error) {
s.closeOnce.Do(func() {
s.sessionRunner.removeConnectionID(s.srcConnID)
@ -963,7 +962,7 @@ func (s *session) sendPacket() (bool, error) {
}
func (s *session) sendPackedPacket(packet *packedPacket) error {
defer putPacketBuffer(&packet.raw)
defer packet.buffer.Release()
s.logPacket(packet)
return s.conn.Write(packet.raw)
}
@ -986,7 +985,7 @@ func (s *session) logPacket(packet *packedPacket) {
// We don't need to allocate the slices for calling the format functions
return
}
s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.encryptionLevel)
s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.EncryptionLevel())
packet.header.Log(s.logger)
for _, frame := range packet.frames {
wire.LogFrame(s.logger, frame, true)

View File

@ -0,0 +1,57 @@
Copyright (c) 2017 Cloudflare. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Cloudflare nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
========================================================================
The x64 field arithmetic implementation was derived from the Microsoft Research
SIDH implementation, <https://github.com/Microsoft/PQCrypto-SIDH>, available
under the following license:
========================================================================
MIT License
Copyright (c) Microsoft Corporation. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE

View File

@ -0,0 +1,66 @@
// +build noasm !amd64
package internal
// helper used for Uint128 representation
type Uint128 struct {
H, L uint64
}
// Adds 2 64bit digits in constant time.
// Returns result and carry (1 or 0)
func Addc64(cin, a, b uint64) (ret, cout uint64) {
t := a + cin
ret = b + t
cout = ((a & b) | ((a | b) & (^ret))) >> 63
return
}
// Substracts 2 64bit digits in constant time.
// Returns result and borrow (1 or 0)
func Subc64(bIn, a, b uint64) (ret, bOut uint64) {
var tmp1 = a - b
// Set bOut if bIn!=0 and tmp1==0 in constant time
bOut = bIn & (1 ^ ((tmp1 | uint64(0-tmp1)) >> 63))
// Constant time check if x<y
bOut |= (a ^ ((a ^ b) | (uint64(a-b) ^ b))) >> 63
ret = tmp1 - bIn
return
}
// Multiplies 2 64bit digits in constant time
func Mul64(a, b uint64) (res Uint128) {
var al, bl, ah, bh, albl, albh, ahbl, ahbh uint64
var res1, res2, res3 uint64
var carry, maskL, maskH, temp uint64
maskL = (^maskL) >> 32
maskH = ^maskL
al = a & maskL
ah = a >> 32
bl = b & maskL
bh = b >> 32
albl = al * bl
albh = al * bh
ahbl = ah * bl
ahbh = ah * bh
res.L = albl & maskL
res1 = albl >> 32
res2 = ahbl & maskL
res3 = albh & maskL
temp = res1 + res2 + res3
carry = temp >> 32
res.L ^= temp << 32
res1 = ahbl >> 32
res2 = albh >> 32
res3 = ahbh & maskL
temp = res1 + res2 + res3 + carry
res.H = temp & maskL
carry = temp & maskH
res.H ^= (ahbh & maskH) + carry
return
}

View File

@ -0,0 +1,440 @@
package internal
type CurveOperations struct {
Params *SidhParams
}
// Computes j-invariant for a curve y2=x3+A/Cx+x with A,C in F_(p^2). Result
// is returned in jBytes buffer, encoded in little-endian format. Caller
// provided jBytes buffer has to be big enough to j-invariant value. In case
// of SIDH, buffer size must be at least size of shared secret.
// Implementation corresponds to Algorithm 9 from SIKE.
func (c *CurveOperations) Jinvariant(cparams *ProjectiveCurveParameters, jBytes []byte) {
var j, t0, t1 Fp2Element
op := c.Params.Op
op.Square(&j, &cparams.A) // j = A^2
op.Square(&t1, &cparams.C) // t1 = C^2
op.Add(&t0, &t1, &t1) // t0 = t1 + t1
op.Sub(&t0, &j, &t0) // t0 = j - t0
op.Sub(&t0, &t0, &t1) // t0 = t0 - t1
op.Sub(&j, &t0, &t1) // t0 = t0 - t1
op.Square(&t1, &t1) // t1 = t1^2
op.Mul(&j, &j, &t1) // j = j * t1
op.Add(&t0, &t0, &t0) // t0 = t0 + t0
op.Add(&t0, &t0, &t0) // t0 = t0 + t0
op.Square(&t1, &t0) // t1 = t0^2
op.Mul(&t0, &t0, &t1) // t0 = t0 * t1
op.Add(&t0, &t0, &t0) // t0 = t0 + t0
op.Add(&t0, &t0, &t0) // t0 = t0 + t0
op.Inv(&j, &j) // j = 1/j
op.Mul(&j, &t0, &j) // j = t0 * j
c.Fp2ToBytes(jBytes, &j)
}
// Given affine points x(P), x(Q) and x(Q-P) in a extension field F_{p^2}, function
// recorvers projective coordinate A of a curve. This is Algorithm 10 from SIKE.
func (c *CurveOperations) RecoverCoordinateA(curve *ProjectiveCurveParameters, xp, xq, xr *Fp2Element) {
var t0, t1 Fp2Element
op := c.Params.Op
op.Add(&t1, xp, xq) // t1 = Xp + Xq
op.Mul(&t0, xp, xq) // t0 = Xp * Xq
op.Mul(&curve.A, xr, &t1) // A = X(q-p) * t1
op.Add(&curve.A, &curve.A, &t0) // A = A + t0
op.Mul(&t0, &t0, xr) // t0 = t0 * X(q-p)
op.Sub(&curve.A, &curve.A, &c.Params.OneFp2) // A = A - 1
op.Add(&t0, &t0, &t0) // t0 = t0 + t0
op.Add(&t1, &t1, xr) // t1 = t1 + X(q-p)
op.Add(&t0, &t0, &t0) // t0 = t0 + t0
op.Square(&curve.A, &curve.A) // A = A^2
op.Inv(&t0, &t0) // t0 = 1/t0
op.Mul(&curve.A, &curve.A, &t0) // A = A * t0
op.Sub(&curve.A, &curve.A, &t1) // A = A - t1
}
// Computes equivalence (A:C) ~ (A+2C : A-2C)
func (c *CurveOperations) CalcCurveParamsEquiv3(cparams *ProjectiveCurveParameters) CurveCoefficientsEquiv {
var coef CurveCoefficientsEquiv
var c2 Fp2Element
var op = c.Params.Op
op.Add(&c2, &cparams.C, &cparams.C)
// A24p = A+2*C
op.Add(&coef.A, &cparams.A, &c2)
// A24m = A-2*C
op.Sub(&coef.C, &cparams.A, &c2)
return coef
}
// Computes equivalence (A:C) ~ (A+2C : 4C)
func (c *CurveOperations) CalcCurveParamsEquiv4(cparams *ProjectiveCurveParameters) CurveCoefficientsEquiv {
var coefEq CurveCoefficientsEquiv
var op = c.Params.Op
op.Add(&coefEq.C, &cparams.C, &cparams.C)
// A24p = A+2C
op.Add(&coefEq.A, &cparams.A, &coefEq.C)
// C24 = 4*C
op.Add(&coefEq.C, &coefEq.C, &coefEq.C)
return coefEq
}
// Helper function for RightToLeftLadder(). Returns A+2C / 4.
func (c *CurveOperations) CalcAplus2Over4(cparams *ProjectiveCurveParameters) (ret Fp2Element) {
var tmp Fp2Element
var op = c.Params.Op
// 2C
op.Add(&tmp, &cparams.C, &cparams.C)
// A+2C
op.Add(&ret, &cparams.A, &tmp)
// 1/4C
op.Add(&tmp, &tmp, &tmp)
op.Inv(&tmp, &tmp)
// A+2C/4C
op.Mul(&ret, &ret, &tmp)
return
}
// Recovers (A:C) curve parameters from projectively equivalent (A+2C:A-2C).
func (c *CurveOperations) RecoverCurveCoefficients3(cparams *ProjectiveCurveParameters, coefEq *CurveCoefficientsEquiv) {
var op = c.Params.Op
op.Add(&cparams.A, &coefEq.A, &coefEq.C)
// cparams.A = 2*(A+2C+A-2C) = 4A
op.Add(&cparams.A, &cparams.A, &cparams.A)
// cparams.C = (A+2C-A+2C) = 4C
op.Sub(&cparams.C, &coefEq.A, &coefEq.C)
return
}
// Recovers (A:C) curve parameters from projectively equivalent (A+2C:4C).
func (c *CurveOperations) RecoverCurveCoefficients4(cparams *ProjectiveCurveParameters, coefEq *CurveCoefficientsEquiv) {
var op = c.Params.Op
// cparams.C = (4C)*1/2=2C
op.Mul(&cparams.C, &coefEq.C, &c.Params.HalfFp2)
// cparams.A = A+2C - 2C = A
op.Sub(&cparams.A, &coefEq.A, &cparams.C)
// cparams.C = 2C * 1/2 = C
op.Mul(&cparams.C, &cparams.C, &c.Params.HalfFp2)
return
}
// Combined coordinate doubling and differential addition. Takes projective points
// P,Q,Q-P and (A+2C)/4C curve E coefficient. Returns 2*P and P+Q calculated on E.
// Function is used only by RightToLeftLadder. Corresponds to Algorithm 5 of SIKE
func (c *CurveOperations) xDblAdd(P, Q, QmP *ProjectivePoint, a24 *Fp2Element) (dblP, PaQ ProjectivePoint) {
var t0, t1, t2 Fp2Element
var op = c.Params.Op
xQmP, zQmP := &QmP.X, &QmP.Z
xPaQ, zPaQ := &PaQ.X, &PaQ.Z
x2P, z2P := &dblP.X, &dblP.Z
xP, zP := &P.X, &P.Z
xQ, zQ := &Q.X, &Q.Z
op.Add(&t0, xP, zP) // t0 = Xp+Zp
op.Sub(&t1, xP, zP) // t1 = Xp-Zp
op.Square(x2P, &t0) // 2P.X = t0^2
op.Sub(&t2, xQ, zQ) // t2 = Xq-Zq
op.Add(xPaQ, xQ, zQ) // Xp+q = Xq+Zq
op.Mul(&t0, &t0, &t2) // t0 = t0 * t2
op.Mul(z2P, &t1, &t1) // 2P.Z = t1 * t1
op.Mul(&t1, &t1, xPaQ) // t1 = t1 * Xp+q
op.Sub(&t2, x2P, z2P) // t2 = 2P.X - 2P.Z
op.Mul(x2P, x2P, z2P) // 2P.X = 2P.X * 2P.Z
op.Mul(xPaQ, a24, &t2) // Xp+q = A24 * t2
op.Sub(zPaQ, &t0, &t1) // Zp+q = t0 - t1
op.Add(z2P, xPaQ, z2P) // 2P.Z = Xp+q + 2P.Z
op.Add(xPaQ, &t0, &t1) // Xp+q = t0 + t1
op.Mul(z2P, z2P, &t2) // 2P.Z = 2P.Z * t2
op.Square(zPaQ, zPaQ) // Zp+q = Zp+q ^ 2
op.Square(xPaQ, xPaQ) // Xp+q = Xp+q ^ 2
op.Mul(zPaQ, xQmP, zPaQ) // Zp+q = Xq-p * Zp+q
op.Mul(xPaQ, zQmP, xPaQ) // Xp+q = Zq-p * Xp+q
return
}
// Given the curve parameters, xP = x(P), computes xP = x([2^k]P)
// Safe to overlap xP, x2P.
func (c *CurveOperations) Pow2k(xP *ProjectivePoint, params *CurveCoefficientsEquiv, k uint32) {
var t0, t1 Fp2Element
var op = c.Params.Op
x, z := &xP.X, &xP.Z
for i := uint32(0); i < k; i++ {
op.Sub(&t0, x, z) // t0 = Xp - Zp
op.Add(&t1, x, z) // t1 = Xp + Zp
op.Square(&t0, &t0) // t0 = t0 ^ 2
op.Square(&t1, &t1) // t1 = t1 ^ 2
op.Mul(z, &params.C, &t0) // Z2p = C24 * t0
op.Mul(x, z, &t1) // X2p = Z2p * t1
op.Sub(&t1, &t1, &t0) // t1 = t1 - t0
op.Mul(&t0, &params.A, &t1) // t0 = A24+ * t1
op.Add(z, z, &t0) // Z2p = Z2p + t0
op.Mul(z, z, &t1) // Zp = Z2p * t1
}
}
// Given the curve parameters, xP = x(P), and k >= 0, compute xP = x([3^k]P).
//
// Safe to overlap xP, xR.
func (c *CurveOperations) Pow3k(xP *ProjectivePoint, params *CurveCoefficientsEquiv, k uint32) {
var t0, t1, t2, t3, t4, t5, t6 Fp2Element
var op = c.Params.Op
x, z := &xP.X, &xP.Z
for i := uint32(0); i < k; i++ {
op.Sub(&t0, x, z) // t0 = Xp - Zp
op.Square(&t2, &t0) // t2 = t0^2
op.Add(&t1, x, z) // t1 = Xp + Zp
op.Square(&t3, &t1) // t3 = t1^2
op.Add(&t4, &t1, &t0) // t4 = t1 + t0
op.Sub(&t0, &t1, &t0) // t0 = t1 - t0
op.Square(&t1, &t4) // t1 = t4^2
op.Sub(&t1, &t1, &t3) // t1 = t1 - t3
op.Sub(&t1, &t1, &t2) // t1 = t1 - t2
op.Mul(&t5, &t3, &params.A) // t5 = t3 * A24+
op.Mul(&t3, &t3, &t5) // t3 = t5 * t3
op.Mul(&t6, &t2, &params.C) // t6 = t2 * A24-
op.Mul(&t2, &t2, &t6) // t2 = t2 * t6
op.Sub(&t3, &t2, &t3) // t3 = t2 - t3
op.Sub(&t2, &t5, &t6) // t2 = t5 - t6
op.Mul(&t1, &t2, &t1) // t1 = t2 * t1
op.Add(&t2, &t3, &t1) // t2 = t3 + t1
op.Square(&t2, &t2) // t2 = t2^2
op.Mul(x, &t2, &t4) // X3p = t2 * t4
op.Sub(&t1, &t3, &t1) // t1 = t3 - t1
op.Square(&t1, &t1) // t1 = t1^2
op.Mul(z, &t1, &t0) // Z3p = t1 * t0
}
}
// Set (y1, y2, y3) = (1/x1, 1/x2, 1/x3).
//
// All xi, yi must be distinct.
func (c *CurveOperations) Fp2Batch3Inv(x1, x2, x3, y1, y2, y3 *Fp2Element) {
var x1x2, t Fp2Element
var op = c.Params.Op
op.Mul(&x1x2, x1, x2) // x1*x2
op.Mul(&t, &x1x2, x3) // 1/(x1*x2*x3)
op.Inv(&t, &t)
op.Mul(y1, &t, x2) // 1/x1
op.Mul(y1, y1, x3)
op.Mul(y2, &t, x1) // 1/x2
op.Mul(y2, y2, x3)
op.Mul(y3, &t, &x1x2) // 1/x3
}
// ScalarMul3Pt is a right-to-left point multiplication that given the
// x-coordinate of P, Q and P-Q calculates the x-coordinate of R=Q+[scalar]P.
// nbits must be smaller or equal to len(scalar).
func (c *CurveOperations) ScalarMul3Pt(cparams *ProjectiveCurveParameters, P, Q, PmQ *ProjectivePoint, nbits uint, scalar []uint8) ProjectivePoint {
var R0, R2, R1 ProjectivePoint
var op = c.Params.Op
aPlus2Over4 := c.CalcAplus2Over4(cparams)
R1 = *P
R2 = *PmQ
R0 = *Q
// Iterate over the bits of the scalar, bottom to top
prevBit := uint8(0)
for i := uint(0); i < nbits; i++ {
bit := (scalar[i>>3] >> (i & 7) & 1)
swap := prevBit ^ bit
prevBit = bit
op.CondSwap(&R1.X, &R1.Z, &R2.X, &R2.Z, swap)
R0, R2 = c.xDblAdd(&R0, &R2, &R1, &aPlus2Over4)
}
op.CondSwap(&R1.X, &R1.Z, &R2.X, &R2.Z, prevBit)
return R1
}
// Convert the input to wire format.
//
// The output byte slice must be at least 2*bytelen(p) bytes long.
func (c *CurveOperations) Fp2ToBytes(output []byte, fp2 *Fp2Element) {
if len(output) < 2*c.Params.Bytelen {
panic("output byte slice too short")
}
var a Fp2Element
c.Params.Op.FromMontgomery(fp2, &a)
// convert to bytes in little endian form
for i := 0; i < c.Params.Bytelen; i++ {
// set i = j*8 + k
fp2 := i / 8
k := uint64(i % 8)
output[i] = byte(a.A[fp2] >> (8 * k))
output[i+c.Params.Bytelen] = byte(a.B[fp2] >> (8 * k))
}
}
// Read 2*bytelen(p) bytes into the given ExtensionFieldElement.
//
// It is an error to call this function if the input byte slice is less than 2*bytelen(p) bytes long.
func (c *CurveOperations) Fp2FromBytes(fp2 *Fp2Element, input []byte) {
if len(input) < 2*c.Params.Bytelen {
panic("input byte slice too short")
}
for i := 0; i < c.Params.Bytelen; i++ {
j := i / 8
k := uint64(i % 8)
fp2.A[j] |= uint64(input[i]) << (8 * k)
fp2.B[j] |= uint64(input[i+c.Params.Bytelen]) << (8 * k)
}
c.Params.Op.ToMontgomery(fp2)
}
/* -------------------------------------------------------------------------
Mechnisms used for isogeny calculations
-------------------------------------------------------------------------*/
// Constructs isogeny3 objects
func Newisogeny3(op FieldOps) Isogeny {
return &isogeny3{Field: op}
}
// Constructs isogeny4 objects
func Newisogeny4(op FieldOps) Isogeny {
return &isogeny4{isogeny3: isogeny3{Field: op}}
}
// Given a three-torsion point p = x(PB) on the curve E_(A:C), construct the
// three-isogeny phi : E_(A:C) -> E_(A:C)/<P_3> = E_(A':C').
//
// Input: (XP_3: ZP_3), where P_3 has exact order 3 on E_A/C
// Output: * Curve coordinates (A' + 2C', A' - 2C') corresponding to E_A'/C' = A_E/C/<P3>
// * Isogeny phi with constants in F_p^2
func (phi *isogeny3) GenerateCurve(p *ProjectivePoint) CurveCoefficientsEquiv {
var t0, t1, t2, t3, t4 Fp2Element
var coefEq CurveCoefficientsEquiv
var K1, K2 = &phi.K1, &phi.K2
op := phi.Field
op.Sub(K1, &p.X, &p.Z) // K1 = XP3 - ZP3
op.Square(&t0, K1) // t0 = K1^2
op.Add(K2, &p.X, &p.Z) // K2 = XP3 + ZP3
op.Square(&t1, K2) // t1 = K2^2
op.Add(&t2, &t0, &t1) // t2 = t0 + t1
op.Add(&t3, K1, K2) // t3 = K1 + K2
op.Square(&t3, &t3) // t3 = t3^2
op.Sub(&t3, &t3, &t2) // t3 = t3 - t2
op.Add(&t2, &t1, &t3) // t2 = t1 + t3
op.Add(&t3, &t3, &t0) // t3 = t3 + t0
op.Add(&t4, &t3, &t0) // t4 = t3 + t0
op.Add(&t4, &t4, &t4) // t4 = t4 + t4
op.Add(&t4, &t1, &t4) // t4 = t1 + t4
op.Mul(&coefEq.C, &t2, &t4) // A24m = t2 * t4
op.Add(&t4, &t1, &t2) // t4 = t1 + t2
op.Add(&t4, &t4, &t4) // t4 = t4 + t4
op.Add(&t4, &t0, &t4) // t4 = t0 + t4
op.Mul(&t4, &t3, &t4) // t4 = t3 * t4
op.Sub(&t0, &t4, &coefEq.C) // t0 = t4 - A24m
op.Add(&coefEq.A, &coefEq.C, &t0) // A24p = A24m + t0
return coefEq
}
// Given a 3-isogeny phi and a point pB = x(PB), compute x(QB), the x-coordinate
// of the image QB = phi(PB) of PB under phi : E_(A:C) -> E_(A':C').
//
// The output xQ = x(Q) is then a point on the curve E_(A':C'); the curve
// parameters are returned by the GenerateCurve function used to construct phi.
func (phi *isogeny3) EvaluatePoint(p *ProjectivePoint) ProjectivePoint {
var t0, t1, t2 Fp2Element
var q ProjectivePoint
var K1, K2 = &phi.K1, &phi.K2
var px, pz = &p.X, &p.Z
op := phi.Field
op.Add(&t0, px, pz) // t0 = XQ + ZQ
op.Sub(&t1, px, pz) // t1 = XQ - ZQ
op.Mul(&t0, K1, &t0) // t2 = K1 * t0
op.Mul(&t1, K2, &t1) // t1 = K2 * t1
op.Add(&t2, &t0, &t1) // t2 = t0 + t1
op.Sub(&t0, &t1, &t0) // t0 = t1 - t0
op.Square(&t2, &t2) // t2 = t2 ^ 2
op.Square(&t0, &t0) // t0 = t0 ^ 2
op.Mul(&q.X, px, &t2) // XQ'= XQ * t2
op.Mul(&q.Z, pz, &t0) // ZQ'= ZQ * t0
return q
}
// Given a four-torsion point p = x(PB) on the curve E_(A:C), construct the
// four-isogeny phi : E_(A:C) -> E_(A:C)/<P_4> = E_(A':C').
//
// Input: (XP_4: ZP_4), where P_4 has exact order 4 on E_A/C
// Output: * Curve coordinates (A' + 2C', 4C') corresponding to E_A'/C' = A_E/C/<P4>
// * Isogeny phi with constants in F_p^2
func (phi *isogeny4) GenerateCurve(p *ProjectivePoint) CurveCoefficientsEquiv {
var coefEq CurveCoefficientsEquiv
var xp4, zp4 = &p.X, &p.Z
var K1, K2, K3 = &phi.K1, &phi.K2, &phi.K3
op := phi.Field
op.Sub(K2, xp4, zp4)
op.Add(K3, xp4, zp4)
op.Square(K1, zp4)
op.Add(K1, K1, K1)
op.Square(&coefEq.C, K1)
op.Add(K1, K1, K1)
op.Square(&coefEq.A, xp4)
op.Add(&coefEq.A, &coefEq.A, &coefEq.A)
op.Square(&coefEq.A, &coefEq.A)
return coefEq
}
// Given a 4-isogeny phi and a point xP = x(P), compute x(Q), the x-coordinate
// of the image Q = phi(P) of P under phi : E_(A:C) -> E_(A':C').
//
// Input: Isogeny returned by GenerateCurve and point q=(Qx,Qz) from E0_A/C
// Output: Corresponding point q from E1_A'/C', where E1 is 4-isogenous to E0
func (phi *isogeny4) EvaluatePoint(p *ProjectivePoint) ProjectivePoint {
var t0, t1 Fp2Element
var q = *p
var xq, zq = &q.X, &q.Z
var K1, K2, K3 = &phi.K1, &phi.K2, &phi.K3
op := phi.Field
op.Add(&t0, xq, zq)
op.Sub(&t1, xq, zq)
op.Mul(xq, &t0, K2)
op.Mul(zq, &t1, K3)
op.Mul(&t0, &t0, &t1)
op.Mul(&t0, &t0, K1)
op.Add(&t1, xq, zq)
op.Sub(zq, xq, zq)
op.Square(&t1, &t1)
op.Square(zq, zq)
op.Add(xq, &t0, &t1)
op.Sub(&t0, zq, &t0)
op.Mul(xq, xq, &t1)
op.Mul(zq, zq, &t0)
return q
}
/* -------------------------------------------------------------------------
Utils
-------------------------------------------------------------------------*/
func (point *ProjectivePoint) ToAffine(c *CurveOperations) *Fp2Element {
var affine_x Fp2Element
c.Params.Op.Inv(&affine_x, &point.Z)
c.Params.Op.Mul(&affine_x, &affine_x, &point.X)
return &affine_x
}
// Cleans data in fp
func (fp *Fp2Element) Zeroize() {
// Zeroizing in 2 seperated loops tells compiler to
// use fast runtime.memclr()
for i := range fp.A {
fp.A[i] = 0
}
for i := range fp.B {
fp.B[i] = 0
}
}

View File

@ -0,0 +1,140 @@
package internal
const (
FP_MAX_WORDS = 12 // Currently p751.NumWords
)
// Representation of an element of the base field F_p.
//
// No particular meaning is assigned to the representation -- it could represent
// an element in Montgomery form, or not. Tracking the meaning of the field
// element is left to higher types.
type FpElement [FP_MAX_WORDS]uint64
// Represents an intermediate product of two elements of the base field F_p.
type FpElementX2 [2 * FP_MAX_WORDS]uint64
// Represents an element of the extended field Fp^2 = Fp(x+i)
type Fp2Element struct {
A FpElement
B FpElement
}
type DomainParams struct {
// P, Q and R=P-Q base points
Affine_P, Affine_Q, Affine_R Fp2Element
// Size of a compuatation strategy for x-torsion group
IsogenyStrategy []uint32
// Max size of secret key for x-torsion group
SecretBitLen uint
// Max size of secret key for x-torsion group
SecretByteLen uint
}
type SidhParams struct {
Id uint8
// Bytelen of P
Bytelen int
// The public key size, in bytes.
PublicKeySize int
// The shared secret size, in bytes.
SharedSecretSize uint
// 2- and 3-torsion group parameter definitions
A, B DomainParams
// Precomputed identity element in the Fp2 in Montgomery domain
OneFp2 Fp2Element
// Precomputed 1/2 in the Fp2 in Montgomery domain
HalfFp2 Fp2Element
// Length of SIKE secret message. Must be one of {24,32,40},
// depending on size of prime field used (see [SIKE], 1.4 and 5.1)
MsgLen uint
// Length of SIKE ephemeral KEM key (see [SIKE], 1.4 and 5.1)
KemSize uint
// Access to field arithmetic
Op FieldOps
}
// Interface for working with isogenies.
type Isogeny interface {
// Given a torsion point on a curve computes isogenous curve.
// Returns curve coefficients (A:C), so that E_(A/C) = E_(A/C)/<P>,
// where P is a provided projective point. Sets also isogeny constants
// that are needed for isogeny evaluation.
GenerateCurve(*ProjectivePoint) CurveCoefficientsEquiv
// Evaluates isogeny at caller provided point. Requires isogeny curve constants
// to be earlier computed by GenerateCurve.
EvaluatePoint(*ProjectivePoint) ProjectivePoint
}
// Stores curve projective parameters equivalent to A/C. Meaning of the
// values depends on the context. When working with isogenies over
// subgroup that are powers of:
// * three then (A:C) ~ (A+2C:A-2C)
// * four then (A:C) ~ (A+2C: 4C)
// See Appendix A of SIKE for more details
type CurveCoefficientsEquiv struct {
A Fp2Element
C Fp2Element
}
// A point on the projective line P^1(F_{p^2}).
//
// This represents a point on the Kummer line of a Montgomery curve. The
// curve is specified by a ProjectiveCurveParameters struct.
type ProjectivePoint struct {
X Fp2Element
Z Fp2Element
}
// A point on the projective line P^1(F_{p^2}).
//
// This is used to work projectively with the curve coefficients.
type ProjectiveCurveParameters struct {
A Fp2Element
C Fp2Element
}
// Stores Isogeny 3 curve constants
type isogeny3 struct {
Field FieldOps
K1 Fp2Element
K2 Fp2Element
}
// Stores Isogeny 4 curve constants
type isogeny4 struct {
isogeny3
K3 Fp2Element
}
type FieldOps interface {
// Set res = lhs + rhs.
//
// Allowed to overlap lhs or rhs with res.
Add(res, lhs, rhs *Fp2Element)
// Set res = lhs - rhs.
//
// Allowed to overlap lhs or rhs with res.
Sub(res, lhs, rhs *Fp2Element)
// Set res = lhs * rhs.
//
// Allowed to overlap lhs or rhs with res.
Mul(res, lhs, rhs *Fp2Element)
// Set res = x * x
//
// Allowed to overlap res with x.
Square(res, x *Fp2Element)
// Set res = 1/x
//
// Allowed to overlap res with x.
Inv(res, x *Fp2Element)
// If choice = 1u8, set (x,y) = (y,x). If choice = 0u8, set (x,y) = (x,y).
CondSwap(xPx, xPz, xQx, xQz *Fp2Element, choice uint8)
// Converts Fp2Element to Montgomery domain (x*R mod p)
ToMontgomery(x *Fp2Element)
// Converts 'a' in montgomery domain to element from Fp2Element
// and stores it in 'x'
FromMontgomery(x *Fp2Element, a *Fp2Element)
}

View File

@ -0,0 +1,11 @@
package utils
type x86 struct {
// Signals support for MULX which is in BMI2
HasBMI2 bool
// Signals support for ADX
HasADX bool
}
var X86 x86

View File

@ -0,0 +1,29 @@
// +build amd64,!noasm
// Sets capabilities flags for x86 according to information received from
// CPUID. It was written in accordance with
// "Intel® 64 and IA-32 Architectures Developer's Manual: Vol. 2A".
// https://www.intel.com/content/www/us/en/architecture-and-technology/64-ia-32-architectures-software-developer-vol-2a-manual.html
package utils
// Performs CPUID and returns values of registers
// go:nosplit
func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
// Returns true in case bit 'n' in 'bits' is set, otherwise false
func bitn(bits uint32, n uint8) bool {
return (bits>>n)&1 == 1
}
func init() {
// CPUID returns max possible input that can be requested
max, _, _, _ := cpuid(0, 0)
if max < 7 {
return
}
_, ebx, _, _ := cpuid(7, 0)
X86.HasBMI2 = bitn(ebx, 8)
X86.HasADX = bitn(ebx, 19)
}

View File

@ -0,0 +1,13 @@
// +build amd64,!noasm
#include "textflag.h"
TEXT ·cpuid(SB), NOSPLIT, $0-4
MOVL eaxArg+0(FP), AX
MOVL ecxArg+4(FP), CX
CPUID
MOVL AX, eax+8(FP)
MOVL BX, ebx+12(FP)
MOVL CX, ecx+16(FP)
MOVL DX, edx+20(FP)
RET

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,802 @@
// +build arm64,!noasm
#include "textflag.h"
TEXT ·fp503ConditionalSwap(SB), NOSPLIT, $0-17
MOVD x+0(FP), R0
MOVD y+8(FP), R1
MOVB choice+16(FP), R2
// Set flags
// If choice is not 0 or 1, this implementation will swap completely
CMP $0, R2
LDP 0(R0), (R3, R4)
LDP 0(R1), (R5, R6)
CSEL EQ, R3, R5, R7
CSEL EQ, R4, R6, R8
STP (R7, R8), 0(R0)
CSEL NE, R3, R5, R9
CSEL NE, R4, R6, R10
STP (R9, R10), 0(R1)
LDP 16(R0), (R3, R4)
LDP 16(R1), (R5, R6)
CSEL EQ, R3, R5, R7
CSEL EQ, R4, R6, R8
STP (R7, R8), 16(R0)
CSEL NE, R3, R5, R9
CSEL NE, R4, R6, R10
STP (R9, R10), 16(R1)
LDP 32(R0), (R3, R4)
LDP 32(R1), (R5, R6)
CSEL EQ, R3, R5, R7
CSEL EQ, R4, R6, R8
STP (R7, R8), 32(R0)
CSEL NE, R3, R5, R9
CSEL NE, R4, R6, R10
STP (R9, R10), 32(R1)
LDP 48(R0), (R3, R4)
LDP 48(R1), (R5, R6)
CSEL EQ, R3, R5, R7
CSEL EQ, R4, R6, R8
STP (R7, R8), 48(R0)
CSEL NE, R3, R5, R9
CSEL NE, R4, R6, R10
STP (R9, R10), 48(R1)
RET
TEXT ·fp503AddReduced(SB), NOSPLIT, $0-24
MOVD z+0(FP), R2
MOVD x+8(FP), R0
MOVD y+16(FP), R1
// Load first summand into R3-R10
// Add first summand and second summand and store result in R3-R10
LDP 0(R0), (R3, R4)
LDP 0(R1), (R11, R12)
LDP 16(R0), (R5, R6)
LDP 16(R1), (R13, R14)
ADDS R11, R3
ADCS R12, R4
ADCS R13, R5
ADCS R14, R6
LDP 32(R0), (R7, R8)
LDP 32(R1), (R11, R12)
LDP 48(R0), (R9, R10)
LDP 48(R1), (R13, R14)
ADCS R11, R7
ADCS R12, R8
ADCS R13, R9
ADC R14, R10
// Subtract 2 * p503 in R11-R17 from the result in R3-R10
LDP ·p503x2+0(SB), (R11, R12)
LDP ·p503x2+24(SB), (R13, R14)
SUBS R11, R3
SBCS R12, R4
LDP ·p503x2+40(SB), (R15, R16)
SBCS R12, R5
SBCS R13, R6
MOVD ·p503x2+56(SB), R17
SBCS R14, R7
SBCS R15, R8
SBCS R16, R9
SBCS R17, R10
SBC ZR, ZR, R19
// If x + y - 2 * p503 < 0, R19 is 1 and 2 * p503 should be added
AND R19, R11
AND R19, R12
AND R19, R13
AND R19, R14
AND R19, R15
AND R19, R16
AND R19, R17
ADDS R11, R3
ADCS R12, R4
STP (R3, R4), 0(R2)
ADCS R12, R5
ADCS R13, R6
STP (R5, R6), 16(R2)
ADCS R14, R7
ADCS R15, R8
STP (R7, R8), 32(R2)
ADCS R16, R9
ADC R17, R10
STP (R9, R10), 48(R2)
RET
TEXT ·fp503SubReduced(SB), NOSPLIT, $0-24
MOVD z+0(FP), R2
MOVD x+8(FP), R0
MOVD y+16(FP), R1
// Load x into R3-R10
// Subtract y from x and store result in R3-R10
LDP 0(R0), (R3, R4)
LDP 0(R1), (R11, R12)
LDP 16(R0), (R5, R6)
LDP 16(R1), (R13, R14)
SUBS R11, R3
SBCS R12, R4
SBCS R13, R5
SBCS R14, R6
LDP 32(R0), (R7, R8)
LDP 32(R1), (R11, R12)
LDP 48(R0), (R9, R10)
LDP 48(R1), (R13, R14)
SBCS R11, R7
SBCS R12, R8
SBCS R13, R9
SBCS R14, R10
SBC ZR, ZR, R19
// If x - y < 0, R19 is 1 and 2 * p503 should be added
LDP ·p503x2+0(SB), (R11, R12)
LDP ·p503x2+24(SB), (R13, R14)
AND R19, R11
AND R19, R12
LDP ·p503x2+40(SB), (R15, R16)
AND R19, R13
AND R19, R14
MOVD ·p503x2+56(SB), R17
AND R19, R15
AND R19, R16
AND R19, R17
ADDS R11, R3
ADCS R12, R4
STP (R3, R4), 0(R2)
ADCS R12, R5
ADCS R13, R6
STP (R5, R6), 16(R2)
ADCS R14, R7
ADCS R15, R8
STP (R7, R8), 32(R2)
ADCS R16, R9
ADC R17, R10
STP (R9, R10), 48(R2)
RET
TEXT ·fp503AddLazy(SB), NOSPLIT, $0-24
MOVD z+0(FP), R2
MOVD x+8(FP), R0
MOVD y+16(FP), R1
// Load first summand into R3-R10
// Add first summand and second summand and store result in R3-R10
LDP 0(R0), (R3, R4)
LDP 0(R1), (R11, R12)
LDP 16(R0), (R5, R6)
LDP 16(R1), (R13, R14)
ADDS R11, R3
ADCS R12, R4
STP (R3, R4), 0(R2)
ADCS R13, R5
ADCS R14, R6
STP (R5, R6), 16(R2)
LDP 32(R0), (R7, R8)
LDP 32(R1), (R11, R12)
LDP 48(R0), (R9, R10)
LDP 48(R1), (R13, R14)
ADCS R11, R7
ADCS R12, R8
STP (R7, R8), 32(R2)
ADCS R13, R9
ADC R14, R10
STP (R9, R10), 48(R2)
RET
TEXT ·fp503X2AddLazy(SB), NOSPLIT, $0-24
MOVD z+0(FP), R2
MOVD x+8(FP), R0
MOVD y+16(FP), R1
LDP 0(R0), (R3, R4)
LDP 0(R1), (R11, R12)
LDP 16(R0), (R5, R6)
LDP 16(R1), (R13, R14)
ADDS R11, R3
ADCS R12, R4
STP (R3, R4), 0(R2)
ADCS R13, R5
ADCS R14, R6
STP (R5, R6), 16(R2)
LDP 32(R0), (R7, R8)
LDP 32(R1), (R11, R12)
LDP 48(R0), (R9, R10)
LDP 48(R1), (R13, R14)
ADCS R11, R7
ADCS R12, R8
STP (R7, R8), 32(R2)
ADCS R13, R9
ADCS R14, R10
STP (R9, R10), 48(R2)
LDP 64(R0), (R3, R4)
LDP 64(R1), (R11, R12)
LDP 80(R0), (R5, R6)
LDP 80(R1), (R13, R14)
ADCS R11, R3
ADCS R12, R4
STP (R3, R4), 64(R2)
ADCS R13, R5
ADCS R14, R6
STP (R5, R6), 80(R2)
LDP 96(R0), (R7, R8)
LDP 96(R1), (R11, R12)
LDP 112(R0), (R9, R10)
LDP 112(R1), (R13, R14)
ADCS R11, R7
ADCS R12, R8
STP (R7, R8), 96(R2)
ADCS R13, R9
ADC R14, R10
STP (R9, R10), 112(R2)
RET
TEXT ·fp503X2SubLazy(SB), NOSPLIT, $0-24
MOVD z+0(FP), R2
MOVD x+8(FP), R0
MOVD y+16(FP), R1
LDP 0(R0), (R3, R4)
LDP 0(R1), (R11, R12)
LDP 16(R0), (R5, R6)
LDP 16(R1), (R13, R14)
SUBS R11, R3
SBCS R12, R4
STP (R3, R4), 0(R2)
SBCS R13, R5
SBCS R14, R6
STP (R5, R6), 16(R2)
LDP 32(R0), (R7, R8)
LDP 32(R1), (R11, R12)
LDP 48(R0), (R9, R10)
LDP 48(R1), (R13, R14)
SBCS R11, R7
SBCS R12, R8
STP (R7, R8), 32(R2)
SBCS R13, R9
SBCS R14, R10
STP (R9, R10), 48(R2)
LDP 64(R0), (R3, R4)
LDP 64(R1), (R11, R12)
LDP 80(R0), (R5, R6)
LDP 80(R1), (R13, R14)
SBCS R11, R3
SBCS R12, R4
SBCS R13, R5
SBCS R14, R6
LDP 96(R0), (R7, R8)
LDP 96(R1), (R11, R12)
LDP 112(R0), (R9, R10)
LDP 112(R1), (R13, R14)
SBCS R11, R7
SBCS R12, R8
SBCS R13, R9
SBCS R14, R10
SBC ZR, ZR, R15
// If x - y < 0, R15 is 1 and p503 should be added
LDP ·p503+16(SB), (R16, R17)
LDP ·p503+32(SB), (R19, R20)
AND R15, R16
AND R15, R17
LDP ·p503+48(SB), (R21, R22)
AND R15, R19
AND R15, R20
AND R15, R21
AND R15, R22
ADDS R16, R3
ADCS R16, R4
STP (R3, R4), 64(R2)
ADCS R16, R5
ADCS R17, R6
STP (R5, R6), 80(R2)
ADCS R19, R7
ADCS R20, R8
STP (R7, R8), 96(R2)
ADCS R21, R9
ADC R22, R10
STP (R9, R10), 112(R2)
RET
// Expects that X0*Y0 is already in Z0(low),Z3(high) and X0*Y1 in Z1(low),Z2(high)
// Z0 is not actually touched
// Result of (X0-X1) * (Y0-Y1) will be in Z0-Z3
// Inputs get overwritten, except for X1
#define mul128x128comba(X0, X1, Y0, Y1, Z0, Z1, Z2, Z3, T0) \
MUL X1, Y0, X0 \
UMULH X1, Y0, Y0 \
ADDS Z3, Z1 \
ADC ZR, Z2 \
\
MUL Y1, X1, T0 \
UMULH Y1, X1, Y1 \
ADDS X0, Z1 \
ADCS Y0, Z2 \
ADC ZR, ZR, Z3 \
\
ADDS T0, Z2 \
ADC Y1, Z3
// Expects that X points to (X0-X1)
// Result of (X0-X3) * (Y0-Y3) will be in Z0-Z7
// Inputs get overwritten, except X2-X3 and Y2-Y3
#define mul256x256karatsuba(X, X0, X1, X2, X3, Y0, Y1, Y2, Y3, Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7, T0, T1)\
ADDS X2, X0 \ // xH + xL, destroys xL
ADCS X3, X1 \
ADCS ZR, ZR, T0 \
\
ADDS Y2, Y0, Z6 \ // yH + yL
ADCS Y3, Y1, T1 \
ADC ZR, ZR, Z7 \
\
SUB T0, ZR, Z2 \
SUB Z7, ZR, Z3 \
AND Z7, T0 \ // combined carry
\
AND Z2, Z6, Z0 \ // masked(yH + yL)
AND Z2, T1, Z1 \
\
AND Z3, X0, Z4 \ // masked(xH + xL)
AND Z3, X1, Z5 \
\
MUL Z6, X0, Z2 \
MUL T1, X0, Z3 \
\
ADDS Z4, Z0 \
UMULH T1, X0, Z4 \
ADCS Z5, Z1 \
UMULH Z6, X0, Z5 \
ADC ZR, T0 \
\ // (xH + xL) * (yH + yL)
mul128x128comba(X0, X1, Z6, T1, Z2, Z3, Z4, Z5, Z7)\
\
LDP 0+X, (X0, X1) \
\
ADDS Z0, Z4 \
UMULH Y0, X0, Z7 \
UMULH Y1, X0, T1 \
ADCS Z1, Z5 \
MUL Y0, X0, Z0 \
MUL Y1, X0, Z1 \
ADC ZR, T0 \
\ // xL * yL
mul128x128comba(X0, X1, Y0, Y1, Z0, Z1, T1, Z7, Z6)\
\
MUL Y2, X2, X0 \
UMULH Y2, X2, Y0 \
SUBS Z0, Z2 \ // (xH + xL) * (yH + yL) - xL * yL
SBCS Z1, Z3 \
SBCS T1, Z4 \
MUL Y3, X2, X1 \
UMULH Y3, X2, Z6 \
SBCS Z7, Z5 \
SBCS ZR, T0 \
\ // xH * yH
mul128x128comba(X2, X3, Y2, Y3, X0, X1, Z6, Y0, Y1)\
\
SUBS X0, Z2 \ // (xH + xL) * (yH + yL) - xL * yL - xH * yH
SBCS X1, Z3 \
SBCS Z6, Z4 \
SBCS Y0, Z5 \
SBCS ZR, T0 \
\
ADDS T1, Z2 \ // (xH * yH) * 2^256 + ((xH + xL) * (yH + yL) - xL * yL - xH * yH) * 2^128 + xL * yL
ADCS Z7, Z3 \
ADCS X0, Z4 \
ADCS X1, Z5 \
ADCS T0, Z6 \
ADC Y0, ZR, Z7
// This implements two-level Karatsuba with a 128x128 Comba multiplier
// at the bottom
TEXT ·fp503Mul(SB), NOSPLIT, $0-24
MOVD z+0(FP), R2
MOVD x+8(FP), R0
MOVD y+16(FP), R1
// Load xL in R3-R6, xH in R7-R10
// (xH + xL) in R25-R29
LDP 0(R0), (R3, R4)
LDP 32(R0), (R7, R8)
ADDS R3, R7, R25
ADCS R4, R8, R26
LDP 16(R0), (R5, R6)
LDP 48(R0), (R9, R10)
ADCS R5, R9, R27
ADCS R6, R10, R29
ADC ZR, ZR, R7
// Load yL in R11-R14, yH in R15-19
// (yH + yL) in R11-R14, destroys yL
LDP 0(R1), (R11, R12)
LDP 32(R1), (R15, R16)
ADDS R15, R11
ADCS R16, R12
LDP 16(R1), (R13, R14)
LDP 48(R1), (R17, R19)
ADCS R17, R13
ADCS R19, R14
ADC ZR, ZR, R8
// Compute maskes and combined carry
SUB R7, ZR, R9
SUB R8, ZR, R10
AND R8, R7
// masked(yH + yL)
AND R9, R11, R15
AND R9, R12, R16
AND R9, R13, R17
AND R9, R14, R19
// masked(xH + xL)
AND R10, R25, R20
AND R10, R26, R21
AND R10, R27, R22
AND R10, R29, R23
// masked(xH + xL) + masked(yH + yL) in R15-R19
ADDS R20, R15
ADCS R21, R16
ADCS R22, R17
ADCS R23, R19
ADC ZR, R7
// Use z as temporary storage
STP (R25, R26), 0(R2)
// (xH + xL) * (yH + yL)
mul256x256karatsuba(0(R2), R25, R26, R27, R29, R11, R12, R13, R14, R8, R9, R10, R20, R21, R22, R23, R24, R0, R1)
MOVD x+8(FP), R0
MOVD y+16(FP), R1
ADDS R21, R15
ADCS R22, R16
ADCS R23, R17
ADCS R24, R19
ADC ZR, R7
// Load yL in R11-R14
LDP 0(R1), (R11, R12)
LDP 16(R1), (R13, R14)
// xL * yL
mul256x256karatsuba(0(R0), R3, R4, R5, R6, R11, R12, R13, R14, R21, R22, R23, R24, R25, R26, R27, R29, R1, R2)
MOVD z+0(FP), R2
MOVD y+16(FP), R1
// (xH + xL) * (yH + yL) - xL * yL
SUBS R21, R8
SBCS R22, R9
STP (R21, R22), 0(R2)
SBCS R23, R10
SBCS R24, R20
STP (R23, R24), 16(R2)
SBCS R25, R15
SBCS R26, R16
SBCS R27, R17
SBCS R29, R19
SBC ZR, R7
// Load xH in R3-R6, yH in R11-R14
LDP 32(R0), (R3, R4)
LDP 48(R0), (R5, R6)
LDP 32(R1), (R11, R12)
LDP 48(R1), (R13, R14)
ADDS R25, R8
ADCS R26, R9
ADCS R27, R10
ADCS R29, R20
ADC ZR, ZR, R1
MOVD R20, 32(R2)
// xH * yH
mul256x256karatsuba(32(R0), R3, R4, R5, R6, R11, R12, R13, R14, R21, R22, R23, R24, R25, R26, R27, R29, R2, R20)
NEG R1, R1
MOVD z+0(FP), R2
MOVD 32(R2), R20
// (xH + xL) * (yH + yL) - xL * yL - xH * yH in R8-R10,R20,R15-R19
// Store lower half in z, that's done
SUBS R21, R8
SBCS R22, R9
STP (R8, R9), 32(R2)
SBCS R23, R10
SBCS R24, R20
STP (R10, R20), 48(R2)
SBCS R25, R15
SBCS R26, R16
SBCS R27, R17
SBCS R29, R19
SBC ZR, R7
// (xH * yH) * 2^512 + ((xH + xL) * (yH + yL) - xL * yL - xH * yH) * 2^256 + xL * yL
// Store remaining limbs in z
ADDS $1, R1
ADCS R21, R15
ADCS R22, R16
STP (R15, R16), 64(R2)
ADCS R23, R17
ADCS R24, R19
STP (R17, R19), 80(R2)
ADCS R7, R25
ADCS ZR, R26
STP (R25, R26), 96(R2)
ADCS ZR, R27
ADC ZR, R29
STP (R27, R29), 112(R2)
RET
// Expects that X0*Y0 is already in Z0(low),Z3(high) and X0*Y1 in Z1(low),Z2(high)
// Z0 is not actually touched
// Result of (X0-X1) * (Y0-Y3) will be in Z0-Z5
// Inputs remain intact
#define mul128x256comba(X0, X1, Y0, Y1, Y2, Y3, Z0, Z1, Z2, Z3, Z4, Z5, T0, T1, T2, T3)\
MUL X1, Y0, T0 \
UMULH X1, Y0, T1 \
ADDS Z3, Z1 \
ADC ZR, Z2 \
\
MUL X0, Y2, T2 \
UMULH X0, Y2, T3 \
ADDS T0, Z1 \
ADCS T1, Z2 \
ADC ZR, ZR, Z3 \
\
MUL X1, Y1, T0 \
UMULH X1, Y1, T1 \
ADDS T2, Z2 \
ADCS T3, Z3 \
ADC ZR, ZR, Z4 \
\
MUL X0, Y3, T2 \
UMULH X0, Y3, T3 \
ADDS T0, Z2 \
ADCS T1, Z3 \
ADC ZR, Z4 \
\
MUL X1, Y2, T0 \
UMULH X1, Y2, T1 \
ADDS T2, Z3 \
ADCS T3, Z4 \
ADC ZR, ZR, Z5 \
\
MUL X1, Y3, T2 \
UMULH X1, Y3, T3 \
ADDS T0, Z3 \
ADCS T1, Z4 \
ADC ZR, Z5 \
ADDS T2, Z4 \
ADC T3, Z5
// This implements the shifted 2^(B*w) Montgomery reduction from
// https://eprint.iacr.org/2016/986.pdf, section Section 3.2, with
// B = 4, w = 64. Performance results were reported in
// https://eprint.iacr.org/2018/700.pdf Section 6.
TEXT ·fp503MontgomeryReduce(SB), NOSPLIT, $0-16
MOVD x+8(FP), R0
// Load x0-x1
LDP 0(R0), (R2, R3)
// Load the prime constant in R25-R29
LDP ·p503p1s8+32(SB), (R25, R26)
LDP ·p503p1s8+48(SB), (R27, R29)
// [x0,x1] * p503p1s8 to R4-R9
MUL R2, R25, R4 // x0 * p503p1s8[0]
UMULH R2, R25, R7
MUL R2, R26, R5 // x0 * p503p1s8[1]
UMULH R2, R26, R6
mul128x256comba(R2, R3, R25, R26, R27, R29, R4, R5, R6, R7, R8, R9, R10, R11, R12, R13)
LDP 16(R0), (R3, R11) // x2
LDP 32(R0), (R12, R13)
LDP 48(R0), (R14, R15)
// Left-shift result in R4-R9 by 56 to R4-R10
ORR R9>>8, ZR, R10
LSL $56, R9
ORR R8>>8, R9
LSL $56, R8
ORR R7>>8, R8
LSL $56, R7
ORR R6>>8, R7
LSL $56, R6
ORR R5>>8, R6
LSL $56, R5
ORR R4>>8, R5
LSL $56, R4
ADDS R4, R11 // x3
ADCS R5, R12 // x4
ADCS R6, R13
ADCS R7, R14
ADCS R8, R15
LDP 64(R0), (R16, R17)
LDP 80(R0), (R19, R20)
MUL R3, R25, R4 // x2 * p503p1s8[0]
UMULH R3, R25, R7
ADCS R9, R16
ADCS R10, R17
ADCS ZR, R19
ADCS ZR, R20
LDP 96(R0), (R21, R22)
LDP 112(R0), (R23, R24)
MUL R3, R26, R5 // x2 * p503p1s8[1]
UMULH R3, R26, R6
ADCS ZR, R21
ADCS ZR, R22
ADCS ZR, R23
ADC ZR, R24
// [x2,x3] * p503p1s8 to R4-R9
mul128x256comba(R3, R11, R25, R26, R27, R29, R4, R5, R6, R7, R8, R9, R10, R0, R1, R2)
ORR R9>>8, ZR, R10
LSL $56, R9
ORR R8>>8, R9
LSL $56, R8
ORR R7>>8, R8
LSL $56, R7
ORR R6>>8, R7
LSL $56, R6
ORR R5>>8, R6
LSL $56, R5
ORR R4>>8, R5
LSL $56, R4
ADDS R4, R13 // x5
ADCS R5, R14 // x6
ADCS R6, R15
ADCS R7, R16
MUL R12, R25, R4 // x4 * p503p1s8[0]
UMULH R12, R25, R7
ADCS R8, R17
ADCS R9, R19
ADCS R10, R20
ADCS ZR, R21
MUL R12, R26, R5 // x4 * p503p1s8[1]
UMULH R12, R26, R6
ADCS ZR, R22
ADCS ZR, R23
ADC ZR, R24
// [x4,x5] * p503p1s8 to R4-R9
mul128x256comba(R12, R13, R25, R26, R27, R29, R4, R5, R6, R7, R8, R9, R10, R0, R1, R2)
ORR R9>>8, ZR, R10
LSL $56, R9
ORR R8>>8, R9
LSL $56, R8
ORR R7>>8, R8
LSL $56, R7
ORR R6>>8, R7
LSL $56, R6
ORR R5>>8, R6
LSL $56, R5
ORR R4>>8, R5
LSL $56, R4
ADDS R4, R15 // x7
ADCS R5, R16 // x8
ADCS R6, R17
ADCS R7, R19
MUL R14, R25, R4 // x6 * p503p1s8[0]
UMULH R14, R25, R7
ADCS R8, R20
ADCS R9, R21
ADCS R10, R22
MUL R14, R26, R5 // x6 * p503p1s8[1]
UMULH R14, R26, R6
ADCS ZR, R23
ADC ZR, R24
// [x6,x7] * p503p1s8 to R4-R9
mul128x256comba(R14, R15, R25, R26, R27, R29, R4, R5, R6, R7, R8, R9, R10, R0, R1, R2)
ORR R9>>8, ZR, R10
LSL $56, R9
ORR R8>>8, R9
LSL $56, R8
ORR R7>>8, R8
LSL $56, R7
ORR R6>>8, R7
LSL $56, R6
ORR R5>>8, R6
LSL $56, R5
ORR R4>>8, R5
LSL $56, R4
MOVD z+0(FP), R0
ADDS R4, R17
ADCS R5, R19
STP (R16, R17), 0(R0) // Store final result to z
ADCS R6, R20
ADCS R7, R21
STP (R19, R20), 16(R0)
ADCS R8, R22
ADCS R9, R23
STP (R21, R22), 32(R0)
ADC R10, R24
STP (R23, R24), 48(R0)
RET
TEXT ·fp503StrongReduce(SB), NOSPLIT, $0-8
MOVD x+0(FP), R0
// Keep x in R1-R8, p503 in R9-R14, subtract to R1-R8
LDP ·p503+16(SB), (R9, R10)
LDP 0(R0), (R1, R2)
LDP 16(R0), (R3, R4)
SUBS R9, R1
SBCS R9, R2
LDP 32(R0), (R5, R6)
LDP ·p503+32(SB), (R11, R12)
SBCS R9, R3
SBCS R10, R4
LDP 48(R0), (R7, R8)
LDP ·p503+48(SB), (R13, R14)
SBCS R11, R5
SBCS R12, R6
SBCS R13, R7
SBCS R14, R8
SBC ZR, ZR, R15
// Mask with the borrow and add p503
AND R15, R9
AND R15, R10
AND R15, R11
AND R15, R12
AND R15, R13
AND R15, R14
ADDS R9, R1
ADCS R9, R2
STP (R1, R2), 0(R0)
ADCS R9, R3
ADCS R10, R4
STP (R3, R4), 16(R0)
ADCS R11, R5
ADCS R12, R6
STP (R5, R6), 32(R0)
ADCS R13, R7
ADCS R14, R8
STP (R7, R8), 48(R0)
RET

View File

@ -0,0 +1,46 @@
// +build amd64,!noasm arm64,!noasm
package p503
import (
. "github.com/cloudflare/sidh/internal/isogeny"
)
// If choice = 0, leave x,y unchanged. If choice = 1, set x,y = y,x.
// If choice is neither 0 nor 1 then behaviour is undefined.
// This function executes in constant time.
//go:noescape
func fp503ConditionalSwap(x, y *FpElement, choice uint8)
// Compute z = x + y (mod p).
//go:noescape
func fp503AddReduced(z, x, y *FpElement)
// Compute z = x - y (mod p).
//go:noescape
func fp503SubReduced(z, x, y *FpElement)
// Compute z = x + y, without reducing mod p.
//go:noescape
func fp503AddLazy(z, x, y *FpElement)
// Compute z = x + y, without reducing mod p.
//go:noescape
func fp503X2AddLazy(z, x, y *FpElementX2)
// Compute z = x - y, without reducing mod p.
//go:noescape
func fp503X2SubLazy(z, x, y *FpElementX2)
// Reduce a field element in [0, 2*p) to one in [0,p).
//go:noescape
func fp503StrongReduce(x *FpElement)
// Computes z = x * y.
//go:noescape
func fp503Mul(z *FpElementX2, x, y *FpElement)
// Computes the Montgomery reduction z = x R^{-1} (mod 2*p). On return value
// of x may be changed. z=x not allowed.
//go:noescape
func fp503MontgomeryReduce(z *FpElement, x *FpElementX2)

View File

@ -0,0 +1,197 @@
// +build noasm !amd64,!arm64
package p503
import (
. "github.com/cloudflare/sidh/internal/arith"
. "github.com/cloudflare/sidh/internal/isogeny"
)
// Compute z = x + y (mod p).
func fp503AddReduced(z, x, y *FpElement) {
var carry uint64
// z=x+y % p503
for i := 0; i < NumWords; i++ {
z[i], carry = Addc64(carry, x[i], y[i])
}
// z = z - p503x2
carry = 0
for i := 0; i < NumWords; i++ {
z[i], carry = Subc64(carry, z[i], p503x2[i])
}
// if z<0 add p503x2 back
mask := uint64(0 - carry)
carry = 0
for i := 0; i < NumWords; i++ {
z[i], carry = Addc64(carry, z[i], p503x2[i]&mask)
}
}
// Compute z = x - y (mod p).
func fp503SubReduced(z, x, y *FpElement) {
var borrow uint64
// z = z - p503x2
for i := 0; i < NumWords; i++ {
z[i], borrow = Subc64(borrow, x[i], y[i])
}
// if z<0 add p503x2 back
mask := uint64(0 - borrow)
borrow = 0
for i := 0; i < NumWords; i++ {
z[i], borrow = Addc64(borrow, z[i], p503x2[i]&mask)
}
}
// Conditionally swaps bits in x and y in constant time.
// mask indicates bits to be swapped (set bits are swapped)
// For details see "Hackers Delight, 2.20"
//
// Implementation doesn't actually depend on a prime field.
func fp503ConditionalSwap(x, y *FpElement, mask uint8) {
var tmp, mask64 uint64
mask64 = 0 - uint64(mask)
for i := 0; i < NumWords; i++ {
tmp = mask64 & (x[i] ^ y[i])
x[i] = tmp ^ x[i]
y[i] = tmp ^ y[i]
}
}
// Perform Montgomery reduction: set z = x R^{-1} (mod 2*p)
// with R=2^512. Destroys the input value.
func fp503MontgomeryReduce(z *FpElement, x *FpElementX2) {
var carry, t, u, v uint64
var uv Uint128
var count int
count = 3 // number of 0 digits in the least significat part of p503 + 1
for i := 0; i < NumWords; i++ {
for j := 0; j < i; j++ {
if j < (i - count + 1) {
uv = Mul64(z[j], p503p1[i-j])
v, carry = Addc64(0, uv.L, v)
u, carry = Addc64(carry, uv.H, u)
t += carry
}
}
v, carry = Addc64(0, v, x[i])
u, carry = Addc64(carry, u, 0)
t += carry
z[i] = v
v = u
u = t
t = 0
}
for i := NumWords; i < 2*NumWords-1; i++ {
if count > 0 {
count--
}
for j := i - NumWords + 1; j < NumWords; j++ {
if j < (NumWords - count) {
uv = Mul64(z[j], p503p1[i-j])
v, carry = Addc64(0, uv.L, v)
u, carry = Addc64(carry, uv.H, u)
t += carry
}
}
v, carry = Addc64(0, v, x[i])
u, carry = Addc64(carry, u, 0)
t += carry
z[i-NumWords] = v
v = u
u = t
t = 0
}
v, carry = Addc64(0, v, x[2*NumWords-1])
z[NumWords-1] = v
}
// Compute z = x * y.
func fp503Mul(z *FpElementX2, x, y *FpElement) {
var u, v, t uint64
var carry uint64
var uv Uint128
for i := uint64(0); i < NumWords; i++ {
for j := uint64(0); j <= i; j++ {
uv = Mul64(x[j], y[i-j])
v, carry = Addc64(0, uv.L, v)
u, carry = Addc64(carry, uv.H, u)
t += carry
}
z[i] = v
v = u
u = t
t = 0
}
for i := NumWords; i < (2*NumWords)-1; i++ {
for j := i - NumWords + 1; j < NumWords; j++ {
uv = Mul64(x[j], y[i-j])
v, carry = Addc64(0, uv.L, v)
u, carry = Addc64(carry, uv.H, u)
t += carry
}
z[i] = v
v = u
u = t
t = 0
}
z[2*NumWords-1] = v
}
// Compute z = x + y, without reducing mod p.
func fp503AddLazy(z, x, y *FpElement) {
var carry uint64
for i := 0; i < NumWords; i++ {
z[i], carry = Addc64(carry, x[i], y[i])
}
}
// Compute z = x + y, without reducing mod p.
func fp503X2AddLazy(z, x, y *FpElementX2) {
var carry uint64
for i := 0; i < 2*NumWords; i++ {
z[i], carry = Addc64(carry, x[i], y[i])
}
}
// Reduce a field element in [0, 2*p) to one in [0,p).
func fp503StrongReduce(x *FpElement) {
var borrow, mask uint64
for i := 0; i < NumWords; i++ {
x[i], borrow = Subc64(borrow, x[i], p503[i])
}
// Sets all bits if borrow = 1
mask = 0 - borrow
borrow = 0
for i := 0; i < NumWords; i++ {
x[i], borrow = Addc64(borrow, x[i], p503[i]&mask)
}
}
// Compute z = x - y, without reducing mod p.
func fp503X2SubLazy(z, x, y *FpElementX2) {
var borrow, mask uint64
for i := 0; i < 2*NumWords; i++ {
z[i], borrow = Subc64(borrow, x[i], y[i])
}
// Sets all bits if borrow = 1
mask = 0 - borrow
borrow = 0
for i := NumWords; i < 2*NumWords; i++ {
z[i], borrow = Addc64(borrow, z[i], p503[i-NumWords]&mask)
}
}

View File

@ -0,0 +1,178 @@
package p503
import (
. "github.com/cloudflare/sidh/internal/isogeny"
cpu "github.com/cloudflare/sidh/internal/utils"
)
const (
// SIDH public key byte size
P503_PublicKeySize = 378
// SIDH shared secret byte size.
P503_SharedSecretSize = 126
// Max size of secret key for 2-torsion group, corresponds to 2^e2 - 1
P503_SecretBitLenA = 250
// Size of secret key for 3-torsion group, corresponds to log_2(3^e3) - 1
P503_SecretBitLenB = 252
// Size of a compuatation strategy for 2-torsion group
strategySizeA = 124
// Size of a compuatation strategy for 3-torsion group
strategySizeB = 158
// ceil(503+7/8)
P503_Bytelen = 63
// Number of limbs for a field element
NumWords = 8
)
// CPU Capabilities. Those flags are referred by assembly code. According to
// https://github.com/golang/go/issues/28230, variables referred from the
// assembly must be in the same package.
// We declare them variables not constants in order to facilitate testing.
var (
// Signals support for MULX which is in BMI2
HasBMI2 = cpu.X86.HasBMI2
// Signals support for ADX and BMI2
HasADXandBMI2 = cpu.X86.HasBMI2 && cpu.X86.HasADX
)
// The x-coordinate of PA
var P503_affine_PA = Fp2Element{
A: FpElement{
0xE7EF4AA786D855AF, 0xED5758F03EB34D3B, 0x09AE172535A86AA9, 0x237B9CC07D622723,
0xE3A284CBA4E7932D, 0x27481D9176C5E63F, 0x6A323FF55C6E71BF, 0x002ECC31A6FB8773,
},
B: FpElement{
0x64D02E4E90A620B8, 0xDAB8128537D4B9F1, 0x4BADF77B8A228F98, 0x0F5DBDF9D1FB7D1B,
0xBEC4DB288E1A0DCC, 0xE76A8665E80675DB, 0x6D6F252E12929463, 0x003188BD1463FACC,
},
}
// The x-coordinate of QA
var P503_affine_QA = Fp2Element{
A: FpElement{
0xB79D41025DE85D56, 0x0B867DA9DF169686, 0x740E5368021C827D, 0x20615D72157BF25C,
0xFF1590013C9B9F5B, 0xC884DCADE8C16CEA, 0xEBD05E53BF724E01, 0x0032FEF8FDA5748C,
},
B: FpElement{
0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
},
}
// The x-coordinate of RA = PA-QA
var P503_affine_RA = Fp2Element{
A: FpElement{
0x12E2E849AA0A8006, 0x41CF47008635A1E8, 0x9CD720A70798AED7, 0x42A820B42FCF04CF,
0x7BF9BAD32AAE88B1, 0xF619127A54090BBE, 0x1CB10D8F56408EAA, 0x001D6B54C3C0EDEB,
},
B: FpElement{
0x34DB54931CBAAC36, 0x420A18CB8DD5F0C4, 0x32008C1A48C0F44D, 0x3B3BA772B1CFD44D,
0xA74B058FDAF13515, 0x095FC9CA7EEC17B4, 0x448E829D28F120F8, 0x00261EC3ED16A489,
},
}
// The x-coordinate of PB
var P503_affine_PB = Fp2Element{
A: FpElement{
0x7EDE37F4FA0BC727, 0xF7F8EC5C8598941C, 0xD15519B516B5F5C8, 0xF6D5AC9B87A36282,
0x7B19F105B30E952E, 0x13BD8B2025B4EBEE, 0x7B96D27F4EC579A2, 0x00140850CAB7E5DE,
},
B: FpElement{
0x7764909DAE7B7B2D, 0x578ABB16284911AB, 0x76E2BFD146A6BF4D, 0x4824044B23AA02F0,
0x1105048912A321F3, 0xB8A2E482CF0F10C1, 0x42FF7D0BE2152085, 0x0018E599C5223352,
},
}
// The x-coordinate of QB
var P503_affine_QB = Fp2Element{
A: FpElement{
0x4256C520FB388820, 0x744FD7C3BAAF0A13, 0x4B6A2DDDB12CBCB8, 0xE46826E27F427DF8,
0xFE4A663CD505A61B, 0xD6B3A1BAF025C695, 0x7C3BB62B8FCC00BD, 0x003AFDDE4A35746C,
},
B: FpElement{
0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
},
}
// The x-coordinate of RB = PB - QB
var P503_affine_RB = Fp2Element{
A: FpElement{
0x75601CD1E6C0DFCB, 0x1A9007239B58F93E, 0xC1F1BE80C62107AC, 0x7F513B898F29FF08,
0xEA0BEDFF43E1F7B2, 0x2C6D94018CBAE6D0, 0x3A430D31BCD84672, 0x000D26892ECCFE83,
},
B: FpElement{
0x1119D62AEA3007A1, 0xE3702AA4E04BAE1B, 0x9AB96F7D59F990E7, 0xF58440E8B43319C0,
0xAF8134BEE1489775, 0xE7F7774E905192AA, 0xF54AE09308E98039, 0x001EF7A041A86112,
},
}
// 2-torsion group computation strategy
var P503_AliceIsogenyStrategy = [strategySizeA]uint32{
0x3D, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02,
0x01, 0x01, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01,
0x01, 0x10, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01,
0x01, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01,
0x1D, 0x10, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01,
0x01, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01,
0x0D, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01,
0x05, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x01}
// 3-torsion group computation strategy
var P503_BobIsogenyStrategy = [strategySizeB]uint32{
0x47, 0x26, 0x15, 0x0D, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01,
0x02, 0x01, 0x01, 0x05, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x01, 0x09,
0x05, 0x03, 0x02, 0x01, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x01,
0x02, 0x01, 0x01, 0x11, 0x09, 0x05, 0x03, 0x02, 0x01, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x01,
0x04, 0x02, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x08, 0x04, 0x02, 0x01, 0x01, 0x01, 0x02, 0x01,
0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x21, 0x11, 0x09, 0x05, 0x03, 0x02, 0x01, 0x01,
0x01, 0x01, 0x02, 0x01, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x08, 0x04,
0x02, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x10, 0x08,
0x04, 0x02, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x08,
0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01}
// Used internally by this package
// -------------------------------
var p503 = FpElement{
0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xABFFFFFFFFFFFFFF,
0x13085BDA2211E7A0, 0x1B9BF6C87B7E7DAF, 0x6045C6BDDA77A4D0, 0x004066F541811E1E,
}
// 2*503
var p503x2 = FpElement{
0xFFFFFFFFFFFFFFFE, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x57FFFFFFFFFFFFFF,
0x2610B7B44423CF41, 0x3737ED90F6FCFB5E, 0xC08B8D7BB4EF49A0, 0x0080CDEA83023C3C,
}
// p503 + 1
var p503p1 = FpElement{
0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0xAC00000000000000,
0x13085BDA2211E7A0, 0x1B9BF6C87B7E7DAF, 0x6045C6BDDA77A4D0, 0x004066F541811E1E,
}
// R^2=(2^512)^2 mod p
var p503R2 = FpElement{
0x5289A0CF641D011F, 0x9B88257189FED2B9, 0xA3B365D58DC8F17A, 0x5BC57AB6EFF168EC,
0x9E51998BD84D4423, 0xBF8999CBAC3B5695, 0x46E9127BCE14CDB6, 0x003F6CFCE8B81771,
}
// p503 + 1 left-shifted by 8, assuming little endianness
var p503p1s8 = FpElement{
0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x085BDA2211E7A0AC, 0x9BF6C87B7E7DAF13, 0x45C6BDDA77A4D01B, 0x4066F541811E1E60,
}
// 1*R mod p
var P503_OneFp2 = Fp2Element{
A: FpElement{
0x00000000000003F9, 0x0000000000000000, 0x0000000000000000, 0xB400000000000000,
0x63CB1A6EA6DED2B4, 0x51689D8D667EB37D, 0x8ACD77C71AB24142, 0x0026FBAEC60F5953},
}
// 1/2 * R mod p
var P503_HalfFp2 = Fp2Element{
A: FpElement{
0x00000000000001FC, 0x0000000000000000, 0x0000000000000000, 0xB000000000000000,
0x3B69BB2464785D2A, 0x36824A2AF0FE9896, 0xF5899F427A94F309, 0x0033B15203C83BB8},
}

View File

@ -0,0 +1,249 @@
package p503
import (
. "github.com/cloudflare/sidh/internal/isogeny"
)
type fp503Ops struct{}
func FieldOperations() FieldOps {
return &fp503Ops{}
}
func (fp503Ops) Add(dest, lhs, rhs *Fp2Element) {
fp503AddReduced(&dest.A, &lhs.A, &rhs.A)
fp503AddReduced(&dest.B, &lhs.B, &rhs.B)
}
func (fp503Ops) Sub(dest, lhs, rhs *Fp2Element) {
fp503SubReduced(&dest.A, &lhs.A, &rhs.A)
fp503SubReduced(&dest.B, &lhs.B, &rhs.B)
}
func (fp503Ops) Mul(dest, lhs, rhs *Fp2Element) {
// Let (a,b,c,d) = (lhs.a,lhs.b,rhs.a,rhs.b).
a := &lhs.A
b := &lhs.B
c := &rhs.A
d := &rhs.B
// We want to compute
//
// (a + bi)*(c + di) = (a*c - b*d) + (a*d + b*c)i
//
// Use Karatsuba's trick: note that
//
// (b - a)*(c - d) = (b*c + a*d) - a*c - b*d
//
// so (a*d + b*c) = (b-a)*(c-d) + a*c + b*d.
var ac, bd FpElementX2
fp503Mul(&ac, a, c) // = a*c*R*R
fp503Mul(&bd, b, d) // = b*d*R*R
var b_minus_a, c_minus_d FpElement
fp503SubReduced(&b_minus_a, b, a) // = (b-a)*R
fp503SubReduced(&c_minus_d, c, d) // = (c-d)*R
var ad_plus_bc FpElementX2
fp503Mul(&ad_plus_bc, &b_minus_a, &c_minus_d) // = (b-a)*(c-d)*R*R
fp503X2AddLazy(&ad_plus_bc, &ad_plus_bc, &ac) // = ((b-a)*(c-d) + a*c)*R*R
fp503X2AddLazy(&ad_plus_bc, &ad_plus_bc, &bd) // = ((b-a)*(c-d) + a*c + b*d)*R*R
fp503MontgomeryReduce(&dest.B, &ad_plus_bc) // = (a*d + b*c)*R mod p
var ac_minus_bd FpElementX2
fp503X2SubLazy(&ac_minus_bd, &ac, &bd) // = (a*c - b*d)*R*R
fp503MontgomeryReduce(&dest.A, &ac_minus_bd) // = (a*c - b*d)*R mod p
}
// Set dest = 1/x
//
// Allowed to overlap dest with x.
//
// Returns dest to allow chaining operations.
func (fp503Ops) Inv(dest, x *Fp2Element) {
a := &x.A
b := &x.B
// We want to compute
//
// 1 1 (a - bi) (a - bi)
// -------- = -------- -------- = -----------
// (a + bi) (a + bi) (a - bi) (a^2 + b^2)
//
// Letting c = 1/(a^2 + b^2), this is
//
// 1/(a+bi) = a*c - b*ci.
var asq_plus_bsq primeFieldElement
var asq, bsq FpElementX2
fp503Mul(&asq, a, a) // = a*a*R*R
fp503Mul(&bsq, b, b) // = b*b*R*R
fp503X2AddLazy(&asq, &asq, &bsq) // = (a^2 + b^2)*R*R
fp503MontgomeryReduce(&asq_plus_bsq.A, &asq) // = (a^2 + b^2)*R mod p
// Now asq_plus_bsq = a^2 + b^2
inv := asq_plus_bsq
inv.Mul(&asq_plus_bsq, &asq_plus_bsq)
inv.P34(&inv)
inv.Mul(&inv, &inv)
inv.Mul(&inv, &asq_plus_bsq)
var ac FpElementX2
fp503Mul(&ac, a, &inv.A)
fp503MontgomeryReduce(&dest.A, &ac)
var minus_b FpElement
fp503SubReduced(&minus_b, &minus_b, b)
var minus_bc FpElementX2
fp503Mul(&minus_bc, &minus_b, &inv.A)
fp503MontgomeryReduce(&dest.B, &minus_bc)
}
func (fp503Ops) Square(dest, x *Fp2Element) {
a := &x.A
b := &x.B
// We want to compute
//
// (a + bi)*(a + bi) = (a^2 - b^2) + 2abi.
var a2, a_plus_b, a_minus_b FpElement
fp503AddReduced(&a2, a, a) // = a*R + a*R = 2*a*R
fp503AddReduced(&a_plus_b, a, b) // = a*R + b*R = (a+b)*R
fp503SubReduced(&a_minus_b, a, b) // = a*R - b*R = (a-b)*R
var asq_minus_bsq, ab2 FpElementX2
fp503Mul(&asq_minus_bsq, &a_plus_b, &a_minus_b) // = (a+b)*(a-b)*R*R = (a^2 - b^2)*R*R
fp503Mul(&ab2, &a2, b) // = 2*a*b*R*R
fp503MontgomeryReduce(&dest.A, &asq_minus_bsq) // = (a^2 - b^2)*R mod p
fp503MontgomeryReduce(&dest.B, &ab2) // = 2*a*b*R mod p
}
// In case choice == 1, performs following swap in constant time:
// xPx <-> xQx
// xPz <-> xQz
// Otherwise returns xPx, xPz, xQx, xQz unchanged
func (fp503Ops) CondSwap(xPx, xPz, xQx, xQz *Fp2Element, choice uint8) {
fp503ConditionalSwap(&xPx.A, &xQx.A, choice)
fp503ConditionalSwap(&xPx.B, &xQx.B, choice)
fp503ConditionalSwap(&xPz.A, &xQz.A, choice)
fp503ConditionalSwap(&xPz.B, &xQz.B, choice)
}
// Converts values in x.A and x.B to Montgomery domain
// x.A = x.A * R mod p
// x.B = x.B * R mod p
// Performs v = v*R^2*R^(-1) mod p, for both x.A and x.B
func (fp503Ops) ToMontgomery(x *Fp2Element) {
var aRR FpElementX2
// convert to montgomery domain
fp503Mul(&aRR, &x.A, &p503R2) // = a*R*R
fp503MontgomeryReduce(&x.A, &aRR) // = a*R mod p
fp503Mul(&aRR, &x.B, &p503R2)
fp503MontgomeryReduce(&x.B, &aRR)
}
// Converts values in x.A and x.B from Montgomery domain
// a = x.A mod p
// b = x.B mod p
//
// After returning from the call x is not modified.
func (fp503Ops) FromMontgomery(x *Fp2Element, out *Fp2Element) {
var aR FpElementX2
// convert from montgomery domain
// TODO: make fpXXXMontgomeryReduce use stack instead of reusing aR
// so that we don't have do this copy here
copy(aR[:], x.A[:])
fp503MontgomeryReduce(&out.A, &aR) // = a mod p in [0, 2p)
fp503StrongReduce(&out.A) // = a mod p in [0, p)
for i := range aR {
aR[i] = 0
}
copy(aR[:], x.B[:])
fp503MontgomeryReduce(&out.B, &aR)
fp503StrongReduce(&out.B)
}
//------------------------------------------------------------------------------
// Prime Field
//------------------------------------------------------------------------------
// Represents an element of the prime field F_p.
type primeFieldElement struct {
// This field element is in Montgomery form, so that the value `A` is
// represented by `aR mod p`.
A FpElement
}
// Set dest = lhs * rhs.
//
// Allowed to overlap lhs or rhs with dest.
//
// Returns dest to allow chaining operations.
func (dest *primeFieldElement) Mul(lhs, rhs *primeFieldElement) *primeFieldElement {
a := &lhs.A // = a*R
b := &rhs.A // = b*R
var ab FpElementX2
fp503Mul(&ab, a, b) // = a*b*R*R
fp503MontgomeryReduce(&dest.A, &ab) // = a*b*R mod p
return dest
}
// Set dest = x^(2^k), for k >= 1, by repeated squarings.
//
// Allowed to overlap x with dest.
//
// Returns dest to allow chaining operations.
func (dest *primeFieldElement) Pow2k(x *primeFieldElement, k uint8) *primeFieldElement {
dest.Mul(x, x)
for i := uint8(1); i < k; i++ {
dest.Mul(dest, dest)
}
return dest
}
// Set dest = x^((p-3)/4). If x is square, this is 1/sqrt(x).
// Uses variation of sliding-window algorithm from with window size
// of 5 and least to most significant bit sliding (left-to-right)
// See HAC 14.85 for general description.
//
// Allowed to overlap x with dest.
//
// Returns dest to allow chaining operations.
func (dest *primeFieldElement) P34(x *primeFieldElement) *primeFieldElement {
// Sliding-window strategy computed with etc/scripts/sliding_window_strat_calc.py
//
// This performs sum(powStrategy) + 1 squarings and len(lookup) + len(mulStrategy)
// multiplications.
powStrategy := []uint8{1, 12, 5, 5, 2, 7, 11, 3, 8, 4, 11, 4, 7, 5, 6, 3, 7, 5, 7, 2, 12, 5, 6, 4, 6, 8, 6, 4, 7, 5, 5, 8, 5, 8, 5, 5, 8, 9, 3, 6, 2, 10, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3}
mulStrategy := []uint8{0, 12, 11, 10, 0, 1, 8, 3, 7, 1, 8, 3, 6, 7, 14, 2, 14, 14, 9, 0, 13, 9, 15, 5, 12, 7, 13, 7, 15, 6, 7, 9, 0, 5, 7, 6, 8, 8, 3, 7, 0, 10, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 3}
// Precompute lookup table of odd multiples of x for window
// size k=5.
lookup := [16]primeFieldElement{}
xx := &primeFieldElement{}
xx.Mul(x, x)
lookup[0] = *x
for i := 1; i < 16; i++ {
lookup[i].Mul(&lookup[i-1], xx)
}
// Now lookup = {x, x^3, x^5, ... }
// so that lookup[i] = x^{2*i + 1}
// so that lookup[k/2] = x^k, for odd k
*dest = lookup[mulStrategy[0]]
for i := uint8(1); i < uint8(len(powStrategy)); i++ {
dest.Pow2k(dest, powStrategy[i])
dest.Mul(dest, &lookup[mulStrategy[i]])
}
return dest
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,46 @@
// +build amd64,!noasm arm64,!noasm
package p751
import (
. "github.com/cloudflare/sidh/internal/isogeny"
)
// If choice = 0, leave x,y unchanged. If choice = 1, set x,y = y,x.
// If choice is neither 0 nor 1 then behaviour is undefined.
// This function executes in constant time.
//go:noescape
func fp751ConditionalSwap(x, y *FpElement, choice uint8)
// Compute z = x + y (mod p).
//go:noescape
func fp751AddReduced(z, x, y *FpElement)
// Compute z = x - y (mod p).
//go:noescape
func fp751SubReduced(z, x, y *FpElement)
// Compute z = x + y, without reducing mod p.
//go:noescape
func fp751AddLazy(z, x, y *FpElement)
// Compute z = x + y, without reducing mod p.
//go:noescape
func fp751X2AddLazy(z, x, y *FpElementX2)
// Compute z = x - y, without reducing mod p.
//go:noescape
func fp751X2SubLazy(z, x, y *FpElementX2)
// Compute z = x * y.
//go:noescape
func fp751Mul(z *FpElementX2, x, y *FpElement)
// Compute Montgomery reduction: set z = x * R^{-1} (mod 2*p).
// It may destroy the input value.
//go:noescape
func fp751MontgomeryReduce(z *FpElement, x *FpElementX2)
// Reduce a field element in [0, 2*p) to one in [0,p).
//go:noescape
func fp751StrongReduce(x *FpElement)

View File

@ -0,0 +1,196 @@
// +build noasm !amd64,!arm64
package p751
import (
. "github.com/cloudflare/sidh/internal/arith"
. "github.com/cloudflare/sidh/internal/isogeny"
)
// Compute z = x + y (mod p).
func fp751AddReduced(z, x, y *FpElement) {
var carry uint64
// z=x+y % p751
for i := 0; i < NumWords; i++ {
z[i], carry = Addc64(carry, x[i], y[i])
}
// z = z - p751x2
carry = 0
for i := 0; i < NumWords; i++ {
z[i], carry = Subc64(carry, z[i], p751x2[i])
}
// z = z + p751x2
mask := uint64(0 - carry)
carry = 0
for i := 0; i < NumWords; i++ {
z[i], carry = Addc64(carry, z[i], p751x2[i]&mask)
}
}
// Compute z = x - y (mod p).
func fp751SubReduced(z, x, y *FpElement) {
var borrow uint64
for i := 0; i < NumWords; i++ {
z[i], borrow = Subc64(borrow, x[i], y[i])
}
mask := uint64(0 - borrow)
borrow = 0
for i := 0; i < NumWords; i++ {
z[i], borrow = Addc64(borrow, z[i], p751x2[i]&mask)
}
}
// Conditionally swaps bits in x and y in constant time.
// mask indicates bits to be swaped (set bits are swapped)
// For details see "Hackers Delight, 2.20"
//
// Implementation doesn't actually depend on a prime field.
func fp751ConditionalSwap(x, y *FpElement, mask uint8) {
var tmp, mask64 uint64
mask64 = 0 - uint64(mask)
for i := 0; i < len(x); i++ {
tmp = mask64 & (x[i] ^ y[i])
x[i] = tmp ^ x[i]
y[i] = tmp ^ y[i]
}
}
// Perform Montgomery reduction: set z = x R^{-1} (mod 2*p)
// with R=2^768. Destroys the input value.
func fp751MontgomeryReduce(z *FpElement, x *FpElementX2) {
var carry, t, u, v uint64
var uv Uint128
var count int
count = 5 // number of 0 digits in the least significat part of p751 + 1
for i := 0; i < NumWords; i++ {
for j := 0; j < i; j++ {
if j < (i - count + 1) {
uv = Mul64(z[j], p751p1[i-j])
v, carry = Addc64(0, uv.L, v)
u, carry = Addc64(carry, uv.H, u)
t += carry
}
}
v, carry = Addc64(0, v, x[i])
u, carry = Addc64(carry, u, 0)
t += carry
z[i] = v
v = u
u = t
t = 0
}
for i := NumWords; i < 2*NumWords-1; i++ {
if count > 0 {
count--
}
for j := i - NumWords + 1; j < NumWords; j++ {
if j < (NumWords - count) {
uv = Mul64(z[j], p751p1[i-j])
v, carry = Addc64(0, uv.L, v)
u, carry = Addc64(carry, uv.H, u)
t += carry
}
}
v, carry = Addc64(0, v, x[i])
u, carry = Addc64(carry, u, 0)
t += carry
z[i-NumWords] = v
v = u
u = t
t = 0
}
v, carry = Addc64(0, v, x[2*NumWords-1])
z[NumWords-1] = v
}
// Compute z = x * y.
func fp751Mul(z *FpElementX2, x, y *FpElement) {
var u, v, t uint64
var carry uint64
var uv Uint128
for i := uint64(0); i < NumWords; i++ {
for j := uint64(0); j <= i; j++ {
uv = Mul64(x[j], y[i-j])
v, carry = Addc64(0, uv.L, v)
u, carry = Addc64(carry, uv.H, u)
t += carry
}
z[i] = v
v = u
u = t
t = 0
}
for i := NumWords; i < (2*NumWords)-1; i++ {
for j := i - NumWords + 1; j < NumWords; j++ {
uv = Mul64(x[j], y[i-j])
v, carry = Addc64(0, uv.L, v)
u, carry = Addc64(carry, uv.H, u)
t += carry
}
z[i] = v
v = u
u = t
t = 0
}
z[2*NumWords-1] = v
}
// Compute z = x + y, without reducing mod p.
func fp751AddLazy(z, x, y *FpElement) {
var carry uint64
for i := 0; i < NumWords; i++ {
z[i], carry = Addc64(carry, x[i], y[i])
}
}
// Compute z = x + y, without reducing mod p.
func fp751X2AddLazy(z, x, y *FpElementX2) {
var carry uint64
for i := 0; i < 2*NumWords; i++ {
z[i], carry = Addc64(carry, x[i], y[i])
}
}
// Reduce a field element in [0, 2*p) to one in [0,p).
func fp751StrongReduce(x *FpElement) {
var borrow, mask uint64
for i := 0; i < NumWords; i++ {
x[i], borrow = Subc64(borrow, x[i], p751[i])
}
// Sets all bits if borrow = 1
mask = 0 - borrow
borrow = 0
for i := 0; i < NumWords; i++ {
x[i], borrow = Addc64(borrow, x[i], p751[i]&mask)
}
}
// Compute z = x - y, without reducing mod p.
func fp751X2SubLazy(z, x, y *FpElementX2) {
var borrow, mask uint64
for i := 0; i < len(z); i++ {
z[i], borrow = Subc64(borrow, x[i], y[i])
}
// Sets all bits if borrow = 1
mask = 0 - borrow
borrow = 0
for i := NumWords; i < len(z); i++ {
z[i], borrow = Addc64(borrow, z[i], p751[i-NumWords]&mask)
}
}

View File

@ -0,0 +1,227 @@
package p751
import (
. "github.com/cloudflare/sidh/internal/isogeny"
cpu "github.com/cloudflare/sidh/internal/utils"
)
const (
// SIDH public key byte size
P751_PublicKeySize = 564
// SIDH shared secret byte size.
P751_SharedSecretSize = 188
// Max size of secret key for 2-torsion group, corresponds to 2^e2
P751_SecretBitLenA = 372
// Size of secret key for 3-torsion group, corresponds to floor(log_2(3^e3))
P751_SecretBitLenB = 378
// P751 bytelen ceil(751/8)
P751_Bytelen = 94
// Size of a compuatation strategy for 2-torsion group
strategySizeA = 185
// Size of a compuatation strategy for 3-torsion group
strategySizeB = 238
// Number of 64-bit limbs used to store Fp element
NumWords = 12
)
// CPU Capabilities. Those flags are referred by assembly code. According to
// https://github.com/golang/go/issues/28230, variables referred from the
// assembly must be in the same package.
// We declare them variables not constants in order to facilitate testing.
var (
// Signals support for MULX which is in BMI2
HasBMI2 = cpu.X86.HasBMI2
// Signals support for ADX and BMI2
HasADXandBMI2 = cpu.X86.HasBMI2 && cpu.X86.HasADX
)
// The x-coordinate of PA
var P751_affine_PA = Fp2Element{
A: FpElement{
0xC2FC08CEAB50AD8B, 0x1D7D710F55E457B1, 0xE8738D92953DCD6E,
0xBAA7EBEE8A3418AA, 0xC9A288345F03F46F, 0xC8D18D167CFE2616,
0x02043761F6B1C045, 0xAA1975E13180E7E9, 0x9E13D3FDC6690DE6,
0x3A024640A3A3BB4F, 0x4E5AD44E6ACBBDAE, 0x0000544BEB561DAD,
},
B: FpElement{
0xE6CC41D21582E411, 0x07C2ECB7C5DF400A, 0xE8E34B521432AEC4,
0x50761E2AB085167D, 0x032CFBCAA6094B3C, 0x6C522F5FDF9DDD71,
0x1319217DC3A1887D, 0xDC4FB25803353A86, 0x362C8D7B63A6AB09,
0x39DCDFBCE47EA488, 0x4C27C99A2C28D409, 0x00003CB0075527C4,
},
}
// The x-coordinate of QA
var P751_affine_QA = Fp2Element{
A: FpElement{
0xD56FE52627914862, 0x1FAD60DC96B5BAEA, 0x01E137D0BF07AB91,
0x404D3E9252161964, 0x3C5385E4CD09A337, 0x4476426769E4AF73,
0x9790C6DB989DFE33, 0xE06E1C04D2AA8B5E, 0x38C08185EDEA73B9,
0xAA41F678A4396CA6, 0x92B9259B2229E9A0, 0x00002F9326818BE0,
},
B: FpElement{
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
},
}
// The x-coordinate of RA = PA-QA
var P751_affine_RA = Fp2Element{
A: FpElement{
0x0BB84441DFFD19B3, 0x84B4DEA99B48C18E, 0x692DE648AD313805,
0xE6D72761B6DFAEE0, 0x223975C672C3058D, 0xA0FDE0C3CBA26FDC,
0xA5326132A922A3CA, 0xCA5E7F5D5EA96FA4, 0x127C7EFE33FFA8C6,
0x4749B1567E2A23C4, 0x2B7DF5B4AF413BFA, 0x0000656595B9623C,
},
B: FpElement{
0xED78C17F1EC71BE8, 0xF824D6DF753859B1, 0x33A10839B2A8529F,
0xFC03E9E25FDEA796, 0xC4708A8054DF1762, 0x4034F2EC034C6467,
0xABFB70FBF06ECC79, 0xDABE96636EC108B7, 0x49CBCFB090605FD3,
0x20B89711819A45A7, 0xFB8E1590B2B0F63E, 0x0000556A5F964AB2,
},
}
// The x-coordinate of PB
var P751_affine_PB = Fp2Element{
A: FpElement{
0xCFB6D71EF867AB0B, 0x4A5FDD76E9A45C76, 0x38B1EE69194B1F03,
0xF6E7B18A7761F3F0, 0xFCF01A486A52C84C, 0xCBE2F63F5AA75466,
0x6487BCE837B5E4D6, 0x7747F5A8C622E9B8, 0x4CBFE1E4EE6AEBBA,
0x8A8616A13FA91512, 0x53DB980E1579E0A5, 0x000058FEBFF3BE69,
},
B: FpElement{
0xA492034E7C075CC3, 0x677BAF00B04AA430, 0x3AAE0C9A755C94C8,
0x1DC4B064E9EBB08B, 0x3684EDD04E826C66, 0x9BAA6CB661F01B22,
0x20285A00AD2EFE35, 0xDCE95ABD0497065F, 0x16C7FBB3778E3794,
0x26B3AC29CEF25AAF, 0xFB3C28A31A30AC1D, 0x000046ED190624EE,
},
}
// The x-coordinate of QB
var P751_affine_QB = Fp2Element{
A: FpElement{
0xF1A8C9ED7B96C4AB, 0x299429DA5178486E, 0xEF4926F20CD5C2F4,
0x683B2E2858B4716A, 0xDDA2FBCC3CAC3EEB, 0xEC055F9F3A600460,
0xD5A5A17A58C3848B, 0x4652D836F42EAED5, 0x2F2E71ED78B3A3B3,
0xA771C057180ADD1D, 0xC780A5D2D835F512, 0x0000114EA3B55AC1,
},
B: FpElement{
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
},
}
// The x-coordinate of RB = PB - QB
var P751_affine_RB = Fp2Element{
A: FpElement{
0x1C0D6733769D0F31, 0xF084C3086E2659D1, 0xE23D5DA27BCBD133,
0xF38EC9A8D5864025, 0x6426DC781B3B645B, 0x4B24E8E3C9FB03EE,
0x6432792F9D2CEA30, 0x7CC8E8B1AE76E857, 0x7F32BFB626BB8963,
0xB9F05995B48D7B74, 0x4D71200A7D67E042, 0x0000228457AF0637,
},
B: FpElement{
0x4AE37E7D8F72BD95, 0xDD2D504B3E993488, 0x5D14E7FA1ECB3C3E,
0x127610CEB75D6350, 0x255B4B4CAC446B11, 0x9EA12336C1F70CAF,
0x79FA68A2147BC2F8, 0x11E895CFDADBBC49, 0xE4B9D3C4D6356C18,
0x44B25856A67F951C, 0x5851541F61308D0B, 0x00002FFD994F7E4C,
},
}
// 2-torsion group computation strategy
var P751_AliceIsogenyStrategy = [strategySizeA]uint32{
0x50, 0x30, 0x1B, 0x0F, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02,
0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x07,
0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x03, 0x02, 0x01,
0x01, 0x01, 0x01, 0x0C, 0x07, 0x04, 0x02, 0x01, 0x01, 0x02,
0x01, 0x01, 0x03, 0x02, 0x01, 0x01, 0x01, 0x01, 0x05, 0x03,
0x02, 0x01, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x01, 0x15,
0x0C, 0x07, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x03,
0x02, 0x01, 0x01, 0x01, 0x01, 0x05, 0x03, 0x02, 0x01, 0x01,
0x01, 0x01, 0x02, 0x01, 0x01, 0x01, 0x09, 0x05, 0x03, 0x02,
0x01, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x01, 0x04, 0x02,
0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x21, 0x14, 0x0C, 0x07,
0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x03, 0x02, 0x01,
0x01, 0x01, 0x01, 0x05, 0x03, 0x02, 0x01, 0x01, 0x01, 0x01,
0x02, 0x01, 0x01, 0x01, 0x08, 0x05, 0x03, 0x02, 0x01, 0x01,
0x01, 0x01, 0x02, 0x01, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01,
0x02, 0x01, 0x01, 0x10, 0x08, 0x04, 0x02, 0x01, 0x01, 0x01,
0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01,
0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02,
0x01, 0x01, 0x02, 0x01, 0x01}
// 3-torsion group computation strategy
var P751_BobIsogenyStrategy = [strategySizeB]uint32{
0x70, 0x3F, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02,
0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x08,
0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01,
0x01, 0x02, 0x01, 0x01, 0x10, 0x08, 0x04, 0x02, 0x01, 0x01,
0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01,
0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02,
0x01, 0x01, 0x02, 0x01, 0x01, 0x1F, 0x10, 0x08, 0x04, 0x02,
0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02,
0x01, 0x01, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01,
0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x0F, 0x08, 0x04,
0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01,
0x02, 0x01, 0x01, 0x07, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01,
0x01, 0x03, 0x02, 0x01, 0x01, 0x01, 0x01, 0x31, 0x1F, 0x10,
0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04, 0x02,
0x01, 0x01, 0x02, 0x01, 0x01, 0x08, 0x04, 0x02, 0x01, 0x01,
0x02, 0x01, 0x01, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01,
0x0F, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x04,
0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x07, 0x04, 0x02, 0x01,
0x01, 0x02, 0x01, 0x01, 0x03, 0x02, 0x01, 0x01, 0x01, 0x01,
0x15, 0x0C, 0x08, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01,
0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01, 0x05, 0x03, 0x02,
0x01, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x01, 0x09, 0x05,
0x03, 0x02, 0x01, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01, 0x01,
0x04, 0x02, 0x01, 0x01, 0x01, 0x02, 0x01, 0x01}
// Used internally by this package. Not consts as Go doesn't allow arrays to be consts
// -------------------------------
// p751
var p751 = FpElement{
0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff,
0xffffffffffffffff, 0xffffffffffffffff, 0xeeafffffffffffff,
0xe3ec968549f878a8, 0xda959b1a13f7cc76, 0x084e9867d6ebe876,
0x8562b5045cb25748, 0x0e12909f97badc66, 0x00006fe5d541f71c}
// 2*p751
var p751x2 = FpElement{
0xFFFFFFFFFFFFFFFE, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xDD5FFFFFFFFFFFFF,
0xC7D92D0A93F0F151, 0xB52B363427EF98ED, 0x109D30CFADD7D0ED,
0x0AC56A08B964AE90, 0x1C25213F2F75B8CD, 0x0000DFCBAA83EE38}
// p751 + 1
var p751p1 = FpElement{
0x0000000000000000, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0xeeb0000000000000,
0xe3ec968549f878a8, 0xda959b1a13f7cc76, 0x084e9867d6ebe876,
0x8562b5045cb25748, 0x0e12909f97badc66, 0x00006fe5d541f71c}
// R^2 = (2^768)^2 mod p
var p751R2 = FpElement{
2535603850726686808, 15780896088201250090, 6788776303855402382,
17585428585582356230, 5274503137951975249, 2266259624764636289,
11695651972693921304, 13072885652150159301, 4908312795585420432,
6229583484603254826, 488927695601805643, 72213483953973}
// 1*R mod p
var P751_OneFp2 = Fp2Element{
A: FpElement{
0x249ad, 0x0, 0x0, 0x0, 0x0, 0x8310000000000000, 0x5527b1e4375c6c66, 0x697797bf3f4f24d0, 0xc89db7b2ac5c4e2e, 0x4ca4b439d2076956, 0x10f7926c7512c7e9, 0x2d5b24bce5e2},
}
// 1/2 * R mod p
var P751_HalfFp2 = Fp2Element{
A: FpElement{
0x00000000000124D6, 0x0000000000000000, 0x0000000000000000,
0x0000000000000000, 0x0000000000000000, 0xB8E0000000000000,
0x9C8A2434C0AA7287, 0xA206996CA9A378A3, 0x6876280D41A41B52,
0xE903B49F175CE04F, 0x0F8511860666D227, 0x00004EA07CFF6E7F},
}

View File

@ -0,0 +1,254 @@
package p751
import . "github.com/cloudflare/sidh/internal/isogeny"
// 2*p751
var ()
//------------------------------------------------------------------------------
// Implementtaion of FieldOperations
//------------------------------------------------------------------------------
// Implements FieldOps
type fp751Ops struct{}
func FieldOperations() FieldOps {
return &fp751Ops{}
}
func (fp751Ops) Add(dest, lhs, rhs *Fp2Element) {
fp751AddReduced(&dest.A, &lhs.A, &rhs.A)
fp751AddReduced(&dest.B, &lhs.B, &rhs.B)
}
func (fp751Ops) Sub(dest, lhs, rhs *Fp2Element) {
fp751SubReduced(&dest.A, &lhs.A, &rhs.A)
fp751SubReduced(&dest.B, &lhs.B, &rhs.B)
}
func (fp751Ops) Mul(dest, lhs, rhs *Fp2Element) {
// Let (a,b,c,d) = (lhs.a,lhs.b,rhs.a,rhs.b).
a := &lhs.A
b := &lhs.B
c := &rhs.A
d := &rhs.B
// We want to compute
//
// (a + bi)*(c + di) = (a*c - b*d) + (a*d + b*c)i
//
// Use Karatsuba's trick: note that
//
// (b - a)*(c - d) = (b*c + a*d) - a*c - b*d
//
// so (a*d + b*c) = (b-a)*(c-d) + a*c + b*d.
var ac, bd FpElementX2
fp751Mul(&ac, a, c) // = a*c*R*R
fp751Mul(&bd, b, d) // = b*d*R*R
var b_minus_a, c_minus_d FpElement
fp751SubReduced(&b_minus_a, b, a) // = (b-a)*R
fp751SubReduced(&c_minus_d, c, d) // = (c-d)*R
var ad_plus_bc FpElementX2
fp751Mul(&ad_plus_bc, &b_minus_a, &c_minus_d) // = (b-a)*(c-d)*R*R
fp751X2AddLazy(&ad_plus_bc, &ad_plus_bc, &ac) // = ((b-a)*(c-d) + a*c)*R*R
fp751X2AddLazy(&ad_plus_bc, &ad_plus_bc, &bd) // = ((b-a)*(c-d) + a*c + b*d)*R*R
fp751MontgomeryReduce(&dest.B, &ad_plus_bc) // = (a*d + b*c)*R mod p
var ac_minus_bd FpElementX2
fp751X2SubLazy(&ac_minus_bd, &ac, &bd) // = (a*c - b*d)*R*R
fp751MontgomeryReduce(&dest.A, &ac_minus_bd) // = (a*c - b*d)*R mod p
}
func (fp751Ops) Square(dest, x *Fp2Element) {
a := &x.A
b := &x.B
// We want to compute
//
// (a + bi)*(a + bi) = (a^2 - b^2) + 2abi.
var a2, a_plus_b, a_minus_b FpElement
fp751AddReduced(&a2, a, a) // = a*R + a*R = 2*a*R
fp751AddReduced(&a_plus_b, a, b) // = a*R + b*R = (a+b)*R
fp751SubReduced(&a_minus_b, a, b) // = a*R - b*R = (a-b)*R
var asq_minus_bsq, ab2 FpElementX2
fp751Mul(&asq_minus_bsq, &a_plus_b, &a_minus_b) // = (a+b)*(a-b)*R*R = (a^2 - b^2)*R*R
fp751Mul(&ab2, &a2, b) // = 2*a*b*R*R
fp751MontgomeryReduce(&dest.A, &asq_minus_bsq) // = (a^2 - b^2)*R mod p
fp751MontgomeryReduce(&dest.B, &ab2) // = 2*a*b*R mod p
}
// Set dest = 1/x
//
// Allowed to overlap dest with x.
//
// Returns dest to allow chaining operations.
func (fp751Ops) Inv(dest, x *Fp2Element) {
a := &x.A
b := &x.B
// We want to compute
//
// 1 1 (a - bi) (a - bi)
// -------- = -------- -------- = -----------
// (a + bi) (a + bi) (a - bi) (a^2 + b^2)
//
// Letting c = 1/(a^2 + b^2), this is
//
// 1/(a+bi) = a*c - b*ci.
var asq_plus_bsq primeFieldElement
var asq, bsq FpElementX2
fp751Mul(&asq, a, a) // = a*a*R*R
fp751Mul(&bsq, b, b) // = b*b*R*R
fp751X2AddLazy(&asq, &asq, &bsq) // = (a^2 + b^2)*R*R
fp751MontgomeryReduce(&asq_plus_bsq.A, &asq) // = (a^2 + b^2)*R mod p
// Now asq_plus_bsq = a^2 + b^2
// Invert asq_plus_bsq
inv := asq_plus_bsq
inv.Mul(&asq_plus_bsq, &asq_plus_bsq)
inv.P34(&inv)
inv.Mul(&inv, &inv)
inv.Mul(&inv, &asq_plus_bsq)
var ac FpElementX2
fp751Mul(&ac, a, &inv.A)
fp751MontgomeryReduce(&dest.A, &ac)
var minus_b FpElement
fp751SubReduced(&minus_b, &minus_b, b)
var minus_bc FpElementX2
fp751Mul(&minus_bc, &minus_b, &inv.A)
fp751MontgomeryReduce(&dest.B, &minus_bc)
}
// In case choice == 1, performs following swap in constant time:
// xPx <-> xQx
// xPz <-> xQz
// Otherwise returns xPx, xPz, xQx, xQz unchanged
func (fp751Ops) CondSwap(xPx, xPz, xQx, xQz *Fp2Element, choice uint8) {
fp751ConditionalSwap(&xPx.A, &xQx.A, choice)
fp751ConditionalSwap(&xPx.B, &xQx.B, choice)
fp751ConditionalSwap(&xPz.A, &xQz.A, choice)
fp751ConditionalSwap(&xPz.B, &xQz.B, choice)
}
// Converts values in x.A and x.B to Montgomery domain
// x.A = x.A * R mod p
// x.B = x.B * R mod p
func (fp751Ops) ToMontgomery(x *Fp2Element) {
var aRR FpElementX2
// convert to montgomery domain
fp751Mul(&aRR, &x.A, &p751R2) // = a*R*R
fp751MontgomeryReduce(&x.A, &aRR) // = a*R mod p
fp751Mul(&aRR, &x.B, &p751R2)
fp751MontgomeryReduce(&x.B, &aRR)
}
// Converts values in x.A and x.B from Montgomery domain
// a = x.A mod p
// b = x.B mod p
//
// After returning from the call x is not modified.
func (fp751Ops) FromMontgomery(x *Fp2Element, out *Fp2Element) {
var aR FpElementX2
// convert from montgomery domain
copy(aR[:], x.A[:])
fp751MontgomeryReduce(&out.A, &aR) // = a mod p in [0, 2p)
fp751StrongReduce(&out.A) // = a mod p in [0, p)
for i := range aR {
aR[i] = 0
}
copy(aR[:], x.B[:])
fp751MontgomeryReduce(&out.B, &aR)
fp751StrongReduce(&out.B)
}
//------------------------------------------------------------------------------
// Prime Field
//------------------------------------------------------------------------------
// Represents an element of the prime field F_p in Montgomery domain
type primeFieldElement struct {
// The value `A`is represented by `aR mod p`.
A FpElement
}
// Set dest = lhs * rhs.
//
// Allowed to overlap lhs or rhs with dest.
//
// Returns dest to allow chaining operations.
func (dest *primeFieldElement) Mul(lhs, rhs *primeFieldElement) *primeFieldElement {
a := &lhs.A // = a*R
b := &rhs.A // = b*R
var ab FpElementX2
fp751Mul(&ab, a, b) // = a*b*R*R
fp751MontgomeryReduce(&dest.A, &ab) // = a*b*R mod p
return dest
}
// Set dest = x^(2^k), for k >= 1, by repeated squarings.
//
// Allowed to overlap x with dest.
//
// Returns dest to allow chaining operations.
func (dest *primeFieldElement) Pow2k(x *primeFieldElement, k uint8) *primeFieldElement {
dest.Mul(x, x)
for i := uint8(1); i < k; i++ {
dest.Mul(dest, dest)
}
return dest
}
// Set dest = x^((p-3)/4). If x is square, this is 1/sqrt(x).
//
// Allowed to overlap x with dest.
//
// Returns dest to allow chaining operations.
func (dest *primeFieldElement) P34(x *primeFieldElement) *primeFieldElement {
// Sliding-window strategy computed with Sage, awk, sed, and tr.
//
// This performs sum(powStrategy) = 744 squarings and len(mulStrategy)
// = 137 multiplications, in addition to 1 squaring and 15
// multiplications to build a lookup table.
//
// In total this is 745 squarings, 152 multiplications. Since squaring
// is not implemented for the prime field, this is 897 multiplications
// in total.
powStrategy := [137]uint8{5, 7, 6, 2, 10, 4, 6, 9, 8, 5, 9, 4, 7, 5, 5, 4, 8, 3, 9, 5, 5, 4, 10, 4, 6, 6, 6, 5, 8, 9, 3, 4, 9, 4, 5, 6, 6, 2, 9, 4, 5, 5, 5, 7, 7, 9, 4, 6, 4, 8, 5, 8, 6, 6, 2, 9, 7, 4, 8, 8, 8, 4, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 2}
mulStrategy := [137]uint8{31, 23, 21, 1, 31, 7, 7, 7, 9, 9, 19, 15, 23, 23, 11, 7, 25, 5, 21, 17, 11, 5, 17, 7, 11, 9, 23, 9, 1, 19, 5, 3, 25, 15, 11, 29, 31, 1, 29, 11, 13, 9, 11, 27, 13, 19, 15, 31, 3, 29, 23, 31, 25, 11, 1, 21, 19, 15, 15, 21, 29, 13, 23, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 3}
initialMul := uint8(27)
// Build a lookup table of odd multiples of x.
lookup := [16]primeFieldElement{}
xx := &primeFieldElement{}
xx.Mul(x, x) // Set xx = x^2
lookup[0] = *x
for i := 1; i < 16; i++ {
lookup[i].Mul(&lookup[i-1], xx)
}
// Now lookup = {x, x^3, x^5, ... }
// so that lookup[i] = x^{2*i + 1}
// so that lookup[k/2] = x^k, for odd k
*dest = lookup[initialMul/2]
for i := uint8(0); i < 137; i++ {
dest.Pow2k(dest, powStrategy[i])
dest.Mul(dest, &lookup[mulStrategy[i]/2])
}
return dest
}

View File

@ -0,0 +1,226 @@
package sidh
import (
"errors"
. "github.com/cloudflare/sidh/internal/isogeny"
"io"
)
// I keep it bool in order to be able to apply logical NOT
type KeyVariant uint
// Id's correspond to bitlength of the prime field characteristic
// Currently FP_751 is the only one supported by this implementation
const (
FP_503 uint8 = iota
FP_751
FP_964
maxPrimeFieldId
)
const (
// First 2 bits identify SIDH variant third bit indicates
// wether key is a SIKE variant (set) or SIDH (not set)
// 001 - SIDH: corresponds to 2-torsion group
KeyVariant_SIDH_A KeyVariant = 1 << 0
// 010 - SIDH: corresponds to 3-torsion group
KeyVariant_SIDH_B = 1 << 1
// 110 - SIKE
KeyVariant_SIKE = 1<<2 | KeyVariant_SIDH_B
)
// Base type for public and private key. Used mainly to carry domain
// parameters.
type key struct {
// Domain parameters of the algorithm to be used with a key
params *SidhParams
// Flag indicates wether corresponds to 2-, 3-torsion group or SIKE
keyVariant KeyVariant
}
// Defines operations on public key
type PublicKey struct {
key
affine_xP Fp2Element
affine_xQ Fp2Element
affine_xQmP Fp2Element
}
// Defines operations on private key
type PrivateKey struct {
key
// Secret key
Scalar []byte
// Used only by KEM
S []byte
}
// Accessor to the domain parameters
func (key *key) Params() *SidhParams {
return key.params
}
// Accessor to key variant
func (key *key) Variant() KeyVariant {
return key.keyVariant
}
// NewPrivateKey initializes private key.
// Usage of this function guarantees that the object is correctly initialized.
func NewPrivateKey(id uint8, v KeyVariant) *PrivateKey {
prv := &PrivateKey{key: key{params: Params(id), keyVariant: v}}
if (v & KeyVariant_SIDH_A) == KeyVariant_SIDH_A {
prv.Scalar = make([]byte, prv.params.A.SecretByteLen)
} else {
prv.Scalar = make([]byte, prv.params.B.SecretByteLen)
}
if v == KeyVariant_SIKE {
prv.S = make([]byte, prv.params.MsgLen)
}
return prv
}
// NewPublicKey initializes public key.
// Usage of this function guarantees that the object is correctly initialized.
func NewPublicKey(id uint8, v KeyVariant) *PublicKey {
return &PublicKey{key: key{params: Params(id), keyVariant: v}}
}
// Import clears content of the public key currently stored in the structure
// and imports key stored in the byte string. Returns error in case byte string
// size is wrong. Doesn't perform any validation.
func (pub *PublicKey) Import(input []byte) error {
if len(input) != pub.Size() {
return errors.New("sidh: input to short")
}
op := CurveOperations{Params: pub.params}
ssSz := pub.params.SharedSecretSize
op.Fp2FromBytes(&pub.affine_xP, input[0:ssSz])
op.Fp2FromBytes(&pub.affine_xQ, input[ssSz:2*ssSz])
op.Fp2FromBytes(&pub.affine_xQmP, input[2*ssSz:3*ssSz])
return nil
}
// Exports currently stored key. In case structure hasn't been filled with key data
// returned byte string is filled with zeros.
func (pub *PublicKey) Export() []byte {
output := make([]byte, pub.params.PublicKeySize)
op := CurveOperations{Params: pub.params}
ssSz := pub.params.SharedSecretSize
op.Fp2ToBytes(output[0:ssSz], &pub.affine_xP)
op.Fp2ToBytes(output[ssSz:2*ssSz], &pub.affine_xQ)
op.Fp2ToBytes(output[2*ssSz:3*ssSz], &pub.affine_xQmP)
return output
}
// Size returns size of the public key in bytes
func (pub *PublicKey) Size() int {
return pub.params.PublicKeySize
}
// Exports currently stored key. In case structure hasn't been filled with key data
// returned byte string is filled with zeros.
func (prv *PrivateKey) Export() []byte {
ret := make([]byte, len(prv.Scalar)+len(prv.S))
copy(ret, prv.S)
copy(ret[len(prv.S):], prv.Scalar)
return ret
}
// Size returns size of the private key in bytes
func (prv *PrivateKey) Size() int {
tmp := len(prv.Scalar)
if prv.Variant() == KeyVariant_SIKE {
tmp += int(prv.params.MsgLen)
}
return tmp
}
// Import clears content of the private key currently stored in the structure
// and imports key from octet string. In case of SIKE, the random value 'S'
// must be prepended to the value of actual private key (see SIKE spec for details).
// Function doesn't import public key value to PrivateKey object.
func (prv *PrivateKey) Import(input []byte) error {
if len(input) != prv.Size() {
return errors.New("sidh: input to short")
}
copy(prv.S, input[:len(prv.S)])
copy(prv.Scalar, input[len(prv.S):])
return nil
}
// Generates random private key for SIDH or SIKE. Generated value is
// formed as little-endian integer from key-space <2^(e2-1)..2^e2 - 1>
// for KeyVariant_A or <2^(s-1)..2^s - 1>, where s = floor(log_2(3^e3)),
// for KeyVariant_B.
//
// Returns error in case user provided RNG fails.
func (prv *PrivateKey) Generate(rand io.Reader) error {
var err error
var dp *DomainParams
if (prv.keyVariant & KeyVariant_SIDH_A) == KeyVariant_SIDH_A {
dp = &prv.params.A
} else {
dp = &prv.params.B
}
if prv.keyVariant == KeyVariant_SIKE && err == nil {
_, err = io.ReadFull(rand, prv.S)
}
// Private key generation takes advantage of the fact that keyspace for secret
// key is (0, 2^x - 1), for some possitivite value of 'x' (see SIKE, 1.3.8).
// It means that all bytes in the secret key, but the last one, can take any
// value between <0x00,0xFF>. Similarily for the last byte, but generation
// needs to chop off some bits, to make sure generated value is an element of
// a key-space.
_, err = io.ReadFull(rand, prv.Scalar)
if err != nil {
return err
}
prv.Scalar[len(prv.Scalar)-1] &= (1 << (dp.SecretBitLen % 8)) - 1
// Make sure scalar is SecretBitLen long. SIKE spec says that key
// space starts from 0, but I'm not confortable with having low
// value scalars used for private keys. It is still secrure as per
// table 5.1 in [SIKE].
prv.Scalar[len(prv.Scalar)-1] |= 1 << ((dp.SecretBitLen % 8) - 1)
return err
}
// Generates public key.
//
// Constant time.
func (prv *PrivateKey) GeneratePublicKey() *PublicKey {
if (prv.keyVariant & KeyVariant_SIDH_A) == KeyVariant_SIDH_A {
return publicKeyGenA(prv)
}
return publicKeyGenB(prv)
}
// Computes a shared secret which is a j-invariant. Function requires that pub has
// different KeyVariant than prv. Length of returned output is 2*ceil(log_2 P)/8),
// where P is a prime defining finite field.
//
// It's important to notice that each keypair must not be used more than once
// to calculate shared secret.
//
// Function may return error. This happens only in case provided input is invalid.
// Constant time for properly initialized private and public key.
func DeriveSecret(prv *PrivateKey, pub *PublicKey) ([]byte, error) {
if (pub == nil) || (prv == nil) {
return nil, errors.New("sidh: invalid arguments")
}
if (pub.keyVariant == prv.keyVariant) || (pub.params.Id != prv.params.Id) {
return nil, errors.New("sidh: public and private are incompatbile")
}
if (prv.keyVariant & KeyVariant_SIDH_A) == KeyVariant_SIDH_A {
return deriveSecretA(prv, pub), nil
} else {
return deriveSecretB(prv, pub), nil
}
}

View File

@ -0,0 +1,82 @@
package sidh
import (
. "github.com/cloudflare/sidh/internal/isogeny"
p503 "github.com/cloudflare/sidh/p503"
p751 "github.com/cloudflare/sidh/p751"
)
// Keeps mapping: SIDH prime field ID to domain parameters
var sidhParams = make(map[uint8]SidhParams)
// Params returns domain parameters corresponding to finite field and identified by
// `id` provieded by the caller. Function panics in case `id` wasn't registered earlier.
func Params(id uint8) *SidhParams {
if val, ok := sidhParams[id]; ok {
return &val
}
panic("sidh: SIDH Params ID unregistered")
}
func init() {
p503 := SidhParams{
Id: FP_503,
PublicKeySize: p503.P503_PublicKeySize,
SharedSecretSize: p503.P503_SharedSecretSize,
A: DomainParams{
Affine_P: p503.P503_affine_PA,
Affine_Q: p503.P503_affine_QA,
Affine_R: p503.P503_affine_RA,
SecretBitLen: p503.P503_SecretBitLenA,
SecretByteLen: uint((p503.P503_SecretBitLenA + 7) / 8),
IsogenyStrategy: p503.P503_AliceIsogenyStrategy[:],
},
B: DomainParams{
Affine_P: p503.P503_affine_PB,
Affine_Q: p503.P503_affine_QB,
Affine_R: p503.P503_affine_RB,
SecretBitLen: p503.P503_SecretBitLenB,
SecretByteLen: uint((p503.P503_SecretBitLenB + 7) / 8),
IsogenyStrategy: p503.P503_BobIsogenyStrategy[:],
},
OneFp2: p503.P503_OneFp2,
HalfFp2: p503.P503_HalfFp2,
MsgLen: 24,
// SIKEp751 provides 128 bit of classical security ([SIKE], 5.1)
KemSize: 16,
Bytelen: p503.P503_Bytelen,
Op: p503.FieldOperations(),
}
p751 := SidhParams{
Id: FP_751,
PublicKeySize: p751.P751_PublicKeySize,
SharedSecretSize: p751.P751_SharedSecretSize,
A: DomainParams{
Affine_P: p751.P751_affine_PA,
Affine_Q: p751.P751_affine_QA,
Affine_R: p751.P751_affine_RA,
IsogenyStrategy: p751.P751_AliceIsogenyStrategy[:],
SecretBitLen: p751.P751_SecretBitLenA,
SecretByteLen: uint((p751.P751_SecretBitLenA + 7) / 8),
},
B: DomainParams{
Affine_P: p751.P751_affine_PB,
Affine_Q: p751.P751_affine_QB,
Affine_R: p751.P751_affine_RB,
IsogenyStrategy: p751.P751_BobIsogenyStrategy[:],
SecretBitLen: p751.P751_SecretBitLenB,
SecretByteLen: uint((p751.P751_SecretBitLenB + 7) / 8),
},
OneFp2: p751.P751_OneFp2,
HalfFp2: p751.P751_HalfFp2,
MsgLen: 32,
// SIKEp751 provides 192 bit of classical security ([SIKE], 5.1)
KemSize: 24,
Bytelen: p751.P751_Bytelen,
Op: p751.FieldOperations(),
}
sidhParams[FP_503] = p503
sidhParams[FP_751] = p751
}

View File

@ -0,0 +1,302 @@
package sidh
import (
. "github.com/cloudflare/sidh/internal/isogeny"
)
// -----------------------------------------------------------------------------
// Functions for traversing isogeny trees acoording to strategy. Key type 'A' is
//
// Traverses isogeny tree in order to compute xR, xP, xQ and xQmP needed
// for public key generation.
func traverseTreePublicKeyA(curve *ProjectiveCurveParameters, xR, phiP, phiQ, phiR *ProjectivePoint, pub *PublicKey) {
var points = make([]ProjectivePoint, 0, 8)
var indices = make([]int, 0, 8)
var i, sidx int
var op = CurveOperations{Params: pub.params}
cparam := op.CalcCurveParamsEquiv4(curve)
phi := Newisogeny4(op.Params.Op)
strat := pub.params.A.IsogenyStrategy
stratSz := len(strat)
for j := 1; j <= stratSz; j++ {
for i <= stratSz-j {
points = append(points, *xR)
indices = append(indices, i)
k := strat[sidx]
sidx++
op.Pow2k(xR, &cparam, 2*k)
i += int(k)
}
cparam = phi.GenerateCurve(xR)
for k := 0; k < len(points); k++ {
points[k] = phi.EvaluatePoint(&points[k])
}
*phiP = phi.EvaluatePoint(phiP)
*phiQ = phi.EvaluatePoint(phiQ)
*phiR = phi.EvaluatePoint(phiR)
// pop xR from points
*xR, points = points[len(points)-1], points[:len(points)-1]
i, indices = int(indices[len(indices)-1]), indices[:len(indices)-1]
}
}
// Traverses isogeny tree in order to compute xR needed
// for public key generation.
func traverseTreeSharedKeyA(curve *ProjectiveCurveParameters, xR *ProjectivePoint, pub *PublicKey) {
var points = make([]ProjectivePoint, 0, 8)
var indices = make([]int, 0, 8)
var i, sidx int
var op = CurveOperations{Params: pub.params}
cparam := op.CalcCurveParamsEquiv4(curve)
phi := Newisogeny4(op.Params.Op)
strat := pub.params.A.IsogenyStrategy
stratSz := len(strat)
for j := 1; j <= stratSz; j++ {
for i <= stratSz-j {
points = append(points, *xR)
indices = append(indices, i)
k := strat[sidx]
sidx++
op.Pow2k(xR, &cparam, 2*k)
i += int(k)
}
cparam = phi.GenerateCurve(xR)
for k := 0; k < len(points); k++ {
points[k] = phi.EvaluatePoint(&points[k])
}
// pop xR from points
*xR, points = points[len(points)-1], points[:len(points)-1]
i, indices = int(indices[len(indices)-1]), indices[:len(indices)-1]
}
}
// Traverses isogeny tree in order to compute xR, xP, xQ and xQmP needed
// for public key generation.
func traverseTreePublicKeyB(curve *ProjectiveCurveParameters, xR, phiP, phiQ, phiR *ProjectivePoint, pub *PublicKey) {
var points = make([]ProjectivePoint, 0, 8)
var indices = make([]int, 0, 8)
var i, sidx int
var op = CurveOperations{Params: pub.params}
cparam := op.CalcCurveParamsEquiv3(curve)
phi := Newisogeny3(op.Params.Op)
strat := pub.params.B.IsogenyStrategy
stratSz := len(strat)
for j := 1; j <= stratSz; j++ {
for i <= stratSz-j {
points = append(points, *xR)
indices = append(indices, i)
k := strat[sidx]
sidx++
op.Pow3k(xR, &cparam, k)
i += int(k)
}
cparam = phi.GenerateCurve(xR)
for k := 0; k < len(points); k++ {
points[k] = phi.EvaluatePoint(&points[k])
}
*phiP = phi.EvaluatePoint(phiP)
*phiQ = phi.EvaluatePoint(phiQ)
*phiR = phi.EvaluatePoint(phiR)
// pop xR from points
*xR, points = points[len(points)-1], points[:len(points)-1]
i, indices = int(indices[len(indices)-1]), indices[:len(indices)-1]
}
}
// Traverses isogeny tree in order to compute xR, xP, xQ and xQmP needed
// for public key generation.
func traverseTreeSharedKeyB(curve *ProjectiveCurveParameters, xR *ProjectivePoint, pub *PublicKey) {
var points = make([]ProjectivePoint, 0, 8)
var indices = make([]int, 0, 8)
var i, sidx int
var op = CurveOperations{Params: pub.params}
cparam := op.CalcCurveParamsEquiv3(curve)
phi := Newisogeny3(op.Params.Op)
strat := pub.params.B.IsogenyStrategy
stratSz := len(strat)
for j := 1; j <= stratSz; j++ {
for i <= stratSz-j {
points = append(points, *xR)
indices = append(indices, i)
k := strat[sidx]
sidx++
op.Pow3k(xR, &cparam, k)
i += int(k)
}
cparam = phi.GenerateCurve(xR)
for k := 0; k < len(points); k++ {
points[k] = phi.EvaluatePoint(&points[k])
}
// pop xR from points
*xR, points = points[len(points)-1], points[:len(points)-1]
i, indices = int(indices[len(indices)-1]), indices[:len(indices)-1]
}
}
// Generate a public key in the 2-torsion group
func publicKeyGenA(prv *PrivateKey) (pub *PublicKey) {
var xPA, xQA, xRA ProjectivePoint
var xPB, xQB, xRB, xR ProjectivePoint
var invZP, invZQ, invZR Fp2Element
var tmp ProjectiveCurveParameters
pub = NewPublicKey(prv.params.Id, KeyVariant_SIDH_A)
var op = CurveOperations{Params: pub.params}
var phi = Newisogeny4(op.Params.Op)
// Load points for A
xPA = ProjectivePoint{X: prv.params.A.Affine_P, Z: prv.params.OneFp2}
xQA = ProjectivePoint{X: prv.params.A.Affine_Q, Z: prv.params.OneFp2}
xRA = ProjectivePoint{X: prv.params.A.Affine_R, Z: prv.params.OneFp2}
// Load points for B
xRB = ProjectivePoint{X: prv.params.B.Affine_R, Z: prv.params.OneFp2}
xQB = ProjectivePoint{X: prv.params.B.Affine_Q, Z: prv.params.OneFp2}
xPB = ProjectivePoint{X: prv.params.B.Affine_P, Z: prv.params.OneFp2}
// Find isogeny kernel
tmp.C = pub.params.OneFp2
xR = op.ScalarMul3Pt(&tmp, &xPA, &xQA, &xRA, prv.params.A.SecretBitLen, prv.Scalar)
// Reset params object and travers isogeny tree
tmp.C = pub.params.OneFp2
tmp.A.Zeroize()
traverseTreePublicKeyA(&tmp, &xR, &xPB, &xQB, &xRB, pub)
// Secret isogeny
phi.GenerateCurve(&xR)
xPA = phi.EvaluatePoint(&xPB)
xQA = phi.EvaluatePoint(&xQB)
xRA = phi.EvaluatePoint(&xRB)
op.Fp2Batch3Inv(&xPA.Z, &xQA.Z, &xRA.Z, &invZP, &invZQ, &invZR)
op.Params.Op.Mul(&pub.affine_xP, &xPA.X, &invZP)
op.Params.Op.Mul(&pub.affine_xQ, &xQA.X, &invZQ)
op.Params.Op.Mul(&pub.affine_xQmP, &xRA.X, &invZR)
return
}
// Generate a public key in the 3-torsion group
func publicKeyGenB(prv *PrivateKey) (pub *PublicKey) {
var xPB, xQB, xRB, xR ProjectivePoint
var xPA, xQA, xRA ProjectivePoint
var invZP, invZQ, invZR Fp2Element
var tmp ProjectiveCurveParameters
pub = NewPublicKey(prv.params.Id, prv.keyVariant)
var op = CurveOperations{Params: pub.params}
var phi = Newisogeny3(op.Params.Op)
// Load points for B
xRB = ProjectivePoint{X: prv.params.B.Affine_R, Z: prv.params.OneFp2}
xQB = ProjectivePoint{X: prv.params.B.Affine_Q, Z: prv.params.OneFp2}
xPB = ProjectivePoint{X: prv.params.B.Affine_P, Z: prv.params.OneFp2}
// Load points for A
xPA = ProjectivePoint{X: prv.params.A.Affine_P, Z: prv.params.OneFp2}
xQA = ProjectivePoint{X: prv.params.A.Affine_Q, Z: prv.params.OneFp2}
xRA = ProjectivePoint{X: prv.params.A.Affine_R, Z: prv.params.OneFp2}
tmp.C = pub.params.OneFp2
xR = op.ScalarMul3Pt(&tmp, &xPB, &xQB, &xRB, prv.params.B.SecretBitLen, prv.Scalar)
tmp.C = pub.params.OneFp2
tmp.A.Zeroize()
traverseTreePublicKeyB(&tmp, &xR, &xPA, &xQA, &xRA, pub)
phi.GenerateCurve(&xR)
xPB = phi.EvaluatePoint(&xPA)
xQB = phi.EvaluatePoint(&xQA)
xRB = phi.EvaluatePoint(&xRA)
op.Fp2Batch3Inv(&xPB.Z, &xQB.Z, &xRB.Z, &invZP, &invZQ, &invZR)
op.Params.Op.Mul(&pub.affine_xP, &xPB.X, &invZP)
op.Params.Op.Mul(&pub.affine_xQ, &xQB.X, &invZQ)
op.Params.Op.Mul(&pub.affine_xQmP, &xRB.X, &invZR)
return
}
// -----------------------------------------------------------------------------
// Key agreement functions
//
// Establishing shared keys in in 2-torsion group
func deriveSecretA(prv *PrivateKey, pub *PublicKey) []byte {
var sharedSecret = make([]byte, pub.params.SharedSecretSize)
var cparam ProjectiveCurveParameters
var xP, xQ, xQmP ProjectivePoint
var xR ProjectivePoint
var op = CurveOperations{Params: prv.params}
var phi = Newisogeny4(op.Params.Op)
// Recover curve coefficients
cparam.C = pub.params.OneFp2
op.RecoverCoordinateA(&cparam, &pub.affine_xP, &pub.affine_xQ, &pub.affine_xQmP)
// Find kernel of the morphism
xP = ProjectivePoint{X: pub.affine_xP, Z: pub.params.OneFp2}
xQ = ProjectivePoint{X: pub.affine_xQ, Z: pub.params.OneFp2}
xQmP = ProjectivePoint{X: pub.affine_xQmP, Z: pub.params.OneFp2}
xR = op.ScalarMul3Pt(&cparam, &xP, &xQ, &xQmP, pub.params.A.SecretBitLen, prv.Scalar)
// Traverse isogeny tree
traverseTreeSharedKeyA(&cparam, &xR, pub)
// Calculate j-invariant on isogeneus curve
c := phi.GenerateCurve(&xR)
op.RecoverCurveCoefficients4(&cparam, &c)
op.Jinvariant(&cparam, sharedSecret)
return sharedSecret
}
// Establishing shared keys in in 3-torsion group
func deriveSecretB(prv *PrivateKey, pub *PublicKey) []byte {
var sharedSecret = make([]byte, pub.params.SharedSecretSize)
var xP, xQ, xQmP ProjectivePoint
var xR ProjectivePoint
var cparam ProjectiveCurveParameters
var op = CurveOperations{Params: prv.params}
var phi = Newisogeny3(op.Params.Op)
// Recover curve coefficients
cparam.C = pub.params.OneFp2
op.RecoverCoordinateA(&cparam, &pub.affine_xP, &pub.affine_xQ, &pub.affine_xQmP)
// Find kernel of the morphism
xP = ProjectivePoint{X: pub.affine_xP, Z: pub.params.OneFp2}
xQ = ProjectivePoint{X: pub.affine_xQ, Z: pub.params.OneFp2}
xQmP = ProjectivePoint{X: pub.affine_xQmP, Z: pub.params.OneFp2}
xR = op.ScalarMul3Pt(&cparam, &xP, &xQ, &xQmP, pub.params.B.SecretBitLen, prv.Scalar)
// Traverse isogeny tree
traverseTreeSharedKeyB(&cparam, &xR, pub)
// Calculate j-invariant on isogeneus curve
c := phi.GenerateCurve(&xR)
op.RecoverCurveCoefficients3(&cparam, &c)
op.Jinvariant(&cparam, sharedSecret)
return sharedSecret
}

View File

@ -22,6 +22,7 @@ import (
"sync/atomic"
"time"
"github.com/cloudflare/sidh/sidh"
"golang.org/x/crypto/curve25519"
)
@ -31,6 +32,17 @@ const numSessionTickets = 2
type secretLabel int
const (
x25519SharedSecretSz = 32
P503PubKeySz = 378
P503PrvKeySz = 32
P503SharedSecretSz = 126
SIDHp503Curve25519PubKeySz = x25519SharedSecretSz + P503PubKeySz
SIDHp503Curve25519PrvKeySz = x25519SharedSecretSz + P503PrvKeySz
SIDHp503Curve25519SharedKeySz = x25519SharedSecretSz + P503SharedSecretSz
)
const (
secretResumptionPskBinder secretLabel = iota
secretEarlyClient
@ -50,6 +62,40 @@ type keySchedule13 struct {
config *Config // Used for KeyLogWriter callback, nil if keylogging is disabled.
}
// Interface implemented by DH key exchange strategies
type dhKex interface {
// c - context of current TLS handshake, groupId - ID of an algorithm
// (curve/field) being chosen for key agreement. Methods implmenting an
// interface always assume that provided groupId is correct.
//
// In case of success, function returns secret key and ephemeral key. Otherwise
// error is set.
generate(c *Conn, groupId CurveID) ([]byte, keyShare, error)
// c - context of current TLS handshake, ks - public key received
// from the other side of the connection, secretKey - is a private key
// used for DH key agreement. Function returns shared secret in case
// of success or empty slice otherwise.
derive(c *Conn, ks keyShare, secretKey []byte) []byte
}
// Key Exchange strategies per curve type
type kexNist struct{} // Used by NIST curves; P-256, P-384, P-512
type kexX25519 struct{} // Used by X25519
type kexSIDHp503 struct{} // Used by SIDH/P503
type kexHybridSIDHp503X25519 struct {
classicKEX kexX25519
pqKEX kexSIDHp503
} // Used by SIDH-ECDH hybrid scheme
// Routing map for key exchange strategies
var dhKexStrat = map[CurveID]dhKex{
CurveP256: &kexNist{},
CurveP384: &kexNist{},
CurveP521: &kexNist{},
X25519: &kexX25519{},
HybridSIDHp503Curve25519: &kexHybridSIDHp503X25519{},
}
func newKeySchedule13(suite *cipherSuite, config *Config, clientRandom []byte) *keySchedule13 {
if config.KeyLogWriter == nil {
clientRandom = nil
@ -70,9 +116,18 @@ func (ks *keySchedule13) setSecret(secret []byte) {
salt := ks.secret
if salt != nil {
h0 := hash.New().Sum(nil)
salt = hkdfExpandLabel(hash, salt, h0, "derived", hash.Size())
salt = HkdfExpandLabel(hash, salt, h0, "derived", hash.Size())
}
ks.secret = hkdfExtract(hash, secret, salt)
ks.secret = HkdfExtract(hash, secret, salt)
}
// Depending on role returns pair of key variant to be used by
// local and remote process.
func getSidhKeyVariant(isClient bool) (sidh.KeyVariant, sidh.KeyVariant) {
if isClient {
return sidh.KeyVariant_SIDH_A, sidh.KeyVariant_SIDH_B
}
return sidh.KeyVariant_SIDH_B, sidh.KeyVariant_SIDH_A
}
// write appends the data to the transcript hash context.
@ -113,7 +168,7 @@ func (ks *keySchedule13) deriveSecret(secretLabel secretLabel) []byte {
ks.handshakeCtx = ks.transcriptHash.Sum(nil)
}
hash := hashForSuite(ks.suite)
secret := hkdfExpandLabel(hash, ks.secret, ks.handshakeCtx, label, hash.Size())
secret := HkdfExpandLabel(hash, ks.secret, ks.handshakeCtx, label, hash.Size())
if keylogType != "" && ks.config != nil {
ks.config.writeKeyLog(keylogType, ks.clientRandom, secret)
}
@ -122,8 +177,8 @@ func (ks *keySchedule13) deriveSecret(secretLabel secretLabel) []byte {
func (ks *keySchedule13) prepareCipher(trafficSecret []byte) cipher.AEAD {
hash := hashForSuite(ks.suite)
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", ks.suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", ks.suite.ivLen)
key := HkdfExpandLabel(hash, trafficSecret, nil, "key", ks.suite.keyLen)
iv := HkdfExpandLabel(hash, trafficSecret, nil, "iv", ks.suite.ivLen)
return ks.suite.aead(key, iv)
}
@ -152,7 +207,7 @@ CurvePreferenceLoop:
return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo)
}
privateKey, serverKS, err := config.generateKeyShare(ks.group)
privateKey, serverKS, err := c.generateKeyShare(ks.group)
if err != nil {
c.sendAlert(alertInternalError)
return err
@ -180,7 +235,7 @@ CurvePreferenceLoop:
earlyClientTrafficSecret := hs.keySchedule.deriveSecret(secretEarlyClient)
ecdheSecret := deriveECDHESecret(ks, privateKey)
ecdheSecret := c.deriveDHESecret(ks, privateKey)
if ecdheSecret == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: bad ECDHE client share")
@ -201,8 +256,8 @@ CurvePreferenceLoop:
hsServerTrafficSecret := hs.keySchedule.deriveSecret(secretHandshakeServer)
c.out.setKey(c.vers, hs.keySchedule.suite, hsServerTrafficSecret)
serverFinishedKey := hkdfExpandLabel(hash, hsServerTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, hs.hsClientTrafficSecret, nil, "finished", hashSize)
serverFinishedKey := HkdfExpandLabel(hash, hsServerTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = HkdfExpandLabel(hash, hs.hsClientTrafficSecret, nil, "finished", hashSize)
// EncryptedExtensions
hs.keySchedule.write(hs.hello13Enc.marshal())
@ -296,11 +351,11 @@ func (hs *serverHandshakeState) readClientFinished13(hasConfirmLock bool) error
}
// client authentication
if certMsg, ok := msg.(*certificateMsg13); ok {
// (4.4.2) Client MUST send certificate msg if requested by server
if c.config.ClientAuth < RequestClientCert {
c.sendAlert(alertUnexpectedMessage)
// (4.4.2) Client MUST send certificate msg if requested by server
if c.config.ClientAuth >= RequestClientCert && !c.didResume {
certMsg, ok := msg.(*certificateMsg13)
if !ok {
c.sendAlert(alertCertificateRequired)
return unexpectedMessageError(certMsg, msg)
}
@ -311,39 +366,37 @@ func (hs *serverHandshakeState) readClientFinished13(hasConfirmLock bool) error
return err
}
// 4.4.3: CertificateVerify MUST appear immediately after Certificate msg
msg, err = c.readHandshake()
if err != nil {
return err
}
if len(certs) > 0 {
// 4.4.3: CertificateVerify MUST appear immediately after Certificate msg
msg, err = c.readHandshake()
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
err, alertCode := verifyPeerHandshakeSignature(
certVerify,
pubKey,
supportedSignatureAlgorithms13,
hs.keySchedule.transcriptHash.Sum(nil),
"TLS 1.3, client CertificateVerify")
if err != nil {
c.sendAlert(alertCode)
return err
err, alertCode := verifyPeerHandshakeSignature(
certVerify,
pubKey,
supportedSignatureAlgorithms13,
hs.keySchedule.transcriptHash.Sum(nil),
"TLS 1.3, client CertificateVerify")
if err != nil {
c.sendAlert(alertCode)
return err
}
hs.keySchedule.write(certVerify.marshal())
}
hs.keySchedule.write(certVerify.marshal())
// Read next chunk
msg, err = c.readHandshake()
if err != nil {
return err
}
} else if (c.config.ClientAuth >= RequestClientCert) && !c.didResume {
c.sendAlert(alertCertificateRequired)
return unexpectedMessageError(certMsg, msg)
}
clientFinished, ok := msg.(*finishedMsg)
@ -545,64 +598,26 @@ func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byt
return h.Sum(nil)
}
func (c *Config) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) {
if curveID == X25519 {
var scalar, public [32]byte
if _, err := io.ReadFull(c.rand(), scalar[:]); err != nil {
return nil, keyShare{}, err
}
curve25519.ScalarBaseMult(&public, &scalar)
return scalar[:], keyShare{group: curveID, data: public[:]}, nil
// generateKeyShare generates keypair. Private key is returned as first argument, public key
// is returned in keyShare.data. keyshare.curveID stores ID of the scheme used.
func (c *Conn) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) {
if val, ok := dhKexStrat[curveID]; ok {
return val.generate(c, curveID)
}
curve, ok := curveForCurveID(curveID)
if !ok {
return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve")
}
privateKey, x, y, err := elliptic.GenerateKey(curve, c.rand())
if err != nil {
return nil, keyShare{}, err
}
ecdhePublic := elliptic.Marshal(curve, x, y)
return privateKey, keyShare{group: curveID, data: ecdhePublic}, nil
return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve")
}
func deriveECDHESecret(ks keyShare, secretKey []byte) []byte {
if ks.group == X25519 {
if len(ks.data) != 32 {
return nil
}
var theirPublic, sharedKey, scalar [32]byte
copy(theirPublic[:], ks.data)
copy(scalar[:], secretKey)
curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
return sharedKey[:]
// DH key agreement. ks stores public key, secretKey stores private key used for ephemeral
// key agreement. Function returns shared secret in case of success or empty slice otherwise.
func (c *Conn) deriveDHESecret(ks keyShare, secretKey []byte) []byte {
if val, ok := dhKexStrat[ks.group]; ok {
return val.derive(c, ks, secretKey)
}
curve, ok := curveForCurveID(ks.group)
if !ok {
return nil
}
x, y := elliptic.Unmarshal(curve, ks.data)
if x == nil {
return nil
}
x, _ = curve.ScalarMult(x, y, secretKey)
xBytes := x.Bytes()
curveSize := (curve.Params().BitSize + 8 - 1) >> 3
if len(xBytes) == curveSize {
return xBytes
}
buf := make([]byte, curveSize)
copy(buf[len(buf)-len(xBytes):], xBytes)
return buf
return nil
}
func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
// HkdfExpandLabel HKDF expands a label
func HkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
prefix := "tls13 "
hkdfLabel := make([]byte, 4+len(prefix)+len(label)+len(hashValue))
hkdfLabel[0] = byte(L >> 8)
@ -695,7 +710,7 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
hs.keySchedule.setSecret(s.pskSecret)
binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
binderFinishedKey := HkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New()
chHash.Write(hs.clientHello.rawTruncated)
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
@ -766,7 +781,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
// tickets might have the same PSK which could be a problem if
// one of them is compromised.
ticketNonce := []byte{byte(i)}
sessionState.pskSecret = hkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size())
sessionState.pskSecret = HkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size())
ticket := sessionState.marshal()
var err error
if c.config.SessionTicketSealer != nil {
@ -978,7 +993,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// 0-RTT is not supported yet, so use an empty PSK.
hs.keySchedule.setSecret(nil)
ecdheSecret := deriveECDHESecret(serverHello.keyShare, hs.privateKey)
ecdheSecret := c.deriveDHESecret(serverHello.keyShare, hs.privateKey)
if ecdheSecret == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: bad ECDHE server share")
@ -996,8 +1011,8 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
c.in.setKey(c.vers, hs.keySchedule.suite, serverHandshakeSecret)
// Calculate MAC key for Finished messages.
serverFinishedKey := hkdfExpandLabel(hash, serverHandshakeSecret, nil, "finished", hashSize)
clientFinishedKey := hkdfExpandLabel(hash, clientHandshakeSecret, nil, "finished", hashSize)
serverFinishedKey := HkdfExpandLabel(hash, serverHandshakeSecret, nil, "finished", hashSize)
clientFinishedKey := HkdfExpandLabel(hash, clientHandshakeSecret, nil, "finished", hashSize)
msg, err := c.readHandshake()
if err != nil {
@ -1160,3 +1175,138 @@ func supportedSigAlgorithmsCert(schemes []SignatureScheme) (ret []SignatureSchem
}
return
}
// Functions below implement dhKex interface for different DH shared secret agreements
// KEX: P-256, P-384, P-512 KEX
func (kexNist) generate(c *Conn, groupId CurveID) (private []byte, ks keyShare, err error) {
// never fails
curve, _ := curveForCurveID(groupId)
private, x, y, err := elliptic.GenerateKey(curve, c.config.rand())
if err != nil {
return nil, keyShare{}, err
}
ks.group = groupId
ks.data = elliptic.Marshal(curve, x, y)
return
}
func (kexNist) derive(c *Conn, ks keyShare, secretKey []byte) []byte {
// never fails
curve, _ := curveForCurveID(ks.group)
x, y := elliptic.Unmarshal(curve, ks.data)
if x == nil {
return nil
}
x, _ = curve.ScalarMult(x, y, secretKey)
xBytes := x.Bytes()
curveSize := (curve.Params().BitSize + 8 - 1) >> 3
if len(xBytes) == curveSize {
return xBytes
}
buf := make([]byte, curveSize)
copy(buf[len(buf)-len(xBytes):], xBytes)
return buf
}
// KEX: X25519
func (kexX25519) generate(c *Conn, groupId CurveID) ([]byte, keyShare, error) {
var scalar, public [x25519SharedSecretSz]byte
if _, err := io.ReadFull(c.config.rand(), scalar[:]); err != nil {
return nil, keyShare{}, err
}
curve25519.ScalarBaseMult(&public, &scalar)
return scalar[:], keyShare{group: X25519, data: public[:]}, nil
}
func (kexX25519) derive(c *Conn, ks keyShare, secretKey []byte) []byte {
var theirPublic, sharedKey, scalar [x25519SharedSecretSz]byte
if len(ks.data) != x25519SharedSecretSz {
return nil
}
copy(theirPublic[:], ks.data)
copy(scalar[:], secretKey)
curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
return sharedKey[:]
}
// KEX: SIDH/503
func (kexSIDHp503) generate(c *Conn, groupId CurveID) ([]byte, keyShare, error) {
var variant, _ = getSidhKeyVariant(c.isClient)
var prvKey = sidh.NewPrivateKey(sidh.FP_503, variant)
if prvKey.Generate(c.config.rand()) != nil {
return nil, keyShare{}, errors.New("tls: private SIDH key generation failed")
}
pubKey := prvKey.GeneratePublicKey()
return prvKey.Export(), keyShare{group: 0 /*UNUSED*/, data: pubKey.Export()}, nil
}
func (kexSIDHp503) derive(c *Conn, ks keyShare, key []byte) []byte {
var prvVariant, pubVariant = getSidhKeyVariant(c.isClient)
var prvKeySize = P503PrvKeySz
if len(ks.data) != P503PubKeySz || len(key) != prvKeySize {
return nil
}
prvKey := sidh.NewPrivateKey(sidh.FP_503, prvVariant)
pubKey := sidh.NewPublicKey(sidh.FP_503, pubVariant)
if err := prvKey.Import(key); err != nil {
return nil
}
if err := pubKey.Import(ks.data); err != nil {
return nil
}
// Never fails
sharedKey, _ := sidh.DeriveSecret(prvKey, pubKey)
return sharedKey
}
// KEX Hybrid SIDH/503-X25519
func (kex *kexHybridSIDHp503X25519) generate(c *Conn, groupId CurveID) (private []byte, ks keyShare, err error) {
var pubHybrid [SIDHp503Curve25519PubKeySz]byte
var prvHybrid [SIDHp503Curve25519PrvKeySz]byte
// Generate ephemeral key for classic x25519
private, ks, err = kex.classicKEX.generate(c, groupId)
if err != nil {
return
}
copy(prvHybrid[:], private)
copy(pubHybrid[:], ks.data)
// Generate PQ ephemeral key for SIDH
private, ks, err = kex.pqKEX.generate(c, groupId)
if err != nil {
return
}
copy(prvHybrid[x25519SharedSecretSz:], private)
copy(pubHybrid[x25519SharedSecretSz:], ks.data)
return prvHybrid[:], keyShare{group: HybridSIDHp503Curve25519, data: pubHybrid[:]}, nil
}
func (kex *kexHybridSIDHp503X25519) derive(c *Conn, ks keyShare, key []byte) []byte {
var sharedKey [SIDHp503Curve25519SharedKeySz]byte
var ret []byte
var tmpKs keyShare
// Key agreement for classic
tmpKs.group = X25519
tmpKs.data = ks.data[:x25519SharedSecretSz]
ret = kex.classicKEX.derive(c, tmpKs, key[:x25519SharedSecretSz])
if ret == nil {
return nil
}
copy(sharedKey[:], ret)
// Key agreement for PQ
tmpKs.group = 0 /*UNUSED*/
tmpKs.data = ks.data[x25519SharedSecretSz:]
ret = kex.pqKEX.derive(c, tmpKs, key[x25519SharedSecretSz:])
if ret == nil {
return nil
}
copy(sharedKey[x25519SharedSecretSz:], ret)
return sharedKey[:]
}

View File

@ -0,0 +1,63 @@
Copyright (c) 2018 Cloudflare. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Cloudflare nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
========================================================================
The code for TLSv1.2 and older TLS versions was derived from the
Golang standard library <https://golang.org/src/crypto/tls/>, available
under the following BSD license:
========================================================================
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -14,9 +14,11 @@ import (
)
// pickSignatureAlgorithm selects a signature algorithm that is compatible with
// the given public key and the list of algorithms from the peer and this side.
// the given public key and the list of algorithms from both sides of connection.
// The lists of signature algorithms (peerSigAlgs and ourSigAlgs) are ignored
// for tlsVersion < VersionTLS12.
//
// The returned SignatureScheme codepoint is only meaningful for TLS 1.2,
// The returned SignatureScheme codepoint is only meaningful for TLS 1.2 and newer
// previous TLS versions have a fixed hash function.
func pickSignatureAlgorithm(pubkey crypto.PublicKey, peerSigAlgs, ourSigAlgs []SignatureScheme, tlsVersion uint16) (SignatureScheme, uint8, crypto.Hash, error) {
if tlsVersion < VersionTLS12 || len(peerSigAlgs) == 0 {

View File

@ -124,6 +124,9 @@ const (
CurveP384 = tls.CurveP384
CurveP521 = tls.CurveP521
X25519 = tls.X25519
// Experimental KEX
HybridSIDHp503Curve25519 CurveID = 0xFE30
)
// TLS 1.3 Key Share
@ -168,9 +171,10 @@ const (
// Rest of these are reserved by the TLS spec
)
// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1)
// Signature algorithms (for internal signaling use). Starting at 16 to avoid overlap with
// TLS 1.2 codepoints (RFC 5246, section A.4.1), with which these have nothing to do.
const (
signaturePKCS1v15 uint8 = iota + 1
signaturePKCS1v15 uint8 = iota + 16
signatureECDSA
signatureRSAPSS
)
@ -517,7 +521,8 @@ type Config struct {
PreferServerCipherSuites bool
// SessionTicketsDisabled may be set to true to disable session ticket
// (resumption) support.
// (resumption) support. Note that on clients, session ticket support is
// also disabled if ClientSessionCache is nil.
SessionTicketsDisabled bool
// SessionTicketKey is used by TLS servers to provide session
@ -531,7 +536,7 @@ type Config struct {
SessionTicketKey [32]byte
// ClientSessionCache is a cache of ClientSessionState entries for TLS
// session resumption.
// session resumption. It is only used by clients.
ClientSessionCache ClientSessionCache
// MinVersion contains the minimum SSL/TLS version that is acceptable.
@ -1106,9 +1111,19 @@ func defaultTLS13CipherSuites() []uint16 {
func initDefaultCipherSuites() {
var topCipherSuites, topTLS13CipherSuites []uint16
// TODO: check for hardware support
// This used to be: if cipherhw.AESGCMSupport() {
// However, cipherhw is an internal package
// Check the cpu flags for each platform that has optimized GCM implementations.
// Worst case, these variables will just all be false
// hasGCMAsmAMD64 := cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
// hasGCMAsmARM64 := cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
// // Keep in sync with crypto/aes/cipher_s390x.go.
// hasGCMAsmS390X := cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
// hasGCMAsm := hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X
if true {
// If AES-GCM hardware is provided then prioritise AES-GCM
// cipher suites.

View File

@ -241,8 +241,8 @@ func (hc *halfConn) setKey(version uint16, suite *cipherSuite, trafficSecret []b
}
hc.version = version
hash := hashForSuite(suite)
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", suite.ivLen)
key := HkdfExpandLabel(hash, trafficSecret, nil, "key", suite.keyLen)
iv := HkdfExpandLabel(hash, trafficSecret, nil, "iv", suite.ivLen)
hc.cipher = suite.aead(key, iv)
for i := range hc.seq {
hc.seq[i] = 0

View File

@ -198,7 +198,7 @@ func (c *Conn) clientHandshake() error {
// Create one keyshare for the first default curve. If it is not
// appropriate, the server should raise a HRR.
defaultGroup := c.config.curvePreferences()[0]
hs.privateKey, clientKS, err = c.config.generateKeyShare(defaultGroup)
hs.privateKey, clientKS, err = c.generateKeyShare(defaultGroup)
if err != nil {
c.sendAlert(alertInternalError)
return err

View File

@ -684,7 +684,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) alert {
return alertDecodeError
}
case extensionKeyShare:
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5
// https://tools.ietf.org/html/rfc8446#section-4.2.8
if length < 2 {
return alertDecodeError
}

View File

@ -588,15 +588,6 @@ func (hs *serverHandshakeState) doFullHandshake() error {
}
hs.finishedHash.Write(certMsg.marshal())
if len(certMsg.certificates) == 0 {
// The client didn't actually send a certificate
switch c.config.ClientAuth {
case RequireAnyClientCert, RequireAndVerifyClientCert:
c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate")
}
}
pub, err = hs.processCertsFromClient(certMsg.certificates)
if err != nil {
return err
@ -797,6 +788,15 @@ func (hs *serverHandshakeState) sendFinished(out []byte) error {
func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (crypto.PublicKey, error) {
c := hs.c
if len(certificates) == 0 {
// The client didn't actually send a certificate
switch c.config.ClientAuth {
case RequireAnyClientCert, RequireAndVerifyClientCert:
c.sendAlert(alertBadCertificate)
return nil, errors.New("tls: client didn't provide a certificate")
}
}
hs.certsFromClient = certificates
certs := make([]*x509.Certificate, len(certificates))
var err error

View File

@ -45,7 +45,8 @@ func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte {
return res
}
func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
// HkdfExtract generates a pseudorandom key for use with Expand from an input secret and an optional independent salt.
func HkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
if salt == nil {
salt = make([]byte, hash.Size())
}

View File

@ -237,15 +237,14 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
}
var err error
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
// We don't need to parse the public key for TLS, but we so do anyway
// to check that it looks sane and matches the private key.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fail(err)
}
// We don't need to parse the public key for TLS, but we so do anyway
// to check that it looks sane and matches the private key.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
if err != nil {
return fail(err)
}