From 29ad2cbbdb4445b1a8d554d102ef2ac9c58655dd Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Thu, 12 Jul 2018 23:38:10 +0200 Subject: [PATCH] function to compare byte array --- common/compare/bytes.go | 29 +++++++++++++++++++ common/compare/bytes_test.go | 43 ++++++++++++++++++++++++++++ common/net/address.go | 4 +-- common/peer/latency.go | 30 +++++++++++++++++++ common/peer/peer.go | 1 + common/predicate/arrays.go | 10 ------- common/predicate/predicate.go | 39 ------------------------- common/protocol/id.go | 5 ++-- common/protocol/id_test.go | 4 +-- proxy/mtproto/server.go | 4 +-- testing/scenarios/command_test.go | 20 +++++++++---- transport/internet/kcp/connection.go | 13 +++++---- 12 files changed, 132 insertions(+), 70 deletions(-) create mode 100644 common/compare/bytes.go create mode 100644 common/compare/bytes_test.go create mode 100644 common/peer/latency.go create mode 100644 common/peer/peer.go delete mode 100644 common/predicate/arrays.go delete mode 100644 common/predicate/predicate.go diff --git a/common/compare/bytes.go b/common/compare/bytes.go new file mode 100644 index 000000000..fe7202bae --- /dev/null +++ b/common/compare/bytes.go @@ -0,0 +1,29 @@ +package compare + +import "v2ray.com/core/common/errors" + +func BytesEqualWithDetail(a []byte, b []byte) error { + if len(a) != len(b) { + return errors.New("mismatch array length ", len(a), " vs ", len(b)) + } + for idx, v := range a { + if b[idx] != v { + return errors.New("mismatch array value at index [", idx, "]: ", v, " vs ", b[idx]) + } + } + return nil +} + +func BytesEqual(a []byte, b []byte) bool { + return BytesEqualWithDetail(a, b) == nil +} + +func BytesAll(arr []byte, value byte) bool { + for _, v := range arr { + if v != value { + return false + } + } + + return true +} diff --git a/common/compare/bytes_test.go b/common/compare/bytes_test.go new file mode 100644 index 000000000..21dffa196 --- /dev/null +++ b/common/compare/bytes_test.go @@ -0,0 +1,43 @@ +package compare_test + +import ( + "testing" + + . "v2ray.com/core/common/compare" +) + +func TestBytesEqual(t *testing.T) { + testCases := []struct { + Input1 []byte + Input2 []byte + Result bool + }{ + { + Input1: []byte{}, + Input2: []byte{1}, + Result: false, + }, + { + Input1: nil, + Input2: []byte{}, + Result: true, + }, + { + Input1: []byte{1}, + Input2: []byte{1}, + Result: true, + }, + { + Input1: []byte{1, 2}, + Input2: []byte{1, 3}, + Result: false, + }, + } + + for _, testCase := range testCases { + cmp := BytesEqual(testCase.Input1, testCase.Input2) + if cmp != testCase.Result { + t.Errorf("unexpected result %v from %v", cmp, testCase) + } + } +} diff --git a/common/net/address.go b/common/net/address.go index 04d40f0c7..5918cf415 100644 --- a/common/net/address.go +++ b/common/net/address.go @@ -4,7 +4,7 @@ import ( "net" "strings" - "v2ray.com/core/common/predicate" + "v2ray.com/core/common/compare" ) var ( @@ -94,7 +94,7 @@ func IPAddress(ip []byte) Address { var addr ipv4Address = [4]byte{ip[0], ip[1], ip[2], ip[3]} return addr case net.IPv6len: - if predicate.BytesAll(ip[0:10], 0) && predicate.BytesAll(ip[10:12], 0xff) { + if compare.BytesAll(ip[0:10], 0) && compare.BytesAll(ip[10:12], 0xff) { return IPAddress(ip[12:16]) } var addr ipv6Address = [16]byte{ diff --git a/common/peer/latency.go b/common/peer/latency.go new file mode 100644 index 000000000..aae292ede --- /dev/null +++ b/common/peer/latency.go @@ -0,0 +1,30 @@ +package peer + +import ( + "sync" +) + +type Latency interface { + Value() uint64 +} + +type HasLatency interface { + ConnectionLatency() Latency + HandshakeLatency() Latency +} + +type AverageLatency struct { + access sync.Mutex + value uint64 +} + +func (al *AverageLatency) Update(newValue uint64) { + al.access.Lock() + defer al.access.Unlock() + + al.value = (al.value + newValue*2) / 3 +} + +func (al *AverageLatency) Value() uint64 { + return al.value +} diff --git a/common/peer/peer.go b/common/peer/peer.go new file mode 100644 index 000000000..333defff1 --- /dev/null +++ b/common/peer/peer.go @@ -0,0 +1 @@ +package peer diff --git a/common/predicate/arrays.go b/common/predicate/arrays.go deleted file mode 100644 index 03e043212..000000000 --- a/common/predicate/arrays.go +++ /dev/null @@ -1,10 +0,0 @@ -package predicate - -func BytesAll(array []byte, b byte) bool { - for _, val := range array { - if val != b { - return false - } - } - return true -} diff --git a/common/predicate/predicate.go b/common/predicate/predicate.go deleted file mode 100644 index fbeb703e4..000000000 --- a/common/predicate/predicate.go +++ /dev/null @@ -1,39 +0,0 @@ -package predicate // import "v2ray.com/core/common/predicate" - -type Predicate func() bool - -func (v Predicate) And(predicate Predicate) Predicate { - return All(v, predicate) -} - -func (v Predicate) Or(predicate Predicate) Predicate { - return Any(v, predicate) -} - -func All(predicates ...Predicate) Predicate { - return func() bool { - for _, p := range predicates { - if !p() { - return false - } - } - return true - } -} - -func Any(predicates ...Predicate) Predicate { - return func() bool { - for _, p := range predicates { - if p() { - return true - } - } - return false - } -} - -func Not(predicate Predicate) Predicate { - return func() bool { - return !predicate() - } -} diff --git a/common/protocol/id.go b/common/protocol/id.go index f5827a971..f7a32e06f 100755 --- a/common/protocol/id.go +++ b/common/protocol/id.go @@ -56,7 +56,7 @@ func NewID(uuid uuid.UUID) *ID { return id } -func nextId(u *uuid.UUID) uuid.UUID { +func nextID(u *uuid.UUID) uuid.UUID { md5hash := md5.New() common.Must2(md5hash.Write(u.Bytes())) common.Must2(md5hash.Write([]byte("16167dc8-16b6-4e6d-b8bb-65dd68113a81"))) @@ -74,8 +74,7 @@ func NewAlterIDs(primary *ID, alterIDCount uint16) []*ID { alterIDs := make([]*ID, alterIDCount) prevID := primary.UUID() for idx := range alterIDs { - newid := nextId(&prevID) - // TODO: check duplicates + newid := nextID(&prevID) alterIDs[idx] = NewID(newid) prevID = newid } diff --git a/common/protocol/id_test.go b/common/protocol/id_test.go index 59b4a0055..a412cef9a 100644 --- a/common/protocol/id_test.go +++ b/common/protocol/id_test.go @@ -3,7 +3,7 @@ package protocol_test import ( "testing" - "v2ray.com/core/common/predicate" + "v2ray.com/core/common/compare" . "v2ray.com/core/common/protocol" "v2ray.com/core/common/uuid" . "v2ray.com/ext/assert" @@ -13,7 +13,7 @@ func TestCmdKey(t *testing.T) { assert := With(t) id := NewID(uuid.New()) - assert(predicate.BytesAll(id.CmdKey(), 0), IsFalse) + assert(compare.BytesAll(id.CmdKey(), 0), IsFalse) } func TestIdEquals(t *testing.T) { diff --git a/proxy/mtproto/server.go b/proxy/mtproto/server.go index 8d16225d4..a6f980057 100644 --- a/proxy/mtproto/server.go +++ b/proxy/mtproto/server.go @@ -7,9 +7,9 @@ import ( "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" + "v2ray.com/core/common/compare" "v2ray.com/core/common/crypto" "v2ray.com/core/common/net" - "v2ray.com/core/common/predicate" "v2ray.com/core/common/protocol" "v2ray.com/core/common/session" "v2ray.com/core/common/signal" @@ -85,7 +85,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:]) decryptor.XORKeyStream(auth.Header[:], auth.Header[:]) - if !predicate.BytesAll(auth.Header[56:60], 0xef) { + if !compare.BytesAll(auth.Header[56:60], 0xef) { return newError("invalid connection type: ", auth.Header[56:60]) } diff --git a/testing/scenarios/command_test.go b/testing/scenarios/command_test.go index 426d8e083..e6fa38b1d 100644 --- a/testing/scenarios/command_test.go +++ b/testing/scenarios/command_test.go @@ -18,6 +18,7 @@ import ( "v2ray.com/core/app/router" "v2ray.com/core/app/stats" statscmd "v2ray.com/core/app/stats/command" + "v2ray.com/core/common/compare" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" @@ -101,12 +102,17 @@ func TestCommanderRemoveHandler(t *testing.T) { servers, err := InitializeServerConfigs(clientConfig) assert(err, IsNil) + defer CloseAllServers(servers) + { conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ IP: []byte{127, 0, 0, 1}, Port: int(clientPort), }) - assert(err, IsNil) + if err != nil { + t.Fatal(err) + } + defer conn.Close() // nolint: errcheck payload := "commander request." nBytes, err := conn.Write([]byte(payload)) @@ -116,8 +122,9 @@ func TestCommanderRemoveHandler(t *testing.T) { response := make([]byte, 1024) nBytes, err = conn.Read(response) assert(err, IsNil) - assert(response[:nBytes], Equals, xor([]byte(payload))) - assert(conn.Close(), IsNil) + if err := compare.BytesEqualWithDetail(response[:nBytes], xor([]byte(payload))); err != nil { + t.Fatal(err) + } } cmdConn, err := grpc.Dial(fmt.Sprintf("127.0.0.1:%d", cmdPort), grpc.WithInsecure(), grpc.WithBlock()) @@ -137,8 +144,6 @@ func TestCommanderRemoveHandler(t *testing.T) { }) assert(err, IsNotNil) } - - CloseAllServers(servers) } func TestCommanderAddRemoveUser(t *testing.T) { @@ -487,7 +492,10 @@ func TestCommanderStats(t *testing.T) { } servers, err := InitializeServerConfigs(serverConfig, clientConfig) - assert(err, IsNil) + if err != nil { + t.Fatal("Failed to create all servers", err) + } + defer CloseAllServers(servers) conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ IP: []byte{127, 0, 0, 1}, diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 55f6fa896..033698dc7 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -8,7 +8,6 @@ import ( "time" "v2ray.com/core/common/buf" - "v2ray.com/core/common/predicate" "v2ray.com/core/common/signal" "v2ray.com/core/common/signal/semaphore" ) @@ -119,13 +118,13 @@ func (info *RoundTripInfo) SmoothedTime() uint32 { type Updater struct { interval int64 - shouldContinue predicate.Predicate - shouldTerminate predicate.Predicate + shouldContinue func() bool + shouldTerminate func() bool updateFunc func() notifier *semaphore.Instance } -func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater { +func NewUpdater(interval uint32, shouldContinue func() bool, shouldTerminate func() bool, updateFunc func()) *Updater { u := &Updater{ interval: int64(time.Duration(interval) * time.Millisecond), shouldContinue: shouldContinue, @@ -230,12 +229,14 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con } conn.dataUpdater = NewUpdater( config.GetTTIValue(), - predicate.Not(isTerminating).And(predicate.Any(conn.sendingWorker.UpdateNecessary, conn.receivingWorker.UpdateNecessary)), + func() bool { + return !isTerminating() && (conn.sendingWorker.UpdateNecessary() || conn.receivingWorker.UpdateNecessary()) + }, isTerminating, conn.updateTask) conn.pingUpdater = NewUpdater( 5000, // 5 seconds - predicate.Not(isTerminated), + func() bool { return !isTerminated() }, isTerminated, conn.updateTask) conn.pingUpdater.WakeUp()