diff --git a/common/protocol/address.go b/common/protocol/address.go index c7058865a..372d3702a 100644 --- a/common/protocol/address.go +++ b/common/protocol/address.go @@ -7,19 +7,21 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/common/serial" - "v2ray.com/core/common/task" ) -type AddressOption func(*AddressParser) +type AddressOption func(*option) func PortThenAddress() AddressOption { - return func(p *AddressParser) { + return func(p *option) { p.portFirst = true } } func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption { - return func(p *AddressParser) { + if b >= 16 { + panic("address family byte too big") + } + return func(p *option) { p.addrTypeMap[b] = f p.addrByteMap[f] = b } @@ -28,38 +30,127 @@ func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption { type AddressTypeParser func(byte) byte func WithAddressTypeParser(atp AddressTypeParser) AddressOption { - return func(p *AddressParser) { + return func(p *option) { p.typeParser = atp } } -// AddressParser is a utility for reading and writer addresses. -type AddressParser struct { - addrTypeMap map[byte]net.AddressFamily - addrByteMap map[net.AddressFamily]byte +type AddressSerializer interface { + ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) + + WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error +} + +const afInvalid = 255 + +type option struct { + addrTypeMap [16]net.AddressFamily + addrByteMap [16]byte portFirst bool typeParser AddressTypeParser } // NewAddressParser creates a new AddressParser -func NewAddressParser(options ...AddressOption) *AddressParser { - p := &AddressParser{ - addrTypeMap: make(map[byte]net.AddressFamily, 8), - addrByteMap: make(map[net.AddressFamily]byte, 8), +func NewAddressParser(options ...AddressOption) AddressSerializer { + var o option + for i := range o.addrByteMap { + o.addrByteMap[i] = afInvalid + } + for i := range o.addrTypeMap { + o.addrTypeMap[i] = net.AddressFamily(afInvalid) } for _, opt := range options { - opt(p) + opt(&o) } - return p + + ap := &addressParser{ + addrByteMap: o.addrByteMap, + addrTypeMap: o.addrTypeMap, + } + + if o.typeParser != nil { + ap.typeParser = o.typeParser + } + + if o.portFirst { + return portFirstAddressParser{ap: ap} + } + + return portLastAddressParser{ap: ap} } -func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) { +type portFirstAddressParser struct { + ap *addressParser +} + +func (p portFirstAddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) { + if buffer == nil { + buffer = buf.New() + defer buffer.Release() + } + + port, err := readPort(buffer, input) + if err != nil { + return nil, 0, err + } + + addr, err := p.ap.readAddress(buffer, input) + if err != nil { + return nil, 0, err + } + return addr, port, nil +} + +func (p portFirstAddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error { + if err := writePort(writer, port); err != nil { + return err + } + + return p.ap.writeAddress(writer, addr) +} + +type portLastAddressParser struct { + ap *addressParser +} + +func (p portLastAddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) { + if buffer == nil { + buffer = buf.New() + defer buffer.Release() + } + + addr, err := p.ap.readAddress(buffer, input) + if err != nil { + return nil, 0, err + } + + port, err := readPort(buffer, input) + if err != nil { + return nil, 0, err + } + + return addr, port, nil +} + +func (p portLastAddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error { + if err := p.ap.writeAddress(writer, addr); err != nil { + return err + } + + return writePort(writer, port) +} + +func readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) { if _, err := b.ReadFullFrom(reader, 2); err != nil { return 0, err } return net.PortFromBytes(b.BytesFrom(-2)), nil } +func writePort(writer io.Writer, port net.Port) error { + return common.Error2(serial.WriteUint16(writer, port.Value())) +} + func maybeIPPrefix(b byte) bool { return b == '[' || (b >= '0' && b <= '9') } @@ -73,7 +164,13 @@ func isValidDomain(d string) bool { return true } -func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) { +type addressParser struct { + addrTypeMap [16]net.AddressFamily + addrByteMap [16]byte + typeParser AddressTypeParser +} + +func (p *addressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) { if _, err := b.ReadFullFrom(reader, 1); err != nil { return nil, err } @@ -83,8 +180,12 @@ func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres addrType = p.typeParser(addrType) } - addrFamily, valid := p.addrTypeMap[addrType] - if !valid { + if addrType >= 16 { + return nil, newError("unknown address type: ", addrType) + } + + addrFamily := p.addrTypeMap[addrType] + if addrFamily == net.AddressFamily(afInvalid) { return nil, newError("unknown address type: ", addrType) } @@ -123,93 +224,34 @@ func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres } } -// ReadAddressPort reads address and port from the given input. -func (p *AddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) { - if buffer == nil { - buffer = buf.New() - defer buffer.Release() - } - - var addr net.Address - var port net.Port - - pTask := func() error { - lp, err := p.readPort(buffer, input) - if err != nil { - return err - } - port = lp - return nil - } - - aTask := func() error { - a, err := p.readAddress(buffer, input) - if err != nil { - return err - } - addr = a - return nil - } - - var err error - - if p.portFirst { - err = task.Run(task.Sequential(pTask, aTask))() - } else { - err = task.Run(task.Sequential(aTask, pTask))() - } - - if err != nil { - return nil, 0, err - } - - return addr, port, nil -} - -func (p *AddressParser) writePort(writer io.Writer, port net.Port) error { - return common.Error2(serial.WriteUint16(writer, port.Value())) -} - -func (p *AddressParser) writeAddress(writer io.Writer, address net.Address) error { - tb, valid := p.addrByteMap[address.Family()] - if !valid { +func (p *addressParser) writeAddress(writer io.Writer, address net.Address) error { + tb := p.addrByteMap[address.Family()] + if tb == afInvalid { return newError("unknown address family", address.Family()) } switch address.Family() { case net.AddressFamilyIPv4, net.AddressFamilyIPv6: - return task.Run(task.Sequential(func() error { - return common.Error2(writer.Write([]byte{tb})) - }, func() error { - return common.Error2(writer.Write(address.IP())) - }))() + if _, err := writer.Write([]byte{tb}); err != nil { + return err + } + if _, err := writer.Write(address.IP()); err != nil { + return err + } case net.AddressFamilyDomain: domain := address.Domain() if isDomainTooLong(domain) { return newError("Super long domain is not supported: ", domain) } - return task.Run(task.Sequential(func() error { - return common.Error2(writer.Write([]byte{tb, byte(len(domain))})) - }, func() error { - return common.Error2(writer.Write([]byte(domain))) - }))() + if _, err := writer.Write([]byte{tb, byte(len(domain))}); err != nil { + return err + } + if _, err := writer.Write([]byte(domain)); err != nil { + return err + } default: panic("Unknown family type.") } -} - -// WriteAddressPort writes address and port into the given writer. -func (p *AddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error { - pTask := func() error { - return p.writePort(writer, port) - } - aTask := func() error { - return p.writeAddress(writer, addr) - } - - if p.portFirst { - return task.Run(task.Sequential(pTask, aTask))() - } - - return task.Run(task.Sequential(aTask, pTask))() + + return nil } diff --git a/common/protocol/address_test.go b/common/protocol/address_test.go index c9ea0e174..9f9dd97ef 100644 --- a/common/protocol/address_test.go +++ b/common/protocol/address_test.go @@ -35,6 +35,12 @@ func TestAddressReading(t *testing.T) { Address: net.IPAddress([]byte{0, 0, 0, 0}), Port: net.Port(53), }, + { + Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4), PortThenAddress()}, + Input: []byte{0, 53, 1, 0, 0, 0, 0}, + Address: net.IPAddress([]byte{0, 0, 0, 0}), + Port: net.Port(53), + }, { Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4)}, Input: []byte{1, 0, 0, 0, 0}, @@ -134,3 +140,102 @@ func TestAddressWriting(t *testing.T) { } } } + +func BenchmarkAddressReadingIPv4(b *testing.B) { + parser := NewAddressParser(AddressFamilyByte(0x01, net.AddressFamilyIPv4)) + cache := buf.New() + defer cache.Release() + + payload := buf.New() + defer payload.Release() + + raw := []byte{1, 0, 0, 0, 0, 0, 53} + payload.Write(raw) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := parser.ReadAddressPort(cache, payload) + common.Must(err) + cache.Clear() + payload.Clear() + payload.Extend(int32(len(raw))) + } +} + +func BenchmarkAddressReadingIPv6(b *testing.B) { + parser := NewAddressParser(AddressFamilyByte(0x04, net.AddressFamilyIPv6)) + cache := buf.New() + defer cache.Release() + + payload := buf.New() + defer payload.Release() + + raw := []byte{4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80} + payload.Write(raw) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := parser.ReadAddressPort(cache, payload) + common.Must(err) + cache.Clear() + payload.Clear() + payload.Extend(int32(len(raw))) + } +} + +func BenchmarkAddressReadingDomain(b *testing.B) { + parser := NewAddressParser(AddressFamilyByte(0x03, net.AddressFamilyDomain)) + cache := buf.New() + defer cache.Release() + + payload := buf.New() + defer payload.Release() + + raw := []byte{3, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80} + payload.Write(raw) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := parser.ReadAddressPort(cache, payload) + common.Must(err) + cache.Clear() + payload.Clear() + payload.Extend(int32(len(raw))) + } +} + +func BenchmarkAddressWritingIPv4(b *testing.B) { + parser := NewAddressParser(AddressFamilyByte(0x01, net.AddressFamilyIPv4)) + writer := buf.New() + defer writer.Release() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + common.Must(parser.WriteAddressPort(writer, net.LocalHostIP, net.Port(80))) + writer.Clear() + } +} + +func BenchmarkAddressWritingIPv6(b *testing.B) { + parser := NewAddressParser(AddressFamilyByte(0x04, net.AddressFamilyIPv6)) + writer := buf.New() + defer writer.Release() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + common.Must(parser.WriteAddressPort(writer, net.LocalHostIPv6, net.Port(80))) + writer.Clear() + } +} + +func BenchmarkAddressWritingDomain(b *testing.B) { + parser := NewAddressParser(AddressFamilyByte(0x02, net.AddressFamilyDomain)) + writer := buf.New() + defer writer.Release() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + common.Must(parser.WriteAddressPort(writer, net.DomainAddress("www.v2ray.com"), net.Port(80))) + writer.Clear() + } +}