1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-02 07:26:24 -05:00

remove dep of assert lib

This commit is contained in:
Darien Raymond 2019-02-10 15:02:28 +01:00
parent a84897b4b6
commit 98950d5ada
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
14 changed files with 416 additions and 326 deletions

View File

@ -8,19 +8,16 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
. "v2ray.com/core/common/buf" . "v2ray.com/core/common/buf"
"v2ray.com/core/transport/pipe" "v2ray.com/core/transport/pipe"
. "v2ray.com/ext/assert"
) )
func TestBytesReaderWriteTo(t *testing.T) { func TestBytesReaderWriteTo(t *testing.T) {
assert := With(t)
pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024)) pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
reader := &BufferedReader{Reader: pReader} reader := &BufferedReader{Reader: pReader}
b1 := New() b1 := New()
b1.WriteString("abc") b1.WriteString("abc")
b2 := New() b2 := New()
b2.WriteString("efg") b2.WriteString("efg")
assert(pWriter.WriteMultiBuffer(MultiBuffer{b1, b2}), IsNil) common.Must(pWriter.WriteMultiBuffer(MultiBuffer{b1, b2}))
pWriter.Close() pWriter.Close()
pReader2, pWriter2 := pipe.New(pipe.WithSizeLimit(1024)) pReader2, pWriter2 := pipe.New(pipe.WithSizeLimit(1024))
@ -29,33 +26,33 @@ func TestBytesReaderWriteTo(t *testing.T) {
nBytes, err := io.Copy(writer, reader) nBytes, err := io.Copy(writer, reader)
common.Must(err) common.Must(err)
assert(nBytes, Equals, int64(6)) if nBytes != 6 {
t.Error("copy: ", nBytes)
}
mb, err := pReader2.ReadMultiBuffer() mb, err := pReader2.ReadMultiBuffer()
common.Must(err) common.Must(err)
assert(len(mb), Equals, 2) if s := mb.String(); s != "abcefg" {
assert(mb[0].String(), Equals, "abc") t.Error("content: ", s)
assert(mb[1].String(), Equals, "efg") }
} }
func TestBytesReaderMultiBuffer(t *testing.T) { func TestBytesReaderMultiBuffer(t *testing.T) {
assert := With(t)
pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024)) pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
reader := &BufferedReader{Reader: pReader} reader := &BufferedReader{Reader: pReader}
b1 := New() b1 := New()
b1.WriteString("abc") b1.WriteString("abc")
b2 := New() b2 := New()
b2.WriteString("efg") b2.WriteString("efg")
assert(pWriter.WriteMultiBuffer(MultiBuffer{b1, b2}), IsNil) common.Must(pWriter.WriteMultiBuffer(MultiBuffer{b1, b2}))
pWriter.Close() pWriter.Close()
mbReader := NewReader(reader) mbReader := NewReader(reader)
mb, err := mbReader.ReadMultiBuffer() mb, err := mbReader.ReadMultiBuffer()
common.Must(err) common.Must(err)
assert(len(mb), Equals, 2) if s := mb.String(); s != "abcefg" {
assert(mb[0].String(), Equals, "abc") t.Error("content: ", s)
assert(mb[1].String(), Equals, "efg") }
} }
func TestReadByte(t *testing.T) { func TestReadByte(t *testing.T) {

View File

@ -8,16 +8,15 @@ import (
"io" "io"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
. "v2ray.com/core/common/crypto" . "v2ray.com/core/common/crypto"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
. "v2ray.com/ext/assert"
) )
func TestAuthenticationReaderWriter(t *testing.T) { func TestAuthenticationReaderWriter(t *testing.T) {
assert := With(t)
key := make([]byte, 16) key := make([]byte, 16)
rand.Read(key) rand.Read(key)
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
@ -31,7 +30,9 @@ func TestAuthenticationReaderWriter(t *testing.T) {
rand.Read(rawPayload) rand.Read(rawPayload)
payload := buf.MergeBytes(nil, rawPayload) payload := buf.MergeBytes(nil, rawPayload)
assert(payload.Len(), Equals, int32(payloadSize)) if r := cmp.Diff(payload.Bytes(), rawPayload); r != "" {
t.Error(r)
}
cache := bytes.NewBuffer(nil) cache := bytes.NewBuffer(nil)
iv := make([]byte, 12) iv := make([]byte, 12)
@ -43,9 +44,11 @@ func TestAuthenticationReaderWriter(t *testing.T) {
AdditionalDataGenerator: GenerateEmptyBytes(), AdditionalDataGenerator: GenerateEmptyBytes(),
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil) }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream, nil)
assert(writer.WriteMultiBuffer(payload), IsNil) common.Must(writer.WriteMultiBuffer(payload))
assert(cache.Len(), Equals, int(82658)) if cache.Len() <= 1024*80 {
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil) t.Error("cache len: ", cache.Len())
}
common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
reader := NewAuthenticationReader(&AEADAuthenticator{ reader := NewAuthenticationReader(&AEADAuthenticator{
AEAD: aead, AEAD: aead,
@ -62,19 +65,23 @@ func TestAuthenticationReaderWriter(t *testing.T) {
mb, _ = buf.MergeMulti(mb, mb2) mb, _ = buf.MergeMulti(mb, mb2)
} }
assert(mb.Len(), Equals, int32(payloadSize)) if mb.Len() != payloadSize {
t.Error("mb len: ", mb.Len())
}
mbContent := make([]byte, payloadSize) mbContent := make([]byte, payloadSize)
buf.SplitBytes(mb, mbContent) buf.SplitBytes(mb, mbContent)
assert(mbContent, Equals, rawPayload) if r := cmp.Diff(mbContent, rawPayload); r != "" {
t.Error(r)
}
_, err = reader.ReadMultiBuffer() _, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF) if err != io.EOF {
t.Error("error: ", err)
}
} }
func TestAuthenticationReaderWriterPacket(t *testing.T) { func TestAuthenticationReaderWriterPacket(t *testing.T) {
assert := With(t)
key := make([]byte, 16) key := make([]byte, 16)
common.Must2(rand.Read(key)) common.Must2(rand.Read(key))
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
@ -102,10 +109,12 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
pb2.Write([]byte("efgh")) pb2.Write([]byte("efgh"))
payload = append(payload, pb2) payload = append(payload, pb2)
assert(writer.WriteMultiBuffer(payload), IsNil) common.Must(writer.WriteMultiBuffer(payload))
assert(cache.Len(), GreaterThan, int32(0)) if cache.Len() == 0 {
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil) t.Error("cache len: ", cache.Len())
common.Must(err) }
common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))
reader := NewAuthenticationReader(&AEADAuthenticator{ reader := NewAuthenticationReader(&AEADAuthenticator{
AEAD: aead, AEAD: aead,
@ -117,13 +126,21 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
common.Must(err) common.Must(err)
mb, b1 := buf.SplitFirst(mb) mb, b1 := buf.SplitFirst(mb)
assert(b1.String(), Equals, "abcd") if b1.String() != "abcd" {
t.Error("b1: ", b1.String())
}
mb, b2 := buf.SplitFirst(mb) mb, b2 := buf.SplitFirst(mb)
assert(b2.String(), Equals, "efgh") if b2.String() != "efgh" {
t.Error("b2: ", b2.String())
}
assert(mb.IsEmpty(), IsTrue) if !mb.IsEmpty() {
t.Error("not empty")
}
_, err = reader.ReadMultiBuffer() _, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF) if err != io.EOF {
t.Error("error: ", err)
}
} }

View File

@ -8,12 +8,9 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
. "v2ray.com/core/common/crypto" . "v2ray.com/core/common/crypto"
. "v2ray.com/ext/assert"
) )
func TestChunkStreamIO(t *testing.T) { func TestChunkStreamIO(t *testing.T) {
assert := With(t)
cache := bytes.NewBuffer(make([]byte, 0, 8192)) cache := bytes.NewBuffer(make([]byte, 0, 8192))
writer := NewChunkStreamWriter(PlainChunkSizeParser{}, cache) writer := NewChunkStreamWriter(PlainChunkSizeParser{}, cache)
@ -36,14 +33,19 @@ func TestChunkStreamIO(t *testing.T) {
mb, err := reader.ReadMultiBuffer() mb, err := reader.ReadMultiBuffer()
common.Must(err) common.Must(err)
assert(mb.Len(), Equals, int32(4)) if s := mb.String(); s != "abcd" {
assert(mb[0].Bytes(), Equals, []byte("abcd")) t.Error("content: ", s)
}
mb, err = reader.ReadMultiBuffer() mb, err = reader.ReadMultiBuffer()
common.Must(err) common.Must(err)
assert(mb.Len(), Equals, int32(3))
assert(mb[0].Bytes(), Equals, []byte("efg")) if s := mb.String(); s != "efg" {
t.Error("content: ", s)
}
_, err = reader.ReadMultiBuffer() _, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF) if err != io.EOF {
t.Error("error: ", err)
}
} }

