1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-02 23:47:07 -05:00

remove all vendor tests

This commit is contained in:
Darien Raymond 2018-11-20 23:59:01 +01:00
parent 84f8bca01c
commit 786290a31d
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
183 changed files with 0 additions and 33311 deletions

View File

@ -1,26 +0,0 @@
package benchmark
import (
"flag"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestBenchmark(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Benchmark Suite")
}
var (
size int // file size in MB, will be read from flags
samples int // number of samples for Measure, will be read from flags
)
func init() {
flag.IntVar(&size, "size", 50, "data length (in MB)")
flag.IntVar(&samples, "samples", 6, "number of samples")
flag.Parse()
}

View File

@ -1,87 +0,0 @@
package benchmark
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"math/rand"
"net"
quic "github.com/lucas-clemente/quic-go"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func init() {
var _ = Describe("Benchmarks", func() {
dataLen := size * /* MB */ 1e6
data := make([]byte, dataLen)
rand.Seed(GinkgoRandomSeed())
rand.Read(data) // no need to check for an error. math.Rand.Read never errors
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with version %s", version), func() {
Measure(fmt.Sprintf("transferring a %d MB file", size), func(b Benchmarker) {
var ln quic.Listener
serverAddr := make(chan net.Addr)
handshakeChan := make(chan struct{})
// start the server
go func() {
defer GinkgoRecover()
var err error
ln, err = quic.ListenAddr(
"localhost:0",
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
serverAddr <- ln.Addr()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
// wait for the client to complete the handshake before sending the data
// this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets
<-handshakeChan
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
err = str.Close()
Expect(err).ToNot(HaveOccurred())
}()
// start the client
addr := <-serverAddr
sess, err := quic.DialAddr(
addr.String(),
&tls.Config{InsecureSkipVerify: true},
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
close(handshakeChan)
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
buf := &bytes.Buffer{}
// measure the time it takes to download the dataLen bytes
// note we're measuring the time for the transfer, i.e. excluding the handshake
runtime := b.Time("transfer time", func() {
_, err := io.Copy(buf, str)
Expect(err).NotTo(HaveOccurred())
})
Expect(buf.Bytes()).To(Equal(data))
b.RecordValue("transfer rate [MB/s]", float64(dataLen)/1e6/runtime.Seconds())
ln.Close()
sess.Close()
}, samples)
})
}
})
}

View File

@ -1,21 +0,0 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Buffer Pool", func() {
It("returns buffers of cap", func() {
buf := *getPacketBuffer()
Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize)))
})
It("panics if wrong-sized buffers are passed", func() {
Expect(func() {
putPacketBuffer(&[]byte{0})
}).To(Panic())
})
})

File diff suppressed because it is too large Load Diff

View File

@ -1,119 +0,0 @@
package quic
import (
"bytes"
"errors"
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockPacketConn struct {
addr net.Addr
dataToRead chan []byte
dataReadFrom net.Addr
readErr error
dataWritten bytes.Buffer
dataWrittenTo net.Addr
closed bool
}
func newMockPacketConn() *mockPacketConn {
return &mockPacketConn{
dataToRead: make(chan []byte, 1000),
}
}
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
if c.readErr != nil {
return 0, nil, c.readErr
}
data, ok := <-c.dataToRead
if !ok {
return 0, nil, errors.New("connection closed")
}
n := copy(b, data)
return n, c.dataReadFrom, nil
}
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
c.dataWrittenTo = addr
return c.dataWritten.Write(b)
}
func (c *mockPacketConn) Close() error {
if !c.closed {
close(c.dataToRead)
}
c.closed = true
return nil
}
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }
var _ net.PacketConn = &mockPacketConn{}
var _ = Describe("Connection", func() {
var c *conn
var packetConn *mockPacketConn
BeforeEach(func() {
addr := &net.UDPAddr{
IP: net.IPv4(192, 168, 100, 200),
Port: 1337,
}
packetConn = newMockPacketConn()
c = &conn{
currentAddr: addr,
pconn: packetConn,
}
})
It("writes", func() {
err := c.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(packetConn.dataWritten.Bytes()).To(Equal([]byte("foobar")))
Expect(packetConn.dataWrittenTo.String()).To(Equal("192.168.100.200:1337"))
})
It("reads", func() {
packetConn.dataToRead <- []byte("foo")
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
p := make([]byte, 10)
n, raddr, err := c.Read(p)
Expect(err).ToNot(HaveOccurred())
Expect(raddr.String()).To(Equal("127.0.0.1:1336"))
Expect(n).To(Equal(3))
Expect(p[0:3]).To(Equal([]byte("foo")))
})
It("gets the remote address", func() {
Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337"))
})
It("gets the local address", func() {
addr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 1),
Port: 1234,
}
packetConn.addr = addr
Expect(c.LocalAddr()).To(Equal(addr))
})
It("changes the remote address", func() {
addr := &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 7331,
}
c.SetCurrentRemoteAddr(addr)
Expect(c.RemoteAddr().String()).To(Equal(addr.String()))
})
It("closes", func() {
err := c.Close()
Expect(err).ToNot(HaveOccurred())
Expect(packetConn.closed).To(BeTrue())
})
})

View File

@ -1,26 +0,0 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Crypto Stream", func() {
var (
str *cryptoStreamImpl
mockSender *MockStreamSender
)
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStreamImpl)
})
It("sets the read offset", func() {
str.setReadOffset(0x42)
Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42)))
Expect(str.receiveStream.frameQueue.readPos).To(Equal(protocol.ByteCount(0x42)))
})
})

View File

@ -1,9 +0,0 @@
FROM scratch
VOLUME /certs
VOLUME /www
EXPOSE 6121
ADD main /main
CMD ["/main", "-bind=0.0.0.0", "-certpath=/certs/", "-www=/www"]

View File

@ -1,7 +0,0 @@
# About the certificate
Yes, this folder contains a private key and a certificate for quic.clemente.io.
Unfortunately we need a valid certificate for the integration tests with Chrome and `quic_client`. No important data is served on the "real" `quic.clemente.io` (only a test page), and the MITM problem is imho negligible.
If you figure out a way to test with Chrome without having a cert and key here, let us now in an issue.

View File

@ -1,66 +0,0 @@
package main
import (
"bytes"
"flag"
"io"
"net/http"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
func main() {
verbose := flag.Bool("v", false, "verbose")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
urls := flag.Args()
logger := utils.DefaultLogger
if *verbose {
logger.SetLogLevel(utils.LogLevelDebug)
} else {
logger.SetLogLevel(utils.LogLevelInfo)
}
logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
roundTripper := &h2quic.RoundTripper{
QuicConfig: &quic.Config{Versions: versions},
}
defer roundTripper.Close()
hclient := &http.Client{
Transport: roundTripper,
}
var wg sync.WaitGroup
wg.Add(len(urls))
for _, addr := range urls {
logger.Infof("GET %s", addr)
go func(addr string) {
rsp, err := hclient.Get(addr)
if err != nil {
panic(err)
}
logger.Infof("Got response for %s: %#v", addr, rsp)
body := &bytes.Buffer{}
_, err = io.Copy(body, rsp.Body)
if err != nil {
panic(err)
}
logger.Infof("Request Body:")
logger.Infof("%s", body.Bytes())
wg.Done()
}(addr)
}
wg.Wait()
}

View File

@ -1,105 +0,0 @@
package main
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"log"
"math/big"
quic "github.com/lucas-clemente/quic-go"
)
const addr = "localhost:4242"
const message = "foobar"
// We start a server echoing data on the first stream the client opens,
// then connect with a client, send the message, and wait for its receipt.
func main() {
go func() { log.Fatal(echoServer()) }()
err := clientMain()
if err != nil {
panic(err)
}
}
// Start a server that echos all data on the first stream opened by the client
func echoServer() error {
listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
if err != nil {
return err
}
sess, err := listener.Accept()
if err != nil {
return err
}
stream, err := sess.AcceptStream()
if err != nil {
panic(err)
}
// Echo through the loggingWriter
_, err = io.Copy(loggingWriter{stream}, stream)
return err
}
func clientMain() error {
session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
if err != nil {
return err
}
stream, err := session.OpenStreamSync()
if err != nil {
return err
}
fmt.Printf("Client: Sending '%s'\n", message)
_, err = stream.Write([]byte(message))
if err != nil {
return err
}
buf := make([]byte, len(message))
_, err = io.ReadFull(stream, buf)
if err != nil {
return err
}
fmt.Printf("Client: Got '%s'\n", buf)
return nil
}
// A wrapper for io.Writer that also logs the message.
type loggingWriter struct{ io.Writer }
func (w loggingWriter) Write(b []byte) (int, error) {
fmt.Printf("Server: Got '%s'\n", string(b))
return w.Writer.Write(b)
}
// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
}

View File

@ -1,62 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIGCzCCBPOgAwIBAgISA6pwet1vq9IjS9ThcggZYGfyMA0GCSqGSIb3DQEBCwUA
MEoxCzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MSMwIQYDVQQD
ExpMZXQncyBFbmNyeXB0IEF1dGhvcml0eSBYMzAeFw0xODA2MDkwODI2MjVaFw0x
ODA5MDcwODI2MjVaMBsxGTAXBgNVBAMTEHF1aWMuY2xlbWVudGUuaW8wggEiMA0G
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDqON9RKCMK4dHxLTLiv3n+b/v4KCbY
Fo38rxCcv4/qoWI1Zoz4XMyPnTZPrzVMz6rhDKWTve1v7g+itgQDp1OoRo5ihtqC
m9Dr1Ed+qb3hHtFRBBOTiWQy1Y0fvUDYGlSHQ4R7xXYVzCYkQ9zopqr2otJp/9ZA
1Yy3ATITSgcds9wJaAOpkSxCx+D7cqpGbaNtxogRZ4ZT1vVs3l7Kr26fvnALMM4K
nf2Rpq1ZjgPRoLNcfUWCZO7gr+VNCFaI2AQRaAzZkSXAk0B91waTtaFD87kYXl9Z
8Jgx1PhJnO8WBUz/7+OAO6M+kuT78e6FltMN004Q8WmuNww6PzmVwnmpAgMBAAGj
ggMYMIIDFDAOBgNVHQ8BAf8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwEB/wQCMAAwHQYDVR0OBBYEFIs3yJaLnEw7MRtNf9EbzESy
TrHwMB8GA1UdIwQYMBaAFKhKamMEfd265tE5t6ZFZe/zqOyhMG8GCCsGAQUFBwEB
BGMwYTAuBggrBgEFBQcwAYYiaHR0cDovL29jc3AuaW50LXgzLmxldHNlbmNyeXB0
Lm9yZzAvBggrBgEFBQcwAoYjaHR0cDovL2NlcnQuaW50LXgzLmxldHNlbmNyeXB0
Lm9yZy8wGwYDVR0RBBQwEoIQcXVpYy5jbGVtZW50ZS5pbzCB/gYDVR0gBIH2MIHz
MAgGBmeBDAECATCB5gYLKwYBBAGC3xMBAQEwgdYwJgYIKwYBBQUHAgEWGmh0dHA6
Ly9jcHMubGV0c2VuY3J5cHQub3JnMIGrBggrBgEFBQcCAjCBngyBm1RoaXMgQ2Vy
dGlmaWNhdGUgbWF5IG9ubHkgYmUgcmVsaWVkIHVwb24gYnkgUmVseWluZyBQYXJ0
aWVzIGFuZCBvbmx5IGluIGFjY29yZGFuY2Ugd2l0aCB0aGUgQ2VydGlmaWNhdGUg
UG9saWN5IGZvdW5kIGF0IGh0dHBzOi8vbGV0c2VuY3J5cHQub3JnL3JlcG9zaXRv
cnkvMIIBBAYKKwYBBAHWeQIEAgSB9QSB8gDwAHYA23Sv7ssp7LH+yj5xbSzluaq7
NveEcYPHXZ1PN7Yfv2QAAAFj495INQAABAMARzBFAiEApF0BFCWyGIUrJrsYuugt
tshGVdg2+7f6d4B1D/xF2s0CIGGsVL2nRlXJXrkk3aa83lH4HzP9vcQSnMFHdXOf
9XeZAHYAKTxRllTIOWW6qlD8WAfUt2+/WHopctykwwz05UVH9HgAAAFj495IQwAA
BAMARzBFAiEAkV4gbM6hucL7ZwqTzb5fKxYhk6WHr5y8pzClZD3qqv4CIBYH7MSA
P05CXv7tHiHEizIvhWJJvVa2E6XLjbDRQMnUMA0GCSqGSIb3DQEBCwUAA4IBAQA2
CALjtlxGXkkfKsRgKDoPpzg/IAl7crq5OrGGwW/bxbeDeiRHVt0Hlhr+0XPYlh/A
m8qBlg7TMHJa2zIt7wkG/MX3d9bO2bxXxsdjfMXLtxQu7eZ5nzGuXL8WGnZwtSqz
M5pOF/AU1JQojIhehKCeqqUi5UobxXUm+9D6OVmr8s732X3n/TL6pgsWRhay9tjB
kdZ9TSe0tLYyXnbwHp0rgwNWOMMN1Tc+Fpqc8UlrCq5REb0bLIQ6A2IJH08MEPWG
ukXHPAaHLz9oB1emR3flCoMKB0KrXUpDFXemOIPN6QVBO16LNNuRffKwXnzp60+b
MoB4Krxrab7TzlfT+HnP
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIEkjCCA3qgAwIBAgIQCgFBQgAAAVOFc2oLheynCDANBgkqhkiG9w0BAQsFADA/
MSQwIgYDVQQKExtEaWdpdGFsIFNpZ25hdHVyZSBUcnVzdCBDby4xFzAVBgNVBAMT
DkRTVCBSb290IENBIFgzMB4XDTE2MDMxNzE2NDA0NloXDTIxMDMxNzE2NDA0Nlow
SjELMAkGA1UEBhMCVVMxFjAUBgNVBAoTDUxldCdzIEVuY3J5cHQxIzAhBgNVBAMT
GkxldCdzIEVuY3J5cHQgQXV0aG9yaXR5IFgzMIIBIjANBgkqhkiG9w0BAQEFAAOC
AQ8AMIIBCgKCAQEAnNMM8FrlLke3cl03g7NoYzDq1zUmGSXhvb418XCSL7e4S0EF
q6meNQhY7LEqxGiHC6PjdeTm86dicbp5gWAf15Gan/PQeGdxyGkOlZHP/uaZ6WA8
SMx+yk13EiSdRxta67nsHjcAHJyse6cF6s5K671B5TaYucv9bTyWaN8jKkKQDIZ0
Z8h/pZq4UmEUEz9l6YKHy9v6Dlb2honzhT+Xhq+w3Brvaw2VFn3EK6BlspkENnWA
a6xK8xuQSXgvopZPKiAlKQTGdMDQMc2PMTiVFrqoM7hD8bEfwzB/onkxEz0tNvjj
/PIzark5McWvxI0NHWQWM6r6hCm21AvA2H3DkwIDAQABo4IBfTCCAXkwEgYDVR0T
AQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAYYwfwYIKwYBBQUHAQEEczBxMDIG
CCsGAQUFBzABhiZodHRwOi8vaXNyZy50cnVzdGlkLm9jc3AuaWRlbnRydXN0LmNv
bTA7BggrBgEFBQcwAoYvaHR0cDovL2FwcHMuaWRlbnRydXN0LmNvbS9yb290cy9k
c3Ryb290Y2F4My5wN2MwHwYDVR0jBBgwFoAUxKexpHsscfrb4UuQdf/EFWCFiRAw
VAYDVR0gBE0wSzAIBgZngQwBAgEwPwYLKwYBBAGC3xMBAQEwMDAuBggrBgEFBQcC
ARYiaHR0cDovL2Nwcy5yb290LXgxLmxldHNlbmNyeXB0Lm9yZzA8BgNVHR8ENTAz
MDGgL6AthitodHRwOi8vY3JsLmlkZW50cnVzdC5jb20vRFNUUk9PVENBWDNDUkwu
Y3JsMB0GA1UdDgQWBBSoSmpjBH3duubRObemRWXv86jsoTANBgkqhkiG9w0BAQsF
AAOCAQEA3TPXEfNjWDjdGBX7CVW+dla5cEilaUcne8IkCJLxWh9KEik3JHRRHGJo
uM2VcGfl96S8TihRzZvoroed6ti6WqEBmtzw3Wodatg+VyOeph4EYpr/1wXKtx8/
wApIvJSwtmVi4MFU5aMqrSDE6ea73Mj2tcMyo5jMd6jmeWUHK8so/joWUoHOUgwu
X4Po1QYz+3dszkDqMp4fklxBwXRsW10KXzPMTZ+sOPAveyxindmjkW8lGy+QsRlG
PfZ+G6Z6h7mjem0Y+iWlkYcV4PIWL1iwBi8saCbGS5jN2p8M+X+Q7UNKEkROb3N6
KOqkqm57TH2H3eDJAkSnh6/DNFu0Qg==
-----END CERTIFICATE-----

View File

@ -1,174 +0,0 @@
package main
import (
"crypto/md5"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"mime/multipart"
"net/http"
"path"
"runtime"
"strings"
"sync"
_ "net/http/pprof"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type binds []string
func (b binds) String() string {
return strings.Join(b, ",")
}
func (b *binds) Set(v string) error {
*b = strings.Split(v, ",")
return nil
}
// Size is needed by the /demo/upload handler to determine the size of the uploaded file
type Size interface {
Size() int64
}
func init() {
http.HandleFunc("/demo/tile", func(w http.ResponseWriter, r *http.Request) {
// Small 40x40 png
w.Write([]byte{
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d,
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x28,
0x01, 0x03, 0x00, 0x00, 0x00, 0xb6, 0x30, 0x2a, 0x2e, 0x00, 0x00, 0x00,
0x03, 0x50, 0x4c, 0x54, 0x45, 0x5a, 0xc3, 0x5a, 0xad, 0x38, 0xaa, 0xdb,
0x00, 0x00, 0x00, 0x0b, 0x49, 0x44, 0x41, 0x54, 0x78, 0x01, 0x63, 0x18,
0x61, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x01, 0xe2, 0xb8, 0x75, 0x22, 0x00,
0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82,
})
})
http.HandleFunc("/demo/tiles", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "<html><head><style>img{width:40px;height:40px;}</style></head><body>")
for i := 0; i < 200; i++ {
fmt.Fprintf(w, `<img src="/demo/tile?cachebust=%d">`, i)
}
io.WriteString(w, "</body></html>")
})
http.HandleFunc("/demo/echo", func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
fmt.Printf("error reading body while handling /echo: %s\n", err.Error())
}
w.Write(body)
})
// accept file uploads and return the MD5 of the uploaded file
// maximum accepted file size is 1 GB
http.HandleFunc("/demo/upload", func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
err := r.ParseMultipartForm(1 << 30) // 1 GB
if err == nil {
var file multipart.File
file, _, err = r.FormFile("uploadfile")
if err == nil {
var size int64
if sizeInterface, ok := file.(Size); ok {
size = sizeInterface.Size()
b := make([]byte, size)
file.Read(b)
md5 := md5.Sum(b)
fmt.Fprintf(w, "%x", md5)
return
}
err = errors.New("couldn't get uploaded file size")
}
}
if err != nil {
utils.DefaultLogger.Infof("Error receiving upload: %#v", err)
}
}
io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data">
<input type="file" name="uploadfile"><br>
<input type="submit">
</form></body></html>`)
})
}
func getBuildDir() string {
_, filename, _, ok := runtime.Caller(0)
if !ok {
panic("Failed to get current frame")
}
return path.Dir(filename)
}
func main() {
// defer profile.Start().Stop()
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
// runtime.SetBlockProfileRate(1)
verbose := flag.Bool("v", false, "verbose")
bs := binds{}
flag.Var(&bs, "bind", "bind to")
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
www := flag.String("www", "/var/www", "www data")
tcp := flag.Bool("tcp", false, "also listen on TCP")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
logger := utils.DefaultLogger
if *verbose {
logger.SetLogLevel(utils.LogLevelDebug)
} else {
logger.SetLogLevel(utils.LogLevelInfo)
}
logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
certFile := *certPath + "/fullchain.pem"
keyFile := *certPath + "/privkey.pem"
http.Handle("/", http.FileServer(http.Dir(*www)))
if len(bs) == 0 {
bs = binds{"localhost:6121"}
}
var wg sync.WaitGroup
wg.Add(len(bs))
for _, b := range bs {
bCap := b
go func() {
var err error
if *tcp {
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
} else {
server := h2quic.Server{
Server: &http.Server{Addr: bCap},
QuicConfig: &quic.Config{Versions: versions},
}
err = server.ListenAndServeTLS(certFile, keyFile)
}
if err != nil {
fmt.Println(err)
}
wg.Done()
}()
}
wg.Wait()
}

View File

@ -1,28 +0,0 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDqON9RKCMK4dHx
LTLiv3n+b/v4KCbYFo38rxCcv4/qoWI1Zoz4XMyPnTZPrzVMz6rhDKWTve1v7g+i
tgQDp1OoRo5ihtqCm9Dr1Ed+qb3hHtFRBBOTiWQy1Y0fvUDYGlSHQ4R7xXYVzCYk
Q9zopqr2otJp/9ZA1Yy3ATITSgcds9wJaAOpkSxCx+D7cqpGbaNtxogRZ4ZT1vVs
3l7Kr26fvnALMM4Knf2Rpq1ZjgPRoLNcfUWCZO7gr+VNCFaI2AQRaAzZkSXAk0B9
1waTtaFD87kYXl9Z8Jgx1PhJnO8WBUz/7+OAO6M+kuT78e6FltMN004Q8WmuNww6
PzmVwnmpAgMBAAECggEAAW7hpux48msZTsF5CzwisfTbdNRCEJZqvf4QOvVNGyFr
qWn8ONTQh5xtpaUrzVGD+SaLqNDDsCijvdohQih28ZOk8WNj2OK9L4Q3/8VoHQWE
QFunBwMTMuBtoaEV0XyvwbgfCmbV5yI9pYEoy9+hMisi4HUpSXJFDyWZudZ9Hqhl
FBf5UPxF1AjZViODVfnKkrWh85jaENyWjWrMEDSohzxP8IStFMK8E0feXaQb+G0f
uTelpG2JmVcO3xaYgtRVeS4p5liMQg5gE5oa2Jh7Vp3TwMhQVg0aLmslSLAYlPoh
hyBeOS3ucyFHoC/6Stnnx3jdOEf2lEUObJj3QVEeBQKBgQD3qptNY9R62UQFO8gT
pseEO5CAZRGuKG0VLPNqKKerYQreiT3xYOTPmwEy+xXzYV6DKHtlKDetHuOB862M
E1bKmjDedafQ3KMc4tywLW+U9I8GyooT8uoetzARzgoftRcMzdeu4fexWTsi+kOh
5/PeXUBnFph8E68nXHQR5+MWAwKBgQDyGnT9KUVuLvzrFAghe9Eu1iZJSDs/IrJj
we3XQ7loqMc/Qv34eVKATsgtb2cTSeTivQcSvQO+Uu9dgo17BoWAABDTaV8NAOZ6
cV2kedrWnxaTRXzB6Z5EKLg+DOMTpCV4Nf3OwzGq/mnKAe8cNM/hS+8HcZywKwr2
UwLKSq6n4wKBgQDXGUaOnVCSbZZlETnAz43i67ShvqXvY07yIDs8jRiqgLrm8b1p
oaS4JkCRXX8ABSYHtaYOAjLw2a3wVIn66WTsy6P74aWhga7szJ+tJ5kMfqal2Ey5
7LSnfqRyIkeqqCXfyfsz+S+dyQjSZRdOS90C2Gyx2+8NfC8YeXSZhJM2rwKBgQDB
pL28G/2nsrejQ2N5fLKE9s6qwLZ6ukLbHasiCc5L0uuDQw8mZcvCSsE77iYQvILx
hGYa68oJugYw0hJdu4qeJe9PWbGoEfdHKlPPEZQjJB4Hb4XpB/YJ6FPtdZtPA3Tg
4LaAYYnhjhqJc+CPvAIl3vlyB8Je+h6LhTvvF6r5JwKBgHBkpQ2Y+N3dMgfE8T38
+Vy+75i9ew3F0TB8o2MWJzYLYCzNLqrB4GkZieudoYP1Cmh50wYIifsk2zNpxOG0
27tUfMxza4ITUfmtVHS22DGWCqb1TgfQt7jajdioK4TUOZFEmiQiOYMZoFMenbzZ
miz6s3fWx1/b3wNd1J5uV2QS
-----END PRIVATE KEY-----

View File

@ -1,435 +0,0 @@
package quic
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STREAM frame sorter", func() {
var s *frameSorter
checkGaps := func(expectedGaps []utils.ByteInterval) {
Expect(s.gaps.Len()).To(Equal(len(expectedGaps)))
var i int
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
Expect(gap.Value).To(Equal(expectedGaps[i]))
i++
}
}
BeforeEach(func() {
s = newFrameSorter()
})
It("head returns nil when empty", func() {
Expect(s.Pop()).To(BeNil())
})
Context("Push", func() {
It("inserts and pops a single frame", func() {
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(Equal([]byte("foobar")))
Expect(fin).To(BeFalse())
Expect(s.Pop()).To(BeNil())
})
It("inserts and pops two consecutive frame", func() {
Expect(s.Push([]byte("foo"), 0, false)).To(Succeed())
Expect(s.Push([]byte("bar"), 3, false)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(Equal([]byte("foo")))
Expect(fin).To(BeFalse())
data, fin = s.Pop()
Expect(data).To(Equal([]byte("bar")))
Expect(fin).To(BeFalse())
Expect(s.Pop()).To(BeNil())
})
It("ignores empty frames", func() {
Expect(s.Push(nil, 0, false)).To(Succeed())
Expect(s.Pop()).To(BeNil())
})
Context("FIN handling", func() {
It("saves a FIN at offset 0", func() {
Expect(s.Push(nil, 0, true)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(BeEmpty())
Expect(fin).To(BeTrue())
data, fin = s.Pop()
Expect(data).To(BeNil())
Expect(fin).To(BeTrue())
})
It("saves a FIN frame at non-zero offset", func() {
Expect(s.Push([]byte("foobar"), 0, true)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(Equal([]byte("foobar")))
Expect(fin).To(BeTrue())
data, fin = s.Pop()
Expect(data).To(BeNil())
Expect(fin).To(BeTrue())
})
It("sets the FIN if a stream is closed after receiving some data", func() {
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
Expect(s.Push(nil, 6, true)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(Equal([]byte("foobar")))
Expect(fin).To(BeTrue())
data, fin = s.Pop()
Expect(data).To(BeNil())
Expect(fin).To(BeTrue())
})
})
Context("Gap handling", func() {
It("finds the first gap", func() {
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 10},
{Start: 16, End: protocol.MaxByteCount},
})
})
It("correctly sets the first gap for a frame with offset 0", func() {
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 6, End: protocol.MaxByteCount},
})
})
It("finds the two gaps", func() {
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 20, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 10},
{Start: 16, End: 20},
{Start: 26, End: protocol.MaxByteCount},
})
})
It("finds the two gaps in reverse order", func() {
Expect(s.Push([]byte("foobar"), 20, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 10},
{Start: 16, End: 20},
{Start: 26, End: protocol.MaxByteCount},
})
})
It("shrinks a gap when it is partially filled", func() {
Expect(s.Push([]byte("test"), 10, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 4, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 4},
{Start: 14, End: protocol.MaxByteCount},
})
})
It("deletes a gap at the beginning, when it is filled", func() {
Expect(s.Push([]byte("test"), 6, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 10, End: protocol.MaxByteCount},
})
})
It("deletes a gap in the middle, when it is filled", func() {
Expect(s.Push([]byte("test"), 0, false)).To(Succeed())
Expect(s.Push([]byte("test2"), 10, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 4, false)).To(Succeed())
Expect(s.queue).To(HaveLen(3))
checkGaps([]utils.ByteInterval{
{Start: 15, End: protocol.MaxByteCount},
})
})
It("splits a gap into two", func() {
Expect(s.Push([]byte("test"), 100, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 50, false)).To(Succeed())
Expect(s.queue).To(HaveLen(2))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 50},
{Start: 56, End: 100},
{Start: 104, End: protocol.MaxByteCount},
})
})
Context("Overlapping Stream Data detection", func() {
// create gaps: 0-5, 10-15, 20-25, 30-inf
BeforeEach(func() {
Expect(s.Push([]byte("12345"), 5, false)).To(Succeed())
Expect(s.Push([]byte("12345"), 15, false)).To(Succeed())
Expect(s.Push([]byte("12345"), 25, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 10, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame with offset 0 that overlaps at the end", func() {
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(0)))
Expect(s.queue[0]).To(Equal([]byte("fooba")))
Expect(s.queue[0]).To(HaveCap(5))
checkGaps([]utils.ByteInterval{
{Start: 10, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame that overlaps at the end", func() {
// 4 to 7
Expect(s.Push([]byte("foo"), 4, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(4)))
Expect(s.queue[4]).To(Equal([]byte("f")))
Expect(s.queue[4]).To(HaveCap(1))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 4},
{Start: 10, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame that completely fills a gap, but overlaps at the end", func() {
// 10 to 16
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal([]byte("fooba")))
Expect(s.queue[10]).To(HaveCap(5))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame that overlaps at the beginning", func() {
// 8 to 14
Expect(s.Push([]byte("foobar"), 8, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal([]byte("obar")))
Expect(s.queue[10]).To(HaveCap(4))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 14, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that overlaps at the beginning and at the end, starting in a gap", func() {
// 2 to 12
Expect(s.Push([]byte("1234567890"), 2, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
Expect(s.queue[2]).To(Equal([]byte("1234567890")))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 2},
{Start: 12, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() {
// 2 to 17
Expect(s.Push([]byte("123456789012345"), 2, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
Expect(s.queue[2]).To(Equal([]byte("1234567890123")))
Expect(s.queue[2]).To(HaveCap(13))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 2},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() {
// 5 to 22
Expect(s.Push([]byte("12345678901234567"), 5, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
Expect(s.queue[10]).To(Equal([]byte("678901234567")))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 22, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that closes multiple gaps", func() {
// 2 to 27
Expect(s.Push(bytes.Repeat([]byte{'e'}, 25), 2, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(25)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
Expect(s.queue[2]).To(Equal(bytes.Repeat([]byte{'e'}, 23)))
Expect(s.queue[2]).To(HaveCap(23))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 2},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that closes multiple gaps", func() {
// 5 to 27
Expect(s.Push(bytes.Repeat([]byte{'d'}, 22), 5, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(25)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal(bytes.Repeat([]byte{'d'}, 15)))
Expect(s.queue[10]).To(HaveCap(15))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that covers multiple gaps and ends at the end of a gap", func() {
data := bytes.Repeat([]byte{'e'}, 14)
// 1 to 15
Expect(s.Push(data, 1, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(1)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(15)))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
Expect(s.queue[1]).To(Equal(data))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 1},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that closes all gaps (except for the last one)", func() {
data := bytes.Repeat([]byte{'f'}, 32)
// 0 to 32
Expect(s.Push(data, 0, false)).To(Succeed())
Expect(s.queue).To(HaveLen(1))
Expect(s.queue).To(HaveKey(protocol.ByteCount(0)))
Expect(s.queue[0]).To(Equal(data))
checkGaps([]utils.ByteInterval{
{Start: 32, End: protocol.MaxByteCount},
})
})
It("cuts a frame that overlaps at the beginning and at the end, starting in data already received", func() {
// 8 to 17
Expect(s.Push([]byte("123456789"), 8, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal([]byte("34567")))
Expect(s.queue[10]).To(HaveCap(5))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame that completely covers two gaps", func() {
// 10 to 20
Expect(s.Push([]byte("1234567890"), 10, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal([]byte("12345")))
Expect(s.queue[10]).To(HaveCap(5))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
})
Context("duplicate data", func() {
expectedGaps := []utils.ByteInterval{
{Start: 5, End: 10},
{Start: 15, End: protocol.MaxByteCount},
}
BeforeEach(func() {
// create gaps: 5-10, 15-inf
Expect(s.Push([]byte("12345"), 0, false)).To(Succeed())
Expect(s.Push([]byte("12345"), 10, false)).To(Succeed())
checkGaps(expectedGaps)
})
AfterEach(func() {
// check that the gaps were not modified
checkGaps(expectedGaps)
})
It("does not modify data when receiving a duplicate", func() {
err := s.push([]byte("fffff"), 0, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[0]).ToNot(Equal([]byte("fffff")))
})
It("detects a duplicate frame that is smaller than the original, starting at the beginning", func() {
// 10 to 12
err := s.push([]byte("12"), 10, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[10]).To(HaveLen(5))
})
It("detects a duplicate frame that is smaller than the original, somewhere in the middle", func() {
// 1 to 4
err := s.push([]byte("123"), 1, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[0]).To(HaveLen(5))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1)))
})
It("detects a duplicate frame that is smaller than the original, somewhere in the middle in the last block", func() {
// 11 to 14
err := s.push([]byte("123"), 11, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[10]).To(HaveLen(5))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11)))
})
It("detects a duplicate frame that is smaller than the original, with aligned end in the last block", func() {
// 11 to 15
err := s.push([]byte("1234"), 1, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[10]).To(HaveLen(5))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11)))
})
It("detects a duplicate frame that is smaller than the original, with aligned end", func() {
// 3 to 5
err := s.push([]byte("12"), 3, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[0]).To(HaveLen(5))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(3)))
})
})
Context("DoS protection", func() {
It("errors when too many gaps are created", func() {
for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ {
Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7), false)).To(Succeed())
}
Expect(s.gaps.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps))
err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, false)
Expect(err).To(MatchError("Too many gaps in received data"))
})
})
})
})
})

View File

@ -1,639 +0,0 @@
package h2quic
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
"net/http"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client", func() {
var (
client *client
session *mockSession
headerStream *mockStream
req *http.Request
origDialAddr = dialAddr
)
injectResponse := func(id protocol.StreamID, rsp *http.Response) {
EventuallyWithOffset(0, func() bool {
client.mutex.Lock()
defer client.mutex.Unlock()
_, ok := client.responses[id]
return ok
}).Should(BeTrue())
rspChan := client.responses[5]
ExpectWithOffset(0, rspChan).ToNot(BeClosed())
rspChan <- rsp
}
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal(hostname))
session = newMockSession()
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
client.session = session
headerStream = newMockStream(3)
client.headerStream = headerStream
client.requestWriter = newRequestWriter(headerStream, utils.DefaultLogger)
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
Expect(client.config).To(Equal(quicConf))
})
It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig))
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
close(headerStream.unblockRead)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
// fmt.Println("done")
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("uses the custom dialer, if provided", func() {
var tlsCfg *tls.Config
var qCfg *quic.Config
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
tlsCfg = tlsCfgP
qCfg = cfg
return session, nil
}
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
Expect(qCfg).To(Equal(client.config))
Expect(tlsCfg).To(Equal(client.tlsConf))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamOpenErr = testErr
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("returns a request when dial fails", func() {
testErr := errors.New("dial error")
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
close(done)
}()
_, err = client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Eventually(done).Should(BeClosed())
})
Context("Doing requests", func() {
var request *http.Request
var dataStream *mockStream
getRequest := func(data []byte) *http2.MetaHeadersFrame {
r := bytes.NewReader(data)
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, r)
frame, err := h2framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
return mhframe
}
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
fields := make(map[string]string)
for _, hf := range f.Fields {
fields[hf.Name] = hf.Value
}
return fields
}
BeforeEach(func() {
var err error
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
dataStream = newMockStream(5)
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
It("does a request", func() {
teapot := &http.Response{
Status: "418 I'm a teapot",
StatusCode: 418,
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).To(Equal(teapot))
Expect(rsp.Body).To(Equal(dataStream))
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(rsp.Request).To(Equal(request))
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
injectResponse(5, teapot)
Expect(client.headerErrored).ToNot(BeClosed())
Eventually(done).Should(BeClosed())
})
It("errors if a request without a body is canceled", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
cancel()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("errors if a request with a body is canceled after the body is sent", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
request.Body = &mockBody{}
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
time.Sleep(10 * time.Millisecond)
cancel()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("errors if a request with a body is canceled before the body is sent", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
request.Body = &mockBody{}
cancel()
time.Sleep(10 * time.Millisecond)
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("closes the quic client when encountering an error on the header stream", func() {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(client.headerErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
})
It("returns subsequent request if there was an error on the header stream before", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
_, err := client.RoundTrip(request)
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
// now that the first request failed due to an error on the header stream, try another request
_, nextErr := client.RoundTrip(request)
Expect(nextErr).To(MatchError(err))
})
It("blocks if no stream is available", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
session.blockOpenStreamSync = true
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
client.Close()
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("refuses to do plain HTTP requests", func() {
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("quic http2: unsupported scheme"))
})
It("adds the port for request URLs without one", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
})
It("sets the EndStream header for requests without a body", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RoundTrip(request)
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("sets the EndStream header to false for requests with a body", func() {
request.Body = &mockBody{}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RoundTrip(request)
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
Context("requests containing a Body", func() {
var requestBody []byte
var response *http.Response
BeforeEach(func() {
requestBody = []byte("request body")
body := &mockBody{}
body.SetData(requestBody)
request.Body = body
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
// fake a handshake
client.dialOnce.Do(func() {})
session.streamsToOpen = []quic.Stream{dataStream}
})
It("sends a request", func() {
rspChan := make(chan *http.Response)
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
rspChan <- rsp
}()
injectResponse(5, response)
Eventually(rspChan).Should(Receive(Equal(response)))
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
Expect(dataStream.closed).To(BeTrue())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
It("returns the error that occurred when reading the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).readErr = testErr
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
It("returns the error that occurred when closing the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).closeErr = testErr
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
})
Context("gzip compression", func() {
var gzippedData []byte // a gzipped foobar
var response *http.Response
BeforeEach(func() {
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write([]byte("foobar"))
w.Close()
gzippedData = b.Bytes()
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
})
It("adds the gzip header to requests", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).ToNot(BeNil())
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(rsp.Header.Get("Content-Length")).To(BeEmpty())
data := make([]byte, 6)
_, err = io.ReadFull(rsp.Body, data)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
close(done)
}()
dataStream.dataToRead.Write(gzippedData)
response.Header.Add("Content-Encoding", "gzip")
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
close(dataStream.unblockRead)
Eventually(done).Should(BeClosed())
})
It("doesn't add gzip if the header disable it", func() {
client.opts.DisableCompression = true
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).ToNot(HaveKey("accept-encoding"))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).ToNot(BeNil())
data := make([]byte, 11)
rsp.Body.Read(data)
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("not gzipped")))
close(done)
}()
dataStream.dataToRead.Write([]byte("not gzipped"))
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
Eventually(done).Should(BeClosed())
})
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
request.Header.Add("accept-encoding", "gzip")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data := make([]byte, 12)
_, err = rsp.Body.Read(data)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("gzipped data")))
close(done)
}()
dataStream.dataToRead.Write([]byte("gzipped data"))
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
Eventually(done).Should(BeClosed())
})
})
Context("handling the header stream", func() {
var h2framer *http2.Framer
BeforeEach(func() {
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
client.responses[23] = make(chan *http.Response)
})
It("reads header values from a response", func() {
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
headerStream.dataToRead.Write(data)
go client.handleHeaderStream()
var rsp *http.Response
Eventually(client.responses[23]).Should(Receive(&rsp))
Expect(rsp).ToNot(BeNil())
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
Expect(rsp.Status).To(Equal("302 Found"))
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
})
It("errors if the H2 frame is not a HeadersFrame", func() {
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
})
It("errors if it can't read the HPACK encoded header fields", func() {
h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 23,
EndHeaders: true,
BlockFragment: []byte("invalid HPACK data"),
})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
})
It("errors if the stream cannot be found", func() {
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1337,
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
Expect(err).ToNot(HaveOccurred())
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
})
})
})
})

View File

@ -1,13 +0,0 @@
package h2quic
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestH2quic(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "H2quic Suite")
}

View File

@ -1,39 +0,0 @@
package h2quic
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request body", func() {
var (
stream *mockStream
rb *requestBody
)
BeforeEach(func() {
stream = &mockStream{}
stream.dataToRead.Write([]byte("foobar")) // provides data to be read
rb = newRequestBody(stream)
})
It("reads from the stream", func() {
b := make([]byte, 10)
n, _ := stream.Read(b)
Expect(n).To(Equal(6))
Expect(b[0:6]).To(Equal([]byte("foobar")))
})
It("saves if the stream was read from", func() {
Expect(rb.requestRead).To(BeFalse())
rb.Read(make([]byte, 1))
Expect(rb.requestRead).To(BeTrue())
})
It("doesn't close the stream when closing the request body", func() {
Expect(stream.closed).To(BeFalse())
err := rb.Close()
Expect(err).ToNot(HaveOccurred())
Expect(stream.closed).To(BeFalse())
})
})

View File

@ -1,121 +0,0 @@
package h2quic
import (
"net/http"
"net/url"
"golang.org/x/net/http2/hpack"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request", func() {
It("populates request", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "content-length", Value: "42"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Method).To(Equal("GET"))
Expect(req.URL.Path).To(Equal("/foo"))
Expect(req.Proto).To(Equal("HTTP/2.0"))
Expect(req.ProtoMajor).To(Equal(2))
Expect(req.ProtoMinor).To(Equal(0))
Expect(req.ContentLength).To(Equal(int64(42)))
Expect(req.Header).To(BeEmpty())
Expect(req.Body).To(BeNil())
Expect(req.Host).To(Equal("quic.clemente.io"))
Expect(req.RequestURI).To(Equal("/foo"))
Expect(req.TLS).ToNot(BeNil())
})
It("concatenates the cookie headers", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "cookie", Value: "cookie1=foobar1"},
{Name: "cookie", Value: "cookie2=foobar2"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Header).To(Equal(http.Header{
"Cookie": []string{"cookie1=foobar1; cookie2=foobar2"},
}))
})
It("handles other headers", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "cache-control", Value: "max-age=0"},
{Name: "duplicate-header", Value: "1"},
{Name: "duplicate-header", Value: "2"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Header).To(Equal(http.Header{
"Cache-Control": []string{"max-age=0"},
"Duplicate-Header": []string{"1", "2"},
}))
})
It("errors with missing path", func() {
headers := []hpack.HeaderField{
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
It("errors with missing method", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
It("errors with missing authority", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":method", Value: "GET"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
Context("extracting the hostname from a request", func() {
var url *url.URL
BeforeEach(func() {
var err error
url, err = url.Parse("https://quic.clemente.io:1337")
Expect(err).ToNot(HaveOccurred())
})
It("uses req.URL.Host", func() {
req := &http.Request{URL: url}
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
})
It("uses req.URL.Host even if req.Host is available", func() {
req := &http.Request{
Host: "www.example.org",
URL: url,
}
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
})
It("returns an empty hostname if nothing is set", func() {
Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty())
})
})
})

View File

@ -1,118 +0,0 @@
package h2quic
import (
"bytes"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request", func() {
var (
rw *requestWriter
headerStream *mockStream
decoder *hpack.Decoder
)
BeforeEach(func() {
headerStream = &mockStream{}
rw = newRequestWriter(headerStream, utils.DefaultLogger)
decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
})
decode := func(p []byte) (*http2.HeadersFrame, map[string] /* HeaderField.Name */ string /* HeaderField.Value */) {
framer := http2.NewFramer(nil, bytes.NewReader(p))
frame, err := framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
headerFrame := frame.(*http2.HeadersFrame)
fields, err := decoder.DecodeFull(headerFrame.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
values := make(map[string]string)
for _, headerField := range fields {
values[headerField.Name] = headerField.Value
}
return headerFrame, values
}
It("writes a GET request", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, false)
headerFrame, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamID).To(Equal(uint32(1337)))
Expect(headerFrame.HasPriority()).To(BeTrue())
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
})
It("sets the EndStream header", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, false)
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamEnded()).To(BeTrue())
})
It("doesn't set the EndStream header, if requested", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, false, false)
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamEnded()).To(BeFalse())
})
It("requests gzip compression, if requested", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, true)
_, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("writes a POST request", func() {
form := url.Values{}
form.Add("foo", "bar")
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode()))
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 5, true, false)
_, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
Expect(headerFields).To(HaveKey("content-length"))
contentLength, err := strconv.Atoi(headerFields["content-length"])
Expect(err).ToNot(HaveOccurred())
Expect(contentLength).To(BeNumerically(">", 0))
})
It("sends cookies", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
cookie1 := &http.Cookie{
Name: "Cookie #1",
Value: "Value #1",
}
cookie2 := &http.Cookie{
Name: "Cookie #2",
Value: "Value #2",
}
req.AddCookie(cookie1)
req.AddCookie(cookie2)
rw.WriteRequest(req, 11, true, false)
_, headerFields := decode(headerStream.dataWritten.Bytes())
// TODO(lclemente): Remove Or() once we drop support for Go 1.8.
Expect(headerFields).To(Or(
HaveKeyWithValue("cookie", "Cookie #1=Value #1; Cookie #2=Value #2"),
HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`),
))
})
})

