diff --git a/common/serial/serial.go b/common/serial/serial.go index 3c196af1e..c964db974 100644 --- a/common/serial/serial.go +++ b/common/serial/serial.go @@ -3,34 +3,42 @@ package serial import ( "encoding/binary" "io" + "unsafe" + + "v2ray.com/core/common/stack" ) // ReadUint16 reads first two bytes from the reader, and then coverts them to an uint16 value. func ReadUint16(reader io.Reader) (uint16, error) { - var b [2]byte - if _, err := io.ReadFull(reader, b[:]); err != nil { + var b stack.TwoBytes + s := b[:] + p := uintptr(unsafe.Pointer(&s)) + v := (*[]byte)(unsafe.Pointer(p)) + + if _, err := io.ReadFull(reader, *v); err != nil { return 0, err } - return binary.BigEndian.Uint16(b[:]), nil + return binary.BigEndian.Uint16(*v), nil } // WriteUint16 writes an uint16 value into writer. func WriteUint16(writer io.Writer, value uint16) (int, error) { - var b [2]byte - binary.BigEndian.PutUint16(b[:], value) - return writer.Write(b[:]) -} + var b stack.TwoBytes + s := b[:] + p := uintptr(unsafe.Pointer(&s)) + v := (*[]byte)(unsafe.Pointer(p)) -// WriteUint32 writes an uint32 value into writer. -func WriteUint32(writer io.Writer, value uint32) (int, error) { - var b [4]byte - binary.BigEndian.PutUint32(b[:], value) - return writer.Write(b[:]) + binary.BigEndian.PutUint16(*v, value) + return writer.Write(*v) } // WriteUint64 writes an uint64 value into writer. func WriteUint64(writer io.Writer, value uint64) (int, error) { - var b [8]byte - binary.BigEndian.PutUint64(b[:], value) - return writer.Write(b[:]) + var b stack.EightBytes + s := b[:] + p := uintptr(unsafe.Pointer(&s)) + v := (*[]byte)(unsafe.Pointer(p)) + + binary.BigEndian.PutUint64(*v, value) + return writer.Write(*v) } diff --git a/common/serial/serial_test.go b/common/serial/serial_test.go index 1988c59ef..f74f96af5 100644 --- a/common/serial/serial_test.go +++ b/common/serial/serial_test.go @@ -1,6 +1,7 @@ package serial_test import ( + "bytes" "testing" "github.com/google/go-cmp/cmp" @@ -10,16 +11,78 @@ import ( "v2ray.com/core/common/serial" ) -func TestUint32Serial(t *testing.T) { +func TestUint16Serial(t *testing.T) { b := buf.New() defer b.Release() - n, err := serial.WriteUint32(b, 10) + n, err := serial.WriteUint16(b, 10) common.Must(err) - if n != 4 { - t.Error("expect 4 bytes writtng, but actually ", n) + if n != 2 { + t.Error("expect 2 bytes writtng, but actually ", n) } - if diff := cmp.Diff(b.Bytes(), []byte{0, 0, 0, 10}); diff != "" { + if diff := cmp.Diff(b.Bytes(), []byte{0, 10}); diff != "" { t.Error(diff) } } + +func TestUint64Serial(t *testing.T) { + b := buf.New() + defer b.Release() + + n, err := serial.WriteUint64(b, 10) + common.Must(err) + if n != 8 { + t.Error("expect 8 bytes writtng, but actually ", n) + } + if diff := cmp.Diff(b.Bytes(), []byte{0, 0, 0, 0, 0, 0, 0, 10}); diff != "" { + t.Error(diff) + } +} + +func TestReadUint16(t *testing.T) { + testCases := []struct { + Input []byte + Output uint16 + }{ + { + Input: []byte{0, 1}, + Output: 1, + }, + } + + for _, testCase := range testCases { + v, err := serial.ReadUint16(bytes.NewReader(testCase.Input)) + common.Must(err) + if v != testCase.Output { + t.Error("for input ", testCase.Input, " expect output ", testCase.Output, " but got ", v) + } + } +} + +func BenchmarkReadUint16(b *testing.B) { + reader := buf.New() + defer reader.Release() + + common.Must2(reader.Write([]byte{0, 1})) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := serial.ReadUint16(reader) + common.Must(err) + reader.Clear() + reader.Extend(2) + } +} + +func BenchmarkWriteUint64(b *testing.B) { + writer := buf.New() + defer writer.Release() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := serial.WriteUint64(writer, 8) + common.Must(err) + writer.Clear() + } +}