View File

@ -2,31 +2,39 @@ package errors_test
import ( import (
"io" "io"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
. "v2ray.com/core/common/errors" . "v2ray.com/core/common/errors"
"v2ray.com/core/common/log" "v2ray.com/core/common/log"
. "v2ray.com/ext/assert"
) )
func TestError(t *testing.T) { func TestError(t *testing.T) {
assert := With(t)
err := New("TestError") err := New("TestError")
assert(GetSeverity(err), Equals, log.Severity_Info) if v := GetSeverity(err); v != log.Severity_Info {
t.Error("severity: ", v)
}
err = New("TestError2").Base(io.EOF) err = New("TestError2").Base(io.EOF)
assert(GetSeverity(err), Equals, log.Severity_Info) if v := GetSeverity(err); v != log.Severity_Info {
t.Error("severity: ", v)
}
err = New("TestError3").Base(io.EOF).AtWarning() err = New("TestError3").Base(io.EOF).AtWarning()
assert(GetSeverity(err), Equals, log.Severity_Warning) if v := GetSeverity(err); v != log.Severity_Warning {
t.Error("severity: ", v)
}
err = New("TestError4").Base(io.EOF).AtWarning() err = New("TestError4").Base(io.EOF).AtWarning()
err = New("TestError5").Base(err) err = New("TestError5").Base(err)
assert(GetSeverity(err), Equals, log.Severity_Warning) if v := GetSeverity(err); v != log.Severity_Warning {
assert(err.Error(), HasSubstring, "EOF") t.Error("severity: ", v)
}
if v := err.Error(); !strings.Contains(v, "EOF") {
t.Error("error: ", v)
}
} }
type e struct{} type e struct{}

View File

