From e00d80eac432b4265f8d4929ec45db30a1a41de6 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Thu, 1 Jul 2021 18:58:13 +0100 Subject: [PATCH] cancel failed grpc connection --- transport/internet/grpc/dial.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 6ef93678d..a86247cf3 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -36,6 +36,8 @@ func init() { common.Must(internet.RegisterTransportDialer(protocolName, Dial)) } +type dialerCanceller func() + var ( globalDialerMap map[net.Destination]*grpc.ClientConn globalDialerAccess sync.Mutex @@ -51,19 +53,20 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne dialOption = grpc.WithTransportCredentials(credentials.NewTLS(config.GetTLSConfig())) } - conn, err := getGrpcClient(ctx, dest, dialOption) + conn, canceller, err := getGrpcClient(ctx, dest, dialOption) if err != nil { return nil, newError("Cannot dial grpc").Base(err) } client := encoding.NewGunServiceClient(conn) gunService, err := client.(encoding.GunServiceClientX).TunCustomName(ctx, grpcSettings.ServiceName) if err != nil { + canceller() return nil, newError("Cannot dial grpc").Base(err) } return encoding.NewGunConn(gunService, nil), nil } -func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.DialOption) (*grpc.ClientConn, error) { +func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.DialOption) (*grpc.ClientConn, dialerCanceller, error) { globalDialerAccess.Lock() defer globalDialerAccess.Unlock() @@ -71,9 +74,15 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di globalDialerMap = make(map[net.Destination]*grpc.ClientConn) } + canceller := func() { + globalDialerAccess.Lock() + defer globalDialerAccess.Unlock() + delete(globalDialerMap, dest) + } + // TODO Should support chain proxy to the same destination if client, found := globalDialerMap[dest]; found && client.GetState() != connectivity.Shutdown { - return client, nil + return client, canceller, nil } conn, err := grpc.Dial( @@ -106,5 +115,5 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di }), ) globalDialerMap[dest] = conn - return conn, err + return conn, canceller, err }