diff --git a/app/stats/command/command.go b/app/stats/command/command.go index 0423e5330..21c854465 100644 --- a/app/stats/command/command.go +++ b/app/stats/command/command.go @@ -63,7 +63,7 @@ func (s *statsServer) QueryStats(ctx context.Context, request *QueryStatsRequest return nil, newError("QueryStats only works its own stats.Manager.") } - manager.Visit(func(name string, c feature_stats.Counter) bool { + manager.VisitCounters(func(name string, c feature_stats.Counter) bool { if matcher.Match(name) { var value int64 if request.Reset_ { diff --git a/app/stats/stats.go b/app/stats/stats.go index d06cc39c5..7bb56b287 100644 --- a/app/stats/stats.go +++ b/app/stats/stats.go @@ -8,6 +8,7 @@ import ( "context" "sync" "sync/atomic" + "time" "v2ray.com/core/features/stats" ) @@ -32,15 +33,76 @@ func (c *Counter) Add(delta int64) int64 { return atomic.AddInt64(&c.value, delta) } +// Channel is an implementation of stats.Channel +type Channel struct { + channel chan interface{} + subscribers []chan interface{} + access sync.RWMutex +} + +// Channel implements stats.Channel +func (c *Channel) Channel() chan interface{} { + return c.channel +} + +// Subscribers implements stats.Channel +func (c *Channel) Subscribers() []chan interface{} { + c.access.RLock() + defer c.access.RUnlock() + return c.subscribers +} + +// Subscribe implements stats.Channel +func (c *Channel) Subscribe() chan interface{} { + c.access.Lock() + defer c.access.Unlock() + ch := make(chan interface{}) + c.subscribers = append(c.subscribers, ch) + return ch +} + +// Unsubscribe implements stats.Channel +func (c *Channel) Unsubscribe(ch chan interface{}) { + c.access.Lock() + defer c.access.Unlock() + for i, s := range c.subscribers { + if s == ch { + // Copy to new memory block to prevent modifying original data + subscribers := make([]chan interface{}, len(c.subscribers)-1) + copy(subscribers[:i], c.subscribers[:i]) + copy(subscribers[i:], c.subscribers[i+1:]) + c.subscribers = subscribers + return + } + } +} + +// Start starts the channel for listening to messsages +func (c *Channel) Start() { + for message := range c.Channel() { + subscribers := c.Subscribers() // Store a copy of slice value for concurrency safety + for _, sub := range subscribers { + select { + case sub <- message: // Successfully sent message + case <-time.After(100 * time.Millisecond): + c.Unsubscribe(sub) // Remove timeout subscriber + close(sub) // Actively close subscriber as notification + } + } + } +} + // Manager is an implementation of stats.Manager. type Manager struct { access sync.RWMutex counters map[string]*Counter + channels map[string]*Channel } func NewManager(ctx context.Context, config *Config) (*Manager, error) { m := &Manager{ counters: make(map[string]*Counter), + channels: make(map[string]*Channel), } return m, nil @@ -50,6 +112,7 @@ func (*Manager) Type() interface{} { return stats.ManagerType() } +// RegisterCounter implements stats.Manager. func (m *Manager) RegisterCounter(name string) (stats.Counter, error) { m.access.Lock() defer m.access.Unlock() @@ -63,6 +126,7 @@ func (m *Manager) RegisterCounter(name string) (stats.Counter, error) { return c, nil } +// GetCounter implements stats.Manager. func (m *Manager) GetCounter(name string) stats.Counter { m.access.RLock() defer m.access.RUnlock() @@ -73,7 +137,8 @@ func (m *Manager) GetCounter(name string) stats.Counter { return nil } -func (m *Manager) Visit(visitor func(string, stats.Counter) bool) { +// VisitCounters calls visitor function on all managed counters. +func (m *Manager) VisitCounters(visitor func(string, stats.Counter) bool) { m.access.RLock() defer m.access.RUnlock() @@ -84,6 +149,32 @@ func (m *Manager) Visit(visitor func(string, stats.Counter) bool) { } } +// RegisterChannel implements stats.Manager. +func (m *Manager) RegisterChannel(name string) (stats.Channel, error) { + m.access.Lock() + defer m.access.Unlock() + + if _, found := m.channels[name]; found { + return nil, newError("Channel ", name, " already registered.") + } + newError("create new channel ", name).AtDebug().WriteToLog() + c := &Channel{channel: make(chan interface{})} + m.channels[name] = c + go c.Start() + return c, nil +} + +// GetChannel implements stats.Manager. +func (m *Manager) GetChannel(name string) stats.Channel { + m.access.RLock() + defer m.access.RUnlock() + + if c, found := m.channels[name]; found { + return c + } + return nil +} + // Start implements common.Runnable. func (m *Manager) Start() error { return nil diff --git a/app/stats/stats_test.go b/app/stats/stats_test.go index 2a0d9b91a..0c724257b 100644 --- a/app/stats/stats_test.go +++ b/app/stats/stats_test.go @@ -2,14 +2,16 @@ package stats_test import ( "context" + "fmt" "testing" + "time" . "v2ray.com/core/app/stats" "v2ray.com/core/common" "v2ray.com/core/features/stats" ) -func TestInternface(t *testing.T) { +func TestInterface(t *testing.T) { _ = (stats.Manager)(new(Manager)) } @@ -33,3 +35,317 @@ func TestStatsCounter(t *testing.T) { t.Fatal("unexpected Value() return: ", v, ", wanted ", 0) } } + +func TestStatsChannel(t *testing.T) { + raw, err := common.CreateObject(context.Background(), &Config{}) + common.Must(err) + + m := raw.(stats.Manager) + c, err := m.RegisterChannel("test.channel") + common.Must(err) + + source := c.Channel() + a := c.Subscribe() + b := c.Subscribe() + defer c.Unsubscribe(a) + defer c.Unsubscribe(b) + + stopCh := make(chan struct{}) + errCh := make(chan string) + + go func() { + source <- 1 + source <- 2 + source <- "3" + source <- []int{4} + source <- nil // Dummy messsage with no subscriber receiving + select { + case source <- nil: // Source should be blocked here, for last message was not cleared + errCh <- fmt.Sprint("unexpected non-blocked source") + default: + close(stopCh) + } + }() + + go func() { + if v, ok := (<-a).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-a).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + if v, ok := (<-a).(string); !ok || v != "3" { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", "3") + } + if v, ok := (<-a).([]int); !ok || v[0] != 4 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4}) + } + }() + + go func() { + if v, ok := (<-b).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-b).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + if v, ok := (<-b).(string); !ok || v != "3" { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", "3") + } + if v, ok := (<-b).([]int); !ok || v[0] != 4 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", []int{4}) + } + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } +} + +func TestStatsChannelUnsubcribe(t *testing.T) { + raw, err := common.CreateObject(context.Background(), &Config{}) + common.Must(err) + + m := raw.(stats.Manager) + c, err := m.RegisterChannel("test.channel") + common.Must(err) + + source := c.Channel() + a := c.Subscribe() + b := c.Subscribe() + defer c.Unsubscribe(a) + + pauseCh := make(chan struct{}) + stopCh := make(chan struct{}) + errCh := make(chan string) + + { + var aSet, bSet bool + for _, s := range c.Subscribers() { + if s == a { + aSet = true + } + if s == b { + bSet = true + } + } + if !(aSet && bSet) { + t.Fatal("unexpected subscribers: ", c.Subscribers()) + } + } + + go func() { + source <- 1 + <-pauseCh // Wait for `b` goroutine to resume sending message + source <- 2 + }() + + go func() { + if v, ok := (<-a).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-a).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + }() + + go func() { + if v, ok := (<-b).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + // Unsubscribe `b` while `source`'s messaging is paused + c.Unsubscribe(b) + { // Test `b` is not in subscribers + var aSet, bSet bool + for _, s := range c.Subscribers() { + if s == a { + aSet = true + } + if s == b { + bSet = true + } + } + if !(aSet && !bSet) { + errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) + } + } + // Resume `source`'s progress + close(pauseCh) + // Test `b` is neither closed nor able to receive any data + select { + case v, ok := <-b: + if ok { + errCh <- fmt.Sprint("unexpected data received: ", v) + } else { + errCh <- fmt.Sprint("unexpected closed channel: ", b) + } + default: + } + close(stopCh) + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } +} + +func TestStatsChannelTimeout(t *testing.T) { + raw, err := common.CreateObject(context.Background(), &Config{}) + common.Must(err) + + m := raw.(stats.Manager) + c, err := m.RegisterChannel("test.channel") + common.Must(err) + + source := c.Channel() + a := c.Subscribe() + b := c.Subscribe() + defer c.Unsubscribe(a) + defer c.Unsubscribe(b) + + stopCh := make(chan struct{}) + errCh := make(chan string) + + go func() { + source <- 1 + source <- 2 + }() + + go func() { + if v, ok := (<-a).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-a).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + { // Test `b` is still in subscribers yet (because `a` receives 2 first) + var aSet, bSet bool + for _, s := range c.Subscribers() { + if s == a { + aSet = true + } + if s == b { + bSet = true + } + } + if !(aSet && bSet) { + errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) + } + } + }() + + go func() { + if v, ok := (<-b).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + // Block `b` channel for a time longer than `source`'s timeout + <-time.After(150 * time.Millisecond) + { // Test `b` has been unsubscribed by source + var aSet, bSet bool + for _, s := range c.Subscribers() { + if s == a { + aSet = true + } + if s == b { + bSet = true + } + } + if !(aSet && !bSet) { + errCh <- fmt.Sprint("unexpected subscribers: ", c.Subscribers()) + } + } + select { // Test `b` has been closed by source + case v, ok := <-b: + if ok { + errCh <- fmt.Sprint("unexpected data received: ", v) + } + default: + } + close(stopCh) + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } +} + +func TestStatsChannelConcurrency(t *testing.T) { + raw, err := common.CreateObject(context.Background(), &Config{}) + common.Must(err) + + m := raw.(stats.Manager) + c, err := m.RegisterChannel("test.channel") + common.Must(err) + + source := c.Channel() + a := c.Subscribe() + b := c.Subscribe() + defer c.Unsubscribe(a) + + stopCh := make(chan struct{}) + errCh := make(chan string) + + go func() { + source <- 1 + source <- 2 + }() + + go func() { + if v, ok := (<-a).(int); !ok || v != 1 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 1) + } + if v, ok := (<-a).(int); !ok || v != 2 { + errCh <- fmt.Sprint("unexpected receiving: ", v, ", wanted ", 2) + } + }() + + go func() { + // Block `b` for a time shorter than `source`'s timeout + // So as to ensure source channel is trying to send message to `b`. + <-time.After(25 * time.Millisecond) + // This causes concurrency scenario: unsubscribe `b` while trying to send message to it + c.Unsubscribe(b) + // Test `b` is not closed and can still receive data 1: + // Because unsubscribe won't affect the ongoing process of sending message. + select { + case v, ok := <-b: + if v1, ok1 := v.(int); !(ok && ok1 && v1 == 1) { + errCh <- fmt.Sprint("unexpected failure in receiving data: ", 1) + } + default: + errCh <- fmt.Sprint("unexpected block from receiving data: ", 1) + } + // Test `b` is not closed but cannot receive data 2: + // Becuase in a new round of messaging, `b` has been unsubscribed. + select { + case v, ok := <-b: + if ok { + errCh <- fmt.Sprint("unexpected receving: ", v) + } else { + errCh <- fmt.Sprint("unexpected closing of channel") + } + default: + } + close(stopCh) + }() + + select { + case <-time.After(2 * time.Second): + t.Fatal("Test timeout after 2s") + case e := <-errCh: + t.Fatal(e) + case <-stopCh: + } +} diff --git a/features/stats/stats.go b/features/stats/stats.go index b6d99620f..2a0acec2b 100644 --- a/features/stats/stats.go +++ b/features/stats/stats.go @@ -16,6 +16,20 @@ type Counter interface { Add(int64) int64 } +// Channel is the interface for stats channel +// +// v2ray:api:stable +type Channel interface { + // Channel returns the underlying go channel. + Channel() chan interface{} + // SubscriberCount returns the number of the subscribers. + Subscribers() []chan interface{} + // Subscribe registers for listening to channel stream and returns a new listener channel. + Subscribe() chan interface{} + // Unsubscribe unregisters a listener channel from current Channel object. + Unsubscribe(chan interface{}) +} + // Manager is the interface for stats manager. // // v2ray:api:stable @@ -26,6 +40,11 @@ type Manager interface { RegisterCounter(string) (Counter, error) // GetCounter returns a counter by its identifier. GetCounter(string) Counter + + // RegisterChannel registers a new channel to the manager. The identifier string must not be empty, and unique among other channels. + RegisterChannel(string) (Channel, error) + // GetChannel returns a channel by its identifier. + GetChannel(string) Channel } // GetOrRegisterCounter tries to get the StatCounter first. If not exist, it then tries to create a new counter. @@ -38,6 +57,16 @@ func GetOrRegisterCounter(m Manager, name string) (Counter, error) { return m.RegisterCounter(name) } +// GetOrRegisterChannel tries to get the StatChannel first. If not exist, it then tries to create a new channel. +func GetOrRegisterChannel(m Manager, name string) (Channel, error) { + channel := m.GetChannel(name) + if channel != nil { + return channel, nil + } + + return m.RegisterChannel(name) +} + // ManagerType returns the type of Manager interface. Can be used to implement common.HasType. // // v2ray:api:stable @@ -63,6 +92,16 @@ func (NoopManager) GetCounter(string) Counter { return nil } +// RegisterChannel implements Manager. +func (NoopManager) RegisterChannel(string) (Channel, error) { + return nil, newError("not implemented") +} + +// GetChannel implements Manager. +func (NoopManager) GetChannel(string) Channel { + return nil +} + // Start implements common.Runnable. func (NoopManager) Start() error { return nil }