@ -4,13 +4,14 @@ import (
"io" "io"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
. "v2ray.com/core/common/mux" . "v2ray.com/core/common/mux"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/transport/pipe" "v2ray.com/core/transport/pipe"
. "v2ray.com/ext/assert"
) )
func readAll(reader buf.Reader) (buf.MultiBuffer, error) { func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
@ -29,8 +30,6 @@ func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
} }
func TestReaderWriter(t *testing.T) { func TestReaderWriter(t *testing.T) {
assert := With(t)
pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024)) pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
dest := net.TCPDestination(net.DomainAddress("v2ray.com"), 80) dest := net.TCPDestination(net.DomainAddress("v2ray.com"), 80)
@ -48,94 +47,150 @@ func TestReaderWriter(t *testing.T) {
return writer.WriteMultiBuffer(buf.MultiBuffer{b}) return writer.WriteMultiBuffer(buf.MultiBuffer{b})
} }
assert(writePayload(writer, 'a', 'b', 'c', 'd'), IsNil) common.Must(writePayload(writer, 'a', 'b', 'c', 'd'))
assert(writePayload(writer2), IsNil) common.Must(writePayload(writer2))
assert(writePayload(writer, 'e', 'f', 'g', 'h'), IsNil) common.Must(writePayload(writer, 'e', 'f', 'g', 'h'))
assert(writePayload(writer3, 'x'), IsNil) common.Must(writePayload(writer3, 'x'))
writer.Close() writer.Close()
writer3.Close() writer3.Close()
assert(writePayload(writer2, 'y'), IsNil) common.Must(writePayload(writer2, 'y'))
writer2.Close() writer2.Close()
bytesReader := &buf.BufferedReader{Reader: pReader} bytesReader := &buf.BufferedReader{Reader: pReader}
var meta FrameMetadata {
err := meta.Unmarshal(bytesReader) var meta FrameMetadata
common.Must(err) common.Must(meta.Unmarshal(bytesReader))
assert(meta.SessionID, Equals, uint16(1)) if r := cmp.Diff(meta, FrameMetadata{
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew)) SessionID: 1,
assert(meta.Target, Equals, dest) SessionStatus: SessionStatusNew,
assert(byte(meta.Option), Equals, byte(OptionData)) Target: dest,
Option: OptionData,
}); r != "" {
t.Error("metadata: ", r)
}
data, err := readAll(NewStreamReader(bytesReader)) data, err := readAll(NewStreamReader(bytesReader))
common.Must(err) common.Must(err)
assert(len(data), Equals, 1) if s := data.String(); s != "abcd" {
assert(data[0].String(), Equals, "abcd") t.Error("data: ", s)
}
}
err = meta.Unmarshal(bytesReader) {
common.Must(err) var meta FrameMetadata
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew)) common.Must(meta.Unmarshal(bytesReader))
assert(meta.SessionID, Equals, uint16(2)) if r := cmp.Diff(meta, FrameMetadata{
assert(byte(meta.Option), Equals, byte(0)) SessionStatus: SessionStatusNew,
assert(meta.Target, Equals, dest2) SessionID: 2,
Option: 0,
Target: dest2,
}); r != "" {
t.Error("meta: ", r)
}
}
err = meta.Unmarshal(bytesReader) {
common.Must(err) var meta FrameMetadata
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep)) common.Must(meta.Unmarshal(bytesReader))
assert(meta.SessionID, Equals, uint16(1)) if r := cmp.Diff(meta, FrameMetadata{
assert(byte(meta.Option), Equals, byte(1)) SessionID: 1,
SessionStatus: SessionStatusKeep,
Option: 1,
}); r != "" {
t.Error("meta: ", r)
}
data, err = readAll(NewStreamReader(bytesReader)) data, err := readAll(NewStreamReader(bytesReader))
common.Must(err) common.Must(err)
assert(len(data), Equals, 1) if s := data.String(); s != "efgh" {
assert(data[0].String(), Equals, "efgh") t.Error("data: ", s)
}
}
err = meta.Unmarshal(bytesReader) {
common.Must(err) var meta FrameMetadata
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew)) common.Must(meta.Unmarshal(bytesReader))
assert(meta.SessionID, Equals, uint16(3)) if r := cmp.Diff(meta, FrameMetadata{
assert(byte(meta.Option), Equals, byte(1)) SessionID: 3,
assert(meta.Target, Equals, dest3) SessionStatus: SessionStatusNew,
Option: 1,
Target: dest3,
}); r != "" {
t.Error("meta: ", r)
}
data, err = readAll(NewStreamReader(bytesReader)) data, err := readAll(NewStreamReader(bytesReader))
common.Must(err) common.Must(err)
assert(len(data), Equals, 1) if s := data.String(); s != "x" {
assert(data[0].String(), Equals, "x") t.Error("data: ", s)
}
}
err = meta.Unmarshal(bytesReader) {
common.Must(err) var meta FrameMetadata
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd)) common.Must(meta.Unmarshal(bytesReader))
assert(meta.SessionID, Equals, uint16(1)) if r := cmp.Diff(meta, FrameMetadata{
assert(byte(meta.Option), Equals, byte(0)) SessionID: 1,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
err = meta.Unmarshal(bytesReader) {
common.Must(err) var meta FrameMetadata
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd)) common.Must(meta.Unmarshal(bytesReader))
assert(meta.SessionID, Equals, uint16(3)) if r := cmp.Diff(meta, FrameMetadata{
assert(byte(meta.Option), Equals, byte(0)) SessionID: 3,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
err = meta.Unmarshal(bytesReader) {
common.Must(err) var meta FrameMetadata
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep)) common.Must(meta.Unmarshal(bytesReader))
assert(meta.SessionID, Equals, uint16(2)) if r := cmp.Diff(meta, FrameMetadata{
assert(byte(meta.Option), Equals, byte(1)) SessionID: 2,
SessionStatus: SessionStatusKeep,
Option: 1,
}); r != "" {
t.Error("meta: ", r)
}
data, err = readAll(NewStreamReader(bytesReader)) data, err := readAll(NewStreamReader(bytesReader))
common.Must(err) common.Must(err)
assert(len(data), Equals, 1) if s := data.String(); s != "y" {
assert(data[0].String(), Equals, "y") t.Error("data: ", s)
}
}
err = meta.Unmarshal(bytesReader) {
common.Must(err) var meta FrameMetadata
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd)) common.Must(meta.Unmarshal(bytesReader))
assert(meta.SessionID, Equals, uint16(2)) if r := cmp.Diff(meta, FrameMetadata{
assert(byte(meta.Option), Equals, byte(0)) SessionID: 2,
SessionStatus: SessionStatusEnd,
Option: 0,
}); r != "" {
t.Error("meta: ", r)
}
}
pWriter.Close() pWriter.Close()
err = meta.Unmarshal(bytesReader) {
assert(err, IsNotNil) var meta FrameMetadata
err := meta.Unmarshal(bytesReader)
if err == nil {
t.Error("nil error")
}
}
} }

View File

