diff --git a/common/serial/numbers.go b/common/serial/numbers.go index 47a1e3340..883b2b642 100644 --- a/common/serial/numbers.go +++ b/common/serial/numbers.go @@ -16,6 +16,10 @@ func Uint32ToBytes(value uint32, b []byte) []byte { return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value)) } +func Uint32ToString(value uint32) string { + return strconv.FormatUint(uint64(value), 10) +} + func IntToBytes(value int, b []byte) []byte { return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value)) } diff --git a/testing/assert/pointer.go b/testing/assert/pointer.go index 87cb3657f..f1e75a549 100644 --- a/testing/assert/pointer.go +++ b/testing/assert/pointer.go @@ -1,6 +1,8 @@ package assert import ( + "reflect" + "github.com/v2ray/v2ray-core/common/serial" ) @@ -26,7 +28,15 @@ func (subject *PointerSubject) Equals(expectation interface{}) { } func (subject *PointerSubject) IsNil() { - if subject.value != nil { + if subject.value == nil { + return + } + + valueType := reflect.TypeOf(subject.value) + nilType := reflect.Zero(valueType) + realValue := reflect.ValueOf(subject.value) + + if nilType != realValue { subject.Fail("is", "nil") } } @@ -35,4 +45,12 @@ func (subject *PointerSubject) IsNotNil() { if subject.value == nil { subject.Fail("is not", "nil") } + + valueType := reflect.TypeOf(subject.value) + nilType := reflect.Zero(valueType) + realValue := reflect.ValueOf(subject.value) + + if nilType == realValue { + subject.Fail("is not", "nil") + } } diff --git a/testing/assert/uint32.go b/testing/assert/uint32.go new file mode 100644 index 000000000..a032b5fce --- /dev/null +++ b/testing/assert/uint32.go @@ -0,0 +1,50 @@ +package assert + +import ( + "github.com/v2ray/v2ray-core/common/serial" +) + +func (this *Assert) Uint32(value uint32) *Uint32Subject { + return &Uint32Subject{ + Subject: Subject{ + a: this, + disp: serial.Uint32ToString(value), + }, + value: value, + } +} + +type Uint32Subject struct { + Subject + value uint32 +} + +func (subject *Uint32Subject) Equals(expectation uint32) { + if subject.value != expectation { + subject.Fail("is equal to", serial.Uint32ToString(expectation)) + } +} + +func (subject *Uint32Subject) GreaterThan(expectation uint32) { + if subject.value <= expectation { + subject.Fail("is greater than", serial.Uint32ToString(expectation)) + } +} + +func (subject *Uint32Subject) LessThan(expectation uint32) { + if subject.value >= expectation { + subject.Fail("is less than", serial.Uint32ToString(expectation)) + } +} + +func (subject *Uint32Subject) IsPositive() { + if subject.value <= 0 { + subject.Fail("is", "positive") + } +} + +func (subject *Uint32Subject) IsNegative() { + if subject.value >= 0 { + subject.Fail("is not", "negative") + } +} diff --git a/transport/internet/kcp/segment.go b/transport/internet/kcp/segment.go new file mode 100644 index 000000000..c30c5cf34 --- /dev/null +++ b/transport/internet/kcp/segment.go @@ -0,0 +1,171 @@ +package kcp + +import ( + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/common/serial" +) + +type SegmentCommand byte + +const ( + SegmentCommandACK SegmentCommand = 0 + SegmentCommandData SegmentCommand = 1 + SegmentCommandTerminated SegmentCommand = 2 +) + +type SegmentOption byte + +const ( + SegmentOptionClose SegmentOption = 1 +) + +type ISegment interface { + ByteSize() int + Bytes([]byte) []byte +} + +type DataSegment struct { + Conv uint16 + Opt SegmentOption + ReceivingWindow uint32 + Timestamp uint32 + Number uint32 + Unacknowledged uint32 + Data *alloc.Buffer + + timeout uint32 + ackSkipped uint32 + transmit uint32 +} + +func (this *DataSegment) Bytes(b []byte) []byte { + b = serial.Uint16ToBytes(this.Conv, b) + b = append(b, byte(SegmentCommandData), byte(this.Opt)) + b = serial.Uint32ToBytes(this.ReceivingWindow, b) + b = serial.Uint32ToBytes(this.Timestamp, b) + b = serial.Uint32ToBytes(this.Number, b) + b = serial.Uint32ToBytes(this.Unacknowledged, b) + b = serial.Uint16ToBytes(uint16(this.Data.Len()), b) + b = append(b, this.Data.Value...) + return b +} + +func (this *DataSegment) ByteSize() int { + return 2 + 1 + 1 + 4 + 4 + 4 + 4 + 2 + this.Data.Len() +} + +type ACKSegment struct { + Conv uint16 + Opt SegmentOption + ReceivingWindow uint32 + Unacknowledged uint32 + Count byte + NumberList []uint32 + TimestampList []uint32 +} + +func (this *ACKSegment) ByteSize() int { + return 2 + 1 + 1 + 4 + 4 + 1 + len(this.NumberList)*4 + len(this.TimestampList)*4 +} + +func (this *ACKSegment) Bytes(b []byte) []byte { + b = serial.Uint16ToBytes(this.Conv, b) + b = append(b, byte(SegmentCommandACK), byte(this.Opt)) + b = serial.Uint32ToBytes(this.ReceivingWindow, b) + b = serial.Uint32ToBytes(this.Unacknowledged, b) + b = append(b, this.Count) + for i := byte(0); i < this.Count; i++ { + b = serial.Uint32ToBytes(this.NumberList[i], b) + b = serial.Uint32ToBytes(this.TimestampList[i], b) + } + return b +} + +type TerminationSegment struct { + Conv uint16 + Opt SegmentOption +} + +func (this *TerminationSegment) ByteSize() int { + return 2 + 1 + 1 +} + +func (this *TerminationSegment) Bytes(b []byte) []byte { + b = serial.Uint16ToBytes(this.Conv, b) + b = append(b, byte(SegmentCommandTerminated), byte(this.Opt)) + return b +} + +func ReadSegment(buf []byte) (ISegment, []byte) { + if len(buf) <= 12 { + return nil, nil + } + + conv := serial.BytesToUint16(buf) + buf = buf[2:] + + cmd := SegmentCommand(buf[0]) + opt := SegmentOption(buf[1]) + buf = buf[2:] + + if cmd == SegmentCommandData { + seg := &DataSegment{ + Conv: conv, + Opt: opt, + } + seg.ReceivingWindow = serial.BytesToUint32(buf) + buf = buf[4:] + + seg.Timestamp = serial.BytesToUint32(buf) + buf = buf[4:] + + seg.Number = serial.BytesToUint32(buf) + buf = buf[4:] + + seg.Unacknowledged = serial.BytesToUint32(buf) + buf = buf[4:] + + len := serial.BytesToUint16(buf) + buf = buf[2:] + + seg.Data = alloc.NewSmallBuffer().Clear().Append(buf[:len]) + buf = buf[len:] + + return seg, buf + } + + if cmd == SegmentCommandACK { + seg := &ACKSegment{ + Conv: conv, + Opt: opt, + } + seg.ReceivingWindow = serial.BytesToUint32(buf) + buf = buf[4:] + + seg.Unacknowledged = serial.BytesToUint32(buf) + buf = buf[4:] + + seg.Count = buf[0] + buf = buf[1:] + + seg.NumberList = make([]uint32, 0, seg.Count) + seg.TimestampList = make([]uint32, 0, seg.Count) + + for i := 0; i < int(seg.Count); i++ { + seg.NumberList = append(seg.NumberList, serial.BytesToUint32(buf)) + seg.TimestampList = append(seg.TimestampList, serial.BytesToUint32(buf[4:])) + buf = buf[8:] + } + + return seg, buf + } + + if cmd == SegmentCommandTerminated { + return &TerminationSegment{ + Conv: conv, + Opt: opt, + }, buf + } + + return nil, nil +} diff --git a/transport/internet/kcp/segment_test.go b/transport/internet/kcp/segment_test.go new file mode 100644 index 000000000..c1fc1ac5c --- /dev/null +++ b/transport/internet/kcp/segment_test.go @@ -0,0 +1,73 @@ +package kcp_test + +import ( + "testing" + + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/testing/assert" + . "github.com/v2ray/v2ray-core/transport/internet/kcp" +) + +func TestBadSegment(t *testing.T) { + assert := assert.On(t) + + seg, buf := ReadSegment(nil) + assert.Pointer(seg).IsNil() + assert.Int(len(buf)).Equals(0) +} + +func TestDataSegment(t *testing.T) { + assert := assert.On(t) + + seg := &DataSegment{ + Conv: 1, + ReceivingWindow: 2, + Timestamp: 3, + Number: 4, + Unacknowledged: 5, + Data: alloc.NewSmallBuffer().Clear().Append([]byte{'a', 'b', 'c', 'd'}), + } + + nBytes := seg.ByteSize() + bytes := seg.Bytes(nil) + + assert.Int(len(bytes)).Equals(nBytes) + + iseg, _ := ReadSegment(bytes) + seg2 := iseg.(*DataSegment) + assert.Uint16(seg2.Conv).Equals(seg.Conv) + assert.Uint32(seg2.ReceivingWindow).Equals(seg.ReceivingWindow) + assert.Uint32(seg2.Timestamp).Equals(seg.Timestamp) + assert.Uint32(seg2.Unacknowledged).Equals(seg.Unacknowledged) + assert.Uint32(seg2.Number).Equals(seg.Number) + assert.Bytes(seg2.Data.Value).Equals(seg.Data.Value) +} + +func TestACKSegment(t *testing.T) { + assert := assert.On(t) + + seg := &ACKSegment{ + Conv: 1, + ReceivingWindow: 2, + Unacknowledged: 3, + Count: 5, + NumberList: []uint32{1, 3, 5, 7, 9}, + TimestampList: []uint32{2, 4, 6, 8, 10}, + } + + nBytes := seg.ByteSize() + bytes := seg.Bytes(nil) + + assert.Int(len(bytes)).Equals(nBytes) + + iseg, _ := ReadSegment(bytes) + seg2 := iseg.(*ACKSegment) + assert.Uint16(seg2.Conv).Equals(seg.Conv) + assert.Uint32(seg2.ReceivingWindow).Equals(seg.ReceivingWindow) + assert.Uint32(seg2.Unacknowledged).Equals(seg.Unacknowledged) + assert.Byte(seg2.Count).Equals(seg.Count) + for i := byte(0); i < seg2.Count; i++ { + assert.Uint32(seg2.TimestampList[i]).Equals(seg.TimestampList[i]) + assert.Uint32(seg2.NumberList[i]).Equals(seg.NumberList[i]) + } +}