diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 4ed3bd732..908c3bafa 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -13,7 +13,8 @@ import ( ) type DokodemoDoor struct { - sync.Mutex + tcpMutex sync.RWMutex + udpMutex sync.RWMutex config Config accepting bool address v2net.Address @@ -35,16 +36,16 @@ func NewDokodemoDoor(space app.Space, config Config) *DokodemoDoor { func (this *DokodemoDoor) Close() { this.accepting = false if this.tcpListener != nil { - this.Lock() this.tcpListener.Close() + this.tcpMutex.Lock() this.tcpListener = nil - this.Unlock() + this.tcpMutex.Unlock() } if this.udpConn != nil { - this.Lock() this.udpConn.Close() + this.udpMutex.Lock() this.udpConn = nil - this.Unlock() + this.udpMutex.Unlock() } } @@ -84,13 +85,12 @@ func (this *DokodemoDoor) ListenUDP(port v2net.Port) error { func (this *DokodemoDoor) handleUDPPackets() { for this.accepting { buffer := alloc.NewBuffer() - var udpConn *net.UDPConn - this.Lock() - if this.udpConn != nil { - udpConn = this.udpConn + this.udpMutex.RLock() + if !this.accepting { + return } - this.Unlock() - nBytes, addr, err := udpConn.ReadFromUDP(buffer.Value) + nBytes, addr, err := this.udpConn.ReadFromUDP(buffer.Value) + this.udpMutex.RUnlock() buffer.Slice(0, nBytes) if err != nil { buffer.Release() @@ -103,7 +103,13 @@ func (this *DokodemoDoor) handleUDPPackets() { close(ray.InboundInput()) for payload := range ray.InboundOutput() { - udpConn.WriteToUDP(payload.Value, addr) + this.udpMutex.RLock() + if !this.accepting { + this.udpMutex.RUnlock() + return + } + this.udpConn.WriteToUDP(payload.Value, addr) + this.udpMutex.RUnlock() } } } @@ -126,8 +132,11 @@ func (this *DokodemoDoor) ListenTCP(port v2net.Port) error { func (this *DokodemoDoor) AcceptTCPConnections() { for this.accepting { retry.Timed(100, 100).On(func() error { - this.Lock() - defer this.Unlock() + if !this.accepting { + return nil + } + this.tcpMutex.RLock() + defer this.tcpMutex.RUnlock() if this.tcpListener != nil { connection, err := this.tcpListener.AcceptTCP() if err != nil { diff --git a/proxy/http/http.go b/proxy/http/http.go index 69c9b03c7..dedff2f50 100644 --- a/proxy/http/http.go +++ b/proxy/http/http.go @@ -35,8 +35,8 @@ func NewHttpProxyServer(space app.Space, config Config) *HttpProxyServer { func (this *HttpProxyServer) Close() { this.accepting = false if this.tcpListener != nil { - this.Lock() this.tcpListener.Close() + this.Lock() this.tcpListener = nil this.Unlock() } diff --git a/proxy/socks/socks.go b/proxy/socks/socks.go index 44a34fdd3..c304eab2e 100644 --- a/proxy/socks/socks.go +++ b/proxy/socks/socks.go @@ -23,7 +23,8 @@ var ( // SocksServer is a SOCKS 5 proxy server type SocksServer struct { - sync.RWMutex + tcpMutex sync.RWMutex + udpMutex sync.RWMutex accepting bool space app.Space config Config @@ -42,20 +43,16 @@ func NewSocksServer(space app.Space, config Config) *SocksServer { func (this *SocksServer) Close() { this.accepting = false if this.tcpListener != nil { - this.Lock() - if this.tcpListener != nil { - this.tcpListener.Close() - this.tcpListener = nil - } - this.Unlock() + this.tcpListener.Close() + this.tcpMutex.Lock() + this.tcpListener = nil + this.tcpMutex.Unlock() } if this.udpConn != nil { - this.Lock() - if this.udpConn != nil { - this.udpConn.Close() - this.udpConn = nil - } - this.Unlock() + this.udpConn.Close() + this.udpMutex.Lock() + this.udpConn = nil + this.udpMutex.Unlock() } } @@ -81,12 +78,16 @@ func (this *SocksServer) Listen(port v2net.Port) error { func (this *SocksServer) AcceptConnections() { for this.accepting { retry.Timed(100 /* times */, 100 /* ms */).On(func() error { - this.RLock() - defer this.RUnlock() if !this.accepting { return nil } + this.tcpMutex.RLock() + if this.tcpListener == nil { + this.tcpMutex.RUnlock() + return nil + } connection, err := this.tcpListener.AcceptTCP() + this.tcpMutex.RUnlock() if err != nil { log.Error("Socks failed to accept new connection %v", err) return err diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go index 3d2dce3e5..ee69a6740 100644 --- a/proxy/socks/udp.go +++ b/proxy/socks/udp.go @@ -30,13 +30,13 @@ func (this *SocksServer) ListenUDP(port v2net.Port) error { func (this *SocksServer) AcceptPackets() error { for this.accepting { buffer := alloc.NewBuffer() - this.RLock() + this.udpMutex.RLock() if !this.accepting { - this.RUnlock() + this.udpMutex.RUnlock() return nil } nBytes, addr, err := this.udpConn.ReadFromUDP(buffer.Value) - this.RUnlock() + this.udpMutex.RUnlock() if err != nil { log.Error("Socks failed to read UDP packets: %v", err) buffer.Release() @@ -82,13 +82,13 @@ func (this *SocksServer) handlePacket(packet v2net.Packet, clientAddr *net.UDPAd udpMessage := alloc.NewSmallBuffer().Clear() response.Write(udpMessage) - this.RLock() + this.udpMutex.RLock() if !this.accepting { - this.RUnlock() + this.udpMutex.RUnlock() return } nBytes, err := this.udpConn.WriteToUDP(udpMessage.Value, clientAddr) - this.RUnlock() + this.udpMutex.RUnlock() udpMessage.Release() response.Data.Release() if err != nil { diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 36c40fba4..4affb9619 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -38,11 +38,9 @@ func NewVMessInboundHandler(space app.Space, clients user.UserSet) *VMessInbound func (this *VMessInboundHandler) Close() { this.accepting = false if this.listener != nil { + this.listener.Close() this.Lock() - if this.listener != nil { - this.listener.Close() - this.listener = nil - } + this.listener = nil this.Unlock() } } diff --git a/shell/point/inbound_detour.go b/shell/point/inbound_detour.go index 912dec687..0803fc0cb 100644 --- a/shell/point/inbound_detour.go +++ b/shell/point/inbound_detour.go @@ -39,6 +39,12 @@ func (this *InboundDetourHandler) Initialize() error { return nil } +func (this *InboundDetourHandler) Close() { + for _, ich := range this.ich { + ich.handler.Close() + } +} + // Starts the inbound connection handler. func (this *InboundDetourHandler) Start() error { for _, ich := range this.ich { diff --git a/shell/point/point.go b/shell/point/point.go index ed4ceffda..f42134089 100644 --- a/shell/point/point.go +++ b/shell/point/point.go @@ -113,6 +113,13 @@ func NewPoint(pConfig PointConfig) (*Point, error) { return vpoint, nil } +func (this *Point) Close() { + this.ich.Close() + for _, idh := range this.idh { + idh.Close() + } +} + // Start starts the Point server, and return any error during the process. // In the case of any errors, the state of the server is unpredicatable. func (this *Point) Start() error { diff --git a/testing/scenarios/dokodemo_test.go b/testing/scenarios/dokodemo_test.go index b581c3692..3c4958db5 100644 --- a/testing/scenarios/dokodemo_test.go +++ b/testing/scenarios/dokodemo_test.go @@ -50,4 +50,6 @@ func TestDokodemoTCP(t *testing.T) { assert.StringLiteral("Processed: " + payload).Equals(string(response[:nBytes])) conn.Close() } + + CloseAllServers() } diff --git a/testing/scenarios/router_test.go b/testing/scenarios/router_test.go index 3226fd4b6..7313a0722 100644 --- a/testing/scenarios/router_test.go +++ b/testing/scenarios/router_test.go @@ -76,4 +76,6 @@ func TestRouter(t *testing.T) { assert.Int(nBytes).Equals(0) assert.Bool(tcpServer2Accessed).IsFalse() conn.Close() + + CloseAllServers() } diff --git a/testing/scenarios/server_env.go b/testing/scenarios/server_env.go index 88f4d3fa5..e6db94f03 100644 --- a/testing/scenarios/server_env.go +++ b/testing/scenarios/server_env.go @@ -27,7 +27,7 @@ import ( ) var ( - serverup = make(map[string]bool) + runningServers = make([]*point.Point, 0, 10) ) func TestFile(filename string) string { @@ -35,9 +35,6 @@ func TestFile(filename string) string { } func InitializeServerSetOnce(testcase string) error { - if up, found := serverup[testcase]; found && up { - return nil - } err := InitializeServer(TestFile(testcase + "_server.json")) if err != nil { return err @@ -46,7 +43,6 @@ func InitializeServerSetOnce(testcase string) error { if err != nil { return err } - serverup[testcase] = true return nil } @@ -68,6 +64,14 @@ func InitializeServer(configFile string) error { log.Error("Error starting Point server: %v", err) return err } + runningServers = append(runningServers, vPoint) return nil } + +func CloseAllServers() { + for _, server := range runningServers { + server.Close() + } + runningServers = make([]*point.Point, 0, 10) +} diff --git a/testing/scenarios/socks_end_test.go b/testing/scenarios/socks_end_test.go index e0a36c95b..e5b69c764 100644 --- a/testing/scenarios/socks_end_test.go +++ b/testing/scenarios/socks_end_test.go @@ -12,10 +12,6 @@ import ( "github.com/v2ray/v2ray-core/testing/servers/udp" ) -var ( - serverUp = false -) - func TestTCPConnection(t *testing.T) { v2testing.Current(t) @@ -86,6 +82,8 @@ func TestTCPConnection(t *testing.T) { conn.Close() } + + CloseAllServers() } func TestTCPBind(t *testing.T) { @@ -135,6 +133,8 @@ func TestTCPBind(t *testing.T) { assert.Bytes(connectResponse[:nBytes]).Equals([]byte{socks5Version, 7, 0, 1, 0, 0, 0, 0, 0, 0}) conn.Close() + + CloseAllServers() } func TestUDPAssociate(t *testing.T) { @@ -204,4 +204,6 @@ func TestUDPAssociate(t *testing.T) { udpConn.Close() conn.Close() + + CloseAllServers() }