@ -4,36 +4,48 @@ import (
"testing" "testing"
. "v2ray.com/core/common/mux" . "v2ray.com/core/common/mux"
. "v2ray.com/ext/assert"
) )
func TestSessionManagerAdd(t *testing.T) { func TestSessionManagerAdd(t *testing.T) {
assert := With(t)
m := NewSessionManager() m := NewSessionManager()
s := m.Allocate() s := m.Allocate()
assert(s.ID, Equals, uint16(1)) if s.ID != 1 {
assert(m.Size(), Equals, 1) t.Error("id: ", s.ID)
}
if m.Size() != 1 {
t.Error("size: ", m.Size())
}
s = m.Allocate() s = m.Allocate()
assert(s.ID, Equals, uint16(2)) if s.ID != 2 {
assert(m.Size(), Equals, 2) t.Error("id: ", s.ID)
}
if m.Size() != 2 {
t.Error("size: ", m.Size())
}
s = &Session{ s = &Session{
ID: 4, ID: 4,
} }
m.Add(s) m.Add(s)
assert(s.ID, Equals, uint16(4)) if s.ID != 4 {
t.Error("id: ", s.ID)
}
if m.Size() != 3 {
t.Error("size: ", m.Size())
}
} }
func TestSessionManagerClose(t *testing.T) { func TestSessionManagerClose(t *testing.T) {
assert := With(t)
m := NewSessionManager() m := NewSessionManager()
s := m.Allocate() s := m.Allocate()
assert(m.CloseIfNoSession(), IsFalse) if m.CloseIfNoSession() {
t.Error("able to close")
}
m.Remove(s.ID) m.Remove(s.ID)
assert(m.CloseIfNoSession(), IsTrue) if !m.CloseIfNoSession() {
t.Error("not able to close")
}
} }

View File

@ -6,31 +6,36 @@ import (
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
. "v2ray.com/core/common/protocol" . "v2ray.com/core/common/protocol"
. "v2ray.com/ext/assert"
) )
func TestServerList(t *testing.T) { func TestServerList(t *testing.T) {
assert := With(t)
list := NewServerList() list := NewServerList()
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid())) list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid()))
assert(list.Size(), Equals, uint32(1)) if list.Size() != 1 {
t.Error("list size: ", list.Size())
}
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second)))) list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second))))
assert(list.Size(), Equals, uint32(2)) if list.Size() != 2 {
t.Error("list.size: ", list.Size())
}
server := list.GetServer(1) server := list.GetServer(1)
assert(server.Destination().Port, Equals, net.Port(2)) if server.Destination().Port != 2 {
t.Error("server: ", server.Destination())
}
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
server = list.GetServer(1) server = list.GetServer(1)
assert(server, IsNil) if server != nil {
t.Error("server: ", server)
}
server = list.GetServer(0) server = list.GetServer(0)
assert(server.Destination().Port, Equals, net.Port(1)) if server.Destination().Port != 1 {
t.Error("server: ", server.Destination())
}
} }
func TestServerPicker(t *testing.T) { func TestServerPicker(t *testing.T) {
assert := With(t)
list := NewServerList() list := NewServerList()
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid())) list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid()))
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second)))) list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second))))
@ -38,17 +43,29 @@ func TestServerPicker(t *testing.T) {
picker := NewRoundRobinServerPicker(list) picker := NewRoundRobinServerPicker(list)
server := picker.PickServer() server := picker.PickServer()
assert(server.Destination().Port, Equals, net.Port(1)) if server.Destination().Port != 1 {
t.Error("server: ", server.Destination())
}
server = picker.PickServer() server = picker.PickServer()
assert(server.Destination().Port, Equals, net.Port(2)) if server.Destination().Port != 2 {
t.Error("server: ", server.Destination())
}
server = picker.PickServer() server = picker.PickServer()
assert(server.Destination().Port, Equals, net.Port(3)) if server.Destination().Port != 3 {
t.Error("server: ", server.Destination())
}
server = picker.PickServer() server = picker.PickServer()
assert(server.Destination().Port, Equals, net.Port(1)) if server.Destination().Port != 1 {
t.Error("server: ", server.Destination())
}
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
server = picker.PickServer() server = picker.PickServer()
assert(server.Destination().Port, Equals, net.Port(1)) if server.Destination().Port != 1 {
t.Error("server: ", server.Destination())
}
server = picker.PickServer() server = picker.PickServer()
assert(server.Destination().Port, Equals, net.Port(1)) if server.Destination().Port != 1 {
t.Error("server: ", server.Destination())
}
} }

View File

@ -1,6 +1,7 @@
package protocol_test package protocol_test
import ( import (
"strings"
"testing" "testing"
"time" "time"
@ -9,34 +10,37 @@ import (
. "v2ray.com/core/common/protocol" . "v2ray.com/core/common/protocol"
"v2ray.com/core/common/uuid" "v2ray.com/core/common/uuid"
"v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess"
. "v2ray.com/ext/assert"
) )
func TestAlwaysValidStrategy(t *testing.T) { func TestAlwaysValidStrategy(t *testing.T) {
assert := With(t)
strategy := AlwaysValid() strategy := AlwaysValid()
assert(strategy.IsValid(), IsTrue) if !strategy.IsValid() {
t.Error("strategy not valid")
}
strategy.Invalidate() strategy.Invalidate()
assert(strategy.IsValid(), IsTrue) if !strategy.IsValid() {
t.Error("strategy not valid")
}
} }
func TestTimeoutValidStrategy(t *testing.T) { func TestTimeoutValidStrategy(t *testing.T) {
assert := With(t)
strategy := BeforeTime(time.Now().Add(2 * time.Second)) strategy := BeforeTime(time.Now().Add(2 * time.Second))
assert(strategy.IsValid(), IsTrue) if !strategy.IsValid() {
t.Error("strategy not valid")
}
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
assert(strategy.IsValid(), IsFalse) if strategy.IsValid() {
t.Error("strategy is valid")
}
strategy = BeforeTime(time.Now().Add(2 * time.Second)) strategy = BeforeTime(time.Now().Add(2 * time.Second))
strategy.Invalidate() strategy.Invalidate()
assert(strategy.IsValid(), IsFalse) if strategy.IsValid() {
t.Error("strategy is valid")
}
} }
func TestUserInServerSpec(t *testing.T) { func TestUserInServerSpec(t *testing.T) {
assert := With(t)
uuid1 := uuid.New() uuid1 := uuid.New()
uuid2 := uuid.New() uuid2 := uuid.New()
@ -50,22 +54,26 @@ func TestUserInServerSpec(t *testing.T) {
Email: "test1@v2ray.com", Email: "test1@v2ray.com",
Account: toAccount(&vmess.Account{Id: uuid1.String()}), Account: toAccount(&vmess.Account{Id: uuid1.String()}),
}) })
assert(spec.HasUser(&MemoryUser{ if spec.HasUser(&MemoryUser{
Email: "test1@v2ray.com", Email: "test1@v2ray.com",
Account: toAccount(&vmess.Account{Id: uuid2.String()}), Account: toAccount(&vmess.Account{Id: uuid2.String()}),
}), IsFalse) }) {
t.Error("has user: ", uuid2)
}
spec.AddUser(&MemoryUser{Email: "test2@v2ray.com"}) spec.AddUser(&MemoryUser{Email: "test2@v2ray.com"})
assert(spec.HasUser(&MemoryUser{ if !spec.HasUser(&MemoryUser{
Email: "test1@v2ray.com", Email: "test1@v2ray.com",
Account: toAccount(&vmess.Account{Id: uuid1.String()}), Account: toAccount(&vmess.Account{Id: uuid1.String()}),
}), IsTrue) }) {
t.Error("not having user: ", uuid1)
}
} }
func TestPickUser(t *testing.T) { func TestPickUser(t *testing.T) {
assert := With(t)
spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{Email: "test1@v2ray.com"}, &MemoryUser{Email: "test2@v2ray.com"}, &MemoryUser{Email: "test3@v2ray.com"}) spec := NewServerSpec(net.Destination{}, AlwaysValid(), &MemoryUser{Email: "test1@v2ray.com"}, &MemoryUser{Email: "test2@v2ray.com"}, &MemoryUser{Email: "test3@v2ray.com"})
user := spec.PickUser() user := spec.PickUser()
assert(user.Email, HasSuffix, "@v2ray.com") if !strings.HasSuffix(user.Email, "@v2ray.com") {
t.Error("user: ", user.Email)
}
} }