View File

@ -1,163 +0,0 @@
package h2quic
import (
"bytes"
"context"
"io"
"net/http"
"sync"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockStream struct {
id protocol.StreamID
dataToRead bytes.Buffer
dataWritten bytes.Buffer
reset bool
canceledWrite bool
closed bool
remoteClosed bool
unblockRead chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
}
var _ quic.Stream = &mockStream{}
func newMockStream(id protocol.StreamID) *mockStream {
s := &mockStream{
id: id,
unblockRead: make(chan struct{}),
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
}
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil }
func (s *mockStream) CancelWrite(quic.ErrorCode) error { s.canceledWrite = true; return nil }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
func (s mockStream) StreamID() protocol.StreamID { return s.id }
func (s *mockStream) Context() context.Context { return s.ctx }
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) Read(p []byte) (int, error) {
n, _ := s.dataToRead.Read(p)
if n == 0 { // block if there's no data
<-s.unblockRead
return 0, io.EOF
}
return n, nil // never return an EOF
}
func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) }
var _ = Describe("Response Writer", func() {
var (
w *responseWriter
headerStream *mockStream
dataStream *mockStream
)
BeforeEach(func() {
headerStream = &mockStream{}
dataStream = &mockStream{}
w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5, utils.DefaultLogger)
})
decodeHeaderFields := func() map[string][]string {
fields := make(map[string][]string)
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, bytes.NewReader(headerStream.dataWritten.Bytes()))
frame, err := h2framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&http2.HeadersFrame{}))
hframe := frame.(*http2.HeadersFrame)
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
Expect(mhframe.StreamID).To(BeEquivalentTo(5))
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
for _, p := range mhframe.Fields {
fields[p.Name] = append(fields[p.Name], p.Value)
}
return fields
}
It("writes status", func() {
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveLen(1))
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
})
It("writes headers", func() {
w.Header().Add("content-length", "42")
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
})
It("writes multiple headers with the same name", func() {
const cookie1 = "test1=1; Max-Age=7200; path=/"
const cookie2 = "test2=2; Max-Age=7200; path=/"
w.Header().Add("set-cookie", cookie1)
w.Header().Add("set-cookie", cookie2)
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveKey("set-cookie"))
cookies := fields["set-cookie"]
Expect(cookies).To(ContainElement(cookie1))
Expect(cookies).To(ContainElement(cookie2))
})
It("writes data", func() {
n, err := w.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
// Should have written 200 on the header stream
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
// And foobar on the data stream
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
})
It("writes data after WriteHeader is called", func() {
w.WriteHeader(http.StatusTeapot)
n, err := w.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
// Should have written 418 on the header stream
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
// And foobar on the data stream
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
})
It("does not WriteHeader() twice", func() {
w.WriteHeader(200)
w.WriteHeader(500)
fields := decodeHeaderFields()
Expect(fields).To(HaveLen(1))
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
})
It("doesn't allow writes if the status code doesn't allow a body", func() {
w.WriteHeader(304)
n, err := w.Write([]byte("foobar"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(http.ErrBodyNotAllowed))
Expect(dataStream.dataWritten.Bytes()).To(HaveLen(0))
})
})

View File

@ -1,218 +0,0 @@
package h2quic
import (
"bytes"
"crypto/tls"
"errors"
"io"
"net/http"
"time"
quic "github.com/lucas-clemente/quic-go"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockClient struct {
closed bool
}
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{Request: req}, nil
}
func (m *mockClient) Close() error {
m.closed = true
return nil
}
var _ roundTripCloser = &mockClient{}
type mockBody struct {
reader bytes.Reader
readErr error
closeErr error
closed bool
}
func (m *mockBody) Read(p []byte) (int, error) {
if m.readErr != nil {
return 0, m.readErr
}
return m.reader.Read(p)
}
func (m *mockBody) SetData(data []byte) {
m.reader = *bytes.NewReader(data)
}
func (m *mockBody) Close() error {
m.closed = true
return m.closeErr
}
// make sure the mockBody can be used as a http.Request.Body
var _ io.ReadCloser = &mockBody{}
var _ = Describe("RoundTripper", func() {
var (
rt *RoundTripper
req1 *http.Request
)
BeforeEach(func() {
rt = &RoundTripper{}
var err error
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
})
Context("dialing hosts", func() {
origDialAddr := dialAddr
streamOpenErr := errors.New("error opening stream")
BeforeEach(func() {
origDialAddr = dialAddr
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
// return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all
return &mockSession{streamOpenErr: streamOpenErr}, nil
}
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("creates new clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
})
It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeTimeout: time.Millisecond}
var receivedConfig *quic.Config
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
receivedConfig = config
return nil, errors.New("err")
}
rt.QuicConfig = config
rt.RoundTrip(req1)
Expect(receivedConfig).To(Equal(config))
})
It("uses the custom dialer, if provided", func() {
var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
dialed = true
return nil, errors.New("err")
}
rt.Dial = dialer
rt.RoundTrip(req1)
Expect(dialed).To(BeTrue())
})
It("reuses existing clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req2)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
})
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
Expect(err).To(MatchError(ErrNoCachedConn))
})
})
Context("validating request", func() {
It("rejects plain HTTP requests", func() {
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
req.Body = &mockBody{}
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError("quic: unsupported protocol scheme: http"))
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects requests without a URL", func() {
req1.URL = nil
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects request without a URL Host", func() {
req1.URL.Host = ""
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: no Host in request URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
})
It("doesn't try to close the body if the request doesn't have one", func() {
req1.URL = nil
Expect(req1.Body).To(BeNil())
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.URL"))
})
It("rejects requests without a header", func() {
req1.Header = nil
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.Header"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects requests with invalid header name fields", func() {
req1.Header.Add("foobär", "value")
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: invalid http header field name \"foobär\""))
})
It("rejects requests with invalid header name values", func() {
req1.Header.Add("foo", string([]byte{0x7}))
_, err := rt.RoundTrip(req1)
Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value"))
})
It("rejects requests with an invalid request method", func() {
req1.Method = "foobär"
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: invalid method \"foobär\""))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
})
})
Context("closing", func() {
It("closes", func() {
rt.clients = make(map[string]roundTripCloser)
cl := &mockClient{}
rt.clients["foo.bar"] = cl
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero())
Expect(cl.closed).To(BeTrue())
})
It("closes a RoundTripper that has never been used", func() {
Expect(len(rt.clients)).To(BeZero())
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero())
})
})
})

View File

@ -1,536 +0,0 @@
package h2quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockSession struct {
closed bool
closedWithError error
dataStream quic.Stream
streamToAccept quic.Stream
streamsToOpen []quic.Stream
blockOpenStreamSync bool
blockOpenStreamChan chan struct{} // close this chan (or call Close) to make OpenStreamSync return
streamOpenErr error
ctx context.Context
ctxCancel context.CancelFunc
}
func newMockSession() *mockSession {
return &mockSession{blockOpenStreamChan: make(chan struct{})}
}
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
return s.dataStream, nil
}
func (s *mockSession) AcceptStream() (quic.Stream, error) { return s.streamToAccept, nil }
func (s *mockSession) OpenStream() (quic.Stream, error) {
if s.streamOpenErr != nil {
return nil, s.streamOpenErr
}
str := s.streamsToOpen[0]
s.streamsToOpen = s.streamsToOpen[1:]
return str, nil
}
func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
if s.blockOpenStreamSync {
<-s.blockOpenStreamChan
}
return s.OpenStream()
}
func (s *mockSession) Close() error {
s.ctxCancel()
if !s.closed {
close(s.blockOpenStreamChan)
}
s.closed = true
return nil
}
func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error {
s.closedWithError = e
return s.Close()
}
func (s *mockSession) LocalAddr() net.Addr {
panic("not implemented")
}
func (s *mockSession) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
}
func (s *mockSession) Context() context.Context {
return s.ctx
}
func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") }
func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") }
func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") }
func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") }
var _ = Describe("H2 server", func() {
var (
s *Server
session *mockSession
dataStream *mockStream
origQuicListenAddr = quicListenAddr
)
BeforeEach(func() {
s = &Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
logger: utils.DefaultLogger,
}
dataStream = newMockStream(0)
close(dataStream.unblockRead)
session = newMockSession()
session.dataStream = dataStream
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
origQuicListenAddr = quicListenAddr
})
AfterEach(func() {
quicListenAddr = origQuicListenAddr
})
Context("handling requests", func() {
var (
h2framer *http2.Framer
hpackDecoder *hpack.Decoder
headerStream *mockStream
)
BeforeEach(func() {
headerStream = &mockStream{}
hpackDecoder = hpack.NewDecoder(4096, nil)
h2framer = http2.NewFramer(nil, headerStream)
})
It("handles a sample GET request", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.RemoteAddr).To(Equal("127.0.0.1:42"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeTrue())
Expect(dataStream.reset).To(BeFalse())
})
It("returns 200 with an empty handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() []byte {
return headerStream.dataWritten.Bytes()
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200
})
It("correctly handles a panicking handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("foobar")
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() []byte {
return headerStream.dataWritten.Bytes()
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500
})
It("resets the dataStream when client sends a body in GET request", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeFalse())
})
It("resets the dataStream when the body of POST request is not read", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.Method).To(Equal("POST"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
Expect(handlerCalled).To(BeTrue())
})
It("handles a request for which the client immediately resets the data stream", func() {
session.dataStream = nil
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
})
It("resets the dataStream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = struct {
io.Reader
io.Closer
}{}
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
Expect(handlerCalled).To(BeTrue())
})
It("closes the dataStream if the body of POST request was read", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.Method).To(Equal("POST"))
handlerCalled = true
// read the request body
b := make([]byte, 1000)
n, _ := r.Body.Read(b)
Expect(n).ToNot(BeZero())
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
dataStream.dataToRead.Write([]byte("foo=bar"))
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.reset).To(BeFalse())
})
It("ignores PRIORITY frames", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerCalled)
})
buf := &bytes.Buffer{}
framer := http2.NewFramer(buf, nil)
err := framer.WritePriority(10, http2.PriorityParam{Weight: 42})
Expect(err).ToNot(HaveOccurred())
Expect(buf.Bytes()).ToNot(BeEmpty())
headerStream.dataToRead.Write(buf.Bytes())
err = s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).ToNot(HaveOccurred())
Consistently(handlerCalled).ShouldNot(BeClosed())
Expect(dataStream.reset).To(BeFalse())
Expect(dataStream.closed).To(BeFalse())
})
It("errors when non-header frames are received", func() {
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x06, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5,
'f', 'o', 'o', 'b', 'a', 'r',
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame"))
})
It("Cancels the request context when the datstream is closed", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := r.Context().Err()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("context canceled"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
dataStream.Close()
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeTrue())
Expect(dataStream.reset).To(BeFalse())
})
})
It("handles the header stream", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
})
It("closes the connection if it encounters an error on the header stream", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
Eventually(func() bool { return session.closed }).Should(BeTrue())
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
})
It("supports closing after first request", func() {
s.CloseAfterFirstRequest = true
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
Expect(session.closed).To(BeFalse())
go s.handleHeaderStream(session)
Eventually(func() bool { return session.closed }).Should(BeTrue())
})
It("uses the default handler as fallback", func() {
var handlerCalled bool
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
}))
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
})
Context("setting http headers", func() {
var expected http.Header
getExpectedHeader := func(versions []protocol.VersionNumber) http.Header {
var versionsAsString []string
for _, v := range versions {
versionsAsString = append(versionsAsString, v.ToAltSvc())
}
return http.Header{
"Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))},
}
}
BeforeEach(func() {
Expect(getExpectedHeader([]protocol.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}}))
expected = getExpectedHeader(protocol.SupportedVersions)
})
It("sets proper headers with numeric port", func() {
s.Server.Addr = ":443"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("sets proper headers with full addr", func() {
s.Server.Addr = "127.0.0.1:443"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("sets proper headers with string port", func() {
s.Server.Addr = ":https"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("works multiple times", func() {
s.Server.Addr = ":https"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
hdr = http.Header{}
err = s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
})
It("should error when ListenAndServe is called with s.Server nil", func() {
err := (&Server{}).ListenAndServe()
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
err := (&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should nop-Close() when s.server is nil", func() {
err := (&Server{}).Close()
Expect(err).NotTo(HaveOccurred())
})
It("errors when ListenAndServer is called after Close", func() {
serv := &Server{Server: &http.Server{}}
Expect(serv.Close()).To(Succeed())
err := serv.ListenAndServe()
Expect(err).To(MatchError("Server is already closed"))
})
Context("ListenAndServe", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"
})
AfterEach(func() {
Expect(s.Close()).To(Succeed())
})
It("may only be called once", func() {
cErr := make(chan error)
for i := 0; i < 2; i++ {
go func() {
defer GinkgoRecover()
err := s.ListenAndServe()
if err != nil {
cErr <- err
}
}()
}
err := <-cErr
Expect(err).To(MatchError("ListenAndServe may only be called once"))
err = s.Close()
Expect(err).NotTo(HaveOccurred())
}, 0.5)
It("uses the quic.Config to start the quic server", func() {
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
var receivedConf *quic.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
receivedConf = config
return nil, errors.New("listen err")
}
s.QuicConfig = conf
go s.ListenAndServe()
Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf))
})
})
Context("ListenAndServeTLS", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"
})
AfterEach(func() {
err := s.Close()
Expect(err).NotTo(HaveOccurred())
})
It("may only be called once", func() {
cErr := make(chan error)
for i := 0; i < 2; i++ {
go func() {
defer GinkgoRecover()
err := s.ListenAndServeTLS(testdata.GetCertificatePaths())
if err != nil {
cErr <- err
}
}()
}
err := <-cErr
Expect(err).To(MatchError("ListenAndServe may only be called once"))
err = s.Close()
Expect(err).NotTo(HaveOccurred())
}, 0.5)
})
It("closes gracefully", func() {
err := s.CloseGracefully(0)
Expect(err).NotTo(HaveOccurred())
})
It("errors when listening fails", func() {
testErr := errors.New("listen error")
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
return nil, testErr
}
fullpem, privkey := testdata.GetCertificatePaths()
err := ListenAndServeQUIC("", fullpem, privkey, nil)
Expect(err).To(MatchError(testErr))
})
})

View File

@ -1,210 +0,0 @@
package chrome_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync/atomic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gexec"
"testing"
)
const (
dataLen = 500 * 1024 // 500 KB
dataLongLen = 50 * 1024 * 1024 // 50 MB
)
var (
nFilesUploaded int32 // should be used atomically
doneCalled utils.AtomicBool
)
func TestChrome(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Chrome Suite")
}
func init() {
// Requires the len & num GET parameters, e.g. /uploadtest?len=100&num=1
http.HandleFunc("/uploadtest", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
response := uploadHTML
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
_, err := io.WriteString(w, response)
Expect(err).NotTo(HaveOccurred())
})
// Requires the len & num GET parameters, e.g. /downloadtest?len=100&num=1
http.HandleFunc("/downloadtest", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
response := downloadHTML
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
_, err := io.WriteString(w, response)
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/uploadhandler", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
l, err := strconv.Atoi(r.URL.Query().Get("len"))
Expect(err).NotTo(HaveOccurred())
defer r.Body.Close()
actual, err := ioutil.ReadAll(r.Body)
Expect(err).NotTo(HaveOccurred())
Expect(bytes.Equal(actual, testserver.GeneratePRData(l))).To(BeTrue())
atomic.AddInt32(&nFilesUploaded, 1)
})
http.HandleFunc("/done", func(w http.ResponseWriter, r *http.Request) {
doneCalled.Set(true)
})
}
var _ = AfterEach(func() {
testserver.StopQuicServer()
atomic.StoreInt32(&nFilesUploaded, 0)
doneCalled.Set(false)
})
func getChromePath() string {
if runtime.GOOS == "darwin" {
return "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
}
if path, err := exec.LookPath("google-chrome"); err == nil {
return path
}
if path, err := exec.LookPath("chromium-browser"); err == nil {
return path
}
Fail("No Chrome executable found.")
return ""
}
func chromeTest(version protocol.VersionNumber, url string, blockUntilDone func()) {
userDataDir, err := ioutil.TempDir("", "quic-go-test-chrome-dir")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(userDataDir)
path := getChromePath()
args := []string{
"--disable-gpu",
"--no-first-run=true",
"--no-default-browser-check=true",
"--user-data-dir=" + userDataDir,
"--enable-quic=true",
"--no-proxy-server=true",
"--no-sandbox",
"--origin-to-force-quic-on=quic.clemente.io:443",
fmt.Sprintf(`--host-resolver-rules=MAP quic.clemente.io:443 127.0.0.1:%s`, testserver.Port()),
fmt.Sprintf("--quic-version=QUIC_VERSION_%s", version.ToAltSvc()),
url,
}
utils.DefaultLogger.Infof("Running chrome: %s '%s'", getChromePath(), strings.Join(args, "' '"))
command := exec.Command(path, args...)
session, err := gexec.Start(command, nil, nil)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
blockUntilDone()
}
func waitForDone() {
Eventually(func() bool { return doneCalled.Get() }, 60).Should(BeTrue())
}
func waitForNUploaded(expected int) func() {
return func() {
Eventually(func() int32 {
return atomic.LoadInt32(&nFilesUploaded)
}, 60).Should(BeEquivalentTo(expected))
}
}
const commonJS = `
var buf = new ArrayBuffer(LENGTH);
var prng = new Uint8Array(buf);
var seed = 1;
for (var i = 0; i < LENGTH; i++) {
// https://en.wikipedia.org/wiki/Lehmer_random_number_generator
seed = seed * 48271 % 2147483647;
prng[i] = seed;
}
`
const uploadHTML = `
<html>
<body>
<script>
console.log("Running DL test...");
` + commonJS + `
for (var i = 0; i < NUM; i++) {
var req = new XMLHttpRequest();
req.open("POST", "/uploadhandler?len=" + LENGTH, true);
req.send(buf);
}
</script>
</body>
</html>
`
const downloadHTML = `
<html>
<body>
<script>
console.log("Running DL test...");
` + commonJS + `
function verify(data) {
if (data.length !== LENGTH) return false;
for (var i = 0; i < LENGTH; i++) {
if (data[i] !== prng[i]) return false;
}
return true;
}
var nOK = 0;
for (var i = 0; i < NUM; i++) {
let req = new XMLHttpRequest();
req.responseType = "arraybuffer";
req.open("POST", "/prdata?len=" + LENGTH, true);
req.onreadystatechange = function () {
if (req.readyState === XMLHttpRequest.DONE && req.status === 200) {
if (verify(new Uint8Array(req.response))) {
nOK++;
if (nOK === NUM) {
console.log("Done :)");
var reqDone = new XMLHttpRequest();
reqDone.open("GET", "/done");
reqDone.send();
}
}
}
};
req.send();
}
</script>
</body>
</html>
`

View File

@ -1,76 +0,0 @@
package chrome_test
import (
"fmt"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
)
var _ = Describe("Chrome tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
// TODO: activate Chrome integration tests with gQUIC 44
if version == protocol.Version44 {
continue
}
Context(fmt.Sprintf("with version %s", version), func() {
JustBeforeEach(func() {
testserver.StartQuicServer([]protocol.VersionNumber{version})
})
It("downloads a small file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLen),
waitForDone,
)
})
It("downloads a large file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLongLen),
waitForDone,
)
})
It("loads a large number of files", func() {
chromeTest(
version,
"https://quic.clemente.io/downloadtest?num=4&len=100",
waitForDone,
)
})
It("uploads a small file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLen),
waitForNUploaded(1),
)
})
It("uploads a large file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLongLen),
waitForNUploaded(1),
)
})
It("uploads many small files", func() {
num := protocol.DefaultMaxIncomingStreams + 20
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=%d&len=%d", num, dataLen),
waitForNUploaded(num),
)
})
})
}
})

View File

@ -1,137 +0,0 @@
package gquic_test
import (
"bytes"
"fmt"
mrand "math/rand"
"os/exec"
"strconv"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
var _ = Describe("Drop tests", func() {
var proxy *quicproxy.QuicProxy
startProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DropPacket: dropCallback,
})
Expect(err).ToNot(HaveOccurred())
}
downloadFile := func(version protocol.VersionNumber) {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
downloadHello := func(version protocol.VersionNumber) {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/hello",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: Hello, World!\n"))
}
deterministicDropper := func(p, interval, dropInARow uint64) bool {
return (p % interval) < dropInARow
}
stochasticDropper := func(freq int) bool {
return mrand.Int63n(int64(freq)) == 0
}
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
for _, v := range protocol.SupportedVersions {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
Context("during the crypto handshake", func() {
for _, d := range directions {
direction := d
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 1 && d.Is(direction)
}, version)
downloadHello(version)
})
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 2 && d.Is(direction)
}, version)
downloadHello(version)
})
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return d.Is(direction) && stochasticDropper(5)
}, version)
downloadHello(version)
})
}
})
Context("after the crypto handshake", func() {
for _, d := range directions {
direction := d
It(fmt.Sprintf("downloads a file when every 5th packet is dropped in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && deterministicDropper(p, 5, 1)
}, version)
downloadFile(version)
})
It(fmt.Sprintf("downloads a file when 1/5th of all packet are dropped randomly in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && stochasticDropper(5)
}, version)
downloadFile(version)
})
It(fmt.Sprintf("downloads a file when 10 packets every 100 packet are dropped in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && deterministicDropper(p, 100, 10)
}, version)
downloadFile(version)
})
}
})
})
}
})

View File

@ -1,45 +0,0 @@
package gquic_test
import (
"fmt"
"math/rand"
"path/filepath"
"runtime"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
var (
clientPath string
serverPath string
)
func TestIntegration(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "GQuic Tests Suite")
}
var _ = BeforeSuite(func() {
rand.Seed(GinkgoRandomSeed())
})
var _ = JustBeforeEach(func() {
testserver.StartQuicServer(nil)
})
var _ = AfterEach(testserver.StopQuicServer)
func init() {
_, thisfile, _, ok := runtime.Caller(0)
if !ok {
panic("Failed to get current path")
}
clientPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/client-%s-debug", runtime.GOOS))
serverPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/server-%s-debug", runtime.GOOS))
}

View File

@ -1,98 +0,0 @@
package gquic_test
import (
"bytes"
"fmt"
"os/exec"
"sync"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
_ "github.com/lucas-clemente/quic-clients" // download clients
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("Integration tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("gets a simple file", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/hello",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 5).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: Hello, World!\n"))
})
It("posts and reads a body", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"--body=foo",
"https://quic.clemente.io/echo",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 5).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: foo\n"))
})
It("gets a file", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 10).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
})
It("gets many copies of a file in parallel", func() {
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}()
}
wg.Wait()
})
})
}
})

View File

@ -1,95 +0,0 @@
package gquic_test
import (
"bytes"
"fmt"
"math/rand"
"os/exec"
"strconv"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gexec"
)
// get a random duration between min and max
func getRandomDuration(min, max time.Duration) time.Duration {
return min + time.Duration(rand.Int63n(int64(max-min)))
}
var _ = Describe("Random Duration Generator", func() {
It("gets a random RTT", func() {
var min time.Duration = time.Hour
var max time.Duration
var sum time.Duration
rep := 10000
for i := 0; i < rep; i++ {
val := getRandomDuration(100*time.Millisecond, 500*time.Millisecond)
sum += val
if val < min {
min = val
}
if val > max {
max = val
}
}
avg := sum / time.Duration(rep)
Expect(avg).To(BeNumerically("~", 300*time.Millisecond, 5*time.Millisecond))
Expect(min).To(BeNumerically(">=", 100*time.Millisecond))
Expect(min).To(BeNumerically("<", 105*time.Millisecond))
Expect(max).To(BeNumerically(">", 495*time.Millisecond))
Expect(max).To(BeNumerically("<=", 500*time.Millisecond))
})
})
var _ = Describe("Random RTT", func() {
var proxy *quicproxy.QuicProxy
runRTTTest := func(minRtt, maxRtt time.Duration, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
return getRandomDuration(minRtt, maxRtt)
},
})
Expect(err).ToNot(HaveOccurred())
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
AfterEach(func() {
err := proxy.Close()
Expect(err).ToNot(HaveOccurred())
time.Sleep(time.Millisecond)
})
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("gets a file a random RTT between 10ms and 30ms", func() {
runRTTTest(10*time.Millisecond, 30*time.Millisecond, version)
})
})
}
})

View File

@ -1,66 +0,0 @@
package gquic_test
import (
"bytes"
"fmt"
"os/exec"
"strconv"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("non-zero RTT", func() {
var proxy *quicproxy.QuicProxy
runRTTTest := func(rtt time.Duration, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
AfterEach(func() {
err := proxy.Close()
Expect(err).ToNot(HaveOccurred())
time.Sleep(time.Millisecond)
})
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
roundTrips := [...]int{10, 50, 100, 200}
for _, rtt := range roundTrips {
It(fmt.Sprintf("gets a 500kB file with %dms RTT", rtt), func() {
runRTTTest(time.Duration(rtt)*time.Millisecond, version)
})
}
})
}
})

View File

