From 2e26cf65874213cbbbe1ad997c0ffbf4588e2862 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Wed, 28 Apr 2021 15:43:43 +0100 Subject: [PATCH] fix: make sure the ctx is propagated to connections --- app/dns/nameserver_quic.go | 12 ++++++------ app/dns/nameserver_udp.go | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/app/dns/nameserver_quic.go b/app/dns/nameserver_quic.go index 902123a8c..7854a5a3e 100644 --- a/app/dns/nameserver_quic.go +++ b/app/dns/nameserver_quic.go @@ -330,7 +330,7 @@ func isActive(s quic.Session) bool { } } -func (s *QUICNameServer) getSession() (quic.Session, error) { +func (s *QUICNameServer) getSession(ctx context.Context) (quic.Session, error) { var session quic.Session s.RLock() session = s.session @@ -348,14 +348,14 @@ func (s *QUICNameServer) getSession() (quic.Session, error) { defer s.Unlock() var err error - session, err = s.openSession() + session, err = s.openSession(ctx) if err != nil { // This does not look too nice, but QUIC (or maybe quic-go) // doesn't seem stable enough. // Maybe retransmissions aren't fully implemented in quic-go? // Anyways, the simple solution is to make a second try when // it fails to open the QUIC session. - session, err = s.openSession() + session, err = s.openSession(ctx) if err != nil { return nil, err } @@ -364,13 +364,13 @@ func (s *QUICNameServer) getSession() (quic.Session, error) { return session, nil } -func (s *QUICNameServer) openSession() (quic.Session, error) { +func (s *QUICNameServer) openSession(ctx context.Context) (quic.Session, error) { tlsConfig := tls.Config{} quicConfig := &quic.Config{ HandshakeIdleTimeout: handshakeIdleTimeout, } - session, err := quic.DialAddrContext(context.Background(), s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig) + session, err := quic.DialAddrContext(ctx, s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig) if err != nil { return nil, err } @@ -379,7 +379,7 @@ func (s *QUICNameServer) openSession() (quic.Session, error) { } func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) { - session, err := s.getSession() + session, err := s.getSession(ctx) if err != nil { return nil, err } diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index 0debdee6e..323009f40 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -192,7 +192,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client for _, req := range reqs { s.addPendingRequest(req) b, _ := dns.PackMessage(req.msg) - udpCtx := context.Background() + udpCtx := ctx if inbound := session.InboundFromContext(ctx); inbound != nil { udpCtx = session.ContextWithInbound(udpCtx, inbound) }