View File

@ -3,6 +3,8 @@ package encoding_test
import ( import (
"testing" "testing"
"github.com/google/go-cmp/cmp"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
@ -10,7 +12,6 @@ import (
"v2ray.com/core/common/uuid" "v2ray.com/core/common/uuid"
"v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess"
. "v2ray.com/core/proxy/vmess/encoding" . "v2ray.com/core/proxy/vmess/encoding"
. "v2ray.com/ext/assert"
) )
func toAccount(a *vmess.Account) protocol.Account { func toAccount(a *vmess.Account) protocol.Account {
@ -20,8 +21,6 @@ func toAccount(a *vmess.Account) protocol.Account {
} }
func TestRequestSerialization(t *testing.T) { func TestRequestSerialization(t *testing.T) {
assert := With(t)
user := &protocol.MemoryUser{ user := &protocol.MemoryUser{
Level: 0, Level: 0,
Email: "test@v2ray.com", Email: "test@v2ray.com",
@ -60,21 +59,18 @@ func TestRequestSerialization(t *testing.T) {
actualRequest, err := server.DecodeRequestHeader(buffer) actualRequest, err := server.DecodeRequestHeader(buffer)
common.Must(err) common.Must(err)
assert(expectedRequest.Version, Equals, actualRequest.Version) if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
assert(byte(expectedRequest.Command), Equals, byte(actualRequest.Command)) t.Error(r)
assert(byte(expectedRequest.Option), Equals, byte(actualRequest.Option)) }
assert(expectedRequest.Address, Equals, actualRequest.Address)
assert(expectedRequest.Port, Equals, actualRequest.Port)
assert(byte(expectedRequest.Security), Equals, byte(actualRequest.Security))
_, err = server.DecodeRequestHeader(buffer2) _, err = server.DecodeRequestHeader(buffer2)
// anti replay attack // anti replay attack
assert(err, IsNotNil) if err == nil {
t.Error("nil error")
}
} }
func TestInvalidRequest(t *testing.T) { func TestInvalidRequest(t *testing.T) {
assert := With(t)
user := &protocol.MemoryUser{ user := &protocol.MemoryUser{
Level: 0, Level: 0,
Email: "test@v2ray.com", Email: "test@v2ray.com",
@ -111,12 +107,12 @@ func TestInvalidRequest(t *testing.T) {
server := NewServerSession(userValidator, sessionHistory) server := NewServerSession(userValidator, sessionHistory)
_, err := server.DecodeRequestHeader(buffer) _, err := server.DecodeRequestHeader(buffer)
assert(err, IsNotNil) if err == nil {
t.Error("nil error")
}
} }
func TestMuxRequest(t *testing.T) { func TestMuxRequest(t *testing.T) {
assert := With(t)
user := &protocol.MemoryUser{ user := &protocol.MemoryUser{
Level: 0, Level: 0,
Email: "test@v2ray.com", Email: "test@v2ray.com",
@ -133,6 +129,7 @@ func TestMuxRequest(t *testing.T) {
User: user, User: user,
Command: protocol.RequestCommandMux, Command: protocol.RequestCommandMux,
Security: protocol.SecurityType_AES128_GCM, Security: protocol.SecurityType_AES128_GCM,
Address: net.DomainAddress("v1.mux.cool"),
} }
buffer := buf.New() buffer := buf.New()
@ -153,8 +150,7 @@ func TestMuxRequest(t *testing.T) {
actualRequest, err := server.DecodeRequestHeader(buffer) actualRequest, err := server.DecodeRequestHeader(buffer)
common.Must(err) common.Must(err)
assert(expectedRequest.Version, Equals, actualRequest.Version) if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
assert(byte(expectedRequest.Command), Equals, byte(actualRequest.Command)) t.Error(r)
assert(byte(expectedRequest.Option), Equals, byte(actualRequest.Option)) }
assert(byte(expectedRequest.Security), Equals, byte(actualRequest.Security))
} }

View File

@ -6,6 +6,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
@ -14,12 +16,9 @@ import (
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
. "v2ray.com/core/transport/internet/http" . "v2ray.com/core/transport/internet/http"
"v2ray.com/core/transport/internet/tls" "v2ray.com/core/transport/internet/tls"
. "v2ray.com/ext/assert"
) )
func TestHTTPConnection(t *testing.T) { func TestHTTPConnection(t *testing.T) {
assert := With(t)
port := tcp.PickPort() port := tcp.PickPort()
listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{ listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
@ -40,9 +39,8 @@ func TestHTTPConnection(t *testing.T) {
if _, err := b.ReadFrom(conn); err != nil { if _, err := b.ReadFrom(conn); err != nil {
return return
} }
nBytes, err := conn.Write(b.Bytes()) _, err := conn.Write(b.Bytes())
common.Must(err) common.Must(err)
assert(int32(nBytes), Equals, b.Len())
} }
}() }()
}) })
@ -71,18 +69,26 @@ func TestHTTPConnection(t *testing.T) {
b2 := buf.New() b2 := buf.New()
nBytes, err := conn.Write(b1) nBytes, err := conn.Write(b1)
assert(nBytes, Equals, N)
common.Must(err) common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}
b2.Clear() b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N)) common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
nBytes, err = conn.Write(b1) nBytes, err = conn.Write(b1)
assert(nBytes, Equals, N)
common.Must(err) common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}
b2.Clear() b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N)) common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
} }