@ -1,218 +0,0 @@
package gquic_test
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
mrand "math/rand"
"net/http"
"os"
"os/exec"
"path/filepath"
"strconv"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("Server tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
var (
serverPort string
tmpDir string
session *Session
client *http.Client
)
generateCA := func() (*rsa.PrivateKey, *x509.Certificate) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
templateRoot := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, templateRoot, templateRoot, &key.PublicKey, key)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return key, cert
}
// prepare the file such that it can be by the quic_server
// some HTTP headers neeed to be prepended, see https://www.chromium.org/quic/playing-with-quic
createDownloadFile := func(filename string, data []byte) {
dataDir := filepath.Join(tmpDir, "quic.clemente.io")
err := os.Mkdir(dataDir, 0777)
Expect(err).ToNot(HaveOccurred())
f, err := os.Create(filepath.Join(dataDir, filename))
Expect(err).ToNot(HaveOccurred())
defer f.Close()
_, err = f.Write([]byte("HTTP/1.1 200 OK\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("Content-Type: text/html\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("X-Original-Url: https://quic.clemente.io:" + serverPort + "/" + filename + "\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("Content-Length: " + strconv.Itoa(len(data)) + "\n\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write(data)
Expect(err).ToNot(HaveOccurred())
}
// download files must be create *before* the quic_server is started
// the quic_server reads its data dir on startup, and only serves those files that were already present then
startServer := func(version protocol.VersionNumber) {
defer GinkgoRecover()
var err error
command := exec.Command(
serverPath,
"--quic_response_cache_dir="+filepath.Join(tmpDir, "quic.clemente.io"),
"--key_file="+filepath.Join(tmpDir, "key.pkcs8"),
"--certificate_file="+filepath.Join(tmpDir, "cert.pem"),
"--quic-version="+strconv.Itoa(int(version)),
"--port="+serverPort,
)
session, err = Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
}
stopServer := func() {
session.Kill()
}
BeforeEach(func() {
serverPort = strconv.Itoa(20000 + int(mrand.Int31n(10000)))
var err error
tmpDir, err = ioutil.TempDir("", "quic-server-certs")
Expect(err).ToNot(HaveOccurred())
// generate an RSA key pair for the server
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
// save the private key in PKCS8 format to disk (required by quic_server)
pkcs8key, err := asn1.Marshal(struct { // copied from the x509 package
Version int
Algo pkix.AlgorithmIdentifier
PrivateKey []byte
}{
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
Algo: pkix.AlgorithmIdentifier{
Algorithm: asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1},
Parameters: asn1.RawValue{Tag: 5},
},
})
Expect(err).ToNot(HaveOccurred())
f, err := os.Create(filepath.Join(tmpDir, "key.pkcs8"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write(pkcs8key)
Expect(err).ToNot(HaveOccurred())
f.Close()
// generate a Certificate Authority
// this CA is used to sign the server's key
// it is set as a valid CA in the QUIC client
rootKey, CACert := generateCA()
// generate the server certificate
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-30 * time.Minute),
NotAfter: time.Now().Add(30 * time.Minute),
Subject: pkix.Name{CommonName: "quic.clemente.io"},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, CACert, &key.PublicKey, rootKey)
Expect(err).ToNot(HaveOccurred())
// save the certificate to disk
certOut, err := os.Create(filepath.Join(tmpDir, "cert.pem"))
Expect(err).ToNot(HaveOccurred())
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certOut.Close()
// prepare the h2quic.client
certPool := x509.NewCertPool()
certPool.AddCert(CACert)
client = &http.Client{
Transport: &h2quic.RoundTripper{
TLSClientConfig: &tls.Config{RootCAs: certPool},
QuicConfig: &quic.Config{
Versions: []protocol.VersionNumber{version},
},
},
}
})
AfterEach(func() {
Expect(tmpDir).ToNot(BeEmpty())
err := os.RemoveAll(tmpDir)
Expect(err).ToNot(HaveOccurred())
tmpDir = ""
})
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("downloads a hello", func() {
data := []byte("Hello world!\n")
createDownloadFile("hello", data)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/hello")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(data))
})
It("downloads a small file", func() {
createDownloadFile("file.dat", testserver.PRData)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRData))
})
It("downloads a large file", func() {
createDownloadFile("file.dat", testserver.PRDataLong)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 20*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRDataLong))
})
})
}
})

View File

@ -1,97 +0,0 @@
package self_test
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
var _ = Describe("Client tests", func() {
var client *http.Client
// also run some tests with the TLS handshake
versions := append(protocol.SupportedVersions, protocol.VersionTLS)
BeforeEach(func() {
err := os.Setenv("HOSTALIASES", "quic.clemente.io 127.0.0.1")
Expect(err).ToNot(HaveOccurred())
addr, err := net.ResolveUDPAddr("udp4", "quic.clemente.io:0")
Expect(err).ToNot(HaveOccurred())
if addr.String() != "127.0.0.1:0" {
Fail("quic.clemente.io does not resolve to 127.0.0.1. Consider adding it to /etc/hosts.")
}
testserver.StartQuicServer(versions)
})
AfterEach(func() {
testserver.StopQuicServer()
})
for _, v := range versions {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
BeforeEach(func() {
client = &http.Client{
Transport: &h2quic.RoundTripper{
QuicConfig: &quic.Config{
Versions: []protocol.VersionNumber{version},
},
},
}
})
It("downloads a hello", func() {
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/hello")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World!\n"))
})
It("downloads a small file", func() {
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/prdata")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRData))
})
It("downloads a large file", func() {
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/prdatalong")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRDataLong))
})
It("uploads a file", func() {
resp, err := client.Post(
"https://quic.clemente.io:"+testserver.Port()+"/echo",
"text/plain",
bytes.NewReader(testserver.PRData),
)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(bytes.Equal(body, testserver.PRData)).To(BeTrue())
})
})
}
})

View File

@ -1,101 +0,0 @@
package self_test
import (
"crypto/tls"
"fmt"
"io/ioutil"
"math/rand"
"net"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Connection ID lengths tests", func() {
randomConnIDLen := func() int {
return 4 + int(rand.Int31n(15))
}
runServer := func(conf *quic.Config) quic.Listener {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
ln, err := quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), conf)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
for {
sess, err := ln.Accept()
if err != nil {
return
}
go func() {
defer GinkgoRecover()
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(testserver.PRData)
Expect(err).ToNot(HaveOccurred())
}()
}
}()
return ln
}
runClient := func(addr net.Addr, conf *quic.Config) {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
cl, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", addr.(*net.UDPAddr).Port),
&tls.Config{InsecureSkipVerify: true},
conf,
)
Expect(err).ToNot(HaveOccurred())
defer cl.Close()
str, err := cl.AcceptStream()
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRData))
}
Context("IETF QUIC", func() {
It("downloads a file using a 0-byte connection ID for the client", func() {
serverConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
clientConf := &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
It("downloads a file when both client and server use a random connection ID length", func() {
serverConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
clientConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
})
Context("gQUIC", func() {
It("downloads a file using a 0-byte connection ID for the client", func() {
ln := runServer(&quic.Config{})
defer ln.Close()
runClient(ln.Addr(), &quic.Config{RequestConnectionIDOmission: true})
})
})
})

View File

@ -1,189 +0,0 @@
package self_test
import (
"fmt"
mrand "math/rand"
"net"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
type applicationProtocol struct {
name string
run func(protocol.VersionNumber)
}
var _ = Describe("Handshake drop tests", func() {
var (
proxy *quicproxy.QuicProxy
ln quic.Listener
)
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
var err error
ln, err = quic.ListenAddr(
"localhost:0",
testdata.GetTLSConfig(),
&quic.Config{
Versions: []protocol.VersionNumber{version},
},
)
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DropPacket: dropCallback,
},
)
Expect(err).ToNot(HaveOccurred())
}
stochasticDropper := func(freq int) bool {
return mrand.Int63n(int64(freq)) == 0
}
clientSpeaksFirst := &applicationProtocol{
name: "client speaks first",
run: func(version protocol.VersionNumber) {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
defer sess.Close()
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
serverSessionChan <- sess
}()
sess, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
sess.Close()
serverSession.Close()
},
}
serverSpeaksFirst := &applicationProtocol{
name: "server speaks first",
run: func(version protocol.VersionNumber) {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
serverSessionChan <- sess
}()
sess, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
sess.Close()
serverSession.Close()
},
}
nobodySpeaks := &applicationProtocol{
name: "nobody speaks",
run: func(version protocol.VersionNumber) {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
serverSessionChan <- sess
}()
sess, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
// both server and client accepted a session. Close now.
sess.Close()
serverSession.Close()
},
}
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
for _, d := range directions {
direction := d
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a
Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 1 && d.Is(direction)
}, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 2 && d.Is(direction)
}, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
return d.Is(direction) && stochasticDropper(5)
}, version)
app.run(version)
})
})
}
}
})
}
})

View File

@ -1,213 +0,0 @@
package self_test
import (
"crypto/tls"
"net"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake RTT tests", func() {
var (
proxy *quicproxy.QuicProxy
server quic.Listener
serverConfig *quic.Config
testStartedAt time.Time
acceptStopped chan struct{}
)
rtt := 400 * time.Millisecond
BeforeEach(func() {
acceptStopped = make(chan struct{})
serverConfig = &quic.Config{}
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
<-acceptStopped
})
runServerAndProxy := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", protocol.VersionWhatever, &quicproxy.Opts{
RemoteAddr: server.Addr().String(),
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
testStartedAt = time.Now()
go func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
_, err := server.Accept()
if err != nil {
return
}
}
}()
}
expectDurationInRTTs := func(num int) {
testDuration := time.Since(testStartedAt)
rtts := float32(testDuration) / float32(rtt)
Expect(rtts).To(SatisfyAll(
BeNumerically(">=", num),
BeNumerically("<", num+1),
))
}
It("fails when there's no matching version, after 1 RTT", func() {
if len(protocol.SupportedVersions) == 1 {
Skip("Test requires at least 2 supported versions.")
}
serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy()
clientConfig := &quic.Config{
Versions: protocol.SupportedVersions[1:2],
}
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
expectDurationInRTTs(1)
})
Context("gQUIC", func() {
// 1 RTT for verifying the source address
// 1 RTT to become secure
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
It("does version negotiation in 1 RTT, IETF QUIC => gQUIC", func() {
clientConfig := &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS, protocol.SupportedVersions[0]},
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
&tls.Config{InsecureSkipVerify: true},
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(4)
})
It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
expectDurationInRTTs(3)
})
})
Context("IETF QUIC", func() {
var clientConfig *quic.Config
var clientTLSConfig *tls.Config
BeforeEach(func() {
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
clientConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
clientTLSConfig = &tls.Config{
InsecureSkipVerify: true,
ServerName: "quic.clemente.io",
}
})
// 1 RTT for verifying the source address
// 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("does version negotiation in 1 RTT, gQUIC => IETF QUIC", func() {
clientConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[0], protocol.VersionTLS}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
It("is forward-secure after 1 RTTs when the server doesn't require a Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(1)
})
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.CryptoTooManyRejects))
})
})
})

View File

@ -1,128 +0,0 @@
package self_test
import (
"crypto/tls"
"fmt"
"net"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type versioner interface {
GetVersion() protocol.VersionNumber
}
var _ = Describe("Handshake tests", func() {
var (
server quic.Listener
serverConfig *quic.Config
acceptStopped chan struct{}
)
BeforeEach(func() {
server = nil
acceptStopped = make(chan struct{})
serverConfig = &quic.Config{}
})
AfterEach(func() {
if server != nil {
server.Close()
<-acceptStopped
}
})
runServer := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
_, err := server.Accept()
if err != nil {
return
}
}
}()
}
Context("Version Negotiation", func() {
var supportedVersions []protocol.VersionNumber
BeforeEach(func() {
supportedVersions = protocol.SupportedVersions
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...)
})
AfterEach(func() {
protocol.SupportedVersions = supportedVersions
})
It("when the server supports more versions than the client", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[1], 7, 8, 9}
runServer()
sess, err := quic.DialAddr(server.Addr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[1]))
})
It("when the client supports more versions than the server supports", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = supportedVersions
runServer()
conf := &quic.Config{
Versions: []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[1], 10},
}
sess, err := quic.DialAddr(server.Addr().String(), &tls.Config{InsecureSkipVerify: true}, conf)
Expect(err).ToNot(HaveOccurred())
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[1]))
})
})
Context("Certifiate validation", func() {
for _, v := range []protocol.VersionNumber{protocol.Version39, protocol.VersionTLS} {
version := v
Context(fmt.Sprintf("using %s", version), func() {
var clientConfig *quic.Config
BeforeEach(func() {
serverConfig.Versions = []protocol.VersionNumber{version}
clientConfig = &quic.Config{
Versions: []protocol.VersionNumber{version},
}
})
It("accepts the certificate", func() {
runServer()
_, err := quic.DialAddr(fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port), nil, clientConfig)
Expect(err).ToNot(HaveOccurred())
})
It("errors if the server name doesn't match", func() {
runServer()
_, err := quic.DialAddr(fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), nil, clientConfig)
Expect(err).To(HaveOccurred())
})
It("uses the ServerName in the tls.Config", func() {
runServer()
conf := &tls.Config{ServerName: "quic.clemente.io"}
_, err := quic.DialAddr(fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), conf, clientConfig)
Expect(err).ToNot(HaveOccurred())
})
})
}
})
})

View File

@ -1,232 +0,0 @@
package self_test
import (
"fmt"
"io/ioutil"
"net"
"os"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Multiplexing", func() {
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
version := v
// gQUIC 44 uses 0 byte connection IDs for packets sent to the client
// It's not possible to do demultiplexing.
if v == protocol.Version44 {
continue
}
Context(fmt.Sprintf("with QUIC version %s", version), func() {
runServer := func(ln quic.Listener) {
go func() {
defer GinkgoRecover()
for {
sess, err := ln.Accept()
if err != nil {
return
}
go func() {
defer GinkgoRecover()
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(testserver.PRDataLong)
Expect(err).ToNot(HaveOccurred())
}()
}
}()
}
dial := func(conn net.PacketConn, addr net.Addr) {
sess, err := quic.Dial(
conn,
addr,
fmt.Sprintf("quic.clemente.io:%d", addr.(*net.UDPAddr).Port),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRDataLong))
}
Context("multiplexing clients on the same conn", func() {
getListener := func() quic.Listener {
ln, err := quic.ListenAddr(
"localhost:0",
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
return ln
}
It("multiplexes connections to the same server", func() {
server := getListener()
runServer(server)
defer server.Close()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
close(done2)
}()
timeout := 30 * time.Second
if testlog.Debug() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
It("multiplexes connections to different servers", func() {
server1 := getListener()
runServer(server1)
defer server1.Close()
server2 := getListener()
runServer(server2)
defer server2.Close()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second
if testlog.Debug() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
})
Context("multiplexing server and client on the same conn", func() {
It("connects to itself", func() {
if version != protocol.VersionTLS {
Skip("Connecting to itself only works with IETF QUIC.")
}
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
server, err := quic.Listen(
conn,
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
runServer(server)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
close(done)
}()
timeout := 30 * time.Second
if testlog.Debug() {
timeout = time.Minute
}
Eventually(done, timeout).Should(BeClosed())
})
It("runs a server and client on the same conn", func() {
if os.Getenv("CI") == "true" {
Skip("This test is flaky on CIs, see see https://github.com/golang/go/issues/17677.")
}
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn1, err := net.ListenUDP("udp", addr1)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
server1, err := quic.Listen(
conn1,
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
runServer(server1)
defer server1.Close()
server2, err := quic.Listen(
conn2,
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
runServer(server2)
defer server2.Close()
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn2, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn1, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second
if testlog.Debug() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
})
})
}
})

View File

@ -1,83 +0,0 @@
package self
import (
"fmt"
"io/ioutil"
"net"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("non-zero RTT", func() {
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
roundTrips := [...]time.Duration{
10 * time.Millisecond,
50 * time.Millisecond,
100 * time.Millisecond,
200 * time.Millisecond,
}
for _, r := range roundTrips {
rtt := r
It(fmt.Sprintf("downloads a message with %s RTT", rtt), func() {
ln, err := quic.ListenAddr(
"localhost:0",
testdata.GetTLSConfig(),
&quic.Config{
Versions: []protocol.VersionNumber{version},
},
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(testserver.PRData)
Expect(err).ToNot(HaveOccurred())
str.Close()
close(done)
}()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(d quicproxy.Direction, p uint64) time.Duration {
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
sess, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRData))
sess.Close()
Eventually(done).Should(BeClosed())
})
}
})
}
})

View File

@ -1,20 +0,0 @@
package self_test
import (
"math/rand"
"testing"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
)
func TestSelf(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Self integration tests")
}
var _ = BeforeSuite(func() {
rand.Seed(GinkgoRandomSeed())
})

View File

@ -1,152 +0,0 @@
package self_test
import (
"fmt"
"io/ioutil"
"net"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Bidirectional streams", func() {
const numStreams = 300
var (
server quic.Listener
serverAddr string
qconf *quic.Config
)
for _, v := range []protocol.VersionNumber{protocol.VersionTLS} {
version := v
Context(fmt.Sprintf("with QUIC %s", version), func() {
BeforeEach(func() {
var err error
qconf = &quic.Config{
Versions: []protocol.VersionNumber{version},
MaxIncomingStreams: 0,
}
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), qconf)
Expect(err).ToNot(HaveOccurred())
serverAddr = fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port)
})
AfterEach(func() {
server.Close()
})
runSendingPeer := func(sess quic.Session) {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.OpenStreamSync()
Expect(err).ToNot(HaveOccurred())
data := testserver.GeneratePRData(25 * i)
go func() {
defer GinkgoRecover()
_, err := str.Write(data)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
go func() {
defer GinkgoRecover()
defer wg.Done()
dataRead, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(dataRead).To(Equal(data))
}()
}
wg.Wait()
}
runReceivingPeer := func(sess quic.Session) {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer wg.Done()
// shouldn't use io.Copy here
// we should read from the stream as early as possible, to free flow control credit
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
}
wg.Wait()
}
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
var sess quic.Session
go func() {
defer GinkgoRecover()
var err error
sess, err = server.Accept()
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess)
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
runSendingPeer(client)
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess)
sess.Close()
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client)
Eventually(client.Context().Done()).Should(BeClosed())
})
It(fmt.Sprintf("client and server opening %d each and sending data to the peer", numStreams), func() {
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
runReceivingPeer(sess)
close(done)
}()
runSendingPeer(sess)
<-done
close(done1)
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
runSendingPeer(client)
close(done2)
}()
runReceivingPeer(client)
<-done1
<-done2
})
})
}
})

View File

@ -1,132 +0,0 @@
package self_test
import (
"fmt"
"io/ioutil"
"net"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Unidirectional Streams", func() {
const numStreams = 500
var (
server quic.Listener
serverAddr string
qconf *quic.Config
)
BeforeEach(func() {
var err error
qconf = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), qconf)
Expect(err).ToNot(HaveOccurred())
serverAddr = fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port)
})
AfterEach(func() {
server.Close()
})
dataForStream := func(id protocol.StreamID) []byte {
return testserver.GeneratePRData(10 * int(id))
}
runSendingPeer := func(sess quic.Session) {
for i := 0; i < numStreams; i++ {
str, err := sess.OpenUniStreamSync()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
_, err := str.Write(dataForStream(str.StreamID()))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
}
}
runReceivingPeer := func(sess quic.Session) {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.AcceptUniStream()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer wg.Done()
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(dataForStream(str.StreamID())))
}()
}
wg.Wait()
}
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
var sess quic.Session
go func() {
defer GinkgoRecover()
var err error
sess, err = server.Accept()
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess)
sess.Close()
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
runSendingPeer(client)
<-client.Context().Done()
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess)
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client)
})
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
runReceivingPeer(sess)
close(done)
}()
runSendingPeer(sess)
<-done
close(done1)
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
runSendingPeer(client)
close(done2)
}()
runReceivingPeer(client)
<-done1
<-done2
})
})

View File

@ -1,270 +0,0 @@
package quicproxy
import (
"net"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Connection is a UDP connection
type connection struct {
ClientAddr *net.UDPAddr // Address of the client
ServerConn *net.UDPConn // UDP connection to server
incomingPacketCounter uint64
outgoingPacketCounter uint64
}
// Direction is the direction a packet is sent.
type Direction int
const (
// DirectionIncoming is the direction from the client to the server.
DirectionIncoming Direction = iota
// DirectionOutgoing is the direction from the server to the client.
DirectionOutgoing
// DirectionBoth is both incoming and outgoing
DirectionBoth
)
func (d Direction) String() string {
switch d {
case DirectionIncoming:
return "incoming"
case DirectionOutgoing:
return "outgoing"
case DirectionBoth:
return "both"
default:
panic("unknown direction")
}
}
// Is says if one direction matches another direction.
// For example, incoming matches both incoming and both, but not outgoing.
func (d Direction) Is(dir Direction) bool {
if d == DirectionBoth || dir == DirectionBoth {
return true
}
return d == dir
}
// DropCallback is a callback that determines which packet gets dropped.
type DropCallback func(dir Direction, packetCount uint64) bool
// NoDropper doesn't drop packets.
var NoDropper DropCallback = func(Direction, uint64) bool {
return false
}
// DelayCallback is a callback that determines how much delay to apply to a packet.
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
// NoDelay doesn't apply a delay.
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
return 0
}
// Opts are proxy options.
type Opts struct {
// The address this proxy proxies packets to.
RemoteAddr string
// DropPacket determines whether a packet gets dropped.
DropPacket DropCallback
// DelayPacket determines how long a packet gets delayed. This allows
// simulating a connection with non-zero RTTs.
// Note that the RTT is the sum of the delay for the incoming and the outgoing packet.
DelayPacket DelayCallback
}
// QuicProxy is a QUIC proxy that can drop and delay packets.
type QuicProxy struct {
mutex sync.Mutex
version protocol.VersionNumber
conn *net.UDPConn
serverAddr *net.UDPAddr
dropPacket DropCallback
delayPacket DelayCallback
// Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection
logger utils.Logger
}
// NewQuicProxy creates a new UDP proxy
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
if opts == nil {
opts = &Opts{}
}
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", laddr)
if err != nil {
return nil, err
}
raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr)
if err != nil {
return nil, err
}
packetDropper := NoDropper
if opts.DropPacket != nil {
packetDropper = opts.DropPacket
}
packetDelayer := NoDelay
if opts.DelayPacket != nil {
packetDelayer = opts.DelayPacket
}
p := QuicProxy{
clientDict: make(map[string]*connection),
conn: conn,
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
version: version,
logger: utils.DefaultLogger.WithPrefix("proxy"),
}
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
go p.runProxy()
return &p, nil
}
// Close stops the UDP Proxy
func (p *QuicProxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
return err
}
}
return p.conn.Close()
}
// LocalAddr is the address the proxy is listening on.
func (p *QuicProxy) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// LocalPort is the UDP port number the proxy is listening on.
func (p *QuicProxy) LocalPort() int {
return p.conn.LocalAddr().(*net.UDPAddr).Port
}
func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
srvudp, err := net.DialUDP("udp", nil, p.serverAddr)
if err != nil {
return nil, err
}
return &connection{
ClientAddr: cliAddr,
ServerConn: srvudp,
}, nil
}
// runProxy listens on the proxy address and handles incoming packets.
func (p *QuicProxy) runProxy() error {
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
saddr := cliaddr.String()
p.mutex.Lock()
conn, ok := p.clientDict[saddr]
if !ok {
conn, err = p.newConnection(cliaddr)
if err != nil {
p.mutex.Unlock()
return err
}
p.clientDict[saddr] = conn
go p.runConnection(conn)
}
p.mutex.Unlock()
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
if p.dropPacket(DirectionIncoming, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
}
continue
}
// Send the packet to the server
delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = conn.ServerConn.Write(raw)
})
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
}
if _, err := conn.ServerConn.Write(raw); err != nil {
return err
}
}
}
}
// runConnection handles packets from server to a single client
func (p *QuicProxy) runConnection(conn *connection) error {
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
if p.dropPacket(DirectionOutgoing, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
})
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return err
}
}
}
}

View File

@ -1,13 +0,0 @@
package quicproxy
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestQuicGo(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "QUIC Proxy")
}

View File

@ -1,394 +0,0 @@
package quicproxy
import (
"bytes"
"net"
"runtime/pprof"
"strconv"
"strings"
"sync/atomic"
"time"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type packetData []byte
var _ = Describe("QUIC Proxy", func() {
makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
b := &bytes.Buffer{}
hdr := wire.Header{
PacketNumber: p,
PacketNumberLen: protocol.PacketNumberLen6,
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
}
hdr.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever)
raw := b.Bytes()
raw = append(raw, payload...)
return raw
}
Context("Proxy setup and teardown", func() {
It("sets up the UDPProxy", func() {
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
Expect(err).ToNot(HaveOccurred())
Expect(proxy.clientDict).To(HaveLen(0))
// check that the proxy port is in use
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(proxy.LocalPort()))
Expect(err).ToNot(HaveOccurred())
_, err = net.ListenUDP("udp", addr)
Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort())))
Expect(proxy.Close()).To(Succeed()) // stopping is tested in the next test
})
It("stops the UDPProxy", func() {
isProxyRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "proxy.(*QuicProxy).runProxy")
}
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
Expect(err).ToNot(HaveOccurred())
port := proxy.LocalPort()
Expect(isProxyRunning()).To(BeTrue())
err = proxy.Close()
Expect(err).ToNot(HaveOccurred())
// check that the proxy port is not in use anymore
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(port))
Expect(err).ToNot(HaveOccurred())
// sometimes it takes a while for the OS to free the port
Eventually(func() error {
ln, err := net.ListenUDP("udp", addr)
if err != nil {
return err
}
ln.Close()
return nil
}).ShouldNot(HaveOccurred())
Eventually(isProxyRunning).Should(BeFalse())
})
It("stops listening for proxied connections", func() {
isConnRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "proxy.(*QuicProxy).runConnection")
}
serverAddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err := net.ListenUDP("udp", serverAddr)
Expect(err).ToNot(HaveOccurred())
defer serverConn.Close()
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, &Opts{RemoteAddr: serverConn.LocalAddr().String()})
Expect(err).ToNot(HaveOccurred())
Expect(isConnRunning()).To(BeFalse())
// check that the proxy port is not in use anymore
conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
Expect(err).ToNot(HaveOccurred())
_, err = conn.Write(makePacket(1, []byte("foobar")))
Expect(err).ToNot(HaveOccurred())
Eventually(isConnRunning).Should(BeTrue())
Expect(proxy.Close()).To(Succeed())
Eventually(isConnRunning).Should(BeFalse())
})
It("has the correct LocalAddr and LocalPort", func() {
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
Expect(err).ToNot(HaveOccurred())
Expect(proxy.LocalAddr().String()).To(Equal("127.0.0.1:" + strconv.Itoa(proxy.LocalPort())))
Expect(proxy.LocalPort()).ToNot(BeZero())
Expect(proxy.Close()).To(Succeed())
})
})
Context("Proxy tests", func() {
var (
serverConn *net.UDPConn
serverNumPacketsSent int32
serverReceivedPackets chan packetData
clientConn *net.UDPConn
proxy *QuicProxy
)
startProxy := func(opts *Opts) {
var err error
proxy, err = NewQuicProxy("localhost:0", protocol.VersionWhatever, opts)
Expect(err).ToNot(HaveOccurred())
clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
Expect(err).ToNot(HaveOccurred())
}
// getClientDict returns a copy of the clientDict map
getClientDict := func() map[string]*connection {
d := make(map[string]*connection)
proxy.mutex.Lock()
defer proxy.mutex.Unlock()
for k, v := range proxy.clientDict {
d[k] = v
}
return d
}
BeforeEach(func() {
serverReceivedPackets = make(chan packetData, 100)
atomic.StoreInt32(&serverNumPacketsSent, 0)
// setup a dump UDP server
// in production this would be a QUIC server
raddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err = net.ListenUDP("udp", raddr)
Expect(err).ToNot(HaveOccurred())
go func() {
for {
buf := make([]byte, protocol.MaxReceivePacketSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, addr, err2 := serverConn.ReadFromUDP(buf)
if err2 != nil {
return
}
data := buf[0:n]
serverReceivedPackets <- packetData(data)
// echo the packet
serverConn.WriteToUDP(data, addr)
atomic.AddInt32(&serverNumPacketsSent, 1)
}
}()
})
AfterEach(func() {
err := proxy.Close()
Expect(err).ToNot(HaveOccurred())
err = serverConn.Close()
Expect(err).ToNot(HaveOccurred())
err = clientConn.Close()
Expect(err).ToNot(HaveOccurred())
time.Sleep(200 * time.Millisecond)
})
Context("no packet drop", func() {
It("relays packets from the client to the server", func() {
startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
// send the first packet
_, err := clientConn.Write(makePacket(1, []byte("foobar")))
Expect(err).ToNot(HaveOccurred())
Eventually(getClientDict).Should(HaveLen(1))
var conn *connection
for _, conn = range getClientDict() {
Eventually(func() uint64 { return atomic.LoadUint64(&conn.incomingPacketCounter) }).Should(Equal(uint64(1)))
}
// send the second packet
_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
Expect(err).ToNot(HaveOccurred())
Eventually(serverReceivedPackets).Should(HaveLen(2))
Expect(getClientDict()).To(HaveLen(1))
Expect(string(<-serverReceivedPackets)).To(ContainSubstring("foobar"))
Expect(string(<-serverReceivedPackets)).To(ContainSubstring("decafbad"))
})
It("relays packets from the server to the client", func() {
startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
// send the first packet
_, err := clientConn.Write(makePacket(1, []byte("foobar")))
Expect(err).ToNot(HaveOccurred())
Eventually(getClientDict).Should(HaveLen(1))
var key string
var conn *connection
for key, conn = range getClientDict() {
Eventually(func() uint64 { return atomic.LoadUint64(&conn.outgoingPacketCounter) }).Should(Equal(uint64(1)))
}
// send the second packet
_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
Expect(err).ToNot(HaveOccurred())
Expect(getClientDict()).To(HaveLen(1))
Eventually(func() uint64 {
conn := getClientDict()[key]
return atomic.LoadUint64(&conn.outgoingPacketCounter)
}).Should(BeEquivalentTo(2))
clientReceivedPackets := make(chan packetData, 2)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxReceivePacketSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, _, err2 := clientConn.ReadFromUDP(buf)
if err2 != nil {
return
}
data := buf[0:n]
clientReceivedPackets <- packetData(data)
}
}()
Eventually(serverReceivedPackets).Should(HaveLen(2))
Expect(atomic.LoadInt32(&serverNumPacketsSent)).To(BeEquivalentTo(2))
Eventually(clientReceivedPackets).Should(HaveLen(2))
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("foobar"))
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("decafbad"))
})
})
Context("Drop Callbacks", func() {
It("drops incoming packets", func() {
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
DropPacket: func(d Direction, p uint64) bool {
return d == DirectionIncoming && p%2 == 0
},
}
startProxy(opts)
for i := 1; i <= 6; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
Eventually(serverReceivedPackets).Should(HaveLen(3))
Consistently(serverReceivedPackets).Should(HaveLen(3))
})
It("drops outgoing packets", func() {
const numPackets = 6
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
DropPacket: func(d Direction, p uint64) bool {
return d == DirectionOutgoing && p%2 == 0
},
}
startProxy(opts)
clientReceivedPackets := make(chan packetData, numPackets)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxReceivePacketSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, _, err2 := clientConn.ReadFromUDP(buf)
if err2 != nil {
return
}
data := buf[0:n]
clientReceivedPackets <- packetData(data)
}
}()
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
Eventually(clientReceivedPackets).Should(HaveLen(numPackets / 2))
Consistently(clientReceivedPackets).Should(HaveLen(numPackets / 2))
})
})
Context("Delay Callback", func() {
expectDelay := func(startTime time.Time, rtt time.Duration, numRTTs int) {
expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * rtt)
Expect(time.Now()).To(SatisfyAll(
BeTemporally(">=", expectedReceiveTime),
BeTemporally("<", expectedReceiveTime.Add(rtt/2)),
))
}
It("delays incoming packets", func() {
delay := 300 * time.Millisecond
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
DelayPacket: func(d Direction, p uint64) time.Duration {
if d == DirectionOutgoing {
return 0
}
return time.Duration(p) * delay
},
}
startProxy(opts)
// send 3 packets
start := time.Now()
for i := 1; i <= 3; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
Eventually(serverReceivedPackets).Should(HaveLen(1))
expectDelay(start, delay, 1)
Eventually(serverReceivedPackets).Should(HaveLen(2))
expectDelay(start, delay, 2)
Eventually(serverReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 3)
})
It("delays outgoing packets", func() {
const numPackets = 3
delay := 300 * time.Millisecond
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
DelayPacket: func(d Direction, p uint64) time.Duration {
if d == DirectionIncoming {
return 0
}
return time.Duration(p) * delay
},
}
startProxy(opts)
clientReceivedPackets := make(chan packetData, numPackets)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxReceivePacketSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, _, err2 := clientConn.ReadFromUDP(buf)
if err2 != nil {
return
}
data := buf[0:n]
clientReceivedPackets <- packetData(data)
}
}()
start := time.Now()
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
// the packets should have arrived immediately at the server
Eventually(serverReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 0)
Eventually(clientReceivedPackets).Should(HaveLen(1))
expectDelay(start, delay, 1)
Eventually(clientReceivedPackets).Should(HaveLen(2))
expectDelay(start, delay, 2)
Eventually(clientReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 3)
})
})
})
})

View File

@ -1,46 +0,0 @@
package testlog
import (
"flag"
"log"
"os"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var (
logFileName string // the log file set in the ginkgo flags
logFile *os.File
)
// read the logfile command line flag
// to set call ginkgo -- -logfile=log.txt
func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
}
var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
if len(logFileName) > 0 {
var err error
logFile, err = os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile)
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
}
})
var _ = AfterEach(func() {
if len(logFileName) > 0 {
_ = logFile.Close()
}
})
// Debug says if this test is being logged
func Debug() bool {
return len(logFileName) > 0
}

View File

@ -1,119 +0,0 @@
package testserver
import (
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const (
dataLen = 500 * 1024 // 500 KB
dataLenLong = 50 * 1024 * 1024 // 50 MB
)
var (
// PRData contains dataLen bytes of pseudo-random data.
PRData = GeneratePRData(dataLen)
// PRDataLong contains dataLenLong bytes of pseudo-random data.
PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server
stoppedServing chan struct{}
port string
)
func init() {
http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
sl := r.URL.Query().Get("len")
if sl != "" {
var err error
l, err := strconv.Atoi(sl)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(GeneratePRData(l))
Expect(err).NotTo(HaveOccurred())
} else {
_, err := w.Write(PRData)
Expect(err).NotTo(HaveOccurred())
}
})
http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := w.Write(PRDataLong)
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := io.WriteString(w, "Hello, World!\n")
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
body, err := ioutil.ReadAll(r.Body)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(body)
Expect(err).NotTo(HaveOccurred())
})
}
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
func GeneratePRData(l int) []byte {
res := make([]byte, l)
seed := uint64(1)
for i := 0; i < l; i++ {
seed = seed * 48271 % 2147483647
res[i] = byte(seed)
}
return res
}
// StartQuicServer starts a h2quic.Server.
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
func StartQuicServer(versions []protocol.VersionNumber) {
server = &h2quic.Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
QuicConfig: &quic.Config{
Versions: versions,
},
}
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
Expect(err).NotTo(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
stoppedServing = make(chan struct{})
go func() {
defer GinkgoRecover()
server.Serve(conn)
close(stoppedServing)
}()
}
// StopQuicServer stops the h2quic.Server.
func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed())
}
// Port returns the UDP port of the QUIC server.
func Port() string {
return port
}

View File

@ -1,24 +0,0 @@
package ackhandler
import (
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestCrypto(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "AckHandler Suite")
}
var mockCtrl *gomock.Controller
var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())
})
var _ = AfterEach(func() {
mockCtrl.Finish()
})

View File

