1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-22 18:17:52 -05:00
v2fly/vendor/github.com/lucas-clemente/quic-go/server.go
2018-11-21 10:49:41 +01:00

421 lines
11 KiB
Go

package quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// packetHandler handles packets
type packetHandler interface {
handlePacket(*receivedPacket)
io.Closer
destroy(error)
GetVersion() protocol.VersionNumber
GetPerspective() protocol.Perspective
}
type unknownPacketHandler interface {
handlePacket(*receivedPacket)
closeWithError(error) error
}
type packetHandlerManager interface {
Add(protocol.ConnectionID, packetHandler)
SetServer(unknownPacketHandler)
Remove(protocol.ConnectionID)
CloseServer()
}
type quicSession interface {
Session
handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber
run() error
destroy(error)
closeRemote(error)
}
type sessionRunner interface {
onHandshakeComplete(Session)
removeConnectionID(protocol.ConnectionID)
}
type runner struct {
onHandshakeCompleteImpl func(Session)
removeConnectionIDImpl func(protocol.ConnectionID)
}
func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) }
func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
var _ sessionRunner = &runner{}
// A Listener of QUIC
type server struct {
mutex sync.Mutex
tlsConf *tls.Config
config *Config
conn net.PacketConn
// If the server is started with ListenAddr, we create a packet conn.
// If it is started with Listen, we take a packet conn as a parameter.
createdPacketConn bool
supportsTLS bool
serverTLS *serverTLS
certChain crypto.CertChain
scfg *handshake.ServerConfig
sessionHandler packetHandlerManager
serverError error
errorChan chan struct{}
closed bool
sessionQueue chan Session
sessionRunner sessionRunner
// set as a member, so they can be set in the tests
newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (quicSession, error)
logger utils.Logger
}
var _ Listener = &server{}
var _ unknownPacketHandler = &server{}
// ListenAddr creates a QUIC server listening on a given address.
// The tls.Config must not be nil, the quic.Config may be nil.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
serv, err := listen(conn, tlsConf, config)
if err != nil {
return nil, err
}
serv.createdPacketConn = true
return serv, nil
}
// Listen listens for QUIC connections on a given net.PacketConn.
// The tls.Config must not be nil, the quic.Config may be nil.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
return listen(conn, tlsConf, config)
}
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
certChain := crypto.NewCertChain(tlsConf)
kex, err := crypto.NewCurve25519KEX()
if err != nil {
return nil, err
}
scfg, err := handshake.NewServerConfig(kex, certChain)
if err != nil {
return nil, err
}
config = populateServerConfig(config)
var supportsTLS bool
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
// check if any of the supported versions supports TLS
if v.UsesTLS() {
supportsTLS = true
break
}
}
sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
s := &server{
conn: conn,
tlsConf: tlsConf,
config: config,
certChain: certChain,
scfg: scfg,
newSession: newSession,
sessionHandler: sessionHandler,
sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}),
supportsTLS: supportsTLS,
logger: utils.DefaultLogger.WithPrefix("server"),
}
s.setup()
if supportsTLS {
if err := s.setupTLS(); err != nil {
return nil, err
}
}
sessionHandler.SetServer(s)
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil
}
func (s *server) setup() {
s.sessionRunner = &runner{
onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess },
removeConnectionIDImpl: s.sessionHandler.Remove,
}
}
func (s *server) setupTLS() error {
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, s.tlsConf, s.logger)
if err != nil {
return err
}
s.serverTLS = serverTLS
// handle TLS connection establishment statelessly
go func() {
for {
select {
case <-s.errorChan:
return
case tlsSession := <-sessionChan:
// The connection ID is a randomly chosen value.
// It is safe to assume that it doesn't collide with other randomly chosen values.
serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
s.sessionHandler.Add(tlsSession.connID, serverSession)
}
}
}()
return nil
}
var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
if cookie == nil {
return false
}
if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
return false
}
var sourceAddr string
if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
sourceAddr = udpAddr.IP.String()
} else {
sourceAddr = clientAddr.String()
}
return sourceAddr == cookie.RemoteAddr
}
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
vsa := defaultAcceptCookie
if config.AcceptCookie != nil {
vsa = config.AcceptCookie
}
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
idleTimeout := protocol.DefaultIdleTimeout
if config.IdleTimeout != 0 {
idleTimeout = config.IdleTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 {
connIDLen = protocol.DefaultConnectionIDLength
}
for _, v := range versions {
if v == protocol.Version44 {
connIDLen = protocol.ConnectionIDLenGQUIC
}
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
IdleTimeout: idleTimeout,
AcceptCookie: vsa,
KeepAlive: config.KeepAlive,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
ConnectionIDLength: connIDLen,
}
}
// Accept returns newly openend sessions
func (s *server) Accept() (Session, error) {
var sess Session
select {
case sess = <-s.sessionQueue:
return sess, nil
case <-s.errorChan:
return nil, s.serverError
}
}
// Close the server
func (s *server) Close() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
return nil
}
return s.closeWithMutex()
}
func (s *server) closeWithMutex() error {
s.sessionHandler.CloseServer()
if s.serverError == nil {
s.serverError = errors.New("server closed")
}
var err error
// If the server was started with ListenAddr, we created the packet conn.
// We need to close it in order to make the go routine reading from that conn return.
if s.createdPacketConn {
err = s.conn.Close()
}
s.closed = true
close(s.errorChan)
return err
}
func (s *server) closeWithError(e error) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
return nil
}
s.serverError = e
return s.closeWithMutex()
}
// Addr returns the server's network address
func (s *server) Addr() net.Addr {
return s.conn.LocalAddr()
}
func (s *server) handlePacket(p *receivedPacket) {
if err := s.handlePacketImpl(p); err != nil {
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
}
}
func (s *server) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
if hdr.VersionFlag || hdr.IsLongHeader {
// send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
return s.sendVersionNegotiationPacket(p)
}
}
if hdr.Type == protocol.PacketTypeInitial && hdr.Version.UsesTLS() {
go s.serverTLS.HandleInitial(p)
return nil
}
// TODO(#943): send Stateless Reset, if this an IETF QUIC packet
if !hdr.VersionFlag && !hdr.Version.UsesIETFHeaderFormat() {
_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
return err
}
// This is (potentially) a Client Hello.
// Make sure it has the minimum required size before spending any more ressources on it.
if len(p.data) < protocol.MinClientHelloSize {
return errors.New("dropping small packet for unknown connection")
}
var destConnID, srcConnID protocol.ConnectionID
if hdr.Version.UsesIETFHeaderFormat() {
srcConnID = hdr.DestConnectionID
} else {
destConnID = hdr.DestConnectionID
srcConnID = hdr.DestConnectionID
}
s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, hdr.Version, p.remoteAddr)
sess, err := s.newSession(
&conn{pconn: s.conn, currentAddr: p.remoteAddr},
s.sessionRunner,
hdr.Version,
destConnID,
srcConnID,
s.scfg,
s.tlsConf,
s.config,
s.logger,
)
if err != nil {
return err
}
s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
go sess.run()
sess.handlePacket(p)
return nil
}
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
hdr := p.header
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
var data []byte
if hdr.IsPublicHeader {
data = wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions)
} else {
var err error
data, err = wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
if err != nil {
return err
}
}
_, err := s.conn.WriteTo(data, p.remoteAddr)
return err
}