View File

@ -4,20 +4,20 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"io" "io"
"sync"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"golang.org/x/sync/errgroup"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
. "v2ray.com/core/transport/internet/kcp" . "v2ray.com/core/transport/internet/kcp"
. "v2ray.com/ext/assert"
) )
func TestDialAndListen(t *testing.T) { func TestDialAndListen(t *testing.T) {
assert := With(t)
listerner, err := NewListener(context.Background(), net.LocalHostIP, net.Port(0), &internet.MemoryStreamConfig{ listerner, err := NewListener(context.Background(), net.LocalHostIP, net.Port(0), &internet.MemoryStreamConfig{
ProtocolName: "mkcp", ProtocolName: "mkcp",
ProtocolSettings: &Config{}, ProtocolSettings: &Config{},
@ -38,42 +38,48 @@ func TestDialAndListen(t *testing.T) {
}(conn) }(conn)
}) })
common.Must(err) common.Must(err)
defer listerner.Close()
port := net.Port(listerner.Addr().(*net.UDPAddr).Port) port := net.Port(listerner.Addr().(*net.UDPAddr).Port)
wg := new(sync.WaitGroup) var errg errgroup.Group
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
clientConn, err := DialKCP(context.Background(), net.UDPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{ errg.Go(func() error {
ProtocolName: "mkcp", clientConn, err := DialKCP(context.Background(), net.UDPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{
ProtocolSettings: &Config{}, ProtocolName: "mkcp",
}) ProtocolSettings: &Config{},
common.Must(err) })
wg.Add(1) if err != nil {
return err
}
defer clientConn.Close()
go func() {
clientSend := make([]byte, 1024*1024) clientSend := make([]byte, 1024*1024)
rand.Read(clientSend) rand.Read(clientSend)
go clientConn.Write(clientSend) go clientConn.Write(clientSend)
clientReceived := make([]byte, 1024*1024) clientReceived := make([]byte, 1024*1024)
nBytes, _ := io.ReadFull(clientConn, clientReceived) common.Must2(io.ReadFull(clientConn, clientReceived))
assert(nBytes, Equals, len(clientReceived))
clientConn.Close()
clientExpected := make([]byte, 1024*1024) clientExpected := make([]byte, 1024*1024)
for idx, b := range clientSend { for idx, b := range clientSend {
clientExpected[idx] = b ^ 'c' clientExpected[idx] = b ^ 'c'
} }
assert(clientReceived, Equals, clientExpected) if r := cmp.Diff(clientReceived, clientExpected); r != "" {
return errors.New(r)
wg.Done() }
}() return nil
})
}
if err := errg.Wait(); err != nil {
t.Fatal(err)
} }
wg.Wait()
for i := 0; i < 60 && listerner.ActiveConnections() > 0; i++ { for i := 0; i < 60 && listerner.ActiveConnections() > 0; i++ {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
assert(listerner.ActiveConnections(), Equals, 0) if v := listerner.ActiveConnections(); v != 0 {
t.Error("active connections: ", v)
listerner.Close() }
} }

View File