@ -1,382 +0,0 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("receivedPacketHandler", func() {
var (
handler *receivedPacketHandler
rttStats *congestion.RTTStats
)
BeforeEach(func() {
rttStats = &congestion.RTTStats{}
handler = NewReceivedPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*receivedPacketHandler)
})
Context("accepting packets", func() {
It("handles a packet that arrives late", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(3), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
})
It("saves the time when each packet arrived", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
})
It("updates the largestObserved and the largestObservedReceivedTime", func() {
now := time.Now()
handler.largestObserved = 3
handler.largestObservedReceivedTime = now.Add(-1 * time.Second)
err := handler.ReceivedPacket(5, now, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(Equal(now))
})
It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() {
now := time.Now()
timestamp := now.Add(-1 * time.Second)
handler.largestObserved = 5
handler.largestObservedReceivedTime = timestamp
err := handler.ReceivedPacket(4, now, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
})
It("passes on errors from receivedPacketHistory", func() {
var err error
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
err = handler.ReceivedPacket(2*i+1, time.Time{}, true)
// this will eventually return an error
// details about when exactly the receivedPacketHistory errors are tested there
if err != nil {
break
}
}
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
})
})
Context("ACKs", func() {
Context("queueing ACKs", func() {
receiveAndAck10Packets := func() {
for i := 1; i <= 10; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.ackQueued).To(BeFalse())
}
receiveAndAckPacketsUntilAckDecimation := func() {
for i := 1; i <= minReceivedBeforeAckDecimation; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.ackQueued).To(BeFalse())
}
It("always queues an ACK for the first packet", func() {
err := handler.ReceivedPacket(1, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
})
It("works with packet number 0", func() {
err := handler.ReceivedPacket(0, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
})
It("queues an ACK for every second retransmittable packet at the beginning", func() {
receiveAndAck10Packets()
p := protocol.PacketNumber(11)
for i := 0; i <= 20; i++ {
err := handler.ReceivedPacket(p, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
p++
err = handler.ReceivedPacket(p, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
p++
// dequeue the ACK frame
Expect(handler.GetAckFrame()).ToNot(BeNil())
}
})
It("queues an ACK for every 10 retransmittable packet, if they are arriving fast", func() {
receiveAndAck10Packets()
p := protocol.PacketNumber(10000)
for i := 0; i < 9; i++ {
err := handler.ReceivedPacket(p, time.Now(), true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
p++
}
Expect(handler.GetAlarmTimeout()).NotTo(BeZero())
err := handler.ReceivedPacket(p, time.Now(), true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
})
It("only sets the timer when receiving a retransmittable packets", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, time.Now(), false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.GetAlarmTimeout()).To(BeZero())
rcvTime := time.Now().Add(10 * time.Millisecond)
err = handler.ReceivedPacket(12, rcvTime, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.GetAlarmTimeout()).To(Equal(rcvTime.Add(ackSendDelay)))
})
It("queues an ACK if it was reported missing before", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(13, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() // ACK: 1-11 and 13, missing: 12
Expect(ack).ToNot(BeNil())
Expect(ack.HasMissingRanges()).To(BeTrue())
Expect(handler.ackQueued).To(BeFalse())
err = handler.ReceivedPacket(12, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
})
It("doesn't queue an ACK if it was reported missing before, but is below the threshold", func() {
receiveAndAck10Packets()
// 11 is missing
err := handler.ReceivedPacket(12, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(13, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() // ACK: 1-10, 12-13
Expect(ack).ToNot(BeNil())
// now receive 11
handler.IgnoreBelow(12)
err = handler.ReceivedPacket(11, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
ack = handler.GetAckFrame()
Expect(ack).To(BeNil())
})
It("doesn't queue an ACK if the packet closes a gap that was not yet reported", func() {
receiveAndAckPacketsUntilAckDecimation()
p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1)
err := handler.ReceivedPacket(p+1, time.Now(), true) // p is missing now
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.GetAlarmTimeout()).ToNot(BeZero())
err = handler.ReceivedPacket(p, time.Now(), true) // p is not missing any more
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
})
It("sets an ACK alarm after 1/4 RTT if it creates a new missing range", func() {
now := time.Now().Add(-time.Hour)
rtt := 80 * time.Millisecond
rttStats.UpdateRTT(rtt, 0, now)
receiveAndAckPacketsUntilAckDecimation()
p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1)
for i := p; i < p+6; i++ {
err := handler.ReceivedPacket(i, now, true)
Expect(err).ToNot(HaveOccurred())
}
err := handler.ReceivedPacket(p+10, now, true) // we now know that packets p+7, p+8 and p+9
Expect(err).ToNot(HaveOccurred())
Expect(rttStats.MinRTT()).To(Equal(rtt))
Expect(handler.ackAlarm.Sub(now)).To(Equal(rtt / 8))
ack := handler.GetAckFrame()
Expect(ack.HasMissingRanges()).To(BeTrue())
Expect(ack).ToNot(BeNil())
})
})
Context("ACK generation", func() {
BeforeEach(func() {
handler.ackQueued = true
})
It("generates a simple ACK frame", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
It("generates an ACK for packet number 0", func() {
err := handler.ReceivedPacket(0, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
It("sets the delay time", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(2, time.Now().Add(-1337*time.Millisecond), true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond))
})
It("saves the last sent ACK", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
err = handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = true
ack = handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
})
It("generates an ACK frame with missing packets", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(4, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1)))
Expect(ack.AckRanges).To(Equal([]wire.AckRange{
{Smallest: 4, Largest: 4},
{Smallest: 1, Largest: 1},
}))
})
It("generates an ACK for packet number 0 and other packets", func() {
err := handler.ReceivedPacket(0, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(3, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0)))
Expect(ack.AckRanges).To(Equal([]wire.AckRange{
{Smallest: 3, Largest: 3},
{Smallest: 0, Largest: 1},
}))
})
It("accepts packets below the lower limit", func() {
handler.IgnoreBelow(6)
err := handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't add delayed packets to the packetHistory", func() {
handler.IgnoreBelow(7)
err := handler.ReceivedPacket(4, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(10, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(10)))
})
It("deletes packets from the packetHistory when a lower limit is set", func() {
for i := 1; i <= 12; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
handler.IgnoreBelow(7)
// check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(12)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(7)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
// TODO: remove this test when dropping support for STOP_WAITINGs
It("handles a lower limit of 0", func() {
handler.IgnoreBelow(0)
err := handler.ReceivedPacket(1337, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(1337)))
})
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.packetsReceivedSinceLastAck).To(BeZero())
Expect(handler.GetAlarmTimeout()).To(BeZero())
Expect(handler.retransmittablePacketsReceivedSinceLastAck).To(BeZero())
Expect(handler.ackQueued).To(BeFalse())
})
It("doesn't generate an ACK when none is queued and the timer is not set", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Time{}
Expect(handler.GetAckFrame()).To(BeNil())
})
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(time.Minute)
Expect(handler.GetAckFrame()).To(BeNil())
})
It("generates an ACK when the timer has expired", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
})
})
})
})

View File

@ -1,248 +0,0 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("receivedPacketHistory", func() {
var (
hist *receivedPacketHistory
)
BeforeEach(func() {
hist = newReceivedPacketHistory()
})
Context("ranges", func() {
It("adds the first packet", func() {
hist.ReceivedPacket(4)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
})
It("doesn't care about duplicate packets", func() {
hist.ReceivedPacket(4)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
})
It("adds a few consecutive packets", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(6)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
})
It("doesn't care about a duplicate packet contained in an existing range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(6)
hist.ReceivedPacket(5)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
})
It("extends a range at the front", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(3)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 3, End: 4}))
})
It("creates a new range when a packet is lost", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(6)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6}))
})
It("creates a new range in between two ranges", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(10)
Expect(hist.ranges.Len()).To(Equal(2))
hist.ReceivedPacket(7)
Expect(hist.ranges.Len()).To(Equal(3))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
It("creates a new range before an existing range for a belated packet", func() {
hist.ReceivedPacket(6)
hist.ReceivedPacket(4)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6}))
})
It("extends a previous range at the end", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(7)
hist.ReceivedPacket(5)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 5}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7}))
})
It("extends a range at the front", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(7)
hist.ReceivedPacket(6)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 7}))
})
It("closes a range", func() {
hist.ReceivedPacket(6)
hist.ReceivedPacket(4)
Expect(hist.ranges.Len()).To(Equal(2))
hist.ReceivedPacket(5)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
})
It("closes a range in the middle", func() {
hist.ReceivedPacket(1)
hist.ReceivedPacket(10)
hist.ReceivedPacket(4)
hist.ReceivedPacket(6)
Expect(hist.ranges.Len()).To(Equal(4))
hist.ReceivedPacket(5)
Expect(hist.ranges.Len()).To(Equal(3))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 1, End: 1}))
Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
})
Context("deleting", func() {
It("does nothing when the history is empty", func() {
hist.DeleteBelow(5)
Expect(hist.ranges.Len()).To(BeZero())
})
It("deletes a range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(10)
hist.DeleteBelow(6)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
It("deletes multiple ranges", func() {
hist.ReceivedPacket(1)
hist.ReceivedPacket(5)
hist.ReceivedPacket(10)
hist.DeleteBelow(8)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
It("adjusts a range, if packets are delete from an existing range", func() {
hist.ReceivedPacket(3)
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(6)
hist.ReceivedPacket(7)
hist.DeleteBelow(5)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 7}))
})
It("adjusts a range, if only one packet remains in the range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(10)
hist.DeleteBelow(5)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 5}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
It("keeps a one-packet range, if deleting up to the packet directly below", func() {
hist.ReceivedPacket(4)
hist.DeleteBelow(4)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
})
Context("DoS protection", func() {
It("doesn't create more than MaxTrackedReceivedAckRanges ranges", func() {
for i := protocol.PacketNumber(1); i <= protocol.MaxTrackedReceivedAckRanges; i++ {
err := hist.ReceivedPacket(2 * i)
Expect(err).ToNot(HaveOccurred())
}
err := hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 2)
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
})
It("doesn't consider already deleted ranges for MaxTrackedReceivedAckRanges", func() {
for i := protocol.PacketNumber(1); i <= protocol.MaxTrackedReceivedAckRanges; i++ {
err := hist.ReceivedPacket(2 * i)
Expect(err).ToNot(HaveOccurred())
}
err := hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 2)
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
hist.DeleteBelow(protocol.MaxTrackedReceivedAckRanges) // deletes about half of the ranges
err = hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 4)
Expect(err).ToNot(HaveOccurred())
})
})
})
Context("ACK range export", func() {
It("returns nil if there are no ranges", func() {
Expect(hist.GetAckRanges()).To(BeNil())
})
It("gets a single ACK range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
ackRanges := hist.GetAckRanges()
Expect(ackRanges).To(HaveLen(1))
Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
})
It("gets multiple ACK ranges", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(6)
hist.ReceivedPacket(1)
hist.ReceivedPacket(11)
hist.ReceivedPacket(10)
hist.ReceivedPacket(2)
ackRanges := hist.GetAckRanges()
Expect(ackRanges).To(HaveLen(3))
Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 10, Largest: 11}))
Expect(ackRanges[1]).To(Equal(wire.AckRange{Smallest: 4, Largest: 6}))
Expect(ackRanges[2]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2}))
})
})
Context("Getting the highest ACK range", func() {
It("returns the zero value if there are no ranges", func() {
Expect(hist.GetHighestAckRange()).To(BeZero())
})
It("gets a single ACK range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
})
It("gets the highest of multiple ACK ranges", func() {
hist.ReceivedPacket(3)
hist.ReceivedPacket(6)
hist.ReceivedPacket(7)
Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 6, Largest: 7}))
})
})
})

View File

@ -1,45 +0,0 @@
package ackhandler
import (
"reflect"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("retransmittable frames", func() {
for fl, el := range map[wire.Frame]bool{
&wire.AckFrame{}: false,
&wire.StopWaitingFrame{}: false,
&wire.BlockedFrame{}: true,
&wire.ConnectionCloseFrame{}: true,
&wire.GoawayFrame{}: true,
&wire.PingFrame{}: true,
&wire.RstStreamFrame{}: true,
&wire.StreamFrame{}: true,
&wire.MaxDataFrame{}: true,
&wire.MaxStreamDataFrame{}: true,
} {
f := fl
e := el
fName := reflect.ValueOf(f).Elem().Type().Name()
It("works for "+fName, func() {
Expect(IsFrameRetransmittable(f)).To(Equal(e))
})
It("stripping non-retransmittable frames works for "+fName, func() {
s := []wire.Frame{f}
if e {
Expect(stripNonRetransmittableFrames(s)).To(Equal([]wire.Frame{f}))
} else {
Expect(stripNonRetransmittableFrames(s)).To(BeEmpty())
}
})
It("HasRetransmittableFrames works for "+fName, func() {
Expect(HasRetransmittableFrames([]wire.Frame{f})).To(Equal(e))
})
}
})

View File

@ -1,18 +0,0 @@
package ackhandler
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Send Mode", func() {
It("has a string representation", func() {
Expect(SendNone.String()).To(Equal("none"))
Expect(SendAny.String()).To(Equal("any"))
Expect(SendAck.String()).To(Equal("ack"))
Expect(SendRTO.String()).To(Equal("rto"))
Expect(SendTLP.String()).To(Equal("tlp"))
Expect(SendRetransmission.String()).To(Equal("retransmission"))
Expect(SendMode(123).String()).To(Equal("invalid send mode: 123"))
})
})

View File

@ -1,297 +0,0 @@
package ackhandler
import (
"errors"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("SentPacketHistory", func() {
var hist *sentPacketHistory
expectInHistory := func(packetNumbers []protocol.PacketNumber) {
ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers)))
ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers)))
i := 0
hist.Iterate(func(p *Packet) (bool, error) {
pn := packetNumbers[i]
ExpectWithOffset(1, p.PacketNumber).To(Equal(pn))
ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn))
i++
return true, nil
})
}
BeforeEach(func() {
hist = newSentPacketHistory()
})
It("saves sent packets", func() {
hist.SentPacket(&Packet{PacketNumber: 1})
hist.SentPacket(&Packet{PacketNumber: 3})
hist.SentPacket(&Packet{PacketNumber: 4})
expectInHistory([]protocol.PacketNumber{1, 3, 4})
})
It("gets the length", func() {
hist.SentPacket(&Packet{PacketNumber: 1})
hist.SentPacket(&Packet{PacketNumber: 10})
Expect(hist.Len()).To(Equal(2))
})
Context("getting the first outstanding packet", func() {
It("gets nil, if there are no packets", func() {
Expect(hist.FirstOutstanding()).To(BeNil())
})
It("gets the first outstanding packet", func() {
hist.SentPacket(&Packet{PacketNumber: 2})
hist.SentPacket(&Packet{PacketNumber: 3})
front := hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2)))
})
It("gets the second packet if the first one is retransmitted", func() {
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
front := hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the first packet for retransmission.
// The first outstanding packet should now be 3.
err := hist.MarkCannotBeRetransmitted(1)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(3)))
})
It("gets the third packet if the first two are retransmitted", func() {
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
front := hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the second packet for retransmission.
// The first outstanding packet should still be 3.
err := hist.MarkCannotBeRetransmitted(3)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the first packet for retransmission.
// The first outstanding packet should still be 4.
err = hist.MarkCannotBeRetransmitted(1)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(4)))
})
})
It("gets a packet by packet number", func() {
p := &Packet{PacketNumber: 2}
hist.SentPacket(p)
Expect(hist.GetPacket(2)).To(Equal(p))
})
It("returns nil if the packet doesn't exist", func() {
Expect(hist.GetPacket(1337)).To(BeNil())
})
It("removes packets", func() {
hist.SentPacket(&Packet{PacketNumber: 1})
hist.SentPacket(&Packet{PacketNumber: 4})
hist.SentPacket(&Packet{PacketNumber: 8})
err := hist.Remove(4)
Expect(err).ToNot(HaveOccurred())
expectInHistory([]protocol.PacketNumber{1, 8})
})
It("errors when trying to remove a non existing packet", func() {
hist.SentPacket(&Packet{PacketNumber: 1})
err := hist.Remove(2)
Expect(err).To(MatchError("packet 2 not found in sent packet history"))
})
Context("iterating", func() {
BeforeEach(func() {
hist.SentPacket(&Packet{PacketNumber: 10})
hist.SentPacket(&Packet{PacketNumber: 14})
hist.SentPacket(&Packet{PacketNumber: 18})
})
It("iterates over all packets", func() {
var iterations []protocol.PacketNumber
err := hist.Iterate(func(p *Packet) (bool, error) {
iterations = append(iterations, p.PacketNumber)
return true, nil
})
Expect(err).ToNot(HaveOccurred())
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18}))
})
It("stops iterating", func() {
var iterations []protocol.PacketNumber
err := hist.Iterate(func(p *Packet) (bool, error) {
iterations = append(iterations, p.PacketNumber)
return p.PacketNumber != 14, nil
})
Expect(err).ToNot(HaveOccurred())
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
})
It("returns the error", func() {
testErr := errors.New("test error")
var iterations []protocol.PacketNumber
err := hist.Iterate(func(p *Packet) (bool, error) {
iterations = append(iterations, p.PacketNumber)
if p.PacketNumber == 14 {
return false, testErr
}
return true, nil
})
Expect(err).To(MatchError(testErr))
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
})
})
Context("retransmissions", func() {
BeforeEach(func() {
for i := protocol.PacketNumber(1); i <= 5; i++ {
hist.SentPacket(&Packet{PacketNumber: i})
}
})
It("errors if the packet doesn't exist", func() {
err := hist.MarkCannotBeRetransmitted(100)
Expect(err).To(MatchError("sent packet history: packet 100 not found"))
})
It("adds a sent packets as a retransmission", func() {
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 2)
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13})
Expect(hist.GetPacket(13).isRetransmission).To(BeTrue())
Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13}))
})
It("adds multiple packets sent as a retransmission", func() {
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}, {PacketNumber: 15}}, 2)
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13, 15})
Expect(hist.GetPacket(13).isRetransmission).To(BeTrue())
Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
Expect(hist.GetPacket(15).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13, 15}))
})
It("adds a packet as a normal packet if the retransmitted packet doesn't exist", func() {
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 7)
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13})
Expect(hist.GetPacket(13).isRetransmission).To(BeFalse())
Expect(hist.GetPacket(13).retransmissionOf).To(BeZero())
})
})
Context("outstanding packets", func() {
It("says if it has outstanding handshake packets", func() {
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionUnencrypted,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
})
It("says if it has outstanding packets", func() {
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
Expect(hist.HasOutstandingPackets()).To(BeTrue())
})
It("doesn't consider non-retransmittable packets as outstanding", func() {
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionUnencrypted,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("accounts for deleted handshake packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 5,
EncryptionLevel: protocol.EncryptionSecure,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
err := hist.Remove(5)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
})
It("accounts for deleted packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err := hist.Remove(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("doesn't count handshake packets marked as non-retransmittable", func() {
hist.SentPacket(&Packet{
PacketNumber: 5,
EncryptionLevel: protocol.EncryptionUnencrypted,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
err := hist.MarkCannotBeRetransmitted(5)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
})
It("doesn't count packets marked as non-retransmittable", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err := hist.MarkCannotBeRetransmitted(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("counts the number of packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
hist.SentPacket(&Packet{
PacketNumber: 11,
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
err := hist.Remove(11)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err = hist.Remove(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
})
})

View File

@ -1,55 +0,0 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("StopWaitingManager", func() {
var manager *stopWaitingManager
BeforeEach(func() {
manager = &stopWaitingManager{}
})
It("returns nil in the beginning", func() {
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
Expect(manager.GetStopWaitingFrame(true)).To(BeNil())
})
It("returns a StopWaitingFrame, when a new ACK arrives", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
It("does not decrease the LeastUnacked", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 9}}})
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
It("does not send the same StopWaitingFrame twice", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
})
It("gets the same StopWaitingFrame twice, if forced", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
})
It("increases the LeastUnacked when a retransmission is queued", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.QueuedRetransmissionForPacketNumber(20)
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 21}))
})
It("does not decrease the LeastUnacked when a retransmission is queued", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.QueuedRetransmissionForPacketNumber(9)
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
})

View File

@ -1,14 +0,0 @@
package congestion
import (
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Bandwidth", func() {
It("converts from time delta", func() {
Expect(BandwidthFromDelta(1, time.Millisecond)).To(Equal(1000 * BytesPerSecond))
})
})

View File

@ -1,13 +0,0 @@
package congestion
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestCongestion(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Congestion Suite")
}

View File

@ -1,640 +0,0 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const initialCongestionWindowPackets = 10
const defaultWindowTCP = protocol.ByteCount(initialCongestionWindowPackets) * protocol.DefaultTCPMSS
type mockClock time.Time
func (c *mockClock) Now() time.Time {
return time.Time(*c)
}
func (c *mockClock) Advance(d time.Duration) {
*c = mockClock(time.Time(*c).Add(d))
}
const MaxCongestionWindow protocol.ByteCount = 200 * protocol.DefaultTCPMSS
var _ = Describe("Cubic Sender", func() {
var (
sender SendAlgorithmWithDebugInfo
clock mockClock
bytesInFlight protocol.ByteCount
packetNumber protocol.PacketNumber
ackedPacketNumber protocol.PacketNumber
rttStats *RTTStats
)
BeforeEach(func() {
bytesInFlight = 0
packetNumber = 1
ackedPacketNumber = 0
clock = mockClock{}
rttStats = NewRTTStats()
sender = NewCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*protocol.DefaultTCPMSS, MaxCongestionWindow)
})
canSend := func() bool {
return bytesInFlight < sender.GetCongestionWindow()
}
SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int {
packetsSent := 0
for canSend() {
sender.OnPacketSent(clock.Now(), bytesInFlight, packetNumber, packetLength, true)
packetNumber++
packetsSent++
bytesInFlight += packetLength
}
return packetsSent
}
// Normal is that TCP acks every other segment.
AckNPackets := func(n int) {
rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now())
sender.MaybeExitSlowStart()
for i := 0; i < n; i++ {
ackedPacketNumber++
sender.OnPacketAcked(ackedPacketNumber, protocol.DefaultTCPMSS, bytesInFlight, clock.Now())
}
bytesInFlight -= protocol.ByteCount(n) * protocol.DefaultTCPMSS
clock.Advance(time.Millisecond)
}
LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) {
for i := 0; i < n; i++ {
ackedPacketNumber++
sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight)
}
bytesInFlight -= protocol.ByteCount(n) * packetLength
}
// Does not increment acked_packet_number_.
LosePacket := func(number protocol.PacketNumber) {
sender.OnPacketLost(number, protocol.DefaultTCPMSS, bytesInFlight)
bytesInFlight -= protocol.DefaultTCPMSS
}
SendAvailableSendWindow := func() int { return SendAvailableSendWindowLen(protocol.DefaultTCPMSS) }
LoseNPackets := func(n int) { LoseNPacketsLen(n, protocol.DefaultTCPMSS) }
It("has the right values at startup", func() {
// At startup make sure we are at the default.
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
// Make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
Expect(canSend()).To(BeTrue())
// And that window is un-affected.
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
// Fill the send window with data, then verify that we can't send.
SendAvailableSendWindow()
Expect(canSend()).To(BeFalse())
})
It("paces", func() {
clock.Advance(time.Hour)
// Fill the send window with data, then verify that we can't send.
SendAvailableSendWindow()
AckNPackets(1)
delay := sender.TimeUntilSend(bytesInFlight)
Expect(delay).ToNot(BeZero())
Expect(delay).ToNot(Equal(utils.InfDuration))
})
It("application limited slow start", func() {
// Send exactly 10 packets and ensure the CWND ends at 14 packets.
const numberOfAcks = 5
// At startup make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
// Make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
SendAvailableSendWindow()
for i := 0; i < numberOfAcks; i++ {
AckNPackets(2)
}
bytesToSend := sender.GetCongestionWindow()
// It's expected 2 acks will arrive when the bytes_in_flight are greater than
// half the CWND.
Expect(bytesToSend).To(Equal(defaultWindowTCP + protocol.DefaultTCPMSS*2*2))
})
It("exponential slow start", func() {
const numberOfAcks = 20
// At startup make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
Expect(sender.BandwidthEstimate()).To(BeZero())
// Make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
cwnd := sender.GetCongestionWindow()
Expect(cwnd).To(Equal(defaultWindowTCP + protocol.DefaultTCPMSS*2*numberOfAcks))
Expect(sender.BandwidthEstimate()).To(Equal(BandwidthFromDelta(cwnd, rttStats.SmoothedRTT())))
})
It("slow start packet loss", func() {
sender.SetNumEmulatedConnections(1)
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose a packet to exit slow start.
LoseNPackets(1)
packetsInRecoveryWindow := expectedSendWindow / protocol.DefaultTCPMSS
// We should now have fallen out of slow start with a reduced window.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Recovery phase. We need to ack every packet in the recovery window before
// we exit recovery.
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
AckNPackets(int(packetsInRecoveryWindow))
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// We need to ack an entire window before we increase CWND by 1.
AckNPackets(int(numberOfPacketsInWindow) - 2)
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Next ack should increase cwnd by 1.
AckNPackets(1)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Now RTO and ensure slow start gets reset.
Expect(sender.HybridSlowStart().Started()).To(BeTrue())
sender.OnRetransmissionTimeout(true)
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
})
It("slow start packet loss with large reduction", func() {
sender.SetSlowStartLargeReduction(true)
sender.SetNumEmulatedConnections(1)
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose a packet to exit slow start. We should now have fallen out of
// slow start with a window reduced by 1.
LoseNPackets(1)
expectedSendWindow -= protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose 5 packets in recovery and verify that congestion window is reduced
// further.
LoseNPackets(5)
expectedSendWindow -= 5 * protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
packetsInRecoveryWindow := expectedSendWindow / protocol.DefaultTCPMSS
// Recovery phase. We need to ack every packet in the recovery window before
// we exit recovery.
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
AckNPackets(int(packetsInRecoveryWindow))
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// We need to ack the rest of the window before cwnd increases by 1.
AckNPackets(int(numberOfPacketsInWindow - 1))
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Next ack should increase cwnd by 1.
AckNPackets(1)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Now RTO and ensure slow start gets reset.
Expect(sender.HybridSlowStart().Started()).To(BeTrue())
sender.OnRetransmissionTimeout(true)
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
})
It("slow start half packet loss with large reduction", func() {
sender.SetSlowStartLargeReduction(true)
sender.SetNumEmulatedConnections(1)
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window in half sized packets.
SendAvailableSendWindowLen(protocol.DefaultTCPMSS / 2)
AckNPackets(2)
}
SendAvailableSendWindowLen(protocol.DefaultTCPMSS / 2)
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose a packet to exit slow start. We should now have fallen out of
// slow start with a window reduced by 1.
LoseNPackets(1)
expectedSendWindow -= protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose 10 packets in recovery and verify that congestion window is reduced
// by 5 packets.
LoseNPacketsLen(10, protocol.DefaultTCPMSS/2)
expectedSendWindow -= 5 * protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
})
// this test doesn't work any more after introducing the pacing needed for QUIC
PIt("no PRR when less than one packet in flight", func() {
SendAvailableSendWindow()
LoseNPackets(int(initialCongestionWindowPackets) - 1)
AckNPackets(1)
// PRR will allow 2 packets for every ack during recovery.
Expect(SendAvailableSendWindow()).To(Equal(2))
// Simulate abandoning all packets by supplying a bytes_in_flight of 0.
// PRR should now allow a packet to be sent, even though prr's state
// variables believe it has sent enough packets.
Expect(sender.TimeUntilSend(0)).To(BeZero())
})
It("slow start packet loss PRR", func() {
sender.SetNumEmulatedConnections(1)
// Test based on the first example in RFC6937.
// Ack 10 packets in 5 acks to raise the CWND to 20, as in the example.
const numberOfAcks = 5
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
LoseNPackets(1)
// We should now have fallen out of slow start with a reduced window.
sendWindowBeforeLoss := expectedSendWindow
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Testing TCP proportional rate reduction.
// We should send packets paced over the received acks for the remaining
// outstanding packets. The number of packets before we exit recovery is the
// original CWND minus the packet that has been lost and the one which
// triggered the loss.
remainingPacketsInRecovery := sendWindowBeforeLoss/protocol.DefaultTCPMSS - 2
for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ {
AckNPackets(1)
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
// We need to ack another window before we increase CWND by 1.
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
for i := protocol.ByteCount(0); i < numberOfPacketsInWindow; i++ {
AckNPackets(1)
Expect(SendAvailableSendWindow()).To(Equal(1))
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
AckNPackets(1)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
})
It("slow start burst packet loss PRR", func() {
sender.SetNumEmulatedConnections(1)
// Test based on the second example in RFC6937, though we also implement
// forward acknowledgements, so the first two incoming acks will trigger
// PRR immediately.
// Ack 20 packets in 10 acks to raise the CWND to 30.
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose one more than the congestion window reduction, so that after loss,
// bytes_in_flight is lesser than the congestion window.
sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow))
numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/protocol.DefaultTCPMSS + 1
LoseNPackets(int(numPacketsToLose))
// Immediately after the loss, ensure at least one packet can be sent.
// Losses without subsequent acks can occur with timer based loss detection.
Expect(sender.TimeUntilSend(bytesInFlight)).To(BeZero())
AckNPackets(1)
// We should now have fallen out of slow start with a reduced window.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Only 2 packets should be allowed to be sent, per PRR-SSRB
Expect(SendAvailableSendWindow()).To(Equal(2))
// Ack the next packet, which triggers another loss.
LoseNPackets(1)
AckNPackets(1)
// Send 2 packets to simulate PRR-SSRB.
Expect(SendAvailableSendWindow()).To(Equal(2))
// Ack the next packet, which triggers another loss.
LoseNPackets(1)
AckNPackets(1)
// Send 2 packets to simulate PRR-SSRB.
Expect(SendAvailableSendWindow()).To(Equal(2))
// Exit recovery and return to sending at the new rate.
for i := 0; i < numberOfAcks; i++ {
AckNPackets(1)
Expect(SendAvailableSendWindow()).To(Equal(1))
}
})
It("RTO congestion window", func() {
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
// Expect the window to decrease to the minimum once the RTO fires
// and slow start threshold to be set to 1/2 of the CWND.
sender.OnRetransmissionTimeout(true)
Expect(sender.GetCongestionWindow()).To(Equal(2 * protocol.DefaultTCPMSS))
Expect(sender.SlowstartThreshold()).To(Equal(5 * protocol.DefaultTCPMSS))
})
It("RTO congestion window no retransmission", func() {
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
// Expect the window to remain unchanged if the RTO fires but no
// packets are retransmitted.
sender.OnRetransmissionTimeout(false)
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
})
It("tcp cubic reset epoch on quiescence", func() {
const maxCongestionWindow = 50
const maxCongestionWindowBytes = maxCongestionWindow * protocol.DefaultTCPMSS
sender = NewCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*protocol.DefaultTCPMSS, maxCongestionWindowBytes)
numSent := SendAvailableSendWindow()
// Make sure we fall out of slow start.
savedCwnd := sender.GetCongestionWindow()
LoseNPackets(1)
Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow()))
// Ack the rest of the outstanding packets to get out of recovery.
for i := 1; i < numSent; i++ {
AckNPackets(1)
}
Expect(bytesInFlight).To(BeZero())
// Send a new window of data and ack all; cubic growth should occur.
savedCwnd = sender.GetCongestionWindow()
numSent = SendAvailableSendWindow()
for i := 0; i < numSent; i++ {
AckNPackets(1)
}
Expect(savedCwnd).To(BeNumerically("<", sender.GetCongestionWindow()))
Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow()))
Expect(bytesInFlight).To(BeZero())
// Quiescent time of 100 seconds
clock.Advance(100 * time.Second)
// Send new window of data and ack one packet. Cubic epoch should have
// been reset; ensure cwnd increase is not dramatic.
savedCwnd = sender.GetCongestionWindow()
SendAvailableSendWindow()
AckNPackets(1)
Expect(savedCwnd).To(BeNumerically("~", sender.GetCongestionWindow(), protocol.DefaultTCPMSS))
Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow()))
})
It("multiple losses in one window", func() {
SendAvailableSendWindow()
initialWindow := sender.GetCongestionWindow()
LosePacket(ackedPacketNumber + 1)
postLossWindow := sender.GetCongestionWindow()
Expect(initialWindow).To(BeNumerically(">", postLossWindow))
LosePacket(ackedPacketNumber + 3)
Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow))
LosePacket(packetNumber - 1)
Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow))
// Lose a later packet and ensure the window decreases.
LosePacket(packetNumber)
Expect(postLossWindow).To(BeNumerically(">", sender.GetCongestionWindow()))
})
It("2 connection congestion avoidance at end of recovery", func() {
sender.SetNumEmulatedConnections(2)
// Ack 10 packets in 5 acks to raise the CWND to 20.
const numberOfAcks = 5
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
LoseNPackets(1)
// We should now have fallen out of slow start with a reduced window.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * sender.RenoBeta())
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// No congestion window growth should occur in recovery phase, i.e., until the
// currently outstanding 20 packets are acked.
for i := 0; i < 10; i++ {
// Send our full send window.
SendAvailableSendWindow()
Expect(sender.InRecovery()).To(BeTrue())
AckNPackets(2)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
Expect(sender.InRecovery()).To(BeFalse())
// Out of recovery now. Congestion window should not grow for half an RTT.
packetsInSendWindow := expectedSendWindow / protocol.DefaultTCPMSS
SendAvailableSendWindow()
AckNPackets(int(packetsInSendWindow/2 - 2))
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Next ack should increase congestion window by 1MSS.
SendAvailableSendWindow()
AckNPackets(2)
expectedSendWindow += protocol.DefaultTCPMSS
packetsInSendWindow++
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Congestion window should remain steady again for half an RTT.
SendAvailableSendWindow()
AckNPackets(int(packetsInSendWindow/2 - 1))
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Next ack should cause congestion window to grow by 1MSS.
SendAvailableSendWindow()
AckNPackets(2)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
})
It("1 connection congestion avoidance at end of recovery", func() {
sender.SetNumEmulatedConnections(1)
// Ack 10 packets in 5 acks to raise the CWND to 20.
const numberOfAcks = 5
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
LoseNPackets(1)
// We should now have fallen out of slow start with a reduced window.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// No congestion window growth should occur in recovery phase, i.e., until the
// currently outstanding 20 packets are acked.
for i := 0; i < 10; i++ {
// Send our full send window.
SendAvailableSendWindow()
Expect(sender.InRecovery()).To(BeTrue())
AckNPackets(2)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
Expect(sender.InRecovery()).To(BeFalse())
// Out of recovery now. Congestion window should not grow during RTT.
for i := protocol.ByteCount(0); i < expectedSendWindow/protocol.DefaultTCPMSS-2; i += 2 {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
// Next ack should cause congestion window to grow by 1MSS.
SendAvailableSendWindow()
AckNPackets(2)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
})
It("reset after connection migration", func() {
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
// Starts with slow start.
sender.SetNumEmulatedConnections(1)
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Loses a packet to exit slow start.
LoseNPackets(1)
// We should now have fallen out of slow start with a reduced window. Slow
// start threshold is also updated.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
Expect(sender.SlowstartThreshold()).To(Equal(expectedSendWindow))
// Resets cwnd and slow start threshold on connection migrations.
sender.OnConnectionMigration()
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
})
It("default max cwnd", func() {
sender = NewCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*protocol.DefaultTCPMSS, protocol.DefaultMaxCongestionWindow)
defaultMaxCongestionWindowPackets := protocol.DefaultMaxCongestionWindow / protocol.DefaultTCPMSS
for i := 1; i < int(defaultMaxCongestionWindowPackets); i++ {
sender.MaybeExitSlowStart()
sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now())
}
Expect(sender.GetCongestionWindow()).To(Equal(protocol.DefaultMaxCongestionWindow))
})
It("limit cwnd increase in congestion avoidance", func() {
// Enable Cubic.
sender = NewCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*protocol.DefaultTCPMSS, MaxCongestionWindow)
numSent := SendAvailableSendWindow()
// Make sure we fall out of slow start.
savedCwnd := sender.GetCongestionWindow()
LoseNPackets(1)
Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow()))
// Ack the rest of the outstanding packets to get out of recovery.
for i := 1; i < numSent; i++ {
AckNPackets(1)
}
Expect(bytesInFlight).To(BeZero())
savedCwnd = sender.GetCongestionWindow()
SendAvailableSendWindow()
// Ack packets until the CWND increases.
for sender.GetCongestionWindow() == savedCwnd {
AckNPackets(1)
SendAvailableSendWindow()
}
// Bytes in flight may be larger than the CWND if the CWND isn't an exact
// multiple of the packet sizes being sent.
Expect(bytesInFlight).To(BeNumerically(">=", sender.GetCongestionWindow()))
savedCwnd = sender.GetCongestionWindow()
// Advance time 2 seconds waiting for an ack.
clock.Advance(2 * time.Second)
// Ack two packets. The CWND should increase by only one packet.
AckNPackets(2)
Expect(sender.GetCongestionWindow()).To(Equal(savedCwnd + protocol.DefaultTCPMSS))
})
})

