diff --git a/common/net/system.go b/common/net/system.go index cb5657e96..e205e4222 100644 --- a/common/net/system.go +++ b/common/net/system.go @@ -14,6 +14,7 @@ var ( DialUDP = net.DialUDP DialUnix = net.DialUnix FileConn = net.FileConn + FileListener = net.FileListener Listen = net.Listen ListenTCP = net.ListenTCP ListenUDP = net.ListenUDP diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index 84f9f1e8a..81d859c98 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -59,6 +59,8 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) { var lc net.ListenConfig var network, address string + var l net.Listener + var err error // callback is called after the Listen function returns // this is used to wrap the listener and do some post processing callback := func(l net.Listener, err error) (net.Listener, error) { @@ -93,6 +95,16 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S copy(fullAddr, address[1:]) address = string(fullAddr) } + } else if strings.HasPrefix(address, "/dev/fd/") { + fd, err := strconv.Atoi(address[8:]) + if err != nil { + return nil, err + } + _ = syscall.SetNonblock(fd, true) + l, err = net.FileListener(os.NewFile(uintptr(fd), address)) + if err != nil { + return nil, err + } } else { // normal unix domain socket var fileMode *os.FileMode @@ -133,13 +145,18 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S } } - l, err := lc.Listen(ctx, network, address) - l, err = callback(l, err) - if err == nil && sockopt != nil && sockopt.AcceptProxyProtocol { + if l == nil { + l, err = lc.Listen(ctx, network, address) + l, err = callback(l, err) + if err != nil { + return nil, err + } + } + if sockopt != nil && sockopt.AcceptProxyProtocol { policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil } l = &proxyproto.Listener{Listener: l, Policy: policyFunc} } - return l, err + return l, nil } func (dl *DefaultListener) ListenPacket(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.PacketConn, error) {