diff --git a/io/socks/socks_test.go b/io/socks/socks_test.go index 1576cb576..81a265500 100644 --- a/io/socks/socks_test.go +++ b/io/socks/socks_test.go @@ -7,6 +7,21 @@ import ( "github.com/v2ray/v2ray-core/testing/unit" ) +func TestHasAuthenticationMethod(t *testing.T) { + assert := unit.Assert(t) + + request := Socks5AuthenticationRequest{ + version: socksVersion, + nMethods: byte(0x02), + authMethods: [256]byte{0x01, 0x02}, + } + + assert.Bool(request.HasAuthMethod(byte(0x01))).IsTrue() + + request.authMethods[0] = byte(0x03) + assert.Bool(request.HasAuthMethod(byte(0x01))).IsFalse() +} + func TestAuthenticationRequestRead(t *testing.T) { assert := unit.Assert(t) @@ -22,6 +37,19 @@ func TestAuthenticationRequestRead(t *testing.T) { assert.Byte(request.authMethods[0]).Named("Auth Method").Equals(0x02) } +func TestAuthenticationResponseWrite(t *testing.T) { + assert := unit.Assert(t) + + response := Socks5AuthenticationResponse{ + version: socksVersion, + authMethod: byte(0x05), + } + + buffer := bytes.NewBuffer(make([]byte, 0, 10)) + WriteAuthentication(buffer, &response) + assert.Bytes(buffer.Bytes()).Equals([]byte{socksVersion, byte(0x05)}) +} + func TestRequestRead(t *testing.T) { assert := unit.Assert(t) diff --git a/testing/unit/assertions.go b/testing/unit/assertions.go index 92cccd546..bb85507a7 100644 --- a/testing/unit/assertions.go +++ b/testing/unit/assertions.go @@ -39,3 +39,7 @@ func (a *Assertion) String(value string) *StringSubject { func (a *Assertion) Error(value error) *ErrorSubject { return NewErrorSubject(NewSubject(a), value) } + +func (a *Assertion) Bool(value bool) *BoolSubject { + return NewBoolSubject(NewSubject(a), value) +} diff --git a/testing/unit/boolsubject.go b/testing/unit/boolsubject.go new file mode 100644 index 000000000..a3f9682e5 --- /dev/null +++ b/testing/unit/boolsubject.go @@ -0,0 +1,48 @@ +package unit + +import ( + "strconv" +) + +type BoolSubject struct { + *Subject + value bool +} + +func NewBoolSubject(base *Subject, value bool) *BoolSubject { + return &BoolSubject{ + Subject: base, + value: value, + } +} + +func (subject *BoolSubject) Named(name string) *BoolSubject { + subject.Subject.Named(name) + return subject +} + +func (subject *BoolSubject) Fail(verb string, other bool) { + subject.FailWithMessage("Not true that " + subject.DisplayString() + " " + verb + " <" + strconv.FormatBool(other) + ">.") +} + +func (subject *BoolSubject) DisplayString() string { + return subject.Subject.DisplayString(strconv.FormatBool(subject.value)) +} + +func (subject *BoolSubject) Equals(expectation bool) { + if subject.value != expectation { + subject.Fail("is equal to", expectation) + } +} + +func (subject *BoolSubject) IsTrue() { + if subject.value != true { + subject.Fail("is", true) + } +} + +func (subject *BoolSubject) IsFalse() { + if subject.value != false { + subject.Fail("is", false) + } +}