View File

@ -1,236 +0,0 @@
package congestion
import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const numConnections uint32 = 2
const nConnectionBeta float32 = (float32(numConnections) - 1 + beta) / float32(numConnections)
const nConnectionBetaLastMax float32 = (float32(numConnections) - 1 + betaLastMax) / float32(numConnections)
const nConnectionAlpha float32 = 3 * float32(numConnections) * float32(numConnections) * (1 - nConnectionBeta) / (1 + nConnectionBeta)
const maxCubicTimeInterval = 30 * time.Millisecond
var _ = Describe("Cubic", func() {
var (
clock mockClock
cubic *Cubic
)
BeforeEach(func() {
clock = mockClock{}
cubic = NewCubic(&clock)
})
renoCwnd := func(currentCwnd protocol.ByteCount) protocol.ByteCount {
return currentCwnd + protocol.ByteCount(float32(protocol.DefaultTCPMSS)*nConnectionAlpha*float32(protocol.DefaultTCPMSS)/float32(currentCwnd))
}
cubicConvexCwnd := func(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount {
offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000
deltaCongestionWindow := 410 * offset * offset * offset * protocol.DefaultTCPMSS >> 40
return initialCwnd + deltaCongestionWindow
}
It("works above origin (with tighter bounds)", func() {
// Convex growth.
const rttMin = 100 * time.Millisecond
const rttMinS = float32(rttMin/time.Millisecond) / 1000.0
currentCwnd := 10 * protocol.DefaultTCPMSS
initialCwnd := currentCwnd
clock.Advance(time.Millisecond)
initialTime := clock.Now()
expectedFirstCwnd := renoCwnd(currentCwnd)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, initialTime)
Expect(expectedFirstCwnd).To(Equal(currentCwnd))
// Normal TCP phase.
// The maximum number of expected reno RTTs can be calculated by
// finding the point where the cubic curve and the reno curve meet.
maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2)
for i := 0; i < maxRenoRtts; i++ {
// Alternatively, we expect it to increase by one, every time we
// receive current_cwnd/Alpha acks back. (This is another way of
// saying we expect cwnd to increase by approximately Alpha once
// we receive current_cwnd number ofacks back).
numAcksThisEpoch := int(float32(currentCwnd/protocol.DefaultTCPMSS) / nConnectionAlpha)
initialCwndThisEpoch := currentCwnd
for n := 0; n < numAcksThisEpoch; n++ {
// Call once per ACK.
expectedNextCwnd := renoCwnd(currentCwnd)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
Expect(currentCwnd).To(Equal(expectedNextCwnd))
}
// Our byte-wise Reno implementation is an estimate. We expect
// the cwnd to increase by approximately one MSS every
// cwnd/kDefaultTCPMSS/Alpha acks, but it may be off by as much as
// half a packet for smaller values of current_cwnd.
cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch
Expect(cwndChangeThisEpoch).To(BeNumerically("~", protocol.DefaultTCPMSS, protocol.DefaultTCPMSS/2))
clock.Advance(100 * time.Millisecond)
}
for i := 0; i < 54; i++ {
maxAcksThisEpoch := currentCwnd / protocol.DefaultTCPMSS
interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond
for n := 0; n < int(maxAcksThisEpoch); n++ {
clock.Advance(interval)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
// If we allow per-ack updates, every update is a small cubic update.
Expect(currentCwnd).To(Equal(expectedCwnd))
}
}
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
Expect(currentCwnd).To(Equal(expectedCwnd))
})
It("works above the origin with fine grained cubing", func() {
// Start the test with an artificially large cwnd to prevent Reno
// from over-taking cubic.
currentCwnd := 1000 * protocol.DefaultTCPMSS
initialCwnd := currentCwnd
rttMin := 100 * time.Millisecond
clock.Advance(time.Millisecond)
initialTime := clock.Now()
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
clock.Advance(600 * time.Millisecond)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
// We expect the algorithm to perform only non-zero, fine-grained cubic
// increases on every ack in this case.
for i := 0; i < 100; i++ {
clock.Advance(10 * time.Millisecond)
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
nextCwnd := cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
// Make sure we are performing cubic increases.
Expect(nextCwnd).To(Equal(expectedCwnd))
// Make sure that these are non-zero, less-than-packet sized increases.
Expect(nextCwnd).To(BeNumerically(">", currentCwnd))
cwndDelta := nextCwnd - currentCwnd
Expect(protocol.DefaultTCPMSS / 10).To(BeNumerically(">", cwndDelta))
currentCwnd = nextCwnd
}
})
It("handles per ack updates", func() {
// Start the test with a large cwnd and RTT, to force the first
// increase to be a cubic increase.
initialCwndPackets := 150
currentCwnd := protocol.ByteCount(initialCwndPackets) * protocol.DefaultTCPMSS
rttMin := 350 * time.Millisecond
// Initialize the epoch
clock.Advance(time.Millisecond)
// Keep track of the growth of the reno-equivalent cwnd.
rCwnd := renoCwnd(currentCwnd)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
initialCwnd := currentCwnd
// Simulate the return of cwnd packets in less than
// MaxCubicInterval() time.
maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha)
interval := maxCubicTimeInterval / time.Duration(maxAcks+1)
// In this scenario, the first increase is dictated by the cubic
// equation, but it is less than one byte, so the cwnd doesn't
// change. Normally, without per-ack increases, any cwnd plateau
// will cause the cwnd to be pinned for MaxCubicTimeInterval(). If
// we enable per-ack updates, the cwnd will continue to grow,
// regardless of the temporary plateau.
clock.Advance(interval)
rCwnd = renoCwnd(rCwnd)
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(currentCwnd))
for i := 1; i < maxAcks; i++ {
clock.Advance(interval)
nextCwnd := cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
rCwnd = renoCwnd(rCwnd)
// The window shoud increase on every ack.
Expect(nextCwnd).To(BeNumerically(">", currentCwnd))
Expect(nextCwnd).To(Equal(rCwnd))
currentCwnd = nextCwnd
}
// After all the acks are returned from the epoch, we expect the
// cwnd to have increased by nearly one packet. (Not exactly one
// packet, because our byte-wise Reno algorithm is always a slight
// under-estimation). Without per-ack updates, the current_cwnd
// would otherwise be unchanged.
minimumExpectedIncrease := protocol.DefaultTCPMSS * 9 / 10
Expect(currentCwnd).To(BeNumerically(">", initialCwnd+minimumExpectedIncrease))
})
It("handles loss events", func() {
rttMin := 100 * time.Millisecond
currentCwnd := 422 * protocol.DefaultTCPMSS
expectedCwnd := renoCwnd(currentCwnd)
// Initialize the state.
clock.Advance(time.Millisecond)
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd))
// On the first loss, the last max congestion window is set to the
// congestion window before the loss.
preLossCwnd := currentCwnd
Expect(cubic.lastMaxCongestionWindow).To(BeZero())
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
Expect(cubic.lastMaxCongestionWindow).To(Equal(preLossCwnd))
currentCwnd = expectedCwnd
// On the second loss, the current congestion window has not yet
// reached the last max congestion window. The last max congestion
// window will be reduced by an additional backoff factor to allow
// for competition.
preLossCwnd = currentCwnd
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
currentCwnd = expectedCwnd
Expect(preLossCwnd).To(BeNumerically(">", cubic.lastMaxCongestionWindow))
expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax)
Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax))
Expect(expectedCwnd).To(BeNumerically("<", cubic.lastMaxCongestionWindow))
// Simulate an increase, and check that we are below the origin.
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
Expect(cubic.lastMaxCongestionWindow).To(BeNumerically(">", currentCwnd))
// On the final loss, simulate the condition where the congestion
// window had a chance to grow nearly to the last congestion window.
currentCwnd = cubic.lastMaxCongestionWindow - 1
preLossCwnd = currentCwnd
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
expectedLastMax = preLossCwnd
Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax))
})
It("works below origin", func() {
// Concave growth.
rttMin := 100 * time.Millisecond
currentCwnd := 422 * protocol.DefaultTCPMSS
expectedCwnd := renoCwnd(currentCwnd)
// Initialize the state.
clock.Advance(time.Millisecond)
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd))
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
currentCwnd = expectedCwnd
// First update after loss to initialize the epoch.
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
// Cubic phase.
for i := 0; i < 40; i++ {
clock.Advance(100 * time.Millisecond)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
}
expectedCwnd = 553632
Expect(currentCwnd).To(Equal(expectedCwnd))
})
})

View File

@ -1,75 +0,0 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Hybrid slow start", func() {
var (
slowStart HybridSlowStart
)
BeforeEach(func() {
slowStart = HybridSlowStart{}
})
It("works in a simple case", func() {
packetNumber := protocol.PacketNumber(1)
endPacketNumber := protocol.PacketNumber(3)
slowStart.StartReceiveRound(endPacketNumber)
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
// Test duplicates.
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
// Test without a new registered end_packet_number;
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
endPacketNumber = 20
slowStart.StartReceiveRound(endPacketNumber)
for packetNumber < endPacketNumber {
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
}
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
})
It("works with delay", func() {
rtt := 60 * time.Millisecond
// We expect to detect the increase at +1/8 of the RTT; hence at a typical
// RTT of 60ms the detection will happen at 67.5 ms.
const hybridStartMinSamples = 8 // Number of acks required to trigger.
endPacketNumber := protocol.PacketNumber(1)
endPacketNumber++
slowStart.StartReceiveRound(endPacketNumber)
// Will not trigger since our lowest RTT in our burst is the same as the long
// term RTT provided.
for n := 0; n < hybridStartMinSamples; n++ {
Expect(slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)).To(BeFalse())
}
endPacketNumber++
slowStart.StartReceiveRound(endPacketNumber)
for n := 1; n < hybridStartMinSamples; n++ {
Expect(slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)).To(BeFalse())
}
// Expect to trigger since all packets in this burst was above the long term
// RTT provided.
Expect(slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)).To(BeTrue())
})
})

View File

@ -1,107 +0,0 @@
package congestion
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var _ = Describe("PRR sender", func() {
var (
prr PrrSender
)
BeforeEach(func() {
prr = PrrSender{}
})
It("single loss results in send on every other ack", func() {
numPacketsInFlight := protocol.ByteCount(50)
bytesInFlight := numPacketsInFlight * protocol.DefaultTCPMSS
sshthreshAfterLoss := numPacketsInFlight / 2
congestionWindow := sshthreshAfterLoss * protocol.DefaultTCPMSS
prr.OnPacketLost(bytesInFlight)
// Ack a packet. PRR allows one packet to leave immediately.
prr.OnPacketAcked(protocol.DefaultTCPMSS)
bytesInFlight -= protocol.DefaultTCPMSS
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
// Send retransmission.
prr.OnPacketSent(protocol.DefaultTCPMSS)
// PRR shouldn't allow sending any more packets.
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeFalse())
// One packet is lost, and one ack was consumed above. PRR now paces
// transmissions through the remaining 48 acks. PRR will alternatively
// disallow and allow a packet to be sent in response to an ack.
for i := protocol.ByteCount(0); i < sshthreshAfterLoss-1; i++ {
// Ack a packet. PRR shouldn't allow sending a packet in response.
prr.OnPacketAcked(protocol.DefaultTCPMSS)
bytesInFlight -= protocol.DefaultTCPMSS
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeFalse())
// Ack another packet. PRR should now allow sending a packet in response.
prr.OnPacketAcked(protocol.DefaultTCPMSS)
bytesInFlight -= protocol.DefaultTCPMSS
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
// Send a packet in response.
prr.OnPacketSent(protocol.DefaultTCPMSS)
bytesInFlight += protocol.DefaultTCPMSS
}
// Since bytes_in_flight is now equal to congestion_window, PRR now maintains
// packet conservation, allowing one packet to be sent in response to an ack.
Expect(bytesInFlight).To(Equal(congestionWindow))
for i := 0; i < 10; i++ {
// Ack a packet.
prr.OnPacketAcked(protocol.DefaultTCPMSS)
bytesInFlight -= protocol.DefaultTCPMSS
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
// Send a packet in response, since PRR allows it.
prr.OnPacketSent(protocol.DefaultTCPMSS)
bytesInFlight += protocol.DefaultTCPMSS
// Since bytes_in_flight is equal to the congestion_window,
// PRR disallows sending.
Expect(bytesInFlight).To(Equal(congestionWindow))
Expect(prr.CanSend(congestionWindow, bytesInFlight, sshthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeFalse())
}
})
It("burst loss results in slow start", func() {
bytesInFlight := protocol.ByteCount(20 * protocol.DefaultTCPMSS)
const numPacketsLost = 13
const ssthreshAfterLoss = 10
const congestionWindow = ssthreshAfterLoss * protocol.DefaultTCPMSS
// Lose 13 packets.
bytesInFlight -= numPacketsLost * protocol.DefaultTCPMSS
prr.OnPacketLost(bytesInFlight)
// PRR-SSRB will allow the following 3 acks to send up to 2 packets.
for i := 0; i < 3; i++ {
prr.OnPacketAcked(protocol.DefaultTCPMSS)
bytesInFlight -= protocol.DefaultTCPMSS
// PRR-SSRB should allow two packets to be sent.
for j := 0; j < 2; j++ {
Expect(prr.CanSend(congestionWindow, bytesInFlight, ssthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
// Send a packet in response.
prr.OnPacketSent(protocol.DefaultTCPMSS)
bytesInFlight += protocol.DefaultTCPMSS
}
// PRR should allow no more than 2 packets in response to an ack.
Expect(prr.CanSend(congestionWindow, bytesInFlight, ssthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeFalse())
}
// Out of SSRB mode, PRR allows one send in response to each ack.
for i := 0; i < 10; i++ {
prr.OnPacketAcked(protocol.DefaultTCPMSS)
bytesInFlight -= protocol.DefaultTCPMSS
Expect(prr.CanSend(congestionWindow, bytesInFlight, ssthreshAfterLoss*protocol.DefaultTCPMSS)).To(BeTrue())
// Send a packet in response.
prr.OnPacketSent(protocol.DefaultTCPMSS)
bytesInFlight += protocol.DefaultTCPMSS
}
})
})

View File

@ -1,131 +0,0 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("RTT stats", func() {
var (
rttStats *RTTStats
)
BeforeEach(func() {
rttStats = NewRTTStats()
})
It("DefaultsBeforeUpdate", func() {
Expect(rttStats.MinRTT()).To(Equal(time.Duration(0)))
Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0)))
})
It("SmoothedRTT", func() {
// Verify that ack_delay is ignored in the first measurement.
rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{})
Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond)))
Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond)))
// Verify that Smoothed RTT includes max ack delay if it's reasonable.
rttStats.UpdateRTT((350 * time.Millisecond), (50 * time.Millisecond), time.Time{})
Expect(rttStats.LatestRTT()).To(Equal((300 * time.Millisecond)))
Expect(rttStats.SmoothedRTT()).To(Equal((300 * time.Millisecond)))
// Verify that large erroneous ack_delay does not change Smoothed RTT.
rttStats.UpdateRTT((200 * time.Millisecond), (300 * time.Millisecond), time.Time{})
Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond)))
Expect(rttStats.SmoothedRTT()).To(Equal((287500 * time.Microsecond)))
})
It("SmoothedOrInitialRTT", func() {
Expect(rttStats.SmoothedOrInitialRTT()).To(Equal(defaultInitialRTT))
rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{})
Expect(rttStats.SmoothedOrInitialRTT()).To(Equal((300 * time.Millisecond)))
})
It("MinRTT", func() {
rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{})
Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond)))
rttStats.UpdateRTT((10 * time.Millisecond), 0, time.Time{}.Add((10 * time.Millisecond)))
Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond)))
rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((20 * time.Millisecond)))
Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond)))
rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((30 * time.Millisecond)))
Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond)))
rttStats.UpdateRTT((50 * time.Millisecond), 0, time.Time{}.Add((40 * time.Millisecond)))
Expect(rttStats.MinRTT()).To(Equal((10 * time.Millisecond)))
// Verify that ack_delay does not go into recording of MinRTT_.
rttStats.UpdateRTT((7 * time.Millisecond), (2 * time.Millisecond), time.Time{}.Add((50 * time.Millisecond)))
Expect(rttStats.MinRTT()).To(Equal((7 * time.Millisecond)))
})
It("ExpireSmoothedMetrics", func() {
initialRtt := (10 * time.Millisecond)
rttStats.UpdateRTT(initialRtt, 0, time.Time{})
Expect(rttStats.MinRTT()).To(Equal(initialRtt))
Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt))
Expect(rttStats.MeanDeviation()).To(Equal(initialRtt / 2))
// Update once with a 20ms RTT.
doubledRtt := initialRtt * (2)
rttStats.UpdateRTT(doubledRtt, 0, time.Time{})
Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(float32(initialRtt) * 1.125)))
// Expire the smoothed metrics, increasing smoothed rtt and mean deviation.
rttStats.ExpireSmoothedMetrics()
Expect(rttStats.SmoothedRTT()).To(Equal(doubledRtt))
Expect(rttStats.MeanDeviation()).To(Equal(time.Duration(float32(initialRtt) * 0.875)))
// Now go back down to 5ms and expire the smoothed metrics, and ensure the
// mean deviation increases to 15ms.
halfRtt := initialRtt / 2
rttStats.UpdateRTT(halfRtt, 0, time.Time{})
Expect(doubledRtt).To(BeNumerically(">", rttStats.SmoothedRTT()))
Expect(initialRtt).To(BeNumerically("<", rttStats.MeanDeviation()))
})
It("UpdateRTTWithBadSendDeltas", func() {
// Make sure we ignore bad RTTs.
// base::test::MockLog log;
initialRtt := (10 * time.Millisecond)
rttStats.UpdateRTT(initialRtt, 0, time.Time{})
Expect(rttStats.MinRTT()).To(Equal(initialRtt))
Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt))
badSendDeltas := []time.Duration{
0,
utils.InfDuration,
-1000 * time.Microsecond,
}
// log.StartCapturingLogs();
for _, badSendDelta := range badSendDeltas {
// SCOPED_TRACE(Message() << "bad_send_delta = "
// << bad_send_delta.ToMicroseconds());
// EXPECT_CALL(log, Log(LOG_WARNING, _, _, _, HasSubstr("Ignoring")));
rttStats.UpdateRTT(badSendDelta, 0, time.Time{})
Expect(rttStats.MinRTT()).To(Equal(initialRtt))
Expect(rttStats.SmoothedRTT()).To(Equal(initialRtt))
}
})
It("ResetAfterConnectionMigrations", func() {
rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{})
Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond)))
Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond)))
Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond)))
rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{})
Expect(rttStats.LatestRTT()).To(Equal((200 * time.Millisecond)))
Expect(rttStats.SmoothedRTT()).To(Equal((200 * time.Millisecond)))
Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond)))
// Reset rtt stats on connection migrations.
rttStats.OnConnectionMigration()
Expect(rttStats.LatestRTT()).To(Equal(time.Duration(0)))
Expect(rttStats.SmoothedRTT()).To(Equal(time.Duration(0)))
Expect(rttStats.MinRTT()).To(Equal(time.Duration(0)))
})
})

View File

@ -1,69 +0,0 @@
package crypto
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("AES-GCM", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
keyAlice = make([]byte, 16)
keyBob = make([]byte, 16)
ivAlice = make([]byte, 4)
ivBob = make([]byte, 4)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADAESGCM12(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + bob.Overhead()))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
var err error
e := "AES-GCM: expected 16-byte keys and 4-byte IVs"
_, err = NewAEADAESGCM12(keyBob[1:], keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADAESGCM12(keyBob, keyAlice[1:], ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})

View File

@ -1,84 +0,0 @@
package crypto
import (
"crypto/rand"
"fmt"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("AES-GCM", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
ivAlice = make([]byte, 12)
ivBob = make([]byte, 12)
})
// 16 bytes for TLS_AES_128_GCM_SHA256
// 32 bytes for TLS_AES_256_GCM_SHA384
for _, ks := range []int{16, 32} {
keySize := ks
Context(fmt.Sprintf("with %d byte keys", keySize), func() {
BeforeEach(func() {
keyAlice = make([]byte, keySize)
keyBob = make([]byte, keySize)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADAESGCM(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADAESGCM(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + bob.Overhead()))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
e := "AES-GCM: expected 12 byte IVs"
var err error
_, err = NewAEADAESGCM(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADAESGCM(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})
}
It("errors when an invalid key size is used", func() {
keyAlice = make([]byte, 17)
keyBob = make([]byte, 17)
_, err := NewAEADAESGCM(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError("crypto/aes: invalid key size 17"))
})
})

View File

@ -1,51 +0,0 @@
package crypto
import (
lru "github.com/hashicorp/golang-lru"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Certificate cache", func() {
BeforeEach(func() {
var err error
compressedCertsCache, err = lru.New(2)
Expect(err).NotTo(HaveOccurred())
})
It("gives a compressed cert", func() {
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
expected, err := compressChain(chain, nil, nil)
Expect(err).NotTo(HaveOccurred())
compressed, err := getCompressedCert(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(compressed).To(Equal(expected))
})
It("gets the same result multiple times", func() {
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
compressed, err := getCompressedCert(chain, nil, nil)
Expect(err).NotTo(HaveOccurred())
compressed2, err := getCompressedCert(chain, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(compressed).To(Equal(compressed2))
})
It("stores cached values", func() {
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
_, err := getCompressedCert(chain, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(compressedCertsCache.Len()).To(Equal(1))
Expect(compressedCertsCache.Contains(uint64(3838929964809501833))).To(BeTrue())
})
It("evicts old values", func() {
_, err := getCompressedCert([][]byte{{0x00}}, nil, nil)
Expect(err).NotTo(HaveOccurred())
_, err = getCompressedCert([][]byte{{0x01}}, nil, nil)
Expect(err).NotTo(HaveOccurred())
_, err = getCompressedCert([][]byte{{0x02}}, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(compressedCertsCache.Len()).To(Equal(2))
})
})

View File

@ -1,148 +0,0 @@
package crypto
import (
"bytes"
"compress/flate"
"compress/zlib"
"crypto/tls"
"reflect"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Proof", func() {
var (
cc *certChain
config *tls.Config
cert tls.Certificate
)
BeforeEach(func() {
cert = testdata.GetCertificate()
config = &tls.Config{}
cc = NewCertChain(config).(*certChain)
})
Context("certificate compression", func() {
It("compresses certs", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert)
z.Close()
kd := &certChain{
config: &tls.Config{
Certificates: []tls.Certificate{
{Certificate: [][]byte{cert}},
},
},
}
certCompressed, err := kd.GetCertsCompressed("", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(certCompressed).To(Equal(append([]byte{
0x01, 0x00,
0x08, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("errors when it can't retrieve a certificate", func() {
_, err := cc.GetCertsCompressed("invalid domain", nil, nil)
Expect(err).To(MatchError(errNoMatchingCertificate))
})
})
Context("signing server configs", func() {
It("errors when it can't retrieve a certificate for the requested SNI", func() {
_, err := cc.SignServerProof("invalid", []byte("chlo"), []byte("scfg"))
Expect(err).To(MatchError(errNoMatchingCertificate))
})
It("signs the server config", func() {
config.Certificates = []tls.Certificate{cert}
proof, err := cc.SignServerProof("", []byte("chlo"), []byte("scfg"))
Expect(err).ToNot(HaveOccurred())
Expect(proof).ToNot(BeEmpty())
})
})
Context("retrieving certificates", func() {
It("errors without certificates", func() {
_, err := cc.getCertForSNI("")
Expect(err).To(MatchError(errNoMatchingCertificate))
})
It("uses first certificate in config.Certificates", func() {
config.Certificates = []tls.Certificate{cert}
cert, err := cc.getCertForSNI("")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses NameToCertificate entries", func() {
config.Certificates = []tls.Certificate{cert, cert} // two entries so the long path is used
config.NameToCertificate = map[string]*tls.Certificate{
"quic.clemente.io": &cert,
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses NameToCertificate entries with wildcard", func() {
config.Certificates = []tls.Certificate{cert, cert} // two entries so the long path is used
config.NameToCertificate = map[string]*tls.Certificate{
"*.clemente.io": &cert,
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses GetCertificate", func() {
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expect(clientHello.ServerName).To(Equal("quic.clemente.io"))
return &cert, nil
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("gets leaf certificates", func() {
config.Certificates = []tls.Certificate{cert}
cert2, err := cc.GetLeafCert("")
Expect(err).ToNot(HaveOccurred())
Expect(cert2).To(Equal(cert.Certificate[0]))
})
It("errors when it can't retrieve a leaf certificate", func() {
_, err := cc.GetLeafCert("invalid domain")
Expect(err).To(MatchError(errNoMatchingCertificate))
})
It("respects GetConfigForClient", func() {
if !reflect.ValueOf(tls.Config{}).FieldByName("GetConfigForClient").IsValid() {
// Pre 1.8, we don't have to do anything
return
}
nestedConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
l := func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
Expect(chi.ServerName).To(Equal("quic.clemente.io"))
return nestedConfig, nil
}
reflect.ValueOf(config).Elem().FieldByName("GetConfigForClient").Set(reflect.ValueOf(l))
resultCert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).NotTo(HaveOccurred())
Expect(*resultCert).To(Equal(cert))
})
})
})

View File

@ -1,294 +0,0 @@
package crypto
import (
"bytes"
"compress/flate"
"compress/zlib"
"encoding/binary"
"errors"
"hash/fnv"
"github.com/lucas-clemente/quic-go-certificates"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func byteHash(d []byte) []byte {
h := fnv.New64a()
h.Write(d)
s := h.Sum64()
res := make([]byte, 8)
binary.LittleEndian.PutUint64(res, s)
return res
}
var _ = Describe("Cert compression and decompression", func() {
var certSetsOld map[uint64]certSet
BeforeEach(func() {
certSetsOld = make(map[uint64]certSet)
for s := range certSets {
certSetsOld[s] = certSets[s]
}
})
AfterEach(func() {
certSets = certSetsOld
})
It("compresses empty", func() {
compressed, err := compressChain(nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(compressed).To(Equal([]byte{0}))
})
It("decompresses empty", func() {
compressed, err := compressChain(nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
uncompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(uncompressed).To(BeEmpty())
})
It("gives correct single cert", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert)
z.Close()
chain := [][]byte{cert}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(compressed).To(Equal(append([]byte{
0x01, 0x00,
0x08, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("decompresses a single cert", func() {
cert := []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}
chain := [][]byte{cert}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
uncompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(uncompressed).To(Equal(chain))
})
It("gives correct cert and intermediate", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert1)
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert2)
z.Close()
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(compressed).To(Equal(append([]byte{
0x01, 0x01, 0x00,
0x10, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("decompresses the chain with a cert and an intermediate", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
decompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(decompressed).To(Equal(chain))
})
It("uses cached certificates", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
certHash := byteHash(cert)
chain := [][]byte{cert}
compressed, err := compressChain(chain, nil, certHash)
Expect(err).ToNot(HaveOccurred())
expected := append([]byte{0x02}, certHash...)
expected = append(expected, 0x00)
Expect(compressed).To(Equal(expected))
})
It("uses cached certificates and compressed combined", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
cert2Hash := byteHash(cert2)
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert1)
z.Close()
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, nil, cert2Hash)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x01, 0x02}
expected = append(expected, cert2Hash...)
expected = append(expected, 0x00)
expected = append(expected, []byte{0x08, 0, 0, 0}...)
expected = append(expected, certZlib.Bytes()...)
Expect(compressed).To(Equal(expected))
})
It("uses common certificate sets", func() {
cert := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x03}
expected = append(expected, setHash...)
expected = append(expected, []byte{42, 0, 0, 0}...)
expected = append(expected, 0x00)
Expect(compressed).To(Equal(expected))
})
It("decompresses a single cert form a common certificate set", func() {
cert := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
decompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(decompressed).To(Equal(chain))
})
It("decompresses multiple certs form common certificate sets", func() {
cert1 := certsets.CertSet3[42]
cert2 := certsets.CertSet2[24]
setHash := make([]byte, 16)
binary.LittleEndian.PutUint64(setHash[0:8], certsets.CertSet3Hash)
binary.LittleEndian.PutUint64(setHash[8:16], certsets.CertSet2Hash)
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
decompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(decompressed).To(Equal(chain))
})
It("ignores uncommon certificate sets", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, 0xdeadbeef)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert)
z.Close()
Expect(compressed).To(Equal(append([]byte{
0x01, 0x00,
0x08, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("errors if a common set does not exist", func() {
cert := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
delete(certSets, certsets.CertSet3Hash)
_, err = decompressChain(compressed)
Expect(err).To(MatchError(errors.New("unknown certSet")))
})
It("errors if a cert in a common set does not exist", func() {
certSet := [][]byte{
{0x1, 0x2, 0x3, 0x4},
{0x5, 0x6, 0x7, 0x8},
}
certSets[0x1337] = certSet
cert := certSet[1]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, 0x1337)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
certSets[0x1337] = certSet[:1] // delete the last certificate from the certSet
_, err = decompressChain(compressed)
Expect(err).To(MatchError(errors.New("certificate not found in certSet")))
})
It("uses common certificates and compressed combined", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert1)
z.Close()
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x01, 0x03}
expected = append(expected, setHash...)
expected = append(expected, []byte{42, 0, 0, 0}...)
expected = append(expected, 0x00)
expected = append(expected, []byte{0x08, 0, 0, 0}...)
expected = append(expected, certZlib.Bytes()...)
Expect(compressed).To(Equal(expected))
})
It("decompresses a certficate from a common set and a compressed cert combined", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
decompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(decompressed).To(Equal(chain))
})
It("rejects invalid CCS / CCRT hashes", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
chain := [][]byte{cert}
_, err := compressChain(chain, []byte("foo"), nil)
Expect(err).To(MatchError("expected a multiple of 8 bytes for CCS / CCRT hashes"))
_, err = compressChain(chain, nil, []byte("foo"))
Expect(err).To(MatchError("expected a multiple of 8 bytes for CCS / CCRT hashes"))
})
Context("common certificate hashes", func() {
It("gets the hashes", func() {
ccs := getCommonCertificateHashes()
Expect(ccs).ToNot(BeEmpty())
hashes, err := splitHashes(ccs)
Expect(err).ToNot(HaveOccurred())
for _, hash := range hashes {
Expect(certSets).To(HaveKey(hash))
}
})
It("returns an empty slice if there are not common sets", func() {
certSets = make(map[uint64]certSet)
ccs := getCommonCertificateHashes()
Expect(ccs).ToNot(BeNil())
Expect(ccs).To(HaveLen(0))
})
})
})

View File

@ -1,348 +0,0 @@
package crypto
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"math/big"
"runtime"
"time"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Cert Manager", func() {
var cm *certManager
var key1, key2 *rsa.PrivateKey
var cert1, cert2 []byte
BeforeEach(func() {
var err error
cm = NewCertManager(nil).(*certManager)
key1, err = rsa.GenerateKey(rand.Reader, 768)
Expect(err).ToNot(HaveOccurred())
key2, err = rsa.GenerateKey(rand.Reader, 768)
Expect(err).ToNot(HaveOccurred())
template := &x509.Certificate{SerialNumber: big.NewInt(1)}
cert1, err = x509.CreateCertificate(rand.Reader, template, template, &key1.PublicKey, key1)
Expect(err).ToNot(HaveOccurred())
cert2, err = x509.CreateCertificate(rand.Reader, template, template, &key2.PublicKey, key2)
Expect(err).ToNot(HaveOccurred())
})
It("saves a client TLS config", func() {
tlsConf := &tls.Config{ServerName: "quic.clemente.io"}
cm = NewCertManager(tlsConf).(*certManager)
Expect(cm.config.ServerName).To(Equal("quic.clemente.io"))
})
It("errors when given invalid data", func() {
err := cm.SetData([]byte("foobar"))
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")))
})
It("gets the common certificate hashes", func() {
ccs := cm.GetCommonCertificateHashes()
Expect(ccs).ToNot(BeEmpty())
})
Context("setting the data", func() {
It("decompresses a certificate chain", func() {
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
err = cm.SetData(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(cm.chain[0].Raw).To(Equal(cert1))
Expect(cm.chain[1].Raw).To(Equal(cert2))
})
It("errors if it can't decompress the chain", func() {
err := cm.SetData([]byte("invalid data"))
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")))
})
It("errors if it can't parse a certificate", func() {
chain := [][]byte{[]byte("cert1"), []byte("cert2")}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
err = cm.SetData(compressed)
_, ok := err.(asn1.StructuralError)
Expect(ok).To(BeTrue())
})
})
Context("getting the leaf cert", func() {
It("gets it", func() {
xcert1, err := x509.ParseCertificate(cert1)
Expect(err).ToNot(HaveOccurred())
xcert2, err := x509.ParseCertificate(cert2)
Expect(err).ToNot(HaveOccurred())
cm.chain = []*x509.Certificate{xcert1, xcert2}
leafCert := cm.GetLeafCert()
Expect(leafCert).To(Equal(cert1))
})
It("returns nil if the chain hasn't been set yet", func() {
leafCert := cm.GetLeafCert()
Expect(leafCert).To(BeNil())
})
})
Context("getting the leaf cert hash", func() {
It("calculates the FVN1a 64 hash", func() {
cm.chain = make([]*x509.Certificate, 1)
cm.chain[0] = &x509.Certificate{
Raw: []byte("test fnv hash"),
}
hash, err := cm.GetLeafCertHash()
Expect(err).ToNot(HaveOccurred())
// hash calculated on http://www.nitrxgen.net/hashgen/
Expect(hash).To(Equal(uint64(0x4770f6141fa0f5ad)))
})
It("errors if the certificate chain is not loaded", func() {
_, err := cm.GetLeafCertHash()
Expect(err).To(MatchError(errNoCertificateChain))
})
})
Context("verifying the server config signature", func() {
It("returns false when the chain hasn't been set yet", func() {
valid := cm.VerifyServerProof([]byte("proof"), []byte("chlo"), []byte("scfg"))
Expect(valid).To(BeFalse())
})
It("verifies the signature", func() {
chlo := []byte("client hello")
scfg := []byte("server config data")
xcert1, err := x509.ParseCertificate(cert1)
Expect(err).ToNot(HaveOccurred())
cm.chain = []*x509.Certificate{xcert1}
proof, err := signServerProof(&tls.Certificate{PrivateKey: key1}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
valid := cm.VerifyServerProof(proof, chlo, scfg)
Expect(valid).To(BeTrue())
})
It("rejects an invalid signature", func() {
xcert1, err := x509.ParseCertificate(cert1)
Expect(err).ToNot(HaveOccurred())
cm.chain = []*x509.Certificate{xcert1}
valid := cm.VerifyServerProof([]byte("invalid proof"), []byte("chlo"), []byte("scfg"))
Expect(valid).To(BeFalse())
})
})
Context("verifying the certificate chain", func() {
generateCertificate := func(template, parent *x509.Certificate, pubKey *rsa.PublicKey, privKey *rsa.PrivateKey) *x509.Certificate {
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pubKey, privKey)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return cert
}
getCertificate := func(template *x509.Certificate) (*rsa.PrivateKey, *x509.Certificate) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
return key, generateCertificate(template, template, &key.PublicKey, key)
}
It("accepts a valid certificate", func() {
cc := NewCertChain(testdata.GetTLSConfig()).(*certChain)
tlsCert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
for _, data := range tlsCert.Certificate {
var cert *x509.Certificate
cert, err = x509.ParseCertificate(data)
Expect(err).ToNot(HaveOccurred())
cm.chain = append(cm.chain, cert)
}
err = cm.Verify("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
})
It("doesn't accept an expired certificate", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-25 * time.Hour),
NotAfter: time.Now().Add(-time.Hour),
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
err := cm.Verify("")
Expect(err).To(HaveOccurred())
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
})
It("doesn't accept a certificate that is not yet valid", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(time.Hour),
NotAfter: time.Now().Add(25 * time.Hour),
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
err := cm.Verify("")
Expect(err).To(HaveOccurred())
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
})
It("doesn't accept an certificate for the wrong hostname", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
Subject: pkix.Name{CommonName: "google.com"},
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
err := cm.Verify("quic.clemente.io")
Expect(err).To(HaveOccurred())
_, ok := err.(x509.HostnameError)
Expect(ok).To(BeTrue())
})
It("errors if the chain hasn't been set yet", func() {
err := cm.Verify("example.com")
Expect(err).To(HaveOccurred())
})
// this tests relies on LetsEncrypt not being contained in the Root CAs
It("rejects valid certificate with missing certificate chain", func() {
if runtime.GOOS == "windows" {
Skip("LetsEncrypt Root CA is included in Windows")
}
cert := testdata.GetCertificate()
xcert, err := x509.ParseCertificate(cert.Certificate[0])
Expect(err).ToNot(HaveOccurred())
cm.chain = []*x509.Certificate{xcert}
err = cm.Verify("quic.clemente.io")
_, ok := err.(x509.UnknownAuthorityError)
Expect(ok).To(BeTrue())
})
It("doesn't do any certificate verification if InsecureSkipVerify is set", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
}
_, leafCert := getCertificate(template)
cm.config = &tls.Config{
InsecureSkipVerify: true,
}
cm.chain = []*x509.Certificate{leafCert}
err := cm.Verify("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
})
It("uses the time specified in a client TLS config", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-25 * time.Hour),
NotAfter: time.Now().Add(-23 * time.Hour),
Subject: pkix.Name{CommonName: "quic.clemente.io"},
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
cm.config = &tls.Config{
Time: func() time.Time { return time.Now().Add(-24 * time.Hour) },
}
err := cm.Verify("quic.clemente.io")
_, ok := err.(x509.UnknownAuthorityError)
Expect(ok).To(BeTrue())
})
It("rejects certificates that are expired at the time specified in a client TLS config", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
cm.config = &tls.Config{
Time: func() time.Time { return time.Now().Add(-24 * time.Hour) },
}
err := cm.Verify("quic.clemente.io")
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
})
It("uses the Root CA given in the client config", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
templateRoot := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
BasicConstraintsValid: true,
}
rootKey, rootCert := getCertificate(templateRoot)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
Subject: pkix.Name{CommonName: "google.com"},
}
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
leafCert := generateCertificate(template, rootCert, &key.PublicKey, rootKey)
rootCAPool := x509.NewCertPool()
rootCAPool.AddCert(rootCert)
cm.chain = []*x509.Certificate{leafCert}
cm.config = &tls.Config{
RootCAs: rootCAPool,
}
err = cm.Verify("google.com")
Expect(err).ToNot(HaveOccurred())
})
})
})

