mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-02 23:47:07 -05:00
remove all vendor tests
This commit is contained in:
parent
84f8bca01c
commit
786290a31d
@ -1,26 +0,0 @@
|
|||||||
package benchmark
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBenchmark(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Benchmark Suite")
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
size int // file size in MB, will be read from flags
|
|
||||||
samples int // number of samples for Measure, will be read from flags
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
flag.IntVar(&size, "size", 50, "data length (in MB)")
|
|
||||||
flag.IntVar(&samples, "samples", 6, "number of samples")
|
|
||||||
flag.Parse()
|
|
||||||
}
|
|
@ -1,87 +0,0 @@
|
|||||||
package benchmark
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
var _ = Describe("Benchmarks", func() {
|
|
||||||
dataLen := size * /* MB */ 1e6
|
|
||||||
data := make([]byte, dataLen)
|
|
||||||
rand.Seed(GinkgoRandomSeed())
|
|
||||||
rand.Read(data) // no need to check for an error. math.Rand.Read never errors
|
|
||||||
|
|
||||||
for i := range protocol.SupportedVersions {
|
|
||||||
version := protocol.SupportedVersions[i]
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with version %s", version), func() {
|
|
||||||
Measure(fmt.Sprintf("transferring a %d MB file", size), func(b Benchmarker) {
|
|
||||||
var ln quic.Listener
|
|
||||||
serverAddr := make(chan net.Addr)
|
|
||||||
handshakeChan := make(chan struct{})
|
|
||||||
// start the server
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
var err error
|
|
||||||
ln, err = quic.ListenAddr(
|
|
||||||
"localhost:0",
|
|
||||||
testdata.GetTLSConfig(),
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverAddr <- ln.Addr()
|
|
||||||
sess, err := ln.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// wait for the client to complete the handshake before sending the data
|
|
||||||
// this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets
|
|
||||||
<-handshakeChan
|
|
||||||
str, err := sess.OpenStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = str.Write(data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = str.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}()
|
|
||||||
|
|
||||||
// start the client
|
|
||||||
addr := <-serverAddr
|
|
||||||
sess, err := quic.DialAddr(
|
|
||||||
addr.String(),
|
|
||||||
&tls.Config{InsecureSkipVerify: true},
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(handshakeChan)
|
|
||||||
str, err := sess.AcceptStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
// measure the time it takes to download the dataLen bytes
|
|
||||||
// note we're measuring the time for the transfer, i.e. excluding the handshake
|
|
||||||
runtime := b.Time("transfer time", func() {
|
|
||||||
_, err := io.Copy(buf, str)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
Expect(buf.Bytes()).To(Equal(data))
|
|
||||||
|
|
||||||
b.RecordValue("transfer rate [MB/s]", float64(dataLen)/1e6/runtime.Seconds())
|
|
||||||
|
|
||||||
ln.Close()
|
|
||||||
sess.Close()
|
|
||||||
}, samples)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
package quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Buffer Pool", func() {
|
|
||||||
It("returns buffers of cap", func() {
|
|
||||||
buf := *getPacketBuffer()
|
|
||||||
Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("panics if wrong-sized buffers are passed", func() {
|
|
||||||
Expect(func() {
|
|
||||||
putPacketBuffer(&[]byte{0})
|
|
||||||
}).To(Panic())
|
|
||||||
})
|
|
||||||
})
|
|
1030
vendor/lucas-clemente/quic-go/client_test.go
vendored
1030
vendor/lucas-clemente/quic-go/client_test.go
vendored
File diff suppressed because it is too large
Load Diff
119
vendor/lucas-clemente/quic-go/conn_test.go
vendored
119
vendor/lucas-clemente/quic-go/conn_test.go
vendored
@ -1,119 +0,0 @@
|
|||||||
package quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockPacketConn struct {
|
|
||||||
addr net.Addr
|
|
||||||
dataToRead chan []byte
|
|
||||||
dataReadFrom net.Addr
|
|
||||||
readErr error
|
|
||||||
dataWritten bytes.Buffer
|
|
||||||
dataWrittenTo net.Addr
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockPacketConn() *mockPacketConn {
|
|
||||||
return &mockPacketConn{
|
|
||||||
dataToRead: make(chan []byte, 1000),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
|
||||||
if c.readErr != nil {
|
|
||||||
return 0, nil, c.readErr
|
|
||||||
}
|
|
||||||
data, ok := <-c.dataToRead
|
|
||||||
if !ok {
|
|
||||||
return 0, nil, errors.New("connection closed")
|
|
||||||
}
|
|
||||||
n := copy(b, data)
|
|
||||||
return n, c.dataReadFrom, nil
|
|
||||||
}
|
|
||||||
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|
||||||
c.dataWrittenTo = addr
|
|
||||||
return c.dataWritten.Write(b)
|
|
||||||
}
|
|
||||||
func (c *mockPacketConn) Close() error {
|
|
||||||
if !c.closed {
|
|
||||||
close(c.dataToRead)
|
|
||||||
}
|
|
||||||
c.closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
|
|
||||||
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
|
|
||||||
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
|
|
||||||
func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }
|
|
||||||
|
|
||||||
var _ net.PacketConn = &mockPacketConn{}
|
|
||||||
|
|
||||||
var _ = Describe("Connection", func() {
|
|
||||||
var c *conn
|
|
||||||
var packetConn *mockPacketConn
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
addr := &net.UDPAddr{
|
|
||||||
IP: net.IPv4(192, 168, 100, 200),
|
|
||||||
Port: 1337,
|
|
||||||
}
|
|
||||||
packetConn = newMockPacketConn()
|
|
||||||
c = &conn{
|
|
||||||
currentAddr: addr,
|
|
||||||
pconn: packetConn,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes", func() {
|
|
||||||
err := c.Write([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(packetConn.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
|
||||||
Expect(packetConn.dataWrittenTo.String()).To(Equal("192.168.100.200:1337"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads", func() {
|
|
||||||
packetConn.dataToRead <- []byte("foo")
|
|
||||||
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
|
|
||||||
p := make([]byte, 10)
|
|
||||||
n, raddr, err := c.Read(p)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(raddr.String()).To(Equal("127.0.0.1:1336"))
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
Expect(p[0:3]).To(Equal([]byte("foo")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the remote address", func() {
|
|
||||||
Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the local address", func() {
|
|
||||||
addr := &net.UDPAddr{
|
|
||||||
IP: net.IPv4(192, 168, 0, 1),
|
|
||||||
Port: 1234,
|
|
||||||
}
|
|
||||||
packetConn.addr = addr
|
|
||||||
Expect(c.LocalAddr()).To(Equal(addr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("changes the remote address", func() {
|
|
||||||
addr := &net.UDPAddr{
|
|
||||||
IP: net.IPv4(127, 0, 0, 1),
|
|
||||||
Port: 7331,
|
|
||||||
}
|
|
||||||
c.SetCurrentRemoteAddr(addr)
|
|
||||||
Expect(c.RemoteAddr().String()).To(Equal(addr.String()))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes", func() {
|
|
||||||
err := c.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(packetConn.closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,26 +0,0 @@
|
|||||||
package quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Crypto Stream", func() {
|
|
||||||
var (
|
|
||||||
str *cryptoStreamImpl
|
|
||||||
mockSender *MockStreamSender
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
mockSender = NewMockStreamSender(mockCtrl)
|
|
||||||
str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStreamImpl)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the read offset", func() {
|
|
||||||
str.setReadOffset(0x42)
|
|
||||||
Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42)))
|
|
||||||
Expect(str.receiveStream.frameQueue.readPos).To(Equal(protocol.ByteCount(0x42)))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,9 +0,0 @@
|
|||||||
FROM scratch
|
|
||||||
|
|
||||||
VOLUME /certs
|
|
||||||
VOLUME /www
|
|
||||||
EXPOSE 6121
|
|
||||||
|
|
||||||
ADD main /main
|
|
||||||
|
|
||||||
CMD ["/main", "-bind=0.0.0.0", "-certpath=/certs/", "-www=/www"]
|
|
@ -1,7 +0,0 @@
|
|||||||
# About the certificate
|
|
||||||
|
|
||||||
Yes, this folder contains a private key and a certificate for quic.clemente.io.
|
|
||||||
|
|
||||||
Unfortunately we need a valid certificate for the integration tests with Chrome and `quic_client`. No important data is served on the "real" `quic.clemente.io` (only a test page), and the MITM problem is imho negligible.
|
|
||||||
|
|
||||||
If you figure out a way to test with Chrome without having a cert and key here, let us now in an issue.
|
|
@ -1,66 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"flag"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
verbose := flag.Bool("v", false, "verbose")
|
|
||||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
|
||||||
flag.Parse()
|
|
||||||
urls := flag.Args()
|
|
||||||
|
|
||||||
logger := utils.DefaultLogger
|
|
||||||
|
|
||||||
if *verbose {
|
|
||||||
logger.SetLogLevel(utils.LogLevelDebug)
|
|
||||||
} else {
|
|
||||||
logger.SetLogLevel(utils.LogLevelInfo)
|
|
||||||
}
|
|
||||||
logger.SetLogTimeFormat("")
|
|
||||||
|
|
||||||
versions := protocol.SupportedVersions
|
|
||||||
if *tls {
|
|
||||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
|
||||||
}
|
|
||||||
|
|
||||||
roundTripper := &h2quic.RoundTripper{
|
|
||||||
QuicConfig: &quic.Config{Versions: versions},
|
|
||||||
}
|
|
||||||
defer roundTripper.Close()
|
|
||||||
hclient := &http.Client{
|
|
||||||
Transport: roundTripper,
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(len(urls))
|
|
||||||
for _, addr := range urls {
|
|
||||||
logger.Infof("GET %s", addr)
|
|
||||||
go func(addr string) {
|
|
||||||
rsp, err := hclient.Get(addr)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
logger.Infof("Got response for %s: %#v", addr, rsp)
|
|
||||||
|
|
||||||
body := &bytes.Buffer{}
|
|
||||||
_, err = io.Copy(body, rsp.Body)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
logger.Infof("Request Body:")
|
|
||||||
logger.Infof("%s", body.Bytes())
|
|
||||||
wg.Done()
|
|
||||||
}(addr)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
105
vendor/lucas-clemente/quic-go/example/echo/echo.go
vendored
105
vendor/lucas-clemente/quic-go/example/echo/echo.go
vendored
@ -1,105 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"math/big"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
const addr = "localhost:4242"
|
|
||||||
|
|
||||||
const message = "foobar"
|
|
||||||
|
|
||||||
// We start a server echoing data on the first stream the client opens,
|
|
||||||
// then connect with a client, send the message, and wait for its receipt.
|
|
||||||
func main() {
|
|
||||||
go func() { log.Fatal(echoServer()) }()
|
|
||||||
|
|
||||||
err := clientMain()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start a server that echos all data on the first stream opened by the client
|
|
||||||
func echoServer() error {
|
|
||||||
listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
sess, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
stream, err := sess.AcceptStream()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
// Echo through the loggingWriter
|
|
||||||
_, err = io.Copy(loggingWriter{stream}, stream)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func clientMain() error {
|
|
||||||
session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := session.OpenStreamSync()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Client: Sending '%s'\n", message)
|
|
||||||
_, err = stream.Write([]byte(message))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, len(message))
|
|
||||||
_, err = io.ReadFull(stream, buf)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fmt.Printf("Client: Got '%s'\n", buf)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// A wrapper for io.Writer that also logs the message.
|
|
||||||
type loggingWriter struct{ io.Writer }
|
|
||||||
|
|
||||||
func (w loggingWriter) Write(b []byte) (int, error) {
|
|
||||||
fmt.Printf("Server: Got '%s'\n", string(b))
|
|
||||||
return w.Writer.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup a bare-bones TLS config for the server
|
|
||||||
func generateTLSConfig() *tls.Config {
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
template := x509.Certificate{SerialNumber: big.NewInt(1)}
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
|
||||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
|
||||||
|
|
||||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
|
|
||||||
}
|
|
@ -1,62 +0,0 @@
|
|||||||
-----BEGIN CERTIFICATE-----
|
|
||||||
MIIGCzCCBPOgAwIBAgISA6pwet1vq9IjS9ThcggZYGfyMA0GCSqGSIb3DQEBCwUA
|
|
||||||
MEoxCzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MSMwIQYDVQQD
|
|
||||||
ExpMZXQncyBFbmNyeXB0IEF1dGhvcml0eSBYMzAeFw0xODA2MDkwODI2MjVaFw0x
|
|
||||||
ODA5MDcwODI2MjVaMBsxGTAXBgNVBAMTEHF1aWMuY2xlbWVudGUuaW8wggEiMA0G
|
|
||||||
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDqON9RKCMK4dHxLTLiv3n+b/v4KCbY
|
|
||||||
Fo38rxCcv4/qoWI1Zoz4XMyPnTZPrzVMz6rhDKWTve1v7g+itgQDp1OoRo5ihtqC
|
|
||||||
m9Dr1Ed+qb3hHtFRBBOTiWQy1Y0fvUDYGlSHQ4R7xXYVzCYkQ9zopqr2otJp/9ZA
|
|
||||||
1Yy3ATITSgcds9wJaAOpkSxCx+D7cqpGbaNtxogRZ4ZT1vVs3l7Kr26fvnALMM4K
|
|
||||||
nf2Rpq1ZjgPRoLNcfUWCZO7gr+VNCFaI2AQRaAzZkSXAk0B91waTtaFD87kYXl9Z
|
|
||||||
8Jgx1PhJnO8WBUz/7+OAO6M+kuT78e6FltMN004Q8WmuNww6PzmVwnmpAgMBAAGj
|
|
||||||
ggMYMIIDFDAOBgNVHQ8BAf8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG
|
|
||||||
AQUFBwMCMAwGA1UdEwEB/wQCMAAwHQYDVR0OBBYEFIs3yJaLnEw7MRtNf9EbzESy
|
|
||||||
TrHwMB8GA1UdIwQYMBaAFKhKamMEfd265tE5t6ZFZe/zqOyhMG8GCCsGAQUFBwEB
|
|
||||||
BGMwYTAuBggrBgEFBQcwAYYiaHR0cDovL29jc3AuaW50LXgzLmxldHNlbmNyeXB0
|
|
||||||
Lm9yZzAvBggrBgEFBQcwAoYjaHR0cDovL2NlcnQuaW50LXgzLmxldHNlbmNyeXB0
|
|
||||||
Lm9yZy8wGwYDVR0RBBQwEoIQcXVpYy5jbGVtZW50ZS5pbzCB/gYDVR0gBIH2MIHz
|
|
||||||
MAgGBmeBDAECATCB5gYLKwYBBAGC3xMBAQEwgdYwJgYIKwYBBQUHAgEWGmh0dHA6
|
|
||||||
Ly9jcHMubGV0c2VuY3J5cHQub3JnMIGrBggrBgEFBQcCAjCBngyBm1RoaXMgQ2Vy
|
|
||||||
dGlmaWNhdGUgbWF5IG9ubHkgYmUgcmVsaWVkIHVwb24gYnkgUmVseWluZyBQYXJ0
|
|
||||||
aWVzIGFuZCBvbmx5IGluIGFjY29yZGFuY2Ugd2l0aCB0aGUgQ2VydGlmaWNhdGUg
|
|
||||||
UG9saWN5IGZvdW5kIGF0IGh0dHBzOi8vbGV0c2VuY3J5cHQub3JnL3JlcG9zaXRv
|
|
||||||
cnkvMIIBBAYKKwYBBAHWeQIEAgSB9QSB8gDwAHYA23Sv7ssp7LH+yj5xbSzluaq7
|
|
||||||
NveEcYPHXZ1PN7Yfv2QAAAFj495INQAABAMARzBFAiEApF0BFCWyGIUrJrsYuugt
|
|
||||||
tshGVdg2+7f6d4B1D/xF2s0CIGGsVL2nRlXJXrkk3aa83lH4HzP9vcQSnMFHdXOf
|
|
||||||
9XeZAHYAKTxRllTIOWW6qlD8WAfUt2+/WHopctykwwz05UVH9HgAAAFj495IQwAA
|
|
||||||
BAMARzBFAiEAkV4gbM6hucL7ZwqTzb5fKxYhk6WHr5y8pzClZD3qqv4CIBYH7MSA
|
|
||||||
P05CXv7tHiHEizIvhWJJvVa2E6XLjbDRQMnUMA0GCSqGSIb3DQEBCwUAA4IBAQA2
|
|
||||||
CALjtlxGXkkfKsRgKDoPpzg/IAl7crq5OrGGwW/bxbeDeiRHVt0Hlhr+0XPYlh/A
|
|
||||||
m8qBlg7TMHJa2zIt7wkG/MX3d9bO2bxXxsdjfMXLtxQu7eZ5nzGuXL8WGnZwtSqz
|
|
||||||
M5pOF/AU1JQojIhehKCeqqUi5UobxXUm+9D6OVmr8s732X3n/TL6pgsWRhay9tjB
|
|
||||||
kdZ9TSe0tLYyXnbwHp0rgwNWOMMN1Tc+Fpqc8UlrCq5REb0bLIQ6A2IJH08MEPWG
|
|
||||||
ukXHPAaHLz9oB1emR3flCoMKB0KrXUpDFXemOIPN6QVBO16LNNuRffKwXnzp60+b
|
|
||||||
MoB4Krxrab7TzlfT+HnP
|
|
||||||
-----END CERTIFICATE-----
|
|
||||||
-----BEGIN CERTIFICATE-----
|
|
||||||
MIIEkjCCA3qgAwIBAgIQCgFBQgAAAVOFc2oLheynCDANBgkqhkiG9w0BAQsFADA/
|
|
||||||
MSQwIgYDVQQKExtEaWdpdGFsIFNpZ25hdHVyZSBUcnVzdCBDby4xFzAVBgNVBAMT
|
|
||||||
DkRTVCBSb290IENBIFgzMB4XDTE2MDMxNzE2NDA0NloXDTIxMDMxNzE2NDA0Nlow
|
|
||||||
SjELMAkGA1UEBhMCVVMxFjAUBgNVBAoTDUxldCdzIEVuY3J5cHQxIzAhBgNVBAMT
|
|
||||||
GkxldCdzIEVuY3J5cHQgQXV0aG9yaXR5IFgzMIIBIjANBgkqhkiG9w0BAQEFAAOC
|
|
||||||
AQ8AMIIBCgKCAQEAnNMM8FrlLke3cl03g7NoYzDq1zUmGSXhvb418XCSL7e4S0EF
|
|
||||||
q6meNQhY7LEqxGiHC6PjdeTm86dicbp5gWAf15Gan/PQeGdxyGkOlZHP/uaZ6WA8
|
|
||||||
SMx+yk13EiSdRxta67nsHjcAHJyse6cF6s5K671B5TaYucv9bTyWaN8jKkKQDIZ0
|
|
||||||
Z8h/pZq4UmEUEz9l6YKHy9v6Dlb2honzhT+Xhq+w3Brvaw2VFn3EK6BlspkENnWA
|
|
||||||
a6xK8xuQSXgvopZPKiAlKQTGdMDQMc2PMTiVFrqoM7hD8bEfwzB/onkxEz0tNvjj
|
|
||||||
/PIzark5McWvxI0NHWQWM6r6hCm21AvA2H3DkwIDAQABo4IBfTCCAXkwEgYDVR0T
|
|
||||||
AQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAYYwfwYIKwYBBQUHAQEEczBxMDIG
|
|
||||||
CCsGAQUFBzABhiZodHRwOi8vaXNyZy50cnVzdGlkLm9jc3AuaWRlbnRydXN0LmNv
|
|
||||||
bTA7BggrBgEFBQcwAoYvaHR0cDovL2FwcHMuaWRlbnRydXN0LmNvbS9yb290cy9k
|
|
||||||
c3Ryb290Y2F4My5wN2MwHwYDVR0jBBgwFoAUxKexpHsscfrb4UuQdf/EFWCFiRAw
|
|
||||||
VAYDVR0gBE0wSzAIBgZngQwBAgEwPwYLKwYBBAGC3xMBAQEwMDAuBggrBgEFBQcC
|
|
||||||
ARYiaHR0cDovL2Nwcy5yb290LXgxLmxldHNlbmNyeXB0Lm9yZzA8BgNVHR8ENTAz
|
|
||||||
MDGgL6AthitodHRwOi8vY3JsLmlkZW50cnVzdC5jb20vRFNUUk9PVENBWDNDUkwu
|
|
||||||
Y3JsMB0GA1UdDgQWBBSoSmpjBH3duubRObemRWXv86jsoTANBgkqhkiG9w0BAQsF
|
|
||||||
AAOCAQEA3TPXEfNjWDjdGBX7CVW+dla5cEilaUcne8IkCJLxWh9KEik3JHRRHGJo
|
|
||||||
uM2VcGfl96S8TihRzZvoroed6ti6WqEBmtzw3Wodatg+VyOeph4EYpr/1wXKtx8/
|
|
||||||
wApIvJSwtmVi4MFU5aMqrSDE6ea73Mj2tcMyo5jMd6jmeWUHK8so/joWUoHOUgwu
|
|
||||||
X4Po1QYz+3dszkDqMp4fklxBwXRsW10KXzPMTZ+sOPAveyxindmjkW8lGy+QsRlG
|
|
||||||
PfZ+G6Z6h7mjem0Y+iWlkYcV4PIWL1iwBi8saCbGS5jN2p8M+X+Q7UNKEkROb3N6
|
|
||||||
KOqkqm57TH2H3eDJAkSnh6/DNFu0Qg==
|
|
||||||
-----END CERTIFICATE-----
|
|
174
vendor/lucas-clemente/quic-go/example/main.go
vendored
174
vendor/lucas-clemente/quic-go/example/main.go
vendored
@ -1,174 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/md5"
|
|
||||||
"errors"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"mime/multipart"
|
|
||||||
"net/http"
|
|
||||||
"path"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
_ "net/http/pprof"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type binds []string
|
|
||||||
|
|
||||||
func (b binds) String() string {
|
|
||||||
return strings.Join(b, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *binds) Set(v string) error {
|
|
||||||
*b = strings.Split(v, ",")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size is needed by the /demo/upload handler to determine the size of the uploaded file
|
|
||||||
type Size interface {
|
|
||||||
Size() int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
http.HandleFunc("/demo/tile", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Small 40x40 png
|
|
||||||
w.Write([]byte{
|
|
||||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d,
|
|
||||||
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x28,
|
|
||||||
0x01, 0x03, 0x00, 0x00, 0x00, 0xb6, 0x30, 0x2a, 0x2e, 0x00, 0x00, 0x00,
|
|
||||||
0x03, 0x50, 0x4c, 0x54, 0x45, 0x5a, 0xc3, 0x5a, 0xad, 0x38, 0xaa, 0xdb,
|
|
||||||
0x00, 0x00, 0x00, 0x0b, 0x49, 0x44, 0x41, 0x54, 0x78, 0x01, 0x63, 0x18,
|
|
||||||
0x61, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x01, 0xe2, 0xb8, 0x75, 0x22, 0x00,
|
|
||||||
0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
http.HandleFunc("/demo/tiles", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
io.WriteString(w, "<html><head><style>img{width:40px;height:40px;}</style></head><body>")
|
|
||||||
for i := 0; i < 200; i++ {
|
|
||||||
fmt.Fprintf(w, `<img src="/demo/tile?cachebust=%d">`, i)
|
|
||||||
}
|
|
||||||
io.WriteString(w, "</body></html>")
|
|
||||||
})
|
|
||||||
|
|
||||||
http.HandleFunc("/demo/echo", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("error reading body while handling /echo: %s\n", err.Error())
|
|
||||||
}
|
|
||||||
w.Write(body)
|
|
||||||
})
|
|
||||||
|
|
||||||
// accept file uploads and return the MD5 of the uploaded file
|
|
||||||
// maximum accepted file size is 1 GB
|
|
||||||
http.HandleFunc("/demo/upload", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method == http.MethodPost {
|
|
||||||
err := r.ParseMultipartForm(1 << 30) // 1 GB
|
|
||||||
if err == nil {
|
|
||||||
var file multipart.File
|
|
||||||
file, _, err = r.FormFile("uploadfile")
|
|
||||||
if err == nil {
|
|
||||||
var size int64
|
|
||||||
if sizeInterface, ok := file.(Size); ok {
|
|
||||||
size = sizeInterface.Size()
|
|
||||||
b := make([]byte, size)
|
|
||||||
file.Read(b)
|
|
||||||
md5 := md5.Sum(b)
|
|
||||||
fmt.Fprintf(w, "%x", md5)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = errors.New("couldn't get uploaded file size")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
utils.DefaultLogger.Infof("Error receiving upload: %#v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data">
|
|
||||||
<input type="file" name="uploadfile"><br>
|
|
||||||
<input type="submit">
|
|
||||||
</form></body></html>`)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBuildDir() string {
|
|
||||||
_, filename, _, ok := runtime.Caller(0)
|
|
||||||
if !ok {
|
|
||||||
panic("Failed to get current frame")
|
|
||||||
}
|
|
||||||
|
|
||||||
return path.Dir(filename)
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// defer profile.Start().Stop()
|
|
||||||
go func() {
|
|
||||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
|
||||||
}()
|
|
||||||
// runtime.SetBlockProfileRate(1)
|
|
||||||
|
|
||||||
verbose := flag.Bool("v", false, "verbose")
|
|
||||||
bs := binds{}
|
|
||||||
flag.Var(&bs, "bind", "bind to")
|
|
||||||
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
|
||||||
www := flag.String("www", "/var/www", "www data")
|
|
||||||
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
|
||||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
logger := utils.DefaultLogger
|
|
||||||
|
|
||||||
if *verbose {
|
|
||||||
logger.SetLogLevel(utils.LogLevelDebug)
|
|
||||||
} else {
|
|
||||||
logger.SetLogLevel(utils.LogLevelInfo)
|
|
||||||
}
|
|
||||||
logger.SetLogTimeFormat("")
|
|
||||||
|
|
||||||
versions := protocol.SupportedVersions
|
|
||||||
if *tls {
|
|
||||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
|
||||||
}
|
|
||||||
|
|
||||||
certFile := *certPath + "/fullchain.pem"
|
|
||||||
keyFile := *certPath + "/privkey.pem"
|
|
||||||
|
|
||||||
http.Handle("/", http.FileServer(http.Dir(*www)))
|
|
||||||
|
|
||||||
if len(bs) == 0 {
|
|
||||||
bs = binds{"localhost:6121"}
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(len(bs))
|
|
||||||
for _, b := range bs {
|
|
||||||
bCap := b
|
|
||||||
go func() {
|
|
||||||
var err error
|
|
||||||
if *tcp {
|
|
||||||
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
|
||||||
} else {
|
|
||||||
server := h2quic.Server{
|
|
||||||
Server: &http.Server{Addr: bCap},
|
|
||||||
QuicConfig: &quic.Config{Versions: versions},
|
|
||||||
}
|
|
||||||
err = server.ListenAndServeTLS(certFile, keyFile)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
@ -1,28 +0,0 @@
|
|||||||
-----BEGIN PRIVATE KEY-----
|
|
||||||
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDqON9RKCMK4dHx
|
|
||||||
LTLiv3n+b/v4KCbYFo38rxCcv4/qoWI1Zoz4XMyPnTZPrzVMz6rhDKWTve1v7g+i
|
|
||||||
tgQDp1OoRo5ihtqCm9Dr1Ed+qb3hHtFRBBOTiWQy1Y0fvUDYGlSHQ4R7xXYVzCYk
|
|
||||||
Q9zopqr2otJp/9ZA1Yy3ATITSgcds9wJaAOpkSxCx+D7cqpGbaNtxogRZ4ZT1vVs
|
|
||||||
3l7Kr26fvnALMM4Knf2Rpq1ZjgPRoLNcfUWCZO7gr+VNCFaI2AQRaAzZkSXAk0B9
|
|
||||||
1waTtaFD87kYXl9Z8Jgx1PhJnO8WBUz/7+OAO6M+kuT78e6FltMN004Q8WmuNww6
|
|
||||||
PzmVwnmpAgMBAAECggEAAW7hpux48msZTsF5CzwisfTbdNRCEJZqvf4QOvVNGyFr
|
|
||||||
qWn8ONTQh5xtpaUrzVGD+SaLqNDDsCijvdohQih28ZOk8WNj2OK9L4Q3/8VoHQWE
|
|
||||||
QFunBwMTMuBtoaEV0XyvwbgfCmbV5yI9pYEoy9+hMisi4HUpSXJFDyWZudZ9Hqhl
|
|
||||||
FBf5UPxF1AjZViODVfnKkrWh85jaENyWjWrMEDSohzxP8IStFMK8E0feXaQb+G0f
|
|
||||||
uTelpG2JmVcO3xaYgtRVeS4p5liMQg5gE5oa2Jh7Vp3TwMhQVg0aLmslSLAYlPoh
|
|
||||||
hyBeOS3ucyFHoC/6Stnnx3jdOEf2lEUObJj3QVEeBQKBgQD3qptNY9R62UQFO8gT
|
|
||||||
pseEO5CAZRGuKG0VLPNqKKerYQreiT3xYOTPmwEy+xXzYV6DKHtlKDetHuOB862M
|
|
||||||
E1bKmjDedafQ3KMc4tywLW+U9I8GyooT8uoetzARzgoftRcMzdeu4fexWTsi+kOh
|
|
||||||
5/PeXUBnFph8E68nXHQR5+MWAwKBgQDyGnT9KUVuLvzrFAghe9Eu1iZJSDs/IrJj
|
|
||||||
we3XQ7loqMc/Qv34eVKATsgtb2cTSeTivQcSvQO+Uu9dgo17BoWAABDTaV8NAOZ6
|
|
||||||
cV2kedrWnxaTRXzB6Z5EKLg+DOMTpCV4Nf3OwzGq/mnKAe8cNM/hS+8HcZywKwr2
|
|
||||||
UwLKSq6n4wKBgQDXGUaOnVCSbZZlETnAz43i67ShvqXvY07yIDs8jRiqgLrm8b1p
|
|
||||||
oaS4JkCRXX8ABSYHtaYOAjLw2a3wVIn66WTsy6P74aWhga7szJ+tJ5kMfqal2Ey5
|
|
||||||
7LSnfqRyIkeqqCXfyfsz+S+dyQjSZRdOS90C2Gyx2+8NfC8YeXSZhJM2rwKBgQDB
|
|
||||||
pL28G/2nsrejQ2N5fLKE9s6qwLZ6ukLbHasiCc5L0uuDQw8mZcvCSsE77iYQvILx
|
|
||||||
hGYa68oJugYw0hJdu4qeJe9PWbGoEfdHKlPPEZQjJB4Hb4XpB/YJ6FPtdZtPA3Tg
|
|
||||||
4LaAYYnhjhqJc+CPvAIl3vlyB8Je+h6LhTvvF6r5JwKBgHBkpQ2Y+N3dMgfE8T38
|
|
||||||
+Vy+75i9ew3F0TB8o2MWJzYLYCzNLqrB4GkZieudoYP1Cmh50wYIifsk2zNpxOG0
|
|
||||||
27tUfMxza4ITUfmtVHS22DGWCqb1TgfQt7jajdioK4TUOZFEmiQiOYMZoFMenbzZ
|
|
||||||
miz6s3fWx1/b3wNd1J5uV2QS
|
|
||||||
-----END PRIVATE KEY-----
|
|
435
vendor/lucas-clemente/quic-go/frame_sorter_test.go
vendored
435
vendor/lucas-clemente/quic-go/frame_sorter_test.go
vendored
@ -1,435 +0,0 @@
|
|||||||
package quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("STREAM frame sorter", func() {
|
|
||||||
var s *frameSorter
|
|
||||||
|
|
||||||
checkGaps := func(expectedGaps []utils.ByteInterval) {
|
|
||||||
Expect(s.gaps.Len()).To(Equal(len(expectedGaps)))
|
|
||||||
var i int
|
|
||||||
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
|
|
||||||
Expect(gap.Value).To(Equal(expectedGaps[i]))
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
s = newFrameSorter()
|
|
||||||
})
|
|
||||||
|
|
||||||
It("head returns nil when empty", func() {
|
|
||||||
Expect(s.Pop()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Push", func() {
|
|
||||||
It("inserts and pops a single frame", func() {
|
|
||||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
|
||||||
data, fin := s.Pop()
|
|
||||||
Expect(data).To(Equal([]byte("foobar")))
|
|
||||||
Expect(fin).To(BeFalse())
|
|
||||||
Expect(s.Pop()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("inserts and pops two consecutive frame", func() {
|
|
||||||
Expect(s.Push([]byte("foo"), 0, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("bar"), 3, false)).To(Succeed())
|
|
||||||
data, fin := s.Pop()
|
|
||||||
Expect(data).To(Equal([]byte("foo")))
|
|
||||||
Expect(fin).To(BeFalse())
|
|
||||||
data, fin = s.Pop()
|
|
||||||
Expect(data).To(Equal([]byte("bar")))
|
|
||||||
Expect(fin).To(BeFalse())
|
|
||||||
Expect(s.Pop()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores empty frames", func() {
|
|
||||||
Expect(s.Push(nil, 0, false)).To(Succeed())
|
|
||||||
Expect(s.Pop()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("FIN handling", func() {
|
|
||||||
It("saves a FIN at offset 0", func() {
|
|
||||||
Expect(s.Push(nil, 0, true)).To(Succeed())
|
|
||||||
data, fin := s.Pop()
|
|
||||||
Expect(data).To(BeEmpty())
|
|
||||||
Expect(fin).To(BeTrue())
|
|
||||||
data, fin = s.Pop()
|
|
||||||
Expect(data).To(BeNil())
|
|
||||||
Expect(fin).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves a FIN frame at non-zero offset", func() {
|
|
||||||
Expect(s.Push([]byte("foobar"), 0, true)).To(Succeed())
|
|
||||||
data, fin := s.Pop()
|
|
||||||
Expect(data).To(Equal([]byte("foobar")))
|
|
||||||
Expect(fin).To(BeTrue())
|
|
||||||
data, fin = s.Pop()
|
|
||||||
Expect(data).To(BeNil())
|
|
||||||
Expect(fin).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the FIN if a stream is closed after receiving some data", func() {
|
|
||||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
|
||||||
Expect(s.Push(nil, 6, true)).To(Succeed())
|
|
||||||
data, fin := s.Pop()
|
|
||||||
Expect(data).To(Equal([]byte("foobar")))
|
|
||||||
Expect(fin).To(BeTrue())
|
|
||||||
data, fin = s.Pop()
|
|
||||||
Expect(data).To(BeNil())
|
|
||||||
Expect(fin).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Gap handling", func() {
|
|
||||||
It("finds the first gap", func() {
|
|
||||||
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 10},
|
|
||||||
{Start: 16, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("correctly sets the first gap for a frame with offset 0", func() {
|
|
||||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 6, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("finds the two gaps", func() {
|
|
||||||
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("foobar"), 20, false)).To(Succeed())
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 10},
|
|
||||||
{Start: 16, End: 20},
|
|
||||||
{Start: 26, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("finds the two gaps in reverse order", func() {
|
|
||||||
Expect(s.Push([]byte("foobar"), 20, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 10},
|
|
||||||
{Start: 16, End: 20},
|
|
||||||
{Start: 26, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("shrinks a gap when it is partially filled", func() {
|
|
||||||
Expect(s.Push([]byte("test"), 10, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("foobar"), 4, false)).To(Succeed())
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 4},
|
|
||||||
{Start: 14, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes a gap at the beginning, when it is filled", func() {
|
|
||||||
Expect(s.Push([]byte("test"), 6, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 10, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes a gap in the middle, when it is filled", func() {
|
|
||||||
Expect(s.Push([]byte("test"), 0, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("test2"), 10, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("foobar"), 4, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveLen(3))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 15, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("splits a gap into two", func() {
|
|
||||||
Expect(s.Push([]byte("test"), 100, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("foobar"), 50, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveLen(2))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 50},
|
|
||||||
{Start: 56, End: 100},
|
|
||||||
{Start: 104, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Overlapping Stream Data detection", func() {
|
|
||||||
// create gaps: 0-5, 10-15, 20-25, 30-inf
|
|
||||||
BeforeEach(func() {
|
|
||||||
Expect(s.Push([]byte("12345"), 5, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("12345"), 15, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("12345"), 25, false)).To(Succeed())
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 5},
|
|
||||||
{Start: 10, End: 15},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("cuts a frame with offset 0 that overlaps at the end", func() {
|
|
||||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(0)))
|
|
||||||
Expect(s.queue[0]).To(Equal([]byte("fooba")))
|
|
||||||
Expect(s.queue[0]).To(HaveCap(5))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 10, End: 15},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("cuts a frame that overlaps at the end", func() {
|
|
||||||
// 4 to 7
|
|
||||||
Expect(s.Push([]byte("foo"), 4, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(4)))
|
|
||||||
Expect(s.queue[4]).To(Equal([]byte("f")))
|
|
||||||
Expect(s.queue[4]).To(HaveCap(1))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 4},
|
|
||||||
{Start: 10, End: 15},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("cuts a frame that completely fills a gap, but overlaps at the end", func() {
|
|
||||||
// 10 to 16
|
|
||||||
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
|
||||||
Expect(s.queue[10]).To(Equal([]byte("fooba")))
|
|
||||||
Expect(s.queue[10]).To(HaveCap(5))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 5},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("cuts a frame that overlaps at the beginning", func() {
|
|
||||||
// 8 to 14
|
|
||||||
Expect(s.Push([]byte("foobar"), 8, false)).To(Succeed())
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
|
||||||
Expect(s.queue[10]).To(Equal([]byte("obar")))
|
|
||||||
Expect(s.queue[10]).To(HaveCap(4))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 5},
|
|
||||||
{Start: 14, End: 15},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("processes a frame that overlaps at the beginning and at the end, starting in a gap", func() {
|
|
||||||
// 2 to 12
|
|
||||||
Expect(s.Push([]byte("1234567890"), 2, false)).To(Succeed())
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
|
|
||||||
Expect(s.queue[2]).To(Equal([]byte("1234567890")))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 2},
|
|
||||||
{Start: 12, End: 15},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() {
|
|
||||||
// 2 to 17
|
|
||||||
Expect(s.Push([]byte("123456789012345"), 2, false)).To(Succeed())
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
|
|
||||||
Expect(s.queue[2]).To(Equal([]byte("1234567890123")))
|
|
||||||
Expect(s.queue[2]).To(HaveCap(13))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 2},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() {
|
|
||||||
// 5 to 22
|
|
||||||
Expect(s.Push([]byte("12345678901234567"), 5, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(5)))
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
|
|
||||||
Expect(s.queue[10]).To(Equal([]byte("678901234567")))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 5},
|
|
||||||
{Start: 22, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("processes a frame that closes multiple gaps", func() {
|
|
||||||
// 2 to 27
|
|
||||||
Expect(s.Push(bytes.Repeat([]byte{'e'}, 25), 2, false)).To(Succeed())
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(25)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
|
|
||||||
Expect(s.queue[2]).To(Equal(bytes.Repeat([]byte{'e'}, 23)))
|
|
||||||
Expect(s.queue[2]).To(HaveCap(23))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 2},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("processes a frame that closes multiple gaps", func() {
|
|
||||||
// 5 to 27
|
|
||||||
Expect(s.Push(bytes.Repeat([]byte{'d'}, 22), 5, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(5)))
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(25)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
|
||||||
Expect(s.queue[10]).To(Equal(bytes.Repeat([]byte{'d'}, 15)))
|
|
||||||
Expect(s.queue[10]).To(HaveCap(15))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 5},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("processes a frame that covers multiple gaps and ends at the end of a gap", func() {
|
|
||||||
data := bytes.Repeat([]byte{'e'}, 14)
|
|
||||||
// 1 to 15
|
|
||||||
Expect(s.Push(data, 1, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(1)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(15)))
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
|
|
||||||
Expect(s.queue[1]).To(Equal(data))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 1},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("processes a frame that closes all gaps (except for the last one)", func() {
|
|
||||||
data := bytes.Repeat([]byte{'f'}, 32)
|
|
||||||
// 0 to 32
|
|
||||||
Expect(s.Push(data, 0, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveLen(1))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(0)))
|
|
||||||
Expect(s.queue[0]).To(Equal(data))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 32, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("cuts a frame that overlaps at the beginning and at the end, starting in data already received", func() {
|
|
||||||
// 8 to 17
|
|
||||||
Expect(s.Push([]byte("123456789"), 8, false)).To(Succeed())
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8)))
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
|
||||||
Expect(s.queue[10]).To(Equal([]byte("34567")))
|
|
||||||
Expect(s.queue[10]).To(HaveCap(5))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 5},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("cuts a frame that completely covers two gaps", func() {
|
|
||||||
// 10 to 20
|
|
||||||
Expect(s.Push([]byte("1234567890"), 10, false)).To(Succeed())
|
|
||||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
|
||||||
Expect(s.queue[10]).To(Equal([]byte("12345")))
|
|
||||||
Expect(s.queue[10]).To(HaveCap(5))
|
|
||||||
checkGaps([]utils.ByteInterval{
|
|
||||||
{Start: 0, End: 5},
|
|
||||||
{Start: 20, End: 25},
|
|
||||||
{Start: 30, End: protocol.MaxByteCount},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("duplicate data", func() {
|
|
||||||
expectedGaps := []utils.ByteInterval{
|
|
||||||
{Start: 5, End: 10},
|
|
||||||
{Start: 15, End: protocol.MaxByteCount},
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
// create gaps: 5-10, 15-inf
|
|
||||||
Expect(s.Push([]byte("12345"), 0, false)).To(Succeed())
|
|
||||||
Expect(s.Push([]byte("12345"), 10, false)).To(Succeed())
|
|
||||||
checkGaps(expectedGaps)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
// check that the gaps were not modified
|
|
||||||
checkGaps(expectedGaps)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not modify data when receiving a duplicate", func() {
|
|
||||||
err := s.push([]byte("fffff"), 0, false)
|
|
||||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
|
||||||
Expect(s.queue[0]).ToNot(Equal([]byte("fffff")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("detects a duplicate frame that is smaller than the original, starting at the beginning", func() {
|
|
||||||
// 10 to 12
|
|
||||||
err := s.push([]byte("12"), 10, false)
|
|
||||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
|
||||||
Expect(s.queue[10]).To(HaveLen(5))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("detects a duplicate frame that is smaller than the original, somewhere in the middle", func() {
|
|
||||||
// 1 to 4
|
|
||||||
err := s.push([]byte("123"), 1, false)
|
|
||||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
|
||||||
Expect(s.queue[0]).To(HaveLen(5))
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("detects a duplicate frame that is smaller than the original, somewhere in the middle in the last block", func() {
|
|
||||||
// 11 to 14
|
|
||||||
err := s.push([]byte("123"), 11, false)
|
|
||||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
|
||||||
Expect(s.queue[10]).To(HaveLen(5))
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("detects a duplicate frame that is smaller than the original, with aligned end in the last block", func() {
|
|
||||||
// 11 to 15
|
|
||||||
err := s.push([]byte("1234"), 1, false)
|
|
||||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
|
||||||
Expect(s.queue[10]).To(HaveLen(5))
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("detects a duplicate frame that is smaller than the original, with aligned end", func() {
|
|
||||||
// 3 to 5
|
|
||||||
err := s.push([]byte("12"), 3, false)
|
|
||||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
|
||||||
Expect(s.queue[0]).To(HaveLen(5))
|
|
||||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(3)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("DoS protection", func() {
|
|
||||||
It("errors when too many gaps are created", func() {
|
|
||||||
for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ {
|
|
||||||
Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7), false)).To(Succeed())
|
|
||||||
}
|
|
||||||
Expect(s.gaps.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps))
|
|
||||||
err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, false)
|
|
||||||
Expect(err).To(MatchError("Too many gaps in received data"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
639
vendor/lucas-clemente/quic-go/h2quic/client_test.go
vendored
639
vendor/lucas-clemente/quic-go/h2quic/client_test.go
vendored
@ -1,639 +0,0 @@
|
|||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"compress/gzip"
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
|
|
||||||
"time"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Client", func() {
|
|
||||||
var (
|
|
||||||
client *client
|
|
||||||
session *mockSession
|
|
||||||
headerStream *mockStream
|
|
||||||
req *http.Request
|
|
||||||
origDialAddr = dialAddr
|
|
||||||
)
|
|
||||||
|
|
||||||
injectResponse := func(id protocol.StreamID, rsp *http.Response) {
|
|
||||||
EventuallyWithOffset(0, func() bool {
|
|
||||||
client.mutex.Lock()
|
|
||||||
defer client.mutex.Unlock()
|
|
||||||
_, ok := client.responses[id]
|
|
||||||
return ok
|
|
||||||
}).Should(BeTrue())
|
|
||||||
rspChan := client.responses[5]
|
|
||||||
ExpectWithOffset(0, rspChan).ToNot(BeClosed())
|
|
||||||
rspChan <- rsp
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
origDialAddr = dialAddr
|
|
||||||
hostname := "quic.clemente.io:1337"
|
|
||||||
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
Expect(client.hostname).To(Equal(hostname))
|
|
||||||
session = newMockSession()
|
|
||||||
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
|
||||||
client.session = session
|
|
||||||
|
|
||||||
headerStream = newMockStream(3)
|
|
||||||
client.headerStream = headerStream
|
|
||||||
client.requestWriter = newRequestWriter(headerStream, utils.DefaultLogger)
|
|
||||||
var err error
|
|
||||||
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
dialAddr = origDialAddr
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves the TLS config", func() {
|
|
||||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
|
||||||
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
|
|
||||||
Expect(client.tlsConf).To(Equal(tlsConf))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves the QUIC config", func() {
|
|
||||||
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
|
||||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
|
|
||||||
Expect(client.config).To(Equal(quicConf))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the default QUIC config if none is give", func() {
|
|
||||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
|
|
||||||
Expect(client.config).ToNot(BeNil())
|
|
||||||
Expect(client.config).To(Equal(defaultQuicConfig))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds the port to the hostname, if none is given", func() {
|
|
||||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("dials", func() {
|
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
close(headerStream.unblockRead)
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
// fmt.Println("done")
|
|
||||||
}()
|
|
||||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when dialing fails", func() {
|
|
||||||
testErr := errors.New("handshake error")
|
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return nil, testErr
|
|
||||||
}
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the custom dialer, if provided", func() {
|
|
||||||
var tlsCfg *tls.Config
|
|
||||||
var qCfg *quic.Config
|
|
||||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
|
||||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
|
||||||
tlsCfg = tlsCfgP
|
|
||||||
qCfg = cfg
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
|
||||||
Expect(qCfg).To(Equal(client.config))
|
|
||||||
Expect(tlsCfg).To(Equal(client.tlsConf))
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if it can't open a stream", func() {
|
|
||||||
testErr := errors.New("you shall not pass")
|
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
session.streamOpenErr = testErr
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns a request when dial fails", func() {
|
|
||||||
testErr := errors.New("dial error")
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return nil, testErr
|
|
||||||
}
|
|
||||||
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
_, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
_, err = client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Doing requests", func() {
|
|
||||||
var request *http.Request
|
|
||||||
var dataStream *mockStream
|
|
||||||
|
|
||||||
getRequest := func(data []byte) *http2.MetaHeadersFrame {
|
|
||||||
r := bytes.NewReader(data)
|
|
||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
||||||
h2framer := http2.NewFramer(nil, r)
|
|
||||||
frame, err := h2framer.ReadFrame()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
|
|
||||||
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return mhframe
|
|
||||||
}
|
|
||||||
|
|
||||||
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
|
|
||||||
fields := make(map[string]string)
|
|
||||||
for _, hf := range f.Fields {
|
|
||||||
fields[hf.Name] = hf.Value
|
|
||||||
}
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
dataStream = newMockStream(5)
|
|
||||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
|
|
||||||
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does a request", func() {
|
|
||||||
teapot := &http.Response{
|
|
||||||
Status: "418 I'm a teapot",
|
|
||||||
StatusCode: 418,
|
|
||||||
}
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp).To(Equal(teapot))
|
|
||||||
Expect(rsp.Body).To(Equal(dataStream))
|
|
||||||
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
|
|
||||||
Expect(rsp.Request).To(Equal(request))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
|
||||||
injectResponse(5, teapot)
|
|
||||||
Expect(client.headerErrored).ToNot(BeClosed())
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if a request without a body is canceled", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
request = request.WithContext(ctx)
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(dataStream.reset).To(BeTrue())
|
|
||||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
|
||||||
Expect(client.headerErrored).ToNot(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if a request with a body is canceled after the body is sent", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
request = request.WithContext(ctx)
|
|
||||||
request.Body = &mockBody{}
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
cancel()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(dataStream.reset).To(BeTrue())
|
|
||||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
|
||||||
Expect(client.headerErrored).ToNot(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if a request with a body is canceled before the body is sent", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
request = request.WithContext(ctx)
|
|
||||||
request.Body = &mockBody{}
|
|
||||||
cancel()
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(dataStream.reset).To(BeTrue())
|
|
||||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
|
||||||
Expect(client.headerErrored).ToNot(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the quic client when encountering an error on the header stream", func() {
|
|
||||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(client.headerErr))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
|
||||||
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns subsequent request if there was an error on the header stream before", func() {
|
|
||||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
|
|
||||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
||||||
_, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
||||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
|
||||||
// now that the first request failed due to an error on the header stream, try another request
|
|
||||||
_, nextErr := client.RoundTrip(request)
|
|
||||||
Expect(nextErr).To(MatchError(err))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("blocks if no stream is available", func() {
|
|
||||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
|
|
||||||
session.blockOpenStreamSync = true
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
|
||||||
// make the go routine return
|
|
||||||
client.Close()
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("validating the address", func() {
|
|
||||||
It("refuses to do requests for the wrong host", func() {
|
|
||||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = client.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("refuses to do plain HTTP requests", func() {
|
|
||||||
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = client.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError("quic http2: unsupported scheme"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds the port for request URLs without one", func() {
|
|
||||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the EndStream header for requests without a body", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
client.RoundTrip(request)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
|
|
||||||
mhf := getRequest(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the EndStream header to false for requests with a body", func() {
|
|
||||||
request.Body = &mockBody{}
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
client.RoundTrip(request)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
|
|
||||||
mhf := getRequest(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("requests containing a Body", func() {
|
|
||||||
var requestBody []byte
|
|
||||||
var response *http.Response
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
requestBody = []byte("request body")
|
|
||||||
body := &mockBody{}
|
|
||||||
body.SetData(requestBody)
|
|
||||||
request.Body = body
|
|
||||||
response = &http.Response{
|
|
||||||
StatusCode: 200,
|
|
||||||
Header: http.Header{"Content-Length": []string{"1000"}},
|
|
||||||
}
|
|
||||||
// fake a handshake
|
|
||||||
client.dialOnce.Do(func() {})
|
|
||||||
session.streamsToOpen = []quic.Stream{dataStream}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends a request", func() {
|
|
||||||
rspChan := make(chan *http.Response)
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rspChan <- rsp
|
|
||||||
}()
|
|
||||||
injectResponse(5, response)
|
|
||||||
Eventually(rspChan).Should(Receive(Equal(response)))
|
|
||||||
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
|
|
||||||
Expect(dataStream.closed).To(BeTrue())
|
|
||||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the error that occurred when reading the body", func() {
|
|
||||||
testErr := errors.New("testErr")
|
|
||||||
request.Body.(*mockBody).readErr = testErr
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the error that occurred when closing the body", func() {
|
|
||||||
testErr := errors.New("testErr")
|
|
||||||
request.Body.(*mockBody).closeErr = testErr
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("gzip compression", func() {
|
|
||||||
var gzippedData []byte // a gzipped foobar
|
|
||||||
var response *http.Response
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var b bytes.Buffer
|
|
||||||
w := gzip.NewWriter(&b)
|
|
||||||
w.Write([]byte("foobar"))
|
|
||||||
w.Close()
|
|
||||||
gzippedData = b.Bytes()
|
|
||||||
response = &http.Response{
|
|
||||||
StatusCode: 200,
|
|
||||||
Header: http.Header{"Content-Length": []string{"1000"}},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds the gzip header to requests", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp).ToNot(BeNil())
|
|
||||||
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
|
|
||||||
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
|
|
||||||
Expect(rsp.Header.Get("Content-Length")).To(BeEmpty())
|
|
||||||
data := make([]byte, 6)
|
|
||||||
_, err = io.ReadFull(rsp.Body, data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(data).To(Equal([]byte("foobar")))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
dataStream.dataToRead.Write(gzippedData)
|
|
||||||
response.Header.Add("Content-Encoding", "gzip")
|
|
||||||
injectResponse(5, response)
|
|
||||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
||||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
||||||
close(dataStream.unblockRead)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't add gzip if the header disable it", func() {
|
|
||||||
client.opts.DisableCompression = true
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
|
||||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
||||||
Expect(headers).ToNot(HaveKey("accept-encoding"))
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("only decompresses the response if the response contains the right content-encoding header", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp).ToNot(BeNil())
|
|
||||||
data := make([]byte, 11)
|
|
||||||
rsp.Body.Read(data)
|
|
||||||
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
|
|
||||||
Expect(data).To(Equal([]byte("not gzipped")))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
dataStream.dataToRead.Write([]byte("not gzipped"))
|
|
||||||
injectResponse(5, response)
|
|
||||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
||||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
|
|
||||||
request.Header.Add("accept-encoding", "gzip")
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
data := make([]byte, 12)
|
|
||||||
_, err = rsp.Body.Read(data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
|
|
||||||
Expect(data).To(Equal([]byte("gzipped data")))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
dataStream.dataToRead.Write([]byte("gzipped data"))
|
|
||||||
injectResponse(5, response)
|
|
||||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
||||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("handling the header stream", func() {
|
|
||||||
var h2framer *http2.Framer
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
|
|
||||||
client.responses[23] = make(chan *http.Response)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads header values from a response", func() {
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
|
|
||||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
|
|
||||||
headerStream.dataToRead.Write(data)
|
|
||||||
go client.handleHeaderStream()
|
|
||||||
var rsp *http.Response
|
|
||||||
Eventually(client.responses[23]).Should(Receive(&rsp))
|
|
||||||
Expect(rsp).ToNot(BeNil())
|
|
||||||
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
|
|
||||||
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
|
|
||||||
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
|
|
||||||
Expect(rsp.Status).To(Equal("302 Found"))
|
|
||||||
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
|
|
||||||
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the H2 frame is not a HeadersFrame", func() {
|
|
||||||
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
|
|
||||||
client.handleHeaderStream()
|
|
||||||
Eventually(client.headerErrored).Should(BeClosed())
|
|
||||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if it can't read the HPACK encoded header fields", func() {
|
|
||||||
h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
||||||
StreamID: 23,
|
|
||||||
EndHeaders: true,
|
|
||||||
BlockFragment: []byte("invalid HPACK data"),
|
|
||||||
})
|
|
||||||
client.handleHeaderStream()
|
|
||||||
Eventually(client.headerErrored).Should(BeClosed())
|
|
||||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
|
||||||
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the stream cannot be found", func() {
|
|
||||||
var headers bytes.Buffer
|
|
||||||
enc := hpack.NewEncoder(&headers)
|
|
||||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
|
||||||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
||||||
StreamID: 1337,
|
|
||||||
EndHeaders: true,
|
|
||||||
BlockFragment: headers.Bytes(),
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
client.handleHeaderStream()
|
|
||||||
Eventually(client.headerErrored).Should(BeClosed())
|
|
||||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
|
||||||
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,13 +0,0 @@
|
|||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestH2quic(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "H2quic Suite")
|
|
||||||
}
|
|
@ -1,39 +0,0 @@
|
|||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Request body", func() {
|
|
||||||
var (
|
|
||||||
stream *mockStream
|
|
||||||
rb *requestBody
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
stream = &mockStream{}
|
|
||||||
stream.dataToRead.Write([]byte("foobar")) // provides data to be read
|
|
||||||
rb = newRequestBody(stream)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads from the stream", func() {
|
|
||||||
b := make([]byte, 10)
|
|
||||||
n, _ := stream.Read(b)
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(b[0:6]).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves if the stream was read from", func() {
|
|
||||||
Expect(rb.requestRead).To(BeFalse())
|
|
||||||
rb.Read(make([]byte, 1))
|
|
||||||
Expect(rb.requestRead).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't close the stream when closing the request body", func() {
|
|
||||||
Expect(stream.closed).To(BeFalse())
|
|
||||||
err := rb.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(stream.closed).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
121
vendor/lucas-clemente/quic-go/h2quic/request_test.go
vendored
121
vendor/lucas-clemente/quic-go/h2quic/request_test.go
vendored
@ -1,121 +0,0 @@
|
|||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Request", func() {
|
|
||||||
It("populates request", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
{Name: "content-length", Value: "42"},
|
|
||||||
}
|
|
||||||
req, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(req.Method).To(Equal("GET"))
|
|
||||||
Expect(req.URL.Path).To(Equal("/foo"))
|
|
||||||
Expect(req.Proto).To(Equal("HTTP/2.0"))
|
|
||||||
Expect(req.ProtoMajor).To(Equal(2))
|
|
||||||
Expect(req.ProtoMinor).To(Equal(0))
|
|
||||||
Expect(req.ContentLength).To(Equal(int64(42)))
|
|
||||||
Expect(req.Header).To(BeEmpty())
|
|
||||||
Expect(req.Body).To(BeNil())
|
|
||||||
Expect(req.Host).To(Equal("quic.clemente.io"))
|
|
||||||
Expect(req.RequestURI).To(Equal("/foo"))
|
|
||||||
Expect(req.TLS).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("concatenates the cookie headers", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
{Name: "cookie", Value: "cookie1=foobar1"},
|
|
||||||
{Name: "cookie", Value: "cookie2=foobar2"},
|
|
||||||
}
|
|
||||||
req, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(req.Header).To(Equal(http.Header{
|
|
||||||
"Cookie": []string{"cookie1=foobar1; cookie2=foobar2"},
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles other headers", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
{Name: "cache-control", Value: "max-age=0"},
|
|
||||||
{Name: "duplicate-header", Value: "1"},
|
|
||||||
{Name: "duplicate-header", Value: "2"},
|
|
||||||
}
|
|
||||||
req, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(req.Header).To(Equal(http.Header{
|
|
||||||
"Cache-Control": []string{"max-age=0"},
|
|
||||||
"Duplicate-Header": []string{"1", "2"},
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with missing path", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
}
|
|
||||||
_, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with missing method", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
}
|
|
||||||
_, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with missing authority", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
}
|
|
||||||
_, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("extracting the hostname from a request", func() {
|
|
||||||
var url *url.URL
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
url, err = url.Parse("https://quic.clemente.io:1337")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses req.URL.Host", func() {
|
|
||||||
req := &http.Request{URL: url}
|
|
||||||
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses req.URL.Host even if req.Host is available", func() {
|
|
||||||
req := &http.Request{
|
|
||||||
Host: "www.example.org",
|
|
||||||
URL: url,
|
|
||||||
}
|
|
||||||
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns an empty hostname if nothing is set", func() {
|
|
||||||
Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,118 +0,0 @@
|
|||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Request", func() {
|
|
||||||
var (
|
|
||||||
rw *requestWriter
|
|
||||||
headerStream *mockStream
|
|
||||||
decoder *hpack.Decoder
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
headerStream = &mockStream{}
|
|
||||||
rw = newRequestWriter(headerStream, utils.DefaultLogger)
|
|
||||||
decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
||||||
})
|
|
||||||
|
|
||||||
decode := func(p []byte) (*http2.HeadersFrame, map[string] /* HeaderField.Name */ string /* HeaderField.Value */) {
|
|
||||||
framer := http2.NewFramer(nil, bytes.NewReader(p))
|
|
||||||
frame, err := framer.ReadFrame()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
headerFrame := frame.(*http2.HeadersFrame)
|
|
||||||
fields, err := decoder.DecodeFull(headerFrame.HeaderBlockFragment())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
values := make(map[string]string)
|
|
||||||
for _, headerField := range fields {
|
|
||||||
values[headerField.Name] = headerField.Value
|
|
||||||
}
|
|
||||||
return headerFrame, values
|
|
||||||
}
|
|
||||||
|
|
||||||
It("writes a GET request", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 1337, true, false)
|
|
||||||
headerFrame, headerFields := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFrame.StreamID).To(Equal(uint32(1337)))
|
|
||||||
Expect(headerFrame.HasPriority()).To(BeTrue())
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
|
|
||||||
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the EndStream header", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 1337, true, false)
|
|
||||||
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFrame.StreamEnded()).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't set the EndStream header, if requested", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 1337, false, false)
|
|
||||||
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFrame.StreamEnded()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("requests gzip compression, if requested", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 1337, true, true)
|
|
||||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes a POST request", func() {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Add("foo", "bar")
|
|
||||||
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode()))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 5, true, false)
|
|
||||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
|
||||||
Expect(headerFields).To(HaveKey("content-length"))
|
|
||||||
contentLength, err := strconv.Atoi(headerFields["content-length"])
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(contentLength).To(BeNumerically(">", 0))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends cookies", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cookie1 := &http.Cookie{
|
|
||||||
Name: "Cookie #1",
|
|
||||||
Value: "Value #1",
|
|
||||||
}
|
|
||||||
cookie2 := &http.Cookie{
|
|
||||||
Name: "Cookie #2",
|
|
||||||
Value: "Value #2",
|
|
||||||
}
|
|
||||||
req.AddCookie(cookie1)
|
|
||||||
req.AddCookie(cookie2)
|
|
||||||
rw.WriteRequest(req, 11, true, false)
|
|
||||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
|
||||||
// TODO(lclemente): Remove Or() once we drop support for Go 1.8.
|
|
||||||
Expect(headerFields).To(Or(
|
|
||||||
HaveKeyWithValue("cookie", "Cookie #1=Value #1; Cookie #2=Value #2"),
|
|
||||||
HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`),
|
|
||||||
))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,163 +0,0 @@
|
|||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockStream struct {
|
|
||||||
id protocol.StreamID
|
|
||||||
dataToRead bytes.Buffer
|
|
||||||
dataWritten bytes.Buffer
|
|
||||||
reset bool
|
|
||||||
canceledWrite bool
|
|
||||||
closed bool
|
|
||||||
remoteClosed bool
|
|
||||||
|
|
||||||
unblockRead chan struct{}
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ quic.Stream = &mockStream{}
|
|
||||||
|
|
||||||
func newMockStream(id protocol.StreamID) *mockStream {
|
|
||||||
s := &mockStream{
|
|
||||||
id: id,
|
|
||||||
unblockRead: make(chan struct{}),
|
|
||||||
}
|
|
||||||
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
|
|
||||||
func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil }
|
|
||||||
func (s *mockStream) CancelWrite(quic.ErrorCode) error { s.canceledWrite = true; return nil }
|
|
||||||
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
|
|
||||||
func (s mockStream) StreamID() protocol.StreamID { return s.id }
|
|
||||||
func (s *mockStream) Context() context.Context { return s.ctx }
|
|
||||||
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
|
|
||||||
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
|
|
||||||
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }
|
|
||||||
|
|
||||||
func (s *mockStream) Read(p []byte) (int, error) {
|
|
||||||
n, _ := s.dataToRead.Read(p)
|
|
||||||
if n == 0 { // block if there's no data
|
|
||||||
<-s.unblockRead
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
return n, nil // never return an EOF
|
|
||||||
}
|
|
||||||
func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) }
|
|
||||||
|
|
||||||
var _ = Describe("Response Writer", func() {
|
|
||||||
var (
|
|
||||||
w *responseWriter
|
|
||||||
headerStream *mockStream
|
|
||||||
dataStream *mockStream
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
headerStream = &mockStream{}
|
|
||||||
dataStream = &mockStream{}
|
|
||||||
w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5, utils.DefaultLogger)
|
|
||||||
})
|
|
||||||
|
|
||||||
decodeHeaderFields := func() map[string][]string {
|
|
||||||
fields := make(map[string][]string)
|
|
||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
||||||
h2framer := http2.NewFramer(nil, bytes.NewReader(headerStream.dataWritten.Bytes()))
|
|
||||||
|
|
||||||
frame, err := h2framer.ReadFrame()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(frame).To(BeAssignableToTypeOf(&http2.HeadersFrame{}))
|
|
||||||
hframe := frame.(*http2.HeadersFrame)
|
|
||||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
|
||||||
Expect(mhframe.StreamID).To(BeEquivalentTo(5))
|
|
||||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
for _, p := range mhframe.Fields {
|
|
||||||
fields[p.Name] = append(fields[p.Name], p.Value)
|
|
||||||
}
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
|
|
||||||
It("writes status", func() {
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveLen(1))
|
|
||||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes headers", func() {
|
|
||||||
w.Header().Add("content-length", "42")
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes multiple headers with the same name", func() {
|
|
||||||
const cookie1 = "test1=1; Max-Age=7200; path=/"
|
|
||||||
const cookie2 = "test2=2; Max-Age=7200; path=/"
|
|
||||||
w.Header().Add("set-cookie", cookie1)
|
|
||||||
w.Header().Add("set-cookie", cookie2)
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveKey("set-cookie"))
|
|
||||||
cookies := fields["set-cookie"]
|
|
||||||
Expect(cookies).To(ContainElement(cookie1))
|
|
||||||
Expect(cookies).To(ContainElement(cookie2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes data", func() {
|
|
||||||
n, err := w.Write([]byte("foobar"))
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// Should have written 200 on the header stream
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
|
|
||||||
// And foobar on the data stream
|
|
||||||
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes data after WriteHeader is called", func() {
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
n, err := w.Write([]byte("foobar"))
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// Should have written 418 on the header stream
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
|
||||||
// And foobar on the data stream
|
|
||||||
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not WriteHeader() twice", func() {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
w.WriteHeader(500)
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveLen(1))
|
|
||||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't allow writes if the status code doesn't allow a body", func() {
|
|
||||||
w.WriteHeader(304)
|
|
||||||
n, err := w.Write([]byte("foobar"))
|
|
||||||
Expect(n).To(BeZero())
|
|
||||||
Expect(err).To(MatchError(http.ErrBodyNotAllowed))
|
|
||||||
Expect(dataStream.dataWritten.Bytes()).To(HaveLen(0))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,218 +0,0 @@
|
|||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockClient struct {
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
return &http.Response{Request: req}, nil
|
|
||||||
}
|
|
||||||
func (m *mockClient) Close() error {
|
|
||||||
m.closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ roundTripCloser = &mockClient{}
|
|
||||||
|
|
||||||
type mockBody struct {
|
|
||||||
reader bytes.Reader
|
|
||||||
readErr error
|
|
||||||
closeErr error
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockBody) Read(p []byte) (int, error) {
|
|
||||||
if m.readErr != nil {
|
|
||||||
return 0, m.readErr
|
|
||||||
}
|
|
||||||
return m.reader.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockBody) SetData(data []byte) {
|
|
||||||
m.reader = *bytes.NewReader(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockBody) Close() error {
|
|
||||||
m.closed = true
|
|
||||||
return m.closeErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure the mockBody can be used as a http.Request.Body
|
|
||||||
var _ io.ReadCloser = &mockBody{}
|
|
||||||
|
|
||||||
var _ = Describe("RoundTripper", func() {
|
|
||||||
var (
|
|
||||||
rt *RoundTripper
|
|
||||||
req1 *http.Request
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
rt = &RoundTripper{}
|
|
||||||
var err error
|
|
||||||
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("dialing hosts", func() {
|
|
||||||
origDialAddr := dialAddr
|
|
||||||
streamOpenErr := errors.New("error opening stream")
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
origDialAddr = dialAddr
|
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
|
||||||
// return an error when trying to open a stream
|
|
||||||
// we don't want to test all the dial logic here, just that dialing happens at all
|
|
||||||
return &mockSession{streamOpenErr: streamOpenErr}, nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
dialAddr = origDialAddr
|
|
||||||
})
|
|
||||||
|
|
||||||
It("creates new clients", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = rt.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError(streamOpenErr))
|
|
||||||
Expect(rt.clients).To(HaveLen(1))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the quic.Config, if provided", func() {
|
|
||||||
config := &quic.Config{HandshakeTimeout: time.Millisecond}
|
|
||||||
var receivedConfig *quic.Config
|
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
|
||||||
receivedConfig = config
|
|
||||||
return nil, errors.New("err")
|
|
||||||
}
|
|
||||||
rt.QuicConfig = config
|
|
||||||
rt.RoundTrip(req1)
|
|
||||||
Expect(receivedConfig).To(Equal(config))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the custom dialer, if provided", func() {
|
|
||||||
var dialed bool
|
|
||||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
|
||||||
dialed = true
|
|
||||||
return nil, errors.New("err")
|
|
||||||
}
|
|
||||||
rt.Dial = dialer
|
|
||||||
rt.RoundTrip(req1)
|
|
||||||
Expect(dialed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reuses existing clients", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = rt.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError(streamOpenErr))
|
|
||||||
Expect(rt.clients).To(HaveLen(1))
|
|
||||||
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = rt.RoundTrip(req2)
|
|
||||||
Expect(err).To(MatchError(streamOpenErr))
|
|
||||||
Expect(rt.clients).To(HaveLen(1))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
|
|
||||||
Expect(err).To(MatchError(ErrNoCachedConn))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("validating request", func() {
|
|
||||||
It("rejects plain HTTP requests", func() {
|
|
||||||
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
|
||||||
req.Body = &mockBody{}
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = rt.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError("quic: unsupported protocol scheme: http"))
|
|
||||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects requests without a URL", func() {
|
|
||||||
req1.URL = nil
|
|
||||||
req1.Body = &mockBody{}
|
|
||||||
_, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err).To(MatchError("quic: nil Request.URL"))
|
|
||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects request without a URL Host", func() {
|
|
||||||
req1.URL.Host = ""
|
|
||||||
req1.Body = &mockBody{}
|
|
||||||
_, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err).To(MatchError("quic: no Host in request URL"))
|
|
||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't try to close the body if the request doesn't have one", func() {
|
|
||||||
req1.URL = nil
|
|
||||||
Expect(req1.Body).To(BeNil())
|
|
||||||
_, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err).To(MatchError("quic: nil Request.URL"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects requests without a header", func() {
|
|
||||||
req1.Header = nil
|
|
||||||
req1.Body = &mockBody{}
|
|
||||||
_, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err).To(MatchError("quic: nil Request.Header"))
|
|
||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects requests with invalid header name fields", func() {
|
|
||||||
req1.Header.Add("foobär", "value")
|
|
||||||
_, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err).To(MatchError("quic: invalid http header field name \"foobär\""))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects requests with invalid header name values", func() {
|
|
||||||
req1.Header.Add("foo", string([]byte{0x7}))
|
|
||||||
_, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects requests with an invalid request method", func() {
|
|
||||||
req1.Method = "foobär"
|
|
||||||
req1.Body = &mockBody{}
|
|
||||||
_, err := rt.RoundTrip(req1)
|
|
||||||
Expect(err).To(MatchError("quic: invalid method \"foobär\""))
|
|
||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("closing", func() {
|
|
||||||
It("closes", func() {
|
|
||||||
rt.clients = make(map[string]roundTripCloser)
|
|
||||||
cl := &mockClient{}
|
|
||||||
rt.clients["foo.bar"] = cl
|
|
||||||
err := rt.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(len(rt.clients)).To(BeZero())
|
|
||||||
Expect(cl.closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes a RoundTripper that has never been used", func() {
|
|
||||||
Expect(len(rt.clients)).To(BeZero())
|
|
||||||
err := rt.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(len(rt.clients)).To(BeZero())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
536
vendor/lucas-clemente/quic-go/h2quic/server_test.go
vendored
536
vendor/lucas-clemente/quic-go/h2quic/server_test.go
vendored
@ -1,536 +0,0 @@
|
|||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockSession struct {
|
|
||||||
closed bool
|
|
||||||
closedWithError error
|
|
||||||
dataStream quic.Stream
|
|
||||||
streamToAccept quic.Stream
|
|
||||||
streamsToOpen []quic.Stream
|
|
||||||
blockOpenStreamSync bool
|
|
||||||
blockOpenStreamChan chan struct{} // close this chan (or call Close) to make OpenStreamSync return
|
|
||||||
streamOpenErr error
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockSession() *mockSession {
|
|
||||||
return &mockSession{blockOpenStreamChan: make(chan struct{})}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
|
|
||||||
return s.dataStream, nil
|
|
||||||
}
|
|
||||||
func (s *mockSession) AcceptStream() (quic.Stream, error) { return s.streamToAccept, nil }
|
|
||||||
func (s *mockSession) OpenStream() (quic.Stream, error) {
|
|
||||||
if s.streamOpenErr != nil {
|
|
||||||
return nil, s.streamOpenErr
|
|
||||||
}
|
|
||||||
str := s.streamsToOpen[0]
|
|
||||||
s.streamsToOpen = s.streamsToOpen[1:]
|
|
||||||
return str, nil
|
|
||||||
}
|
|
||||||
func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
|
|
||||||
if s.blockOpenStreamSync {
|
|
||||||
<-s.blockOpenStreamChan
|
|
||||||
}
|
|
||||||
return s.OpenStream()
|
|
||||||
}
|
|
||||||
func (s *mockSession) Close() error {
|
|
||||||
s.ctxCancel()
|
|
||||||
if !s.closed {
|
|
||||||
close(s.blockOpenStreamChan)
|
|
||||||
}
|
|
||||||
s.closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error {
|
|
||||||
s.closedWithError = e
|
|
||||||
return s.Close()
|
|
||||||
}
|
|
||||||
func (s *mockSession) LocalAddr() net.Addr {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
func (s *mockSession) RemoteAddr() net.Addr {
|
|
||||||
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
|
|
||||||
}
|
|
||||||
func (s *mockSession) Context() context.Context {
|
|
||||||
return s.ctx
|
|
||||||
}
|
|
||||||
func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") }
|
|
||||||
func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") }
|
|
||||||
func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") }
|
|
||||||
func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") }
|
|
||||||
|
|
||||||
var _ = Describe("H2 server", func() {
|
|
||||||
var (
|
|
||||||
s *Server
|
|
||||||
session *mockSession
|
|
||||||
dataStream *mockStream
|
|
||||||
origQuicListenAddr = quicListenAddr
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
s = &Server{
|
|
||||||
Server: &http.Server{
|
|
||||||
TLSConfig: testdata.GetTLSConfig(),
|
|
||||||
},
|
|
||||||
logger: utils.DefaultLogger,
|
|
||||||
}
|
|
||||||
dataStream = newMockStream(0)
|
|
||||||
close(dataStream.unblockRead)
|
|
||||||
session = newMockSession()
|
|
||||||
session.dataStream = dataStream
|
|
||||||
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
|
||||||
origQuicListenAddr = quicListenAddr
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
quicListenAddr = origQuicListenAddr
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("handling requests", func() {
|
|
||||||
var (
|
|
||||||
h2framer *http2.Framer
|
|
||||||
hpackDecoder *hpack.Decoder
|
|
||||||
headerStream *mockStream
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
headerStream = &mockStream{}
|
|
||||||
hpackDecoder = hpack.NewDecoder(4096, nil)
|
|
||||||
h2framer = http2.NewFramer(nil, headerStream)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles a sample GET request", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
Expect(r.RemoteAddr).To(Equal("127.0.0.1:42"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
Expect(dataStream.remoteClosed).To(BeTrue())
|
|
||||||
Expect(dataStream.reset).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns 200 with an empty handler", func() {
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() []byte {
|
|
||||||
return headerStream.dataWritten.Bytes()
|
|
||||||
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200
|
|
||||||
})
|
|
||||||
|
|
||||||
It("correctly handles a panicking handler", func() {
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
panic("foobar")
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() []byte {
|
|
||||||
return headerStream.dataWritten.Bytes()
|
|
||||||
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500
|
|
||||||
})
|
|
||||||
|
|
||||||
It("resets the dataStream when client sends a body in GET request", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
|
|
||||||
Expect(dataStream.remoteClosed).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("resets the dataStream when the body of POST request is not read", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
Expect(r.Method).To(Equal("POST"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
|
|
||||||
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
|
|
||||||
Expect(handlerCalled).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles a request for which the client immediately resets the data stream", func() {
|
|
||||||
session.dataStream = nil
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("resets the dataStream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
r.Body = struct {
|
|
||||||
io.Reader
|
|
||||||
io.Closer
|
|
||||||
}{}
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
|
|
||||||
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
|
|
||||||
Expect(handlerCalled).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the dataStream if the body of POST request was read", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
Expect(r.Method).To(Equal("POST"))
|
|
||||||
handlerCalled = true
|
|
||||||
// read the request body
|
|
||||||
b := make([]byte, 1000)
|
|
||||||
n, _ := r.Body.Read(b)
|
|
||||||
Expect(n).ToNot(BeZero())
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
|
||||||
dataStream.dataToRead.Write([]byte("foo=bar"))
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
Expect(dataStream.reset).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores PRIORITY frames", func() {
|
|
||||||
handlerCalled := make(chan struct{})
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
close(handlerCalled)
|
|
||||||
})
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
framer := http2.NewFramer(buf, nil)
|
|
||||||
err := framer.WritePriority(10, http2.PriorityParam{Weight: 42})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(buf.Bytes()).ToNot(BeEmpty())
|
|
||||||
headerStream.dataToRead.Write(buf.Bytes())
|
|
||||||
err = s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Consistently(handlerCalled).ShouldNot(BeClosed())
|
|
||||||
Expect(dataStream.reset).To(BeFalse())
|
|
||||||
Expect(dataStream.closed).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when non-header frames are received", func() {
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x06, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
'f', 'o', 'o', 'b', 'a', 'r',
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("Cancels the request context when the datstream is closed", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
err := r.Context().Err()
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(Equal("context canceled"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
dataStream.Close()
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
Expect(dataStream.remoteClosed).To(BeTrue())
|
|
||||||
Expect(dataStream.reset).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles the header stream", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream := &mockStream{id: 3}
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
session.streamToAccept = headerStream
|
|
||||||
go s.handleHeaderStream(session)
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the connection if it encounters an error on the header stream", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream := &mockStream{id: 3}
|
|
||||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
||||||
session.streamToAccept = headerStream
|
|
||||||
go s.handleHeaderStream(session)
|
|
||||||
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
|
|
||||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
|
||||||
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("supports closing after first request", func() {
|
|
||||||
s.CloseAfterFirstRequest = true
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
headerStream := &mockStream{id: 3}
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
session.streamToAccept = headerStream
|
|
||||||
Expect(session.closed).To(BeFalse())
|
|
||||||
go s.handleHeaderStream(session)
|
|
||||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the default handler as fallback", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
handlerCalled = true
|
|
||||||
}))
|
|
||||||
headerStream := &mockStream{id: 3}
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
session.streamToAccept = headerStream
|
|
||||||
go s.handleHeaderStream(session)
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("setting http headers", func() {
|
|
||||||
var expected http.Header
|
|
||||||
|
|
||||||
getExpectedHeader := func(versions []protocol.VersionNumber) http.Header {
|
|
||||||
var versionsAsString []string
|
|
||||||
for _, v := range versions {
|
|
||||||
versionsAsString = append(versionsAsString, v.ToAltSvc())
|
|
||||||
}
|
|
||||||
return http.Header{
|
|
||||||
"Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
Expect(getExpectedHeader([]protocol.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}}))
|
|
||||||
expected = getExpectedHeader(protocol.SupportedVersions)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets proper headers with numeric port", func() {
|
|
||||||
s.Server.Addr = ":443"
|
|
||||||
hdr := http.Header{}
|
|
||||||
err := s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets proper headers with full addr", func() {
|
|
||||||
s.Server.Addr = "127.0.0.1:443"
|
|
||||||
hdr := http.Header{}
|
|
||||||
err := s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets proper headers with string port", func() {
|
|
||||||
s.Server.Addr = ":https"
|
|
||||||
hdr := http.Header{}
|
|
||||||
err := s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works multiple times", func() {
|
|
||||||
s.Server.Addr = ":https"
|
|
||||||
hdr := http.Header{}
|
|
||||||
err := s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
hdr = http.Header{}
|
|
||||||
err = s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("should error when ListenAndServe is called with s.Server nil", func() {
|
|
||||||
err := (&Server{}).ListenAndServe()
|
|
||||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
|
|
||||||
err := (&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())
|
|
||||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("should nop-Close() when s.server is nil", func() {
|
|
||||||
err := (&Server{}).Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when ListenAndServer is called after Close", func() {
|
|
||||||
serv := &Server{Server: &http.Server{}}
|
|
||||||
Expect(serv.Close()).To(Succeed())
|
|
||||||
err := serv.ListenAndServe()
|
|
||||||
Expect(err).To(MatchError("Server is already closed"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ListenAndServe", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
s.Server.Addr = "localhost:0"
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(s.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("may only be called once", func() {
|
|
||||||
cErr := make(chan error)
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
err := s.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
cErr <- err
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
err := <-cErr
|
|
||||||
Expect(err).To(MatchError("ListenAndServe may only be called once"))
|
|
||||||
err = s.Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}, 0.5)
|
|
||||||
|
|
||||||
It("uses the quic.Config to start the quic server", func() {
|
|
||||||
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
|
||||||
var receivedConf *quic.Config
|
|
||||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
|
||||||
receivedConf = config
|
|
||||||
return nil, errors.New("listen err")
|
|
||||||
}
|
|
||||||
s.QuicConfig = conf
|
|
||||||
go s.ListenAndServe()
|
|
||||||
Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ListenAndServeTLS", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
s.Server.Addr = "localhost:0"
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
err := s.Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("may only be called once", func() {
|
|
||||||
cErr := make(chan error)
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
err := s.ListenAndServeTLS(testdata.GetCertificatePaths())
|
|
||||||
if err != nil {
|
|
||||||
cErr <- err
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
err := <-cErr
|
|
||||||
Expect(err).To(MatchError("ListenAndServe may only be called once"))
|
|
||||||
err = s.Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}, 0.5)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes gracefully", func() {
|
|
||||||
err := s.CloseGracefully(0)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when listening fails", func() {
|
|
||||||
testErr := errors.New("listen error")
|
|
||||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
|
||||||
return nil, testErr
|
|
||||||
}
|
|
||||||
fullpem, privkey := testdata.GetCertificatePaths()
|
|
||||||
err := ListenAndServeQUIC("", fullpem, privkey, nil)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,210 +0,0 @@
|
|||||||
package chrome_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
"github.com/onsi/gomega/gexec"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
dataLen = 500 * 1024 // 500 KB
|
|
||||||
dataLongLen = 50 * 1024 * 1024 // 50 MB
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
nFilesUploaded int32 // should be used atomically
|
|
||||||
doneCalled utils.AtomicBool
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestChrome(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Chrome Suite")
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Requires the len & num GET parameters, e.g. /uploadtest?len=100&num=1
|
|
||||||
http.HandleFunc("/uploadtest", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
response := uploadHTML
|
|
||||||
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
|
|
||||||
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
|
|
||||||
_, err := io.WriteString(w, response)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
// Requires the len & num GET parameters, e.g. /downloadtest?len=100&num=1
|
|
||||||
http.HandleFunc("/downloadtest", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
response := downloadHTML
|
|
||||||
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
|
|
||||||
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
|
|
||||||
_, err := io.WriteString(w, response)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
http.HandleFunc("/uploadhandler", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
|
|
||||||
l, err := strconv.Atoi(r.URL.Query().Get("len"))
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
defer r.Body.Close()
|
|
||||||
actual, err := ioutil.ReadAll(r.Body)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
Expect(bytes.Equal(actual, testserver.GeneratePRData(l))).To(BeTrue())
|
|
||||||
|
|
||||||
atomic.AddInt32(&nFilesUploaded, 1)
|
|
||||||
})
|
|
||||||
|
|
||||||
http.HandleFunc("/done", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
doneCalled.Set(true)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = AfterEach(func() {
|
|
||||||
testserver.StopQuicServer()
|
|
||||||
|
|
||||||
atomic.StoreInt32(&nFilesUploaded, 0)
|
|
||||||
doneCalled.Set(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
func getChromePath() string {
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
return "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
|
|
||||||
}
|
|
||||||
if path, err := exec.LookPath("google-chrome"); err == nil {
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
if path, err := exec.LookPath("chromium-browser"); err == nil {
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
Fail("No Chrome executable found.")
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func chromeTest(version protocol.VersionNumber, url string, blockUntilDone func()) {
|
|
||||||
userDataDir, err := ioutil.TempDir("", "quic-go-test-chrome-dir")
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer os.RemoveAll(userDataDir)
|
|
||||||
path := getChromePath()
|
|
||||||
args := []string{
|
|
||||||
"--disable-gpu",
|
|
||||||
"--no-first-run=true",
|
|
||||||
"--no-default-browser-check=true",
|
|
||||||
"--user-data-dir=" + userDataDir,
|
|
||||||
"--enable-quic=true",
|
|
||||||
"--no-proxy-server=true",
|
|
||||||
"--no-sandbox",
|
|
||||||
"--origin-to-force-quic-on=quic.clemente.io:443",
|
|
||||||
fmt.Sprintf(`--host-resolver-rules=MAP quic.clemente.io:443 127.0.0.1:%s`, testserver.Port()),
|
|
||||||
fmt.Sprintf("--quic-version=QUIC_VERSION_%s", version.ToAltSvc()),
|
|
||||||
url,
|
|
||||||
}
|
|
||||||
utils.DefaultLogger.Infof("Running chrome: %s '%s'", getChromePath(), strings.Join(args, "' '"))
|
|
||||||
command := exec.Command(path, args...)
|
|
||||||
session, err := gexec.Start(command, nil, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
blockUntilDone()
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForDone() {
|
|
||||||
Eventually(func() bool { return doneCalled.Get() }, 60).Should(BeTrue())
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForNUploaded(expected int) func() {
|
|
||||||
return func() {
|
|
||||||
Eventually(func() int32 {
|
|
||||||
return atomic.LoadInt32(&nFilesUploaded)
|
|
||||||
}, 60).Should(BeEquivalentTo(expected))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const commonJS = `
|
|
||||||
var buf = new ArrayBuffer(LENGTH);
|
|
||||||
var prng = new Uint8Array(buf);
|
|
||||||
var seed = 1;
|
|
||||||
for (var i = 0; i < LENGTH; i++) {
|
|
||||||
// https://en.wikipedia.org/wiki/Lehmer_random_number_generator
|
|
||||||
seed = seed * 48271 % 2147483647;
|
|
||||||
prng[i] = seed;
|
|
||||||
}
|
|
||||||
`
|
|
||||||
|
|
||||||
const uploadHTML = `
|
|
||||||
<html>
|
|
||||||
<body>
|
|
||||||
<script>
|
|
||||||
console.log("Running DL test...");
|
|
||||||
|
|
||||||
` + commonJS + `
|
|
||||||
for (var i = 0; i < NUM; i++) {
|
|
||||||
var req = new XMLHttpRequest();
|
|
||||||
req.open("POST", "/uploadhandler?len=" + LENGTH, true);
|
|
||||||
req.send(buf);
|
|
||||||
}
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
`
|
|
||||||
|
|
||||||
const downloadHTML = `
|
|
||||||
<html>
|
|
||||||
<body>
|
|
||||||
<script>
|
|
||||||
console.log("Running DL test...");
|
|
||||||
` + commonJS + `
|
|
||||||
|
|
||||||
function verify(data) {
|
|
||||||
if (data.length !== LENGTH) return false;
|
|
||||||
for (var i = 0; i < LENGTH; i++) {
|
|
||||||
if (data[i] !== prng[i]) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
var nOK = 0;
|
|
||||||
for (var i = 0; i < NUM; i++) {
|
|
||||||
let req = new XMLHttpRequest();
|
|
||||||
req.responseType = "arraybuffer";
|
|
||||||
req.open("POST", "/prdata?len=" + LENGTH, true);
|
|
||||||
req.onreadystatechange = function () {
|
|
||||||
if (req.readyState === XMLHttpRequest.DONE && req.status === 200) {
|
|
||||||
if (verify(new Uint8Array(req.response))) {
|
|
||||||
nOK++;
|
|
||||||
if (nOK === NUM) {
|
|
||||||
console.log("Done :)");
|
|
||||||
var reqDone = new XMLHttpRequest();
|
|
||||||
reqDone.open("GET", "/done");
|
|
||||||
reqDone.send();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
req.send();
|
|
||||||
}
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
`
|
|
@ -1,76 +0,0 @@
|
|||||||
package chrome_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Chrome tests", func() {
|
|
||||||
for i := range protocol.SupportedVersions {
|
|
||||||
version := protocol.SupportedVersions[i]
|
|
||||||
|
|
||||||
// TODO: activate Chrome integration tests with gQUIC 44
|
|
||||||
if version == protocol.Version44 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with version %s", version), func() {
|
|
||||||
JustBeforeEach(func() {
|
|
||||||
testserver.StartQuicServer([]protocol.VersionNumber{version})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("downloads a small file", func() {
|
|
||||||
chromeTest(
|
|
||||||
version,
|
|
||||||
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLen),
|
|
||||||
waitForDone,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("downloads a large file", func() {
|
|
||||||
chromeTest(
|
|
||||||
version,
|
|
||||||
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLongLen),
|
|
||||||
waitForDone,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("loads a large number of files", func() {
|
|
||||||
chromeTest(
|
|
||||||
version,
|
|
||||||
"https://quic.clemente.io/downloadtest?num=4&len=100",
|
|
||||||
waitForDone,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uploads a small file", func() {
|
|
||||||
chromeTest(
|
|
||||||
version,
|
|
||||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLen),
|
|
||||||
waitForNUploaded(1),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uploads a large file", func() {
|
|
||||||
chromeTest(
|
|
||||||
version,
|
|
||||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLongLen),
|
|
||||||
waitForNUploaded(1),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uploads many small files", func() {
|
|
||||||
num := protocol.DefaultMaxIncomingStreams + 20
|
|
||||||
chromeTest(
|
|
||||||
version,
|
|
||||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=%d&len=%d", num, dataLen),
|
|
||||||
waitForNUploaded(num),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,137 +0,0 @@
|
|||||||
package gquic_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
mrand "math/rand"
|
|
||||||
"os/exec"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
. "github.com/onsi/gomega/gbytes"
|
|
||||||
. "github.com/onsi/gomega/gexec"
|
|
||||||
)
|
|
||||||
|
|
||||||
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
|
|
||||||
|
|
||||||
var _ = Describe("Drop tests", func() {
|
|
||||||
var proxy *quicproxy.QuicProxy
|
|
||||||
|
|
||||||
startProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
|
|
||||||
var err error
|
|
||||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
|
|
||||||
RemoteAddr: "localhost:" + testserver.Port(),
|
|
||||||
DropPacket: dropCallback,
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
|
|
||||||
downloadFile := func(version protocol.VersionNumber) {
|
|
||||||
command := exec.Command(
|
|
||||||
clientPath,
|
|
||||||
"--quic-version="+version.ToAltSvc(),
|
|
||||||
"--host=127.0.0.1",
|
|
||||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
|
||||||
"https://quic.clemente.io/prdata",
|
|
||||||
)
|
|
||||||
session, err := Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
Eventually(session, 20).Should(Exit(0))
|
|
||||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
|
||||||
}
|
|
||||||
|
|
||||||
downloadHello := func(version protocol.VersionNumber) {
|
|
||||||
command := exec.Command(
|
|
||||||
clientPath,
|
|
||||||
"--quic-version="+version.ToAltSvc(),
|
|
||||||
"--host=127.0.0.1",
|
|
||||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
|
||||||
"https://quic.clemente.io/hello",
|
|
||||||
)
|
|
||||||
session, err := Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
Eventually(session, 20).Should(Exit(0))
|
|
||||||
Expect(session.Out).To(Say(":status 200"))
|
|
||||||
Expect(session.Out).To(Say("body: Hello, World!\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
deterministicDropper := func(p, interval, dropInARow uint64) bool {
|
|
||||||
return (p % interval) < dropInARow
|
|
||||||
}
|
|
||||||
|
|
||||||
stochasticDropper := func(freq int) bool {
|
|
||||||
return mrand.Int63n(int64(freq)) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(proxy.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, v := range protocol.SupportedVersions {
|
|
||||||
version := v
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
Context("during the crypto handshake", func() {
|
|
||||||
for _, d := range directions {
|
|
||||||
direction := d
|
|
||||||
|
|
||||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
|
|
||||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return p == 1 && d.Is(direction)
|
|
||||||
}, version)
|
|
||||||
downloadHello(version)
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
|
|
||||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return p == 2 && d.Is(direction)
|
|
||||||
}, version)
|
|
||||||
downloadHello(version)
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
|
|
||||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return d.Is(direction) && stochasticDropper(5)
|
|
||||||
}, version)
|
|
||||||
downloadHello(version)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("after the crypto handshake", func() {
|
|
||||||
for _, d := range directions {
|
|
||||||
direction := d
|
|
||||||
|
|
||||||
It(fmt.Sprintf("downloads a file when every 5th packet is dropped in %s direction", d), func() {
|
|
||||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return p >= 10 && d.Is(direction) && deterministicDropper(p, 5, 1)
|
|
||||||
}, version)
|
|
||||||
downloadFile(version)
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("downloads a file when 1/5th of all packet are dropped randomly in %s direction", d), func() {
|
|
||||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return p >= 10 && d.Is(direction) && stochasticDropper(5)
|
|
||||||
}, version)
|
|
||||||
downloadFile(version)
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("downloads a file when 10 packets every 100 packet are dropped in %s direction", d), func() {
|
|
||||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return p >= 10 && d.Is(direction) && deterministicDropper(p, 100, 10)
|
|
||||||
}, version)
|
|
||||||
downloadFile(version)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,45 +0,0 @@
|
|||||||
package gquic_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
clientPath string
|
|
||||||
serverPath string
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIntegration(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "GQuic Tests Suite")
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = BeforeSuite(func() {
|
|
||||||
rand.Seed(GinkgoRandomSeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
var _ = JustBeforeEach(func() {
|
|
||||||
testserver.StartQuicServer(nil)
|
|
||||||
})
|
|
||||||
|
|
||||||
var _ = AfterEach(testserver.StopQuicServer)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
_, thisfile, _, ok := runtime.Caller(0)
|
|
||||||
if !ok {
|
|
||||||
panic("Failed to get current path")
|
|
||||||
}
|
|
||||||
clientPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/client-%s-debug", runtime.GOOS))
|
|
||||||
serverPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/server-%s-debug", runtime.GOOS))
|
|
||||||
}
|
|
@ -1,98 +0,0 @@
|
|||||||
package gquic_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
. "github.com/onsi/gomega/gbytes"
|
|
||||||
. "github.com/onsi/gomega/gexec"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Integration tests", func() {
|
|
||||||
for i := range protocol.SupportedVersions {
|
|
||||||
version := protocol.SupportedVersions[i]
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
It("gets a simple file", func() {
|
|
||||||
command := exec.Command(
|
|
||||||
clientPath,
|
|
||||||
"--quic-version="+version.ToAltSvc(),
|
|
||||||
"--host=127.0.0.1",
|
|
||||||
"--port="+testserver.Port(),
|
|
||||||
"https://quic.clemente.io/hello",
|
|
||||||
)
|
|
||||||
session, err := Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
Eventually(session, 5).Should(Exit(0))
|
|
||||||
Expect(session.Out).To(Say(":status 200"))
|
|
||||||
Expect(session.Out).To(Say("body: Hello, World!\n"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("posts and reads a body", func() {
|
|
||||||
command := exec.Command(
|
|
||||||
clientPath,
|
|
||||||
"--quic-version="+version.ToAltSvc(),
|
|
||||||
"--host=127.0.0.1",
|
|
||||||
"--port="+testserver.Port(),
|
|
||||||
"--body=foo",
|
|
||||||
"https://quic.clemente.io/echo",
|
|
||||||
)
|
|
||||||
session, err := Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
Eventually(session, 5).Should(Exit(0))
|
|
||||||
Expect(session.Out).To(Say(":status 200"))
|
|
||||||
Expect(session.Out).To(Say("body: foo\n"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets a file", func() {
|
|
||||||
command := exec.Command(
|
|
||||||
clientPath,
|
|
||||||
"--quic-version="+version.ToAltSvc(),
|
|
||||||
"--host=127.0.0.1",
|
|
||||||
"--port="+testserver.Port(),
|
|
||||||
"https://quic.clemente.io/prdata",
|
|
||||||
)
|
|
||||||
session, err := Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
Eventually(session, 10).Should(Exit(0))
|
|
||||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets many copies of a file in parallel", func() {
|
|
||||||
wg := sync.WaitGroup{}
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
defer GinkgoRecover()
|
|
||||||
command := exec.Command(
|
|
||||||
clientPath,
|
|
||||||
"--quic-version="+version.ToAltSvc(),
|
|
||||||
"--host=127.0.0.1",
|
|
||||||
"--port="+testserver.Port(),
|
|
||||||
"https://quic.clemente.io/prdata",
|
|
||||||
)
|
|
||||||
session, err := Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
Eventually(session, 20).Should(Exit(0))
|
|
||||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,95 +0,0 @@
|
|||||||
package gquic_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"os/exec"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
. "github.com/onsi/gomega/gexec"
|
|
||||||
)
|
|
||||||
|
|
||||||
// get a random duration between min and max
|
|
||||||
func getRandomDuration(min, max time.Duration) time.Duration {
|
|
||||||
return min + time.Duration(rand.Int63n(int64(max-min)))
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Random Duration Generator", func() {
|
|
||||||
It("gets a random RTT", func() {
|
|
||||||
var min time.Duration = time.Hour
|
|
||||||
var max time.Duration
|
|
||||||
|
|
||||||
var sum time.Duration
|
|
||||||
rep := 10000
|
|
||||||
for i := 0; i < rep; i++ {
|
|
||||||
val := getRandomDuration(100*time.Millisecond, 500*time.Millisecond)
|
|
||||||
sum += val
|
|
||||||
if val < min {
|
|
||||||
min = val
|
|
||||||
}
|
|
||||||
if val > max {
|
|
||||||
max = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
avg := sum / time.Duration(rep)
|
|
||||||
Expect(avg).To(BeNumerically("~", 300*time.Millisecond, 5*time.Millisecond))
|
|
||||||
Expect(min).To(BeNumerically(">=", 100*time.Millisecond))
|
|
||||||
Expect(min).To(BeNumerically("<", 105*time.Millisecond))
|
|
||||||
Expect(max).To(BeNumerically(">", 495*time.Millisecond))
|
|
||||||
Expect(max).To(BeNumerically("<=", 500*time.Millisecond))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
var _ = Describe("Random RTT", func() {
|
|
||||||
var proxy *quicproxy.QuicProxy
|
|
||||||
|
|
||||||
runRTTTest := func(minRtt, maxRtt time.Duration, version protocol.VersionNumber) {
|
|
||||||
var err error
|
|
||||||
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
|
|
||||||
RemoteAddr: "localhost:" + testserver.Port(),
|
|
||||||
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
|
|
||||||
return getRandomDuration(minRtt, maxRtt)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
command := exec.Command(
|
|
||||||
clientPath,
|
|
||||||
"--quic-version="+version.ToAltSvc(),
|
|
||||||
"--host=127.0.0.1",
|
|
||||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
|
||||||
"https://quic.clemente.io/prdata",
|
|
||||||
)
|
|
||||||
|
|
||||||
session, err := Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
Eventually(session, 20).Should(Exit(0))
|
|
||||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
err := proxy.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
time.Sleep(time.Millisecond)
|
|
||||||
})
|
|
||||||
|
|
||||||
for i := range protocol.SupportedVersions {
|
|
||||||
version := protocol.SupportedVersions[i]
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
It("gets a file a random RTT between 10ms and 30ms", func() {
|
|
||||||
runRTTTest(10*time.Millisecond, 30*time.Millisecond, version)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,66 +0,0 @@
|
|||||||
package gquic_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
. "github.com/onsi/gomega/gexec"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("non-zero RTT", func() {
|
|
||||||
var proxy *quicproxy.QuicProxy
|
|
||||||
|
|
||||||
runRTTTest := func(rtt time.Duration, version protocol.VersionNumber) {
|
|
||||||
var err error
|
|
||||||
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
|
|
||||||
RemoteAddr: "localhost:" + testserver.Port(),
|
|
||||||
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
|
|
||||||
return rtt / 2
|
|
||||||
},
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
command := exec.Command(
|
|
||||||
clientPath,
|
|
||||||
"--quic-version="+version.ToAltSvc(),
|
|
||||||
"--host=127.0.0.1",
|
|
||||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
|
||||||
"https://quic.clemente.io/prdata",
|
|
||||||
)
|
|
||||||
|
|
||||||
session, err := Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
defer session.Kill()
|
|
||||||
Eventually(session, 20).Should(Exit(0))
|
|
||||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
err := proxy.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
time.Sleep(time.Millisecond)
|
|
||||||
})
|
|
||||||
|
|
||||||
for i := range protocol.SupportedVersions {
|
|
||||||
version := protocol.SupportedVersions[i]
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
roundTrips := [...]int{10, 50, 100, 200}
|
|
||||||
for _, rtt := range roundTrips {
|
|
||||||
It(fmt.Sprintf("gets a 500kB file with %dms RTT", rtt), func() {
|
|
||||||
runRTTTest(time.Duration(rtt)*time.Millisecond, version)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,218 +0,0 @@
|
|||||||
package gquic_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/asn1"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"math/big"
|
|
||||||
mrand "math/rand"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
"github.com/onsi/gomega/gbytes"
|
|
||||||
. "github.com/onsi/gomega/gexec"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Server tests", func() {
|
|
||||||
for i := range protocol.SupportedVersions {
|
|
||||||
version := protocol.SupportedVersions[i]
|
|
||||||
|
|
||||||
var (
|
|
||||||
serverPort string
|
|
||||||
tmpDir string
|
|
||||||
session *Session
|
|
||||||
client *http.Client
|
|
||||||
)
|
|
||||||
|
|
||||||
generateCA := func() (*rsa.PrivateKey, *x509.Certificate) {
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
templateRoot := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(-time.Hour),
|
|
||||||
NotAfter: time.Now().Add(time.Hour),
|
|
||||||
IsCA: true,
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, templateRoot, templateRoot, &key.PublicKey, key)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cert, err := x509.ParseCertificate(certDER)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return key, cert
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare the file such that it can be by the quic_server
|
|
||||||
// some HTTP headers neeed to be prepended, see https://www.chromium.org/quic/playing-with-quic
|
|
||||||
createDownloadFile := func(filename string, data []byte) {
|
|
||||||
dataDir := filepath.Join(tmpDir, "quic.clemente.io")
|
|
||||||
err := os.Mkdir(dataDir, 0777)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
f, err := os.Create(filepath.Join(dataDir, filename))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer f.Close()
|
|
||||||
_, err = f.Write([]byte("HTTP/1.1 200 OK\n"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = f.Write([]byte("Content-Type: text/html\n"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = f.Write([]byte("X-Original-Url: https://quic.clemente.io:" + serverPort + "/" + filename + "\n"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = f.Write([]byte("Content-Length: " + strconv.Itoa(len(data)) + "\n\n"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = f.Write(data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
|
|
||||||
// download files must be create *before* the quic_server is started
|
|
||||||
// the quic_server reads its data dir on startup, and only serves those files that were already present then
|
|
||||||
startServer := func(version protocol.VersionNumber) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
var err error
|
|
||||||
command := exec.Command(
|
|
||||||
serverPath,
|
|
||||||
"--quic_response_cache_dir="+filepath.Join(tmpDir, "quic.clemente.io"),
|
|
||||||
"--key_file="+filepath.Join(tmpDir, "key.pkcs8"),
|
|
||||||
"--certificate_file="+filepath.Join(tmpDir, "cert.pem"),
|
|
||||||
"--quic-version="+strconv.Itoa(int(version)),
|
|
||||||
"--port="+serverPort,
|
|
||||||
)
|
|
||||||
session, err = Start(command, nil, GinkgoWriter)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}
|
|
||||||
|
|
||||||
stopServer := func() {
|
|
||||||
session.Kill()
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
serverPort = strconv.Itoa(20000 + int(mrand.Int31n(10000)))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
tmpDir, err = ioutil.TempDir("", "quic-server-certs")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
// generate an RSA key pair for the server
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
// save the private key in PKCS8 format to disk (required by quic_server)
|
|
||||||
pkcs8key, err := asn1.Marshal(struct { // copied from the x509 package
|
|
||||||
Version int
|
|
||||||
Algo pkix.AlgorithmIdentifier
|
|
||||||
PrivateKey []byte
|
|
||||||
}{
|
|
||||||
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
|
|
||||||
Algo: pkix.AlgorithmIdentifier{
|
|
||||||
Algorithm: asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1},
|
|
||||||
Parameters: asn1.RawValue{Tag: 5},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
f, err := os.Create(filepath.Join(tmpDir, "key.pkcs8"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = f.Write(pkcs8key)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
f.Close()
|
|
||||||
|
|
||||||
// generate a Certificate Authority
|
|
||||||
// this CA is used to sign the server's key
|
|
||||||
// it is set as a valid CA in the QUIC client
|
|
||||||
rootKey, CACert := generateCA()
|
|
||||||
// generate the server certificate
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(-30 * time.Minute),
|
|
||||||
NotAfter: time.Now().Add(30 * time.Minute),
|
|
||||||
Subject: pkix.Name{CommonName: "quic.clemente.io"},
|
|
||||||
}
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, template, CACert, &key.PublicKey, rootKey)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// save the certificate to disk
|
|
||||||
certOut, err := os.Create(filepath.Join(tmpDir, "cert.pem"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
|
||||||
certOut.Close()
|
|
||||||
|
|
||||||
// prepare the h2quic.client
|
|
||||||
certPool := x509.NewCertPool()
|
|
||||||
certPool.AddCert(CACert)
|
|
||||||
client = &http.Client{
|
|
||||||
Transport: &h2quic.RoundTripper{
|
|
||||||
TLSClientConfig: &tls.Config{RootCAs: certPool},
|
|
||||||
QuicConfig: &quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{version},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(tmpDir).ToNot(BeEmpty())
|
|
||||||
err := os.RemoveAll(tmpDir)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
tmpDir = ""
|
|
||||||
})
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
It("downloads a hello", func() {
|
|
||||||
data := []byte("Hello world!\n")
|
|
||||||
createDownloadFile("hello", data)
|
|
||||||
|
|
||||||
startServer(version)
|
|
||||||
defer stopServer()
|
|
||||||
|
|
||||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/hello")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp.StatusCode).To(Equal(200))
|
|
||||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(body).To(Equal(data))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("downloads a small file", func() {
|
|
||||||
createDownloadFile("file.dat", testserver.PRData)
|
|
||||||
|
|
||||||
startServer(version)
|
|
||||||
defer stopServer()
|
|
||||||
|
|
||||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp.StatusCode).To(Equal(200))
|
|
||||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(body).To(Equal(testserver.PRData))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("downloads a large file", func() {
|
|
||||||
createDownloadFile("file.dat", testserver.PRDataLong)
|
|
||||||
|
|
||||||
startServer(version)
|
|
||||||
defer stopServer()
|
|
||||||
|
|
||||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp.StatusCode).To(Equal(200))
|
|
||||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 20*time.Second))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(body).To(Equal(testserver.PRDataLong))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,97 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
"github.com/onsi/gomega/gbytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Client tests", func() {
|
|
||||||
var client *http.Client
|
|
||||||
|
|
||||||
// also run some tests with the TLS handshake
|
|
||||||
versions := append(protocol.SupportedVersions, protocol.VersionTLS)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
err := os.Setenv("HOSTALIASES", "quic.clemente.io 127.0.0.1")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
addr, err := net.ResolveUDPAddr("udp4", "quic.clemente.io:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
if addr.String() != "127.0.0.1:0" {
|
|
||||||
Fail("quic.clemente.io does not resolve to 127.0.0.1. Consider adding it to /etc/hosts.")
|
|
||||||
}
|
|
||||||
testserver.StartQuicServer(versions)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
testserver.StopQuicServer()
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, v := range versions {
|
|
||||||
version := v
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
client = &http.Client{
|
|
||||||
Transport: &h2quic.RoundTripper{
|
|
||||||
QuicConfig: &quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{version},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("downloads a hello", func() {
|
|
||||||
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/hello")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(resp.StatusCode).To(Equal(200))
|
|
||||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(string(body)).To(Equal("Hello, World!\n"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("downloads a small file", func() {
|
|
||||||
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/prdata")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(resp.StatusCode).To(Equal(200))
|
|
||||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(body).To(Equal(testserver.PRData))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("downloads a large file", func() {
|
|
||||||
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/prdatalong")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(resp.StatusCode).To(Equal(200))
|
|
||||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(body).To(Equal(testserver.PRDataLong))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uploads a file", func() {
|
|
||||||
resp, err := client.Post(
|
|
||||||
"https://quic.clemente.io:"+testserver.Port()+"/echo",
|
|
||||||
"text/plain",
|
|
||||||
bytes.NewReader(testserver.PRData),
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(resp.StatusCode).To(Equal(200))
|
|
||||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(bytes.Equal(body, testserver.PRData)).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,101 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Connection ID lengths tests", func() {
|
|
||||||
randomConnIDLen := func() int {
|
|
||||||
return 4 + int(rand.Int31n(15))
|
|
||||||
}
|
|
||||||
|
|
||||||
runServer := func(conf *quic.Config) quic.Listener {
|
|
||||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
|
|
||||||
ln, err := quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), conf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
for {
|
|
||||||
sess, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
str, err := sess.OpenStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer str.Close()
|
|
||||||
_, err = str.Write(testserver.PRData)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return ln
|
|
||||||
}
|
|
||||||
|
|
||||||
runClient := func(addr net.Addr, conf *quic.Config) {
|
|
||||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
|
|
||||||
cl, err := quic.DialAddr(
|
|
||||||
fmt.Sprintf("quic.clemente.io:%d", addr.(*net.UDPAddr).Port),
|
|
||||||
&tls.Config{InsecureSkipVerify: true},
|
|
||||||
conf,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer cl.Close()
|
|
||||||
str, err := cl.AcceptStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
data, err := ioutil.ReadAll(str)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(data).To(Equal(testserver.PRData))
|
|
||||||
}
|
|
||||||
|
|
||||||
Context("IETF QUIC", func() {
|
|
||||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
|
||||||
serverConf := &quic.Config{
|
|
||||||
ConnectionIDLength: randomConnIDLen(),
|
|
||||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
}
|
|
||||||
clientConf := &quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
}
|
|
||||||
|
|
||||||
ln := runServer(serverConf)
|
|
||||||
defer ln.Close()
|
|
||||||
runClient(ln.Addr(), clientConf)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("downloads a file when both client and server use a random connection ID length", func() {
|
|
||||||
serverConf := &quic.Config{
|
|
||||||
ConnectionIDLength: randomConnIDLen(),
|
|
||||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
}
|
|
||||||
clientConf := &quic.Config{
|
|
||||||
ConnectionIDLength: randomConnIDLen(),
|
|
||||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
}
|
|
||||||
|
|
||||||
ln := runServer(serverConf)
|
|
||||||
defer ln.Close()
|
|
||||||
runClient(ln.Addr(), clientConf)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("gQUIC", func() {
|
|
||||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
|
||||||
ln := runServer(&quic.Config{})
|
|
||||||
defer ln.Close()
|
|
||||||
runClient(ln.Addr(), &quic.Config{RequestConnectionIDOmission: true})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,189 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
mrand "math/rand"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
"github.com/onsi/gomega/gbytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
|
|
||||||
|
|
||||||
type applicationProtocol struct {
|
|
||||||
name string
|
|
||||||
run func(protocol.VersionNumber)
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Handshake drop tests", func() {
|
|
||||||
var (
|
|
||||||
proxy *quicproxy.QuicProxy
|
|
||||||
ln quic.Listener
|
|
||||||
)
|
|
||||||
|
|
||||||
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
|
|
||||||
var err error
|
|
||||||
ln, err = quic.ListenAddr(
|
|
||||||
"localhost:0",
|
|
||||||
testdata.GetTLSConfig(),
|
|
||||||
&quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{version},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
|
||||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
|
|
||||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
|
||||||
DropPacket: dropCallback,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
|
|
||||||
stochasticDropper := func(freq int) bool {
|
|
||||||
return mrand.Int63n(int64(freq)) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
clientSpeaksFirst := &applicationProtocol{
|
|
||||||
name: "client speaks first",
|
|
||||||
run: func(version protocol.VersionNumber) {
|
|
||||||
serverSessionChan := make(chan quic.Session)
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sess, err := ln.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer sess.Close()
|
|
||||||
str, err := sess.AcceptStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
b := make([]byte, 6)
|
|
||||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(string(b)).To(Equal("foobar"))
|
|
||||||
serverSessionChan <- sess
|
|
||||||
}()
|
|
||||||
sess, err := quic.DialAddr(
|
|
||||||
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
|
|
||||||
nil,
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
str, err := sess.OpenStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = str.Write([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
var serverSession quic.Session
|
|
||||||
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
|
|
||||||
sess.Close()
|
|
||||||
serverSession.Close()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
serverSpeaksFirst := &applicationProtocol{
|
|
||||||
name: "server speaks first",
|
|
||||||
run: func(version protocol.VersionNumber) {
|
|
||||||
serverSessionChan := make(chan quic.Session)
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sess, err := ln.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
str, err := sess.OpenStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = str.Write([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverSessionChan <- sess
|
|
||||||
}()
|
|
||||||
sess, err := quic.DialAddr(
|
|
||||||
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
|
|
||||||
nil,
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
str, err := sess.AcceptStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
b := make([]byte, 6)
|
|
||||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(string(b)).To(Equal("foobar"))
|
|
||||||
|
|
||||||
var serverSession quic.Session
|
|
||||||
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
|
|
||||||
sess.Close()
|
|
||||||
serverSession.Close()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
nobodySpeaks := &applicationProtocol{
|
|
||||||
name: "nobody speaks",
|
|
||||||
run: func(version protocol.VersionNumber) {
|
|
||||||
serverSessionChan := make(chan quic.Session)
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sess, err := ln.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverSessionChan <- sess
|
|
||||||
}()
|
|
||||||
sess, err := quic.DialAddr(
|
|
||||||
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
|
|
||||||
nil,
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
var serverSession quic.Session
|
|
||||||
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
|
|
||||||
// both server and client accepted a session. Close now.
|
|
||||||
sess.Close()
|
|
||||||
serverSession.Close()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(proxy.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
|
|
||||||
version := v
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
for _, d := range directions {
|
|
||||||
direction := d
|
|
||||||
|
|
||||||
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
|
|
||||||
app := a
|
|
||||||
|
|
||||||
Context(app.name, func() {
|
|
||||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
|
|
||||||
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return p == 1 && d.Is(direction)
|
|
||||||
}, version)
|
|
||||||
app.run(version)
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
|
|
||||||
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return p == 2 && d.Is(direction)
|
|
||||||
}, version)
|
|
||||||
app.run(version)
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
|
|
||||||
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
|
|
||||||
return d.Is(direction) && stochasticDropper(5)
|
|
||||||
}, version)
|
|
||||||
app.run(version)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,213 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Handshake RTT tests", func() {
|
|
||||||
var (
|
|
||||||
proxy *quicproxy.QuicProxy
|
|
||||||
server quic.Listener
|
|
||||||
serverConfig *quic.Config
|
|
||||||
testStartedAt time.Time
|
|
||||||
acceptStopped chan struct{}
|
|
||||||
)
|
|
||||||
|
|
||||||
rtt := 400 * time.Millisecond
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
acceptStopped = make(chan struct{})
|
|
||||||
serverConfig = &quic.Config{}
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(proxy.Close()).To(Succeed())
|
|
||||||
Expect(server.Close()).To(Succeed())
|
|
||||||
<-acceptStopped
|
|
||||||
})
|
|
||||||
|
|
||||||
runServerAndProxy := func() {
|
|
||||||
var err error
|
|
||||||
// start the server
|
|
||||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// start the proxy
|
|
||||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", protocol.VersionWhatever, &quicproxy.Opts{
|
|
||||||
RemoteAddr: server.Addr().String(),
|
|
||||||
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration { return rtt / 2 },
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
testStartedAt = time.Now()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
defer close(acceptStopped)
|
|
||||||
for {
|
|
||||||
_, err := server.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
expectDurationInRTTs := func(num int) {
|
|
||||||
testDuration := time.Since(testStartedAt)
|
|
||||||
rtts := float32(testDuration) / float32(rtt)
|
|
||||||
Expect(rtts).To(SatisfyAll(
|
|
||||||
BeNumerically(">=", num),
|
|
||||||
BeNumerically("<", num+1),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
It("fails when there's no matching version, after 1 RTT", func() {
|
|
||||||
if len(protocol.SupportedVersions) == 1 {
|
|
||||||
Skip("Test requires at least 2 supported versions.")
|
|
||||||
}
|
|
||||||
serverConfig.Versions = protocol.SupportedVersions[:1]
|
|
||||||
runServerAndProxy()
|
|
||||||
clientConfig := &quic.Config{
|
|
||||||
Versions: protocol.SupportedVersions[1:2],
|
|
||||||
}
|
|
||||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
|
|
||||||
expectDurationInRTTs(1)
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("gQUIC", func() {
|
|
||||||
// 1 RTT for verifying the source address
|
|
||||||
// 1 RTT to become secure
|
|
||||||
// 1 RTT to become forward-secure
|
|
||||||
It("is forward-secure after 3 RTTs", func() {
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expectDurationInRTTs(3)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does version negotiation in 1 RTT, IETF QUIC => gQUIC", func() {
|
|
||||||
clientConfig := &quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{protocol.VersionTLS, protocol.SupportedVersions[0]},
|
|
||||||
}
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(
|
|
||||||
proxy.LocalAddr().String(),
|
|
||||||
&tls.Config{InsecureSkipVerify: true},
|
|
||||||
clientConfig,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expectDurationInRTTs(4)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() {
|
|
||||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expectDurationInRTTs(2)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
|
|
||||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't complete the handshake when the handshake timeout is too short", func() {
|
|
||||||
serverConfig.HandshakeTimeout = 2 * rtt
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
|
||||||
// 2 RTTs during the timeout
|
|
||||||
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
|
|
||||||
expectDurationInRTTs(3)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("IETF QUIC", func() {
|
|
||||||
var clientConfig *quic.Config
|
|
||||||
var clientTLSConfig *tls.Config
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
|
|
||||||
clientConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
|
||||||
clientTLSConfig = &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
ServerName: "quic.clemente.io",
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// 1 RTT for verifying the source address
|
|
||||||
// 1 RTT for the TLS handshake
|
|
||||||
It("is forward-secure after 2 RTTs", func() {
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(
|
|
||||||
proxy.LocalAddr().String(),
|
|
||||||
clientTLSConfig,
|
|
||||||
clientConfig,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expectDurationInRTTs(2)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does version negotiation in 1 RTT, gQUIC => IETF QUIC", func() {
|
|
||||||
clientConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[0], protocol.VersionTLS}
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(
|
|
||||||
proxy.LocalAddr().String(),
|
|
||||||
clientTLSConfig,
|
|
||||||
clientConfig,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expectDurationInRTTs(3)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is forward-secure after 1 RTTs when the server doesn't require a Cookie", func() {
|
|
||||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(
|
|
||||||
proxy.LocalAddr().String(),
|
|
||||||
clientTLSConfig,
|
|
||||||
clientConfig,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expectDurationInRTTs(1)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
|
|
||||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
runServerAndProxy()
|
|
||||||
_, err := quic.DialAddr(
|
|
||||||
proxy.LocalAddr().String(),
|
|
||||||
clientTLSConfig,
|
|
||||||
clientConfig,
|
|
||||||
)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.CryptoTooManyRejects))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,128 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type versioner interface {
|
|
||||||
GetVersion() protocol.VersionNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Handshake tests", func() {
|
|
||||||
var (
|
|
||||||
server quic.Listener
|
|
||||||
serverConfig *quic.Config
|
|
||||||
acceptStopped chan struct{}
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
server = nil
|
|
||||||
acceptStopped = make(chan struct{})
|
|
||||||
serverConfig = &quic.Config{}
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
if server != nil {
|
|
||||||
server.Close()
|
|
||||||
<-acceptStopped
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
runServer := func() {
|
|
||||||
var err error
|
|
||||||
// start the server
|
|
||||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
defer close(acceptStopped)
|
|
||||||
for {
|
|
||||||
_, err := server.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
Context("Version Negotiation", func() {
|
|
||||||
var supportedVersions []protocol.VersionNumber
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
supportedVersions = protocol.SupportedVersions
|
|
||||||
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
protocol.SupportedVersions = supportedVersions
|
|
||||||
})
|
|
||||||
|
|
||||||
It("when the server supports more versions than the client", func() {
|
|
||||||
// the server doesn't support the highest supported version, which is the first one the client will try
|
|
||||||
// but it supports a bunch of versions that the client doesn't speak
|
|
||||||
serverConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[1], 7, 8, 9}
|
|
||||||
runServer()
|
|
||||||
sess, err := quic.DialAddr(server.Addr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[1]))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("when the client supports more versions than the server supports", func() {
|
|
||||||
// the server doesn't support the highest supported version, which is the first one the client will try
|
|
||||||
// but it supports a bunch of versions that the client doesn't speak
|
|
||||||
serverConfig.Versions = supportedVersions
|
|
||||||
runServer()
|
|
||||||
conf := &quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[1], 10},
|
|
||||||
}
|
|
||||||
sess, err := quic.DialAddr(server.Addr().String(), &tls.Config{InsecureSkipVerify: true}, conf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[1]))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Certifiate validation", func() {
|
|
||||||
for _, v := range []protocol.VersionNumber{protocol.Version39, protocol.VersionTLS} {
|
|
||||||
version := v
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("using %s", version), func() {
|
|
||||||
var clientConfig *quic.Config
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
serverConfig.Versions = []protocol.VersionNumber{version}
|
|
||||||
clientConfig = &quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{version},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts the certificate", func() {
|
|
||||||
runServer()
|
|
||||||
_, err := quic.DialAddr(fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port), nil, clientConfig)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the server name doesn't match", func() {
|
|
||||||
runServer()
|
|
||||||
_, err := quic.DialAddr(fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), nil, clientConfig)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the ServerName in the tls.Config", func() {
|
|
||||||
runServer()
|
|
||||||
conf := &tls.Config{ServerName: "quic.clemente.io"}
|
|
||||||
_, err := quic.DialAddr(fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), conf, clientConfig)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,232 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Multiplexing", func() {
|
|
||||||
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
|
|
||||||
version := v
|
|
||||||
|
|
||||||
// gQUIC 44 uses 0 byte connection IDs for packets sent to the client
|
|
||||||
// It's not possible to do demultiplexing.
|
|
||||||
if v == protocol.Version44 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
runServer := func(ln quic.Listener) {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
for {
|
|
||||||
sess, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
str, err := sess.OpenStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer str.Close()
|
|
||||||
_, err = str.Write(testserver.PRDataLong)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
dial := func(conn net.PacketConn, addr net.Addr) {
|
|
||||||
sess, err := quic.Dial(
|
|
||||||
conn,
|
|
||||||
addr,
|
|
||||||
fmt.Sprintf("quic.clemente.io:%d", addr.(*net.UDPAddr).Port),
|
|
||||||
nil,
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
str, err := sess.AcceptStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
data, err := ioutil.ReadAll(str)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(data).To(Equal(testserver.PRDataLong))
|
|
||||||
}
|
|
||||||
|
|
||||||
Context("multiplexing clients on the same conn", func() {
|
|
||||||
getListener := func() quic.Listener {
|
|
||||||
ln, err := quic.ListenAddr(
|
|
||||||
"localhost:0",
|
|
||||||
testdata.GetTLSConfig(),
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return ln
|
|
||||||
}
|
|
||||||
|
|
||||||
It("multiplexes connections to the same server", func() {
|
|
||||||
server := getListener()
|
|
||||||
runServer(server)
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
dial(conn, server.Addr())
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
dial(conn, server.Addr())
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
timeout := 30 * time.Second
|
|
||||||
if testlog.Debug() {
|
|
||||||
timeout = time.Minute
|
|
||||||
}
|
|
||||||
Eventually(done1, timeout).Should(BeClosed())
|
|
||||||
Eventually(done2, timeout).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("multiplexes connections to different servers", func() {
|
|
||||||
server1 := getListener()
|
|
||||||
runServer(server1)
|
|
||||||
defer server1.Close()
|
|
||||||
server2 := getListener()
|
|
||||||
runServer(server2)
|
|
||||||
defer server2.Close()
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
dial(conn, server1.Addr())
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
dial(conn, server2.Addr())
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
timeout := 30 * time.Second
|
|
||||||
if testlog.Debug() {
|
|
||||||
timeout = time.Minute
|
|
||||||
}
|
|
||||||
Eventually(done1, timeout).Should(BeClosed())
|
|
||||||
Eventually(done2, timeout).Should(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("multiplexing server and client on the same conn", func() {
|
|
||||||
It("connects to itself", func() {
|
|
||||||
if version != protocol.VersionTLS {
|
|
||||||
Skip("Connecting to itself only works with IETF QUIC.")
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
server, err := quic.Listen(
|
|
||||||
conn,
|
|
||||||
testdata.GetTLSConfig(),
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runServer(server)
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
dial(conn, server.Addr())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
timeout := 30 * time.Second
|
|
||||||
if testlog.Debug() {
|
|
||||||
timeout = time.Minute
|
|
||||||
}
|
|
||||||
Eventually(done, timeout).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("runs a server and client on the same conn", func() {
|
|
||||||
if os.Getenv("CI") == "true" {
|
|
||||||
Skip("This test is flaky on CIs, see see https://github.com/golang/go/issues/17677.")
|
|
||||||
}
|
|
||||||
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn1, err := net.ListenUDP("udp", addr1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer conn1.Close()
|
|
||||||
|
|
||||||
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
conn2, err := net.ListenUDP("udp", addr2)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer conn2.Close()
|
|
||||||
|
|
||||||
server1, err := quic.Listen(
|
|
||||||
conn1,
|
|
||||||
testdata.GetTLSConfig(),
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runServer(server1)
|
|
||||||
defer server1.Close()
|
|
||||||
|
|
||||||
server2, err := quic.Listen(
|
|
||||||
conn2,
|
|
||||||
testdata.GetTLSConfig(),
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runServer(server2)
|
|
||||||
defer server2.Close()
|
|
||||||
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
dial(conn2, server1.Addr())
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
dial(conn1, server2.Addr())
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
timeout := 30 * time.Second
|
|
||||||
if testlog.Debug() {
|
|
||||||
timeout = time.Minute
|
|
||||||
}
|
|
||||||
Eventually(done1, timeout).Should(BeClosed())
|
|
||||||
Eventually(done2, timeout).Should(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,83 +0,0 @@
|
|||||||
package self
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("non-zero RTT", func() {
|
|
||||||
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
|
|
||||||
version := v
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
|
||||||
roundTrips := [...]time.Duration{
|
|
||||||
10 * time.Millisecond,
|
|
||||||
50 * time.Millisecond,
|
|
||||||
100 * time.Millisecond,
|
|
||||||
200 * time.Millisecond,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range roundTrips {
|
|
||||||
rtt := r
|
|
||||||
|
|
||||||
It(fmt.Sprintf("downloads a message with %s RTT", rtt), func() {
|
|
||||||
ln, err := quic.ListenAddr(
|
|
||||||
"localhost:0",
|
|
||||||
testdata.GetTLSConfig(),
|
|
||||||
&quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{version},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sess, err := ln.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
str, err := sess.OpenStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = str.Write(testserver.PRData)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
str.Close()
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
|
||||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
|
|
||||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
|
||||||
DelayPacket: func(d quicproxy.Direction, p uint64) time.Duration {
|
|
||||||
return rtt / 2
|
|
||||||
},
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer proxy.Close()
|
|
||||||
|
|
||||||
sess, err := quic.DialAddr(
|
|
||||||
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
|
|
||||||
nil,
|
|
||||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
str, err := sess.AcceptStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
data, err := ioutil.ReadAll(str)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(data).To(Equal(testserver.PRData))
|
|
||||||
sess.Close()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,20 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSelf(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Self integration tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = BeforeSuite(func() {
|
|
||||||
rand.Seed(GinkgoRandomSeed())
|
|
||||||
})
|
|
@ -1,152 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Bidirectional streams", func() {
|
|
||||||
const numStreams = 300
|
|
||||||
|
|
||||||
var (
|
|
||||||
server quic.Listener
|
|
||||||
serverAddr string
|
|
||||||
qconf *quic.Config
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, v := range []protocol.VersionNumber{protocol.VersionTLS} {
|
|
||||||
version := v
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with QUIC %s", version), func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
qconf = &quic.Config{
|
|
||||||
Versions: []protocol.VersionNumber{version},
|
|
||||||
MaxIncomingStreams: 0,
|
|
||||||
}
|
|
||||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), qconf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverAddr = fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
server.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
runSendingPeer := func(sess quic.Session) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(numStreams)
|
|
||||||
for i := 0; i < numStreams; i++ {
|
|
||||||
str, err := sess.OpenStreamSync()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
data := testserver.GeneratePRData(25 * i)
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := str.Write(data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(str.Close()).To(Succeed())
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
defer wg.Done()
|
|
||||||
dataRead, err := ioutil.ReadAll(str)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(dataRead).To(Equal(data))
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
runReceivingPeer := func(sess quic.Session) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(numStreams)
|
|
||||||
for i := 0; i < numStreams; i++ {
|
|
||||||
str, err := sess.AcceptStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
defer wg.Done()
|
|
||||||
// shouldn't use io.Copy here
|
|
||||||
// we should read from the stream as early as possible, to free flow control credit
|
|
||||||
data, err := ioutil.ReadAll(str)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = str.Write(data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(str.Close()).To(Succeed())
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
|
|
||||||
var sess quic.Session
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
var err error
|
|
||||||
sess, err = server.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runReceivingPeer(sess)
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runSendingPeer(client)
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sess, err := server.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runSendingPeer(sess)
|
|
||||||
sess.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runReceivingPeer(client)
|
|
||||||
Eventually(client.Context().Done()).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("client and server opening %d each and sending data to the peer", numStreams), func() {
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sess, err := server.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
runReceivingPeer(sess)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
runSendingPeer(sess)
|
|
||||||
<-done
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
runSendingPeer(client)
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
runReceivingPeer(client)
|
|
||||||
<-done1
|
|
||||||
<-done2
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,132 +0,0 @@
|
|||||||
package self_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Unidirectional Streams", func() {
|
|
||||||
const numStreams = 500
|
|
||||||
|
|
||||||
var (
|
|
||||||
server quic.Listener
|
|
||||||
serverAddr string
|
|
||||||
qconf *quic.Config
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
qconf = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
|
||||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), qconf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverAddr = fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port)
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
server.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
dataForStream := func(id protocol.StreamID) []byte {
|
|
||||||
return testserver.GeneratePRData(10 * int(id))
|
|
||||||
}
|
|
||||||
|
|
||||||
runSendingPeer := func(sess quic.Session) {
|
|
||||||
for i := 0; i < numStreams; i++ {
|
|
||||||
str, err := sess.OpenUniStreamSync()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := str.Write(dataForStream(str.StreamID()))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(str.Close()).To(Succeed())
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
runReceivingPeer := func(sess quic.Session) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(numStreams)
|
|
||||||
for i := 0; i < numStreams; i++ {
|
|
||||||
str, err := sess.AcceptUniStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
defer wg.Done()
|
|
||||||
data, err := ioutil.ReadAll(str)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(data).To(Equal(dataForStream(str.StreamID())))
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
|
|
||||||
var sess quic.Session
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
var err error
|
|
||||||
sess, err = server.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runReceivingPeer(sess)
|
|
||||||
sess.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runSendingPeer(client)
|
|
||||||
<-client.Context().Done()
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sess, err := server.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runSendingPeer(sess)
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
runReceivingPeer(client)
|
|
||||||
})
|
|
||||||
|
|
||||||
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
|
|
||||||
done1 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sess, err := server.Accept()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
runReceivingPeer(sess)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
runSendingPeer(sess)
|
|
||||||
<-done
|
|
||||||
close(done1)
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
done2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
runSendingPeer(client)
|
|
||||||
close(done2)
|
|
||||||
}()
|
|
||||||
runReceivingPeer(client)
|
|
||||||
<-done1
|
|
||||||
<-done2
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,270 +0,0 @@
|
|||||||
package quicproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Connection is a UDP connection
|
|
||||||
type connection struct {
|
|
||||||
ClientAddr *net.UDPAddr // Address of the client
|
|
||||||
ServerConn *net.UDPConn // UDP connection to server
|
|
||||||
|
|
||||||
incomingPacketCounter uint64
|
|
||||||
outgoingPacketCounter uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direction is the direction a packet is sent.
|
|
||||||
type Direction int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// DirectionIncoming is the direction from the client to the server.
|
|
||||||
DirectionIncoming Direction = iota
|
|
||||||
// DirectionOutgoing is the direction from the server to the client.
|
|
||||||
DirectionOutgoing
|
|
||||||
// DirectionBoth is both incoming and outgoing
|
|
||||||
DirectionBoth
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d Direction) String() string {
|
|
||||||
switch d {
|
|
||||||
case DirectionIncoming:
|
|
||||||
return "incoming"
|
|
||||||
case DirectionOutgoing:
|
|
||||||
return "outgoing"
|
|
||||||
case DirectionBoth:
|
|
||||||
return "both"
|
|
||||||
default:
|
|
||||||
panic("unknown direction")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is says if one direction matches another direction.
|
|
||||||
// For example, incoming matches both incoming and both, but not outgoing.
|
|
||||||
func (d Direction) Is(dir Direction) bool {
|
|
||||||
if d == DirectionBoth || dir == DirectionBoth {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return d == dir
|
|
||||||
}
|
|
||||||
|
|
||||||
// DropCallback is a callback that determines which packet gets dropped.
|
|
||||||
type DropCallback func(dir Direction, packetCount uint64) bool
|
|
||||||
|
|
||||||
// NoDropper doesn't drop packets.
|
|
||||||
var NoDropper DropCallback = func(Direction, uint64) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
|
||||||
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
|
|
||||||
|
|
||||||
// NoDelay doesn't apply a delay.
|
|
||||||
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Opts are proxy options.
|
|
||||||
type Opts struct {
|
|
||||||
// The address this proxy proxies packets to.
|
|
||||||
RemoteAddr string
|
|
||||||
// DropPacket determines whether a packet gets dropped.
|
|
||||||
DropPacket DropCallback
|
|
||||||
// DelayPacket determines how long a packet gets delayed. This allows
|
|
||||||
// simulating a connection with non-zero RTTs.
|
|
||||||
// Note that the RTT is the sum of the delay for the incoming and the outgoing packet.
|
|
||||||
DelayPacket DelayCallback
|
|
||||||
}
|
|
||||||
|
|
||||||
// QuicProxy is a QUIC proxy that can drop and delay packets.
|
|
||||||
type QuicProxy struct {
|
|
||||||
mutex sync.Mutex
|
|
||||||
|
|
||||||
version protocol.VersionNumber
|
|
||||||
|
|
||||||
conn *net.UDPConn
|
|
||||||
serverAddr *net.UDPAddr
|
|
||||||
|
|
||||||
dropPacket DropCallback
|
|
||||||
delayPacket DelayCallback
|
|
||||||
|
|
||||||
// Mapping from client addresses (as host:port) to connection
|
|
||||||
clientDict map[string]*connection
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewQuicProxy creates a new UDP proxy
|
|
||||||
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
|
|
||||||
if opts == nil {
|
|
||||||
opts = &Opts{}
|
|
||||||
}
|
|
||||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn, err := net.ListenUDP("udp", laddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
packetDropper := NoDropper
|
|
||||||
if opts.DropPacket != nil {
|
|
||||||
packetDropper = opts.DropPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
packetDelayer := NoDelay
|
|
||||||
if opts.DelayPacket != nil {
|
|
||||||
packetDelayer = opts.DelayPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
p := QuicProxy{
|
|
||||||
clientDict: make(map[string]*connection),
|
|
||||||
conn: conn,
|
|
||||||
serverAddr: raddr,
|
|
||||||
dropPacket: packetDropper,
|
|
||||||
delayPacket: packetDelayer,
|
|
||||||
version: version,
|
|
||||||
logger: utils.DefaultLogger.WithPrefix("proxy"),
|
|
||||||
}
|
|
||||||
|
|
||||||
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
|
|
||||||
go p.runProxy()
|
|
||||||
return &p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops the UDP Proxy
|
|
||||||
func (p *QuicProxy) Close() error {
|
|
||||||
p.mutex.Lock()
|
|
||||||
defer p.mutex.Unlock()
|
|
||||||
for _, c := range p.clientDict {
|
|
||||||
if err := c.ServerConn.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAddr is the address the proxy is listening on.
|
|
||||||
func (p *QuicProxy) LocalAddr() net.Addr {
|
|
||||||
return p.conn.LocalAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalPort is the UDP port number the proxy is listening on.
|
|
||||||
func (p *QuicProxy) LocalPort() int {
|
|
||||||
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
|
||||||
srvudp, err := net.DialUDP("udp", nil, p.serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &connection{
|
|
||||||
ClientAddr: cliAddr,
|
|
||||||
ServerConn: srvudp,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// runProxy listens on the proxy address and handles incoming packets.
|
|
||||||
func (p *QuicProxy) runProxy() error {
|
|
||||||
for {
|
|
||||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
|
||||||
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
raw := buffer[0:n]
|
|
||||||
|
|
||||||
saddr := cliaddr.String()
|
|
||||||
p.mutex.Lock()
|
|
||||||
conn, ok := p.clientDict[saddr]
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
conn, err = p.newConnection(cliaddr)
|
|
||||||
if err != nil {
|
|
||||||
p.mutex.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p.clientDict[saddr] = conn
|
|
||||||
go p.runConnection(conn)
|
|
||||||
}
|
|
||||||
p.mutex.Unlock()
|
|
||||||
|
|
||||||
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
|
||||||
|
|
||||||
if p.dropPacket(DirectionIncoming, packetCount) {
|
|
||||||
if p.logger.Debug() {
|
|
||||||
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the packet to the server
|
|
||||||
delay := p.delayPacket(DirectionIncoming, packetCount)
|
|
||||||
if delay != 0 {
|
|
||||||
if p.logger.Debug() {
|
|
||||||
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
|
|
||||||
}
|
|
||||||
time.AfterFunc(delay, func() {
|
|
||||||
// TODO: handle error
|
|
||||||
_, _ = conn.ServerConn.Write(raw)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
if p.logger.Debug() {
|
|
||||||
p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
|
|
||||||
}
|
|
||||||
if _, err := conn.ServerConn.Write(raw); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// runConnection handles packets from server to a single client
|
|
||||||
func (p *QuicProxy) runConnection(conn *connection) error {
|
|
||||||
for {
|
|
||||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
|
||||||
n, err := conn.ServerConn.Read(buffer)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
raw := buffer[0:n]
|
|
||||||
|
|
||||||
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
|
||||||
|
|
||||||
if p.dropPacket(DirectionOutgoing, packetCount) {
|
|
||||||
if p.logger.Debug() {
|
|
||||||
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
delay := p.delayPacket(DirectionOutgoing, packetCount)
|
|
||||||
if delay != 0 {
|
|
||||||
if p.logger.Debug() {
|
|
||||||
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
|
|
||||||
}
|
|
||||||
time.AfterFunc(delay, func() {
|
|
||||||
// TODO: handle error
|
|
||||||
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
if p.logger.Debug() {
|
|
||||||
p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
|
|
||||||
}
|
|
||||||
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,13 +0,0 @@
|
|||||||
package quicproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestQuicGo(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "QUIC Proxy")
|
|
||||||
}
|
|
@ -1,394 +0,0 @@
|
|||||||
package quicproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net"
|
|
||||||
"runtime/pprof"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type packetData []byte
|
|
||||||
|
|
||||||
var _ = Describe("QUIC Proxy", func() {
|
|
||||||
makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
hdr := wire.Header{
|
|
||||||
PacketNumber: p,
|
|
||||||
PacketNumberLen: protocol.PacketNumberLen6,
|
|
||||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
|
|
||||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
|
|
||||||
}
|
|
||||||
hdr.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever)
|
|
||||||
raw := b.Bytes()
|
|
||||||
raw = append(raw, payload...)
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
|
|
||||||
Context("Proxy setup and teardown", func() {
|
|
||||||
It("sets up the UDPProxy", func() {
|
|
||||||
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(proxy.clientDict).To(HaveLen(0))
|
|
||||||
|
|
||||||
// check that the proxy port is in use
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(proxy.LocalPort()))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = net.ListenUDP("udp", addr)
|
|
||||||
Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort())))
|
|
||||||
Expect(proxy.Close()).To(Succeed()) // stopping is tested in the next test
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stops the UDPProxy", func() {
|
|
||||||
isProxyRunning := func() bool {
|
|
||||||
var b bytes.Buffer
|
|
||||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
|
||||||
return strings.Contains(b.String(), "proxy.(*QuicProxy).runProxy")
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
port := proxy.LocalPort()
|
|
||||||
Expect(isProxyRunning()).To(BeTrue())
|
|
||||||
err = proxy.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
// check that the proxy port is not in use anymore
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(port))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// sometimes it takes a while for the OS to free the port
|
|
||||||
Eventually(func() error {
|
|
||||||
ln, err := net.ListenUDP("udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ln.Close()
|
|
||||||
return nil
|
|
||||||
}).ShouldNot(HaveOccurred())
|
|
||||||
Eventually(isProxyRunning).Should(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stops listening for proxied connections", func() {
|
|
||||||
isConnRunning := func() bool {
|
|
||||||
var b bytes.Buffer
|
|
||||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
|
||||||
return strings.Contains(b.String(), "proxy.(*QuicProxy).runConnection")
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverConn, err := net.ListenUDP("udp", serverAddr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer serverConn.Close()
|
|
||||||
|
|
||||||
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, &Opts{RemoteAddr: serverConn.LocalAddr().String()})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(isConnRunning()).To(BeFalse())
|
|
||||||
|
|
||||||
// check that the proxy port is not in use anymore
|
|
||||||
conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = conn.Write(makePacket(1, []byte("foobar")))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(isConnRunning).Should(BeTrue())
|
|
||||||
Expect(proxy.Close()).To(Succeed())
|
|
||||||
Eventually(isConnRunning).Should(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has the correct LocalAddr and LocalPort", func() {
|
|
||||||
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
Expect(proxy.LocalAddr().String()).To(Equal("127.0.0.1:" + strconv.Itoa(proxy.LocalPort())))
|
|
||||||
Expect(proxy.LocalPort()).ToNot(BeZero())
|
|
||||||
|
|
||||||
Expect(proxy.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Proxy tests", func() {
|
|
||||||
var (
|
|
||||||
serverConn *net.UDPConn
|
|
||||||
serverNumPacketsSent int32
|
|
||||||
serverReceivedPackets chan packetData
|
|
||||||
clientConn *net.UDPConn
|
|
||||||
proxy *QuicProxy
|
|
||||||
)
|
|
||||||
|
|
||||||
startProxy := func(opts *Opts) {
|
|
||||||
var err error
|
|
||||||
proxy, err = NewQuicProxy("localhost:0", protocol.VersionWhatever, opts)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
|
|
||||||
// getClientDict returns a copy of the clientDict map
|
|
||||||
getClientDict := func() map[string]*connection {
|
|
||||||
d := make(map[string]*connection)
|
|
||||||
proxy.mutex.Lock()
|
|
||||||
defer proxy.mutex.Unlock()
|
|
||||||
for k, v := range proxy.clientDict {
|
|
||||||
d[k] = v
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
serverReceivedPackets = make(chan packetData, 100)
|
|
||||||
atomic.StoreInt32(&serverNumPacketsSent, 0)
|
|
||||||
|
|
||||||
// setup a dump UDP server
|
|
||||||
// in production this would be a QUIC server
|
|
||||||
raddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverConn, err = net.ListenUDP("udp", raddr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
buf := make([]byte, protocol.MaxReceivePacketSize)
|
|
||||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
|
||||||
n, addr, err2 := serverConn.ReadFromUDP(buf)
|
|
||||||
if err2 != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data := buf[0:n]
|
|
||||||
serverReceivedPackets <- packetData(data)
|
|
||||||
// echo the packet
|
|
||||||
serverConn.WriteToUDP(data, addr)
|
|
||||||
atomic.AddInt32(&serverNumPacketsSent, 1)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
err := proxy.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = serverConn.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = clientConn.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("no packet drop", func() {
|
|
||||||
It("relays packets from the client to the server", func() {
|
|
||||||
startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
|
|
||||||
// send the first packet
|
|
||||||
_, err := clientConn.Write(makePacket(1, []byte("foobar")))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
Eventually(getClientDict).Should(HaveLen(1))
|
|
||||||
var conn *connection
|
|
||||||
for _, conn = range getClientDict() {
|
|
||||||
Eventually(func() uint64 { return atomic.LoadUint64(&conn.incomingPacketCounter) }).Should(Equal(uint64(1)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the second packet
|
|
||||||
_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
Eventually(serverReceivedPackets).Should(HaveLen(2))
|
|
||||||
Expect(getClientDict()).To(HaveLen(1))
|
|
||||||
Expect(string(<-serverReceivedPackets)).To(ContainSubstring("foobar"))
|
|
||||||
Expect(string(<-serverReceivedPackets)).To(ContainSubstring("decafbad"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("relays packets from the server to the client", func() {
|
|
||||||
startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
|
|
||||||
// send the first packet
|
|
||||||
_, err := clientConn.Write(makePacket(1, []byte("foobar")))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
Eventually(getClientDict).Should(HaveLen(1))
|
|
||||||
var key string
|
|
||||||
var conn *connection
|
|
||||||
for key, conn = range getClientDict() {
|
|
||||||
Eventually(func() uint64 { return atomic.LoadUint64(&conn.outgoingPacketCounter) }).Should(Equal(uint64(1)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the second packet
|
|
||||||
_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
Expect(getClientDict()).To(HaveLen(1))
|
|
||||||
Eventually(func() uint64 {
|
|
||||||
conn := getClientDict()[key]
|
|
||||||
return atomic.LoadUint64(&conn.outgoingPacketCounter)
|
|
||||||
}).Should(BeEquivalentTo(2))
|
|
||||||
|
|
||||||
clientReceivedPackets := make(chan packetData, 2)
|
|
||||||
// receive the packets echoed by the server on client side
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
buf := make([]byte, protocol.MaxReceivePacketSize)
|
|
||||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
|
||||||
n, _, err2 := clientConn.ReadFromUDP(buf)
|
|
||||||
if err2 != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data := buf[0:n]
|
|
||||||
clientReceivedPackets <- packetData(data)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(serverReceivedPackets).Should(HaveLen(2))
|
|
||||||
Expect(atomic.LoadInt32(&serverNumPacketsSent)).To(BeEquivalentTo(2))
|
|
||||||
Eventually(clientReceivedPackets).Should(HaveLen(2))
|
|
||||||
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("foobar"))
|
|
||||||
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("decafbad"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Drop Callbacks", func() {
|
|
||||||
It("drops incoming packets", func() {
|
|
||||||
opts := &Opts{
|
|
||||||
RemoteAddr: serverConn.LocalAddr().String(),
|
|
||||||
DropPacket: func(d Direction, p uint64) bool {
|
|
||||||
return d == DirectionIncoming && p%2 == 0
|
|
||||||
},
|
|
||||||
}
|
|
||||||
startProxy(opts)
|
|
||||||
|
|
||||||
for i := 1; i <= 6; i++ {
|
|
||||||
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
Eventually(serverReceivedPackets).Should(HaveLen(3))
|
|
||||||
Consistently(serverReceivedPackets).Should(HaveLen(3))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("drops outgoing packets", func() {
|
|
||||||
const numPackets = 6
|
|
||||||
opts := &Opts{
|
|
||||||
RemoteAddr: serverConn.LocalAddr().String(),
|
|
||||||
DropPacket: func(d Direction, p uint64) bool {
|
|
||||||
return d == DirectionOutgoing && p%2 == 0
|
|
||||||
},
|
|
||||||
}
|
|
||||||
startProxy(opts)
|
|
||||||
|
|
||||||
clientReceivedPackets := make(chan packetData, numPackets)
|
|
||||||
// receive the packets echoed by the server on client side
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
buf := make([]byte, protocol.MaxReceivePacketSize)
|
|
||||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
|
||||||
n, _, err2 := clientConn.ReadFromUDP(buf)
|
|
||||||
if err2 != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data := buf[0:n]
|
|
||||||
clientReceivedPackets <- packetData(data)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for i := 1; i <= numPackets; i++ {
|
|
||||||
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
|
|
||||||
Eventually(clientReceivedPackets).Should(HaveLen(numPackets / 2))
|
|
||||||
Consistently(clientReceivedPackets).Should(HaveLen(numPackets / 2))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Delay Callback", func() {
|
|
||||||
expectDelay := func(startTime time.Time, rtt time.Duration, numRTTs int) {
|
|
||||||
expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * rtt)
|
|
||||||
Expect(time.Now()).To(SatisfyAll(
|
|
||||||
BeTemporally(">=", expectedReceiveTime),
|
|
||||||
BeTemporally("<", expectedReceiveTime.Add(rtt/2)),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
It("delays incoming packets", func() {
|
|
||||||
delay := 300 * time.Millisecond
|
|
||||||
opts := &Opts{
|
|
||||||
RemoteAddr: serverConn.LocalAddr().String(),
|
|
||||||
// delay packet 1 by 200 ms
|
|
||||||
// delay packet 2 by 400 ms
|
|
||||||
// ...
|
|
||||||
DelayPacket: func(d Direction, p uint64) time.Duration {
|
|
||||||
if d == DirectionOutgoing {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return time.Duration(p) * delay
|
|
||||||
},
|
|
||||||
}
|
|
||||||
startProxy(opts)
|
|
||||||
|
|
||||||
// send 3 packets
|
|
||||||
start := time.Now()
|
|
||||||
for i := 1; i <= 3; i++ {
|
|
||||||
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
Eventually(serverReceivedPackets).Should(HaveLen(1))
|
|
||||||
expectDelay(start, delay, 1)
|
|
||||||
Eventually(serverReceivedPackets).Should(HaveLen(2))
|
|
||||||
expectDelay(start, delay, 2)
|
|
||||||
Eventually(serverReceivedPackets).Should(HaveLen(3))
|
|
||||||
expectDelay(start, delay, 3)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("delays outgoing packets", func() {
|
|
||||||
const numPackets = 3
|
|
||||||
delay := 300 * time.Millisecond
|
|
||||||
opts := &Opts{
|
|
||||||
RemoteAddr: serverConn.LocalAddr().String(),
|
|
||||||
// delay packet 1 by 200 ms
|
|
||||||
// delay packet 2 by 400 ms
|
|
||||||
// ...
|
|
||||||
DelayPacket: func(d Direction, p uint64) time.Duration {
|
|
||||||
if d == DirectionIncoming {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return time.Duration(p) * delay
|
|
||||||
},
|
|
||||||
}
|
|
||||||
startProxy(opts)
|
|
||||||
|
|
||||||
clientReceivedPackets := make(chan packetData, numPackets)
|
|
||||||
// receive the packets echoed by the server on client side
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
buf := make([]byte, protocol.MaxReceivePacketSize)
|
|
||||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
|
||||||
n, _, err2 := clientConn.ReadFromUDP(buf)
|
|
||||||
if err2 != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data := buf[0:n]
|
|
||||||
clientReceivedPackets <- packetData(data)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
for i := 1; i <= numPackets; i++ {
|
|
||||||
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
// the packets should have arrived immediately at the server
|
|
||||||
Eventually(serverReceivedPackets).Should(HaveLen(3))
|
|
||||||
expectDelay(start, delay, 0)
|
|
||||||
Eventually(clientReceivedPackets).Should(HaveLen(1))
|
|
||||||
expectDelay(start, delay, 1)
|
|
||||||
Eventually(clientReceivedPackets).Should(HaveLen(2))
|
|
||||||
expectDelay(start, delay, 2)
|
|
||||||
Eventually(clientReceivedPackets).Should(HaveLen(3))
|
|
||||||
expectDelay(start, delay, 3)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,46 +0,0 @@
|
|||||||
package testlog
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
logFileName string // the log file set in the ginkgo flags
|
|
||||||
logFile *os.File
|
|
||||||
)
|
|
||||||
|
|
||||||
// read the logfile command line flag
|
|
||||||
// to set call ginkgo -- -logfile=log.txt
|
|
||||||
func init() {
|
|
||||||
flag.StringVar(&logFileName, "logfile", "", "log file")
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = BeforeEach(func() {
|
|
||||||
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
|
|
||||||
|
|
||||||
if len(logFileName) > 0 {
|
|
||||||
var err error
|
|
||||||
logFile, err = os.Create(logFileName)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
log.SetOutput(logFile)
|
|
||||||
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
var _ = AfterEach(func() {
|
|
||||||
if len(logFileName) > 0 {
|
|
||||||
_ = logFile.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Debug says if this test is being logged
|
|
||||||
func Debug() bool {
|
|
||||||
return len(logFileName) > 0
|
|
||||||
}
|
|
@ -1,119 +0,0 @@
|
|||||||
package testserver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
dataLen = 500 * 1024 // 500 KB
|
|
||||||
dataLenLong = 50 * 1024 * 1024 // 50 MB
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// PRData contains dataLen bytes of pseudo-random data.
|
|
||||||
PRData = GeneratePRData(dataLen)
|
|
||||||
// PRDataLong contains dataLenLong bytes of pseudo-random data.
|
|
||||||
PRDataLong = GeneratePRData(dataLenLong)
|
|
||||||
|
|
||||||
server *h2quic.Server
|
|
||||||
stoppedServing chan struct{}
|
|
||||||
port string
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
sl := r.URL.Query().Get("len")
|
|
||||||
if sl != "" {
|
|
||||||
var err error
|
|
||||||
l, err := strconv.Atoi(sl)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
_, err = w.Write(GeneratePRData(l))
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
} else {
|
|
||||||
_, err := w.Write(PRData)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := w.Write(PRDataLong)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := io.WriteString(w, "Hello, World!\n")
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
_, err = w.Write(body)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
|
|
||||||
func GeneratePRData(l int) []byte {
|
|
||||||
res := make([]byte, l)
|
|
||||||
seed := uint64(1)
|
|
||||||
for i := 0; i < l; i++ {
|
|
||||||
seed = seed * 48271 % 2147483647
|
|
||||||
res[i] = byte(seed)
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartQuicServer starts a h2quic.Server.
|
|
||||||
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
|
||||||
func StartQuicServer(versions []protocol.VersionNumber) {
|
|
||||||
server = &h2quic.Server{
|
|
||||||
Server: &http.Server{
|
|
||||||
TLSConfig: testdata.GetTLSConfig(),
|
|
||||||
},
|
|
||||||
QuicConfig: &quic.Config{
|
|
||||||
Versions: versions,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
|
||||||
|
|
||||||
stoppedServing = make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
server.Serve(conn)
|
|
||||||
close(stoppedServing)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// StopQuicServer stops the h2quic.Server.
|
|
||||||
func StopQuicServer() {
|
|
||||||
Expect(server.Close()).NotTo(HaveOccurred())
|
|
||||||
Eventually(stoppedServing).Should(BeClosed())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Port returns the UDP port of the QUIC server.
|
|
||||||
func Port() string {
|
|
||||||
return port
|
|
||||||
}
|
|
@ -1,24 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCrypto(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "AckHandler Suite")
|
|
||||||
}
|
|
||||||
|
|
||||||
var mockCtrl *gomock.Controller
|
|
||||||
|
|
||||||
var _ = BeforeEach(func() {
|
|
||||||
mockCtrl = gomock.NewController(GinkgoT())
|
|
||||||
})
|
|
||||||
|
|
||||||
var _ = AfterEach(func() {
|
|
||||||
mockCtrl.Finish()
|
|
||||||
})
|
|
@ -1,382 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("receivedPacketHandler", func() {
|
|
||||||
var (
|
|
||||||
handler *receivedPacketHandler
|
|
||||||
rttStats *congestion.RTTStats
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
rttStats = &congestion.RTTStats{}
|
|
||||||
handler = NewReceivedPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*receivedPacketHandler)
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("accepting packets", func() {
|
|
||||||
It("handles a packet that arrives late", func() {
|
|
||||||
err := handler.ReceivedPacket(protocol.PacketNumber(1), time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(protocol.PacketNumber(3), time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(protocol.PacketNumber(2), time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves the time when each packet arrived", func() {
|
|
||||||
err := handler.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("updates the largestObserved and the largestObservedReceivedTime", func() {
|
|
||||||
now := time.Now()
|
|
||||||
handler.largestObserved = 3
|
|
||||||
handler.largestObservedReceivedTime = now.Add(-1 * time.Second)
|
|
||||||
err := handler.ReceivedPacket(5, now, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
|
|
||||||
Expect(handler.largestObservedReceivedTime).To(Equal(now))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() {
|
|
||||||
now := time.Now()
|
|
||||||
timestamp := now.Add(-1 * time.Second)
|
|
||||||
handler.largestObserved = 5
|
|
||||||
handler.largestObservedReceivedTime = timestamp
|
|
||||||
err := handler.ReceivedPacket(4, now, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
|
|
||||||
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("passes on errors from receivedPacketHistory", func() {
|
|
||||||
var err error
|
|
||||||
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
|
|
||||||
err = handler.ReceivedPacket(2*i+1, time.Time{}, true)
|
|
||||||
// this will eventually return an error
|
|
||||||
// details about when exactly the receivedPacketHistory errors are tested there
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ACKs", func() {
|
|
||||||
Context("queueing ACKs", func() {
|
|
||||||
receiveAndAck10Packets := func() {
|
|
||||||
for i := 1; i <= 10; i++ {
|
|
||||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
}
|
|
||||||
|
|
||||||
receiveAndAckPacketsUntilAckDecimation := func() {
|
|
||||||
for i := 1; i <= minReceivedBeforeAckDecimation; i++ {
|
|
||||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
}
|
|
||||||
|
|
||||||
It("always queues an ACK for the first packet", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeTrue())
|
|
||||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works with packet number 0", func() {
|
|
||||||
err := handler.ReceivedPacket(0, time.Time{}, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeTrue())
|
|
||||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("queues an ACK for every second retransmittable packet at the beginning", func() {
|
|
||||||
receiveAndAck10Packets()
|
|
||||||
p := protocol.PacketNumber(11)
|
|
||||||
for i := 0; i <= 20; i++ {
|
|
||||||
err := handler.ReceivedPacket(p, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
p++
|
|
||||||
err = handler.ReceivedPacket(p, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeTrue())
|
|
||||||
p++
|
|
||||||
// dequeue the ACK frame
|
|
||||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("queues an ACK for every 10 retransmittable packet, if they are arriving fast", func() {
|
|
||||||
receiveAndAck10Packets()
|
|
||||||
p := protocol.PacketNumber(10000)
|
|
||||||
for i := 0; i < 9; i++ {
|
|
||||||
err := handler.ReceivedPacket(p, time.Now(), true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
p++
|
|
||||||
}
|
|
||||||
Expect(handler.GetAlarmTimeout()).NotTo(BeZero())
|
|
||||||
err := handler.ReceivedPacket(p, time.Now(), true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeTrue())
|
|
||||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("only sets the timer when receiving a retransmittable packets", func() {
|
|
||||||
receiveAndAck10Packets()
|
|
||||||
err := handler.ReceivedPacket(11, time.Now(), false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
|
||||||
rcvTime := time.Now().Add(10 * time.Millisecond)
|
|
||||||
err = handler.ReceivedPacket(12, rcvTime, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
Expect(handler.GetAlarmTimeout()).To(Equal(rcvTime.Add(ackSendDelay)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("queues an ACK if it was reported missing before", func() {
|
|
||||||
receiveAndAck10Packets()
|
|
||||||
err := handler.ReceivedPacket(11, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(13, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame() // ACK: 1-11 and 13, missing: 12
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.HasMissingRanges()).To(BeTrue())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
err = handler.ReceivedPacket(12, time.Time{}, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't queue an ACK if it was reported missing before, but is below the threshold", func() {
|
|
||||||
receiveAndAck10Packets()
|
|
||||||
// 11 is missing
|
|
||||||
err := handler.ReceivedPacket(12, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(13, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame() // ACK: 1-10, 12-13
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
// now receive 11
|
|
||||||
handler.IgnoreBelow(12)
|
|
||||||
err = handler.ReceivedPacket(11, time.Time{}, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack = handler.GetAckFrame()
|
|
||||||
Expect(ack).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't queue an ACK if the packet closes a gap that was not yet reported", func() {
|
|
||||||
receiveAndAckPacketsUntilAckDecimation()
|
|
||||||
p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1)
|
|
||||||
err := handler.ReceivedPacket(p+1, time.Now(), true) // p is missing now
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
Expect(handler.GetAlarmTimeout()).ToNot(BeZero())
|
|
||||||
err = handler.ReceivedPacket(p, time.Now(), true) // p is not missing any more
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets an ACK alarm after 1/4 RTT if it creates a new missing range", func() {
|
|
||||||
now := time.Now().Add(-time.Hour)
|
|
||||||
rtt := 80 * time.Millisecond
|
|
||||||
rttStats.UpdateRTT(rtt, 0, now)
|
|
||||||
receiveAndAckPacketsUntilAckDecimation()
|
|
||||||
p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1)
|
|
||||||
for i := p; i < p+6; i++ {
|
|
||||||
err := handler.ReceivedPacket(i, now, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
err := handler.ReceivedPacket(p+10, now, true) // we now know that packets p+7, p+8 and p+9
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal(rtt))
|
|
||||||
Expect(handler.ackAlarm.Sub(now)).To(Equal(rtt / 8))
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack.HasMissingRanges()).To(BeTrue())
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ACK generation", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
handler.ackQueued = true
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates a simple ACK frame", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(2, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2)))
|
|
||||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1)))
|
|
||||||
Expect(ack.HasMissingRanges()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates an ACK for packet number 0", func() {
|
|
||||||
err := handler.ReceivedPacket(0, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0)))
|
|
||||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0)))
|
|
||||||
Expect(ack.HasMissingRanges()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the delay time", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(2, time.Now().Add(-1337*time.Millisecond), true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves the last sent ACK", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(handler.lastAck).To(Equal(ack))
|
|
||||||
err = handler.ReceivedPacket(2, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
handler.ackQueued = true
|
|
||||||
ack = handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(handler.lastAck).To(Equal(ack))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates an ACK frame with missing packets", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(4, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4)))
|
|
||||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1)))
|
|
||||||
Expect(ack.AckRanges).To(Equal([]wire.AckRange{
|
|
||||||
{Smallest: 4, Largest: 4},
|
|
||||||
{Smallest: 1, Largest: 1},
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates an ACK for packet number 0 and other packets", func() {
|
|
||||||
err := handler.ReceivedPacket(0, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(3, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3)))
|
|
||||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0)))
|
|
||||||
Expect(ack.AckRanges).To(Equal([]wire.AckRange{
|
|
||||||
{Smallest: 3, Largest: 3},
|
|
||||||
{Smallest: 0, Largest: 1},
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts packets below the lower limit", func() {
|
|
||||||
handler.IgnoreBelow(6)
|
|
||||||
err := handler.ReceivedPacket(2, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't add delayed packets to the packetHistory", func() {
|
|
||||||
handler.IgnoreBelow(7)
|
|
||||||
err := handler.ReceivedPacket(4, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.ReceivedPacket(10, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10)))
|
|
||||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(10)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes packets from the packetHistory when a lower limit is set", func() {
|
|
||||||
for i := 1; i <= 12; i++ {
|
|
||||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
handler.IgnoreBelow(7)
|
|
||||||
// check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(12)))
|
|
||||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(7)))
|
|
||||||
Expect(ack.HasMissingRanges()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
// TODO: remove this test when dropping support for STOP_WAITINGs
|
|
||||||
It("handles a lower limit of 0", func() {
|
|
||||||
handler.IgnoreBelow(0)
|
|
||||||
err := handler.ReceivedPacket(1337, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ack := handler.GetAckFrame()
|
|
||||||
Expect(ack).ToNot(BeNil())
|
|
||||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(1337)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
handler.ackAlarm = time.Now().Add(-time.Minute)
|
|
||||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
|
||||||
Expect(handler.packetsReceivedSinceLastAck).To(BeZero())
|
|
||||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
|
||||||
Expect(handler.retransmittablePacketsReceivedSinceLastAck).To(BeZero())
|
|
||||||
Expect(handler.ackQueued).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't generate an ACK when none is queued and the timer is not set", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
handler.ackQueued = false
|
|
||||||
handler.ackAlarm = time.Time{}
|
|
||||||
Expect(handler.GetAckFrame()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
handler.ackQueued = false
|
|
||||||
handler.ackAlarm = time.Now().Add(time.Minute)
|
|
||||||
Expect(handler.GetAckFrame()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates an ACK when the timer has expired", func() {
|
|
||||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
handler.ackQueued = false
|
|
||||||
handler.ackAlarm = time.Now().Add(-time.Minute)
|
|
||||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,248 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("receivedPacketHistory", func() {
|
|
||||||
var (
|
|
||||||
hist *receivedPacketHistory
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
hist = newReceivedPacketHistory()
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ranges", func() {
|
|
||||||
It("adds the first packet", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't care about duplicate packets", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds a few consecutive packets", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't care about a duplicate packet contained in an existing range", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("extends a range at the front", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(3)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 3, End: 4}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("creates a new range when a packet is lost", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(2))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
|
||||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("creates a new range in between two ranges", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(10)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(2))
|
|
||||||
hist.ReceivedPacket(7)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(3))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
|
||||||
Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7}))
|
|
||||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("creates a new range before an existing range for a belated packet", func() {
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(2))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
|
||||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("extends a previous range at the end", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(7)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(2))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 5}))
|
|
||||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("extends a range at the front", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(7)
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(2))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
|
||||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 7}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes a range", func() {
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(2))
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes a range in the middle", func() {
|
|
||||||
hist.ReceivedPacket(1)
|
|
||||||
hist.ReceivedPacket(10)
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(4))
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(3))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 1, End: 1}))
|
|
||||||
Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
|
|
||||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("deleting", func() {
|
|
||||||
It("does nothing when the history is empty", func() {
|
|
||||||
hist.DeleteBelow(5)
|
|
||||||
Expect(hist.ranges.Len()).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes a range", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
hist.ReceivedPacket(10)
|
|
||||||
hist.DeleteBelow(6)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deletes multiple ranges", func() {
|
|
||||||
hist.ReceivedPacket(1)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
hist.ReceivedPacket(10)
|
|
||||||
hist.DeleteBelow(8)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adjusts a range, if packets are delete from an existing range", func() {
|
|
||||||
hist.ReceivedPacket(3)
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
hist.ReceivedPacket(7)
|
|
||||||
hist.DeleteBelow(5)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 7}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adjusts a range, if only one packet remains in the range", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
hist.ReceivedPacket(10)
|
|
||||||
hist.DeleteBelow(5)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(2))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 5}))
|
|
||||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("keeps a one-packet range, if deleting up to the packet directly below", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.DeleteBelow(4)
|
|
||||||
Expect(hist.ranges.Len()).To(Equal(1))
|
|
||||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("DoS protection", func() {
|
|
||||||
It("doesn't create more than MaxTrackedReceivedAckRanges ranges", func() {
|
|
||||||
for i := protocol.PacketNumber(1); i <= protocol.MaxTrackedReceivedAckRanges; i++ {
|
|
||||||
err := hist.ReceivedPacket(2 * i)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
err := hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 2)
|
|
||||||
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't consider already deleted ranges for MaxTrackedReceivedAckRanges", func() {
|
|
||||||
for i := protocol.PacketNumber(1); i <= protocol.MaxTrackedReceivedAckRanges; i++ {
|
|
||||||
err := hist.ReceivedPacket(2 * i)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
err := hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 2)
|
|
||||||
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
|
|
||||||
hist.DeleteBelow(protocol.MaxTrackedReceivedAckRanges) // deletes about half of the ranges
|
|
||||||
err = hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 4)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ACK range export", func() {
|
|
||||||
It("returns nil if there are no ranges", func() {
|
|
||||||
Expect(hist.GetAckRanges()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets a single ACK range", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
ackRanges := hist.GetAckRanges()
|
|
||||||
Expect(ackRanges).To(HaveLen(1))
|
|
||||||
Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets multiple ACK ranges", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
hist.ReceivedPacket(1)
|
|
||||||
hist.ReceivedPacket(11)
|
|
||||||
hist.ReceivedPacket(10)
|
|
||||||
hist.ReceivedPacket(2)
|
|
||||||
ackRanges := hist.GetAckRanges()
|
|
||||||
Expect(ackRanges).To(HaveLen(3))
|
|
||||||
Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 10, Largest: 11}))
|
|
||||||
Expect(ackRanges[1]).To(Equal(wire.AckRange{Smallest: 4, Largest: 6}))
|
|
||||||
Expect(ackRanges[2]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Getting the highest ACK range", func() {
|
|
||||||
It("returns the zero value if there are no ranges", func() {
|
|
||||||
Expect(hist.GetHighestAckRange()).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets a single ACK range", func() {
|
|
||||||
hist.ReceivedPacket(4)
|
|
||||||
hist.ReceivedPacket(5)
|
|
||||||
Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the highest of multiple ACK ranges", func() {
|
|
||||||
hist.ReceivedPacket(3)
|
|
||||||
hist.ReceivedPacket(6)
|
|
||||||
hist.ReceivedPacket(7)
|
|
||||||
Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 6, Largest: 7}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,45 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("retransmittable frames", func() {
|
|
||||||
for fl, el := range map[wire.Frame]bool{
|
|
||||||
&wire.AckFrame{}: false,
|
|
||||||
&wire.StopWaitingFrame{}: false,
|
|
||||||
&wire.BlockedFrame{}: true,
|
|
||||||
&wire.ConnectionCloseFrame{}: true,
|
|
||||||
&wire.GoawayFrame{}: true,
|
|
||||||
&wire.PingFrame{}: true,
|
|
||||||
&wire.RstStreamFrame{}: true,
|
|
||||||
&wire.StreamFrame{}: true,
|
|
||||||
&wire.MaxDataFrame{}: true,
|
|
||||||
&wire.MaxStreamDataFrame{}: true,
|
|
||||||
} {
|
|
||||||
f := fl
|
|
||||||
e := el
|
|
||||||
fName := reflect.ValueOf(f).Elem().Type().Name()
|
|
||||||
|
|
||||||
It("works for "+fName, func() {
|
|
||||||
Expect(IsFrameRetransmittable(f)).To(Equal(e))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stripping non-retransmittable frames works for "+fName, func() {
|
|
||||||
s := []wire.Frame{f}
|
|
||||||
if e {
|
|
||||||
Expect(stripNonRetransmittableFrames(s)).To(Equal([]wire.Frame{f}))
|
|
||||||
} else {
|
|
||||||
Expect(stripNonRetransmittableFrames(s)).To(BeEmpty())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("HasRetransmittableFrames works for "+fName, func() {
|
|
||||||
Expect(HasRetransmittableFrames([]wire.Frame{f})).To(Equal(e))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
@ -1,18 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Send Mode", func() {
|
|
||||||
It("has a string representation", func() {
|
|
||||||
Expect(SendNone.String()).To(Equal("none"))
|
|
||||||
Expect(SendAny.String()).To(Equal("any"))
|
|
||||||
Expect(SendAck.String()).To(Equal("ack"))
|
|
||||||
Expect(SendRTO.String()).To(Equal("rto"))
|
|
||||||
Expect(SendTLP.String()).To(Equal("tlp"))
|
|
||||||
Expect(SendRetransmission.String()).To(Equal("retransmission"))
|
|
||||||
Expect(SendMode(123).String()).To(Equal("invalid send mode: 123"))
|
|
||||||
})
|
|
||||||
})
|
|
File diff suppressed because it is too large
Load Diff
@ -1,297 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("SentPacketHistory", func() {
|
|
||||||
var hist *sentPacketHistory
|
|
||||||
|
|
||||||
expectInHistory := func(packetNumbers []protocol.PacketNumber) {
|
|
||||||
ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers)))
|
|
||||||
ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers)))
|
|
||||||
i := 0
|
|
||||||
hist.Iterate(func(p *Packet) (bool, error) {
|
|
||||||
pn := packetNumbers[i]
|
|
||||||
ExpectWithOffset(1, p.PacketNumber).To(Equal(pn))
|
|
||||||
ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn))
|
|
||||||
i++
|
|
||||||
return true, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
hist = newSentPacketHistory()
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves sent packets", func() {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 1})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 3})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 4})
|
|
||||||
expectInHistory([]protocol.PacketNumber{1, 3, 4})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the length", func() {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 1})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 10})
|
|
||||||
Expect(hist.Len()).To(Equal(2))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("getting the first outstanding packet", func() {
|
|
||||||
It("gets nil, if there are no packets", func() {
|
|
||||||
Expect(hist.FirstOutstanding()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the first outstanding packet", func() {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 2})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 3})
|
|
||||||
front := hist.FirstOutstanding()
|
|
||||||
Expect(front).ToNot(BeNil())
|
|
||||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the second packet if the first one is retransmitted", func() {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
|
|
||||||
front := hist.FirstOutstanding()
|
|
||||||
Expect(front).ToNot(BeNil())
|
|
||||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
|
||||||
// Queue the first packet for retransmission.
|
|
||||||
// The first outstanding packet should now be 3.
|
|
||||||
err := hist.MarkCannotBeRetransmitted(1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
front = hist.FirstOutstanding()
|
|
||||||
Expect(front).ToNot(BeNil())
|
|
||||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(3)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the third packet if the first two are retransmitted", func() {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
|
|
||||||
front := hist.FirstOutstanding()
|
|
||||||
Expect(front).ToNot(BeNil())
|
|
||||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
|
||||||
// Queue the second packet for retransmission.
|
|
||||||
// The first outstanding packet should still be 3.
|
|
||||||
err := hist.MarkCannotBeRetransmitted(3)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
front = hist.FirstOutstanding()
|
|
||||||
Expect(front).ToNot(BeNil())
|
|
||||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
|
||||||
// Queue the first packet for retransmission.
|
|
||||||
// The first outstanding packet should still be 4.
|
|
||||||
err = hist.MarkCannotBeRetransmitted(1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
front = hist.FirstOutstanding()
|
|
||||||
Expect(front).ToNot(BeNil())
|
|
||||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(4)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets a packet by packet number", func() {
|
|
||||||
p := &Packet{PacketNumber: 2}
|
|
||||||
hist.SentPacket(p)
|
|
||||||
Expect(hist.GetPacket(2)).To(Equal(p))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns nil if the packet doesn't exist", func() {
|
|
||||||
Expect(hist.GetPacket(1337)).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("removes packets", func() {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 1})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 4})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 8})
|
|
||||||
err := hist.Remove(4)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expectInHistory([]protocol.PacketNumber{1, 8})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when trying to remove a non existing packet", func() {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 1})
|
|
||||||
err := hist.Remove(2)
|
|
||||||
Expect(err).To(MatchError("packet 2 not found in sent packet history"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("iterating", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 10})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 14})
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: 18})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("iterates over all packets", func() {
|
|
||||||
var iterations []protocol.PacketNumber
|
|
||||||
err := hist.Iterate(func(p *Packet) (bool, error) {
|
|
||||||
iterations = append(iterations, p.PacketNumber)
|
|
||||||
return true, nil
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stops iterating", func() {
|
|
||||||
var iterations []protocol.PacketNumber
|
|
||||||
err := hist.Iterate(func(p *Packet) (bool, error) {
|
|
||||||
iterations = append(iterations, p.PacketNumber)
|
|
||||||
return p.PacketNumber != 14, nil
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the error", func() {
|
|
||||||
testErr := errors.New("test error")
|
|
||||||
var iterations []protocol.PacketNumber
|
|
||||||
err := hist.Iterate(func(p *Packet) (bool, error) {
|
|
||||||
iterations = append(iterations, p.PacketNumber)
|
|
||||||
if p.PacketNumber == 14 {
|
|
||||||
return false, testErr
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
})
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("retransmissions", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
for i := protocol.PacketNumber(1); i <= 5; i++ {
|
|
||||||
hist.SentPacket(&Packet{PacketNumber: i})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the packet doesn't exist", func() {
|
|
||||||
err := hist.MarkCannotBeRetransmitted(100)
|
|
||||||
Expect(err).To(MatchError("sent packet history: packet 100 not found"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds a sent packets as a retransmission", func() {
|
|
||||||
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 2)
|
|
||||||
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13})
|
|
||||||
Expect(hist.GetPacket(13).isRetransmission).To(BeTrue())
|
|
||||||
Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
|
|
||||||
Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds multiple packets sent as a retransmission", func() {
|
|
||||||
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}, {PacketNumber: 15}}, 2)
|
|
||||||
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13, 15})
|
|
||||||
Expect(hist.GetPacket(13).isRetransmission).To(BeTrue())
|
|
||||||
Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
|
|
||||||
Expect(hist.GetPacket(15).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
|
|
||||||
Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13, 15}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds a packet as a normal packet if the retransmitted packet doesn't exist", func() {
|
|
||||||
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 7)
|
|
||||||
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13})
|
|
||||||
Expect(hist.GetPacket(13).isRetransmission).To(BeFalse())
|
|
||||||
Expect(hist.GetPacket(13).retransmissionOf).To(BeZero())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("outstanding packets", func() {
|
|
||||||
It("says if it has outstanding handshake packets", func() {
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
|
||||||
canBeRetransmitted: true,
|
|
||||||
})
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("says if it has outstanding packets", func() {
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
|
||||||
canBeRetransmitted: true,
|
|
||||||
})
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't consider non-retransmittable packets as outstanding", func() {
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
|
||||||
})
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accounts for deleted handshake packets", func() {
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
PacketNumber: 5,
|
|
||||||
EncryptionLevel: protocol.EncryptionSecure,
|
|
||||||
canBeRetransmitted: true,
|
|
||||||
})
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
|
||||||
err := hist.Remove(5)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accounts for deleted packets", func() {
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
PacketNumber: 10,
|
|
||||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
|
||||||
canBeRetransmitted: true,
|
|
||||||
})
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
|
||||||
err := hist.Remove(10)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't count handshake packets marked as non-retransmittable", func() {
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
PacketNumber: 5,
|
|
||||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
|
||||||
canBeRetransmitted: true,
|
|
||||||
})
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
|
||||||
err := hist.MarkCannotBeRetransmitted(5)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't count packets marked as non-retransmittable", func() {
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
PacketNumber: 10,
|
|
||||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
|
||||||
canBeRetransmitted: true,
|
|
||||||
})
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
|
||||||
err := hist.MarkCannotBeRetransmitted(10)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("counts the number of packets", func() {
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
PacketNumber: 10,
|
|
||||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
|
||||||
canBeRetransmitted: true,
|
|
||||||
})
|
|
||||||
hist.SentPacket(&Packet{
|
|
||||||
PacketNumber: 11,
|
|
||||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
|
||||||
canBeRetransmitted: true,
|
|
||||||
})
|
|
||||||
err := hist.Remove(11)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
|
||||||
err = hist.Remove(10)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,55 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("StopWaitingManager", func() {
|
|
||||||
var manager *stopWaitingManager
|
|
||||||
BeforeEach(func() {
|
|
||||||
manager = &stopWaitingManager{}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns nil in the beginning", func() {
|
|
||||||
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
|
|
||||||
Expect(manager.GetStopWaitingFrame(true)).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns a StopWaitingFrame, when a new ACK arrives", func() {
|
|
||||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
|
||||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not decrease the LeastUnacked", func() {
|
|
||||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
|
||||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 9}}})
|
|
||||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not send the same StopWaitingFrame twice", func() {
|
|
||||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
|
||||||
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
|
|
||||||
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the same StopWaitingFrame twice, if forced", func() {
|
|
||||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
|
||||||
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
|
|
||||||
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
|
|
||||||
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("increases the LeastUnacked when a retransmission is queued", func() {
|
|
||||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
|
||||||
manager.QueuedRetransmissionForPacketNumber(20)
|
|
||||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 21}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not decrease the LeastUnacked when a retransmission is queued", func() {
|
|
||||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
|
||||||
manager.QueuedRetransmissionForPacketNumber(9)
|
|
||||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,14 +0,0 @@
|
|||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Bandwidth", func() {
|
|
||||||
It("converts from time delta", func() {
|
|
||||||
Expect(BandwidthFromDelta(1, time.Millisecond)).To(Equal(1000 * BytesPerSecond))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,13 +0,0 @@
|
|||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCongestion(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Congestion Suite")
|
|
||||||
}
|
|
@ -1,640 +0,0 @@
|
|||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
const initialCongestionWindowPackets = 10
|
|
||||||
const defaultWindowTCP = protocol.ByteCount(initialCongestionWindowPackets) * protocol.DefaultTCPMSS
|
|
||||||
|
|
||||||
type mockClock time.Time
|
|
||||||
|
|
||||||
func (c *mockClock) Now() time.Time {
|
|
||||||
return time.Time(*c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *mockClock) Advance(d time.Duration) {
|
|
||||||
*c = mockClock(time.Time(*c).Add(d))
|
|
||||||
}
|
|
||||||
|
|
||||||
const MaxCongestionWindow protocol.ByteCount = 200 * protocol.DefaultTCPMSS
|
|
||||||
|
|
||||||
var _ = Describe("Cubic Sender", func() {
|
|
||||||
var (
|
|
||||||
sender SendAlgorithmWithDebugInfo
|
|
||||||
clock mockClock
|
|
||||||
bytesInFlight protocol.ByteCount
|
|
||||||
packetNumber protocol.PacketNumber
|
|
||||||
ackedPacketNumber protocol.PacketNumber
|
|
||||||
rttStats *RTTStats
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
bytesInFlight = 0
|
|
||||||
packetNumber = 1
|
|
||||||
ackedPacketNumber = 0
|
|
||||||
clock = mockClock{}
|
|
||||||
rttStats = NewRTTStats()
|
|
||||||
sender = NewCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*protocol.DefaultTCPMSS, MaxCongestionWindow)
|
|
||||||
})
|
|
||||||
|
|
||||||
canSend := func() bool {
|
|
||||||
return bytesInFlight < sender.GetCongestionWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int {
|
|
||||||
packetsSent := 0
|
|
||||||
for canSend() {
|
|
||||||
sender.OnPacketSent(clock.Now(), bytesInFlight, packetNumber, packetLength, true)
|
|
||||||
packetNumber++
|
|
||||||
packetsSent++
|
|
||||||
bytesInFlight += packetLength
|
|
||||||
}
|
|
||||||
return packetsSent
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normal is that TCP acks every other segment.
|
|
||||||
AckNPackets := func(n int) {
|
|
||||||
rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now())
|
|
||||||
sender.MaybeExitSlowStart()
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
ackedPacketNumber++
|
|
||||||
sender.OnPacketAcked(ackedPacketNumber, protocol.DefaultTCPMSS, bytesInFlight, clock.Now())
|
|
||||||
}
|
|
||||||
bytesInFlight -= protocol.ByteCount(n) * protocol.DefaultTCPMSS
|
|
||||||
clock.Advance(time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) {
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
ackedPacketNumber++
|
|
||||||
sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight)
|
|
||||||
}
|
|
||||||
bytesInFlight -= protocol.ByteCount(n) * packetLength
|
|
||||||
}
|
|
||||||
|
|
||||||
// Does not increment acked_packet_number_.
|
|
||||||
LosePacket := func(number protocol.PacketNumber) {
|
|
||||||
sender.OnPacketLost(number, protocol.DefaultTCPMSS, bytesInFlight)
|
|
||||||
bytesInFlight -= protocol.DefaultTCPMSS
|
|
||||||
}
|
|
||||||
|
|
||||||
SendAvailableSendWindow := func() int { return SendAvailableSendWindowLen(protocol.DefaultTCPMSS) }
|
|
||||||
LoseNPackets := func(n int) { LoseNPacketsLen(n, protocol.DefaultTCPMSS) }
|
|
||||||
|
|
||||||
It("has the right values at startup", func() {
|
|
||||||
// At startup make sure we are at the default.
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
|
||||||
// Make sure we can send.
|
|
||||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
|
||||||
Expect(canSend()).To(BeTrue())
|
|
||||||
// And that window is un-affected.
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
|
||||||
|
|
||||||
// Fill the send window with data, then verify that we can't send.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
Expect(canSend()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("paces", func() {
|
|
||||||
clock.Advance(time.Hour)
|
|
||||||
// Fill the send window with data, then verify that we can't send.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(1)
|
|
||||||
delay := sender.TimeUntilSend(bytesInFlight)
|
|
||||||
Expect(delay).ToNot(BeZero())
|
|
||||||
Expect(delay).ToNot(Equal(utils.InfDuration))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("application limited slow start", func() {
|
|
||||||
// Send exactly 10 packets and ensure the CWND ends at 14 packets.
|
|
||||||
const numberOfAcks = 5
|
|
||||||
// At startup make sure we can send.
|
|
||||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
|
||||||
// Make sure we can send.
|
|
||||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
|
||||||
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
bytesToSend := sender.GetCongestionWindow()
|
|
||||||
// It's expected 2 acks will arrive when the bytes_in_flight are greater than
|
|
||||||
// half the CWND.
|
|
||||||
Expect(bytesToSend).To(Equal(defaultWindowTCP + protocol.DefaultTCPMSS*2*2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("exponential slow start", func() {
|
|
||||||
const numberOfAcks = 20
|
|
||||||
// At startup make sure we can send.
|
|
||||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
|
||||||
Expect(sender.BandwidthEstimate()).To(BeZero())
|
|
||||||
// Make sure we can send.
|
|
||||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
|
||||||
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
cwnd := sender.GetCongestionWindow()
|
|
||||||
Expect(cwnd).To(Equal(defaultWindowTCP + protocol.DefaultTCPMSS*2*numberOfAcks))
|
|
||||||
Expect(sender.BandwidthEstimate()).To(Equal(BandwidthFromDelta(cwnd, rttStats.SmoothedRTT())))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("slow start packet loss", func() {
|
|
||||||
sender.SetNumEmulatedConnections(1)
|
|
||||||
const numberOfAcks = 10
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Lose a packet to exit slow start.
|
|
||||||
LoseNPackets(1)
|
|
||||||
packetsInRecoveryWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
|
||||||
|
|
||||||
// We should now have fallen out of slow start with a reduced window.
|
|
||||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Recovery phase. We need to ack every packet in the recovery window before
|
|
||||||
// we exit recovery.
|
|
||||||
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
|
||||||
AckNPackets(int(packetsInRecoveryWindow))
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// We need to ack an entire window before we increase CWND by 1.
|
|
||||||
AckNPackets(int(numberOfPacketsInWindow) - 2)
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Next ack should increase cwnd by 1.
|
|
||||||
AckNPackets(1)
|
|
||||||
expectedSendWindow += protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Now RTO and ensure slow start gets reset.
|
|
||||||
Expect(sender.HybridSlowStart().Started()).To(BeTrue())
|
|
||||||
sender.OnRetransmissionTimeout(true)
|
|
||||||
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("slow start packet loss with large reduction", func() {
|
|
||||||
sender.SetSlowStartLargeReduction(true)
|
|
||||||
|
|
||||||
sender.SetNumEmulatedConnections(1)
|
|
||||||
const numberOfAcks = 10
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Lose a packet to exit slow start. We should now have fallen out of
|
|
||||||
// slow start with a window reduced by 1.
|
|
||||||
LoseNPackets(1)
|
|
||||||
expectedSendWindow -= protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Lose 5 packets in recovery and verify that congestion window is reduced
|
|
||||||
// further.
|
|
||||||
LoseNPackets(5)
|
|
||||||
expectedSendWindow -= 5 * protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
packetsInRecoveryWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
|
||||||
|
|
||||||
// Recovery phase. We need to ack every packet in the recovery window before
|
|
||||||
// we exit recovery.
|
|
||||||
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
|
||||||
AckNPackets(int(packetsInRecoveryWindow))
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// We need to ack the rest of the window before cwnd increases by 1.
|
|
||||||
AckNPackets(int(numberOfPacketsInWindow - 1))
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Next ack should increase cwnd by 1.
|
|
||||||
AckNPackets(1)
|
|
||||||
expectedSendWindow += protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Now RTO and ensure slow start gets reset.
|
|
||||||
Expect(sender.HybridSlowStart().Started()).To(BeTrue())
|
|
||||||
sender.OnRetransmissionTimeout(true)
|
|
||||||
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("slow start half packet loss with large reduction", func() {
|
|
||||||
sender.SetSlowStartLargeReduction(true)
|
|
||||||
|
|
||||||
sender.SetNumEmulatedConnections(1)
|
|
||||||
const numberOfAcks = 10
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window in half sized packets.
|
|
||||||
SendAvailableSendWindowLen(protocol.DefaultTCPMSS / 2)
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
SendAvailableSendWindowLen(protocol.DefaultTCPMSS / 2)
|
|
||||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Lose a packet to exit slow start. We should now have fallen out of
|
|
||||||
// slow start with a window reduced by 1.
|
|
||||||
LoseNPackets(1)
|
|
||||||
expectedSendWindow -= protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Lose 10 packets in recovery and verify that congestion window is reduced
|
|
||||||
// by 5 packets.
|
|
||||||
LoseNPacketsLen(10, protocol.DefaultTCPMSS/2)
|
|
||||||
expectedSendWindow -= 5 * protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
// this test doesn't work any more after introducing the pacing needed for QUIC
|
|
||||||
PIt("no PRR when less than one packet in flight", func() {
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
LoseNPackets(int(initialCongestionWindowPackets) - 1)
|
|
||||||
AckNPackets(1)
|
|
||||||
// PRR will allow 2 packets for every ack during recovery.
|
|
||||||
Expect(SendAvailableSendWindow()).To(Equal(2))
|
|
||||||
// Simulate abandoning all packets by supplying a bytes_in_flight of 0.
|
|
||||||
// PRR should now allow a packet to be sent, even though prr's state
|
|
||||||
// variables believe it has sent enough packets.
|
|
||||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("slow start packet loss PRR", func() {
|
|
||||||
sender.SetNumEmulatedConnections(1)
|
|
||||||
// Test based on the first example in RFC6937.
|
|
||||||
// Ack 10 packets in 5 acks to raise the CWND to 20, as in the example.
|
|
||||||
const numberOfAcks = 5
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
LoseNPackets(1)
|
|
||||||
|
|
||||||
// We should now have fallen out of slow start with a reduced window.
|
|
||||||
sendWindowBeforeLoss := expectedSendWindow
|
|
||||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Testing TCP proportional rate reduction.
|
|
||||||
// We should send packets paced over the received acks for the remaining
|
|
||||||
// outstanding packets. The number of packets before we exit recovery is the
|
|
||||||
// original CWND minus the packet that has been lost and the one which
|
|
||||||
// triggered the loss.
|
|
||||||
remainingPacketsInRecovery := sendWindowBeforeLoss/protocol.DefaultTCPMSS - 2
|
|
||||||
|
|
||||||
for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ {
|
|
||||||
AckNPackets(1)
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to ack another window before we increase CWND by 1.
|
|
||||||
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
|
||||||
for i := protocol.ByteCount(0); i < numberOfPacketsInWindow; i++ {
|
|
||||||
AckNPackets(1)
|
|
||||||
Expect(SendAvailableSendWindow()).To(Equal(1))
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
AckNPackets(1)
|
|
||||||
expectedSendWindow += protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("slow start burst packet loss PRR", func() {
|
|
||||||
sender.SetNumEmulatedConnections(1)
|
|
||||||
// Test based on the second example in RFC6937, though we also implement
|
|
||||||
// forward acknowledgements, so the first two incoming acks will trigger
|
|
||||||
// PRR immediately.
|
|
||||||
// Ack 20 packets in 10 acks to raise the CWND to 30.
|
|
||||||
const numberOfAcks = 10
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Lose one more than the congestion window reduction, so that after loss,
|
|
||||||
// bytes_in_flight is lesser than the congestion window.
|
|
||||||
sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow))
|
|
||||||
numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/protocol.DefaultTCPMSS + 1
|
|
||||||
LoseNPackets(int(numPacketsToLose))
|
|
||||||
// Immediately after the loss, ensure at least one packet can be sent.
|
|
||||||
// Losses without subsequent acks can occur with timer based loss detection.
|
|
||||||
Expect(sender.TimeUntilSend(bytesInFlight)).To(BeZero())
|
|
||||||
AckNPackets(1)
|
|
||||||
|
|
||||||
// We should now have fallen out of slow start with a reduced window.
|
|
||||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Only 2 packets should be allowed to be sent, per PRR-SSRB
|
|
||||||
Expect(SendAvailableSendWindow()).To(Equal(2))
|
|
||||||
|
|
||||||
// Ack the next packet, which triggers another loss.
|
|
||||||
LoseNPackets(1)
|
|
||||||
AckNPackets(1)
|
|
||||||
|
|
||||||
// Send 2 packets to simulate PRR-SSRB.
|
|
||||||
Expect(SendAvailableSendWindow()).To(Equal(2))
|
|
||||||
|
|
||||||
// Ack the next packet, which triggers another loss.
|
|
||||||
LoseNPackets(1)
|
|
||||||
AckNPackets(1)
|
|
||||||
|
|
||||||
// Send 2 packets to simulate PRR-SSRB.
|
|
||||||
Expect(SendAvailableSendWindow()).To(Equal(2))
|
|
||||||
|
|
||||||
// Exit recovery and return to sending at the new rate.
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
AckNPackets(1)
|
|
||||||
Expect(SendAvailableSendWindow()).To(Equal(1))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("RTO congestion window", func() {
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
|
||||||
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
|
|
||||||
|
|
||||||
// Expect the window to decrease to the minimum once the RTO fires
|
|
||||||
// and slow start threshold to be set to 1/2 of the CWND.
|
|
||||||
sender.OnRetransmissionTimeout(true)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(2 * protocol.DefaultTCPMSS))
|
|
||||||
Expect(sender.SlowstartThreshold()).To(Equal(5 * protocol.DefaultTCPMSS))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("RTO congestion window no retransmission", func() {
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
|
||||||
|
|
||||||
// Expect the window to remain unchanged if the RTO fires but no
|
|
||||||
// packets are retransmitted.
|
|
||||||
sender.OnRetransmissionTimeout(false)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("tcp cubic reset epoch on quiescence", func() {
|
|
||||||
const maxCongestionWindow = 50
|
|
||||||
const maxCongestionWindowBytes = maxCongestionWindow * protocol.DefaultTCPMSS
|
|
||||||
sender = NewCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*protocol.DefaultTCPMSS, maxCongestionWindowBytes)
|
|
||||||
|
|
||||||
numSent := SendAvailableSendWindow()
|
|
||||||
|
|
||||||
// Make sure we fall out of slow start.
|
|
||||||
savedCwnd := sender.GetCongestionWindow()
|
|
||||||
LoseNPackets(1)
|
|
||||||
Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow()))
|
|
||||||
|
|
||||||
// Ack the rest of the outstanding packets to get out of recovery.
|
|
||||||
for i := 1; i < numSent; i++ {
|
|
||||||
AckNPackets(1)
|
|
||||||
}
|
|
||||||
Expect(bytesInFlight).To(BeZero())
|
|
||||||
|
|
||||||
// Send a new window of data and ack all; cubic growth should occur.
|
|
||||||
savedCwnd = sender.GetCongestionWindow()
|
|
||||||
numSent = SendAvailableSendWindow()
|
|
||||||
for i := 0; i < numSent; i++ {
|
|
||||||
AckNPackets(1)
|
|
||||||
}
|
|
||||||
Expect(savedCwnd).To(BeNumerically("<", sender.GetCongestionWindow()))
|
|
||||||
Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow()))
|
|
||||||
Expect(bytesInFlight).To(BeZero())
|
|
||||||
|
|
||||||
// Quiescent time of 100 seconds
|
|
||||||
clock.Advance(100 * time.Second)
|
|
||||||
|
|
||||||
// Send new window of data and ack one packet. Cubic epoch should have
|
|
||||||
// been reset; ensure cwnd increase is not dramatic.
|
|
||||||
savedCwnd = sender.GetCongestionWindow()
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(1)
|
|
||||||
Expect(savedCwnd).To(BeNumerically("~", sender.GetCongestionWindow(), protocol.DefaultTCPMSS))
|
|
||||||
Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow()))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("multiple losses in one window", func() {
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
initialWindow := sender.GetCongestionWindow()
|
|
||||||
LosePacket(ackedPacketNumber + 1)
|
|
||||||
postLossWindow := sender.GetCongestionWindow()
|
|
||||||
Expect(initialWindow).To(BeNumerically(">", postLossWindow))
|
|
||||||
LosePacket(ackedPacketNumber + 3)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow))
|
|
||||||
LosePacket(packetNumber - 1)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow))
|
|
||||||
|
|
||||||
// Lose a later packet and ensure the window decreases.
|
|
||||||
LosePacket(packetNumber)
|
|
||||||
Expect(postLossWindow).To(BeNumerically(">", sender.GetCongestionWindow()))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("2 connection congestion avoidance at end of recovery", func() {
|
|
||||||
sender.SetNumEmulatedConnections(2)
|
|
||||||
// Ack 10 packets in 5 acks to raise the CWND to 20.
|
|
||||||
const numberOfAcks = 5
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
LoseNPackets(1)
|
|
||||||
|
|
||||||
// We should now have fallen out of slow start with a reduced window.
|
|
||||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * sender.RenoBeta())
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// No congestion window growth should occur in recovery phase, i.e., until the
|
|
||||||
// currently outstanding 20 packets are acked.
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
Expect(sender.InRecovery()).To(BeTrue())
|
|
||||||
AckNPackets(2)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
}
|
|
||||||
Expect(sender.InRecovery()).To(BeFalse())
|
|
||||||
|
|
||||||
// Out of recovery now. Congestion window should not grow for half an RTT.
|
|
||||||
packetsInSendWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(int(packetsInSendWindow/2 - 2))
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Next ack should increase congestion window by 1MSS.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
expectedSendWindow += protocol.DefaultTCPMSS
|
|
||||||
packetsInSendWindow++
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Congestion window should remain steady again for half an RTT.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(int(packetsInSendWindow/2 - 1))
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Next ack should cause congestion window to grow by 1MSS.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
expectedSendWindow += protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("1 connection congestion avoidance at end of recovery", func() {
|
|
||||||
sender.SetNumEmulatedConnections(1)
|
|
||||||
// Ack 10 packets in 5 acks to raise the CWND to 20.
|
|
||||||
const numberOfAcks = 5
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
LoseNPackets(1)
|
|
||||||
|
|
||||||
// We should now have fallen out of slow start with a reduced window.
|
|
||||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// No congestion window growth should occur in recovery phase, i.e., until the
|
|
||||||
// currently outstanding 20 packets are acked.
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
Expect(sender.InRecovery()).To(BeTrue())
|
|
||||||
AckNPackets(2)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
}
|
|
||||||
Expect(sender.InRecovery()).To(BeFalse())
|
|
||||||
|
|
||||||
// Out of recovery now. Congestion window should not grow during RTT.
|
|
||||||
for i := protocol.ByteCount(0); i < expectedSendWindow/protocol.DefaultTCPMSS-2; i += 2 {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next ack should cause congestion window to grow by 1MSS.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
expectedSendWindow += protocol.DefaultTCPMSS
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reset after connection migration", func() {
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
|
||||||
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
|
|
||||||
|
|
||||||
// Starts with slow start.
|
|
||||||
sender.SetNumEmulatedConnections(1)
|
|
||||||
const numberOfAcks = 10
|
|
||||||
for i := 0; i < numberOfAcks; i++ {
|
|
||||||
// Send our full send window.
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
AckNPackets(2)
|
|
||||||
}
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Loses a packet to exit slow start.
|
|
||||||
LoseNPackets(1)
|
|
||||||
|
|
||||||
// We should now have fallen out of slow start with a reduced window. Slow
|
|
||||||
// start threshold is also updated.
|
|
||||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
|
||||||
Expect(sender.SlowstartThreshold()).To(Equal(expectedSendWindow))
|
|
||||||
|
|
||||||
// Resets cwnd and slow start threshold on connection migrations.
|
|
||||||
sender.OnConnectionMigration()
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
|
||||||
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
|
|
||||||
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("default max cwnd", func() {
|
|
||||||
sender = NewCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*protocol.DefaultTCPMSS, protocol.DefaultMaxCongestionWindow)
|
|
||||||
|
|
||||||
defaultMaxCongestionWindowPackets := protocol.DefaultMaxCongestionWindow / protocol.DefaultTCPMSS
|
|
||||||
for i := 1; i < int(defaultMaxCongestionWindowPackets); i++ {
|
|
||||||
sender.MaybeExitSlowStart()
|
|
||||||
sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now())
|
|
||||||
}
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(protocol.DefaultMaxCongestionWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("limit cwnd increase in congestion avoidance", func() {
|
|
||||||
// Enable Cubic.
|
|
||||||
sender = NewCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*protocol.DefaultTCPMSS, MaxCongestionWindow)
|
|
||||||
numSent := SendAvailableSendWindow()
|
|
||||||
|
|
||||||
// Make sure we fall out of slow start.
|
|
||||||
savedCwnd := sender.GetCongestionWindow()
|
|
||||||
LoseNPackets(1)
|
|
||||||
Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow()))
|
|
||||||
|
|
||||||
// Ack the rest of the outstanding packets to get out of recovery.
|
|
||||||
for i := 1; i < numSent; i++ {
|
|
||||||
AckNPackets(1)
|
|
||||||
}
|
|
||||||
Expect(bytesInFlight).To(BeZero())
|
|
||||||
|
|
||||||
savedCwnd = sender.GetCongestionWindow()
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
|
|
||||||
// Ack packets until the CWND increases.
|
|
||||||
for sender.GetCongestionWindow() == savedCwnd {
|
|
||||||
AckNPackets(1)
|
|
||||||
SendAvailableSendWindow()
|
|
||||||
}
|
|
||||||
// Bytes in flight may be larger than the CWND if the CWND isn't an exact
|
|
||||||
// multiple of the packet sizes being sent.
|
|
||||||
Expect(bytesInFlight).To(BeNumerically(">=", sender.GetCongestionWindow()))
|
|
||||||
savedCwnd = sender.GetCongestionWindow()
|
|
||||||
|
|
||||||
// Advance time 2 seconds waiting for an ack.
|
|
||||||
clock.Advance(2 * time.Second)
|
|
||||||
|
|
||||||
// Ack two packets. The CWND should increase by only one packet.
|
|
||||||
AckNPackets(2)
|
|
||||||
Expect(sender.GetCongestionWindow()).To(Equal(savedCwnd + protocol.DefaultTCPMSS))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,236 +0,0 @@
|
|||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
const numConnections uint32 = 2
|
|
||||||
const nConnectionBeta float32 = (float32(numConnections) - 1 + beta) / float32(numConnections)
|
|
||||||
const nConnectionBetaLastMax float32 = (float32(numConnections) - 1 + betaLastMax) / float32(numConnections)
|
|
||||||
const nConnectionAlpha float32 = 3 * float32(numConnections) * float32(numConnections) * (1 - nConnectionBeta) / (1 + nConnectionBeta)
|
|
||||||
const maxCubicTimeInterval = 30 * time.Millisecond
|
|
||||||
|
|
||||||
var _ = Describe("Cubic", func() {
|
|
||||||
var (
|
|
||||||
clock mockClock
|
|
||||||
cubic *Cubic
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
clock = mockClock{}
|
|
||||||
cubic = NewCubic(&clock)
|
|
||||||
})
|
|
||||||
|
|
||||||
renoCwnd := func(currentCwnd protocol.ByteCount) protocol.ByteCount {
|
|
||||||
return currentCwnd + protocol.ByteCount(float32(protocol.DefaultTCPMSS)*nConnectionAlpha*float32(protocol.DefaultTCPMSS)/float32(currentCwnd))
|
|
||||||
}
|
|
||||||
|
|
||||||
cubicConvexCwnd := func(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount {
|
|
||||||
offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000
|
|
||||||
deltaCongestionWindow := 410 * offset * offset * offset * protocol.DefaultTCPMSS >> 40
|
|
||||||
return initialCwnd + deltaCongestionWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
It("works above origin (with tighter bounds)", func() {
|
|
||||||
// Convex growth.
|
|
||||||
const rttMin = 100 * time.Millisecond
|
|
||||||
const rttMinS = float32(rttMin/time.Millisecond) / 1000.0
|
|
||||||
currentCwnd := 10 * protocol.DefaultTCPMSS
|
|
||||||
initialCwnd := currentCwnd
|
|
||||||
|
|
||||||
clock.Advance(time.Millisecond)
|
|
||||||
initialTime := clock.Now()
|
|
||||||
expectedFirstCwnd := renoCwnd(currentCwnd)
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, initialTime)
|
|
||||||
Expect(expectedFirstCwnd).To(Equal(currentCwnd))
|
|
||||||
|
|
||||||
// Normal TCP phase.
|
|
||||||
// The maximum number of expected reno RTTs can be calculated by
|
|
||||||
// finding the point where the cubic curve and the reno curve meet.
|
|
||||||
maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2)
|
|
||||||
for i := 0; i < maxRenoRtts; i++ {
|
|
||||||
// Alternatively, we expect it to increase by one, every time we
|
|
||||||
// receive current_cwnd/Alpha acks back. (This is another way of
|
|
||||||
// saying we expect cwnd to increase by approximately Alpha once
|
|
||||||
// we receive current_cwnd number ofacks back).
|
|
||||||
numAcksThisEpoch := int(float32(currentCwnd/protocol.DefaultTCPMSS) / nConnectionAlpha)
|
|
||||||
|
|
||||||
initialCwndThisEpoch := currentCwnd
|
|
||||||
for n := 0; n < numAcksThisEpoch; n++ {
|
|
||||||
// Call once per ACK.
|
|
||||||
expectedNextCwnd := renoCwnd(currentCwnd)
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
Expect(currentCwnd).To(Equal(expectedNextCwnd))
|
|
||||||
}
|
|
||||||
// Our byte-wise Reno implementation is an estimate. We expect
|
|
||||||
// the cwnd to increase by approximately one MSS every
|
|
||||||
// cwnd/kDefaultTCPMSS/Alpha acks, but it may be off by as much as
|
|
||||||
// half a packet for smaller values of current_cwnd.
|
|
||||||
cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch
|
|
||||||
Expect(cwndChangeThisEpoch).To(BeNumerically("~", protocol.DefaultTCPMSS, protocol.DefaultTCPMSS/2))
|
|
||||||
clock.Advance(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < 54; i++ {
|
|
||||||
maxAcksThisEpoch := currentCwnd / protocol.DefaultTCPMSS
|
|
||||||
interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond
|
|
||||||
for n := 0; n < int(maxAcksThisEpoch); n++ {
|
|
||||||
clock.Advance(interval)
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
|
|
||||||
// If we allow per-ack updates, every update is a small cubic update.
|
|
||||||
Expect(currentCwnd).To(Equal(expectedCwnd))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
Expect(currentCwnd).To(Equal(expectedCwnd))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works above the origin with fine grained cubing", func() {
|
|
||||||
// Start the test with an artificially large cwnd to prevent Reno
|
|
||||||
// from over-taking cubic.
|
|
||||||
currentCwnd := 1000 * protocol.DefaultTCPMSS
|
|
||||||
initialCwnd := currentCwnd
|
|
||||||
rttMin := 100 * time.Millisecond
|
|
||||||
clock.Advance(time.Millisecond)
|
|
||||||
initialTime := clock.Now()
|
|
||||||
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
clock.Advance(600 * time.Millisecond)
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
|
|
||||||
// We expect the algorithm to perform only non-zero, fine-grained cubic
|
|
||||||
// increases on every ack in this case.
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
clock.Advance(10 * time.Millisecond)
|
|
||||||
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
|
|
||||||
nextCwnd := cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
// Make sure we are performing cubic increases.
|
|
||||||
Expect(nextCwnd).To(Equal(expectedCwnd))
|
|
||||||
// Make sure that these are non-zero, less-than-packet sized increases.
|
|
||||||
Expect(nextCwnd).To(BeNumerically(">", currentCwnd))
|
|
||||||
cwndDelta := nextCwnd - currentCwnd
|
|
||||||
Expect(protocol.DefaultTCPMSS / 10).To(BeNumerically(">", cwndDelta))
|
|
||||||
currentCwnd = nextCwnd
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles per ack updates", func() {
|
|
||||||
// Start the test with a large cwnd and RTT, to force the first
|
|
||||||
// increase to be a cubic increase.
|
|
||||||
initialCwndPackets := 150
|
|
||||||
currentCwnd := protocol.ByteCount(initialCwndPackets) * protocol.DefaultTCPMSS
|
|
||||||
rttMin := 350 * time.Millisecond
|
|
||||||
|
|
||||||
// Initialize the epoch
|
|
||||||
clock.Advance(time.Millisecond)
|
|
||||||
// Keep track of the growth of the reno-equivalent cwnd.
|
|
||||||
rCwnd := renoCwnd(currentCwnd)
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
initialCwnd := currentCwnd
|
|
||||||
|
|
||||||
// Simulate the return of cwnd packets in less than
|
|
||||||
// MaxCubicInterval() time.
|
|
||||||
maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha)
|
|
||||||
interval := maxCubicTimeInterval / time.Duration(maxAcks+1)
|
|
||||||
|
|
||||||
// In this scenario, the first increase is dictated by the cubic
|
|
||||||
// equation, but it is less than one byte, so the cwnd doesn't
|
|
||||||
// change. Normally, without per-ack increases, any cwnd plateau
|
|
||||||
// will cause the cwnd to be pinned for MaxCubicTimeInterval(). If
|
|
||||||
// we enable per-ack updates, the cwnd will continue to grow,
|
|
||||||
// regardless of the temporary plateau.
|
|
||||||
clock.Advance(interval)
|
|
||||||
rCwnd = renoCwnd(rCwnd)
|
|
||||||
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(currentCwnd))
|
|
||||||
for i := 1; i < maxAcks; i++ {
|
|
||||||
clock.Advance(interval)
|
|
||||||
nextCwnd := cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
rCwnd = renoCwnd(rCwnd)
|
|
||||||
// The window shoud increase on every ack.
|
|
||||||
Expect(nextCwnd).To(BeNumerically(">", currentCwnd))
|
|
||||||
Expect(nextCwnd).To(Equal(rCwnd))
|
|
||||||
currentCwnd = nextCwnd
|
|
||||||
}
|
|
||||||
|
|
||||||
// After all the acks are returned from the epoch, we expect the
|
|
||||||
// cwnd to have increased by nearly one packet. (Not exactly one
|
|
||||||
// packet, because our byte-wise Reno algorithm is always a slight
|
|
||||||
// under-estimation). Without per-ack updates, the current_cwnd
|
|
||||||
// would otherwise be unchanged.
|
|
||||||
minimumExpectedIncrease := protocol.DefaultTCPMSS * 9 / 10
|
|
||||||
Expect(currentCwnd).To(BeNumerically(">", initialCwnd+minimumExpectedIncrease))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles loss events", func() {
|
|
||||||
rttMin := 100 * time.Millisecond
|
|
||||||
currentCwnd := 422 * protocol.DefaultTCPMSS
|
|
||||||
expectedCwnd := renoCwnd(currentCwnd)
|
|
||||||
// Initialize the state.
|
|
||||||
clock.Advance(time.Millisecond)
|
|
||||||
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd))
|
|
||||||
|
|
||||||
// On the first loss, the last max congestion window is set to the
|
|
||||||
// congestion window before the loss.
|
|
||||||
preLossCwnd := currentCwnd
|
|
||||||
Expect(cubic.lastMaxCongestionWindow).To(BeZero())
|
|
||||||
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
|
|
||||||
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
|
|
||||||
Expect(cubic.lastMaxCongestionWindow).To(Equal(preLossCwnd))
|
|
||||||
currentCwnd = expectedCwnd
|
|
||||||
|
|
||||||
// On the second loss, the current congestion window has not yet
|
|
||||||
// reached the last max congestion window. The last max congestion
|
|
||||||
// window will be reduced by an additional backoff factor to allow
|
|
||||||
// for competition.
|
|
||||||
preLossCwnd = currentCwnd
|
|
||||||
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
|
|
||||||
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
|
|
||||||
currentCwnd = expectedCwnd
|
|
||||||
Expect(preLossCwnd).To(BeNumerically(">", cubic.lastMaxCongestionWindow))
|
|
||||||
expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax)
|
|
||||||
Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax))
|
|
||||||
Expect(expectedCwnd).To(BeNumerically("<", cubic.lastMaxCongestionWindow))
|
|
||||||
// Simulate an increase, and check that we are below the origin.
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
Expect(cubic.lastMaxCongestionWindow).To(BeNumerically(">", currentCwnd))
|
|
||||||
|
|
||||||
// On the final loss, simulate the condition where the congestion
|
|
||||||
// window had a chance to grow nearly to the last congestion window.
|
|
||||||
currentCwnd = cubic.lastMaxCongestionWindow - 1
|
|
||||||
preLossCwnd = currentCwnd
|
|
||||||
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
|
|
||||||
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
|
|
||||||
expectedLastMax = preLossCwnd
|
|
||||||
Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works below origin", func() {
|
|
||||||
// Concave growth.
|
|
||||||
rttMin := 100 * time.Millisecond
|
|
||||||
currentCwnd := 422 * protocol.DefaultTCPMSS
|
|
||||||
expectedCwnd := renoCwnd(currentCwnd)
|
|
||||||
// Initialize the state.
|
|
||||||
clock.Advance(time.Millisecond)
|
|
||||||
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd))
|
|
||||||
|
|
||||||
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
|
|
||||||
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
|
|
||||||
currentCwnd = expectedCwnd
|
|
||||||
// First update after loss to initialize the epoch.
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
// Cubic phase.
|
|
||||||
for i := 0; i < 40; i++ {
|
|
||||||
clock.Advance(100 * time.Millisecond)
|
|
||||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
|
||||||
}
|
|
||||||
expectedCwnd = 553632
|
|
||||||
Expect(currentCwnd).To(Equal(expectedCwnd))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,75 +0,0 @@
|
|||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Hybrid slow start", func() {
|
|
||||||
var (
|
|
||||||
slowStart HybridSlowStart
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
slowStart = HybridSlowStart{}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works in a simple case", func() {
|
|
||||||
packetNumber := protocol.PacketNumber(1)
|
|
||||||
endPacketNumber := protocol.PacketNumber(3)
|
|
||||||
slowStart.StartReceiveRound(endPacketNumber)
|
|
||||||
|
|
||||||
packetNumber++
|
|
||||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
|
|
||||||
|
|
||||||
// Test duplicates.
|
|
||||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
|
|
||||||
|
|
||||||
packetNumber++
|
|
||||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
|
|
||||||
packetNumber++
|
|
||||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
|
|
||||||
|
|
||||||
// Test without a new registered end_packet_number;
|
|
||||||
packetNumber++
|
|
||||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
|
|
||||||
|
|
||||||
endPacketNumber = 20
|
|
||||||
slowStart.StartReceiveRound(endPacketNumber)
|
|
||||||
for packetNumber < endPacketNumber {
|
|
||||||
packetNumber++
|
|
||||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
|
|
||||||
}
|
|
||||||
packetNumber++
|
|
||||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works with delay", func() {
|
|
||||||
rtt := 60 * time.Millisecond
|
|
||||||
// We expect to detect the increase at +1/8 of the RTT; hence at a typical
|
|
||||||
// RTT of 60ms the detection will happen at 67.5 ms.
|
|
||||||
const hybridStartMinSamples = 8 // Number of acks required to trigger.
|
|
||||||
|
|
||||||
endPacketNumber := protocol.PacketNumber(1)
|
|
||||||
endPacketNumber++
|
|
||||||
slowStart.StartReceiveRound(endPacketNumber)
|
|
||||||
|
|
||||||
// Will not trigger since our lowest RTT in our burst is the same as the long
|
|
||||||
// term RTT provided.
|
|
||||||
for n := 0; n < hybridStartMinSamples; n++ {
|
|
||||||
Expect(slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)).To(BeFalse())
|
|
||||||
}
|
|
||||||
endPacketNumber++
|
|
||||||
slowStart.StartReceiveRound(endPacketNumber)
|
|
||||||
for n := 1; n < hybridStartMinSamples; n++ {
|
|
||||||
Expect(slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)).To(BeFalse())
|
|
||||||
}
|
|
||||||
// Expect to trigger since all packets in this burst was above the long term
|
|
||||||
// RTT provided.
|
|
||||||
Expect(slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
})
|
|
@ -1,107 +0,0 @@
|
|||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("PRR sender", func() {
|
|
||||||
var (
|
|
||||||
prr PrrSender
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
prr = PrrSender{}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("single loss results in send on every other ack", func() {
|
|
||||||
numPacketsInFlight := protocol.ByteCount(50)
|
|
||||||
bytesInFlight := numPacketsInFlight * protocol.DefaultTCPMSS
|
|
||||||
sshthreshAfterLoss := numPacketsInFlight / 2
|
|
||||||
congestionWindow := sshthreshAfterLoss * protocol.DefaultTCPMSS
|
|
||||||
|
|
||||||
prr.OnPacketLost(bytesInFlight)
|
|
||||||
// Ack a packet. PRR allows one packet to leave immediately.
|
|
||||||
prr.OnPacketAcked(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight -= protocol.DefaultTCPMSS
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
|
|
||||||
// Send retransmission.
|
|
||||||
prr.OnPacketSent(protocol.DefaultTCPMSS)
|
|
||||||
// PRR shouldn't allow sending any more packets.
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeFalse())
|
|
||||||
|
|
||||||
// One packet is lost, and one ack was consumed above. PRR now paces
|
|
||||||
// transmissions through the remaining 48 acks. PRR will alternatively
|
|
||||||
// disallow and allow a packet to be sent in response to an ack.
|
|
||||||
for i := protocol.ByteCount(0); i < sshthreshAfterLoss-1; i++ {
|
|
||||||
// Ack a packet. PRR shouldn't allow sending a packet in response.
|
|
||||||
prr.OnPacketAcked(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight -= protocol.DefaultTCPMSS
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeFalse())
|
|
||||||
// Ack another packet. PRR should now allow sending a packet in response.
|
|
||||||
prr.OnPacketAcked(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight -= protocol.DefaultTCPMSS
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
|
|
||||||
// Send a packet in response.
|
|
||||||
prr.OnPacketSent(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight += protocol.DefaultTCPMSS
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since bytes_in_flight is now equal to congestion_window, PRR now maintains
|
|
||||||
// packet conservation, allowing one packet to be sent in response to an ack.
|
|
||||||
Expect(bytesInFlight).To(Equal(congestionWindow))
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
// Ack a packet.
|
|
||||||
prr.OnPacketAcked(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight -= protocol.DefaultTCPMSS
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
|
|
||||||
// Send a packet in response, since PRR allows it.
|
|
||||||
prr.OnPacketSent(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight += protocol.DefaultTCPMSS
|
|
||||||
|
|
||||||
// Since bytes_in_flight is equal to the congestion_window,
|
|
||||||
// PRR disallows sending.
|
|
||||||
Expect(bytesInFlight).To(Equal(congestionWindow))
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeFalse())
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
It("burst loss results in slow start", func() {
|
|
||||||
bytesInFlight := protocol.ByteCount(20 * protocol.DefaultTCPMSS)
|
|
||||||
const numPacketsLost = 13
|
|
||||||
const ssthreshAfterLoss = 10
|
|
||||||
const congestionWindow = ssthreshAfterLoss * protocol.DefaultTCPMSS
|
|
||||||
|
|
||||||
// Lose 13 packets.
|
|
||||||
bytesInFlight -= numPacketsLost * protocol.DefaultTCPMSS
|
|
||||||
prr.OnPacketLost(bytesInFlight)
|
|
||||||
|
|
||||||
// PRR-SSRB will allow the following 3 acks to send up to 2 packets.
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
prr.OnPacketAcked(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight -= protocol.DefaultTCPMSS
|
|
||||||
// PRR-SSRB should allow two packets to be sent.
|
|
||||||
for j := 0; j < 2; j++ {
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, ssthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
|
|
||||||
// Send a packet in response.
|
|
||||||
prr.OnPacketSent(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight += protocol.DefaultTCPMSS
|
|
||||||
}
|
|
||||||
// PRR should allow no more than 2 packets in response to an ack.
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, ssthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeFalse())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Out of SSRB mode, PRR allows one send in response to each ack.
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
prr.OnPacketAcked(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight -= protocol.DefaultTCPMSS
|
|
||||||
Expect(prr.CanSend(congestionWindow, bytesInFlight, ssthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
|
|
||||||
// Send a packet in response.
|
|
||||||
prr.OnPacketSent(protocol.DefaultTCPMSS)
|
|
||||||
bytesInFlight += protocol.DefaultTCPMSS
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,131 +0,0 @@
|
|||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("RTT stats", func() {
|
|
||||||
var (
|
|
||||||
rttStats *RTTStats
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
rttStats = NewRTTStats()
|
|
||||||
})
|
|
||||||
|
|
||||||
It("DefaultsBeforeUpdate", func() {
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal(time.Duration(0)))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("SmoothedRTT", func() {
|
|
||||||
// Verify that ack_delay is ignored in the first measurement.
|
|
||||||
rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{})
|
|
||||||
Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond)))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond)))
|
|
||||||
// Verify that Smoothed RTT includes max ack delay if it's reasonable.
|
|
||||||
rttStats.UpdateRTT((350 * time.Millisecond), (50 * time.Millisecond), time.Time{})
|
|
||||||
Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond)))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond)))
|
|
||||||
// Verify that large erroneous ack_delay does not change Smoothed RTT.
|
|
||||||
rttStats.UpdateRTT((200 * time.Millisecond), (300 * time.Millisecond), time.Time{})
|
|
||||||
Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond)))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal((287500 * time.Microsecond)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("SmoothedOrInitialRTT", func() {
|
|
||||||
Expect(rttStats.SmoothedOrInitialRTT()).To(Equal(defaultInitialRTT))
|
|
||||||
rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{})
|
|
||||||
Expect(rttStats.SmoothedOrInitialRTT()).To(Equal((300 * time.Millisecond)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("MinRTT", func() {
|
|
||||||
rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{})
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond)))
|
|
||||||
rttStats.UpdateRTT((10 * time.Millisecond), 0, time.Time{}.Add((10 * time.Millisecond)))
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond)))
|
|
||||||
rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((20 * time.Millisecond)))
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond)))
|
|
||||||
rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((30 * time.Millisecond)))
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond)))
|
|
||||||
rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((40 * time.Millisecond)))
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond)))
|
|
||||||
// Verify that ack_delay does not go into recording of MinRTT_.
|
|
||||||
rttStats.UpdateRTT((7 * time.Millisecond), (2 * time.Millisecond), time.Time{}.Add((50 * time.Millisecond)))
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal((7 * time.Millisecond)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ExpireSmoothedMetrics", func() {
|
|
||||||
initialRtt := (10 * time.Millisecond)
|
|
||||||
rttStats.UpdateRTT(initialRtt, 0, time.Time{})
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal(initialRtt))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt))
|
|
||||||
|
|
||||||
Expect(rttStats.MeanDeviation()).To(Equal(initialRtt / 2))
|
|
||||||
|
|
||||||
// Update once with a 20ms RTT.
|
|
||||||
doubledRtt := initialRtt * (2)
|
|
||||||
rttStats.UpdateRTT(doubledRtt, 0, time.Time{})
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(float32(initialRtt) * 1.125)))
|
|
||||||
|
|
||||||
// Expire the smoothed metrics, increasing smoothed rtt and mean deviation.
|
|
||||||
rttStats.ExpireSmoothedMetrics()
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal(doubledRtt))
|
|
||||||
Expect(rttStats.MeanDeviation()).To(Equal(time.Duration(float32(initialRtt) * 0.875)))
|
|
||||||
|
|
||||||
// Now go back down to 5ms and expire the smoothed metrics, and ensure the
|
|
||||||
// mean deviation increases to 15ms.
|
|
||||||
halfRtt := initialRtt / 2
|
|
||||||
rttStats.UpdateRTT(halfRtt, 0, time.Time{})
|
|
||||||
Expect(doubledRtt).To(BeNumerically(">", rttStats.SmoothedRTT()))
|
|
||||||
Expect(initialRtt).To(BeNumerically("<", rttStats.MeanDeviation()))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("UpdateRTTWithBadSendDeltas", func() {
|
|
||||||
// Make sure we ignore bad RTTs.
|
|
||||||
// base::test::MockLog log;
|
|
||||||
|
|
||||||
initialRtt := (10 * time.Millisecond)
|
|
||||||
rttStats.UpdateRTT(initialRtt, 0, time.Time{})
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal(initialRtt))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt))
|
|
||||||
|
|
||||||
badSendDeltas := []time.Duration{
|
|
||||||
0,
|
|
||||||
utils.InfDuration,
|
|
||||||
-1000 * time.Microsecond,
|
|
||||||
}
|
|
||||||
// log.StartCapturingLogs();
|
|
||||||
|
|
||||||
for _, badSendDelta := range badSendDeltas {
|
|
||||||
// SCOPED_TRACE(Message() << "bad_send_delta = "
|
|
||||||
// << bad_send_delta.ToMicroseconds());
|
|
||||||
// EXPECT_CALL(log, Log(LOG_WARNING, _, _, _, HasSubstr("Ignoring")));
|
|
||||||
rttStats.UpdateRTT(badSendDelta, 0, time.Time{})
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal(initialRtt))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ResetAfterConnectionMigrations", func() {
|
|
||||||
rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{})
|
|
||||||
Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond)))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond)))
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond)))
|
|
||||||
rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{})
|
|
||||||
Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond)))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond)))
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond)))
|
|
||||||
|
|
||||||
// Reset rtt stats on connection migrations.
|
|
||||||
rttStats.OnConnectionMigration()
|
|
||||||
Expect(rttStats.LatestRTT()).To(Equal(time.Duration(0)))
|
|
||||||
Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0)))
|
|
||||||
Expect(rttStats.MinRTT()).To(Equal(time.Duration(0)))
|
|
||||||
})
|
|
||||||
|
|
||||||
})
|
|
@ -1,69 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("AES-GCM", func() {
|
|
||||||
var (
|
|
||||||
alice, bob AEAD
|
|
||||||
keyAlice, keyBob, ivAlice, ivBob []byte
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
keyAlice = make([]byte, 16)
|
|
||||||
keyBob = make([]byte, 16)
|
|
||||||
ivAlice = make([]byte, 4)
|
|
||||||
ivBob = make([]byte, 4)
|
|
||||||
rand.Reader.Read(keyAlice)
|
|
||||||
rand.Reader.Read(keyBob)
|
|
||||||
rand.Reader.Read(ivAlice)
|
|
||||||
rand.Reader.Read(ivBob)
|
|
||||||
var err error
|
|
||||||
alice, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob, ivAlice)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
bob, err = NewAEADAESGCM12(keyAlice, keyBob, ivAlice, ivBob)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens", func() {
|
|
||||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
text, err := bob.Open(nil, b, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(text).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens reverse", func() {
|
|
||||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
text, err := alice.Open(nil, b, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(text).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has the proper length", func() {
|
|
||||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
Expect(b).To(HaveLen(6 + bob.Overhead()))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails with wrong aad", func() {
|
|
||||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
_, err := bob.Open(nil, b, 42, []byte("aad2"))
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects wrong key and iv sizes", func() {
|
|
||||||
var err error
|
|
||||||
e := "AES-GCM: expected 16-byte keys and 4-byte IVs"
|
|
||||||
_, err = NewAEADAESGCM12(keyBob[1:], keyAlice, ivBob, ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADAESGCM12(keyBob, keyAlice[1:], ivBob, ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob[1:], ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob, ivAlice[1:])
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,84 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("AES-GCM", func() {
|
|
||||||
var (
|
|
||||||
alice, bob AEAD
|
|
||||||
keyAlice, keyBob, ivAlice, ivBob []byte
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
ivAlice = make([]byte, 12)
|
|
||||||
ivBob = make([]byte, 12)
|
|
||||||
})
|
|
||||||
|
|
||||||
// 16 bytes for TLS_AES_128_GCM_SHA256
|
|
||||||
// 32 bytes for TLS_AES_256_GCM_SHA384
|
|
||||||
for _, ks := range []int{16, 32} {
|
|
||||||
keySize := ks
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with %d byte keys", keySize), func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
keyAlice = make([]byte, keySize)
|
|
||||||
keyBob = make([]byte, keySize)
|
|
||||||
rand.Reader.Read(keyAlice)
|
|
||||||
rand.Reader.Read(keyBob)
|
|
||||||
rand.Reader.Read(ivAlice)
|
|
||||||
rand.Reader.Read(ivBob)
|
|
||||||
var err error
|
|
||||||
alice, err = NewAEADAESGCM(keyBob, keyAlice, ivBob, ivAlice)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
bob, err = NewAEADAESGCM(keyAlice, keyBob, ivAlice, ivBob)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens", func() {
|
|
||||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
text, err := bob.Open(nil, b, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(text).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens reverse", func() {
|
|
||||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
text, err := alice.Open(nil, b, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(text).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has the proper length", func() {
|
|
||||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
Expect(b).To(HaveLen(6 + bob.Overhead()))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails with wrong aad", func() {
|
|
||||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
_, err := bob.Open(nil, b, 42, []byte("aad2"))
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects wrong key and iv sizes", func() {
|
|
||||||
e := "AES-GCM: expected 12 byte IVs"
|
|
||||||
var err error
|
|
||||||
_, err = NewAEADAESGCM(keyBob, keyAlice, ivBob[1:], ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADAESGCM(keyBob, keyAlice, ivBob, ivAlice[1:])
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
It("errors when an invalid key size is used", func() {
|
|
||||||
keyAlice = make([]byte, 17)
|
|
||||||
keyBob = make([]byte, 17)
|
|
||||||
_, err := NewAEADAESGCM(keyBob, keyAlice, ivBob, ivAlice)
|
|
||||||
Expect(err).To(MatchError("crypto/aes: invalid key size 17"))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,51 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
lru "github.com/hashicorp/golang-lru"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Certificate cache", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
compressedCertsCache, err = lru.New(2)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gives a compressed cert", func() {
|
|
||||||
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
|
|
||||||
expected, err := compressChain(chain, nil, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
compressed, err := getCompressedCert(chain, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(compressed).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the same result multiple times", func() {
|
|
||||||
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
|
|
||||||
compressed, err := getCompressedCert(chain, nil, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
compressed2, err := getCompressedCert(chain, nil, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(compressed).To(Equal(compressed2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stores cached values", func() {
|
|
||||||
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
|
|
||||||
_, err := getCompressedCert(chain, nil, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(compressedCertsCache.Len()).To(Equal(1))
|
|
||||||
Expect(compressedCertsCache.Contains(uint64(3838929964809501833))).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("evicts old values", func() {
|
|
||||||
_, err := getCompressedCert([][]byte{{0x00}}, nil, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
_, err = getCompressedCert([][]byte{{0x01}}, nil, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
_, err = getCompressedCert([][]byte{{0x02}}, nil, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(compressedCertsCache.Len()).To(Equal(2))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,148 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"compress/flate"
|
|
||||||
"compress/zlib"
|
|
||||||
"crypto/tls"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Proof", func() {
|
|
||||||
var (
|
|
||||||
cc *certChain
|
|
||||||
config *tls.Config
|
|
||||||
cert tls.Certificate
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
cert = testdata.GetCertificate()
|
|
||||||
config = &tls.Config{}
|
|
||||||
cc = NewCertChain(config).(*certChain)
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("certificate compression", func() {
|
|
||||||
It("compresses certs", func() {
|
|
||||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
certZlib := &bytes.Buffer{}
|
|
||||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
|
||||||
z.Write(cert)
|
|
||||||
z.Close()
|
|
||||||
kd := &certChain{
|
|
||||||
config: &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{
|
|
||||||
{Certificate: [][]byte{cert}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
certCompressed, err := kd.GetCertsCompressed("", nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(certCompressed).To(Equal(append([]byte{
|
|
||||||
0x01, 0x00,
|
|
||||||
0x08, 0x00, 0x00, 0x00,
|
|
||||||
}, certZlib.Bytes()...)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when it can't retrieve a certificate", func() {
|
|
||||||
_, err := cc.GetCertsCompressed("invalid domain", nil, nil)
|
|
||||||
Expect(err).To(MatchError(errNoMatchingCertificate))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("signing server configs", func() {
|
|
||||||
It("errors when it can't retrieve a certificate for the requested SNI", func() {
|
|
||||||
_, err := cc.SignServerProof("invalid", []byte("chlo"), []byte("scfg"))
|
|
||||||
Expect(err).To(MatchError(errNoMatchingCertificate))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("signs the server config", func() {
|
|
||||||
config.Certificates = []tls.Certificate{cert}
|
|
||||||
proof, err := cc.SignServerProof("", []byte("chlo"), []byte("scfg"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(proof).ToNot(BeEmpty())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("retrieving certificates", func() {
|
|
||||||
It("errors without certificates", func() {
|
|
||||||
_, err := cc.getCertForSNI("")
|
|
||||||
Expect(err).To(MatchError(errNoMatchingCertificate))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses first certificate in config.Certificates", func() {
|
|
||||||
config.Certificates = []tls.Certificate{cert}
|
|
||||||
cert, err := cc.getCertForSNI("")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cert.PrivateKey).ToNot(BeNil())
|
|
||||||
Expect(cert.Certificate[0]).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses NameToCertificate entries", func() {
|
|
||||||
config.Certificates = []tls.Certificate{cert, cert} // two entries so the long path is used
|
|
||||||
config.NameToCertificate = map[string]*tls.Certificate{
|
|
||||||
"quic.clemente.io": &cert,
|
|
||||||
}
|
|
||||||
cert, err := cc.getCertForSNI("quic.clemente.io")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cert.PrivateKey).ToNot(BeNil())
|
|
||||||
Expect(cert.Certificate[0]).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses NameToCertificate entries with wildcard", func() {
|
|
||||||
config.Certificates = []tls.Certificate{cert, cert} // two entries so the long path is used
|
|
||||||
config.NameToCertificate = map[string]*tls.Certificate{
|
|
||||||
"*.clemente.io": &cert,
|
|
||||||
}
|
|
||||||
cert, err := cc.getCertForSNI("quic.clemente.io")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cert.PrivateKey).ToNot(BeNil())
|
|
||||||
Expect(cert.Certificate[0]).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses GetCertificate", func() {
|
|
||||||
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
Expect(clientHello.ServerName).To(Equal("quic.clemente.io"))
|
|
||||||
return &cert, nil
|
|
||||||
}
|
|
||||||
cert, err := cc.getCertForSNI("quic.clemente.io")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cert.PrivateKey).ToNot(BeNil())
|
|
||||||
Expect(cert.Certificate[0]).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets leaf certificates", func() {
|
|
||||||
config.Certificates = []tls.Certificate{cert}
|
|
||||||
cert2, err := cc.GetLeafCert("")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cert2).To(Equal(cert.Certificate[0]))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when it can't retrieve a leaf certificate", func() {
|
|
||||||
_, err := cc.GetLeafCert("invalid domain")
|
|
||||||
Expect(err).To(MatchError(errNoMatchingCertificate))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("respects GetConfigForClient", func() {
|
|
||||||
if !reflect.ValueOf(tls.Config{}).FieldByName("GetConfigForClient").IsValid() {
|
|
||||||
// Pre 1.8, we don't have to do anything
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nestedConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
|
|
||||||
l := func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
||||||
Expect(chi.ServerName).To(Equal("quic.clemente.io"))
|
|
||||||
return nestedConfig, nil
|
|
||||||
}
|
|
||||||
reflect.ValueOf(config).Elem().FieldByName("GetConfigForClient").Set(reflect.ValueOf(l))
|
|
||||||
resultCert, err := cc.getCertForSNI("quic.clemente.io")
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(*resultCert).To(Equal(cert))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,294 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"compress/flate"
|
|
||||||
"compress/zlib"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"hash/fnv"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go-certificates"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
func byteHash(d []byte) []byte {
|
|
||||||
h := fnv.New64a()
|
|
||||||
h.Write(d)
|
|
||||||
s := h.Sum64()
|
|
||||||
res := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(res, s)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Cert compression and decompression", func() {
|
|
||||||
var certSetsOld map[uint64]certSet
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
certSetsOld = make(map[uint64]certSet)
|
|
||||||
for s := range certSets {
|
|
||||||
certSetsOld[s] = certSets[s]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
certSets = certSetsOld
|
|
||||||
})
|
|
||||||
|
|
||||||
It("compresses empty", func() {
|
|
||||||
compressed, err := compressChain(nil, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(compressed).To(Equal([]byte{0}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("decompresses empty", func() {
|
|
||||||
compressed, err := compressChain(nil, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
uncompressed, err := decompressChain(compressed)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(uncompressed).To(BeEmpty())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gives correct single cert", func() {
|
|
||||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
certZlib := &bytes.Buffer{}
|
|
||||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
|
||||||
z.Write(cert)
|
|
||||||
z.Close()
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
compressed, err := compressChain(chain, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(compressed).To(Equal(append([]byte{
|
|
||||||
0x01, 0x00,
|
|
||||||
0x08, 0x00, 0x00, 0x00,
|
|
||||||
}, certZlib.Bytes()...)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("decompresses a single cert", func() {
|
|
||||||
cert := []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
compressed, err := compressChain(chain, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
uncompressed, err := decompressChain(compressed)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(uncompressed).To(Equal(chain))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gives correct cert and intermediate", func() {
|
|
||||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
|
||||||
certZlib := &bytes.Buffer{}
|
|
||||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
|
||||||
z.Write(cert1)
|
|
||||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
|
||||||
z.Write(cert2)
|
|
||||||
z.Close()
|
|
||||||
chain := [][]byte{cert1, cert2}
|
|
||||||
compressed, err := compressChain(chain, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(compressed).To(Equal(append([]byte{
|
|
||||||
0x01, 0x01, 0x00,
|
|
||||||
0x10, 0x00, 0x00, 0x00,
|
|
||||||
}, certZlib.Bytes()...)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("decompresses the chain with a cert and an intermediate", func() {
|
|
||||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
|
||||||
chain := [][]byte{cert1, cert2}
|
|
||||||
compressed, err := compressChain(chain, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
decompressed, err := decompressChain(compressed)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(decompressed).To(Equal(chain))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses cached certificates", func() {
|
|
||||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
certHash := byteHash(cert)
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
compressed, err := compressChain(chain, nil, certHash)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expected := append([]byte{0x02}, certHash...)
|
|
||||||
expected = append(expected, 0x00)
|
|
||||||
Expect(compressed).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses cached certificates and compressed combined", func() {
|
|
||||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
|
||||||
cert2Hash := byteHash(cert2)
|
|
||||||
certZlib := &bytes.Buffer{}
|
|
||||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
|
||||||
z.Write(cert1)
|
|
||||||
z.Close()
|
|
||||||
chain := [][]byte{cert1, cert2}
|
|
||||||
compressed, err := compressChain(chain, nil, cert2Hash)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expected := []byte{0x01, 0x02}
|
|
||||||
expected = append(expected, cert2Hash...)
|
|
||||||
expected = append(expected, 0x00)
|
|
||||||
expected = append(expected, []byte{0x08, 0, 0, 0}...)
|
|
||||||
expected = append(expected, certZlib.Bytes()...)
|
|
||||||
Expect(compressed).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses common certificate sets", func() {
|
|
||||||
cert := certsets.CertSet3[42]
|
|
||||||
setHash := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
compressed, err := compressChain(chain, setHash, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expected := []byte{0x03}
|
|
||||||
expected = append(expected, setHash...)
|
|
||||||
expected = append(expected, []byte{42, 0, 0, 0}...)
|
|
||||||
expected = append(expected, 0x00)
|
|
||||||
Expect(compressed).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("decompresses a single cert form a common certificate set", func() {
|
|
||||||
cert := certsets.CertSet3[42]
|
|
||||||
setHash := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
compressed, err := compressChain(chain, setHash, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
decompressed, err := decompressChain(compressed)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(decompressed).To(Equal(chain))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("decompresses multiple certs form common certificate sets", func() {
|
|
||||||
cert1 := certsets.CertSet3[42]
|
|
||||||
cert2 := certsets.CertSet2[24]
|
|
||||||
setHash := make([]byte, 16)
|
|
||||||
binary.LittleEndian.PutUint64(setHash[0:8], certsets.CertSet3Hash)
|
|
||||||
binary.LittleEndian.PutUint64(setHash[8:16], certsets.CertSet2Hash)
|
|
||||||
chain := [][]byte{cert1, cert2}
|
|
||||||
compressed, err := compressChain(chain, setHash, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
decompressed, err := decompressChain(compressed)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(decompressed).To(Equal(chain))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores uncommon certificate sets", func() {
|
|
||||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
setHash := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(setHash, 0xdeadbeef)
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
compressed, err := compressChain(chain, setHash, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
certZlib := &bytes.Buffer{}
|
|
||||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
|
||||||
z.Write(cert)
|
|
||||||
z.Close()
|
|
||||||
Expect(compressed).To(Equal(append([]byte{
|
|
||||||
0x01, 0x00,
|
|
||||||
0x08, 0x00, 0x00, 0x00,
|
|
||||||
}, certZlib.Bytes()...)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if a common set does not exist", func() {
|
|
||||||
cert := certsets.CertSet3[42]
|
|
||||||
setHash := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
compressed, err := compressChain(chain, setHash, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
delete(certSets, certsets.CertSet3Hash)
|
|
||||||
_, err = decompressChain(compressed)
|
|
||||||
Expect(err).To(MatchError(errors.New("unknown certSet")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if a cert in a common set does not exist", func() {
|
|
||||||
certSet := [][]byte{
|
|
||||||
{0x1, 0x2, 0x3, 0x4},
|
|
||||||
{0x5, 0x6, 0x7, 0x8},
|
|
||||||
}
|
|
||||||
certSets[0x1337] = certSet
|
|
||||||
cert := certSet[1]
|
|
||||||
setHash := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(setHash, 0x1337)
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
compressed, err := compressChain(chain, setHash, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
certSets[0x1337] = certSet[:1] // delete the last certificate from the certSet
|
|
||||||
_, err = decompressChain(compressed)
|
|
||||||
Expect(err).To(MatchError(errors.New("certificate not found in certSet")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses common certificates and compressed combined", func() {
|
|
||||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
cert2 := certsets.CertSet3[42]
|
|
||||||
setHash := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
|
||||||
certZlib := &bytes.Buffer{}
|
|
||||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
|
||||||
z.Write(cert1)
|
|
||||||
z.Close()
|
|
||||||
chain := [][]byte{cert1, cert2}
|
|
||||||
compressed, err := compressChain(chain, setHash, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expected := []byte{0x01, 0x03}
|
|
||||||
expected = append(expected, setHash...)
|
|
||||||
expected = append(expected, []byte{42, 0, 0, 0}...)
|
|
||||||
expected = append(expected, 0x00)
|
|
||||||
expected = append(expected, []byte{0x08, 0, 0, 0}...)
|
|
||||||
expected = append(expected, certZlib.Bytes()...)
|
|
||||||
Expect(compressed).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("decompresses a certficate from a common set and a compressed cert combined", func() {
|
|
||||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
cert2 := certsets.CertSet3[42]
|
|
||||||
setHash := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
|
||||||
chain := [][]byte{cert1, cert2}
|
|
||||||
compressed, err := compressChain(chain, setHash, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
decompressed, err := decompressChain(compressed)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(decompressed).To(Equal(chain))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects invalid CCS / CCRT hashes", func() {
|
|
||||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
|
||||||
chain := [][]byte{cert}
|
|
||||||
_, err := compressChain(chain, []byte("foo"), nil)
|
|
||||||
Expect(err).To(MatchError("expected a multiple of 8 bytes for CCS / CCRT hashes"))
|
|
||||||
_, err = compressChain(chain, nil, []byte("foo"))
|
|
||||||
Expect(err).To(MatchError("expected a multiple of 8 bytes for CCS / CCRT hashes"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("common certificate hashes", func() {
|
|
||||||
It("gets the hashes", func() {
|
|
||||||
ccs := getCommonCertificateHashes()
|
|
||||||
Expect(ccs).ToNot(BeEmpty())
|
|
||||||
hashes, err := splitHashes(ccs)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
for _, hash := range hashes {
|
|
||||||
Expect(certSets).To(HaveKey(hash))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns an empty slice if there are not common sets", func() {
|
|
||||||
certSets = make(map[uint64]certSet)
|
|
||||||
ccs := getCommonCertificateHashes()
|
|
||||||
Expect(ccs).ToNot(BeNil())
|
|
||||||
Expect(ccs).To(HaveLen(0))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,348 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/asn1"
|
|
||||||
"math/big"
|
|
||||||
"runtime"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Cert Manager", func() {
|
|
||||||
var cm *certManager
|
|
||||||
var key1, key2 *rsa.PrivateKey
|
|
||||||
var cert1, cert2 []byte
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
cm = NewCertManager(nil).(*certManager)
|
|
||||||
key1, err = rsa.GenerateKey(rand.Reader, 768)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
key2, err = rsa.GenerateKey(rand.Reader, 768)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
template := &x509.Certificate{SerialNumber: big.NewInt(1)}
|
|
||||||
cert1, err = x509.CreateCertificate(rand.Reader, template, template, &key1.PublicKey, key1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cert2, err = x509.CreateCertificate(rand.Reader, template, template, &key2.PublicKey, key2)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves a client TLS config", func() {
|
|
||||||
tlsConf := &tls.Config{ServerName: "quic.clemente.io"}
|
|
||||||
cm = NewCertManager(tlsConf).(*certManager)
|
|
||||||
Expect(cm.config.ServerName).To(Equal("quic.clemente.io"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when given invalid data", func() {
|
|
||||||
err := cm.SetData([]byte("foobar"))
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the common certificate hashes", func() {
|
|
||||||
ccs := cm.GetCommonCertificateHashes()
|
|
||||||
Expect(ccs).ToNot(BeEmpty())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("setting the data", func() {
|
|
||||||
It("decompresses a certificate chain", func() {
|
|
||||||
chain := [][]byte{cert1, cert2}
|
|
||||||
compressed, err := compressChain(chain, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = cm.SetData(compressed)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cm.chain[0].Raw).To(Equal(cert1))
|
|
||||||
Expect(cm.chain[1].Raw).To(Equal(cert2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if it can't decompress the chain", func() {
|
|
||||||
err := cm.SetData([]byte("invalid data"))
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if it can't parse a certificate", func() {
|
|
||||||
chain := [][]byte{[]byte("cert1"), []byte("cert2")}
|
|
||||||
compressed, err := compressChain(chain, nil, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = cm.SetData(compressed)
|
|
||||||
_, ok := err.(asn1.StructuralError)
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("getting the leaf cert", func() {
|
|
||||||
It("gets it", func() {
|
|
||||||
xcert1, err := x509.ParseCertificate(cert1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
xcert2, err := x509.ParseCertificate(cert2)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cm.chain = []*x509.Certificate{xcert1, xcert2}
|
|
||||||
leafCert := cm.GetLeafCert()
|
|
||||||
Expect(leafCert).To(Equal(cert1))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns nil if the chain hasn't been set yet", func() {
|
|
||||||
leafCert := cm.GetLeafCert()
|
|
||||||
Expect(leafCert).To(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("getting the leaf cert hash", func() {
|
|
||||||
It("calculates the FVN1a 64 hash", func() {
|
|
||||||
cm.chain = make([]*x509.Certificate, 1)
|
|
||||||
cm.chain[0] = &x509.Certificate{
|
|
||||||
Raw: []byte("test fnv hash"),
|
|
||||||
}
|
|
||||||
hash, err := cm.GetLeafCertHash()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// hash calculated on http://www.nitrxgen.net/hashgen/
|
|
||||||
Expect(hash).To(Equal(uint64(0x4770f6141fa0f5ad)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the certificate chain is not loaded", func() {
|
|
||||||
_, err := cm.GetLeafCertHash()
|
|
||||||
Expect(err).To(MatchError(errNoCertificateChain))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("verifying the server config signature", func() {
|
|
||||||
It("returns false when the chain hasn't been set yet", func() {
|
|
||||||
valid := cm.VerifyServerProof([]byte("proof"), []byte("chlo"), []byte("scfg"))
|
|
||||||
Expect(valid).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("verifies the signature", func() {
|
|
||||||
chlo := []byte("client hello")
|
|
||||||
scfg := []byte("server config data")
|
|
||||||
xcert1, err := x509.ParseCertificate(cert1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cm.chain = []*x509.Certificate{xcert1}
|
|
||||||
proof, err := signServerProof(&tls.Certificate{PrivateKey: key1}, chlo, scfg)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
valid := cm.VerifyServerProof(proof, chlo, scfg)
|
|
||||||
Expect(valid).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects an invalid signature", func() {
|
|
||||||
xcert1, err := x509.ParseCertificate(cert1)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cm.chain = []*x509.Certificate{xcert1}
|
|
||||||
valid := cm.VerifyServerProof([]byte("invalid proof"), []byte("chlo"), []byte("scfg"))
|
|
||||||
Expect(valid).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("verifying the certificate chain", func() {
|
|
||||||
generateCertificate := func(template, parent *x509.Certificate, pubKey *rsa.PublicKey, privKey *rsa.PrivateKey) *x509.Certificate {
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pubKey, privKey)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cert, err := x509.ParseCertificate(certDER)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return cert
|
|
||||||
}
|
|
||||||
|
|
||||||
getCertificate := func(template *x509.Certificate) (*rsa.PrivateKey, *x509.Certificate) {
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return key, generateCertificate(template, template, &key.PublicKey, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
It("accepts a valid certificate", func() {
|
|
||||||
cc := NewCertChain(testdata.GetTLSConfig()).(*certChain)
|
|
||||||
tlsCert, err := cc.getCertForSNI("quic.clemente.io")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
for _, data := range tlsCert.Certificate {
|
|
||||||
var cert *x509.Certificate
|
|
||||||
cert, err = x509.ParseCertificate(data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cm.chain = append(cm.chain, cert)
|
|
||||||
}
|
|
||||||
err = cm.Verify("quic.clemente.io")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't accept an expired certificate", func() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
|
||||||
Skip("windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(-25 * time.Hour),
|
|
||||||
NotAfter: time.Now().Add(-time.Hour),
|
|
||||||
}
|
|
||||||
_, leafCert := getCertificate(template)
|
|
||||||
|
|
||||||
cm.chain = []*x509.Certificate{leafCert}
|
|
||||||
err := cm.Verify("")
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't accept a certificate that is not yet valid", func() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
|
||||||
Skip("windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(time.Hour),
|
|
||||||
NotAfter: time.Now().Add(25 * time.Hour),
|
|
||||||
}
|
|
||||||
_, leafCert := getCertificate(template)
|
|
||||||
|
|
||||||
cm.chain = []*x509.Certificate{leafCert}
|
|
||||||
err := cm.Verify("")
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't accept an certificate for the wrong hostname", func() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
|
||||||
Skip("windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(-time.Hour),
|
|
||||||
NotAfter: time.Now().Add(time.Hour),
|
|
||||||
Subject: pkix.Name{CommonName: "google.com"},
|
|
||||||
}
|
|
||||||
_, leafCert := getCertificate(template)
|
|
||||||
|
|
||||||
cm.chain = []*x509.Certificate{leafCert}
|
|
||||||
err := cm.Verify("quic.clemente.io")
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
_, ok := err.(x509.HostnameError)
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the chain hasn't been set yet", func() {
|
|
||||||
err := cm.Verify("example.com")
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
// this tests relies on LetsEncrypt not being contained in the Root CAs
|
|
||||||
It("rejects valid certificate with missing certificate chain", func() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
Skip("LetsEncrypt Root CA is included in Windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
cert := testdata.GetCertificate()
|
|
||||||
xcert, err := x509.ParseCertificate(cert.Certificate[0])
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cm.chain = []*x509.Certificate{xcert}
|
|
||||||
err = cm.Verify("quic.clemente.io")
|
|
||||||
_, ok := err.(x509.UnknownAuthorityError)
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't do any certificate verification if InsecureSkipVerify is set", func() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
|
||||||
Skip("windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, leafCert := getCertificate(template)
|
|
||||||
cm.config = &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
}
|
|
||||||
cm.chain = []*x509.Certificate{leafCert}
|
|
||||||
err := cm.Verify("quic.clemente.io")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the time specified in a client TLS config", func() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
|
||||||
Skip("windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(-25 * time.Hour),
|
|
||||||
NotAfter: time.Now().Add(-23 * time.Hour),
|
|
||||||
Subject: pkix.Name{CommonName: "quic.clemente.io"},
|
|
||||||
}
|
|
||||||
_, leafCert := getCertificate(template)
|
|
||||||
cm.chain = []*x509.Certificate{leafCert}
|
|
||||||
cm.config = &tls.Config{
|
|
||||||
Time: func() time.Time { return time.Now().Add(-24 * time.Hour) },
|
|
||||||
}
|
|
||||||
err := cm.Verify("quic.clemente.io")
|
|
||||||
_, ok := err.(x509.UnknownAuthorityError)
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects certificates that are expired at the time specified in a client TLS config", func() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
|
||||||
Skip("windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(-time.Hour),
|
|
||||||
NotAfter: time.Now().Add(time.Hour),
|
|
||||||
}
|
|
||||||
_, leafCert := getCertificate(template)
|
|
||||||
cm.chain = []*x509.Certificate{leafCert}
|
|
||||||
cm.config = &tls.Config{
|
|
||||||
Time: func() time.Time { return time.Now().Add(-24 * time.Hour) },
|
|
||||||
}
|
|
||||||
err := cm.Verify("quic.clemente.io")
|
|
||||||
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the Root CA given in the client config", func() {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
|
||||||
Skip("windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
templateRoot := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(-time.Hour),
|
|
||||||
NotAfter: time.Now().Add(time.Hour),
|
|
||||||
IsCA: true,
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
rootKey, rootCert := getCertificate(templateRoot)
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
NotBefore: time.Now().Add(-time.Hour),
|
|
||||||
NotAfter: time.Now().Add(time.Hour),
|
|
||||||
Subject: pkix.Name{CommonName: "google.com"},
|
|
||||||
}
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
leafCert := generateCertificate(template, rootCert, &key.PublicKey, rootKey)
|
|
||||||
|
|
||||||
rootCAPool := x509.NewCertPool()
|
|
||||||
rootCAPool.AddCert(rootCert)
|
|
||||||
|
|
||||||
cm.chain = []*x509.Certificate{leafCert}
|
|
||||||
cm.config = &tls.Config{
|
|
||||||
RootCAs: rootCAPool,
|
|
||||||
}
|
|
||||||
err = cm.Verify("google.com")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,71 +0,0 @@
|
|||||||
// +build ignore
|
|
||||||
|
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Chacha20poly1305", func() {
|
|
||||||
var (
|
|
||||||
alice, bob AEAD
|
|
||||||
keyAlice, keyBob, ivAlice, ivBob []byte
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
keyAlice = make([]byte, 32)
|
|
||||||
keyBob = make([]byte, 32)
|
|
||||||
ivAlice = make([]byte, 4)
|
|
||||||
ivBob = make([]byte, 4)
|
|
||||||
rand.Reader.Read(keyAlice)
|
|
||||||
rand.Reader.Read(keyBob)
|
|
||||||
rand.Reader.Read(ivAlice)
|
|
||||||
rand.Reader.Read(ivBob)
|
|
||||||
var err error
|
|
||||||
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens", func() {
|
|
||||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
text, err := bob.Open(nil, b, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(text).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens reverse", func() {
|
|
||||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
text, err := alice.Open(nil, b, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(text).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has the proper length", func() {
|
|
||||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
Expect(b).To(HaveLen(6 + 12))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails with wrong aad", func() {
|
|
||||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
_, err := bob.Open(nil, b, 42, []byte("aad2"))
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects wrong key and iv sizes", func() {
|
|
||||||
var err error
|
|
||||||
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
|
|
||||||
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,13 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCrypto(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Crypto Suite")
|
|
||||||
}
|
|
@ -1,27 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("ProofRsa", func() {
|
|
||||||
It("works", func() {
|
|
||||||
a, err := NewCurve25519KEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
b, err := NewCurve25519KEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
sA, err := a.CalculateSharedKey(b.PublicKey())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
sB, err := b.CalculateSharedKey(a.PublicKey())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(sA).To(Equal(sB))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects short public keys", func() {
|
|
||||||
a, err := NewCurve25519KEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = a.CalculateSharedKey(nil)
|
|
||||||
Expect(err).To(MatchError("Curve25519: expected public key of 32 byte"))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,197 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("QUIC Crypto Key Derivation", func() {
|
|
||||||
// Context("chacha20poly1305", func() {
|
|
||||||
// It("derives non-fs keys", func() {
|
|
||||||
// aead, err := DeriveKeysChacha20(
|
|
||||||
// protocol.Version32,
|
|
||||||
// false,
|
|
||||||
// []byte("0123456789012345678901"),
|
|
||||||
// []byte("nonce"),
|
|
||||||
// protocol.ConnectionID(42),
|
|
||||||
// []byte("chlo"),
|
|
||||||
// []byte("scfg"),
|
|
||||||
// []byte("cert"),
|
|
||||||
// nil,
|
|
||||||
// )
|
|
||||||
// Expect(err).ToNot(HaveOccurred())
|
|
||||||
// chacha := aead.(*aeadChacha20Poly1305)
|
|
||||||
// // If the IVs match, the keys will match too, since the keys are read earlier
|
|
||||||
// Expect(chacha.myIV).To(Equal([]byte{0xf0, 0xf5, 0x4c, 0xa8}))
|
|
||||||
// Expect(chacha.otherIV).To(Equal([]byte{0x75, 0xd8, 0xa2, 0x8d}))
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// It("derives fs keys", func() {
|
|
||||||
// aead, err := DeriveKeysChacha20(
|
|
||||||
// protocol.Version32,
|
|
||||||
// true,
|
|
||||||
// []byte("0123456789012345678901"),
|
|
||||||
// []byte("nonce"),
|
|
||||||
// protocol.ConnectionID(42),
|
|
||||||
// []byte("chlo"),
|
|
||||||
// []byte("scfg"),
|
|
||||||
// []byte("cert"),
|
|
||||||
// nil,
|
|
||||||
// )
|
|
||||||
// Expect(err).ToNot(HaveOccurred())
|
|
||||||
// chacha := aead.(*aeadChacha20Poly1305)
|
|
||||||
// // If the IVs match, the keys will match too, since the keys are read earlier
|
|
||||||
// Expect(chacha.myIV).To(Equal([]byte{0xf5, 0x73, 0x11, 0x79}))
|
|
||||||
// Expect(chacha.otherIV).To(Equal([]byte{0xf7, 0x26, 0x4d, 0x2c}))
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// It("does not use diversification nonces in FS key derivation", func() {
|
|
||||||
// aead, err := DeriveKeysChacha20(
|
|
||||||
// protocol.Version33,
|
|
||||||
// true,
|
|
||||||
// []byte("0123456789012345678901"),
|
|
||||||
// []byte("nonce"),
|
|
||||||
// protocol.ConnectionID(42),
|
|
||||||
// []byte("chlo"),
|
|
||||||
// []byte("scfg"),
|
|
||||||
// []byte("cert"),
|
|
||||||
// []byte("divnonce"),
|
|
||||||
// )
|
|
||||||
// Expect(err).ToNot(HaveOccurred())
|
|
||||||
// chacha := aead.(*aeadChacha20Poly1305)
|
|
||||||
// // If the IVs match, the keys will match too, since the keys are read earlier
|
|
||||||
// Expect(chacha.myIV).To(Equal([]byte{0xf5, 0x73, 0x11, 0x79}))
|
|
||||||
// Expect(chacha.otherIV).To(Equal([]byte{0xf7, 0x26, 0x4d, 0x2c}))
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// It("uses diversification nonces in initial key derivation", func() {
|
|
||||||
// aead, err := DeriveKeysChacha20(
|
|
||||||
// protocol.Version33,
|
|
||||||
// false,
|
|
||||||
// []byte("0123456789012345678901"),
|
|
||||||
// []byte("nonce"),
|
|
||||||
// protocol.ConnectionID(42),
|
|
||||||
// []byte("chlo"),
|
|
||||||
// []byte("scfg"),
|
|
||||||
// []byte("cert"),
|
|
||||||
// []byte("divnonce"),
|
|
||||||
// )
|
|
||||||
// Expect(err).ToNot(HaveOccurred())
|
|
||||||
// chacha := aead.(*aeadChacha20Poly1305)
|
|
||||||
// // If the IVs match, the keys will match too, since the keys are read earlier
|
|
||||||
// Expect(chacha.myIV).To(Equal([]byte{0xc4, 0x12, 0x25, 0x64}))
|
|
||||||
// Expect(chacha.otherIV).To(Equal([]byte{0x75, 0xd8, 0xa2, 0x8d}))
|
|
||||||
// })
|
|
||||||
// })
|
|
||||||
|
|
||||||
Context("AES-GCM", func() {
|
|
||||||
It("derives non-forward secure keys", func() {
|
|
||||||
aead, err := DeriveQuicCryptoAESKeys(
|
|
||||||
false,
|
|
||||||
[]byte("0123456789012345678901"),
|
|
||||||
[]byte("nonce"),
|
|
||||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
|
||||||
[]byte("chlo"),
|
|
||||||
[]byte("scfg"),
|
|
||||||
[]byte("cert"),
|
|
||||||
[]byte("divnonce"),
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
aesgcm := aead.(*aeadAESGCM12)
|
|
||||||
// If the IVs match, the keys will match too, since the keys are read earlier
|
|
||||||
Expect(aesgcm.myIV).To(Equal([]byte{0x1c, 0xec, 0xac, 0x9b}))
|
|
||||||
Expect(aesgcm.otherIV).To(Equal([]byte{0x64, 0xef, 0x3c, 0x9}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the diversification nonce when generating non-forwared secure keys", func() {
|
|
||||||
aead1, err := DeriveQuicCryptoAESKeys(
|
|
||||||
false,
|
|
||||||
[]byte("0123456789012345678901"),
|
|
||||||
[]byte("nonce"),
|
|
||||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
|
||||||
[]byte("chlo"),
|
|
||||||
[]byte("scfg"),
|
|
||||||
[]byte("cert"),
|
|
||||||
[]byte("divnonce"),
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
aead2, err := DeriveQuicCryptoAESKeys(
|
|
||||||
false,
|
|
||||||
[]byte("0123456789012345678901"),
|
|
||||||
[]byte("nonce"),
|
|
||||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
|
||||||
[]byte("chlo"),
|
|
||||||
[]byte("scfg"),
|
|
||||||
[]byte("cert"),
|
|
||||||
[]byte("ecnonvid"),
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
aesgcm1 := aead1.(*aeadAESGCM12)
|
|
||||||
aesgcm2 := aead2.(*aeadAESGCM12)
|
|
||||||
Expect(aesgcm1.myIV).ToNot(Equal(aesgcm2.myIV))
|
|
||||||
Expect(aesgcm1.otherIV).To(Equal(aesgcm2.otherIV))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("derives non-forward secure keys, for the other side", func() {
|
|
||||||
aead, err := DeriveQuicCryptoAESKeys(
|
|
||||||
false,
|
|
||||||
[]byte("0123456789012345678901"),
|
|
||||||
[]byte("nonce"),
|
|
||||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
|
||||||
[]byte("chlo"),
|
|
||||||
[]byte("scfg"),
|
|
||||||
[]byte("cert"),
|
|
||||||
[]byte("divnonce"),
|
|
||||||
protocol.PerspectiveClient,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
aesgcm := aead.(*aeadAESGCM12)
|
|
||||||
// If the IVs match, the keys will match too, since the keys are read earlier
|
|
||||||
Expect(aesgcm.otherIV).To(Equal([]byte{0x1c, 0xec, 0xac, 0x9b}))
|
|
||||||
Expect(aesgcm.myIV).To(Equal([]byte{0x64, 0xef, 0x3c, 0x9}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("derives forward secure keys", func() {
|
|
||||||
aead, err := DeriveQuicCryptoAESKeys(
|
|
||||||
true,
|
|
||||||
[]byte("0123456789012345678901"),
|
|
||||||
[]byte("nonce"),
|
|
||||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
|
||||||
[]byte("chlo"),
|
|
||||||
[]byte("scfg"),
|
|
||||||
[]byte("cert"),
|
|
||||||
nil,
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
aesgcm := aead.(*aeadAESGCM12)
|
|
||||||
// If the IVs match, the keys will match too, since the keys are read earlier
|
|
||||||
Expect(aesgcm.myIV).To(Equal([]byte{0x7, 0xad, 0xab, 0xb8}))
|
|
||||||
Expect(aesgcm.otherIV).To(Equal([]byte{0xf2, 0x7a, 0xcc, 0x42}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not use div-nonce for FS key derivation", func() {
|
|
||||||
aead, err := DeriveQuicCryptoAESKeys(
|
|
||||||
true,
|
|
||||||
[]byte("0123456789012345678901"),
|
|
||||||
[]byte("nonce"),
|
|
||||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
|
||||||
[]byte("chlo"),
|
|
||||||
[]byte("scfg"),
|
|
||||||
[]byte("cert"),
|
|
||||||
[]byte("divnonce"),
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
aesgcm := aead.(*aeadAESGCM12)
|
|
||||||
// If the IVs match, the keys will match too, since the keys are read earlier
|
|
||||||
Expect(aesgcm.myIV).To(Equal([]byte{0x7, 0xad, 0xab, 0xb8}))
|
|
||||||
Expect(aesgcm.otherIV).To(Equal([]byte{0xf2, 0x7a, 0xcc, 0x42}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,56 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockTLSExporter struct {
|
|
||||||
hash crypto.Hash
|
|
||||||
computerError error
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ TLSExporter = &mockTLSExporter{}
|
|
||||||
|
|
||||||
func (c *mockTLSExporter) Handshake() mint.Alert { panic("not implemented") }
|
|
||||||
|
|
||||||
func (c *mockTLSExporter) ConnectionState() mint.ConnectionState {
|
|
||||||
return mint.ConnectionState{
|
|
||||||
CipherSuite: mint.CipherSuiteParams{
|
|
||||||
Hash: c.hash,
|
|
||||||
KeyLen: 32,
|
|
||||||
IvLen: 12,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *mockTLSExporter) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
|
||||||
if c.computerError != nil {
|
|
||||||
return nil, c.computerError
|
|
||||||
}
|
|
||||||
return append([]byte(label), context...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Key Derivation", func() {
|
|
||||||
It("derives keys", func() {
|
|
||||||
clientAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveClient)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveServer)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
ciphertext := clientAEAD.Seal(nil, []byte("foobar"), 0, []byte("aad"))
|
|
||||||
data, err := serverAEAD.Open(nil, ciphertext, 0, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(data).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails when computing the exporter fails", func() {
|
|
||||||
testErr := errors.New("test error")
|
|
||||||
_, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256, computerError: testErr}, protocol.PerspectiveClient)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,86 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("NullAEAD using AES-GCM", func() {
|
|
||||||
// values taken from https://github.com/quicwg/base-drafts/wiki/Test-Vector-for-the-Clear-Text-AEAD-key-derivation
|
|
||||||
Context("using the test vector from the QUIC WG Wiki", func() {
|
|
||||||
connID := protocol.ConnectionID([]byte{0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08})
|
|
||||||
|
|
||||||
It("computes the secrets", func() {
|
|
||||||
clientSecret, serverSecret := computeSecrets(connID)
|
|
||||||
Expect(clientSecret).To(Equal([]byte{
|
|
||||||
0x83, 0x55, 0xf2, 0x1a, 0x3d, 0x8f, 0x83, 0xec,
|
|
||||||
0xb3, 0xd0, 0xf9, 0x71, 0x08, 0xd3, 0xf9, 0x5e,
|
|
||||||
0x0f, 0x65, 0xb4, 0xd8, 0xae, 0x88, 0xa0, 0x61,
|
|
||||||
0x1e, 0xe4, 0x9d, 0xb0, 0xb5, 0x23, 0x59, 0x1d,
|
|
||||||
}))
|
|
||||||
Expect(serverSecret).To(Equal([]byte{
|
|
||||||
0xf8, 0x0e, 0x57, 0x71, 0x48, 0x4b, 0x21, 0xcd,
|
|
||||||
0xeb, 0xb5, 0xaf, 0xe0, 0xa2, 0x56, 0xa3, 0x17,
|
|
||||||
0x41, 0xef, 0xe2, 0xb5, 0xc6, 0xb6, 0x17, 0xba,
|
|
||||||
0xe1, 0xb2, 0xf1, 0x5a, 0x83, 0x04, 0x83, 0xd6,
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("computes the client key and IV", func() {
|
|
||||||
clientSecret, _ := computeSecrets(connID)
|
|
||||||
key, iv := computeNullAEADKeyAndIV(clientSecret)
|
|
||||||
Expect(key).To(Equal([]byte{
|
|
||||||
0x3a, 0xd0, 0x54, 0x2c, 0x4a, 0x85, 0x84, 0x74,
|
|
||||||
0x00, 0x63, 0x04, 0x9e, 0x3b, 0x3c, 0xaa, 0xb2,
|
|
||||||
}))
|
|
||||||
Expect(iv).To(Equal([]byte{
|
|
||||||
0xd1, 0xfd, 0x26, 0x05, 0x42, 0x75, 0x3a, 0xba,
|
|
||||||
0x38, 0x58, 0x9b, 0xad,
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("computes the server key and IV", func() {
|
|
||||||
_, serverSecret := computeSecrets(connID)
|
|
||||||
key, iv := computeNullAEADKeyAndIV(serverSecret)
|
|
||||||
Expect(key).To(Equal([]byte{
|
|
||||||
0xbe, 0xe4, 0xc2, 0x4d, 0x2a, 0xf1, 0x33, 0x80,
|
|
||||||
0xa9, 0xfa, 0x24, 0xa5, 0xe2, 0xba, 0x2c, 0xff,
|
|
||||||
}))
|
|
||||||
Expect(iv).To(Equal([]byte{
|
|
||||||
0x25, 0xb5, 0x8e, 0x24, 0x6d, 0x9e, 0x7d, 0x5f,
|
|
||||||
0xfe, 0x43, 0x23, 0xfe,
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens", func() {
|
|
||||||
connectionID := protocol.ConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef})
|
|
||||||
clientAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveClient)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveServer)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
m, err := serverAEAD.Open(nil, clientMessage, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(m).To(Equal([]byte("foobar")))
|
|
||||||
serverMessage := serverAEAD.Seal(nil, []byte("raboof"), 99, []byte("daa"))
|
|
||||||
m, err = clientAEAD.Open(nil, serverMessage, 99, []byte("daa"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(m).To(Equal([]byte("raboof")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't work if initialized with different connection IDs", func() {
|
|
||||||
c1 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1})
|
|
||||||
c2 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2})
|
|
||||||
clientAEAD, err := newNullAEADAESGCM(c1, protocol.PerspectiveClient)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
serverAEAD, err := newNullAEADAESGCM(c2, protocol.PerspectiveServer)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
_, err = serverAEAD.Open(nil, clientMessage, 42, []byte("aad"))
|
|
||||||
Expect(err).To(MatchError("cipher: message authentication failed"))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,55 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("NullAEAD using FNV128a", func() {
|
|
||||||
aad := []byte("All human beings are born free and equal in dignity and rights.")
|
|
||||||
plainText := []byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")
|
|
||||||
hash36 := []byte{0x98, 0x9b, 0x33, 0x3f, 0xe8, 0xde, 0x32, 0x5c, 0xa6, 0x7f, 0x9c, 0xf7}
|
|
||||||
|
|
||||||
var aeadServer AEAD
|
|
||||||
var aeadClient AEAD
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
aeadServer = &nullAEADFNV128a{protocol.PerspectiveServer}
|
|
||||||
aeadClient = &nullAEADFNV128a{protocol.PerspectiveClient}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens, client => server", func() {
|
|
||||||
cipherText := aeadClient.Seal(nil, plainText, 0, aad)
|
|
||||||
res, err := aeadServer.Open(nil, cipherText, 0, aad)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens, server => client", func() {
|
|
||||||
cipherText := aeadServer.Seal(nil, plainText, 0, aad)
|
|
||||||
res, err := aeadClient.Open(nil, cipherText, 0, aad)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects short ciphertexts", func() {
|
|
||||||
_, err := aeadServer.Open(nil, nil, 0, nil)
|
|
||||||
Expect(err).To(MatchError("NullAEAD: ciphertext cannot be less than 12 bytes long"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals in-place", func() {
|
|
||||||
buf := make([]byte, 6, 12+6)
|
|
||||||
copy(buf, []byte("foobar"))
|
|
||||||
res := aeadServer.Seal(buf[0:0], buf, 0, nil)
|
|
||||||
buf = buf[:12+6]
|
|
||||||
Expect(buf[12:]).To(Equal([]byte("foobar")))
|
|
||||||
Expect(res[12:]).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails", func() {
|
|
||||||
cipherText := append(append(hash36, plainText...), byte(0x42))
|
|
||||||
_, err := aeadClient.Open(nil, cipherText, 0, aad)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,17 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("NullAEAD", func() {
|
|
||||||
It("selects the right FVN variant", func() {
|
|
||||||
connID := protocol.ConnectionID([]byte{0x42, 0, 0, 0, 0, 0, 0, 0})
|
|
||||||
Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.Version39)).To(Equal(&nullAEADFNV128a{
|
|
||||||
perspective: protocol.PerspectiveClient,
|
|
||||||
}))
|
|
||||||
Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.VersionTLS)).To(BeAssignableToTypeOf(&aeadAESGCM{}))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,127 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/asn1"
|
|
||||||
"math/big"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Proof", func() {
|
|
||||||
It("gives valid signatures with the key in internal/testdata", func() {
|
|
||||||
key := &testdata.GetTLSConfig().Certificates[0]
|
|
||||||
signature, err := signServerProof(key, []byte{'C', 'H', 'L', 'O'}, []byte{'S', 'C', 'F', 'G'})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// Generated with:
|
|
||||||
// ruby -e 'require "digest"; p Digest::SHA256.digest("QUIC CHLO and server config signature\x00" + "\x20\x00\x00\x00" + Digest::SHA256.digest("CHLO") + "SCFG")'
|
|
||||||
data := []byte("W\xA6\xFC\xDE\xC7\xD2>c\xE6\xB5\xF6\tq\x9E|<~1\xA33\x01\xCA=\x19\xBD\xC1\xE4\xB0\xBA\x9B\x16%")
|
|
||||||
err = rsa.VerifyPSS(key.PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey), crypto.SHA256, data, signature, &rsa.PSSOptions{SaltLength: 32})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("when using RSA", func() {
|
|
||||||
generateCert := func() (*rsa.PrivateKey, *x509.Certificate) {
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
certTemplate := x509.Certificate{SerialNumber: big.NewInt(1)}
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cert, err := x509.ParseCertificate(certDER)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
return key, cert
|
|
||||||
}
|
|
||||||
|
|
||||||
It("verifies a signature", func() {
|
|
||||||
key, cert := generateCert()
|
|
||||||
chlo := []byte("chlo")
|
|
||||||
scfg := []byte("scfg")
|
|
||||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(verifyServerProof(signature, cert, chlo, scfg)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects invalid signatures", func() {
|
|
||||||
key, cert := generateCert()
|
|
||||||
chlo := []byte("client hello")
|
|
||||||
scfg := []byte("sever config")
|
|
||||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(verifyServerProof(append(signature, byte(0x99)), cert, chlo, scfg)).To(BeFalse())
|
|
||||||
Expect(verifyServerProof(signature, cert, chlo[:len(chlo)-2], scfg)).To(BeFalse())
|
|
||||||
Expect(verifyServerProof(signature, cert, chlo, scfg[:len(scfg)-2])).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("when using ECDSA", func() {
|
|
||||||
generateCert := func() (*ecdsa.PrivateKey, *x509.Certificate) {
|
|
||||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
certTemplate := x509.Certificate{SerialNumber: big.NewInt(1)}
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cert, err := x509.ParseCertificate(certDER)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
return key, cert
|
|
||||||
}
|
|
||||||
|
|
||||||
It("gives valid signatures", func() {
|
|
||||||
key, _ := generateCert()
|
|
||||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, []byte{'C', 'H', 'L', 'O'}, []byte{'S', 'C', 'F', 'G'})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// Generated with:
|
|
||||||
// ruby -e 'require "digest"; p Digest::SHA256.digest("QUIC CHLO and server config signature\x00" + "\x20\x00\x00\x00" + Digest::SHA256.digest("CHLO") + "SCFG")'
|
|
||||||
data := []byte("W\xA6\xFC\xDE\xC7\xD2>c\xE6\xB5\xF6\tq\x9E|<~1\xA33\x01\xCA=\x19\xBD\xC1\xE4\xB0\xBA\x9B\x16%")
|
|
||||||
s := &ecdsaSignature{}
|
|
||||||
_, err = asn1.Unmarshal(signature, s)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
b := ecdsa.Verify(key.Public().(*ecdsa.PublicKey), data, s.R, s.S)
|
|
||||||
Expect(b).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("verifies a signature", func() {
|
|
||||||
key, cert := generateCert()
|
|
||||||
chlo := []byte("chlo")
|
|
||||||
scfg := []byte("server config")
|
|
||||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(verifyServerProof(signature, cert, chlo, scfg)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects invalid signatures", func() {
|
|
||||||
key, cert := generateCert()
|
|
||||||
chlo := []byte("client hello")
|
|
||||||
scfg := []byte("server config")
|
|
||||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(verifyServerProof(append(signature, byte(0x99)), cert, chlo, scfg)).To(BeFalse())
|
|
||||||
Expect(verifyServerProof(signature, cert, chlo[:len(chlo)-2], scfg)).To(BeFalse())
|
|
||||||
Expect(verifyServerProof(signature, cert, chlo, scfg[:len(scfg)-2])).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects signatures generated with a different certificate", func() {
|
|
||||||
key1, cert1 := generateCert()
|
|
||||||
key2, cert2 := generateCert()
|
|
||||||
Expect(key1.PublicKey).ToNot(Equal(key2))
|
|
||||||
Expect(cert1.Equal(cert2)).To(BeFalse())
|
|
||||||
chlo := []byte("chlo")
|
|
||||||
scfg := []byte("sfcg")
|
|
||||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key1}, chlo, scfg)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(verifyServerProof(signature, cert2, chlo, scfg)).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,235 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
// on the CIs, the timing is a lot less precise, so scale every duration by this factor
|
|
||||||
func scaleDuration(t time.Duration) time.Duration {
|
|
||||||
scaleFactor := 1
|
|
||||||
if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
|
|
||||||
scaleFactor = f
|
|
||||||
}
|
|
||||||
Expect(scaleFactor).ToNot(BeZero())
|
|
||||||
return time.Duration(scaleFactor) * t
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Base Flow controller", func() {
|
|
||||||
var controller *baseFlowController
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
controller = &baseFlowController{}
|
|
||||||
controller.rttStats = &congestion.RTTStats{}
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("send flow control", func() {
|
|
||||||
It("adds bytes sent", func() {
|
|
||||||
controller.bytesSent = 5
|
|
||||||
controller.AddBytesSent(6)
|
|
||||||
Expect(controller.bytesSent).To(Equal(protocol.ByteCount(5 + 6)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the size of the remaining flow control window", func() {
|
|
||||||
controller.bytesSent = 5
|
|
||||||
controller.sendWindow = 12
|
|
||||||
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(12 - 5)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("updates the size of the flow control window", func() {
|
|
||||||
controller.AddBytesSent(5)
|
|
||||||
controller.UpdateSendWindow(15)
|
|
||||||
Expect(controller.sendWindow).To(Equal(protocol.ByteCount(15)))
|
|
||||||
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(15 - 5)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("says that the window size is 0 if we sent more than we were allowed to", func() {
|
|
||||||
controller.AddBytesSent(15)
|
|
||||||
controller.UpdateSendWindow(10)
|
|
||||||
Expect(controller.sendWindowSize()).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not decrease the flow control window", func() {
|
|
||||||
controller.UpdateSendWindow(20)
|
|
||||||
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20)))
|
|
||||||
controller.UpdateSendWindow(10)
|
|
||||||
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("says when it's blocked", func() {
|
|
||||||
controller.UpdateSendWindow(100)
|
|
||||||
Expect(controller.IsNewlyBlocked()).To(BeFalse())
|
|
||||||
controller.AddBytesSent(100)
|
|
||||||
blocked, offset := controller.IsNewlyBlocked()
|
|
||||||
Expect(blocked).To(BeTrue())
|
|
||||||
Expect(offset).To(Equal(protocol.ByteCount(100)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't say that it's newly blocked multiple times for the same offset", func() {
|
|
||||||
controller.UpdateSendWindow(100)
|
|
||||||
controller.AddBytesSent(100)
|
|
||||||
newlyBlocked, offset := controller.IsNewlyBlocked()
|
|
||||||
Expect(newlyBlocked).To(BeTrue())
|
|
||||||
Expect(offset).To(Equal(protocol.ByteCount(100)))
|
|
||||||
newlyBlocked, _ = controller.IsNewlyBlocked()
|
|
||||||
Expect(newlyBlocked).To(BeFalse())
|
|
||||||
controller.UpdateSendWindow(150)
|
|
||||||
controller.AddBytesSent(150)
|
|
||||||
newlyBlocked, _ = controller.IsNewlyBlocked()
|
|
||||||
Expect(newlyBlocked).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("receive flow control", func() {
|
|
||||||
var (
|
|
||||||
receiveWindow protocol.ByteCount = 10000
|
|
||||||
receiveWindowSize protocol.ByteCount = 1000
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
controller.bytesRead = receiveWindow - receiveWindowSize
|
|
||||||
controller.receiveWindow = receiveWindow
|
|
||||||
controller.receiveWindowSize = receiveWindowSize
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds bytes read", func() {
|
|
||||||
controller.bytesRead = 5
|
|
||||||
controller.AddBytesRead(6)
|
|
||||||
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(5 + 6)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("triggers a window update when necessary", func() {
|
|
||||||
bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold + 1 // consumed 1 byte more than the threshold
|
|
||||||
bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
|
|
||||||
readPosition := receiveWindow - bytesRemaining
|
|
||||||
controller.bytesRead = readPosition
|
|
||||||
offset := controller.getWindowUpdate()
|
|
||||||
Expect(offset).To(Equal(readPosition + receiveWindowSize))
|
|
||||||
Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowSize))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't trigger a window update when not necessary", func() {
|
|
||||||
bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold - 1 // consumed 1 byte less than the threshold
|
|
||||||
bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
|
|
||||||
readPosition := receiveWindow - bytesRemaining
|
|
||||||
controller.bytesRead = readPosition
|
|
||||||
offset := controller.getWindowUpdate()
|
|
||||||
Expect(offset).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("receive window size auto-tuning", func() {
|
|
||||||
var oldWindowSize protocol.ByteCount
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
oldWindowSize = controller.receiveWindowSize
|
|
||||||
controller.maxReceiveWindowSize = 5000
|
|
||||||
})
|
|
||||||
|
|
||||||
// update the congestion such that it returns a given value for the smoothed RTT
|
|
||||||
setRtt := func(t time.Duration) {
|
|
||||||
controller.rttStats.UpdateRTT(t, 0, time.Now())
|
|
||||||
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
|
|
||||||
}
|
|
||||||
|
|
||||||
It("doesn't increase the window size for a new stream", func() {
|
|
||||||
controller.maybeAdjustWindowSize()
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't increase the window size when no RTT estimate is available", func() {
|
|
||||||
setRtt(0)
|
|
||||||
controller.startNewAutoTuningEpoch()
|
|
||||||
controller.AddBytesRead(400)
|
|
||||||
offset := controller.getWindowUpdate()
|
|
||||||
Expect(offset).ToNot(BeZero()) // make sure a window update is sent
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("increases the window size if read so fast that the window would be consumed in less than 4 RTTs", func() {
|
|
||||||
bytesRead := controller.bytesRead
|
|
||||||
rtt := scaleDuration(20 * time.Millisecond)
|
|
||||||
setRtt(rtt)
|
|
||||||
// consume more than 2/3 of the window...
|
|
||||||
dataRead := receiveWindowSize*2/3 + 1
|
|
||||||
// ... in 4*2/3 of the RTT
|
|
||||||
controller.epochStartOffset = controller.bytesRead
|
|
||||||
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
|
|
||||||
controller.AddBytesRead(dataRead)
|
|
||||||
offset := controller.getWindowUpdate()
|
|
||||||
Expect(offset).ToNot(BeZero())
|
|
||||||
// check that the window size was increased
|
|
||||||
newWindowSize := controller.receiveWindowSize
|
|
||||||
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
|
|
||||||
// check that the new window size was used to increase the offset
|
|
||||||
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't increase the window size if data is read so fast that the window would be consumed in less than 4 RTTs, but less than half the window has been read", func() {
|
|
||||||
// this test only makes sense if a window update is triggered before half of the window has been consumed
|
|
||||||
Expect(protocol.WindowUpdateThreshold).To(BeNumerically(">", 1/3))
|
|
||||||
bytesRead := controller.bytesRead
|
|
||||||
rtt := scaleDuration(20 * time.Millisecond)
|
|
||||||
setRtt(rtt)
|
|
||||||
// consume more than 2/3 of the window...
|
|
||||||
dataRead := receiveWindowSize*1/3 + 1
|
|
||||||
// ... in 4*2/3 of the RTT
|
|
||||||
controller.epochStartOffset = controller.bytesRead
|
|
||||||
controller.epochStartTime = time.Now().Add(-rtt * 4 * 1 / 3)
|
|
||||||
controller.AddBytesRead(dataRead)
|
|
||||||
offset := controller.getWindowUpdate()
|
|
||||||
Expect(offset).ToNot(BeZero())
|
|
||||||
// check that the window size was not increased
|
|
||||||
newWindowSize := controller.receiveWindowSize
|
|
||||||
Expect(newWindowSize).To(Equal(oldWindowSize))
|
|
||||||
// check that the new window size was used to increase the offset
|
|
||||||
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't increase the window size if read too slowly", func() {
|
|
||||||
bytesRead := controller.bytesRead
|
|
||||||
rtt := scaleDuration(20 * time.Millisecond)
|
|
||||||
setRtt(rtt)
|
|
||||||
// consume less than 2/3 of the window...
|
|
||||||
dataRead := receiveWindowSize*2/3 - 1
|
|
||||||
// ... in 4*2/3 of the RTT
|
|
||||||
controller.epochStartOffset = controller.bytesRead
|
|
||||||
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
|
|
||||||
controller.AddBytesRead(dataRead)
|
|
||||||
offset := controller.getWindowUpdate()
|
|
||||||
Expect(offset).ToNot(BeZero())
|
|
||||||
// check that the window size was not increased
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
|
|
||||||
// check that the new window size was used to increase the offset
|
|
||||||
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + oldWindowSize)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't increase the window size to a value higher than the maxReceiveWindowSize", func() {
|
|
||||||
resetEpoch := func() {
|
|
||||||
// make sure the next call to maybeAdjustWindowSize will increase the window
|
|
||||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
|
||||||
controller.epochStartOffset = controller.bytesRead
|
|
||||||
controller.AddBytesRead(controller.receiveWindowSize/2 + 1)
|
|
||||||
}
|
|
||||||
setRtt(scaleDuration(20 * time.Millisecond))
|
|
||||||
resetEpoch()
|
|
||||||
controller.maybeAdjustWindowSize()
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) // 2000
|
|
||||||
// because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowSize() multiple times and get an increase of the window size every time
|
|
||||||
resetEpoch()
|
|
||||||
controller.maybeAdjustWindowSize()
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(2 * 2 * oldWindowSize)) // 4000
|
|
||||||
resetEpoch()
|
|
||||||
controller.maybeAdjustWindowSize()
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
|
|
||||||
controller.maybeAdjustWindowSize()
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,133 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Connection Flow controller", func() {
|
|
||||||
var (
|
|
||||||
controller *connectionFlowController
|
|
||||||
queuedWindowUpdate bool
|
|
||||||
)
|
|
||||||
|
|
||||||
// update the congestion such that it returns a given value for the smoothed RTT
|
|
||||||
setRtt := func(t time.Duration) {
|
|
||||||
controller.rttStats.UpdateRTT(t, 0, time.Now())
|
|
||||||
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
controller = &connectionFlowController{}
|
|
||||||
controller.rttStats = &congestion.RTTStats{}
|
|
||||||
controller.logger = utils.DefaultLogger
|
|
||||||
controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Constructor", func() {
|
|
||||||
rttStats := &congestion.RTTStats{}
|
|
||||||
|
|
||||||
It("sets the send and receive windows", func() {
|
|
||||||
receiveWindow := protocol.ByteCount(2000)
|
|
||||||
maxReceiveWindow := protocol.ByteCount(3000)
|
|
||||||
|
|
||||||
fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, nil, rttStats, utils.DefaultLogger).(*connectionFlowController)
|
|
||||||
Expect(fc.receiveWindow).To(Equal(receiveWindow))
|
|
||||||
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("receive flow control", func() {
|
|
||||||
It("increases the highestReceived by a given window size", func() {
|
|
||||||
controller.highestReceived = 1337
|
|
||||||
controller.IncrementHighestReceived(123)
|
|
||||||
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123)))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("getting window updates", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
controller.receiveWindow = 100
|
|
||||||
controller.receiveWindowSize = 60
|
|
||||||
controller.maxReceiveWindowSize = 1000
|
|
||||||
controller.bytesRead = 100 - 60
|
|
||||||
})
|
|
||||||
|
|
||||||
It("queues window updates", func() {
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedWindowUpdate).To(BeFalse())
|
|
||||||
controller.AddBytesRead(30)
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedWindowUpdate).To(BeTrue())
|
|
||||||
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
|
|
||||||
queuedWindowUpdate = false
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedWindowUpdate).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets a window update", func() {
|
|
||||||
windowSize := controller.receiveWindowSize
|
|
||||||
oldOffset := controller.bytesRead
|
|
||||||
dataRead := windowSize/2 - 1 // make sure not to trigger auto-tuning
|
|
||||||
controller.AddBytesRead(dataRead)
|
|
||||||
offset := controller.GetWindowUpdate()
|
|
||||||
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + 60)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("autotunes the window", func() {
|
|
||||||
oldOffset := controller.bytesRead
|
|
||||||
oldWindowSize := controller.receiveWindowSize
|
|
||||||
rtt := scaleDuration(20 * time.Millisecond)
|
|
||||||
setRtt(rtt)
|
|
||||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
|
||||||
controller.epochStartOffset = oldOffset
|
|
||||||
dataRead := oldWindowSize/2 + 1
|
|
||||||
controller.AddBytesRead(dataRead)
|
|
||||||
offset := controller.GetWindowUpdate()
|
|
||||||
newWindowSize := controller.receiveWindowSize
|
|
||||||
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
|
|
||||||
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + newWindowSize)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("setting the minimum window size", func() {
|
|
||||||
var (
|
|
||||||
oldWindowSize protocol.ByteCount
|
|
||||||
receiveWindow protocol.ByteCount = 10000
|
|
||||||
receiveWindowSize protocol.ByteCount = 1000
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
controller.receiveWindow = receiveWindow
|
|
||||||
controller.receiveWindowSize = receiveWindowSize
|
|
||||||
oldWindowSize = controller.receiveWindowSize
|
|
||||||
controller.maxReceiveWindowSize = 3000
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the minimum window window size", func() {
|
|
||||||
controller.EnsureMinimumWindowSize(1800)
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(protocol.ByteCount(1800)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't reduce the window window size", func() {
|
|
||||||
controller.EnsureMinimumWindowSize(1)
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doens't increase the window size beyond the maxReceiveWindowSize", func() {
|
|
||||||
max := controller.maxReceiveWindowSize
|
|
||||||
controller.EnsureMinimumWindowSize(2 * max)
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(max))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("starts a new epoch after the window size was increased", func() {
|
|
||||||
controller.EnsureMinimumWindowSize(1912)
|
|
||||||
Expect(controller.epochStartTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,24 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCrypto(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "FlowControl Suite")
|
|
||||||
}
|
|
||||||
|
|
||||||
var mockCtrl *gomock.Controller
|
|
||||||
|
|
||||||
var _ = BeforeEach(func() {
|
|
||||||
mockCtrl = gomock.NewController(GinkgoT())
|
|
||||||
})
|
|
||||||
|
|
||||||
var _ = AfterEach(func() {
|
|
||||||
mockCtrl.Finish()
|
|
||||||
})
|
|
@ -1,290 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Stream Flow controller", func() {
|
|
||||||
var (
|
|
||||||
controller *streamFlowController
|
|
||||||
queuedWindowUpdate bool
|
|
||||||
queuedConnWindowUpdate bool
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
queuedWindowUpdate = false
|
|
||||||
queuedConnWindowUpdate = false
|
|
||||||
rttStats := &congestion.RTTStats{}
|
|
||||||
controller = &streamFlowController{
|
|
||||||
streamID: 10,
|
|
||||||
connection: NewConnectionFlowController(1000, 1000, func() { queuedConnWindowUpdate = true }, rttStats, utils.DefaultLogger).(*connectionFlowController),
|
|
||||||
}
|
|
||||||
controller.maxReceiveWindowSize = 10000
|
|
||||||
controller.rttStats = rttStats
|
|
||||||
controller.logger = utils.DefaultLogger
|
|
||||||
controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Constructor", func() {
|
|
||||||
rttStats := &congestion.RTTStats{}
|
|
||||||
receiveWindow := protocol.ByteCount(2000)
|
|
||||||
maxReceiveWindow := protocol.ByteCount(3000)
|
|
||||||
sendWindow := protocol.ByteCount(4000)
|
|
||||||
|
|
||||||
It("sets the send and receive windows", func() {
|
|
||||||
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
|
|
||||||
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
|
|
||||||
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
|
|
||||||
Expect(fc.receiveWindow).To(Equal(receiveWindow))
|
|
||||||
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
|
|
||||||
Expect(fc.sendWindow).To(Equal(sendWindow))
|
|
||||||
Expect(fc.contributesToConnection).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("queues window updates with the correction stream ID", func() {
|
|
||||||
var queued bool
|
|
||||||
queueWindowUpdate := func(id protocol.StreamID) {
|
|
||||||
Expect(id).To(Equal(protocol.StreamID(5)))
|
|
||||||
queued = true
|
|
||||||
}
|
|
||||||
|
|
||||||
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
|
|
||||||
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
|
|
||||||
fc.AddBytesRead(receiveWindow)
|
|
||||||
fc.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queued).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("receiving data", func() {
|
|
||||||
Context("registering received offsets", func() {
|
|
||||||
var receiveWindow protocol.ByteCount = 10000
|
|
||||||
var receiveWindowSize protocol.ByteCount = 600
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
controller.receiveWindow = receiveWindow
|
|
||||||
controller.receiveWindowSize = receiveWindowSize
|
|
||||||
})
|
|
||||||
|
|
||||||
It("updates the highestReceived", func() {
|
|
||||||
controller.highestReceived = 1337
|
|
||||||
err := controller.UpdateHighestReceived(1338, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1338)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("informs the connection flow controller about received data", func() {
|
|
||||||
controller.highestReceived = 10
|
|
||||||
controller.contributesToConnection = true
|
|
||||||
controller.connection.(*connectionFlowController).highestReceived = 100
|
|
||||||
err := controller.UpdateHighestReceived(20, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100 + 10)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't informs the connection flow controller about received data if it doesn't contribute", func() {
|
|
||||||
controller.highestReceived = 10
|
|
||||||
controller.connection.(*connectionFlowController).highestReceived = 100
|
|
||||||
err := controller.UpdateHighestReceived(20, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not decrease the highestReceived", func() {
|
|
||||||
controller.highestReceived = 1337
|
|
||||||
err := controller.UpdateHighestReceived(1000, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does nothing when setting the same byte offset", func() {
|
|
||||||
controller.highestReceived = 1337
|
|
||||||
err := controller.UpdateHighestReceived(1337, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not give a flow control violation when using the window completely", func() {
|
|
||||||
err := controller.UpdateHighestReceived(receiveWindow, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("detects a flow control violation", func() {
|
|
||||||
err := controller.UpdateHighestReceived(receiveWindow+1, false)
|
|
||||||
Expect(err).To(MatchError("FlowControlReceivedTooMuchData: Received 10001 bytes on stream 10, allowed 10000 bytes"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts a final offset higher than the highest received", func() {
|
|
||||||
controller.highestReceived = 100
|
|
||||||
err := controller.UpdateHighestReceived(101, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(101)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when receiving a final offset smaller than the highest offset received so far", func() {
|
|
||||||
controller.highestReceived = 100
|
|
||||||
err := controller.UpdateHighestReceived(99, true)
|
|
||||||
Expect(err).To(MatchError(qerr.StreamDataAfterTermination))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts delayed data after receiving a final offset", func() {
|
|
||||||
err := controller.UpdateHighestReceived(300, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = controller.UpdateHighestReceived(250, false)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when receiving a higher offset after receiving a final offset", func() {
|
|
||||||
err := controller.UpdateHighestReceived(200, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = controller.UpdateHighestReceived(250, false)
|
|
||||||
Expect(err).To(MatchError(qerr.StreamDataAfterTermination))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts duplicate final offsets", func() {
|
|
||||||
err := controller.UpdateHighestReceived(200, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = controller.UpdateHighestReceived(200, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(200)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when receiving inconsistent final offsets", func() {
|
|
||||||
err := controller.UpdateHighestReceived(200, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = controller.UpdateHighestReceived(201, true)
|
|
||||||
Expect(err).To(MatchError("StreamDataAfterTermination: Received inconsistent final offset for stream 10 (old: 200, new: 201 bytes)"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("registering data read", func() {
|
|
||||||
It("saves when data is read, on a stream not contributing to the connection", func() {
|
|
||||||
controller.AddBytesRead(100)
|
|
||||||
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(100)))
|
|
||||||
Expect(controller.connection.(*connectionFlowController).bytesRead).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves when data is read, on a stream not contributing to the connection", func() {
|
|
||||||
controller.contributesToConnection = true
|
|
||||||
controller.AddBytesRead(200)
|
|
||||||
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200)))
|
|
||||||
Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("generating window updates", func() {
|
|
||||||
var oldWindowSize protocol.ByteCount
|
|
||||||
|
|
||||||
// update the congestion such that it returns a given value for the smoothed RTT
|
|
||||||
setRtt := func(t time.Duration) {
|
|
||||||
controller.rttStats.UpdateRTT(t, 0, time.Now())
|
|
||||||
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
controller.receiveWindow = 100
|
|
||||||
controller.receiveWindowSize = 60
|
|
||||||
controller.bytesRead = 100 - 60
|
|
||||||
controller.connection.(*connectionFlowController).receiveWindow = 100
|
|
||||||
controller.connection.(*connectionFlowController).receiveWindowSize = 120
|
|
||||||
oldWindowSize = controller.receiveWindowSize
|
|
||||||
})
|
|
||||||
|
|
||||||
It("queues window updates", func() {
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedWindowUpdate).To(BeFalse())
|
|
||||||
controller.AddBytesRead(30)
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedWindowUpdate).To(BeTrue())
|
|
||||||
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
|
|
||||||
queuedWindowUpdate = false
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedWindowUpdate).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("queues connection-level window updates", func() {
|
|
||||||
controller.contributesToConnection = true
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedConnWindowUpdate).To(BeFalse())
|
|
||||||
controller.AddBytesRead(60)
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedConnWindowUpdate).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("tells the connection flow controller when the window was autotuned", func() {
|
|
||||||
oldOffset := controller.bytesRead
|
|
||||||
controller.contributesToConnection = true
|
|
||||||
setRtt(scaleDuration(20 * time.Millisecond))
|
|
||||||
controller.epochStartOffset = oldOffset
|
|
||||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
|
||||||
controller.AddBytesRead(55)
|
|
||||||
offset := controller.GetWindowUpdate()
|
|
||||||
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + 55 + 2*oldWindowSize)))
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
|
|
||||||
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't tell the connection flow controller if it doesn't contribute", func() {
|
|
||||||
oldOffset := controller.bytesRead
|
|
||||||
controller.contributesToConnection = false
|
|
||||||
setRtt(scaleDuration(20 * time.Millisecond))
|
|
||||||
controller.epochStartOffset = oldOffset
|
|
||||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
|
||||||
controller.AddBytesRead(55)
|
|
||||||
offset := controller.GetWindowUpdate()
|
|
||||||
Expect(offset).ToNot(BeZero())
|
|
||||||
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
|
|
||||||
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(2 * oldWindowSize))) // unchanged
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't increase the window after a final offset was already received", func() {
|
|
||||||
controller.AddBytesRead(30)
|
|
||||||
err := controller.UpdateHighestReceived(90, true)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
controller.MaybeQueueWindowUpdate()
|
|
||||||
Expect(queuedWindowUpdate).To(BeFalse())
|
|
||||||
offset := controller.GetWindowUpdate()
|
|
||||||
Expect(offset).To(BeZero())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("sending data", func() {
|
|
||||||
It("gets the size of the send window", func() {
|
|
||||||
controller.UpdateSendWindow(15)
|
|
||||||
controller.AddBytesSent(5)
|
|
||||||
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't care about the connection-level window, if it doesn't contribute", func() {
|
|
||||||
controller.UpdateSendWindow(15)
|
|
||||||
controller.connection.UpdateSendWindow(1)
|
|
||||||
controller.AddBytesSent(5)
|
|
||||||
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("makes sure that it doesn't overflow the connection-level window", func() {
|
|
||||||
controller.contributesToConnection = true
|
|
||||||
controller.connection.UpdateSendWindow(12)
|
|
||||||
controller.UpdateSendWindow(20)
|
|
||||||
controller.AddBytesSent(10)
|
|
||||||
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(2)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't say that it's blocked, if only the connection is blocked", func() {
|
|
||||||
controller.contributesToConnection = true
|
|
||||||
controller.connection.UpdateSendWindow(50)
|
|
||||||
controller.UpdateSendWindow(100)
|
|
||||||
controller.AddBytesSent(50)
|
|
||||||
blocked, _ := controller.connection.IsNewlyBlocked()
|
|
||||||
Expect(blocked).To(BeTrue())
|
|
||||||
Expect(controller.IsNewlyBlocked()).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,111 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/asn1"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Cookie Generator", func() {
|
|
||||||
var cookieGen *CookieGenerator
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
cookieGen, err = NewCookieGenerator()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates a Cookie", func() {
|
|
||||||
ip := net.IPv4(127, 0, 0, 1)
|
|
||||||
token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(token).ToNot(BeEmpty())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works with nil tokens", func() {
|
|
||||||
cookie, err := cookieGen.DecodeToken(nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cookie).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts a valid cookie", func() {
|
|
||||||
ip := net.IPv4(192, 168, 0, 1)
|
|
||||||
token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cookie, err := cookieGen.DecodeToken(token)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cookie.RemoteAddr).To(Equal("192.168.0.1"))
|
|
||||||
// the time resolution of the Cookie is just 1 second
|
|
||||||
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
|
|
||||||
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects invalid tokens", func() {
|
|
||||||
_, err := cookieGen.DecodeToken([]byte("invalid token"))
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects tokens that cannot be decoded", func() {
|
|
||||||
token, err := cookieGen.cookieProtector.NewToken([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = cookieGen.DecodeToken(token)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects tokens that can be decoded, but have additional payload", func() {
|
|
||||||
t, err := asn1.Marshal(token{Data: []byte("foobar")})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
t = append(t, []byte("rest")...)
|
|
||||||
enc, err := cookieGen.cookieProtector.NewToken(t)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = cookieGen.DecodeToken(enc)
|
|
||||||
Expect(err).To(MatchError("rest when unpacking token: 4"))
|
|
||||||
})
|
|
||||||
|
|
||||||
// we don't generate tokens that have no data, but we should be able to handle them if we receive one for whatever reason
|
|
||||||
It("doesn't panic if a tokens has no data", func() {
|
|
||||||
t, err := asn1.Marshal(token{Data: []byte("")})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
enc, err := cookieGen.cookieProtector.NewToken(t)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = cookieGen.DecodeToken(enc)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works with an IPv6 addresses ", func() {
|
|
||||||
addresses := []string{
|
|
||||||
"2001:db8::68",
|
|
||||||
"2001:0000:4136:e378:8000:63bf:3fff:fdd2",
|
|
||||||
"2001::1",
|
|
||||||
"ff01:0:0:0:0:0:0:2",
|
|
||||||
}
|
|
||||||
for _, addr := range addresses {
|
|
||||||
ip := net.ParseIP(addr)
|
|
||||||
Expect(ip).ToNot(BeNil())
|
|
||||||
raddr := &net.UDPAddr{IP: ip, Port: 1337}
|
|
||||||
token, err := cookieGen.NewToken(raddr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cookie, err := cookieGen.DecodeToken(token)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cookie.RemoteAddr).To(Equal(ip.String()))
|
|
||||||
// the time resolution of the Cookie is just 1 second
|
|
||||||
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
|
|
||||||
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the string representation an address that is not a UDP address", func() {
|
|
||||||
raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
|
||||||
token, err := cookieGen.NewToken(raddr)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cookie, err := cookieGen.DecodeToken(token)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cookie.RemoteAddr).To(Equal("192.168.13.37:1337"))
|
|
||||||
// the time resolution of the Cookie is just 1 second
|
|
||||||
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
|
|
||||||
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,39 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Cookie Protector", func() {
|
|
||||||
var cp cookieProtector
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
cp, err = newCookieProtector()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("encodes and decodes tokens", func() {
|
|
||||||
token, err := cp.NewToken([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(token).ToNot(ContainSubstring("foobar"))
|
|
||||||
decoded, err := cp.DecodeToken(token)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(decoded).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails deconding invalid tokens", func() {
|
|
||||||
token, err := cp.NewToken([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
token = token[1:] // remove the first byte
|
|
||||||
_, err = cp.DecodeToken(token)
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(ContainSubstring("message authentication failed"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when decoding too short tokens", func() {
|
|
||||||
_, err := cp.DecodeToken([]byte("foobar"))
|
|
||||||
Expect(err).To(MatchError("Token too short: 6"))
|
|
||||||
})
|
|
||||||
})
|
|
File diff suppressed because it is too large
Load Diff
@ -1,731 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockKEX struct {
|
|
||||||
ephermal bool
|
|
||||||
sharedKeyError error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockKEX) PublicKey() []byte {
|
|
||||||
if m.ephermal {
|
|
||||||
return []byte("ephermal pub")
|
|
||||||
}
|
|
||||||
return []byte("initial public")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockKEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
|
|
||||||
if m.sharedKeyError != nil {
|
|
||||||
return nil, m.sharedKeyError
|
|
||||||
}
|
|
||||||
if m.ephermal {
|
|
||||||
return []byte("shared ephermal"), nil
|
|
||||||
}
|
|
||||||
return []byte("shared key"), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockSigner struct {
|
|
||||||
gotCHLO bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockSigner) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
|
|
||||||
if len(chlo) > 0 {
|
|
||||||
s.gotCHLO = true
|
|
||||||
}
|
|
||||||
return []byte("proof"), nil
|
|
||||||
}
|
|
||||||
func (*mockSigner) GetCertsCompressed(sni string, common, cached []byte) ([]byte, error) {
|
|
||||||
return []byte("certcompressed"), nil
|
|
||||||
}
|
|
||||||
func (*mockSigner) GetLeafCert(sni string) ([]byte, error) {
|
|
||||||
return []byte("certuncompressed"), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
|
|
||||||
return mockcrypto.NewMockAEAD(mockCtrl), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockStream struct {
|
|
||||||
unblockRead chan struct{}
|
|
||||||
dataToRead bytes.Buffer
|
|
||||||
dataWritten bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ io.ReadWriter = &mockStream{}
|
|
||||||
|
|
||||||
var errMockStreamClosing = errors.New("mock stream closing")
|
|
||||||
|
|
||||||
func newMockStream() *mockStream {
|
|
||||||
return &mockStream{unblockRead: make(chan struct{})}
|
|
||||||
}
|
|
||||||
|
|
||||||
// call Close to make Read return
|
|
||||||
func (s *mockStream) Read(p []byte) (int, error) {
|
|
||||||
n, _ := s.dataToRead.Read(p)
|
|
||||||
if n == 0 { // block if there's no data
|
|
||||||
<-s.unblockRead
|
|
||||||
return 0, errMockStreamClosing
|
|
||||||
}
|
|
||||||
return n, nil // never return an EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockStream) Write(p []byte) (int, error) {
|
|
||||||
return s.dataWritten.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockStream) close() {
|
|
||||||
close(s.unblockRead)
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockCookieProtector struct {
|
|
||||||
decodeErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ cookieProtector = &mockCookieProtector{}
|
|
||||||
|
|
||||||
func (mockCookieProtector) NewToken(sourceAddr []byte) ([]byte, error) {
|
|
||||||
return append([]byte("token "), sourceAddr...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s mockCookieProtector) DecodeToken(data []byte) ([]byte, error) {
|
|
||||||
if s.decodeErr != nil {
|
|
||||||
return nil, s.decodeErr
|
|
||||||
}
|
|
||||||
if len(data) < 6 {
|
|
||||||
return nil, errors.New("token too short")
|
|
||||||
}
|
|
||||||
return data[6:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Server Crypto Setup", func() {
|
|
||||||
var (
|
|
||||||
kex *mockKEX
|
|
||||||
signer *mockSigner
|
|
||||||
scfg *ServerConfig
|
|
||||||
cs *cryptoSetupServer
|
|
||||||
stream *mockStream
|
|
||||||
paramsChan chan TransportParameters
|
|
||||||
handshakeEvent chan struct{}
|
|
||||||
nonce32 []byte
|
|
||||||
versionTag []byte
|
|
||||||
validSTK []byte
|
|
||||||
aead []byte
|
|
||||||
kexs []byte
|
|
||||||
version protocol.VersionNumber
|
|
||||||
supportedVersions []protocol.VersionNumber
|
|
||||||
sourceAddrValid bool
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
expectedInitialNonceLen = 32
|
|
||||||
expectedFSNonceLen = 64
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
|
||||||
|
|
||||||
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
|
|
||||||
paramsChan = make(chan TransportParameters, 1)
|
|
||||||
handshakeEvent = make(chan struct{}, 2)
|
|
||||||
stream = newMockStream()
|
|
||||||
kex = &mockKEX{}
|
|
||||||
signer = &mockSigner{}
|
|
||||||
scfg, err = NewServerConfig(kex, signer)
|
|
||||||
nonce32 = make([]byte, 32)
|
|
||||||
aead = []byte("AESG")
|
|
||||||
kexs = []byte("C255")
|
|
||||||
copy(nonce32[4:12], scfg.obit) // set the OBIT value at the right position
|
|
||||||
versionTag = make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint32(versionTag, uint32(protocol.VersionWhatever))
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
|
||||||
supportedVersions = []protocol.VersionNumber{version, 98, 99}
|
|
||||||
csInt, err := NewCryptoSetup(
|
|
||||||
stream,
|
|
||||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
|
||||||
remoteAddr,
|
|
||||||
version,
|
|
||||||
make([]byte, 32), // div nonce
|
|
||||||
scfg,
|
|
||||||
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
|
||||||
supportedVersions,
|
|
||||||
nil,
|
|
||||||
paramsChan,
|
|
||||||
handshakeEvent,
|
|
||||||
utils.DefaultLogger,
|
|
||||||
)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
cs = csInt.(*cryptoSetupServer)
|
|
||||||
cs.scfg.cookieGenerator.cookieProtector = &mockCookieProtector{}
|
|
||||||
validSTK, err = cs.scfg.cookieGenerator.NewToken(remoteAddr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
sourceAddrValid = true
|
|
||||||
cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid }
|
|
||||||
cs.keyDerivation = mockQuicCryptoKeyDerivation
|
|
||||||
cs.keyExchange = func() (crypto.KeyExchange, error) { return &mockKEX{ephermal: true}, nil }
|
|
||||||
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
|
||||||
cs.cryptoStream = stream
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("when responding to client messages", func() {
|
|
||||||
var cert []byte
|
|
||||||
var xlct []byte
|
|
||||||
var fullCHLO map[Tag][]byte
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
xlct = make([]byte, 8)
|
|
||||||
var err error
|
|
||||||
cert, err = cs.scfg.certChain.GetLeafCert("")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
binary.LittleEndian.PutUint64(xlct, crypto.HashCert(cert))
|
|
||||||
fullCHLO = map[Tag][]byte{
|
|
||||||
TagSCID: scfg.ID,
|
|
||||||
TagSNI: []byte("quic.clemente.io"),
|
|
||||||
TagNONC: nonce32,
|
|
||||||
TagSTK: validSTK,
|
|
||||||
TagXLCT: xlct,
|
|
||||||
TagAEAD: aead,
|
|
||||||
TagKEXS: kexs,
|
|
||||||
TagPUBS: bytes.Repeat([]byte{'e'}, 31),
|
|
||||||
TagVER: versionTag,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't support Chrome's no STOP_WAITING experiment", func() {
|
|
||||||
HandshakeMessage{
|
|
||||||
Tag: TagCHLO,
|
|
||||||
Data: map[Tag][]byte{
|
|
||||||
TagNSTP: []byte("foobar"),
|
|
||||||
},
|
|
||||||
}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(ErrNSTPExperiment))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads the transport parameters sent by the client", func() {
|
|
||||||
sourceAddrValid = true
|
|
||||||
fullCHLO[TagICSL] = []byte{0x37, 0x13, 0, 0}
|
|
||||||
_, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), fullCHLO)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
var params TransportParameters
|
|
||||||
Expect(paramsChan).To(Receive(¶ms))
|
|
||||||
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates REJ messages", func() {
|
|
||||||
sourceAddrValid = false
|
|
||||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(response).To(HavePrefix("REJ"))
|
|
||||||
Expect(response).To(ContainSubstring("initial public"))
|
|
||||||
Expect(response).ToNot(ContainSubstring("certcompressed"))
|
|
||||||
Expect(response).ToNot(ContainSubstring("proof"))
|
|
||||||
Expect(signer.gotCHLO).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("REJ messages don't include cert or proof without STK", func() {
|
|
||||||
sourceAddrValid = false
|
|
||||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(response).To(HavePrefix("REJ"))
|
|
||||||
Expect(response).ToNot(ContainSubstring("certcompressed"))
|
|
||||||
Expect(response).ToNot(ContainSubstring("proof"))
|
|
||||||
Expect(signer.gotCHLO).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("REJ messages include cert and proof with valid STK", func() {
|
|
||||||
sourceAddrValid = true
|
|
||||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), map[Tag][]byte{
|
|
||||||
TagSTK: validSTK,
|
|
||||||
TagSNI: []byte("foo"),
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(response).To(HavePrefix("REJ"))
|
|
||||||
Expect(response).To(ContainSubstring("certcompressed"))
|
|
||||||
Expect(response).To(ContainSubstring("proof"))
|
|
||||||
Expect(signer.gotCHLO).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates SHLO messages", func() {
|
|
||||||
var checkedSecure, checkedForwardSecure bool
|
|
||||||
cs.keyDerivation = func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
|
|
||||||
if forwardSecure {
|
|
||||||
Expect(nonces).To(HaveLen(expectedFSNonceLen))
|
|
||||||
checkedForwardSecure = true
|
|
||||||
Expect(sharedSecret).To(Equal([]byte("shared ephermal")))
|
|
||||||
} else {
|
|
||||||
Expect(nonces).To(HaveLen(expectedInitialNonceLen))
|
|
||||||
Expect(sharedSecret).To(Equal([]byte("shared key")))
|
|
||||||
checkedSecure = true
|
|
||||||
}
|
|
||||||
return mockcrypto.NewMockAEAD(mockCtrl), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
|
|
||||||
TagPUBS: []byte("pubs-c"),
|
|
||||||
TagNONC: nonce32,
|
|
||||||
TagAEAD: aead,
|
|
||||||
TagKEXS: kexs,
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(response).To(HavePrefix("SHLO"))
|
|
||||||
message, err := ParseHandshakeMessage(bytes.NewReader(response))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(message.Data).To(HaveKeyWithValue(TagPUBS, []byte("ephermal pub")))
|
|
||||||
Expect(message.Data).To(HaveKey(TagSNO))
|
|
||||||
Expect(message.Data).To(HaveKey(TagVER))
|
|
||||||
Expect(message.Data[TagVER]).To(HaveLen(4 * len(supportedVersions)))
|
|
||||||
for _, v := range supportedVersions {
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(v))
|
|
||||||
Expect(message.Data[TagVER]).To(ContainSubstring(b.String()))
|
|
||||||
}
|
|
||||||
Expect(checkedSecure).To(BeTrue())
|
|
||||||
Expect(checkedForwardSecure).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles long handshake", func() {
|
|
||||||
HandshakeMessage{
|
|
||||||
Tag: TagCHLO,
|
|
||||||
Data: map[Tag][]byte{
|
|
||||||
TagSNI: []byte("quic.clemente.io"),
|
|
||||||
TagSTK: validSTK,
|
|
||||||
TagPAD: bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
|
|
||||||
TagVER: versionTag,
|
|
||||||
},
|
|
||||||
}.Write(&stream.dataToRead)
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
|
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
|
||||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
|
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
|
||||||
Expect(handshakeEvent).ToNot(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects client nonces that have the wrong length", func() {
|
|
||||||
fullCHLO[TagNONC] = []byte("too short client nonce")
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects client nonces that have the wrong OBIT value", func() {
|
|
||||||
fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if it can't calculate a shared key", func() {
|
|
||||||
testErr := errors.New("test error")
|
|
||||||
kex.sharedKeyError = testErr
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles 0-RTT handshake", func() {
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
|
|
||||||
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
|
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
|
||||||
Expect(handshakeEvent).ToNot(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes inchoate CHLOs missing SCID", func() {
|
|
||||||
delete(fullCHLO, TagSCID)
|
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes inchoate CHLOs missing PUBS", func() {
|
|
||||||
delete(fullCHLO, TagPUBS)
|
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes inchoate CHLOs with missing XLCT", func() {
|
|
||||||
delete(fullCHLO, TagXLCT)
|
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes inchoate CHLOs with wrong length XLCT", func() {
|
|
||||||
fullCHLO[TagXLCT] = bytes.Repeat([]byte{'f'}, 7) // should be 8 bytes
|
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes inchoate CHLOs with wrong XLCT", func() {
|
|
||||||
fullCHLO[TagXLCT] = bytes.Repeat([]byte{'f'}, 8)
|
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes inchoate CHLOs with an invalid STK", func() {
|
|
||||||
testErr := errors.New("STK invalid")
|
|
||||||
cs.scfg.cookieGenerator.cookieProtector.(*mockCookieProtector).decodeErr = testErr
|
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes proper CHLOs", func() {
|
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects CHLOs without the version tag", func() {
|
|
||||||
HandshakeMessage{
|
|
||||||
Tag: TagCHLO,
|
|
||||||
Data: map[Tag][]byte{
|
|
||||||
TagSCID: scfg.ID,
|
|
||||||
TagSNI: []byte("quic.clemente.io"),
|
|
||||||
},
|
|
||||||
}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects CHLOs with a version tag that has the wrong length", func() {
|
|
||||||
fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("detects version downgrade attacks", func() {
|
|
||||||
highestSupportedVersion := supportedVersions[len(supportedVersions)-1]
|
|
||||||
lowestSupportedVersion := supportedVersions[0]
|
|
||||||
Expect(highestSupportedVersion).ToNot(Equal(lowestSupportedVersion))
|
|
||||||
cs.version = highestSupportedVersion
|
|
||||||
b := make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint32(b, uint32(lowestSupportedVersion))
|
|
||||||
fullCHLO[TagVER] = b
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts a non-matching version tag in the CHLO, if it is an unsupported version", func() {
|
|
||||||
supportedVersion := protocol.SupportedVersions[0]
|
|
||||||
unsupportedVersion := supportedVersion + 1000
|
|
||||||
Expect(protocol.IsSupportedVersion(supportedVersions, unsupportedVersion)).To(BeFalse())
|
|
||||||
cs.version = supportedVersion
|
|
||||||
b := make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint32(b, uint32(unsupportedVersion))
|
|
||||||
fullCHLO[TagVER] = b
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the AEAD tag is missing", func() {
|
|
||||||
delete(fullCHLO, TagAEAD)
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the AEAD tag has the wrong value", func() {
|
|
||||||
fullCHLO[TagAEAD] = []byte("wrong")
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the KEXS tag is missing", func() {
|
|
||||||
delete(fullCHLO, TagKEXS)
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the KEXS tag has the wrong value", func() {
|
|
||||||
fullCHLO[TagKEXS] = []byte("wrong")
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors without SNI", func() {
|
|
||||||
HandshakeMessage{
|
|
||||||
Tag: TagCHLO,
|
|
||||||
Data: map[Tag][]byte{
|
|
||||||
TagSTK: validSTK,
|
|
||||||
},
|
|
||||||
}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with empty SNI", func() {
|
|
||||||
HandshakeMessage{
|
|
||||||
Tag: TagCHLO,
|
|
||||||
Data: map[Tag][]byte{
|
|
||||||
TagSTK: validSTK,
|
|
||||||
TagSNI: nil,
|
|
||||||
},
|
|
||||||
}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with invalid message", func() {
|
|
||||||
stream.dataToRead.Write([]byte("invalid message"))
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.HandshakeFailed))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with non-CHLO message", func() {
|
|
||||||
HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("escalating crypto", func() {
|
|
||||||
doCHLO := func() {
|
|
||||||
_, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
|
|
||||||
TagPUBS: []byte("pubs-c"),
|
|
||||||
TagNONC: nonce32,
|
|
||||||
TagAEAD: aead,
|
|
||||||
TagKEXS: kexs,
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
|
||||||
close(cs.sentSHLO)
|
|
||||||
}
|
|
||||||
|
|
||||||
Context("null encryption", func() {
|
|
||||||
It("is used initially", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(10), []byte{}).Return([]byte("foobar signed"))
|
|
||||||
enc, sealer := cs.GetSealer()
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 10, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar signed")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is used for the crypto stream", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
|
|
||||||
enc, sealer := cs.GetSealerForCryptoStream()
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
|
||||||
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is accepted initially", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(5), []byte{}).Return([]byte("decrypted"), nil)
|
|
||||||
d, enc, err := cs.Open(nil, []byte("unencrypted"), 5, []byte{})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(d).To(Equal([]byte("decrypted")))
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the has the wrong hash", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("not unencrypted"), protocol.PacketNumber(5), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
_, enc, err := cs.Open(nil, []byte("not unencrypted"), 5, []byte{})
|
|
||||||
Expect(err).To(MatchError("authentication failed"))
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is still accepted after CHLO", func() {
|
|
||||||
doCHLO()
|
|
||||||
// it tries forward secure and secure decryption first
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{})
|
|
||||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
|
||||||
_, enc, err := cs.Open(nil, []byte("unencrypted"), 99, []byte{})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is not accepted after receiving secure packet", func() {
|
|
||||||
doCHLO()
|
|
||||||
// first receive a secure packet
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
|
|
||||||
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionSecure))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(d).To(Equal([]byte("decrypted")))
|
|
||||||
// now receive an unencrypted packet
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
_, enc, err = cs.Open(nil, []byte("unencrypted"), 99, []byte{})
|
|
||||||
Expect(err).To(MatchError("authentication failed"))
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is not used after CHLO", func() {
|
|
||||||
doCHLO()
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
|
|
||||||
enc, sealer := cs.GetSealer()
|
|
||||||
Expect(enc).ToNot(Equal(protocol.EncryptionUnencrypted))
|
|
||||||
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("initial encryption", func() {
|
|
||||||
It("is accepted after CHLO", func() {
|
|
||||||
doCHLO()
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
|
|
||||||
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionSecure))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(d).To(Equal([]byte("decrypted")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is not accepted after receiving forward secure packet", func() {
|
|
||||||
doCHLO()
|
|
||||||
// receive a forward secure packet
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
|
|
||||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// receive a secure packet
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(12), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
_, enc, err := cs.Open(nil, []byte("encrypted"), 12, []byte{})
|
|
||||||
Expect(err).To(MatchError("authentication failed"))
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is used for the crypto stream", func() {
|
|
||||||
doCHLO()
|
|
||||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(1), []byte{}).Return([]byte("foobar crypto stream"))
|
|
||||||
enc, sealer := cs.GetSealerForCryptoStream()
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionSecure))
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 1, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar crypto stream")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("forward secure encryption", func() {
|
|
||||||
It("is used after the CHLO", func() {
|
|
||||||
doCHLO()
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(20), []byte{}).Return([]byte("foobar forward sec"))
|
|
||||||
enc, sealer := cs.GetSealer()
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 20, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar forward sec")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("regards the handshake as complete once it receives a forward encrypted packet", func() {
|
|
||||||
doCHLO()
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
|
|
||||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handshakeEvent).To(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("reporting the connection state", func() {
|
|
||||||
It("reports before the handshake completes", func() {
|
|
||||||
cs.sni = "server name"
|
|
||||||
state := cs.ConnectionState()
|
|
||||||
Expect(state.HandshakeComplete).To(BeFalse())
|
|
||||||
Expect(state.ServerName).To(Equal("server name"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reports after the handshake completes", func() {
|
|
||||||
doCHLO()
|
|
||||||
// receive a forward secure packet
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
|
|
||||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
state := cs.ConnectionState()
|
|
||||||
Expect(state.HandshakeComplete).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("forcing encryption levels", func() {
|
|
||||||
It("forces null encryption", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(11), []byte{}).Return([]byte("foobar unencrypted"))
|
|
||||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 11, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar unencrypted")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("forces initial encryption", func() {
|
|
||||||
doCHLO()
|
|
||||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(12), []byte{}).Return([]byte("foobar secure"))
|
|
||||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 12, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar secure")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if no AEAD for initial encryption is available", func() {
|
|
||||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
|
|
||||||
Expect(err).To(MatchError("CryptoSetupServer: no secureAEAD"))
|
|
||||||
Expect(sealer).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("forces forward-secure encryption", func() {
|
|
||||||
doCHLO()
|
|
||||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(13), []byte{}).Return([]byte("foobar forward sec"))
|
|
||||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 13, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar forward sec")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors of no AEAD for forward-secure encryption is available", func() {
|
|
||||||
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
|
|
||||||
Expect(err).To(MatchError("CryptoSetupServer: no forwardSecureAEAD"))
|
|
||||||
Expect(seal).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if no encryption level is specified", func() {
|
|
||||||
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified)
|
|
||||||
Expect(err).To(MatchError("CryptoSetupServer: no encryption level specified"))
|
|
||||||
Expect(seal).To(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("STK verification and creation", func() {
|
|
||||||
It("requires STK", func() {
|
|
||||||
sourceAddrValid = false
|
|
||||||
done, err := cs.handleMessage(
|
|
||||||
bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
|
|
||||||
map[Tag][]byte{
|
|
||||||
TagSNI: []byte("foo"),
|
|
||||||
TagVER: versionTag,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(done).To(BeFalse())
|
|
||||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
|
|
||||||
Expect(cs.sni).To(Equal("foo"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works with proper STK", func() {
|
|
||||||
sourceAddrValid = true
|
|
||||||
done, err := cs.handleMessage(
|
|
||||||
bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
|
|
||||||
map[Tag][]byte{
|
|
||||||
TagSNI: []byte("foo"),
|
|
||||||
TagVER: versionTag,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(done).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,192 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) {
|
|
||||||
return mockcrypto.NewMockAEAD(mockCtrl), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("TLS Crypto Setup", func() {
|
|
||||||
var (
|
|
||||||
cs *cryptoSetupTLS
|
|
||||||
handshakeEvent chan struct{}
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
handshakeEvent = make(chan struct{}, 2)
|
|
||||||
css, err := NewCryptoSetupTLSServer(
|
|
||||||
newCryptoStreamConn(bytes.NewBuffer([]byte{})),
|
|
||||||
protocol.ConnectionID{},
|
|
||||||
&mint.Config{},
|
|
||||||
handshakeEvent,
|
|
||||||
protocol.VersionTLS,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cs = css.(*cryptoSetupTLS)
|
|
||||||
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when the handshake fails", func() {
|
|
||||||
alert := mint.AlertBadRecordMAC
|
|
||||||
cs.tls = NewMockMintTLS(mockCtrl)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(alert)
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("derives keys", func() {
|
|
||||||
cs.tls = NewMockMintTLS(mockCtrl)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
|
|
||||||
cs.keyDerivation = mockKeyDerivation
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handshakeEvent).To(Receive())
|
|
||||||
Expect(handshakeEvent).To(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handshakes until it is connected", func() {
|
|
||||||
cs.tls = NewMockMintTLS(mockCtrl)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerNegotiated}).Times(9)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
|
|
||||||
cs.keyDerivation = mockKeyDerivation
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(handshakeEvent).To(Receive())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("reporting the handshake state", func() {
|
|
||||||
It("reports before the handshake compeletes", func() {
|
|
||||||
cs.tls = NewMockMintTLS(mockCtrl)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
|
|
||||||
state := cs.ConnectionState()
|
|
||||||
Expect(state.HandshakeComplete).To(BeFalse())
|
|
||||||
Expect(state.PeerCertificates).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reports after the handshake completes", func() {
|
|
||||||
cs.tls = NewMockMintTLS(mockCtrl)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected}).Times(2)
|
|
||||||
cs.keyDerivation = mockKeyDerivation
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
state := cs.ConnectionState()
|
|
||||||
Expect(state.HandshakeComplete).To(BeTrue())
|
|
||||||
Expect(state.PeerCertificates).To(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("escalating crypto", func() {
|
|
||||||
doHandshake := func() {
|
|
||||||
cs.tls = NewMockMintTLS(mockCtrl)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
|
||||||
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
|
|
||||||
cs.keyDerivation = mockKeyDerivation
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
}
|
|
||||||
|
|
||||||
Context("null encryption", func() {
|
|
||||||
It("is used initially", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar signed"))
|
|
||||||
enc, sealer := cs.GetSealer()
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 5, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar signed")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is used for opening", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("foobar enc"), protocol.PacketNumber(10), []byte{}).Return([]byte("foobar"), nil)
|
|
||||||
d, err := cs.OpenHandshake(nil, []byte("foobar enc"), 10, []byte{})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(d).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is used for crypto stream", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(20), []byte{}).Return([]byte("foobar signed"))
|
|
||||||
enc, sealer := cs.GetSealerForCryptoStream()
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 20, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar signed")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the has the wrong hash", func() {
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("foobar enc"), protocol.PacketNumber(10), []byte{}).Return(nil, errors.New("authentication failed"))
|
|
||||||
_, err := cs.OpenHandshake(nil, []byte("foobar enc"), 10, []byte{})
|
|
||||||
Expect(err).To(MatchError("authentication failed"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("forward-secure encryption", func() {
|
|
||||||
It("is used for sealing after the handshake completes", func() {
|
|
||||||
doHandshake()
|
|
||||||
cs.aead.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar forward sec"))
|
|
||||||
enc, sealer := cs.GetSealer()
|
|
||||||
Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 5, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar forward sec")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("is used for opening", func() {
|
|
||||||
doHandshake()
|
|
||||||
cs.aead.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(6), []byte{}).Return([]byte("decrypted"), nil)
|
|
||||||
d, err := cs.Open1RTT(nil, []byte("encrypted"), 6, []byte{})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(d).To(Equal([]byte("decrypted")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("forcing encryption levels", func() {
|
|
||||||
It("forces null encryption", func() {
|
|
||||||
doHandshake()
|
|
||||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar signed"))
|
|
||||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 5, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar signed")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("forces forward-secure encryption", func() {
|
|
||||||
doHandshake()
|
|
||||||
cs.aead.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar forward sec"))
|
|
||||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
d := sealer.Seal(nil, []byte("foobar"), 5, []byte{})
|
|
||||||
Expect(d).To(Equal([]byte("foobar forward sec")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the forward-secure AEAD is not available", func() {
|
|
||||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
|
|
||||||
Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level forward-secure"))
|
|
||||||
Expect(sealer).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("never returns a secure AEAD (they don't exist with TLS)", func() {
|
|
||||||
doHandshake()
|
|
||||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
|
|
||||||
Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level encrypted (not forward-secure)"))
|
|
||||||
Expect(sealer).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if no encryption level is specified", func() {
|
|
||||||
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified)
|
|
||||||
Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level unknown"))
|
|
||||||
Expect(seal).To(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,41 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Crypto Stream Conn", func() {
|
|
||||||
var (
|
|
||||||
stream *bytes.Buffer
|
|
||||||
csc *cryptoStreamConn
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
stream = &bytes.Buffer{}
|
|
||||||
csc = newCryptoStreamConn(stream)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("buffers writes", func() {
|
|
||||||
_, err := csc.Write([]byte("foo"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(stream.Len()).To(BeZero())
|
|
||||||
_, err = csc.Write([]byte("bar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(stream.Len()).To(BeZero())
|
|
||||||
|
|
||||||
Expect(csc.Flush()).To(Succeed())
|
|
||||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads from the stream", func() {
|
|
||||||
stream.Write([]byte("foobar"))
|
|
||||||
b := make([]byte, 6)
|
|
||||||
n, err := csc.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(b).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
})
|
|
File diff suppressed because one or more lines are too long
@ -1,35 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Ephermal KEX", func() {
|
|
||||||
It("has a consistent KEX", func() {
|
|
||||||
kex1, err := getEphermalKEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(kex1).ToNot(BeNil())
|
|
||||||
kex2, err := getEphermalKEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(kex2).ToNot(BeNil())
|
|
||||||
Expect(kex1).To(Equal(kex2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("changes KEX", func() {
|
|
||||||
kexLifetime = 10 * time.Millisecond
|
|
||||||
defer func() {
|
|
||||||
kexLifetime = protocol.EphermalKeyLifetime
|
|
||||||
}()
|
|
||||||
kex, err := getEphermalKEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(kex).ToNot(BeNil())
|
|
||||||
time.Sleep(kexLifetime)
|
|
||||||
kex2, err := getEphermalKEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(kex2).ToNot(Equal(kex))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,71 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Handshake Message", func() {
|
|
||||||
Context("when parsing", func() {
|
|
||||||
It("parses sample CHLO message", func() {
|
|
||||||
msg, err := ParseHandshakeMessage(bytes.NewReader(sampleCHLO))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(msg.Tag).To(Equal(TagCHLO))
|
|
||||||
Expect(msg.Data).To(Equal(sampleCHLOMap))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects large numbers of pairs", func() {
|
|
||||||
r := bytes.NewReader([]byte("CHLO\xff\xff\xff\xff"))
|
|
||||||
_, err := ParseHandshakeMessage(r)
|
|
||||||
Expect(err).To(MatchError(qerr.CryptoTooManyEntries))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects too long values", func() {
|
|
||||||
r := bytes.NewReader([]byte{
|
|
||||||
'C', 'H', 'L', 'O',
|
|
||||||
1, 0, 0, 0,
|
|
||||||
0, 0, 0, 0,
|
|
||||||
0xff, 0xff, 0xff, 0xff,
|
|
||||||
})
|
|
||||||
_, err := ParseHandshakeMessage(r)
|
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoInvalidValueLength, "value too long")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("when writing", func() {
|
|
||||||
It("writes sample message", func() {
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: sampleCHLOMap}.Write(b)
|
|
||||||
Expect(b.Bytes()).To(Equal(sampleCHLO))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("string representation", func() {
|
|
||||||
It("has a string representation", func() {
|
|
||||||
str := HandshakeMessage{
|
|
||||||
Tag: TagSHLO,
|
|
||||||
Data: map[Tag][]byte{
|
|
||||||
TagAEAD: []byte("foobar"),
|
|
||||||
TagEXPY: []byte("raboof"),
|
|
||||||
},
|
|
||||||
}.String()
|
|
||||||
Expect(str[:4]).To(Equal("SHLO"))
|
|
||||||
Expect(str).To(ContainSubstring("AEAD: \"foobar\""))
|
|
||||||
Expect(str).To(ContainSubstring("EXPY: \"raboof\""))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("lists padding separately", func() {
|
|
||||||
str := HandshakeMessage{
|
|
||||||
Tag: TagSHLO,
|
|
||||||
Data: map[Tag][]byte{
|
|
||||||
TagPAD: bytes.Repeat([]byte{0}, 1337),
|
|
||||||
},
|
|
||||||
}.String()
|
|
||||||
Expect(str).To(ContainSubstring("PAD"))
|
|
||||||
Expect(str).To(ContainSubstring("1337 bytes"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,24 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestQuicGo(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "Handshake Suite")
|
|
||||||
}
|
|
||||||
|
|
||||||
var mockCtrl *gomock.Controller
|
|
||||||
|
|
||||||
var _ = BeforeEach(func() {
|
|
||||||
mockCtrl = gomock.NewController(GinkgoT())
|
|
||||||
})
|
|
||||||
|
|
||||||
var _ = AfterEach(func() {
|
|
||||||
mockCtrl.Finish()
|
|
||||||
})
|
|
@ -1,72 +0,0 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS)
|
|
||||||
|
|
||||||
// Package handshake is a generated GoMock package.
|
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
mint "github.com/bifurcation/mint"
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockMintTLS is a mock of MintTLS interface
|
|
||||||
type MockMintTLS struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockMintTLSMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockMintTLSMockRecorder is the mock recorder for MockMintTLS
|
|
||||||
type MockMintTLSMockRecorder struct {
|
|
||||||
mock *MockMintTLS
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockMintTLS creates a new mock instance
|
|
||||||
func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS {
|
|
||||||
mock := &MockMintTLS{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockMintTLSMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use
|
|
||||||
func (m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// ComputeExporter mocks base method
|
|
||||||
func (m *MockMintTLS) ComputeExporter(arg0 string, arg1 []byte, arg2 int) ([]byte, error) {
|
|
||||||
ret := m.ctrl.Call(m, "ComputeExporter", arg0, arg1, arg2)
|
|
||||||
ret0, _ := ret[0].([]byte)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// ComputeExporter indicates an expected call of ComputeExporter
|
|
||||||
func (mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionState mocks base method
|
|
||||||
func (m *MockMintTLS) ConnectionState() mint.ConnectionState {
|
|
||||||
ret := m.ctrl.Call(m, "ConnectionState")
|
|
||||||
ret0, _ := ret[0].(mint.ConnectionState)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionState indicates an expected call of ConnectionState
|
|
||||||
func (mr *MockMintTLSMockRecorder) ConnectionState() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockMintTLS)(nil).ConnectionState))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handshake mocks base method
|
|
||||||
func (m *MockMintTLS) Handshake() mint.Alert {
|
|
||||||
ret := m.ctrl.Call(m, "Handshake")
|
|
||||||
ret0, _ := ret[0].(mint.Alert)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handshake indicates an expected call of Handshake
|
|
||||||
func (mr *MockMintTLSMockRecorder) Handshake() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake))
|
|
||||||
}
|
|
@ -1,266 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This tagMap can be passed to parseValues and is garantueed to not cause any errors
|
|
||||||
func getDefaultServerConfigClient() map[Tag][]byte {
|
|
||||||
return map[Tag][]byte{
|
|
||||||
TagSCID: bytes.Repeat([]byte{'F'}, 16),
|
|
||||||
TagKEXS: []byte("C255"),
|
|
||||||
TagAEAD: []byte("AESG"),
|
|
||||||
TagPUBS: append([]byte{0x20, 0x00, 0x00}, bytes.Repeat([]byte{0}, 32)...),
|
|
||||||
TagOBIT: bytes.Repeat([]byte{0}, 8),
|
|
||||||
TagEXPY: {0x0, 0x6c, 0x57, 0x78, 0, 0, 0, 0}, // 2033-12-24
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = Describe("Server Config", func() {
|
|
||||||
var tagMap map[Tag][]byte
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
tagMap = getDefaultServerConfigClient()
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the parsed server config", func() {
|
|
||||||
tagMap[TagSCID] = []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b)
|
|
||||||
scfg, err := parseServerConfig(b.Bytes())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg.ID).To(Equal(tagMap[TagSCID]))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves the raw server config", func() {
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b)
|
|
||||||
scfg, err := parseServerConfig(b.Bytes())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg.raw).To(Equal(b.Bytes()))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("tells if a server config is expired", func() {
|
|
||||||
scfg := &serverConfigClient{}
|
|
||||||
scfg.expiry = time.Now().Add(-time.Second)
|
|
||||||
Expect(scfg.IsExpired()).To(BeTrue())
|
|
||||||
scfg.expiry = time.Now().Add(time.Second)
|
|
||||||
Expect(scfg.IsExpired()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("parsing the server config", func() {
|
|
||||||
It("rejects a handshake message with the wrong message tag", func() {
|
|
||||||
var serverConfig bytes.Buffer
|
|
||||||
HandshakeMessage{Tag: TagCHLO, Data: make(map[Tag][]byte)}.Write(&serverConfig)
|
|
||||||
_, err := parseServerConfig(serverConfig.Bytes())
|
|
||||||
Expect(err).To(MatchError(errMessageNotServerConfig))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors on invalid handshake messages", func() {
|
|
||||||
var serverConfig bytes.Buffer
|
|
||||||
HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig)
|
|
||||||
_, err := parseServerConfig(serverConfig.Bytes()[:serverConfig.Len()-2])
|
|
||||||
Expect(err).To(MatchError("unexpected EOF"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("passes on errors encountered when reading the TagMap", func() {
|
|
||||||
var serverConfig bytes.Buffer
|
|
||||||
HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig)
|
|
||||||
_, err := parseServerConfig(serverConfig.Bytes())
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads an example Handshake Message", func() {
|
|
||||||
var serverConfig bytes.Buffer
|
|
||||||
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(&serverConfig)
|
|
||||||
scfg, err := parseServerConfig(serverConfig.Bytes())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg.ID).To(Equal(tagMap[TagSCID]))
|
|
||||||
Expect(scfg.obit).To(Equal(tagMap[TagOBIT]))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Reading values from the TagMap", func() {
|
|
||||||
var scfg *serverConfigClient
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
scfg = &serverConfigClient{}
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ServerConfig ID", func() {
|
|
||||||
It("parses the ServerConfig ID", func() {
|
|
||||||
id := []byte{0xb2, 0xa4, 0xbb, 0x8f, 0xf6, 0x51, 0x28, 0xfd, 0x4d, 0xf7, 0xb3, 0x9a, 0x91, 0xe7, 0x91, 0xfb}
|
|
||||||
tagMap[TagSCID] = id
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg.ID).To(Equal(id))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the ServerConfig ID is missing", func() {
|
|
||||||
delete(tagMap, TagSCID)
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects ServerConfig IDs that have the wrong length", func() {
|
|
||||||
tagMap[TagSCID] = bytes.Repeat([]byte{'F'}, 17) // 1 byte too long
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoInvalidValueLength: SCID"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("KEXS", func() {
|
|
||||||
It("rejects KEXS values that have the wrong length", func() {
|
|
||||||
tagMap[TagKEXS] = bytes.Repeat([]byte{'F'}, 5) // 1 byte too long
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoInvalidValueLength: KEXS"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects KEXS values other than C255", func() {
|
|
||||||
tagMap[TagKEXS] = []byte("P256")
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoNoSupport: KEXS: Could not find C255, other key exchanges are not supported"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the KEXS is missing", func() {
|
|
||||||
delete(tagMap, TagKEXS)
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: KEXS"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("AEAD", func() {
|
|
||||||
It("rejects AEAD values that have the wrong length", func() {
|
|
||||||
tagMap[TagAEAD] = bytes.Repeat([]byte{'F'}, 5) // 1 byte too long
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoInvalidValueLength: AEAD"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects AEAD values other than AESG", func() {
|
|
||||||
tagMap[TagAEAD] = []byte("S20P")
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoNoSupport: AEAD"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes AESG in the list of AEADs, at the first position", func() {
|
|
||||||
tagMap[TagAEAD] = []byte("AESGS20P")
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("recognizes AESG in the list of AEADs, not at the first position", func() {
|
|
||||||
tagMap[TagAEAD] = []byte("S20PAESG")
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the AEAD is missing", func() {
|
|
||||||
delete(tagMap, TagAEAD)
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: AEAD"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("PUBS", func() {
|
|
||||||
It("creates a Curve25519 key exchange", func() {
|
|
||||||
serverKex, err := crypto.NewCurve25519KEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
tagMap[TagPUBS] = append([]byte{0x20, 0x00, 0x00}, serverKex.PublicKey()...)
|
|
||||||
err = scfg.parseValues(tagMap)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
sharedSecret, err := serverKex.CalculateSharedKey(scfg.kex.PublicKey())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg.sharedSecret).To(Equal(sharedSecret))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects PUBS values that have the wrong length", func() {
|
|
||||||
tagMap[TagPUBS] = bytes.Repeat([]byte{'F'}, 100) // completely wrong length
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoInvalidValueLength: PUBS"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects PUBS values that have a zero length", func() {
|
|
||||||
tagMap[TagPUBS] = bytes.Repeat([]byte{0}, 100) // completely wrong length
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoInvalidValueLength: PUBS"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ensure that C255 Pubs must not be at the first index", func() {
|
|
||||||
serverKex, err := crypto.NewCurve25519KEX()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
tagMap[TagKEXS] = []byte("P256C255") // have another KEXS before C255
|
|
||||||
// 3 byte len + 1 byte empty + C255
|
|
||||||
tagMap[TagPUBS] = append([]byte{0x01, 0x00, 0x00, 0x00}, append([]byte{0x20, 0x00, 0x00}, serverKex.PublicKey()...)...)
|
|
||||||
err = scfg.parseValues(tagMap)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
sharedSecret, err := serverKex.CalculateSharedKey(scfg.kex.PublicKey())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg.sharedSecret).To(Equal(sharedSecret))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the PUBS is missing", func() {
|
|
||||||
delete(tagMap, TagPUBS)
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: PUBS"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("OBIT", func() {
|
|
||||||
It("parses the OBIT value", func() {
|
|
||||||
obit := []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}
|
|
||||||
tagMap[TagOBIT] = obit
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg.obit).To(Equal(obit))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the OBIT is missing", func() {
|
|
||||||
delete(tagMap, TagOBIT)
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: OBIT"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejets OBIT values that have the wrong length", func() {
|
|
||||||
tagMap[TagOBIT] = bytes.Repeat([]byte{'F'}, 7) // 1 byte too short
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoInvalidValueLength: OBIT"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("EXPY", func() {
|
|
||||||
It("parses the expiry date", func() {
|
|
||||||
tagMap[TagEXPY] = []byte{0xdc, 0x89, 0x0e, 0x59, 0, 0, 0, 0} // UNIX Timestamp 0x590e89dc = 1494125020
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
year, month, day := scfg.expiry.UTC().Date()
|
|
||||||
Expect(year).To(Equal(2017))
|
|
||||||
Expect(month).To(Equal(time.Month(5)))
|
|
||||||
Expect(day).To(Equal(7))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the EXPY is missing", func() {
|
|
||||||
delete(tagMap, TagEXPY)
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: EXPY"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects EXPY values that have the wrong length", func() {
|
|
||||||
tagMap[TagEXPY] = bytes.Repeat([]byte{'F'}, 9) // 1 byte too long
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).To(MatchError("CryptoInvalidValueLength: EXPY"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("deals with absurdly large timestamps", func() {
|
|
||||||
tagMap[TagEXPY] = bytes.Repeat([]byte{0xff}, 8) // this would overflow the int64
|
|
||||||
err := scfg.parseValues(tagMap)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg.expiry.After(time.Now())).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,45 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("ServerConfig", func() {
|
|
||||||
var (
|
|
||||||
kex crypto.KeyExchange
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
kex, err = crypto.NewCurve25519KEX()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates a random ID and OBIT", func() {
|
|
||||||
scfg1, err := NewServerConfig(kex, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
scfg2, err := NewServerConfig(kex, nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(scfg1.ID).ToNot(Equal(scfg2.ID))
|
|
||||||
Expect(scfg1.obit).ToNot(Equal(scfg2.obit))
|
|
||||||
Expect(scfg1.cookieGenerator).ToNot(Equal(scfg2.cookieGenerator))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("gets the proper binary representation", func() {
|
|
||||||
scfg, err := NewServerConfig(kex, nil)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
expected := bytes.NewBuffer([]byte{0x53, 0x43, 0x46, 0x47, 0x6, 0x0, 0x0, 0x0, 0x41, 0x45, 0x41, 0x44, 0x4, 0x0, 0x0, 0x0, 0x53, 0x43, 0x49, 0x44, 0x14, 0x0, 0x0, 0x0, 0x50, 0x55, 0x42, 0x53, 0x37, 0x0, 0x0, 0x0, 0x4b, 0x45, 0x58, 0x53, 0x3b, 0x0, 0x0, 0x0, 0x4f, 0x42, 0x49, 0x54, 0x43, 0x0, 0x0, 0x0, 0x45, 0x58, 0x50, 0x59, 0x4b, 0x0, 0x0, 0x0, 0x41, 0x45, 0x53, 0x47})
|
|
||||||
expected.Write(scfg.ID)
|
|
||||||
expected.Write([]byte{0x20, 0x0, 0x0})
|
|
||||||
expected.Write(kex.PublicKey())
|
|
||||||
expected.Write([]byte{0x43, 0x32, 0x35, 0x35})
|
|
||||||
expected.Write(scfg.obit)
|
|
||||||
expected.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
|
|
||||||
Expect(scfg.Get()).To(Equal(expected.Bytes()))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,233 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("TLS Extension Handler, for the client", func() {
|
|
||||||
var (
|
|
||||||
handler *extensionHandlerClient
|
|
||||||
el mint.ExtensionList
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever, utils.DefaultLogger).(*extensionHandlerClient)
|
|
||||||
el = make(mint.ExtensionList, 0)
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("sending", func() {
|
|
||||||
It("only adds TransportParameters for the ClientHello", func() {
|
|
||||||
// test 2 other handshake types
|
|
||||||
err := handler.Send(mint.HandshakeTypeCertificateRequest, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(el).To(BeEmpty())
|
|
||||||
err = handler.Send(mint.HandshakeTypeEndOfEarlyData, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(el).To(BeEmpty())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds TransportParameters to the ClientHello", func() {
|
|
||||||
handler.initialVersion = 13
|
|
||||||
err := handler.Send(mint.HandshakeTypeClientHello, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(el).To(HaveLen(1))
|
|
||||||
ext := &tlsExtensionBody{}
|
|
||||||
found, err := el.Find(ext)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(found).To(BeTrue())
|
|
||||||
chtp := &clientHelloTransportParameters{}
|
|
||||||
err = chtp.Unmarshal(ext.data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(chtp.InitialVersion).To(BeEquivalentTo(13))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("receiving", func() {
|
|
||||||
var fakeBody *tlsExtensionBody
|
|
||||||
var parameters TransportParameters
|
|
||||||
|
|
||||||
addEncryptedExtensionsWithParameters := func(params TransportParameters) {
|
|
||||||
body := (&encryptedExtensionsTransportParameters{
|
|
||||||
Parameters: params,
|
|
||||||
SupportedVersions: []protocol.VersionNumber{handler.version},
|
|
||||||
}).Marshal()
|
|
||||||
Expect(el.Add(&tlsExtensionBody{data: body})).To(Succeed())
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
fakeBody = &tlsExtensionBody{data: []byte("foobar foobar")}
|
|
||||||
parameters = TransportParameters{
|
|
||||||
IdleTimeout: 0x1337 * time.Second,
|
|
||||||
StatelessResetToken: bytes.Repeat([]byte{0}, 16),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("blocks until the transport parameters are read", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
addEncryptedExtensionsWithParameters(parameters)
|
|
||||||
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
|
||||||
Expect(handler.GetPeerParams()).To(Receive())
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts the TransportParameters on the EncryptedExtensions message", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
addEncryptedExtensionsWithParameters(parameters)
|
|
||||||
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
var params TransportParameters
|
|
||||||
Eventually(handler.GetPeerParams()).Should(Receive(¶ms))
|
|
||||||
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the EncryptedExtensions message doesn't contain TransportParameters", func() {
|
|
||||||
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).To(MatchError("EncryptedExtensions message didn't contain a QUIC extension"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the TransportParameters on a wrong handshake types", func() {
|
|
||||||
err := el.Add(fakeBody)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeCertificate, &el)
|
|
||||||
Expect(err).To(MatchError(fmt.Sprintf("Unexpected QUIC extension in handshake message %d", mint.HandshakeTypeCertificate)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores messages without TransportParameters, if they are not required", func() {
|
|
||||||
err := handler.Receive(mint.HandshakeTypeCertificate, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when it can't parse the TransportParameters", func() {
|
|
||||||
err := el.Add(fakeBody)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).To(HaveOccurred()) // this will be some kind of decoding error
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects TransportParameters if they don't contain the stateless reset token", func() {
|
|
||||||
parameters.StatelessResetToken = nil
|
|
||||||
addEncryptedExtensionsWithParameters(parameters)
|
|
||||||
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).To(MatchError("server didn't sent stateless_reset_token"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Version Negotiation", func() {
|
|
||||||
It("accepts a valid version negotiation", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
Eventually(handler.GetPeerParams()).Should(Receive())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
handler.initialVersion = 13
|
|
||||||
handler.version = 37
|
|
||||||
handler.supportedVersions = []protocol.VersionNumber{13, 37, 42}
|
|
||||||
body := (&encryptedExtensionsTransportParameters{
|
|
||||||
Parameters: parameters,
|
|
||||||
NegotiatedVersion: 37,
|
|
||||||
SupportedVersions: []protocol.VersionNumber{36, 37, 38},
|
|
||||||
}).Marshal()
|
|
||||||
err := el.Add(&tlsExtensionBody{data: body})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the current version doesn't match negotiated_version", func() {
|
|
||||||
handler.initialVersion = 13
|
|
||||||
handler.version = 37
|
|
||||||
handler.supportedVersions = []protocol.VersionNumber{13, 37, 42}
|
|
||||||
body := (&encryptedExtensionsTransportParameters{
|
|
||||||
Parameters: parameters,
|
|
||||||
NegotiatedVersion: 38,
|
|
||||||
SupportedVersions: []protocol.VersionNumber{36, 37, 38},
|
|
||||||
}).Marshal()
|
|
||||||
err := el.Add(&tlsExtensionBody{data: body})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).To(MatchError("VersionNegotiationMismatch: current version doesn't match negotiated_version"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the current version is not contained in the server's supported versions", func() {
|
|
||||||
handler.version = 42
|
|
||||||
body := (&encryptedExtensionsTransportParameters{
|
|
||||||
NegotiatedVersion: 42,
|
|
||||||
SupportedVersions: []protocol.VersionNumber{43, 44},
|
|
||||||
}).Marshal()
|
|
||||||
err := el.Add(&tlsExtensionBody{data: body})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).To(MatchError("VersionNegotiationMismatch: current version not included in the supported versions"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if version negotiation was performed, but would have picked a different version based on the supported version list", func() {
|
|
||||||
handler.version = 42
|
|
||||||
handler.initialVersion = 41
|
|
||||||
handler.supportedVersions = []protocol.VersionNumber{43, 42, 41}
|
|
||||||
serverSupportedVersions := []protocol.VersionNumber{42, 43}
|
|
||||||
// check that version negotiation would have led us to pick version 43
|
|
||||||
ver, ok := protocol.ChooseSupportedVersion(handler.supportedVersions, serverSupportedVersions)
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
Expect(ver).To(Equal(protocol.VersionNumber(43)))
|
|
||||||
body := (&encryptedExtensionsTransportParameters{
|
|
||||||
NegotiatedVersion: 42,
|
|
||||||
SupportedVersions: serverSupportedVersions,
|
|
||||||
}).Marshal()
|
|
||||||
err := el.Add(&tlsExtensionBody{data: body})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).To(MatchError("VersionNegotiationMismatch: would have picked a different version"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't error if it would have picked a different version based on the supported version list, if no version negotiation was performed", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
Eventually(handler.GetPeerParams()).Should(Receive())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
handler.version = 42
|
|
||||||
handler.initialVersion = 42 // version == initialVersion means no version negotiation was performed
|
|
||||||
handler.supportedVersions = []protocol.VersionNumber{43, 42, 41}
|
|
||||||
serverSupportedVersions := []protocol.VersionNumber{42, 43}
|
|
||||||
// check that version negotiation would have led us to pick version 43
|
|
||||||
ver, ok := protocol.ChooseSupportedVersion(handler.supportedVersions, serverSupportedVersions)
|
|
||||||
Expect(ok).To(BeTrue())
|
|
||||||
Expect(ver).To(Equal(protocol.VersionNumber(43)))
|
|
||||||
body := (&encryptedExtensionsTransportParameters{
|
|
||||||
Parameters: parameters,
|
|
||||||
NegotiatedVersion: 42,
|
|
||||||
SupportedVersions: serverSupportedVersions,
|
|
||||||
}).Marshal()
|
|
||||||
err := el.Add(&tlsExtensionBody{data: body})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,155 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("TLS Extension Handler, for the server", func() {
|
|
||||||
var (
|
|
||||||
handler *extensionHandlerServer
|
|
||||||
el mint.ExtensionList
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever, utils.DefaultLogger).(*extensionHandlerServer)
|
|
||||||
el = make(mint.ExtensionList, 0)
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("sending", func() {
|
|
||||||
It("only adds TransportParameters for the ClientHello", func() {
|
|
||||||
// test 2 other handshake types
|
|
||||||
err := handler.Send(mint.HandshakeTypeCertificateRequest, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(el).To(BeEmpty())
|
|
||||||
err = handler.Send(mint.HandshakeTypeEndOfEarlyData, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(el).To(BeEmpty())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds TransportParameters to the EncryptedExtensions message", func() {
|
|
||||||
handler.version = 666
|
|
||||||
versions := []protocol.VersionNumber{13, 37, 42}
|
|
||||||
handler.supportedVersions = versions
|
|
||||||
err := handler.Send(mint.HandshakeTypeEncryptedExtensions, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(el).To(HaveLen(1))
|
|
||||||
ext := &tlsExtensionBody{}
|
|
||||||
found, err := el.Find(ext)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(found).To(BeTrue())
|
|
||||||
eetp := &encryptedExtensionsTransportParameters{}
|
|
||||||
err = eetp.Unmarshal(ext.data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(eetp.NegotiatedVersion).To(BeEquivalentTo(666))
|
|
||||||
// the SupportedVersions will contain one reserved version number
|
|
||||||
Expect(eetp.SupportedVersions).To(HaveLen(len(versions) + 1))
|
|
||||||
for _, version := range versions {
|
|
||||||
Expect(eetp.SupportedVersions).To(ContainElement(version))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("receiving", func() {
|
|
||||||
var (
|
|
||||||
fakeBody *tlsExtensionBody
|
|
||||||
parameters TransportParameters
|
|
||||||
)
|
|
||||||
|
|
||||||
addClientHelloWithParameters := func(params TransportParameters) {
|
|
||||||
body := (&clientHelloTransportParameters{Parameters: params}).Marshal()
|
|
||||||
Expect(el.Add(&tlsExtensionBody{data: body})).To(Succeed())
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
fakeBody = &tlsExtensionBody{data: []byte("foobar foobar")}
|
|
||||||
parameters = TransportParameters{IdleTimeout: 0x1337 * time.Second}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts the TransportParameters on the EncryptedExtensions message", func() {
|
|
||||||
addClientHelloWithParameters(parameters)
|
|
||||||
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
var params TransportParameters
|
|
||||||
Expect(handler.GetPeerParams()).To(Receive(¶ms))
|
|
||||||
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the ClientHello doesn't contain TransportParameters", func() {
|
|
||||||
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
|
|
||||||
Expect(err).To(MatchError("ClientHello didn't contain a QUIC extension"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores messages without TransportParameters, if they are not required", func() {
|
|
||||||
err := handler.Receive(mint.HandshakeTypeCertificate, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if it can't unmarshal the TransportParameters", func() {
|
|
||||||
err := el.Add(fakeBody)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeClientHello, &el)
|
|
||||||
Expect(err).To(HaveOccurred()) // this will be some kind of decoding error
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects messages other than the ClientHello that contain TransportParameters", func() {
|
|
||||||
addClientHelloWithParameters(parameters)
|
|
||||||
err := handler.Receive(mint.HandshakeTypeCertificateRequest, &el)
|
|
||||||
Expect(err).To(MatchError(fmt.Sprintf("Unexpected QUIC extension in handshake message %d", mint.HandshakeTypeCertificateRequest)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects messages that contain a stateless reset token", func() {
|
|
||||||
parameters.StatelessResetToken = bytes.Repeat([]byte{0}, 16)
|
|
||||||
addClientHelloWithParameters(parameters)
|
|
||||||
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
|
|
||||||
Expect(err).To(MatchError("client sent a stateless reset token"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Version Negotiation", func() {
|
|
||||||
It("accepts a ClientHello, when no version negotiation was performed", func() {
|
|
||||||
handler.version = 42
|
|
||||||
body := (&clientHelloTransportParameters{
|
|
||||||
InitialVersion: 42,
|
|
||||||
Parameters: parameters,
|
|
||||||
}).Marshal()
|
|
||||||
err := el.Add(&tlsExtensionBody{data: body})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeClientHello, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("accepts a valid version negotiation", func() {
|
|
||||||
handler.version = 42
|
|
||||||
handler.supportedVersions = []protocol.VersionNumber{13, 37, 42}
|
|
||||||
body := (&clientHelloTransportParameters{
|
|
||||||
InitialVersion: 22, // this must be an unsupported version
|
|
||||||
Parameters: parameters,
|
|
||||||
}).Marshal()
|
|
||||||
err := el.Add(&tlsExtensionBody{data: body})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeClientHello, &el)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("erros when a version negotiation was performed, although we already support the initial version", func() {
|
|
||||||
handler.supportedVersions = []protocol.VersionNumber{11, 12, 13}
|
|
||||||
handler.version = 13
|
|
||||||
body := (&clientHelloTransportParameters{
|
|
||||||
InitialVersion: 11, // this is an supported version
|
|
||||||
}).Marshal()
|
|
||||||
err := el.Add(&tlsExtensionBody{data: body})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = handler.Receive(mint.HandshakeTypeClientHello, &el)
|
|
||||||
Expect(err).To(MatchError("VersionNegotiationMismatch: Client should have used the initial version"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,95 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("TLS extension body", func() {
|
|
||||||
Context("Client Hello Transport Parameters", func() {
|
|
||||||
It("marshals and unmarshals", func() {
|
|
||||||
chtp := &clientHelloTransportParameters{
|
|
||||||
InitialVersion: 0x123456,
|
|
||||||
Parameters: TransportParameters{
|
|
||||||
StreamFlowControlWindow: 0x42,
|
|
||||||
IdleTimeout: 0x1337 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
chtp2 := &clientHelloTransportParameters{}
|
|
||||||
Expect(chtp2.Unmarshal(chtp.Marshal())).To(Succeed())
|
|
||||||
Expect(chtp2.InitialVersion).To(Equal(chtp.InitialVersion))
|
|
||||||
Expect(chtp2.Parameters.StreamFlowControlWindow).To(Equal(chtp.Parameters.StreamFlowControlWindow))
|
|
||||||
Expect(chtp2.Parameters.IdleTimeout).To(Equal(chtp.Parameters.IdleTimeout))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fuzzes", func() {
|
|
||||||
rand := rand.New(rand.NewSource(GinkgoRandomSeed()))
|
|
||||||
b := make([]byte, 100)
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
rand.Read(b)
|
|
||||||
chtp := &clientHelloTransportParameters{}
|
|
||||||
chtp.Unmarshal(b[:int(rand.Int31n(100))])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Encrypted Extensions Transport Parameters", func() {
|
|
||||||
It("marshals and unmarshals", func() {
|
|
||||||
eetp := &encryptedExtensionsTransportParameters{
|
|
||||||
NegotiatedVersion: 0x123456,
|
|
||||||
SupportedVersions: []protocol.VersionNumber{0x42, 0x4242},
|
|
||||||
Parameters: TransportParameters{
|
|
||||||
StreamFlowControlWindow: 0x42,
|
|
||||||
IdleTimeout: 0x1337 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
eetp2 := &encryptedExtensionsTransportParameters{}
|
|
||||||
Expect(eetp2.Unmarshal(eetp.Marshal())).To(Succeed())
|
|
||||||
Expect(eetp2.NegotiatedVersion).To(Equal(eetp.NegotiatedVersion))
|
|
||||||
Expect(eetp2.SupportedVersions).To(Equal(eetp.SupportedVersions))
|
|
||||||
Expect(eetp2.Parameters.StreamFlowControlWindow).To(Equal(eetp.Parameters.StreamFlowControlWindow))
|
|
||||||
Expect(eetp2.Parameters.IdleTimeout).To(Equal(eetp.Parameters.IdleTimeout))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fuzzes", func() {
|
|
||||||
rand := rand.New(rand.NewSource(GinkgoRandomSeed()))
|
|
||||||
b := make([]byte, 100)
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
rand.Read(b)
|
|
||||||
chtp := &encryptedExtensionsTransportParameters{}
|
|
||||||
chtp.Unmarshal(b[:int(rand.Int31n(100))])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("TLS Extension Body", func() {
|
|
||||||
var extBody *tlsExtensionBody
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
extBody = &tlsExtensionBody{}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has the right TLS extension type", func() {
|
|
||||||
Expect(extBody.Type()).To(BeEquivalentTo(quicTLSExtensionType))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves the body when unmarshalling", func() {
|
|
||||||
n, err := extBody.Unmarshal([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(extBody.data).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the body when marshalling", func() {
|
|
||||||
extBody.data = []byte("foo")
|
|
||||||
data, err := extBody.Marshal()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(data).To(Equal([]byte("foo")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,265 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Transport Parameters", func() {
|
|
||||||
Context("for gQUIC", func() {
|
|
||||||
Context("parsing", func() {
|
|
||||||
It("sets all values", func() {
|
|
||||||
values := map[Tag][]byte{
|
|
||||||
TagSFCW: {0xad, 0xfb, 0xca, 0xde},
|
|
||||||
TagCFCW: {0xef, 0xbe, 0xad, 0xde},
|
|
||||||
TagICSL: {0x0d, 0xf0, 0xad, 0xba},
|
|
||||||
TagMIDS: {0xff, 0x10, 0x00, 0xc0},
|
|
||||||
}
|
|
||||||
params, err := readHelloMap(values)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xdecafbad)))
|
|
||||||
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0xdeadbeef)))
|
|
||||||
Expect(params.IdleTimeout).To(Equal(time.Duration(0xbaadf00d) * time.Second))
|
|
||||||
Expect(params.MaxStreams).To(Equal(uint32(0xc00010ff)))
|
|
||||||
Expect(params.OmitConnectionID).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads if the connection ID should be omitted", func() {
|
|
||||||
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
|
|
||||||
params, err := readHelloMap(values)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(params.OmitConnectionID).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't allow idle timeouts below the minimum remote idle timeout", func() {
|
|
||||||
t := 2 * time.Second
|
|
||||||
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
|
||||||
values := map[Tag][]byte{
|
|
||||||
TagICSL: {uint8(t.Seconds()), 0, 0, 0},
|
|
||||||
}
|
|
||||||
params, err := readHelloMap(values)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when given an invalid SFCW value", func() {
|
|
||||||
values := map[Tag][]byte{TagSFCW: {2, 0, 0}} // 1 byte too short
|
|
||||||
_, err := readHelloMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when given an invalid CFCW value", func() {
|
|
||||||
values := map[Tag][]byte{TagCFCW: {2, 0, 0}} // 1 byte too short
|
|
||||||
_, err := readHelloMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when given an invalid TCID value", func() {
|
|
||||||
values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short
|
|
||||||
_, err := readHelloMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when given an invalid ICSL value", func() {
|
|
||||||
values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short
|
|
||||||
_, err := readHelloMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when given an invalid MIDS value", func() {
|
|
||||||
values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short
|
|
||||||
_, err := readHelloMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("writing", func() {
|
|
||||||
It("returns all necessary parameters ", func() {
|
|
||||||
params := &TransportParameters{
|
|
||||||
StreamFlowControlWindow: 0xdeadbeef,
|
|
||||||
ConnectionFlowControlWindow: 0xdecafbad,
|
|
||||||
IdleTimeout: 0xbaaaaaad * time.Second,
|
|
||||||
MaxStreams: 0x1337,
|
|
||||||
}
|
|
||||||
entryMap := params.getHelloMap()
|
|
||||||
Expect(entryMap).To(HaveLen(4))
|
|
||||||
Expect(entryMap).ToNot(HaveKey(TagTCID))
|
|
||||||
Expect(entryMap).To(HaveKeyWithValue(TagSFCW, []byte{0xef, 0xbe, 0xad, 0xde}))
|
|
||||||
Expect(entryMap).To(HaveKeyWithValue(TagCFCW, []byte{0xad, 0xfb, 0xca, 0xde}))
|
|
||||||
Expect(entryMap).To(HaveKeyWithValue(TagICSL, []byte{0xad, 0xaa, 0xaa, 0xba}))
|
|
||||||
Expect(entryMap).To(HaveKeyWithValue(TagMIDS, []byte{0x37, 0x13, 0, 0}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("requests omission of the connection ID", func() {
|
|
||||||
params := &TransportParameters{OmitConnectionID: true}
|
|
||||||
entryMap := params.getHelloMap()
|
|
||||||
Expect(entryMap).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("for TLS", func() {
|
|
||||||
It("has a string representation", func() {
|
|
||||||
p := &TransportParameters{
|
|
||||||
StreamFlowControlWindow: 0x1234,
|
|
||||||
ConnectionFlowControlWindow: 0x4321,
|
|
||||||
MaxBidiStreams: 1337,
|
|
||||||
MaxUniStreams: 7331,
|
|
||||||
IdleTimeout: 42 * time.Second,
|
|
||||||
}
|
|
||||||
Expect(p.String()).To(Equal("&handshake.TransportParameters{StreamFlowControlWindow: 0x1234, ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("parsing", func() {
|
|
||||||
var (
|
|
||||||
params *TransportParameters
|
|
||||||
parameters map[transportParameterID][]byte
|
|
||||||
statelessResetToken []byte
|
|
||||||
)
|
|
||||||
|
|
||||||
marshal := func(p map[transportParameterID][]byte) []byte {
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
for id, val := range p {
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(id))
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(len(val)))
|
|
||||||
b.Write(val)
|
|
||||||
}
|
|
||||||
return b.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
params = &TransportParameters{}
|
|
||||||
statelessResetToken = bytes.Repeat([]byte{42}, 16)
|
|
||||||
parameters = map[transportParameterID][]byte{
|
|
||||||
initialMaxStreamDataParameterID: {0x11, 0x22, 0x33, 0x44},
|
|
||||||
initialMaxDataParameterID: {0x22, 0x33, 0x44, 0x55},
|
|
||||||
initialMaxBidiStreamsParameterID: {0x33, 0x44},
|
|
||||||
initialMaxUniStreamsParameterID: {0x44, 0x55},
|
|
||||||
idleTimeoutParameterID: {0x13, 0x37},
|
|
||||||
maxPacketSizeParameterID: {0x73, 0x31},
|
|
||||||
disableMigrationParameterID: {},
|
|
||||||
statelessResetTokenParameterID: statelessResetToken,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
It("reads parameters", func() {
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344)))
|
|
||||||
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455)))
|
|
||||||
Expect(params.MaxBidiStreams).To(Equal(uint16(0x3344)))
|
|
||||||
Expect(params.MaxUniStreams).To(Equal(uint16(0x4455)))
|
|
||||||
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
|
||||||
Expect(params.OmitConnectionID).To(BeFalse())
|
|
||||||
Expect(params.MaxPacketSize).To(Equal(protocol.ByteCount(0x7331)))
|
|
||||||
Expect(params.DisableMigration).To(BeTrue())
|
|
||||||
Expect(params.StatelessResetToken).To(Equal(statelessResetToken))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the idle_timeout is missing", func() {
|
|
||||||
delete(parameters, idleTimeoutParameterID)
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("missing parameter"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't allow values below the minimum remote idle timeout", func() {
|
|
||||||
t := 2 * time.Second
|
|
||||||
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
|
||||||
parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
|
|
||||||
parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_data has the wrong length", func() {
|
|
||||||
parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
|
|
||||||
parameters[initialMaxBidiStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("wrong length for initial_max_stream_id_bidi: 3 (expected 2)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
|
|
||||||
parameters[initialMaxUniStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("wrong length for initial_max_stream_id_uni: 3 (expected 2)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
|
|
||||||
parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if max_packet_size has the wrong length", func() {
|
|
||||||
parameters[maxPacketSizeParameterID] = []byte{0x11} // should be 2 bytes
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("wrong length for max_packet_size: 1 (expected 2)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects max_packet_sizes smaller than 1200 bytes", func() {
|
|
||||||
parameters[maxPacketSizeParameterID] = []byte{0x4, 0xaf} // 0x4af = 1199
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("invalid value for max_packet_size: 1199 (minimum 1200)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if disable_connection_migration has the wrong length", func() {
|
|
||||||
parameters[disableMigrationParameterID] = []byte{0x11} // should empty
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("wrong length for disable_migration: 1 (expected empty)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the stateless_reset_token has the wrong length", func() {
|
|
||||||
parameters[statelessResetTokenParameterID] = statelessResetToken[1:]
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores unknown parameters", func() {
|
|
||||||
parameters[1337] = []byte{42}
|
|
||||||
err := params.unmarshal(marshal(parameters))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("marshalling", func() {
|
|
||||||
It("marshals", func() {
|
|
||||||
params := &TransportParameters{
|
|
||||||
StreamFlowControlWindow: 0xdeadbeef,
|
|
||||||
ConnectionFlowControlWindow: 0xdecafbad,
|
|
||||||
IdleTimeout: 0xcafe * time.Second,
|
|
||||||
MaxBidiStreams: 0x1234,
|
|
||||||
MaxUniStreams: 0x4321,
|
|
||||||
DisableMigration: true,
|
|
||||||
StatelessResetToken: bytes.Repeat([]byte{100}, 16),
|
|
||||||
}
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
params.marshal(b)
|
|
||||||
|
|
||||||
p := &TransportParameters{}
|
|
||||||
Expect(p.unmarshal(b.Bytes())).To(Succeed())
|
|
||||||
Expect(p.StreamFlowControlWindow).To(Equal(params.StreamFlowControlWindow))
|
|
||||||
Expect(p.ConnectionFlowControlWindow).To(Equal(params.ConnectionFlowControlWindow))
|
|
||||||
Expect(p.MaxUniStreams).To(Equal(params.MaxUniStreams))
|
|
||||||
Expect(p.MaxBidiStreams).To(Equal(params.MaxBidiStreams))
|
|
||||||
Expect(p.IdleTimeout).To(Equal(params.IdleTimeout))
|
|
||||||
Expect(p.DisableMigration).To(Equal(params.DisableMigration))
|
|
||||||
Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,108 +0,0 @@
|
|||||||
package protocol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Connection ID generation", func() {
|
|
||||||
It("generates random connection IDs", func() {
|
|
||||||
c1, err := GenerateConnectionID(8)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(c1).ToNot(BeZero())
|
|
||||||
c2, err := GenerateConnectionID(8)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(c1).ToNot(Equal(c2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates connection IDs with the requested length", func() {
|
|
||||||
c, err := GenerateConnectionID(5)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(c.Len()).To(Equal(5))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("generates random length destination connection IDs", func() {
|
|
||||||
var has8ByteConnID, has18ByteConnID bool
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
c, err := GenerateConnectionIDForInitial()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(c.Len()).To(BeNumerically(">=", 8))
|
|
||||||
Expect(c.Len()).To(BeNumerically("<=", 18))
|
|
||||||
if c.Len() == 8 {
|
|
||||||
has8ByteConnID = true
|
|
||||||
}
|
|
||||||
if c.Len() == 18 {
|
|
||||||
has18ByteConnID = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Expect(has8ByteConnID).To(BeTrue())
|
|
||||||
Expect(has18ByteConnID).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("says if connection IDs are equal", func() {
|
|
||||||
c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
|
||||||
c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
|
||||||
Expect(c1.Equal(c1)).To(BeTrue())
|
|
||||||
Expect(c2.Equal(c2)).To(BeTrue())
|
|
||||||
Expect(c1.Equal(c2)).To(BeFalse())
|
|
||||||
Expect(c2.Equal(c1)).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads the connection ID", func() {
|
|
||||||
buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
|
||||||
c, err := ReadConnectionID(buf, 9)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns io.EOF if there's not enough data to read", func() {
|
|
||||||
buf := bytes.NewBuffer([]byte{1, 2, 3, 4})
|
|
||||||
_, err := ReadConnectionID(buf, 5)
|
|
||||||
Expect(err).To(MatchError(io.EOF))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns nil for a 0 length connection ID", func() {
|
|
||||||
buf := bytes.NewBuffer([]byte{1, 2, 3, 4})
|
|
||||||
c, err := ReadConnectionID(buf, 0)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(c).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the length", func() {
|
|
||||||
c := ConnectionID{1, 2, 3, 4, 5, 6, 7}
|
|
||||||
Expect(c.Len()).To(Equal(7))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has 0 length for the default value", func() {
|
|
||||||
var c ConnectionID
|
|
||||||
Expect(c.Len()).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the bytes", func() {
|
|
||||||
c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
|
|
||||||
Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns a nil byte slice for the default value", func() {
|
|
||||||
var c ConnectionID
|
|
||||||
Expect(c.Bytes()).To(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has a string representation", func() {
|
|
||||||
c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
|
|
||||||
Expect(c.String()).To(Equal("0xdeadbeef42"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has a long string representation", func() {
|
|
||||||
c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}
|
|
||||||
Expect(c.String()).To(Equal("0x13370000decafbad"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has a string representation for the default value", func() {
|
|
||||||
var c ConnectionID
|
|
||||||
Expect(c.String()).To(Equal("(empty)"))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,15 +0,0 @@
|
|||||||
package protocol
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Encryption Level", func() {
|
|
||||||
It("has the correct string representation", func() {
|
|
||||||
Expect(EncryptionUnspecified.String()).To(Equal("unknown"))
|
|
||||||
Expect(EncryptionUnencrypted.String()).To(Equal("unencrypted"))
|
|
||||||
Expect(EncryptionSecure.String()).To(Equal("encrypted (not forward-secure)"))
|
|
||||||
Expect(EncryptionForwardSecure.String()).To(Equal("forward-secure"))
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,244 +0,0 @@
|
|||||||
package protocol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Tests taken and extended from chrome
|
|
||||||
var _ = Describe("packet number calculation", func() {
|
|
||||||
Context("infering a packet number", func() {
|
|
||||||
getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 {
|
|
||||||
if v.UsesVarintPacketNumbers() {
|
|
||||||
switch len {
|
|
||||||
case PacketNumberLen1:
|
|
||||||
return uint64(1) << 7
|
|
||||||
case PacketNumberLen2:
|
|
||||||
return uint64(1) << 14
|
|
||||||
case PacketNumberLen4:
|
|
||||||
return uint64(1) << 30
|
|
||||||
default:
|
|
||||||
Fail("invalid packet number len")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return uint64(1) << (len * 8)
|
|
||||||
}
|
|
||||||
check := func(length PacketNumberLen, expected, last uint64, v VersionNumber) {
|
|
||||||
epoch := getEpoch(length, v)
|
|
||||||
epochMask := epoch - 1
|
|
||||||
wirePacketNumber := expected & epochMask
|
|
||||||
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected)))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range []VersionNumber{Version39, VersionTLS} {
|
|
||||||
version := v
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("using varint packet numbers: %t", version.UsesVarintPacketNumbers()), func() {
|
|
||||||
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
|
|
||||||
length := l
|
|
||||||
|
|
||||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
|
||||||
epoch := getEpoch(length, version)
|
|
||||||
epochMask := epoch - 1
|
|
||||||
|
|
||||||
It("works near epoch start", func() {
|
|
||||||
// A few quick manual sanity check
|
|
||||||
check(length, 1, 0, version)
|
|
||||||
check(length, epoch+1, epochMask, version)
|
|
||||||
check(length, epoch, epochMask, version)
|
|
||||||
|
|
||||||
// Cases where the last number was close to the start of the range.
|
|
||||||
for last := uint64(0); last < 10; last++ {
|
|
||||||
// Small numbers should not wrap (even if they're out of order).
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
check(length, j, last, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Large numbers should not wrap either (because we're near 0 already).
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
check(length, epoch-1-j, last, version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works near epoch end", func() {
|
|
||||||
// Cases where the last number was close to the end of the range
|
|
||||||
for i := uint64(0); i < 10; i++ {
|
|
||||||
last := epoch - i
|
|
||||||
|
|
||||||
// Small numbers should wrap.
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
check(length, epoch+j, last, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Large numbers should not (even if they're out of order).
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
check(length, epoch-1-j, last, version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Next check where we're in a non-zero epoch to verify we handle
|
|
||||||
// reverse wrapping, too.
|
|
||||||
It("works near previous epoch", func() {
|
|
||||||
prevEpoch := 1 * epoch
|
|
||||||
curEpoch := 2 * epoch
|
|
||||||
// Cases where the last number was close to the start of the range
|
|
||||||
for i := uint64(0); i < 10; i++ {
|
|
||||||
last := curEpoch + i
|
|
||||||
// Small number should not wrap (even if they're out of order).
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
check(length, curEpoch+j, last, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// But large numbers should reverse wrap.
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
num := epoch - 1 - j
|
|
||||||
check(length, prevEpoch+num, last, version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works near next epoch", func() {
|
|
||||||
curEpoch := 2 * epoch
|
|
||||||
nextEpoch := 3 * epoch
|
|
||||||
// Cases where the last number was close to the end of the range
|
|
||||||
for i := uint64(0); i < 10; i++ {
|
|
||||||
last := nextEpoch - 1 - i
|
|
||||||
|
|
||||||
// Small numbers should wrap.
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
check(length, nextEpoch+j, last, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// but large numbers should not (even if they're out of order).
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
num := epoch - 1 - j
|
|
||||||
check(length, curEpoch+num, last, version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works near next max", func() {
|
|
||||||
maxNumber := uint64(math.MaxUint64)
|
|
||||||
maxEpoch := maxNumber & ^epochMask
|
|
||||||
|
|
||||||
// Cases where the last number was close to the end of the range
|
|
||||||
for i := uint64(0); i < 10; i++ {
|
|
||||||
// Subtract 1, because the expected next packet number is 1 more than the
|
|
||||||
// last packet number.
|
|
||||||
last := maxNumber - i - 1
|
|
||||||
|
|
||||||
// Small numbers should not wrap, because they have nowhere to go.
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
check(length, maxEpoch+j, last, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Large numbers should not wrap either.
|
|
||||||
for j := uint64(0); j < 10; j++ {
|
|
||||||
num := epoch - 1 - j
|
|
||||||
check(length, maxEpoch+num, last, version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
Context("shortening a packet number for the header", func() {
|
|
||||||
Context("shortening", func() {
|
|
||||||
It("sends out low packet numbers as 2 byte", func() {
|
|
||||||
length := GetPacketNumberLengthForHeader(4, 2, version)
|
|
||||||
Expect(length).To(Equal(PacketNumberLen2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
|
|
||||||
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, version)
|
|
||||||
Expect(length).To(Equal(PacketNumberLen2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
|
|
||||||
length := GetPacketNumberLengthForHeader(40000, 2, version)
|
|
||||||
Expect(length).To(Equal(PacketNumberLen4))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("self-consistency", func() {
|
|
||||||
It("works for small packet numbers", func() {
|
|
||||||
for i := uint64(1); i < 10000; i++ {
|
|
||||||
packetNumber := PacketNumber(i)
|
|
||||||
leastUnacked := PacketNumber(1)
|
|
||||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
|
||||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
|
||||||
|
|
||||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
|
||||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works for small packet numbers and increasing ACKed packets", func() {
|
|
||||||
for i := uint64(1); i < 10000; i++ {
|
|
||||||
packetNumber := PacketNumber(i)
|
|
||||||
leastUnacked := PacketNumber(i / 2)
|
|
||||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
|
||||||
epochMask := getEpoch(length, version) - 1
|
|
||||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
|
||||||
|
|
||||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
|
||||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("also works for larger packet numbers", func() {
|
|
||||||
var increment uint64
|
|
||||||
for i := uint64(1); i < getEpoch(PacketNumberLen4, version); i += increment {
|
|
||||||
packetNumber := PacketNumber(i)
|
|
||||||
leastUnacked := PacketNumber(1)
|
|
||||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
|
||||||
epochMask := getEpoch(length, version) - 1
|
|
||||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
|
||||||
|
|
||||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
|
||||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
|
||||||
|
|
||||||
increment = getEpoch(length, version) / 8
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works for packet numbers larger than 2^48", func() {
|
|
||||||
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
|
|
||||||
packetNumber := PacketNumber(i)
|
|
||||||
leastUnacked := PacketNumber(i - 1000)
|
|
||||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
|
||||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
|
||||||
|
|
||||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
|
||||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("determining the minimum length of a packet number", func() {
|
|
||||||
It("1 byte", func() {
|
|
||||||
Expect(GetPacketNumberLength(0xFF)).To(Equal(PacketNumberLen1))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("2 byte", func() {
|
|
||||||
Expect(GetPacketNumberLength(0xFFFF)).To(Equal(PacketNumberLen2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("4 byte", func() {
|
|
||||||
Expect(GetPacketNumberLength(0xFFFFFFFF)).To(Equal(PacketNumberLen4))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("6 byte", func() {
|
|
||||||
Expect(GetPacketNumberLength(0xFFFFFFFFFFFF)).To(Equal(PacketNumberLen6))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
@ -1,19 +0,0 @@
|
|||||||
package protocol
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Perspective", func() {
|
|
||||||
It("has a string representation", func() {
|
|
||||||
Expect(PerspectiveClient.String()).To(Equal("Client"))
|
|
||||||
Expect(PerspectiveServer.String()).To(Equal("Server"))
|
|
||||||
Expect(Perspective(0).String()).To(Equal("invalid perspective"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the opposite", func() {
|
|
||||||
Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer))
|
|
||||||
Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient))
|
|
||||||
})
|
|
||||||
})
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user