From eb05a92592d975be17282aabc1aab3d1e1cd111f Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Wed, 29 Aug 2018 23:00:01 +0200 Subject: [PATCH] dont start periodic task until necessary --- app/dns/server.go | 7 +-- app/dns/udpns.go | 12 +++-- app/proxyman/inbound/worker.go | 53 +++++++++++---------- app/router/condition.go | 8 +--- common/signal/pubsub/pubsub.go | 7 ++- common/strmatcher/benchmark_test.go | 17 ------- common/strmatcher/strmatcher.go | 73 ----------------------------- common/task/periodic.go | 44 ++++++++--------- common/task/periodic_test.go | 6 +++ proxy/vmess/encoding/server.go | 16 +++---- 10 files changed, 81 insertions(+), 162 deletions(-) diff --git a/app/dns/server.go b/app/dns/server.go index 7ffd39a90..bdd029444 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -81,12 +81,7 @@ func New(ctx context.Context, config *Config) (*Server, error) { } } - if domainMatcher.Size() > 64 { - server.domainMatcher = strmatcher.NewCachedMatcherGroup(domainMatcher) - } else { - server.domainMatcher = domainMatcher - } - + server.domainMatcher = domainMatcher server.domainIndexMap = domainIndexMap } diff --git a/app/dns/udpns.go b/app/dns/udpns.go index 5cc89e66d..5c4a972c3 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -59,13 +59,18 @@ func NewClassicNameServer(address net.Destination, dispatcher core.Dispatcher, c Execute: s.Cleanup, } s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse) - common.Must(s.cleanup.Start()) return s } func (s *ClassicNameServer) Cleanup() error { now := time.Now() s.Lock() + defer s.Unlock() + + if len(s.ips) == 0 && len(s.requests) == 0 { + return newError("nothing to do. stopping...") + } + for domain, ips := range s.ips { newIPs := make([]IPRecord, 0, len(ips)) for _, ip := range ips { @@ -94,7 +99,6 @@ func (s *ClassicNameServer) Cleanup() error { s.requests = make(map[uint16]pendingRequest) } - s.Unlock() return nil } @@ -151,7 +155,6 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) { s.Lock() - defer s.Unlock() newError("updating IP records for domain:", domain).AtDebug().WriteToLog() now := time.Now() @@ -163,6 +166,9 @@ func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) { } s.ips[domain] = ips s.pub.Publish(domain, nil) + + s.Unlock() + common.Must(s.cleanup.Start()) } func (s *ClassicNameServer) getMsgOptions() *dns.OPT { diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 198b8b8fd..6695876e3 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -289,6 +289,8 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest } if !existing { + common.Must(w.checker.Start()) + go func() { ctx := context.Background() sid := session.NewID() @@ -324,40 +326,41 @@ func (w *udpWorker) handlePackets() { } } +func (w *udpWorker) clean() error { + nowSec := time.Now().Unix() + w.Lock() + defer w.Unlock() + + if len(w.activeConn) == 0 { + return newError("no more connections. stopping...") + } + + for addr, conn := range w.activeConn { + if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 8 { + delete(w.activeConn, addr) + conn.Close() // nolint: errcheck + } + } + + if len(w.activeConn) == 0 { + w.activeConn = make(map[connID]*udpConn, 16) + } + + return nil +} + func (w *udpWorker) Start() error { w.activeConn = make(map[connID]*udpConn, 16) h, err := udp.ListenUDP(w.address, w.port, udp.HubReceiveOriginalDestination(w.recvOrigDest), udp.HubCapacity(256)) if err != nil { return err } + w.checker = &task.Periodic{ Interval: time.Second * 16, - Execute: func() error { - nowSec := time.Now().Unix() - w.Lock() - defer w.Unlock() - - if len(w.activeConn) == 0 { - return nil - } - - for addr, conn := range w.activeConn { - if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 8 { - delete(w.activeConn, addr) - conn.Close() // nolint: errcheck - } - } - - if len(w.activeConn) == 0 { - w.activeConn = make(map[connID]*udpConn, 16) - } - - return nil - }, - } - if err := w.checker.Start(); err != nil { - return err + Execute: w.clean, } + w.hub = h go w.handlePackets() return nil diff --git a/app/router/condition.go b/app/router/condition.go index 9e1cd272a..9fcbbf33c 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -100,14 +100,8 @@ func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) { g.Add(m) } - if len(domains) < 64 { - return &DomainMatcher{ - matchers: g, - }, nil - } - return &DomainMatcher{ - matchers: strmatcher.NewCachedMatcherGroup(g), + matchers: g, }, nil } diff --git a/common/signal/pubsub/pubsub.go b/common/signal/pubsub/pubsub.go index 7306ef2f9..e71ed8698 100644 --- a/common/signal/pubsub/pubsub.go +++ b/common/signal/pubsub/pubsub.go @@ -1,6 +1,7 @@ package pubsub import ( + "errors" "sync" "time" @@ -47,7 +48,6 @@ func NewService() *Service { Execute: s.Cleanup, Interval: time.Second * 30, } - common.Must(s.ctask.Start()) return s } @@ -57,6 +57,10 @@ func (s *Service) Cleanup() error { s.Lock() defer s.Unlock() + if len(s.subs) == 0 { + return errors.New("nothing to do") + } + for name, subs := range s.subs { newSub := make([]*Subscriber, 0, len(s.subs)) for _, sub := range subs { @@ -86,6 +90,7 @@ func (s *Service) Subscribe(name string) *Subscriber { subs := append(s.subs[name], sub) s.subs[name] = subs s.Unlock() + common.Must(s.ctask.Start()) return sub } diff --git a/common/strmatcher/benchmark_test.go b/common/strmatcher/benchmark_test.go index 1e7ff5563..de5f5f626 100644 --- a/common/strmatcher/benchmark_test.go +++ b/common/strmatcher/benchmark_test.go @@ -47,20 +47,3 @@ func BenchmarkMarchGroup(b *testing.B) { _ = g.Match("0.v2ray.com") } } - -func BenchmarkCachedMarchGroup(b *testing.B) { - g := new(MatcherGroup) - for i := 1; i <= 1024; i++ { - m, err := Domain.New(strconv.Itoa(i) + ".v2ray.com") - common.Must(err) - g.Add(m) - } - - cg := NewCachedMatcherGroup(g) - _ = cg.Match("0.v2ray.com") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = cg.Match("0.v2ray.com") - } -} diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go index 7f4eac527..fb63eda56 100644 --- a/common/strmatcher/strmatcher.go +++ b/common/strmatcher/strmatcher.go @@ -2,11 +2,6 @@ package strmatcher import ( "regexp" - "sync" - "time" - - "v2ray.com/core/common" - "v2ray.com/core/common/task" ) // Matcher is the interface to determine a string matches a pattern. @@ -114,71 +109,3 @@ func (g *MatcherGroup) Match(pattern string) uint32 { func (g *MatcherGroup) Size() uint32 { return g.count } - -type cacheEntry struct { - timestamp time.Time - result uint32 -} - -// CachedMatcherGroup is a IndexMatcher with cachable results. -type CachedMatcherGroup struct { - sync.RWMutex - group *MatcherGroup - cache map[string]cacheEntry - cleanup *task.Periodic -} - -// NewCachedMatcherGroup creats a new CachedMatcherGroup. -func NewCachedMatcherGroup(g *MatcherGroup) *CachedMatcherGroup { - r := &CachedMatcherGroup{ - group: g, - cache: make(map[string]cacheEntry), - } - r.cleanup = &task.Periodic{ - Interval: time.Second * 30, - Execute: func() error { - r.Lock() - defer r.Unlock() - - if len(r.cache) == 0 { - return nil - } - - expire := time.Now().Add(-1 * time.Second * 120) - for p, e := range r.cache { - if e.timestamp.Before(expire) { - delete(r.cache, p) - } - } - - if len(r.cache) == 0 { - r.cache = make(map[string]cacheEntry) - } - - return nil - }, - } - common.Must(r.cleanup.Start()) - return r -} - -// Match implements IndexMatcher.Match. -func (g *CachedMatcherGroup) Match(pattern string) uint32 { - g.RLock() - r, f := g.cache[pattern] - g.RUnlock() - if f { - return r.result - } - - mr := g.group.Match(pattern) - - g.Lock() - g.cache[pattern] = cacheEntry{ - result: mr, - timestamp: time.Now(), - } - g.Unlock() - - return mr -} diff --git a/common/task/periodic.go b/common/task/periodic.go index 305a4933f..23f559c1a 100644 --- a/common/task/periodic.go +++ b/common/task/periodic.go @@ -11,25 +11,17 @@ type Periodic struct { Interval time.Duration // Execute is the task function Execute func() error - // OnFailure will be called when Execute returns non-nil error - OnError func(error) - access sync.RWMutex - timer *time.Timer - closed bool -} - -func (t *Periodic) setClosed(f bool) { - t.access.Lock() - t.closed = f - t.access.Unlock() + access sync.Mutex + timer *time.Timer + running bool } func (t *Periodic) hasClosed() bool { - t.access.RLock() - defer t.access.RUnlock() + t.access.Lock() + defer t.access.Unlock() - return t.closed + return !t.running } func (t *Periodic) checkedExecute() error { @@ -38,31 +30,39 @@ func (t *Periodic) checkedExecute() error { } if err := t.Execute(); err != nil { + t.access.Lock() + t.running = false + t.access.Unlock() return err } t.access.Lock() defer t.access.Unlock() - if t.closed { + if !t.running { return nil } t.timer = time.AfterFunc(t.Interval, func() { - if err := t.checkedExecute(); err != nil && t.OnError != nil { - t.OnError(err) - } + t.checkedExecute() // nolint: errcheck }) return nil } -// Start implements common.Runnable. Start must not be called multiple times without Close being called. +// Start implements common.Runnable. func (t *Periodic) Start() error { - t.setClosed(false) + t.access.Lock() + if t.running { + return nil + } + t.running = true + t.access.Unlock() if err := t.checkedExecute(); err != nil { - t.setClosed(true) + t.access.Lock() + t.running = false + t.access.Unlock() return err } @@ -74,7 +74,7 @@ func (t *Periodic) Close() error { t.access.Lock() defer t.access.Unlock() - t.closed = true + t.running = false if t.timer != nil { t.timer.Stop() t.timer = nil diff --git a/common/task/periodic_test.go b/common/task/periodic_test.go index 1abfa1b95..a0bc5c6c7 100644 --- a/common/task/periodic_test.go +++ b/common/task/periodic_test.go @@ -27,4 +27,10 @@ func TestPeriodicTaskStop(t *testing.T) { assert(value, Equals, 3) time.Sleep(time.Second * 4) assert(value, Equals, 3) + common.Must(task.Start()) + time.Sleep(time.Second * 3) + if value != 5 { + t.Fatal("Expected 5, but ", value) + } + common.Must(task.Close()) } diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 1c027cea5..42117b58c 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -42,12 +42,8 @@ func NewSessionHistory() *SessionHistory { } h.task = &task.Periodic{ Interval: time.Second * 30, - Execute: func() error { - h.removeExpiredEntries() - return nil - }, + Execute: h.removeExpiredEntries, } - common.Must(h.task.Start()) return h } @@ -58,24 +54,26 @@ func (h *SessionHistory) Close() error { func (h *SessionHistory) addIfNotExits(session sessionId) bool { h.Lock() - defer h.Unlock() if expire, found := h.cache[session]; found && expire.After(time.Now()) { + h.Unlock() return false } h.cache[session] = time.Now().Add(time.Minute * 3) + h.Unlock() + common.Must(h.task.Start()) return true } -func (h *SessionHistory) removeExpiredEntries() { +func (h *SessionHistory) removeExpiredEntries() error { now := time.Now() h.Lock() defer h.Unlock() if len(h.cache) == 0 { - return + return newError("nothing to do") } for session, expire := range h.cache { @@ -87,6 +85,8 @@ func (h *SessionHistory) removeExpiredEntries() { if len(h.cache) == 0 { h.cache = make(map[sessionId]time.Time, 128) } + + return nil } // ServerSession keeps information for a session in VMess server.