View File

@ -1,71 +0,0 @@
// +build ignore
package crypto
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Chacha20poly1305", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
keyAlice = make([]byte, 32)
keyBob = make([]byte, 32)
ivAlice = make([]byte, 4)
ivBob = make([]byte, 4)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + 12))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
var err error
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})

View File

@ -1,13 +0,0 @@
package crypto
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestCrypto(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Crypto Suite")
}

View File

@ -1,27 +0,0 @@
package crypto
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("ProofRsa", func() {
It("works", func() {
a, err := NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
b, err := NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
sA, err := a.CalculateSharedKey(b.PublicKey())
Expect(err).ToNot(HaveOccurred())
sB, err := b.CalculateSharedKey(a.PublicKey())
Expect(err).ToNot(HaveOccurred())
Expect(sA).To(Equal(sB))
})
It("rejects short public keys", func() {
a, err := NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
_, err = a.CalculateSharedKey(nil)
Expect(err).To(MatchError("Curve25519: expected public key of 32 byte"))
})
})

View File

@ -1,197 +0,0 @@
package crypto
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("QUIC Crypto Key Derivation", func() {
// Context("chacha20poly1305", func() {
// It("derives non-fs keys", func() {
// aead, err := DeriveKeysChacha20(
// protocol.Version32,
// false,
// []byte("0123456789012345678901"),
// []byte("nonce"),
// protocol.ConnectionID(42),
// []byte("chlo"),
// []byte("scfg"),
// []byte("cert"),
// nil,
// )
// Expect(err).ToNot(HaveOccurred())
// chacha := aead.(*aeadChacha20Poly1305)
// // If the IVs match, the keys will match too, since the keys are read earlier
// Expect(chacha.myIV).To(Equal([]byte{0xf0, 0xf5, 0x4c, 0xa8}))
// Expect(chacha.otherIV).To(Equal([]byte{0x75, 0xd8, 0xa2, 0x8d}))
// })
//
// It("derives fs keys", func() {
// aead, err := DeriveKeysChacha20(
// protocol.Version32,
// true,
// []byte("0123456789012345678901"),
// []byte("nonce"),
// protocol.ConnectionID(42),
// []byte("chlo"),
// []byte("scfg"),
// []byte("cert"),
// nil,
// )
// Expect(err).ToNot(HaveOccurred())
// chacha := aead.(*aeadChacha20Poly1305)
// // If the IVs match, the keys will match too, since the keys are read earlier
// Expect(chacha.myIV).To(Equal([]byte{0xf5, 0x73, 0x11, 0x79}))
// Expect(chacha.otherIV).To(Equal([]byte{0xf7, 0x26, 0x4d, 0x2c}))
// })
//
// It("does not use diversification nonces in FS key derivation", func() {
// aead, err := DeriveKeysChacha20(
// protocol.Version33,
// true,
// []byte("0123456789012345678901"),
// []byte("nonce"),
// protocol.ConnectionID(42),
// []byte("chlo"),
// []byte("scfg"),
// []byte("cert"),
// []byte("divnonce"),
// )
// Expect(err).ToNot(HaveOccurred())
// chacha := aead.(*aeadChacha20Poly1305)
// // If the IVs match, the keys will match too, since the keys are read earlier
// Expect(chacha.myIV).To(Equal([]byte{0xf5, 0x73, 0x11, 0x79}))
// Expect(chacha.otherIV).To(Equal([]byte{0xf7, 0x26, 0x4d, 0x2c}))
// })
//
// It("uses diversification nonces in initial key derivation", func() {
// aead, err := DeriveKeysChacha20(
// protocol.Version33,
// false,
// []byte("0123456789012345678901"),
// []byte("nonce"),
// protocol.ConnectionID(42),
// []byte("chlo"),
// []byte("scfg"),
// []byte("cert"),
// []byte("divnonce"),
// )
// Expect(err).ToNot(HaveOccurred())
// chacha := aead.(*aeadChacha20Poly1305)
// // If the IVs match, the keys will match too, since the keys are read earlier
// Expect(chacha.myIV).To(Equal([]byte{0xc4, 0x12, 0x25, 0x64}))
// Expect(chacha.otherIV).To(Equal([]byte{0x75, 0xd8, 0xa2, 0x8d}))
// })
// })
Context("AES-GCM", func() {
It("derives non-forward secure keys", func() {
aead, err := DeriveQuicCryptoAESKeys(
false,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("divnonce"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aesgcm := aead.(*aeadAESGCM12)
// If the IVs match, the keys will match too, since the keys are read earlier
Expect(aesgcm.myIV).To(Equal([]byte{0x1c, 0xec, 0xac, 0x9b}))
Expect(aesgcm.otherIV).To(Equal([]byte{0x64, 0xef, 0x3c, 0x9}))
})
It("uses the diversification nonce when generating non-forwared secure keys", func() {
aead1, err := DeriveQuicCryptoAESKeys(
false,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("divnonce"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aead2, err := DeriveQuicCryptoAESKeys(
false,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("ecnonvid"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aesgcm1 := aead1.(*aeadAESGCM12)
aesgcm2 := aead2.(*aeadAESGCM12)
Expect(aesgcm1.myIV).ToNot(Equal(aesgcm2.myIV))
Expect(aesgcm1.otherIV).To(Equal(aesgcm2.otherIV))
})
It("derives non-forward secure keys, for the other side", func() {
aead, err := DeriveQuicCryptoAESKeys(
false,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("divnonce"),
protocol.PerspectiveClient,
)
Expect(err).ToNot(HaveOccurred())
aesgcm := aead.(*aeadAESGCM12)
// If the IVs match, the keys will match too, since the keys are read earlier
Expect(aesgcm.otherIV).To(Equal([]byte{0x1c, 0xec, 0xac, 0x9b}))
Expect(aesgcm.myIV).To(Equal([]byte{0x64, 0xef, 0x3c, 0x9}))
})
It("derives forward secure keys", func() {
aead, err := DeriveQuicCryptoAESKeys(
true,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
nil,
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aesgcm := aead.(*aeadAESGCM12)
// If the IVs match, the keys will match too, since the keys are read earlier
Expect(aesgcm.myIV).To(Equal([]byte{0x7, 0xad, 0xab, 0xb8}))
Expect(aesgcm.otherIV).To(Equal([]byte{0xf2, 0x7a, 0xcc, 0x42}))
})
It("does not use div-nonce for FS key derivation", func() {
aead, err := DeriveQuicCryptoAESKeys(
true,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("divnonce"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aesgcm := aead.(*aeadAESGCM12)
// If the IVs match, the keys will match too, since the keys are read earlier
Expect(aesgcm.myIV).To(Equal([]byte{0x7, 0xad, 0xab, 0xb8}))
Expect(aesgcm.otherIV).To(Equal([]byte{0xf2, 0x7a, 0xcc, 0x42}))
})
})
})

View File

@ -1,56 +0,0 @@
package crypto
import (
"crypto"
"errors"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockTLSExporter struct {
hash crypto.Hash
computerError error
}
var _ TLSExporter = &mockTLSExporter{}
func (c *mockTLSExporter) Handshake() mint.Alert { panic("not implemented") }
func (c *mockTLSExporter) ConnectionState() mint.ConnectionState {
return mint.ConnectionState{
CipherSuite: mint.CipherSuiteParams{
Hash: c.hash,
KeyLen: 32,
IvLen: 12,
},
}
}
func (c *mockTLSExporter) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
if c.computerError != nil {
return nil, c.computerError
}
return append([]byte(label), context...), nil
}
var _ = Describe("Key Derivation", func() {
It("derives keys", func() {
clientAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
serverAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
ciphertext := clientAEAD.Seal(nil, []byte("foobar"), 0, []byte("aad"))
data, err := serverAEAD.Open(nil, ciphertext, 0, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
})
It("fails when computing the exporter fails", func() {
testErr := errors.New("test error")
_, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256, computerError: testErr}, protocol.PerspectiveClient)
Expect(err).To(MatchError(testErr))
})
})

View File

@ -1,86 +0,0 @@
package crypto
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("NullAEAD using AES-GCM", func() {
// values taken from https://github.com/quicwg/base-drafts/wiki/Test-Vector-for-the-Clear-Text-AEAD-key-derivation
Context("using the test vector from the QUIC WG Wiki", func() {
connID := protocol.ConnectionID([]byte{0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08})
It("computes the secrets", func() {
clientSecret, serverSecret := computeSecrets(connID)
Expect(clientSecret).To(Equal([]byte{
0x83, 0x55, 0xf2, 0x1a, 0x3d, 0x8f, 0x83, 0xec,
0xb3, 0xd0, 0xf9, 0x71, 0x08, 0xd3, 0xf9, 0x5e,
0x0f, 0x65, 0xb4, 0xd8, 0xae, 0x88, 0xa0, 0x61,
0x1e, 0xe4, 0x9d, 0xb0, 0xb5, 0x23, 0x59, 0x1d,
}))
Expect(serverSecret).To(Equal([]byte{
0xf8, 0x0e, 0x57, 0x71, 0x48, 0x4b, 0x21, 0xcd,
0xeb, 0xb5, 0xaf, 0xe0, 0xa2, 0x56, 0xa3, 0x17,
0x41, 0xef, 0xe2, 0xb5, 0xc6, 0xb6, 0x17, 0xba,
0xe1, 0xb2, 0xf1, 0x5a, 0x83, 0x04, 0x83, 0xd6,
}))
})
It("computes the client key and IV", func() {
clientSecret, _ := computeSecrets(connID)
key, iv := computeNullAEADKeyAndIV(clientSecret)
Expect(key).To(Equal([]byte{
0x3a, 0xd0, 0x54, 0x2c, 0x4a, 0x85, 0x84, 0x74,
0x00, 0x63, 0x04, 0x9e, 0x3b, 0x3c, 0xaa, 0xb2,
}))
Expect(iv).To(Equal([]byte{
0xd1, 0xfd, 0x26, 0x05, 0x42, 0x75, 0x3a, 0xba,
0x38, 0x58, 0x9b, 0xad,
}))
})
It("computes the server key and IV", func() {
_, serverSecret := computeSecrets(connID)
key, iv := computeNullAEADKeyAndIV(serverSecret)
Expect(key).To(Equal([]byte{
0xbe, 0xe4, 0xc2, 0x4d, 0x2a, 0xf1, 0x33, 0x80,
0xa9, 0xfa, 0x24, 0xa5, 0xe2, 0xba, 0x2c, 0xff,
}))
Expect(iv).To(Equal([]byte{
0x25, 0xb5, 0x8e, 0x24, 0x6d, 0x9e, 0x7d, 0x5f,
0xfe, 0x43, 0x23, 0xfe,
}))
})
})
It("seals and opens", func() {
connectionID := protocol.ConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef})
clientAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
serverAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad"))
m, err := serverAEAD.Open(nil, clientMessage, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(m).To(Equal([]byte("foobar")))
serverMessage := serverAEAD.Seal(nil, []byte("raboof"), 99, []byte("daa"))
m, err = clientAEAD.Open(nil, serverMessage, 99, []byte("daa"))
Expect(err).ToNot(HaveOccurred())
Expect(m).To(Equal([]byte("raboof")))
})
It("doesn't work if initialized with different connection IDs", func() {
c1 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1})
c2 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2})
clientAEAD, err := newNullAEADAESGCM(c1, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
serverAEAD, err := newNullAEADAESGCM(c2, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err = serverAEAD.Open(nil, clientMessage, 42, []byte("aad"))
Expect(err).To(MatchError("cipher: message authentication failed"))
})
})

View File

@ -1,55 +0,0 @@
package crypto
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("NullAEAD using FNV128a", func() {
aad := []byte("All human beings are born free and equal in dignity and rights.")
plainText := []byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")
hash36 := []byte{0x98, 0x9b, 0x33, 0x3f, 0xe8, 0xde, 0x32, 0x5c, 0xa6, 0x7f, 0x9c, 0xf7}
var aeadServer AEAD
var aeadClient AEAD
BeforeEach(func() {
aeadServer = &nullAEADFNV128a{protocol.PerspectiveServer}
aeadClient = &nullAEADFNV128a{protocol.PerspectiveClient}
})
It("seals and opens, client => server", func() {
cipherText := aeadClient.Seal(nil, plainText, 0, aad)
res, err := aeadServer.Open(nil, cipherText, 0, aad)
Expect(err).ToNot(HaveOccurred())
Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")))
})
It("seals and opens, server => client", func() {
cipherText := aeadServer.Seal(nil, plainText, 0, aad)
res, err := aeadClient.Open(nil, cipherText, 0, aad)
Expect(err).ToNot(HaveOccurred())
Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")))
})
It("rejects short ciphertexts", func() {
_, err := aeadServer.Open(nil, nil, 0, nil)
Expect(err).To(MatchError("NullAEAD: ciphertext cannot be less than 12 bytes long"))
})
It("seals in-place", func() {
buf := make([]byte, 6, 12+6)
copy(buf, []byte("foobar"))
res := aeadServer.Seal(buf[0:0], buf, 0, nil)
buf = buf[:12+6]
Expect(buf[12:]).To(Equal([]byte("foobar")))
Expect(res[12:]).To(Equal([]byte("foobar")))
})
It("fails", func() {
cipherText := append(append(hash36, plainText...), byte(0x42))
_, err := aeadClient.Open(nil, cipherText, 0, aad)
Expect(err).To(HaveOccurred())
})
})

View File

@ -1,17 +0,0 @@
package crypto
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("NullAEAD", func() {
It("selects the right FVN variant", func() {
connID := protocol.ConnectionID([]byte{0x42, 0, 0, 0, 0, 0, 0, 0})
Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.Version39)).To(Equal(&nullAEADFNV128a{
perspective: protocol.PerspectiveClient,
}))
Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.VersionTLS)).To(BeAssignableToTypeOf(&aeadAESGCM{}))
})
})

View File

@ -1,127 +0,0 @@
package crypto
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"math/big"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Proof", func() {
It("gives valid signatures with the key in internal/testdata", func() {
key := &testdata.GetTLSConfig().Certificates[0]
signature, err := signServerProof(key, []byte{'C', 'H', 'L', 'O'}, []byte{'S', 'C', 'F', 'G'})
Expect(err).ToNot(HaveOccurred())
// Generated with:
// ruby -e 'require "digest"; p Digest::SHA256.digest("QUIC CHLO and server config signature\x00" + "\x20\x00\x00\x00" + Digest::SHA256.digest("CHLO") + "SCFG")'
data := []byte("W\xA6\xFC\xDE\xC7\xD2>c\xE6\xB5\xF6\tq\x9E|<~1\xA33\x01\xCA=\x19\xBD\xC1\xE4\xB0\xBA\x9B\x16%")
err = rsa.VerifyPSS(key.PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey), crypto.SHA256, data, signature, &rsa.PSSOptions{SaltLength: 32})
Expect(err).ToNot(HaveOccurred())
})
Context("when using RSA", func() {
generateCert := func() (*rsa.PrivateKey, *x509.Certificate) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).NotTo(HaveOccurred())
certTemplate := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return key, cert
}
It("verifies a signature", func() {
key, cert := generateCert()
chlo := []byte("chlo")
scfg := []byte("scfg")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(signature, cert, chlo, scfg)).To(BeTrue())
})
It("rejects invalid signatures", func() {
key, cert := generateCert()
chlo := []byte("client hello")
scfg := []byte("sever config")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(append(signature, byte(0x99)), cert, chlo, scfg)).To(BeFalse())
Expect(verifyServerProof(signature, cert, chlo[:len(chlo)-2], scfg)).To(BeFalse())
Expect(verifyServerProof(signature, cert, chlo, scfg[:len(scfg)-2])).To(BeFalse())
})
})
Context("when using ECDSA", func() {
generateCert := func() (*ecdsa.PrivateKey, *x509.Certificate) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Expect(err).NotTo(HaveOccurred())
certTemplate := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return key, cert
}
It("gives valid signatures", func() {
key, _ := generateCert()
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, []byte{'C', 'H', 'L', 'O'}, []byte{'S', 'C', 'F', 'G'})
Expect(err).ToNot(HaveOccurred())
// Generated with:
// ruby -e 'require "digest"; p Digest::SHA256.digest("QUIC CHLO and server config signature\x00" + "\x20\x00\x00\x00" + Digest::SHA256.digest("CHLO") + "SCFG")'
data := []byte("W\xA6\xFC\xDE\xC7\xD2>c\xE6\xB5\xF6\tq\x9E|<~1\xA33\x01\xCA=\x19\xBD\xC1\xE4\xB0\xBA\x9B\x16%")
s := &ecdsaSignature{}
_, err = asn1.Unmarshal(signature, s)
Expect(err).NotTo(HaveOccurred())
b := ecdsa.Verify(key.Public().(*ecdsa.PublicKey), data, s.R, s.S)
Expect(b).To(BeTrue())
})
It("verifies a signature", func() {
key, cert := generateCert()
chlo := []byte("chlo")
scfg := []byte("server config")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(signature, cert, chlo, scfg)).To(BeTrue())
})
It("rejects invalid signatures", func() {
key, cert := generateCert()
chlo := []byte("client hello")
scfg := []byte("server config")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(append(signature, byte(0x99)), cert, chlo, scfg)).To(BeFalse())
Expect(verifyServerProof(signature, cert, chlo[:len(chlo)-2], scfg)).To(BeFalse())
Expect(verifyServerProof(signature, cert, chlo, scfg[:len(scfg)-2])).To(BeFalse())
})
It("rejects signatures generated with a different certificate", func() {
key1, cert1 := generateCert()
key2, cert2 := generateCert()
Expect(key1.PublicKey).ToNot(Equal(key2))
Expect(cert1.Equal(cert2)).To(BeFalse())
chlo := []byte("chlo")
scfg := []byte("sfcg")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key1}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(signature, cert2, chlo, scfg)).To(BeFalse())
})
})
})

View File

@ -1,235 +0,0 @@
package flowcontrol
import (
"os"
"strconv"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// on the CIs, the timing is a lot less precise, so scale every duration by this factor
func scaleDuration(t time.Duration) time.Duration {
scaleFactor := 1
if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
scaleFactor = f
}
Expect(scaleFactor).ToNot(BeZero())
return time.Duration(scaleFactor) * t
}
var _ = Describe("Base Flow controller", func() {
var controller *baseFlowController
BeforeEach(func() {
controller = &baseFlowController{}
controller.rttStats = &congestion.RTTStats{}
})
Context("send flow control", func() {
It("adds bytes sent", func() {
controller.bytesSent = 5
controller.AddBytesSent(6)
Expect(controller.bytesSent).To(Equal(protocol.ByteCount(5 + 6)))
})
It("gets the size of the remaining flow control window", func() {
controller.bytesSent = 5
controller.sendWindow = 12
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(12 - 5)))
})
It("updates the size of the flow control window", func() {
controller.AddBytesSent(5)
controller.UpdateSendWindow(15)
Expect(controller.sendWindow).To(Equal(protocol.ByteCount(15)))
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(15 - 5)))
})
It("says that the window size is 0 if we sent more than we were allowed to", func() {
controller.AddBytesSent(15)
controller.UpdateSendWindow(10)
Expect(controller.sendWindowSize()).To(BeZero())
})
It("does not decrease the flow control window", func() {
controller.UpdateSendWindow(20)
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20)))
controller.UpdateSendWindow(10)
Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20)))
})
It("says when it's blocked", func() {
controller.UpdateSendWindow(100)
Expect(controller.IsNewlyBlocked()).To(BeFalse())
controller.AddBytesSent(100)
blocked, offset := controller.IsNewlyBlocked()
Expect(blocked).To(BeTrue())
Expect(offset).To(Equal(protocol.ByteCount(100)))
})
It("doesn't say that it's newly blocked multiple times for the same offset", func() {
controller.UpdateSendWindow(100)
controller.AddBytesSent(100)
newlyBlocked, offset := controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeTrue())
Expect(offset).To(Equal(protocol.ByteCount(100)))
newlyBlocked, _ = controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeFalse())
controller.UpdateSendWindow(150)
controller.AddBytesSent(150)
newlyBlocked, _ = controller.IsNewlyBlocked()
Expect(newlyBlocked).To(BeTrue())
})
})
Context("receive flow control", func() {
var (
receiveWindow protocol.ByteCount = 10000
receiveWindowSize protocol.ByteCount = 1000
)
BeforeEach(func() {
controller.bytesRead = receiveWindow - receiveWindowSize
controller.receiveWindow = receiveWindow
controller.receiveWindowSize = receiveWindowSize
})
It("adds bytes read", func() {
controller.bytesRead = 5
controller.AddBytesRead(6)
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(5 + 6)))
})
It("triggers a window update when necessary", func() {
bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold + 1 // consumed 1 byte more than the threshold
bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
readPosition := receiveWindow - bytesRemaining
controller.bytesRead = readPosition
offset := controller.getWindowUpdate()
Expect(offset).To(Equal(readPosition + receiveWindowSize))
Expect(controller.receiveWindow).To(Equal(readPosition + receiveWindowSize))
})
It("doesn't trigger a window update when not necessary", func() {
bytesConsumed := float64(receiveWindowSize)*protocol.WindowUpdateThreshold - 1 // consumed 1 byte less than the threshold
bytesRemaining := receiveWindowSize - protocol.ByteCount(bytesConsumed)
readPosition := receiveWindow - bytesRemaining
controller.bytesRead = readPosition
offset := controller.getWindowUpdate()
Expect(offset).To(BeZero())
})
Context("receive window size auto-tuning", func() {
var oldWindowSize protocol.ByteCount
BeforeEach(func() {
oldWindowSize = controller.receiveWindowSize
controller.maxReceiveWindowSize = 5000
})
// update the congestion such that it returns a given value for the smoothed RTT
setRtt := func(t time.Duration) {
controller.rttStats.UpdateRTT(t, 0, time.Now())
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
}
It("doesn't increase the window size for a new stream", func() {
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
})
It("doesn't increase the window size when no RTT estimate is available", func() {
setRtt(0)
controller.startNewAutoTuningEpoch()
controller.AddBytesRead(400)
offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero()) // make sure a window update is sent
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
})
It("increases the window size if read so fast that the window would be consumed in less than 4 RTTs", func() {
bytesRead := controller.bytesRead
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
// consume more than 2/3 of the window...
dataRead := receiveWindowSize*2/3 + 1
// ... in 4*2/3 of the RTT
controller.epochStartOffset = controller.bytesRead
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
controller.AddBytesRead(dataRead)
offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero())
// check that the window size was increased
newWindowSize := controller.receiveWindowSize
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
})
It("doesn't increase the window size if data is read so fast that the window would be consumed in less than 4 RTTs, but less than half the window has been read", func() {
// this test only makes sense if a window update is triggered before half of the window has been consumed
Expect(protocol.WindowUpdateThreshold).To(BeNumerically(">", 1/3))
bytesRead := controller.bytesRead
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
// consume more than 2/3 of the window...
dataRead := receiveWindowSize*1/3 + 1
// ... in 4*2/3 of the RTT
controller.epochStartOffset = controller.bytesRead
controller.epochStartTime = time.Now().Add(-rtt * 4 * 1 / 3)
controller.AddBytesRead(dataRead)
offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero())
// check that the window size was not increased
newWindowSize := controller.receiveWindowSize
Expect(newWindowSize).To(Equal(oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + newWindowSize)))
})
It("doesn't increase the window size if read too slowly", func() {
bytesRead := controller.bytesRead
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
// consume less than 2/3 of the window...
dataRead := receiveWindowSize*2/3 - 1
// ... in 4*2/3 of the RTT
controller.epochStartOffset = controller.bytesRead
controller.epochStartTime = time.Now().Add(-rtt * 4 * 2 / 3)
controller.AddBytesRead(dataRead)
offset := controller.getWindowUpdate()
Expect(offset).ToNot(BeZero())
// check that the window size was not increased
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
// check that the new window size was used to increase the offset
Expect(offset).To(Equal(protocol.ByteCount(bytesRead + dataRead + oldWindowSize)))
})
It("doesn't increase the window size to a value higher than the maxReceiveWindowSize", func() {
resetEpoch := func() {
// make sure the next call to maybeAdjustWindowSize will increase the window
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.epochStartOffset = controller.bytesRead
controller.AddBytesRead(controller.receiveWindowSize/2 + 1)
}
setRtt(scaleDuration(20 * time.Millisecond))
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize)) // 2000
// because the lastWindowUpdateTime is updated by MaybeTriggerWindowUpdate(), we can just call maybeAdjustWindowSize() multiple times and get an increase of the window size every time
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(2 * 2 * oldWindowSize)) // 4000
resetEpoch()
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
controller.maybeAdjustWindowSize()
Expect(controller.receiveWindowSize).To(Equal(controller.maxReceiveWindowSize)) // 5000
})
})
})
})

View File

@ -1,133 +0,0 @@
package flowcontrol
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Connection Flow controller", func() {
var (
controller *connectionFlowController
queuedWindowUpdate bool
)
// update the congestion such that it returns a given value for the smoothed RTT
setRtt := func(t time.Duration) {
controller.rttStats.UpdateRTT(t, 0, time.Now())
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
}
BeforeEach(func() {
controller = &connectionFlowController{}
controller.rttStats = &congestion.RTTStats{}
controller.logger = utils.DefaultLogger
controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
})
Context("Constructor", func() {
rttStats := &congestion.RTTStats{}
It("sets the send and receive windows", func() {
receiveWindow := protocol.ByteCount(2000)
maxReceiveWindow := protocol.ByteCount(3000)
fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, nil, rttStats, utils.DefaultLogger).(*connectionFlowController)
Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
})
})
Context("receive flow control", func() {
It("increases the highestReceived by a given window size", func() {
controller.highestReceived = 1337
controller.IncrementHighestReceived(123)
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337 + 123)))
})
Context("getting window updates", func() {
BeforeEach(func() {
controller.receiveWindow = 100
controller.receiveWindowSize = 60
controller.maxReceiveWindowSize = 1000
controller.bytesRead = 100 - 60
})
It("queues window updates", func() {
controller.MaybeQueueWindowUpdate()
Expect(queuedWindowUpdate).To(BeFalse())
controller.AddBytesRead(30)
controller.MaybeQueueWindowUpdate()
Expect(queuedWindowUpdate).To(BeTrue())
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
queuedWindowUpdate = false
controller.MaybeQueueWindowUpdate()
Expect(queuedWindowUpdate).To(BeFalse())
})
It("gets a window update", func() {
windowSize := controller.receiveWindowSize
oldOffset := controller.bytesRead
dataRead := windowSize/2 - 1 // make sure not to trigger auto-tuning
controller.AddBytesRead(dataRead)
offset := controller.GetWindowUpdate()
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + 60)))
})
It("autotunes the window", func() {
oldOffset := controller.bytesRead
oldWindowSize := controller.receiveWindowSize
rtt := scaleDuration(20 * time.Millisecond)
setRtt(rtt)
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.epochStartOffset = oldOffset
dataRead := oldWindowSize/2 + 1
controller.AddBytesRead(dataRead)
offset := controller.GetWindowUpdate()
newWindowSize := controller.receiveWindowSize
Expect(newWindowSize).To(Equal(2 * oldWindowSize))
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + dataRead + newWindowSize)))
})
})
})
Context("setting the minimum window size", func() {
var (
oldWindowSize protocol.ByteCount
receiveWindow protocol.ByteCount = 10000
receiveWindowSize protocol.ByteCount = 1000
)
BeforeEach(func() {
controller.receiveWindow = receiveWindow
controller.receiveWindowSize = receiveWindowSize
oldWindowSize = controller.receiveWindowSize
controller.maxReceiveWindowSize = 3000
})
It("sets the minimum window window size", func() {
controller.EnsureMinimumWindowSize(1800)
Expect(controller.receiveWindowSize).To(Equal(protocol.ByteCount(1800)))
})
It("doesn't reduce the window window size", func() {
controller.EnsureMinimumWindowSize(1)
Expect(controller.receiveWindowSize).To(Equal(oldWindowSize))
})
It("doens't increase the window size beyond the maxReceiveWindowSize", func() {
max := controller.maxReceiveWindowSize
controller.EnsureMinimumWindowSize(2 * max)
Expect(controller.receiveWindowSize).To(Equal(max))
})
It("starts a new epoch after the window size was increased", func() {
controller.EnsureMinimumWindowSize(1912)
Expect(controller.epochStartTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond))
})
})
})

View File

@ -1,24 +0,0 @@
package flowcontrol
import (
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestCrypto(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "FlowControl Suite")
}
var mockCtrl *gomock.Controller
var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())
})
var _ = AfterEach(func() {
mockCtrl.Finish()
})

View File