@ -3,21 +3,23 @@ package kcp_test
import ( import (
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
. "v2ray.com/core/transport/internet/kcp" . "v2ray.com/core/transport/internet/kcp"
. "v2ray.com/ext/assert"
) )
func TestBadSegment(t *testing.T) { func TestBadSegment(t *testing.T) {
assert := With(t)
seg, buf := ReadSegment(nil) seg, buf := ReadSegment(nil)
assert(seg, IsNil) if seg != nil {
assert(len(buf), Equals, 0) t.Error("non-nil seg")
}
if len(buf) != 0 {
t.Error("buf len: ", len(buf))
}
} }
func TestDataSegment(t *testing.T) { func TestDataSegment(t *testing.T) {
assert := With(t)
seg := &DataSegment{ seg := &DataSegment{
Conv: 1, Conv: 1,
Timestamp: 3, Timestamp: 3,
@ -30,20 +32,17 @@ func TestDataSegment(t *testing.T) {
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
seg.Serialize(bytes) seg.Serialize(bytes)
assert(int32(len(bytes)), Equals, nBytes)
iseg, _ := ReadSegment(bytes) iseg, _ := ReadSegment(bytes)
seg2 := iseg.(*DataSegment) seg2 := iseg.(*DataSegment)
assert(seg2.Conv, Equals, seg.Conv) if r := cmp.Diff(seg2, seg, cmpopts.IgnoreUnexported(DataSegment{})); r != "" {
assert(seg2.Timestamp, Equals, seg.Timestamp) t.Error(r)
assert(seg2.SendingNext, Equals, seg.SendingNext) }
assert(seg2.Number, Equals, seg.Number) if r := cmp.Diff(seg2.Data().Bytes(), seg.Data().Bytes()); r != "" {
assert(seg2.Data().Bytes(), Equals, seg.Data().Bytes()) t.Error(r)
}
} }
func Test1ByteDataSegment(t *testing.T) { func Test1ByteDataSegment(t *testing.T) {
assert := With(t)
seg := &DataSegment{ seg := &DataSegment{
Conv: 1, Conv: 1,
Timestamp: 3, Timestamp: 3,
@ -56,20 +55,17 @@ func Test1ByteDataSegment(t *testing.T) {
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
seg.Serialize(bytes) seg.Serialize(bytes)
assert(int32(len(bytes)), Equals, nBytes)
iseg, _ := ReadSegment(bytes) iseg, _ := ReadSegment(bytes)
seg2 := iseg.(*DataSegment) seg2 := iseg.(*DataSegment)
assert(seg2.Conv, Equals, seg.Conv) if r := cmp.Diff(seg2, seg, cmpopts.IgnoreUnexported(DataSegment{})); r != "" {
assert(seg2.Timestamp, Equals, seg.Timestamp) t.Error(r)
assert(seg2.SendingNext, Equals, seg.SendingNext) }
assert(seg2.Number, Equals, seg.Number) if r := cmp.Diff(seg2.Data().Bytes(), seg.Data().Bytes()); r != "" {
assert(seg2.Data().Bytes(), Equals, seg.Data().Bytes()) t.Error(r)
}
} }
func TestACKSegment(t *testing.T) { func TestACKSegment(t *testing.T) {
assert := With(t)
seg := &AckSegment{ seg := &AckSegment{
Conv: 1, Conv: 1,
ReceivingWindow: 2, ReceivingWindow: 2,
@ -82,23 +78,14 @@ func TestACKSegment(t *testing.T) {
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
seg.Serialize(bytes) seg.Serialize(bytes)
assert(int32(len(bytes)), Equals, nBytes)
iseg, _ := ReadSegment(bytes) iseg, _ := ReadSegment(bytes)
seg2 := iseg.(*AckSegment) seg2 := iseg.(*AckSegment)
assert(seg2.Conv, Equals, seg.Conv) if r := cmp.Diff(seg2, seg); r != "" {
assert(seg2.ReceivingWindow, Equals, seg.ReceivingWindow) t.Error(r)
assert(seg2.ReceivingNext, Equals, seg.ReceivingNext)
assert(len(seg2.NumberList), Equals, len(seg.NumberList))
assert(seg2.Timestamp, Equals, seg.Timestamp)
for i, number := range seg2.NumberList {
assert(number, Equals, seg.NumberList[i])
} }
} }
func TestCmdSegment(t *testing.T) { func TestCmdSegment(t *testing.T) {
assert := With(t)
seg := &CmdOnlySegment{ seg := &CmdOnlySegment{
Conv: 1, Conv: 1,
Cmd: CommandPing, Cmd: CommandPing,
@ -112,14 +99,9 @@ func TestCmdSegment(t *testing.T) {
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
seg.Serialize(bytes) seg.Serialize(bytes)
assert(int32(len(bytes)), Equals, nBytes)
iseg, _ := ReadSegment(bytes) iseg, _ := ReadSegment(bytes)
seg2 := iseg.(*CmdOnlySegment) seg2 := iseg.(*CmdOnlySegment)
assert(seg2.Conv, Equals, seg.Conv) if r := cmp.Diff(seg2, seg); r != "" {
assert(byte(seg2.Command()), Equals, byte(seg.Command())) t.Error(r)
assert(byte(seg2.Option), Equals, byte(seg.Option)) }
assert(seg2.SendingNext, Equals, seg.SendingNext)
assert(seg2.ReceivingNext, Equals, seg.ReceivingNext)
assert(seg2.PeerRTO, Equals, seg.PeerRTO)
} }

View File

@ -6,6 +6,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
@ -17,12 +19,9 @@ import (
"v2ray.com/core/transport/internet/headers/wireguard" "v2ray.com/core/transport/internet/headers/wireguard"
"v2ray.com/core/transport/internet/quic" "v2ray.com/core/transport/internet/quic"
"v2ray.com/core/transport/internet/tls" "v2ray.com/core/transport/internet/tls"
. "v2ray.com/ext/assert"
) )
func TestQuicConnection(t *testing.T) { func TestQuicConnection(t *testing.T) {
assert := With(t)
port := udp.PickPort() port := udp.PickPort()
listener, err := quic.Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{ listener, err := quic.Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
@ -44,9 +43,7 @@ func TestQuicConnection(t *testing.T) {
if _, err := b.ReadFrom(conn); err != nil { if _, err := b.ReadFrom(conn); err != nil {
return return
} }
nBytes, err := conn.Write(b.Bytes()) common.Must2(conn.Write(b.Bytes()))
common.Must(err)
assert(int32(nBytes), Equals, b.Len())
} }
}() }()
}) })
@ -74,26 +71,24 @@ func TestQuicConnection(t *testing.T) {
common.Must2(rand.Read(b1)) common.Must2(rand.Read(b1))
b2 := buf.New() b2 := buf.New()
nBytes, err := conn.Write(b1) common.Must2(conn.Write(b1))
assert(nBytes, Equals, N)
common.Must(err)
b2.Clear() b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N)) common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
nBytes, err = conn.Write(b1) common.Must2(conn.Write(b1))
assert(nBytes, Equals, N)
common.Must(err)
b2.Clear() b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N)) common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
} }
func TestQuicConnectionWithoutTLS(t *testing.T) { func TestQuicConnectionWithoutTLS(t *testing.T) {
assert := With(t)
port := udp.PickPort() port := udp.PickPort()
listener, err := quic.Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{ listener, err := quic.Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
@ -111,9 +106,7 @@ func TestQuicConnectionWithoutTLS(t *testing.T) {
if _, err := b.ReadFrom(conn); err != nil { if _, err := b.ReadFrom(conn); err != nil {
return return
} }
nBytes, err := conn.Write(b.Bytes()) common.Must2(conn.Write(b.Bytes()))
common.Must(err)
assert(int32(nBytes), Equals, b.Len())
} }
}() }()
}) })
@ -136,26 +129,24 @@ func TestQuicConnectionWithoutTLS(t *testing.T) {
common.Must2(rand.Read(b1)) common.Must2(rand.Read(b1))
b2 := buf.New() b2 := buf.New()
nBytes, err := conn.Write(b1) common.Must2(conn.Write(b1))
assert(nBytes, Equals, N)
common.Must(err)
b2.Clear() b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N)) common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
nBytes, err = conn.Write(b1) common.Must2(conn.Write(b1))
assert(nBytes, Equals, N)
common.Must(err)
b2.Clear() b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N)) common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
} }
func TestQuicConnectionAuthHeader(t *testing.T) { func TestQuicConnectionAuthHeader(t *testing.T) {
assert := With(t)
port := udp.PickPort() port := udp.PickPort()
listener, err := quic.Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{ listener, err := quic.Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
@ -179,9 +170,7 @@ func TestQuicConnectionAuthHeader(t *testing.T) {
if _, err := b.ReadFrom(conn); err != nil { if _, err := b.ReadFrom(conn); err != nil {
return return
} }
nBytes, err := conn.Write(b.Bytes()) common.Must2(conn.Write(b.Bytes()))
common.Must(err)
assert(int32(nBytes), Equals, b.Len())
} }
}() }()
}) })
@ -210,19 +199,19 @@ func TestQuicConnectionAuthHeader(t *testing.T) {
common.Must2(rand.Read(b1)) common.Must2(rand.Read(b1))
b2 := buf.New() b2 := buf.New()
nBytes, err := conn.Write(b1) common.Must2(conn.Write(b1))
assert(nBytes, Equals, N)
common.Must(err)
b2.Clear() b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N)) common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
nBytes, err = conn.Write(b1) common.Must2(conn.Write(b1))
assert(nBytes, Equals, N)
common.Must(err)
b2.Clear() b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N)) common.Must2(b2.ReadFullFrom(conn, N))
assert(b2.Bytes(), Equals, b1) if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
} }

