mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-05 00:47:51 -05:00
466 lines
13 KiB
Go
466 lines
13 KiB
Go
package quic
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
)
|
|
|
|
// packetHandler handles packets
|
|
type packetHandler interface {
|
|
handlePacket(*receivedPacket)
|
|
io.Closer
|
|
destroy(error)
|
|
GetVersion() protocol.VersionNumber
|
|
GetPerspective() protocol.Perspective
|
|
}
|
|
|
|
type unknownPacketHandler interface {
|
|
handlePacket(*receivedPacket)
|
|
closeWithError(error) error
|
|
}
|
|
|
|
type packetHandlerManager interface {
|
|
Add(protocol.ConnectionID, packetHandler)
|
|
Retire(protocol.ConnectionID)
|
|
Remove(protocol.ConnectionID)
|
|
SetServer(unknownPacketHandler)
|
|
CloseServer()
|
|
}
|
|
|
|
type quicSession interface {
|
|
Session
|
|
handlePacket(*receivedPacket)
|
|
GetVersion() protocol.VersionNumber
|
|
run() error
|
|
destroy(error)
|
|
closeRemote(error)
|
|
}
|
|
|
|
type sessionRunner interface {
|
|
onHandshakeComplete(Session)
|
|
retireConnectionID(protocol.ConnectionID)
|
|
removeConnectionID(protocol.ConnectionID)
|
|
}
|
|
|
|
type runner struct {
|
|
onHandshakeCompleteImpl func(Session)
|
|
retireConnectionIDImpl func(protocol.ConnectionID)
|
|
removeConnectionIDImpl func(protocol.ConnectionID)
|
|
}
|
|
|
|
func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) }
|
|
func (r *runner) retireConnectionID(c protocol.ConnectionID) { r.retireConnectionIDImpl(c) }
|
|
func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
|
|
|
|
var _ sessionRunner = &runner{}
|
|
|
|
// A Listener of QUIC
|
|
type server struct {
|
|
mutex sync.Mutex
|
|
|
|
tlsConf *tls.Config
|
|
config *Config
|
|
|
|
conn net.PacketConn
|
|
// If the server is started with ListenAddr, we create a packet conn.
|
|
// If it is started with Listen, we take a packet conn as a parameter.
|
|
createdPacketConn bool
|
|
|
|
cookieGenerator *handshake.CookieGenerator
|
|
|
|
sessionHandler packetHandlerManager
|
|
|
|
// set as a member, so they can be set in the tests
|
|
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error)
|
|
|
|
serverError error
|
|
errorChan chan struct{}
|
|
closed bool
|
|
|
|
sessionQueue chan Session
|
|
|
|
sessionRunner sessionRunner
|
|
|
|
logger utils.Logger
|
|
}
|
|
|
|
var _ Listener = &server{}
|
|
var _ unknownPacketHandler = &server{}
|
|
|
|
// ListenAddr creates a QUIC server listening on a given address.
|
|
// The tls.Config must not be nil, the quic.Config may be nil.
|
|
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
|
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn, err := net.ListenUDP("udp", udpAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
serv, err := listen(conn, tlsConf, config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
serv.createdPacketConn = true
|
|
return serv, nil
|
|
}
|
|
|
|
// Listen listens for QUIC connections on a given net.PacketConn.
|
|
// The tls.Config must not be nil, the quic.Config may be nil.
|
|
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
|
|
return listen(conn, tlsConf, config)
|
|
}
|
|
|
|
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
|
|
config = populateServerConfig(config)
|
|
for _, v := range config.Versions {
|
|
if !protocol.IsValidVersion(v) {
|
|
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
|
|
}
|
|
}
|
|
|
|
sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s := &server{
|
|
conn: conn,
|
|
tlsConf: tlsConf,
|
|
config: config,
|
|
sessionHandler: sessionHandler,
|
|
sessionQueue: make(chan Session, 5),
|
|
errorChan: make(chan struct{}),
|
|
newSession: newSession,
|
|
logger: utils.DefaultLogger.WithPrefix("server"),
|
|
}
|
|
if err := s.setup(); err != nil {
|
|
return nil, err
|
|
}
|
|
sessionHandler.SetServer(s)
|
|
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
|
return s, nil
|
|
}
|
|
|
|
func (s *server) setup() error {
|
|
s.sessionRunner = &runner{
|
|
onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess },
|
|
retireConnectionIDImpl: s.sessionHandler.Retire,
|
|
removeConnectionIDImpl: s.sessionHandler.Remove,
|
|
}
|
|
cookieGenerator, err := handshake.NewCookieGenerator()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.cookieGenerator = cookieGenerator
|
|
return nil
|
|
}
|
|
|
|
var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
|
|
if cookie == nil {
|
|
return false
|
|
}
|
|
if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
|
|
return false
|
|
}
|
|
var sourceAddr string
|
|
if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
|
|
sourceAddr = udpAddr.IP.String()
|
|
} else {
|
|
sourceAddr = clientAddr.String()
|
|
}
|
|
return sourceAddr == cookie.RemoteAddr
|
|
}
|
|
|
|
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
|
// it may be called with nil
|
|
func populateServerConfig(config *Config) *Config {
|
|
if config == nil {
|
|
config = &Config{}
|
|
}
|
|
versions := config.Versions
|
|
if len(versions) == 0 {
|
|
versions = protocol.SupportedVersions
|
|
}
|
|
|
|
vsa := defaultAcceptCookie
|
|
if config.AcceptCookie != nil {
|
|
vsa = config.AcceptCookie
|
|
}
|
|
|
|
handshakeTimeout := protocol.DefaultHandshakeTimeout
|
|
if config.HandshakeTimeout != 0 {
|
|
handshakeTimeout = config.HandshakeTimeout
|
|
}
|
|
idleTimeout := protocol.DefaultIdleTimeout
|
|
if config.IdleTimeout != 0 {
|
|
idleTimeout = config.IdleTimeout
|
|
}
|
|
|
|
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
|
if maxReceiveStreamFlowControlWindow == 0 {
|
|
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
|
|
}
|
|
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
|
|
if maxReceiveConnectionFlowControlWindow == 0 {
|
|
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
|
|
}
|
|
maxIncomingStreams := config.MaxIncomingStreams
|
|
if maxIncomingStreams == 0 {
|
|
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
|
|
} else if maxIncomingStreams < 0 {
|
|
maxIncomingStreams = 0
|
|
}
|
|
maxIncomingUniStreams := config.MaxIncomingUniStreams
|
|
if maxIncomingUniStreams == 0 {
|
|
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
|
|
} else if maxIncomingUniStreams < 0 {
|
|
maxIncomingUniStreams = 0
|
|
}
|
|
connIDLen := config.ConnectionIDLength
|
|
if connIDLen == 0 {
|
|
connIDLen = protocol.DefaultConnectionIDLength
|
|
}
|
|
|
|
return &Config{
|
|
Versions: versions,
|
|
HandshakeTimeout: handshakeTimeout,
|
|
IdleTimeout: idleTimeout,
|
|
AcceptCookie: vsa,
|
|
KeepAlive: config.KeepAlive,
|
|
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
|
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
|
MaxIncomingStreams: maxIncomingStreams,
|
|
MaxIncomingUniStreams: maxIncomingUniStreams,
|
|
ConnectionIDLength: connIDLen,
|
|
}
|
|
}
|
|
|
|
// Accept returns newly openend sessions
|
|
func (s *server) Accept() (Session, error) {
|
|
var sess Session
|
|
select {
|
|
case sess = <-s.sessionQueue:
|
|
return sess, nil
|
|
case <-s.errorChan:
|
|
return nil, s.serverError
|
|
}
|
|
}
|
|
|
|
// Close the server
|
|
func (s *server) Close() error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
if s.closed {
|
|
return nil
|
|
}
|
|
return s.closeWithMutex()
|
|
}
|
|
|
|
func (s *server) closeWithMutex() error {
|
|
s.sessionHandler.CloseServer()
|
|
if s.serverError == nil {
|
|
s.serverError = errors.New("server closed")
|
|
}
|
|
var err error
|
|
// If the server was started with ListenAddr, we created the packet conn.
|
|
// We need to close it in order to make the go routine reading from that conn return.
|
|
if s.createdPacketConn {
|
|
err = s.conn.Close()
|
|
}
|
|
s.closed = true
|
|
close(s.errorChan)
|
|
return err
|
|
}
|
|
|
|
func (s *server) closeWithError(e error) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
if s.closed {
|
|
return nil
|
|
}
|
|
s.serverError = e
|
|
return s.closeWithMutex()
|
|
}
|
|
|
|
// Addr returns the server's network address
|
|
func (s *server) Addr() net.Addr {
|
|
return s.conn.LocalAddr()
|
|
}
|
|
|
|
func (s *server) handlePacket(p *receivedPacket) {
|
|
if err := s.handlePacketImpl(p); err != nil {
|
|
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
|
|
}
|
|
}
|
|
|
|
func (s *server) handlePacketImpl(p *receivedPacket) error {
|
|
hdr := p.header
|
|
|
|
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
|
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
|
return s.sendVersionNegotiationPacket(p)
|
|
}
|
|
if hdr.Type == protocol.PacketTypeInitial {
|
|
go s.handleInitial(p)
|
|
}
|
|
// TODO(#943): send Stateless Reset
|
|
return nil
|
|
}
|
|
|
|
func (s *server) handleInitial(p *receivedPacket) {
|
|
// TODO: add a check that DestConnID == SrcConnID
|
|
s.logger.Debugf("<- Received Initial packet.")
|
|
sess, connID, err := s.handleInitialImpl(p)
|
|
if err != nil {
|
|
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
|
return
|
|
}
|
|
if sess == nil { // a retry was done
|
|
return
|
|
}
|
|
serverSession := newServerSession(sess, s.config, s.logger)
|
|
s.sessionHandler.Add(connID, serverSession)
|
|
}
|
|
|
|
func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
|
|
hdr := p.header
|
|
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
|
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
|
|
}
|
|
if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize {
|
|
return nil, nil, errors.New("dropping too small Initial packet")
|
|
}
|
|
|
|
var cookie *Cookie
|
|
var origDestConnectionID protocol.ConnectionID
|
|
if len(hdr.Token) > 0 {
|
|
c, err := s.cookieGenerator.DecodeToken(hdr.Token)
|
|
if err == nil {
|
|
cookie = &Cookie{
|
|
RemoteAddr: c.RemoteAddr,
|
|
SentTime: c.SentTime,
|
|
}
|
|
origDestConnectionID = c.OriginalDestConnectionID
|
|
}
|
|
}
|
|
if !s.config.AcceptCookie(p.remoteAddr, cookie) {
|
|
// Log the Initial packet now.
|
|
// If no Retry is sent, the packet will be logged by the session.
|
|
p.header.Log(s.logger)
|
|
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
|
|
}
|
|
|
|
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
s.logger.Debugf("Changing connection ID to %s.", connID)
|
|
sess, err := s.createNewSession(
|
|
p.remoteAddr,
|
|
origDestConnectionID,
|
|
hdr.DestConnectionID,
|
|
hdr.SrcConnectionID,
|
|
connID,
|
|
hdr.Version,
|
|
)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
sess.handlePacket(p)
|
|
return sess, connID, nil
|
|
}
|
|
|
|
func (s *server) createNewSession(
|
|
remoteAddr net.Addr,
|
|
origDestConnID protocol.ConnectionID,
|
|
clientDestConnID protocol.ConnectionID,
|
|
destConnID protocol.ConnectionID,
|
|
srcConnID protocol.ConnectionID,
|
|
version protocol.VersionNumber,
|
|
) (quicSession, error) {
|
|
params := &handshake.TransportParameters{
|
|
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
|
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
|
InitialMaxStreamDataUni: protocol.InitialMaxStreamData,
|
|
InitialMaxData: protocol.InitialMaxData,
|
|
IdleTimeout: s.config.IdleTimeout,
|
|
MaxBidiStreams: uint64(s.config.MaxIncomingStreams),
|
|
MaxUniStreams: uint64(s.config.MaxIncomingUniStreams),
|
|
DisableMigration: true,
|
|
// TODO(#855): generate a real token
|
|
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
|
|
OriginalConnectionID: origDestConnID,
|
|
}
|
|
sess, err := s.newSession(
|
|
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
|
s.sessionRunner,
|
|
clientDestConnID,
|
|
destConnID,
|
|
srcConnID,
|
|
s.config,
|
|
s.tlsConf,
|
|
params,
|
|
s.logger,
|
|
version,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go sess.run()
|
|
return sess, nil
|
|
}
|
|
|
|
func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
|
token, err := s.cookieGenerator.NewToken(remoteAddr, hdr.DestConnectionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
replyHdr := &wire.Header{
|
|
IsLongHeader: true,
|
|
Type: protocol.PacketTypeRetry,
|
|
Version: hdr.Version,
|
|
SrcConnectionID: connID,
|
|
DestConnectionID: hdr.SrcConnectionID,
|
|
OrigDestConnectionID: hdr.DestConnectionID,
|
|
Token: token,
|
|
}
|
|
s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
|
|
replyHdr.Log(s.logger)
|
|
buf := &bytes.Buffer{}
|
|
if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil {
|
|
return err
|
|
}
|
|
if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
|
|
s.logger.Debugf("Error sending Retry: %s", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
|
|
hdr := p.header
|
|
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
|
|
|
|
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.conn.WriteTo(data, p.remoteAddr)
|
|
return err
|
|
}
|