@ -1,290 +0,0 @@
package flowcontrol
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Stream Flow controller", func() {
var (
controller *streamFlowController
queuedWindowUpdate bool
queuedConnWindowUpdate bool
)
BeforeEach(func() {
queuedWindowUpdate = false
queuedConnWindowUpdate = false
rttStats := &congestion.RTTStats{}
controller = &streamFlowController{
streamID: 10,
connection: NewConnectionFlowController(1000, 1000, func() { queuedConnWindowUpdate = true }, rttStats, utils.DefaultLogger).(*connectionFlowController),
}
controller.maxReceiveWindowSize = 10000
controller.rttStats = rttStats
controller.logger = utils.DefaultLogger
controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
})
Context("Constructor", func() {
rttStats := &congestion.RTTStats{}
receiveWindow := protocol.ByteCount(2000)
maxReceiveWindow := protocol.ByteCount(3000)
sendWindow := protocol.ByteCount(4000)
It("sets the send and receive windows", func() {
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
Expect(fc.sendWindow).To(Equal(sendWindow))
Expect(fc.contributesToConnection).To(BeTrue())
})
It("queues window updates with the correction stream ID", func() {
var queued bool
queueWindowUpdate := func(id protocol.StreamID) {
Expect(id).To(Equal(protocol.StreamID(5)))
queued = true
}
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
fc.AddBytesRead(receiveWindow)
fc.MaybeQueueWindowUpdate()
Expect(queued).To(BeTrue())
})
})
Context("receiving data", func() {
Context("registering received offsets", func() {
var receiveWindow protocol.ByteCount = 10000
var receiveWindowSize protocol.ByteCount = 600
BeforeEach(func() {
controller.receiveWindow = receiveWindow
controller.receiveWindowSize = receiveWindowSize
})
It("updates the highestReceived", func() {
controller.highestReceived = 1337
err := controller.UpdateHighestReceived(1338, false)
Expect(err).ToNot(HaveOccurred())
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1338)))
})
It("informs the connection flow controller about received data", func() {
controller.highestReceived = 10
controller.contributesToConnection = true
controller.connection.(*connectionFlowController).highestReceived = 100
err := controller.UpdateHighestReceived(20, false)
Expect(err).ToNot(HaveOccurred())
Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100 + 10)))
})
It("doesn't informs the connection flow controller about received data if it doesn't contribute", func() {
controller.highestReceived = 10
controller.connection.(*connectionFlowController).highestReceived = 100
err := controller.UpdateHighestReceived(20, false)
Expect(err).ToNot(HaveOccurred())
Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100)))
})
It("does not decrease the highestReceived", func() {
controller.highestReceived = 1337
err := controller.UpdateHighestReceived(1000, false)
Expect(err).ToNot(HaveOccurred())
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(1337)))
})
It("does nothing when setting the same byte offset", func() {
controller.highestReceived = 1337
err := controller.UpdateHighestReceived(1337, false)
Expect(err).ToNot(HaveOccurred())
})
It("does not give a flow control violation when using the window completely", func() {
err := controller.UpdateHighestReceived(receiveWindow, false)
Expect(err).ToNot(HaveOccurred())
})
It("detects a flow control violation", func() {
err := controller.UpdateHighestReceived(receiveWindow+1, false)
Expect(err).To(MatchError("FlowControlReceivedTooMuchData: Received 10001 bytes on stream 10, allowed 10000 bytes"))
})
It("accepts a final offset higher than the highest received", func() {
controller.highestReceived = 100
err := controller.UpdateHighestReceived(101, true)
Expect(err).ToNot(HaveOccurred())
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(101)))
})
It("errors when receiving a final offset smaller than the highest offset received so far", func() {
controller.highestReceived = 100
err := controller.UpdateHighestReceived(99, true)
Expect(err).To(MatchError(qerr.StreamDataAfterTermination))
})
It("accepts delayed data after receiving a final offset", func() {
err := controller.UpdateHighestReceived(300, true)
Expect(err).ToNot(HaveOccurred())
err = controller.UpdateHighestReceived(250, false)
Expect(err).ToNot(HaveOccurred())
})
It("errors when receiving a higher offset after receiving a final offset", func() {
err := controller.UpdateHighestReceived(200, true)
Expect(err).ToNot(HaveOccurred())
err = controller.UpdateHighestReceived(250, false)
Expect(err).To(MatchError(qerr.StreamDataAfterTermination))
})
It("accepts duplicate final offsets", func() {
err := controller.UpdateHighestReceived(200, true)
Expect(err).ToNot(HaveOccurred())
err = controller.UpdateHighestReceived(200, true)
Expect(err).ToNot(HaveOccurred())
Expect(controller.highestReceived).To(Equal(protocol.ByteCount(200)))
})
It("errors when receiving inconsistent final offsets", func() {
err := controller.UpdateHighestReceived(200, true)
Expect(err).ToNot(HaveOccurred())
err = controller.UpdateHighestReceived(201, true)
Expect(err).To(MatchError("StreamDataAfterTermination: Received inconsistent final offset for stream 10 (old: 200, new: 201 bytes)"))
})
})
Context("registering data read", func() {
It("saves when data is read, on a stream not contributing to the connection", func() {
controller.AddBytesRead(100)
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(100)))
Expect(controller.connection.(*connectionFlowController).bytesRead).To(BeZero())
})
It("saves when data is read, on a stream not contributing to the connection", func() {
controller.contributesToConnection = true
controller.AddBytesRead(200)
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200)))
Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200)))
})
})
Context("generating window updates", func() {
var oldWindowSize protocol.ByteCount
// update the congestion such that it returns a given value for the smoothed RTT
setRtt := func(t time.Duration) {
controller.rttStats.UpdateRTT(t, 0, time.Now())
Expect(controller.rttStats.SmoothedRTT()).To(Equal(t)) // make sure it worked
}
BeforeEach(func() {
controller.receiveWindow = 100
controller.receiveWindowSize = 60
controller.bytesRead = 100 - 60
controller.connection.(*connectionFlowController).receiveWindow = 100
controller.connection.(*connectionFlowController).receiveWindowSize = 120
oldWindowSize = controller.receiveWindowSize
})
It("queues window updates", func() {
controller.MaybeQueueWindowUpdate()
Expect(queuedWindowUpdate).To(BeFalse())
controller.AddBytesRead(30)
controller.MaybeQueueWindowUpdate()
Expect(queuedWindowUpdate).To(BeTrue())
Expect(controller.GetWindowUpdate()).ToNot(BeZero())
queuedWindowUpdate = false
controller.MaybeQueueWindowUpdate()
Expect(queuedWindowUpdate).To(BeFalse())
})
It("queues connection-level window updates", func() {
controller.contributesToConnection = true
controller.MaybeQueueWindowUpdate()
Expect(queuedConnWindowUpdate).To(BeFalse())
controller.AddBytesRead(60)
controller.MaybeQueueWindowUpdate()
Expect(queuedConnWindowUpdate).To(BeTrue())
})
It("tells the connection flow controller when the window was autotuned", func() {
oldOffset := controller.bytesRead
controller.contributesToConnection = true
setRtt(scaleDuration(20 * time.Millisecond))
controller.epochStartOffset = oldOffset
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.AddBytesRead(55)
offset := controller.GetWindowUpdate()
Expect(offset).To(Equal(protocol.ByteCount(oldOffset + 55 + 2*oldWindowSize)))
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)))
})
It("doesn't tell the connection flow controller if it doesn't contribute", func() {
oldOffset := controller.bytesRead
controller.contributesToConnection = false
setRtt(scaleDuration(20 * time.Millisecond))
controller.epochStartOffset = oldOffset
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.AddBytesRead(55)
offset := controller.GetWindowUpdate()
Expect(offset).ToNot(BeZero())
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(2 * oldWindowSize))) // unchanged
})
It("doesn't increase the window after a final offset was already received", func() {
controller.AddBytesRead(30)
err := controller.UpdateHighestReceived(90, true)
Expect(err).ToNot(HaveOccurred())
controller.MaybeQueueWindowUpdate()
Expect(queuedWindowUpdate).To(BeFalse())
offset := controller.GetWindowUpdate()
Expect(offset).To(BeZero())
})
})
})
Context("sending data", func() {
It("gets the size of the send window", func() {
controller.UpdateSendWindow(15)
controller.AddBytesSent(5)
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10)))
})
It("doesn't care about the connection-level window, if it doesn't contribute", func() {
controller.UpdateSendWindow(15)
controller.connection.UpdateSendWindow(1)
controller.AddBytesSent(5)
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10)))
})
It("makes sure that it doesn't overflow the connection-level window", func() {
controller.contributesToConnection = true
controller.connection.UpdateSendWindow(12)
controller.UpdateSendWindow(20)
controller.AddBytesSent(10)
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(2)))
})
It("doesn't say that it's blocked, if only the connection is blocked", func() {
controller.contributesToConnection = true
controller.connection.UpdateSendWindow(50)
controller.UpdateSendWindow(100)
controller.AddBytesSent(50)
blocked, _ := controller.connection.IsNewlyBlocked()
Expect(blocked).To(BeTrue())
Expect(controller.IsNewlyBlocked()).To(BeFalse())
})
})
})

View File

@ -1,111 +0,0 @@
package handshake
import (
"encoding/asn1"
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Cookie Generator", func() {
var cookieGen *CookieGenerator
BeforeEach(func() {
var err error
cookieGen, err = NewCookieGenerator()
Expect(err).ToNot(HaveOccurred())
})
It("generates a Cookie", func() {
ip := net.IPv4(127, 0, 0, 1)
token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
Expect(err).ToNot(HaveOccurred())
Expect(token).ToNot(BeEmpty())
})
It("works with nil tokens", func() {
cookie, err := cookieGen.DecodeToken(nil)
Expect(err).ToNot(HaveOccurred())
Expect(cookie).To(BeNil())
})
It("accepts a valid cookie", func() {
ip := net.IPv4(192, 168, 0, 1)
token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337})
Expect(err).ToNot(HaveOccurred())
cookie, err := cookieGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred())
Expect(cookie.RemoteAddr).To(Equal("192.168.0.1"))
// the time resolution of the Cookie is just 1 second
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
})
It("rejects invalid tokens", func() {
_, err := cookieGen.DecodeToken([]byte("invalid token"))
Expect(err).To(HaveOccurred())
})
It("rejects tokens that cannot be decoded", func() {
token, err := cookieGen.cookieProtector.NewToken([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
_, err = cookieGen.DecodeToken(token)
Expect(err).To(HaveOccurred())
})
It("rejects tokens that can be decoded, but have additional payload", func() {
t, err := asn1.Marshal(token{Data: []byte("foobar")})
Expect(err).ToNot(HaveOccurred())
t = append(t, []byte("rest")...)
enc, err := cookieGen.cookieProtector.NewToken(t)
Expect(err).ToNot(HaveOccurred())
_, err = cookieGen.DecodeToken(enc)
Expect(err).To(MatchError("rest when unpacking token: 4"))
})
// we don't generate tokens that have no data, but we should be able to handle them if we receive one for whatever reason
It("doesn't panic if a tokens has no data", func() {
t, err := asn1.Marshal(token{Data: []byte("")})
Expect(err).ToNot(HaveOccurred())
enc, err := cookieGen.cookieProtector.NewToken(t)
Expect(err).ToNot(HaveOccurred())
_, err = cookieGen.DecodeToken(enc)
Expect(err).ToNot(HaveOccurred())
})
It("works with an IPv6 addresses ", func() {
addresses := []string{
"2001:db8::68",
"2001:0000:4136:e378:8000:63bf:3fff:fdd2",
"2001::1",
"ff01:0:0:0:0:0:0:2",
}
for _, addr := range addresses {
ip := net.ParseIP(addr)
Expect(ip).ToNot(BeNil())
raddr := &net.UDPAddr{IP: ip, Port: 1337}
token, err := cookieGen.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
cookie, err := cookieGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred())
Expect(cookie.RemoteAddr).To(Equal(ip.String()))
// the time resolution of the Cookie is just 1 second
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
}
})
It("uses the string representation an address that is not a UDP address", func() {
raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
token, err := cookieGen.NewToken(raddr)
Expect(err).ToNot(HaveOccurred())
cookie, err := cookieGen.DecodeToken(token)
Expect(err).ToNot(HaveOccurred())
Expect(cookie.RemoteAddr).To(Equal("192.168.13.37:1337"))
// the time resolution of the Cookie is just 1 second
// if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds
Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second))
})
})

View File

@ -1,39 +0,0 @@
package handshake
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Cookie Protector", func() {
var cp cookieProtector
BeforeEach(func() {
var err error
cp, err = newCookieProtector()
Expect(err).ToNot(HaveOccurred())
})
It("encodes and decodes tokens", func() {
token, err := cp.NewToken([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(token).ToNot(ContainSubstring("foobar"))
decoded, err := cp.DecodeToken(token)
Expect(err).ToNot(HaveOccurred())
Expect(decoded).To(Equal([]byte("foobar")))
})
It("fails deconding invalid tokens", func() {
token, err := cp.NewToken([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
token = token[1:] // remove the first byte
_, err = cp.DecodeToken(token)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("message authentication failed"))
})
It("errors when decoding too short tokens", func() {
_, err := cp.DecodeToken([]byte("foobar"))
Expect(err).To(MatchError("Token too short: 6"))
})
})

File diff suppressed because it is too large Load Diff

View File

@ -1,731 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockKEX struct {
ephermal bool
sharedKeyError error
}
func (m *mockKEX) PublicKey() []byte {
if m.ephermal {
return []byte("ephermal pub")
}
return []byte("initial public")
}
func (m *mockKEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
if m.sharedKeyError != nil {
return nil, m.sharedKeyError
}
if m.ephermal {
return []byte("shared ephermal"), nil
}
return []byte("shared key"), nil
}
type mockSigner struct {
gotCHLO bool
}
func (s *mockSigner) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
if len(chlo) > 0 {
s.gotCHLO = true
}
return []byte("proof"), nil
}
func (*mockSigner) GetCertsCompressed(sni string, common, cached []byte) ([]byte, error) {
return []byte("certcompressed"), nil
}
func (*mockSigner) GetLeafCert(sni string) ([]byte, error) {
return []byte("certuncompressed"), nil
}
func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
return mockcrypto.NewMockAEAD(mockCtrl), nil
}
type mockStream struct {
unblockRead chan struct{}
dataToRead bytes.Buffer
dataWritten bytes.Buffer
}
var _ io.ReadWriter = &mockStream{}
var errMockStreamClosing = errors.New("mock stream closing")
func newMockStream() *mockStream {
return &mockStream{unblockRead: make(chan struct{})}
}
// call Close to make Read return
func (s *mockStream) Read(p []byte) (int, error) {
n, _ := s.dataToRead.Read(p)
if n == 0 { // block if there's no data
<-s.unblockRead
return 0, errMockStreamClosing
}
return n, nil // never return an EOF
}
func (s *mockStream) Write(p []byte) (int, error) {
return s.dataWritten.Write(p)
}
func (s *mockStream) close() {
close(s.unblockRead)
}
type mockCookieProtector struct {
decodeErr error
}
var _ cookieProtector = &mockCookieProtector{}
func (mockCookieProtector) NewToken(sourceAddr []byte) ([]byte, error) {
return append([]byte("token "), sourceAddr...), nil
}
func (s mockCookieProtector) DecodeToken(data []byte) ([]byte, error) {
if s.decodeErr != nil {
return nil, s.decodeErr
}
if len(data) < 6 {
return nil, errors.New("token too short")
}
return data[6:], nil
}
var _ = Describe("Server Crypto Setup", func() {
var (
kex *mockKEX
signer *mockSigner
scfg *ServerConfig
cs *cryptoSetupServer
stream *mockStream
paramsChan chan TransportParameters
handshakeEvent chan struct{}
nonce32 []byte
versionTag []byte
validSTK []byte
aead []byte
kexs []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
sourceAddrValid bool
)
const (
expectedInitialNonceLen = 32
expectedFSNonceLen = 64
)
BeforeEach(func() {
var err error
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
handshakeEvent = make(chan struct{}, 2)
stream = newMockStream()
kex = &mockKEX{}
signer = &mockSigner{}
scfg, err = NewServerConfig(kex, signer)
nonce32 = make([]byte, 32)
aead = []byte("AESG")
kexs = []byte("C255")
copy(nonce32[4:12], scfg.obit) // set the OBIT value at the right position
versionTag = make([]byte, 4)
binary.BigEndian.PutUint32(versionTag, uint32(protocol.VersionWhatever))
Expect(err).NotTo(HaveOccurred())
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
supportedVersions = []protocol.VersionNumber{version, 98, 99}
csInt, err := NewCryptoSetup(
stream,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
remoteAddr,
version,
make([]byte, 32), // div nonce
scfg,
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
supportedVersions,
nil,
paramsChan,
handshakeEvent,
utils.DefaultLogger,
)
Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer)
cs.scfg.cookieGenerator.cookieProtector = &mockCookieProtector{}
validSTK, err = cs.scfg.cookieGenerator.NewToken(remoteAddr)
Expect(err).NotTo(HaveOccurred())
sourceAddrValid = true
cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid }
cs.keyDerivation = mockQuicCryptoKeyDerivation
cs.keyExchange = func() (crypto.KeyExchange, error) { return &mockKEX{ephermal: true}, nil }
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
cs.cryptoStream = stream
})
Context("when responding to client messages", func() {
var cert []byte
var xlct []byte
var fullCHLO map[Tag][]byte
BeforeEach(func() {
xlct = make([]byte, 8)
var err error
cert, err = cs.scfg.certChain.GetLeafCert("")
Expect(err).ToNot(HaveOccurred())
binary.LittleEndian.PutUint64(xlct, crypto.HashCert(cert))
fullCHLO = map[Tag][]byte{
TagSCID: scfg.ID,
TagSNI: []byte("quic.clemente.io"),
TagNONC: nonce32,
TagSTK: validSTK,
TagXLCT: xlct,
TagAEAD: aead,
TagKEXS: kexs,
TagPUBS: bytes.Repeat([]byte{'e'}, 31),
TagVER: versionTag,
}
})
It("doesn't support Chrome's no STOP_WAITING experiment", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagNSTP: []byte("foobar"),
},
}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(ErrNSTPExperiment))
})
It("reads the transport parameters sent by the client", func() {
sourceAddrValid = true
fullCHLO[TagICSL] = []byte{0x37, 0x13, 0, 0}
_, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), fullCHLO)
Expect(err).ToNot(HaveOccurred())
var params TransportParameters
Expect(paramsChan).To(Receive(&params))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
})
It("generates REJ messages", func() {
sourceAddrValid = false
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), nil)
Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("REJ"))
Expect(response).To(ContainSubstring("initial public"))
Expect(response).ToNot(ContainSubstring("certcompressed"))
Expect(response).ToNot(ContainSubstring("proof"))
Expect(signer.gotCHLO).To(BeFalse())
})
It("REJ messages don't include cert or proof without STK", func() {
sourceAddrValid = false
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), nil)
Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("REJ"))
Expect(response).ToNot(ContainSubstring("certcompressed"))
Expect(response).ToNot(ContainSubstring("proof"))
Expect(signer.gotCHLO).To(BeFalse())
})
It("REJ messages include cert and proof with valid STK", func() {
sourceAddrValid = true
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), map[Tag][]byte{
TagSTK: validSTK,
TagSNI: []byte("foo"),
})
Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("REJ"))
Expect(response).To(ContainSubstring("certcompressed"))
Expect(response).To(ContainSubstring("proof"))
Expect(signer.gotCHLO).To(BeTrue())
})
It("generates SHLO messages", func() {
var checkedSecure, checkedForwardSecure bool
cs.keyDerivation = func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
if forwardSecure {
Expect(nonces).To(HaveLen(expectedFSNonceLen))
checkedForwardSecure = true
Expect(sharedSecret).To(Equal([]byte("shared ephermal")))
} else {
Expect(nonces).To(HaveLen(expectedInitialNonceLen))
Expect(sharedSecret).To(Equal([]byte("shared key")))
checkedSecure = true
}
return mockcrypto.NewMockAEAD(mockCtrl), nil
}
response, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
TagPUBS: []byte("pubs-c"),
TagNONC: nonce32,
TagAEAD: aead,
TagKEXS: kexs,
})
Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("SHLO"))
message, err := ParseHandshakeMessage(bytes.NewReader(response))
Expect(err).ToNot(HaveOccurred())
Expect(message.Data).To(HaveKeyWithValue(TagPUBS, []byte("ephermal pub")))
Expect(message.Data).To(HaveKey(TagSNO))
Expect(message.Data).To(HaveKey(TagVER))
Expect(message.Data[TagVER]).To(HaveLen(4 * len(supportedVersions)))
for _, v := range supportedVersions {
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(v))
Expect(message.Data[TagVER]).To(ContainSubstring(b.String()))
}
Expect(checkedSecure).To(BeTrue())
Expect(checkedForwardSecure).To(BeTrue())
})
It("handles long handshake", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagSNI: []byte("quic.clemente.io"),
TagSTK: validSTK,
TagPAD: bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
TagVER: versionTag,
},
}.Write(&stream.dataToRead)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeEvent).ToNot(BeClosed())
})
It("rejects client nonces that have the wrong length", func() {
fullCHLO[TagNONC] = []byte("too short client nonce")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")))
})
It("rejects client nonces that have the wrong OBIT value", func() {
fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")))
})
It("errors if it can't calculate a shared key", func() {
testErr := errors.New("test error")
kex.sharedKeyError = testErr
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(testErr))
})
It("handles 0-RTT handshake", func() {
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeEvent).ToNot(BeClosed())
})
It("recognizes inchoate CHLOs missing SCID", func() {
delete(fullCHLO, TagSCID)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs missing PUBS", func() {
delete(fullCHLO, TagPUBS)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs with missing XLCT", func() {
delete(fullCHLO, TagXLCT)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs with wrong length XLCT", func() {
fullCHLO[TagXLCT] = bytes.Repeat([]byte{'f'}, 7) // should be 8 bytes
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs with wrong XLCT", func() {
fullCHLO[TagXLCT] = bytes.Repeat([]byte{'f'}, 8)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs with an invalid STK", func() {
testErr := errors.New("STK invalid")
cs.scfg.cookieGenerator.cookieProtector.(*mockCookieProtector).decodeErr = testErr
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes proper CHLOs", func() {
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
})
It("rejects CHLOs without the version tag", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagSCID: scfg.ID,
TagSNI: []byte("quic.clemente.io"),
},
}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")))
})
It("rejects CHLOs with a version tag that has the wrong length", func() {
fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")))
})
It("detects version downgrade attacks", func() {
highestSupportedVersion := supportedVersions[len(supportedVersions)-1]
lowestSupportedVersion := supportedVersions[0]
Expect(highestSupportedVersion).ToNot(Equal(lowestSupportedVersion))
cs.version = highestSupportedVersion
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(lowestSupportedVersion))
fullCHLO[TagVER] = b
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")))
})
It("accepts a non-matching version tag in the CHLO, if it is an unsupported version", func() {
supportedVersion := protocol.SupportedVersions[0]
unsupportedVersion := supportedVersion + 1000
Expect(protocol.IsSupportedVersion(supportedVersions, unsupportedVersion)).To(BeFalse())
cs.version = supportedVersion
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(unsupportedVersion))
fullCHLO[TagVER] = b
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
})
It("errors if the AEAD tag is missing", func() {
delete(fullCHLO, TagAEAD)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
})
It("errors if the AEAD tag has the wrong value", func() {
fullCHLO[TagAEAD] = []byte("wrong")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
})
It("errors if the KEXS tag is missing", func() {
delete(fullCHLO, TagKEXS)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
})
It("errors if the KEXS tag has the wrong value", func() {
fullCHLO[TagKEXS] = []byte("wrong")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
})
})
It("errors without SNI", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagSTK: validSTK,
},
}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
})
It("errors with empty SNI", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagSTK: validSTK,
TagSNI: nil,
},
}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
})
It("errors with invalid message", func() {
stream.dataToRead.Write([]byte("invalid message"))
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.HandshakeFailed))
})
It("errors with non-CHLO message", func() {
HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
})
Context("escalating crypto", func() {
doCHLO := func() {
_, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
TagPUBS: []byte("pubs-c"),
TagNONC: nonce32,
TagAEAD: aead,
TagKEXS: kexs,
})
Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(Receive()) // for the switch to secure
close(cs.sentSHLO)
}
Context("null encryption", func() {
It("is used initially", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(10), []byte{}).Return([]byte("foobar signed"))
enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := sealer.Seal(nil, []byte("foobar"), 10, []byte{})
Expect(d).To(Equal([]byte("foobar signed")))
})
It("is used for the crypto stream", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
})
It("is accepted initially", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(5), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("unencrypted"), 5, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
})
It("errors if the has the wrong hash", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("not unencrypted"), protocol.PacketNumber(5), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err := cs.Open(nil, []byte("not unencrypted"), 5, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is still accepted after CHLO", func() {
doCHLO()
// it tries forward secure and secure decryption first
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{})
Expect(cs.secureAEAD).ToNot(BeNil())
_, enc, err := cs.Open(nil, []byte("unencrypted"), 99, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
})
It("is not accepted after receiving secure packet", func() {
doCHLO()
// first receive a secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
Expect(enc).To(Equal(protocol.EncryptionSecure))
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
// now receive an unencrypted packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err = cs.Open(nil, []byte("unencrypted"), 99, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is not used after CHLO", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
enc, sealer := cs.GetSealer()
Expect(enc).ToNot(Equal(protocol.EncryptionUnencrypted))
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
})
})
Context("initial encryption", func() {
It("is accepted after CHLO", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
Expect(enc).To(Equal(protocol.EncryptionSecure))
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
})
It("is not accepted after receiving forward secure packet", func() {
doCHLO()
// receive a forward secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
Expect(err).ToNot(HaveOccurred())
// receive a secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(12), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err := cs.Open(nil, []byte("encrypted"), 12, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is used for the crypto stream", func() {
doCHLO()
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(1), []byte{}).Return([]byte("foobar crypto stream"))
enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionSecure))
d := sealer.Seal(nil, []byte("foobar"), 1, []byte{})
Expect(d).To(Equal([]byte("foobar crypto stream")))
})
})
Context("forward secure encryption", func() {
It("is used after the CHLO", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(20), []byte{}).Return([]byte("foobar forward sec"))
enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
d := sealer.Seal(nil, []byte("foobar"), 20, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec")))
})
It("regards the handshake as complete once it receives a forward encrypted packet", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(BeClosed())
})
})
Context("reporting the connection state", func() {
It("reports before the handshake completes", func() {
cs.sni = "server name"
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.ServerName).To(Equal("server name"))
})
It("reports after the handshake completes", func() {
doCHLO()
// receive a forward secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
Expect(err).ToNot(HaveOccurred())
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
})
})
Context("forcing encryption levels", func() {
It("forces null encryption", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(11), []byte{}).Return([]byte("foobar unencrypted"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 11, []byte{})
Expect(d).To(Equal([]byte("foobar unencrypted")))
})
It("forces initial encryption", func() {
doCHLO()
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(12), []byte{}).Return([]byte("foobar secure"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 12, []byte{})
Expect(d).To(Equal([]byte("foobar secure")))
})
It("errors if no AEAD for initial encryption is available", func() {
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).To(MatchError("CryptoSetupServer: no secureAEAD"))
Expect(sealer).To(BeNil())
})
It("forces forward-secure encryption", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(13), []byte{}).Return([]byte("foobar forward sec"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 13, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec")))
})
It("errors of no AEAD for forward-secure encryption is available", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).To(MatchError("CryptoSetupServer: no forwardSecureAEAD"))
Expect(seal).To(BeNil())
})
It("errors if no encryption level is specified", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified)
Expect(err).To(MatchError("CryptoSetupServer: no encryption level specified"))
Expect(seal).To(BeNil())
})
})
})
Context("STK verification and creation", func() {
It("requires STK", func() {
sourceAddrValid = false
done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
map[Tag][]byte{
TagSNI: []byte("foo"),
TagVER: versionTag,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
Expect(cs.sni).To(Equal("foo"))
})
It("works with proper STK", func() {
sourceAddrValid = true
done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
map[Tag][]byte{
TagSNI: []byte("foo"),
TagVER: versionTag,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
})
})
})

View File

@ -1,192 +0,0 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) {
return mockcrypto.NewMockAEAD(mockCtrl), nil
}
var _ = Describe("TLS Crypto Setup", func() {
var (
cs *cryptoSetupTLS
handshakeEvent chan struct{}
)
BeforeEach(func() {
handshakeEvent = make(chan struct{}, 2)
css, err := NewCryptoSetupTLSServer(
newCryptoStreamConn(bytes.NewBuffer([]byte{})),
protocol.ConnectionID{},
&mint.Config{},
handshakeEvent,
protocol.VersionTLS,
)
Expect(err).ToNot(HaveOccurred())
cs = css.(*cryptoSetupTLS)
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
})
It("errors when the handshake fails", func() {
alert := mint.AlertBadRecordMAC
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(alert)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
})
It("derives keys", func() {
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(Receive())
Expect(handshakeEvent).To(BeClosed())
})
It("handshakes until it is connected", func() {
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerNegotiated}).Times(9)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(Receive())
})
Context("reporting the handshake state", func() {
It("reports before the handshake compeletes", func() {
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.PeerCertificates).To(BeNil())
})
It("reports after the handshake completes", func() {
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected}).Times(2)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
Expect(state.PeerCertificates).To(BeNil())
})
})
Context("escalating crypto", func() {
doHandshake := func() {
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
}
Context("null encryption", func() {
It("is used initially", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar signed"))
enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := sealer.Seal(nil, []byte("foobar"), 5, []byte{})
Expect(d).To(Equal([]byte("foobar signed")))
})
It("is used for opening", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("foobar enc"), protocol.PacketNumber(10), []byte{}).Return([]byte("foobar"), nil)
d, err := cs.OpenHandshake(nil, []byte("foobar enc"), 10, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("foobar")))
})
It("is used for crypto stream", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(20), []byte{}).Return([]byte("foobar signed"))
enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := sealer.Seal(nil, []byte("foobar"), 20, []byte{})
Expect(d).To(Equal([]byte("foobar signed")))
})
It("errors if the has the wrong hash", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("foobar enc"), protocol.PacketNumber(10), []byte{}).Return(nil, errors.New("authentication failed"))
_, err := cs.OpenHandshake(nil, []byte("foobar enc"), 10, []byte{})
Expect(err).To(MatchError("authentication failed"))
})
})
Context("forward-secure encryption", func() {
It("is used for sealing after the handshake completes", func() {
doHandshake()
cs.aead.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar forward sec"))
enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
d := sealer.Seal(nil, []byte("foobar"), 5, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec")))
})
It("is used for opening", func() {
doHandshake()
cs.aead.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(6), []byte{}).Return([]byte("decrypted"), nil)
d, err := cs.Open1RTT(nil, []byte("encrypted"), 6, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
})
})
Context("forcing encryption levels", func() {
It("forces null encryption", func() {
doHandshake()
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar signed"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 5, []byte{})
Expect(d).To(Equal([]byte("foobar signed")))
})
It("forces forward-secure encryption", func() {
doHandshake()
cs.aead.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar forward sec"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 5, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec")))
})
It("errors if the forward-secure AEAD is not available", func() {
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level forward-secure"))
Expect(sealer).To(BeNil())
})
It("never returns a secure AEAD (they don't exist with TLS)", func() {
doHandshake()
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level encrypted (not forward-secure)"))
Expect(sealer).To(BeNil())
})
It("errors if no encryption level is specified", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified)
Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level unknown"))
Expect(seal).To(BeNil())
})
})
})
})

View File

@ -1,41 +0,0 @@
package handshake
import (
"bytes"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Crypto Stream Conn", func() {
var (
stream *bytes.Buffer
csc *cryptoStreamConn
)
BeforeEach(func() {
stream = &bytes.Buffer{}
csc = newCryptoStreamConn(stream)
})
It("buffers writes", func() {
_, err := csc.Write([]byte("foo"))
Expect(err).ToNot(HaveOccurred())
Expect(stream.Len()).To(BeZero())
_, err = csc.Write([]byte("bar"))
Expect(err).ToNot(HaveOccurred())
Expect(stream.Len()).To(BeZero())
Expect(csc.Flush()).To(Succeed())
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
It("reads from the stream", func() {
stream.Write([]byte("foobar"))
b := make([]byte, 6)
n, err := csc.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
})

File diff suppressed because one or more lines are too long

View File

@ -1,35 +0,0 @@
package handshake
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Ephermal KEX", func() {
It("has a consistent KEX", func() {
kex1, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex1).ToNot(BeNil())
kex2, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex2).ToNot(BeNil())
Expect(kex1).To(Equal(kex2))
})
It("changes KEX", func() {
kexLifetime = 10 * time.Millisecond
defer func() {
kexLifetime = protocol.EphermalKeyLifetime
}()
kex, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex).ToNot(BeNil())
time.Sleep(kexLifetime)
kex2, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex2).ToNot(Equal(kex))
})
})

View File

@ -1,71 +0,0 @@
package handshake
import (
"bytes"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake Message", func() {
Context("when parsing", func() {
It("parses sample CHLO message", func() {
msg, err := ParseHandshakeMessage(bytes.NewReader(sampleCHLO))
Expect(err).ToNot(HaveOccurred())
Expect(msg.Tag).To(Equal(TagCHLO))
Expect(msg.Data).To(Equal(sampleCHLOMap))
})
It("rejects large numbers of pairs", func() {
r := bytes.NewReader([]byte("CHLO\xff\xff\xff\xff"))
_, err := ParseHandshakeMessage(r)
Expect(err).To(MatchError(qerr.CryptoTooManyEntries))
})
It("rejects too long values", func() {
r := bytes.NewReader([]byte{
'C', 'H', 'L', 'O',
1, 0, 0, 0,
0, 0, 0, 0,
0xff, 0xff, 0xff, 0xff,
})
_, err := ParseHandshakeMessage(r)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoInvalidValueLength, "value too long")))
})
})
Context("when writing", func() {
It("writes sample message", func() {
b := &bytes.Buffer{}
HandshakeMessage{Tag: TagCHLO, Data: sampleCHLOMap}.Write(b)
Expect(b.Bytes()).To(Equal(sampleCHLO))
})
})
Context("string representation", func() {
It("has a string representation", func() {
str := HandshakeMessage{
Tag: TagSHLO,
Data: map[Tag][]byte{
TagAEAD: []byte("foobar"),
TagEXPY: []byte("raboof"),
},
}.String()
Expect(str[:4]).To(Equal("SHLO"))
Expect(str).To(ContainSubstring("AEAD: \"foobar\""))
Expect(str).To(ContainSubstring("EXPY: \"raboof\""))
})
It("lists padding separately", func() {
str := HandshakeMessage{
Tag: TagSHLO,
Data: map[Tag][]byte{
TagPAD: bytes.Repeat([]byte{0}, 1337),
},
}.String()
Expect(str).To(ContainSubstring("PAD"))
Expect(str).To(ContainSubstring("1337 bytes"))
})
})
})

View File

@ -1,24 +0,0 @@
package handshake
import (
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestQuicGo(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Handshake Suite")
}
var mockCtrl *gomock.Controller
var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())
})
var _ = AfterEach(func() {
mockCtrl.Finish()
})

View File

@ -1,72 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS)
// Package handshake is a generated GoMock package.
package handshake
import (
reflect "reflect"
mint "github.com/bifurcation/mint"
gomock "github.com/golang/mock/gomock"
)
// MockMintTLS is a mock of MintTLS interface
type MockMintTLS struct {
ctrl *gomock.Controller
recorder *MockMintTLSMockRecorder
}
// MockMintTLSMockRecorder is the mock recorder for MockMintTLS
type MockMintTLSMockRecorder struct {
mock *MockMintTLS
}
// NewMockMintTLS creates a new mock instance
func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS {
mock := &MockMintTLS{ctrl: ctrl}
mock.recorder = &MockMintTLSMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder {
return m.recorder
}
// ComputeExporter mocks base method
func (m *MockMintTLS) ComputeExporter(arg0 string, arg1 []byte, arg2 int) ([]byte, error) {
ret := m.ctrl.Call(m, "ComputeExporter", arg0, arg1, arg2)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ComputeExporter indicates an expected call of ComputeExporter
func (mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
}
// ConnectionState mocks base method
func (m *MockMintTLS) ConnectionState() mint.ConnectionState {
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(mint.ConnectionState)
return ret0
}
// ConnectionState indicates an expected call of ConnectionState
func (mr *MockMintTLSMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockMintTLS)(nil).ConnectionState))
}
// Handshake mocks base method
func (m *MockMintTLS) Handshake() mint.Alert {
ret := m.ctrl.Call(m, "Handshake")
ret0, _ := ret[0].(mint.Alert)
return ret0
}
// Handshake indicates an expected call of Handshake
func (mr *MockMintTLSMockRecorder) Handshake() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake))
}

View File