View File

@ -1,7 +1,6 @@
package websocket_test package websocket_test
import ( import (
"bytes"
"context" "context"
"runtime" "runtime"
"testing" "testing"
@ -13,12 +12,9 @@ import (
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/tls" "v2ray.com/core/transport/internet/tls"
. "v2ray.com/core/transport/internet/websocket" . "v2ray.com/core/transport/internet/websocket"
. "v2ray.com/ext/assert"
) )
func Test_listenWSAndDial(t *testing.T) { func Test_listenWSAndDial(t *testing.T) {
assert := With(t)
listen, err := ListenWS(context.Background(), net.LocalHostIP, 13146, &internet.MemoryStreamConfig{ listen, err := ListenWS(context.Background(), net.LocalHostIP, 13146, &internet.MemoryStreamConfig{
ProtocolName: "websocket", ProtocolName: "websocket",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
@ -29,15 +25,12 @@ func Test_listenWSAndDial(t *testing.T) {
defer c.Close() defer c.Close()
var b [1024]byte var b [1024]byte
n, err := c.Read(b[:]) _, err := c.Read(b[:])
//common.Must(err)
if err != nil { if err != nil {
return return
} }
assert(bytes.HasPrefix(b[:n], []byte("Test connection")), IsTrue)
_, err = c.Write([]byte("Response")) common.Must2(c.Write([]byte("Response")))
common.Must(err)
}(conn) }(conn)
}) })
common.Must(err) common.Must(err)
@ -56,9 +49,11 @@ func Test_listenWSAndDial(t *testing.T) {
var b [1024]byte var b [1024]byte
n, err := conn.Read(b[:]) n, err := conn.Read(b[:])
common.Must(err) common.Must(err)
assert(string(b[:n]), Equals, "Response") if string(b[:n]) != "Response" {
t.Error("response: ", string(b[:n]))
}
assert(conn.Close(), IsNil) common.Must(conn.Close())
<-time.After(time.Second * 5) <-time.After(time.Second * 5)
conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings) conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
common.Must(err) common.Must(err)
@ -66,14 +61,15 @@ func Test_listenWSAndDial(t *testing.T) {
common.Must(err) common.Must(err)
n, err = conn.Read(b[:]) n, err = conn.Read(b[:])
common.Must(err) common.Must(err)
assert(string(b[:n]), Equals, "Response") if string(b[:n]) != "Response" {
assert(conn.Close(), IsNil) t.Error("response: ", string(b[:n]))
}
common.Must(conn.Close())
assert(listen.Close(), IsNil) common.Must(listen.Close())
} }
func TestDialWithRemoteAddr(t *testing.T) { func TestDialWithRemoteAddr(t *testing.T) {
assert := With(t)
listen, err := ListenWS(context.Background(), net.LocalHostIP, 13148, &internet.MemoryStreamConfig{ listen, err := ListenWS(context.Background(), net.LocalHostIP, 13148, &internet.MemoryStreamConfig{
ProtocolName: "websocket", ProtocolName: "websocket",
ProtocolSettings: &Config{ ProtocolSettings: &Config{
@ -83,15 +79,12 @@ func TestDialWithRemoteAddr(t *testing.T) {
go func(c internet.Connection) { go func(c internet.Connection) {
defer c.Close() defer c.Close()
assert(c.RemoteAddr().String(), HasPrefix, "1.1.1.1")
var b [1024]byte var b [1024]byte
n, err := c.Read(b[:]) _, err := c.Read(b[:])
//common.Must(err) //common.Must(err)
if err != nil { if err != nil {
return return
} }
assert(bytes.HasPrefix(b[:n], []byte("Test connection")), IsTrue)
_, err = c.Write([]byte("Response")) _, err = c.Write([]byte("Response"))
common.Must(err) common.Must(err)
@ -111,9 +104,11 @@ func TestDialWithRemoteAddr(t *testing.T) {
var b [1024]byte var b [1024]byte
n, err := conn.Read(b[:]) n, err := conn.Read(b[:])
common.Must(err) common.Must(err)
assert(string(b[:n]), Equals, "Response") if string(b[:n]) != "Response" {
t.Error("response: ", string(b[:n]))
}
assert(listen.Close(), IsNil) common.Must(listen.Close())
} }
func Test_listenWSAndDial_TLS(t *testing.T) { func Test_listenWSAndDial_TLS(t *testing.T) {
@ -121,8 +116,6 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
return return
} }
assert := With(t)
start := time.Now() start := time.Now()
streamSettings := &internet.MemoryStreamConfig{ streamSettings := &internet.MemoryStreamConfig{
@ -149,5 +142,7 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
_ = conn.Close() _ = conn.Close()
end := time.Now() end := time.Now()
assert(end.Before(start.Add(time.Second*5)), IsTrue) if !end.Before(start.Add(time.Second * 5)) {
t.Error("end: ", end, " start: ", start)
}
} }