diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index 347694701..ef54228d6 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -56,9 +56,14 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co } } -func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) { +func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) { var lc net.ListenConfig var network, address string + // callback is called after the Listen function returns + // this is used to wrap the listener and do some post processing + callback := func(l net.Listener, err error) (net.Listener, error) { + return l, err + } switch addr := addr.(type) { case *net.TCPAddr: network = addr.Network() @@ -81,6 +86,7 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S } } else { // normal unix domain socket + var fileMode *os.FileMode // parse file mode from address if s := strings.Split(address, ","); len(s) == 2 { fMode, err := strconv.ParseUint(s[1], 8, 32) @@ -88,18 +94,8 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S return nil, newError("failed to parse file mode").Base(err) } address = s[0] - // set file mode for unix domain socket when it is created - defer func(name string, mode os.FileMode) { - if err != nil { - return - } - if cerr := os.Chmod(name, mode); cerr != nil { - // failed to set file mode, close the listener - l.Close() - l = nil - err = newError("failed to set file mode for file: ", name).Base(cerr) - } - }(address, os.FileMode(fMode)) + fm := os.FileMode(fMode) + fileMode = &fm } // normal unix domain socket needs lock locker := &FileLocker{ @@ -108,20 +104,29 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S if err := locker.Acquire(); err != nil { return nil, err } - defer func(locker *FileLocker) { - // combine listener and locker - if err == nil { - l = &combinedListener{Listener: l, locker: locker} - } else { - // failed to create listener, release the locker + // set file mode for unix domain socket when it is created + callback = func(l net.Listener, err error) (net.Listener, error) { + if err != nil { locker.Release() + return nil, err } - }(locker) + l = &combinedListener{Listener: l, locker: locker} + if fileMode == nil { + return l, err + } + if cerr := os.Chmod(address, *fileMode); cerr != nil { + // failed to set file mode, close the listener + l.Close() + return nil, newError("failed to set file mode for file: ", address).Base(cerr) + } + return l, err + } } } - l, err = lc.Listen(ctx, network, address) - if sockopt != nil && sockopt.AcceptProxyProtocol { + l, err := lc.Listen(ctx, network, address) + l, err = callback(l, err) + if err == nil && sockopt != nil && sockopt.AcceptProxyProtocol { policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil } l = &proxyproto.Listener{Listener: l, Policy: policyFunc} }