@ -1,266 +0,0 @@
package handshake
import (
"bytes"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// This tagMap can be passed to parseValues and is garantueed to not cause any errors
func getDefaultServerConfigClient() map[Tag][]byte {
return map[Tag][]byte{
TagSCID: bytes.Repeat([]byte{'F'}, 16),
TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, bytes.Repeat([]byte{0}, 32)...),
TagOBIT: bytes.Repeat([]byte{0}, 8),
TagEXPY: {0x0, 0x6c, 0x57, 0x78, 0, 0, 0, 0}, // 2033-12-24
}
}
var _ = Describe("Server Config", func() {
var tagMap map[Tag][]byte
BeforeEach(func() {
tagMap = getDefaultServerConfigClient()
})
It("returns the parsed server config", func() {
tagMap[TagSCID] = []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
b := &bytes.Buffer{}
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b)
scfg, err := parseServerConfig(b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.ID).To(Equal(tagMap[TagSCID]))
})
It("saves the raw server config", func() {
b := &bytes.Buffer{}
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b)
scfg, err := parseServerConfig(b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.raw).To(Equal(b.Bytes()))
})
It("tells if a server config is expired", func() {
scfg := &serverConfigClient{}
scfg.expiry = time.Now().Add(-time.Second)
Expect(scfg.IsExpired()).To(BeTrue())
scfg.expiry = time.Now().Add(time.Second)
Expect(scfg.IsExpired()).To(BeFalse())
})
Context("parsing the server config", func() {
It("rejects a handshake message with the wrong message tag", func() {
var serverConfig bytes.Buffer
HandshakeMessage{Tag: TagCHLO, Data: make(map[Tag][]byte)}.Write(&serverConfig)
_, err := parseServerConfig(serverConfig.Bytes())
Expect(err).To(MatchError(errMessageNotServerConfig))
})
It("errors on invalid handshake messages", func() {
var serverConfig bytes.Buffer
HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig)
_, err := parseServerConfig(serverConfig.Bytes()[:serverConfig.Len()-2])
Expect(err).To(MatchError("unexpected EOF"))
})
It("passes on errors encountered when reading the TagMap", func() {
var serverConfig bytes.Buffer
HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig)
_, err := parseServerConfig(serverConfig.Bytes())
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID"))
})
It("reads an example Handshake Message", func() {
var serverConfig bytes.Buffer
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(&serverConfig)
scfg, err := parseServerConfig(serverConfig.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.ID).To(Equal(tagMap[TagSCID]))
Expect(scfg.obit).To(Equal(tagMap[TagOBIT]))
})
})
Context("Reading values from the TagMap", func() {
var scfg *serverConfigClient
BeforeEach(func() {
scfg = &serverConfigClient{}
})
Context("ServerConfig ID", func() {
It("parses the ServerConfig ID", func() {
id := []byte{0xb2, 0xa4, 0xbb, 0x8f, 0xf6, 0x51, 0x28, 0xfd, 0x4d, 0xf7, 0xb3, 0x9a, 0x91, 0xe7, 0x91, 0xfb}
tagMap[TagSCID] = id
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
Expect(scfg.ID).To(Equal(id))
})
It("errors if the ServerConfig ID is missing", func() {
delete(tagMap, TagSCID)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID"))
})
It("rejects ServerConfig IDs that have the wrong length", func() {
tagMap[TagSCID] = bytes.Repeat([]byte{'F'}, 17) // 1 byte too long
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: SCID"))
})
})
Context("KEXS", func() {
It("rejects KEXS values that have the wrong length", func() {
tagMap[TagKEXS] = bytes.Repeat([]byte{'F'}, 5) // 1 byte too long
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: KEXS"))
})
It("rejects KEXS values other than C255", func() {
tagMap[TagKEXS] = []byte("P256")
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoNoSupport: KEXS: Could not find C255, other key exchanges are not supported"))
})
It("errors if the KEXS is missing", func() {
delete(tagMap, TagKEXS)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: KEXS"))
})
})
Context("AEAD", func() {
It("rejects AEAD values that have the wrong length", func() {
tagMap[TagAEAD] = bytes.Repeat([]byte{'F'}, 5) // 1 byte too long
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: AEAD"))
})
It("rejects AEAD values other than AESG", func() {
tagMap[TagAEAD] = []byte("S20P")
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoNoSupport: AEAD"))
})
It("recognizes AESG in the list of AEADs, at the first position", func() {
tagMap[TagAEAD] = []byte("AESGS20P")
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
})
It("recognizes AESG in the list of AEADs, not at the first position", func() {
tagMap[TagAEAD] = []byte("S20PAESG")
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
})
It("errors if the AEAD is missing", func() {
delete(tagMap, TagAEAD)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: AEAD"))
})
})
Context("PUBS", func() {
It("creates a Curve25519 key exchange", func() {
serverKex, err := crypto.NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
tagMap[TagPUBS] = append([]byte{0x20, 0x00, 0x00}, serverKex.PublicKey()...)
err = scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
sharedSecret, err := serverKex.CalculateSharedKey(scfg.kex.PublicKey())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.sharedSecret).To(Equal(sharedSecret))
})
It("rejects PUBS values that have the wrong length", func() {
tagMap[TagPUBS] = bytes.Repeat([]byte{'F'}, 100) // completely wrong length
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: PUBS"))
})
It("rejects PUBS values that have a zero length", func() {
tagMap[TagPUBS] = bytes.Repeat([]byte{0}, 100) // completely wrong length
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: PUBS"))
})
It("ensure that C255 Pubs must not be at the first index", func() {
serverKex, err := crypto.NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
tagMap[TagKEXS] = []byte("P256C255") // have another KEXS before C255
// 3 byte len + 1 byte empty + C255
tagMap[TagPUBS] = append([]byte{0x01, 0x00, 0x00, 0x00}, append([]byte{0x20, 0x00, 0x00}, serverKex.PublicKey()...)...)
err = scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
sharedSecret, err := serverKex.CalculateSharedKey(scfg.kex.PublicKey())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.sharedSecret).To(Equal(sharedSecret))
})
It("errors if the PUBS is missing", func() {
delete(tagMap, TagPUBS)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: PUBS"))
})
})
Context("OBIT", func() {
It("parses the OBIT value", func() {
obit := []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}
tagMap[TagOBIT] = obit
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
Expect(scfg.obit).To(Equal(obit))
})
It("errors if the OBIT is missing", func() {
delete(tagMap, TagOBIT)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: OBIT"))
})
It("rejets OBIT values that have the wrong length", func() {
tagMap[TagOBIT] = bytes.Repeat([]byte{'F'}, 7) // 1 byte too short
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: OBIT"))
})
})
Context("EXPY", func() {
It("parses the expiry date", func() {
tagMap[TagEXPY] = []byte{0xdc, 0x89, 0x0e, 0x59, 0, 0, 0, 0} // UNIX Timestamp 0x590e89dc = 1494125020
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
year, month, day := scfg.expiry.UTC().Date()
Expect(year).To(Equal(2017))
Expect(month).To(Equal(time.Month(5)))
Expect(day).To(Equal(7))
})
It("errors if the EXPY is missing", func() {
delete(tagMap, TagEXPY)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: EXPY"))
})
It("rejects EXPY values that have the wrong length", func() {
tagMap[TagEXPY] = bytes.Repeat([]byte{'F'}, 9) // 1 byte too long
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: EXPY"))
})
It("deals with absurdly large timestamps", func() {
tagMap[TagEXPY] = bytes.Repeat([]byte{0xff}, 8) // this would overflow the int64
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
Expect(scfg.expiry.After(time.Now())).To(BeTrue())
})
})
})
})

View File

@ -1,45 +0,0 @@
package handshake
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/crypto"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("ServerConfig", func() {
var (
kex crypto.KeyExchange
)
BeforeEach(func() {
var err error
kex, err = crypto.NewCurve25519KEX()
Expect(err).NotTo(HaveOccurred())
})
It("generates a random ID and OBIT", func() {
scfg1, err := NewServerConfig(kex, nil)
Expect(err).ToNot(HaveOccurred())
scfg2, err := NewServerConfig(kex, nil)
Expect(err).ToNot(HaveOccurred())
Expect(scfg1.ID).ToNot(Equal(scfg2.ID))
Expect(scfg1.obit).ToNot(Equal(scfg2.obit))
Expect(scfg1.cookieGenerator).ToNot(Equal(scfg2.cookieGenerator))
})
It("gets the proper binary representation", func() {
scfg, err := NewServerConfig(kex, nil)
Expect(err).NotTo(HaveOccurred())
expected := bytes.NewBuffer([]byte{0x53, 0x43, 0x46, 0x47, 0x6, 0x0, 0x0, 0x0, 0x41, 0x45, 0x41, 0x44, 0x4, 0x0, 0x0, 0x0, 0x53, 0x43, 0x49, 0x44, 0x14, 0x0, 0x0, 0x0, 0x50, 0x55, 0x42, 0x53, 0x37, 0x0, 0x0, 0x0, 0x4b, 0x45, 0x58, 0x53, 0x3b, 0x0, 0x0, 0x0, 0x4f, 0x42, 0x49, 0x54, 0x43, 0x0, 0x0, 0x0, 0x45, 0x58, 0x50, 0x59, 0x4b, 0x0, 0x0, 0x0, 0x41, 0x45, 0x53, 0x47})
expected.Write(scfg.ID)
expected.Write([]byte{0x20, 0x0, 0x0})
expected.Write(kex.PublicKey())
expected.Write([]byte{0x43, 0x32, 0x35, 0x35})
expected.Write(scfg.obit)
expected.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
Expect(scfg.Get()).To(Equal(expected.Bytes()))
})
})

View File

@ -1,233 +0,0 @@
package handshake
import (
"bytes"
"fmt"
"time"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("TLS Extension Handler, for the client", func() {
var (
handler *extensionHandlerClient
el mint.ExtensionList
)
BeforeEach(func() {
handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever, utils.DefaultLogger).(*extensionHandlerClient)
el = make(mint.ExtensionList, 0)
})
Context("sending", func() {
It("only adds TransportParameters for the ClientHello", func() {
// test 2 other handshake types
err := handler.Send(mint.HandshakeTypeCertificateRequest, &el)
Expect(err).ToNot(HaveOccurred())
Expect(el).To(BeEmpty())
err = handler.Send(mint.HandshakeTypeEndOfEarlyData, &el)
Expect(err).ToNot(HaveOccurred())
Expect(el).To(BeEmpty())
})
It("adds TransportParameters to the ClientHello", func() {
handler.initialVersion = 13
err := handler.Send(mint.HandshakeTypeClientHello, &el)
Expect(err).ToNot(HaveOccurred())
Expect(el).To(HaveLen(1))
ext := &tlsExtensionBody{}
found, err := el.Find(ext)
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeTrue())
chtp := &clientHelloTransportParameters{}
err = chtp.Unmarshal(ext.data)
Expect(err).ToNot(HaveOccurred())
Expect(chtp.InitialVersion).To(BeEquivalentTo(13))
})
})
Context("receiving", func() {
var fakeBody *tlsExtensionBody
var parameters TransportParameters
addEncryptedExtensionsWithParameters := func(params TransportParameters) {
body := (&encryptedExtensionsTransportParameters{
Parameters: params,
SupportedVersions: []protocol.VersionNumber{handler.version},
}).Marshal()
Expect(el.Add(&tlsExtensionBody{data: body})).To(Succeed())
}
BeforeEach(func() {
fakeBody = &tlsExtensionBody{data: []byte("foobar foobar")}
parameters = TransportParameters{
IdleTimeout: 0x1337 * time.Second,
StatelessResetToken: bytes.Repeat([]byte{0}, 16),
}
})
It("blocks until the transport parameters are read", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
addEncryptedExtensionsWithParameters(parameters)
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
Expect(handler.GetPeerParams()).To(Receive())
Eventually(done).Should(BeClosed())
})
It("accepts the TransportParameters on the EncryptedExtensions message", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
addEncryptedExtensionsWithParameters(parameters)
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
var params TransportParameters
Eventually(handler.GetPeerParams()).Should(Receive(&params))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
Eventually(done).Should(BeClosed())
})
It("errors if the EncryptedExtensions message doesn't contain TransportParameters", func() {
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).To(MatchError("EncryptedExtensions message didn't contain a QUIC extension"))
})
It("rejects the TransportParameters on a wrong handshake types", func() {
err := el.Add(fakeBody)
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeCertificate, &el)
Expect(err).To(MatchError(fmt.Sprintf("Unexpected QUIC extension in handshake message %d", mint.HandshakeTypeCertificate)))
})
It("ignores messages without TransportParameters, if they are not required", func() {
err := handler.Receive(mint.HandshakeTypeCertificate, &el)
Expect(err).ToNot(HaveOccurred())
})
It("errors when it can't parse the TransportParameters", func() {
err := el.Add(fakeBody)
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).To(HaveOccurred()) // this will be some kind of decoding error
})
It("rejects TransportParameters if they don't contain the stateless reset token", func() {
parameters.StatelessResetToken = nil
addEncryptedExtensionsWithParameters(parameters)
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).To(MatchError("server didn't sent stateless_reset_token"))
})
Context("Version Negotiation", func() {
It("accepts a valid version negotiation", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
Eventually(handler.GetPeerParams()).Should(Receive())
close(done)
}()
handler.initialVersion = 13
handler.version = 37
handler.supportedVersions = []protocol.VersionNumber{13, 37, 42}
body := (&encryptedExtensionsTransportParameters{
Parameters: parameters,
NegotiatedVersion: 37,
SupportedVersions: []protocol.VersionNumber{36, 37, 38},
}).Marshal()
err := el.Add(&tlsExtensionBody{data: body})
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
})
It("errors if the current version doesn't match negotiated_version", func() {
handler.initialVersion = 13
handler.version = 37
handler.supportedVersions = []protocol.VersionNumber{13, 37, 42}
body := (&encryptedExtensionsTransportParameters{
Parameters: parameters,
NegotiatedVersion: 38,
SupportedVersions: []protocol.VersionNumber{36, 37, 38},
}).Marshal()
err := el.Add(&tlsExtensionBody{data: body})
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).To(MatchError("VersionNegotiationMismatch: current version doesn't match negotiated_version"))
})
It("errors if the current version is not contained in the server's supported versions", func() {
handler.version = 42
body := (&encryptedExtensionsTransportParameters{
NegotiatedVersion: 42,
SupportedVersions: []protocol.VersionNumber{43, 44},
}).Marshal()
err := el.Add(&tlsExtensionBody{data: body})
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).To(MatchError("VersionNegotiationMismatch: current version not included in the supported versions"))
})
It("errors if version negotiation was performed, but would have picked a different version based on the supported version list", func() {
handler.version = 42
handler.initialVersion = 41
handler.supportedVersions = []protocol.VersionNumber{43, 42, 41}
serverSupportedVersions := []protocol.VersionNumber{42, 43}
// check that version negotiation would have led us to pick version 43
ver, ok := protocol.ChooseSupportedVersion(handler.supportedVersions, serverSupportedVersions)
Expect(ok).To(BeTrue())
Expect(ver).To(Equal(protocol.VersionNumber(43)))
body := (&encryptedExtensionsTransportParameters{
NegotiatedVersion: 42,
SupportedVersions: serverSupportedVersions,
}).Marshal()
err := el.Add(&tlsExtensionBody{data: body})
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).To(MatchError("VersionNegotiationMismatch: would have picked a different version"))
})
It("doesn't error if it would have picked a different version based on the supported version list, if no version negotiation was performed", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
Eventually(handler.GetPeerParams()).Should(Receive())
close(done)
}()
handler.version = 42
handler.initialVersion = 42 // version == initialVersion means no version negotiation was performed
handler.supportedVersions = []protocol.VersionNumber{43, 42, 41}
serverSupportedVersions := []protocol.VersionNumber{42, 43}
// check that version negotiation would have led us to pick version 43
ver, ok := protocol.ChooseSupportedVersion(handler.supportedVersions, serverSupportedVersions)
Expect(ok).To(BeTrue())
Expect(ver).To(Equal(protocol.VersionNumber(43)))
body := (&encryptedExtensionsTransportParameters{
Parameters: parameters,
NegotiatedVersion: 42,
SupportedVersions: serverSupportedVersions,
}).Marshal()
err := el.Add(&tlsExtensionBody{data: body})
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
})
})
})
})

View File

@ -1,155 +0,0 @@
package handshake
import (
"bytes"
"fmt"
"time"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("TLS Extension Handler, for the server", func() {
var (
handler *extensionHandlerServer
el mint.ExtensionList
)
BeforeEach(func() {
handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever, utils.DefaultLogger).(*extensionHandlerServer)
el = make(mint.ExtensionList, 0)
})
Context("sending", func() {
It("only adds TransportParameters for the ClientHello", func() {
// test 2 other handshake types
err := handler.Send(mint.HandshakeTypeCertificateRequest, &el)
Expect(err).ToNot(HaveOccurred())
Expect(el).To(BeEmpty())
err = handler.Send(mint.HandshakeTypeEndOfEarlyData, &el)
Expect(err).ToNot(HaveOccurred())
Expect(el).To(BeEmpty())
})
It("adds TransportParameters to the EncryptedExtensions message", func() {
handler.version = 666
versions := []protocol.VersionNumber{13, 37, 42}
handler.supportedVersions = versions
err := handler.Send(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).ToNot(HaveOccurred())
Expect(el).To(HaveLen(1))
ext := &tlsExtensionBody{}
found, err := el.Find(ext)
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeTrue())
eetp := &encryptedExtensionsTransportParameters{}
err = eetp.Unmarshal(ext.data)
Expect(err).ToNot(HaveOccurred())
Expect(eetp.NegotiatedVersion).To(BeEquivalentTo(666))
// the SupportedVersions will contain one reserved version number
Expect(eetp.SupportedVersions).To(HaveLen(len(versions) + 1))
for _, version := range versions {
Expect(eetp.SupportedVersions).To(ContainElement(version))
}
})
})
Context("receiving", func() {
var (
fakeBody *tlsExtensionBody
parameters TransportParameters
)
addClientHelloWithParameters := func(params TransportParameters) {
body := (&clientHelloTransportParameters{Parameters: params}).Marshal()
Expect(el.Add(&tlsExtensionBody{data: body})).To(Succeed())
}
BeforeEach(func() {
fakeBody = &tlsExtensionBody{data: []byte("foobar foobar")}
parameters = TransportParameters{IdleTimeout: 0x1337 * time.Second}
})
It("accepts the TransportParameters on the EncryptedExtensions message", func() {
addClientHelloWithParameters(parameters)
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).ToNot(HaveOccurred())
var params TransportParameters
Expect(handler.GetPeerParams()).To(Receive(&params))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
})
It("errors if the ClientHello doesn't contain TransportParameters", func() {
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).To(MatchError("ClientHello didn't contain a QUIC extension"))
})
It("ignores messages without TransportParameters, if they are not required", func() {
err := handler.Receive(mint.HandshakeTypeCertificate, &el)
Expect(err).ToNot(HaveOccurred())
})
It("errors if it can't unmarshal the TransportParameters", func() {
err := el.Add(fakeBody)
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).To(HaveOccurred()) // this will be some kind of decoding error
})
It("rejects messages other than the ClientHello that contain TransportParameters", func() {
addClientHelloWithParameters(parameters)
err := handler.Receive(mint.HandshakeTypeCertificateRequest, &el)
Expect(err).To(MatchError(fmt.Sprintf("Unexpected QUIC extension in handshake message %d", mint.HandshakeTypeCertificateRequest)))
})
It("rejects messages that contain a stateless reset token", func() {
parameters.StatelessResetToken = bytes.Repeat([]byte{0}, 16)
addClientHelloWithParameters(parameters)
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).To(MatchError("client sent a stateless reset token"))
})
Context("Version Negotiation", func() {
It("accepts a ClientHello, when no version negotiation was performed", func() {
handler.version = 42
body := (&clientHelloTransportParameters{
InitialVersion: 42,
Parameters: parameters,
}).Marshal()
err := el.Add(&tlsExtensionBody{data: body})
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).ToNot(HaveOccurred())
})
It("accepts a valid version negotiation", func() {
handler.version = 42
handler.supportedVersions = []protocol.VersionNumber{13, 37, 42}
body := (&clientHelloTransportParameters{
InitialVersion: 22, // this must be an unsupported version
Parameters: parameters,
}).Marshal()
err := el.Add(&tlsExtensionBody{data: body})
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).ToNot(HaveOccurred())
})
It("erros when a version negotiation was performed, although we already support the initial version", func() {
handler.supportedVersions = []protocol.VersionNumber{11, 12, 13}
handler.version = 13
body := (&clientHelloTransportParameters{
InitialVersion: 11, // this is an supported version
}).Marshal()
err := el.Add(&tlsExtensionBody{data: body})
Expect(err).ToNot(HaveOccurred())
err = handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).To(MatchError("VersionNegotiationMismatch: Client should have used the initial version"))
})
})
})
})

View File

@ -1,95 +0,0 @@
package handshake
import (
"math/rand"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("TLS extension body", func() {
Context("Client Hello Transport Parameters", func() {
It("marshals and unmarshals", func() {
chtp := &clientHelloTransportParameters{
InitialVersion: 0x123456,
Parameters: TransportParameters{
StreamFlowControlWindow: 0x42,
IdleTimeout: 0x1337 * time.Second,
},
}
chtp2 := &clientHelloTransportParameters{}
Expect(chtp2.Unmarshal(chtp.Marshal())).To(Succeed())
Expect(chtp2.InitialVersion).To(Equal(chtp.InitialVersion))
Expect(chtp2.Parameters.StreamFlowControlWindow).To(Equal(chtp.Parameters.StreamFlowControlWindow))
Expect(chtp2.Parameters.IdleTimeout).To(Equal(chtp.Parameters.IdleTimeout))
})
It("fuzzes", func() {
rand := rand.New(rand.NewSource(GinkgoRandomSeed()))
b := make([]byte, 100)
for i := 0; i < 1000; i++ {
rand.Read(b)
chtp := &clientHelloTransportParameters{}
chtp.Unmarshal(b[:int(rand.Int31n(100))])
}
})
})
Context("Encrypted Extensions Transport Parameters", func() {
It("marshals and unmarshals", func() {
eetp := &encryptedExtensionsTransportParameters{
NegotiatedVersion: 0x123456,
SupportedVersions: []protocol.VersionNumber{0x42, 0x4242},
Parameters: TransportParameters{
StreamFlowControlWindow: 0x42,
IdleTimeout: 0x1337 * time.Second,
},
}
eetp2 := &encryptedExtensionsTransportParameters{}
Expect(eetp2.Unmarshal(eetp.Marshal())).To(Succeed())
Expect(eetp2.NegotiatedVersion).To(Equal(eetp.NegotiatedVersion))
Expect(eetp2.SupportedVersions).To(Equal(eetp.SupportedVersions))
Expect(eetp2.Parameters.StreamFlowControlWindow).To(Equal(eetp.Parameters.StreamFlowControlWindow))
Expect(eetp2.Parameters.IdleTimeout).To(Equal(eetp.Parameters.IdleTimeout))
})
It("fuzzes", func() {
rand := rand.New(rand.NewSource(GinkgoRandomSeed()))
b := make([]byte, 100)
for i := 0; i < 1000; i++ {
rand.Read(b)
chtp := &encryptedExtensionsTransportParameters{}
chtp.Unmarshal(b[:int(rand.Int31n(100))])
}
})
})
Context("TLS Extension Body", func() {
var extBody *tlsExtensionBody
BeforeEach(func() {
extBody = &tlsExtensionBody{}
})
It("has the right TLS extension type", func() {
Expect(extBody.Type()).To(BeEquivalentTo(quicTLSExtensionType))
})
It("saves the body when unmarshalling", func() {
n, err := extBody.Unmarshal([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(extBody.data).To(Equal([]byte("foobar")))
})
It("returns the body when marshalling", func() {
extBody.data = []byte("foo")
data, err := extBody.Marshal()
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foo")))
})
})
})

View File

@ -1,265 +0,0 @@
package handshake
import (
"bytes"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Transport Parameters", func() {
Context("for gQUIC", func() {
Context("parsing", func() {
It("sets all values", func() {
values := map[Tag][]byte{
TagSFCW: {0xad, 0xfb, 0xca, 0xde},
TagCFCW: {0xef, 0xbe, 0xad, 0xde},
TagICSL: {0x0d, 0xf0, 0xad, 0xba},
TagMIDS: {0xff, 0x10, 0x00, 0xc0},
}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xdecafbad)))
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0xdeadbeef)))
Expect(params.IdleTimeout).To(Equal(time.Duration(0xbaadf00d) * time.Second))
Expect(params.MaxStreams).To(Equal(uint32(0xc00010ff)))
Expect(params.OmitConnectionID).To(BeFalse())
})
It("reads if the connection ID should be omitted", func() {
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.OmitConnectionID).To(BeTrue())
})
It("doesn't allow idle timeouts below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
values := map[Tag][]byte{
TagICSL: {uint8(t.Seconds()), 0, 0, 0},
}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
})
It("errors when given an invalid SFCW value", func() {
values := map[Tag][]byte{TagSFCW: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid CFCW value", func() {
values := map[Tag][]byte{TagCFCW: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid TCID value", func() {
values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid ICSL value", func() {
values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid MIDS value", func() {
values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
})
Context("writing", func() {
It("returns all necessary parameters ", func() {
params := &TransportParameters{
StreamFlowControlWindow: 0xdeadbeef,
ConnectionFlowControlWindow: 0xdecafbad,
IdleTimeout: 0xbaaaaaad * time.Second,
MaxStreams: 0x1337,
}
entryMap := params.getHelloMap()
Expect(entryMap).To(HaveLen(4))
Expect(entryMap).ToNot(HaveKey(TagTCID))
Expect(entryMap).To(HaveKeyWithValue(TagSFCW, []byte{0xef, 0xbe, 0xad, 0xde}))
Expect(entryMap).To(HaveKeyWithValue(TagCFCW, []byte{0xad, 0xfb, 0xca, 0xde}))
Expect(entryMap).To(HaveKeyWithValue(TagICSL, []byte{0xad, 0xaa, 0xaa, 0xba}))
Expect(entryMap).To(HaveKeyWithValue(TagMIDS, []byte{0x37, 0x13, 0, 0}))
})
It("requests omission of the connection ID", func() {
params := &TransportParameters{OmitConnectionID: true}
entryMap := params.getHelloMap()
Expect(entryMap).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
})
})
})
Context("for TLS", func() {
It("has a string representation", func() {
p := &TransportParameters{
StreamFlowControlWindow: 0x1234,
ConnectionFlowControlWindow: 0x4321,
MaxBidiStreams: 1337,
MaxUniStreams: 7331,
IdleTimeout: 42 * time.Second,
}
Expect(p.String()).To(Equal("&handshake.TransportParameters{StreamFlowControlWindow: 0x1234, ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}"))
})
Context("parsing", func() {
var (
params *TransportParameters
parameters map[transportParameterID][]byte
statelessResetToken []byte
)
marshal := func(p map[transportParameterID][]byte) []byte {
b := &bytes.Buffer{}
for id, val := range p {
utils.BigEndian.WriteUint16(b, uint16(id))
utils.BigEndian.WriteUint16(b, uint16(len(val)))
b.Write(val)
}
return b.Bytes()
}
BeforeEach(func() {
params = &TransportParameters{}
statelessResetToken = bytes.Repeat([]byte{42}, 16)
parameters = map[transportParameterID][]byte{
initialMaxStreamDataParameterID: {0x11, 0x22, 0x33, 0x44},
initialMaxDataParameterID: {0x22, 0x33, 0x44, 0x55},
initialMaxBidiStreamsParameterID: {0x33, 0x44},
initialMaxUniStreamsParameterID: {0x44, 0x55},
idleTimeoutParameterID: {0x13, 0x37},
maxPacketSizeParameterID: {0x73, 0x31},
disableMigrationParameterID: {},
statelessResetTokenParameterID: statelessResetToken,
}
})
It("reads parameters", func() {
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344)))
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455)))
Expect(params.MaxBidiStreams).To(Equal(uint16(0x3344)))
Expect(params.MaxUniStreams).To(Equal(uint16(0x4455)))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
Expect(params.OmitConnectionID).To(BeFalse())
Expect(params.MaxPacketSize).To(Equal(protocol.ByteCount(0x7331)))
Expect(params.DisableMigration).To(BeTrue())
Expect(params.StatelessResetToken).To(Equal(statelessResetToken))
})
It("rejects the parameters if the idle_timeout is missing", func() {
delete(parameters, idleTimeoutParameterID)
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("missing parameter"))
})
It("doesn't allow values below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
})
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_data has the wrong length", func() {
parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
parameters[initialMaxBidiStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_id_bidi: 3 (expected 2)"))
})
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
parameters[initialMaxUniStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_id_uni: 3 (expected 2)"))
})
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
})
It("rejects the parameters if max_packet_size has the wrong length", func() {
parameters[maxPacketSizeParameterID] = []byte{0x11} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for max_packet_size: 1 (expected 2)"))
})
It("rejects max_packet_sizes smaller than 1200 bytes", func() {
parameters[maxPacketSizeParameterID] = []byte{0x4, 0xaf} // 0x4af = 1199
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("invalid value for max_packet_size: 1199 (minimum 1200)"))
})
It("rejects the parameters if disable_connection_migration has the wrong length", func() {
parameters[disableMigrationParameterID] = []byte{0x11} // should empty
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for disable_migration: 1 (expected empty)"))
})
It("rejects the parameters if the stateless_reset_token has the wrong length", func() {
parameters[statelessResetTokenParameterID] = statelessResetToken[1:]
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)"))
})
It("ignores unknown parameters", func() {
parameters[1337] = []byte{42}
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
})
})
Context("marshalling", func() {
It("marshals", func() {
params := &TransportParameters{
StreamFlowControlWindow: 0xdeadbeef,
ConnectionFlowControlWindow: 0xdecafbad,
IdleTimeout: 0xcafe * time.Second,
MaxBidiStreams: 0x1234,
MaxUniStreams: 0x4321,
DisableMigration: true,
StatelessResetToken: bytes.Repeat([]byte{100}, 16),
}
b := &bytes.Buffer{}
params.marshal(b)
p := &TransportParameters{}
Expect(p.unmarshal(b.Bytes())).To(Succeed())
Expect(p.StreamFlowControlWindow).To(Equal(params.StreamFlowControlWindow))
Expect(p.ConnectionFlowControlWindow).To(Equal(params.ConnectionFlowControlWindow))
Expect(p.MaxUniStreams).To(Equal(params.MaxUniStreams))
Expect(p.MaxBidiStreams).To(Equal(params.MaxBidiStreams))
Expect(p.IdleTimeout).To(Equal(params.IdleTimeout))
Expect(p.DisableMigration).To(Equal(params.DisableMigration))
Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken))
})
})
})
})

View File

@ -1,108 +0,0 @@
package protocol
import (
"bytes"
"io"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Connection ID generation", func() {
It("generates random connection IDs", func() {
c1, err := GenerateConnectionID(8)
Expect(err).ToNot(HaveOccurred())
Expect(c1).ToNot(BeZero())
c2, err := GenerateConnectionID(8)
Expect(err).ToNot(HaveOccurred())
Expect(c1).ToNot(Equal(c2))
})
It("generates connection IDs with the requested length", func() {
c, err := GenerateConnectionID(5)
Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(Equal(5))
})
It("generates random length destination connection IDs", func() {
var has8ByteConnID, has18ByteConnID bool
for i := 0; i < 1000; i++ {
c, err := GenerateConnectionIDForInitial()
Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(BeNumerically(">=", 8))
Expect(c.Len()).To(BeNumerically("<=", 18))
if c.Len() == 8 {
has8ByteConnID = true
}
if c.Len() == 18 {
has18ByteConnID = true
}
}
Expect(has8ByteConnID).To(BeTrue())
Expect(has18ByteConnID).To(BeTrue())
})
It("says if connection IDs are equal", func() {
c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
Expect(c1.Equal(c1)).To(BeTrue())
Expect(c2.Equal(c2)).To(BeTrue())
Expect(c1.Equal(c2)).To(BeFalse())
Expect(c2.Equal(c1)).To(BeFalse())
})
It("reads the connection ID", func() {
buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
c, err := ReadConnectionID(buf, 9)
Expect(err).ToNot(HaveOccurred())
Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}))
})
It("returns io.EOF if there's not enough data to read", func() {
buf := bytes.NewBuffer([]byte{1, 2, 3, 4})
_, err := ReadConnectionID(buf, 5)
Expect(err).To(MatchError(io.EOF))
})
It("returns nil for a 0 length connection ID", func() {
buf := bytes.NewBuffer([]byte{1, 2, 3, 4})
c, err := ReadConnectionID(buf, 0)
Expect(err).ToNot(HaveOccurred())
Expect(c).To(BeNil())
})
It("returns the length", func() {
c := ConnectionID{1, 2, 3, 4, 5, 6, 7}
Expect(c.Len()).To(Equal(7))
})
It("has 0 length for the default value", func() {
var c ConnectionID
Expect(c.Len()).To(BeZero())
})
It("returns the bytes", func() {
c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7}))
})
It("returns a nil byte slice for the default value", func() {
var c ConnectionID
Expect(c.Bytes()).To(BeNil())
})
It("has a string representation", func() {
c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
Expect(c.String()).To(Equal("0xdeadbeef42"))
})
It("has a long string representation", func() {
c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}
Expect(c.String()).To(Equal("0x13370000decafbad"))
})
It("has a string representation for the default value", func() {
var c ConnectionID
Expect(c.String()).To(Equal("(empty)"))
})
})

View File

@ -1,15 +0,0 @@
package protocol
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Encryption Level", func() {
It("has the correct string representation", func() {
Expect(EncryptionUnspecified.String()).To(Equal("unknown"))
Expect(EncryptionUnencrypted.String()).To(Equal("unencrypted"))
Expect(EncryptionSecure.String()).To(Equal("encrypted (not forward-secure)"))
Expect(EncryptionForwardSecure.String()).To(Equal("forward-secure"))
})
})

View File

@ -1,244 +0,0 @@
package protocol
import (
"fmt"
"math"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// Tests taken and extended from chrome
var _ = Describe("packet number calculation", func() {
Context("infering a packet number", func() {
getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 {
if v.UsesVarintPacketNumbers() {
switch len {
case PacketNumberLen1:
return uint64(1) << 7
case PacketNumberLen2:
return uint64(1) << 14
case PacketNumberLen4:
return uint64(1) << 30
default:
Fail("invalid packet number len")
}
}
return uint64(1) << (len * 8)
}
check := func(length PacketNumberLen, expected, last uint64, v VersionNumber) {
epoch := getEpoch(length, v)
epochMask := epoch - 1
wirePacketNumber := expected & epochMask
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected)))
}
for _, v := range []VersionNumber{Version39, VersionTLS} {
version := v
Context(fmt.Sprintf("using varint packet numbers: %t", version.UsesVarintPacketNumbers()), func() {
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
length := l
Context(fmt.Sprintf("with %d bytes", length), func() {
epoch := getEpoch(length, version)
epochMask := epoch - 1
It("works near epoch start", func() {
// A few quick manual sanity check
check(length, 1, 0, version)
check(length, epoch+1, epochMask, version)
check(length, epoch, epochMask, version)
// Cases where the last number was close to the start of the range.
for last := uint64(0); last < 10; last++ {
// Small numbers should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, j, last, version)
}
// Large numbers should not wrap either (because we're near 0 already).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, version)
}
}
})
It("works near epoch end", func() {
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := epoch - i
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, epoch+j, last, version)
}
// Large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, version)
}
}
})
// Next check where we're in a non-zero epoch to verify we handle
// reverse wrapping, too.
It("works near previous epoch", func() {
prevEpoch := 1 * epoch
curEpoch := 2 * epoch
// Cases where the last number was close to the start of the range
for i := uint64(0); i < 10; i++ {
last := curEpoch + i
// Small number should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, curEpoch+j, last, version)
}
// But large numbers should reverse wrap.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, prevEpoch+num, last, version)
}
}
})
It("works near next epoch", func() {
curEpoch := 2 * epoch
nextEpoch := 3 * epoch
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := nextEpoch - 1 - i
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, nextEpoch+j, last, version)
}
// but large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, curEpoch+num, last, version)
}
}
})
It("works near next max", func() {
maxNumber := uint64(math.MaxUint64)
maxEpoch := maxNumber & ^epochMask
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
// Subtract 1, because the expected next packet number is 1 more than the
// last packet number.
last := maxNumber - i - 1
// Small numbers should not wrap, because they have nowhere to go.
for j := uint64(0); j < 10; j++ {
check(length, maxEpoch+j, last, version)
}
// Large numbers should not wrap either.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, maxEpoch+num, last, version)
}
}
})
})
}
Context("shortening a packet number for the header", func() {
Context("shortening", func() {
It("sends out low packet numbers as 2 byte", func() {
length := GetPacketNumberLengthForHeader(4, 2, version)
Expect(length).To(Equal(PacketNumberLen2))
})
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, version)
Expect(length).To(Equal(PacketNumberLen2))
})
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
length := GetPacketNumberLengthForHeader(40000, 2, version)
Expect(length).To(Equal(PacketNumberLen4))
})
})
Context("self-consistency", func() {
It("works for small packet numbers", func() {
for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
It("works for small packet numbers and increasing ACKed packets", func() {
for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i / 2)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
epochMask := getEpoch(length, version) - 1
wirePacketNumber := uint64(packetNumber) & epochMask
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
It("also works for larger packet numbers", func() {
var increment uint64
for i := uint64(1); i < getEpoch(PacketNumberLen4, version); i += increment {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
epochMask := getEpoch(length, version) - 1
wirePacketNumber := uint64(packetNumber) & epochMask
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
Expect(inferedPacketNumber).To(Equal(packetNumber))
increment = getEpoch(length, version) / 8
}
})
It("works for packet numbers larger than 2^48", func() {
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i - 1000)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
})
})
})
}
})
Context("determining the minimum length of a packet number", func() {
It("1 byte", func() {
Expect(GetPacketNumberLength(0xFF)).To(Equal(PacketNumberLen1))
})
It("2 byte", func() {
Expect(GetPacketNumberLength(0xFFFF)).To(Equal(PacketNumberLen2))
})
It("4 byte", func() {
Expect(GetPacketNumberLength(0xFFFFFFFF)).To(Equal(PacketNumberLen4))
})
It("6 byte", func() {
Expect(GetPacketNumberLength(0xFFFFFFFFFFFF)).To(Equal(PacketNumberLen6))
})
})
})

View File

@ -1,19 +0,0 @@
package protocol
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Perspective", func() {
It("has a string representation", func() {
Expect(PerspectiveClient.String()).To(Equal("Client"))
Expect(PerspectiveServer.String()).To(Equal("Server"))
Expect(Perspective(0).String()).To(Equal("invalid perspective"))
})
It("returns the opposite", func() {
Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer))
Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient))
})
})

Some files were not shown because too many files have changed in this diff Show More