diff --git a/common/net/port.go b/common/net/port.go index 1e665550a..97c3b1126 100644 --- a/common/net/port.go +++ b/common/net/port.go @@ -1,15 +1,38 @@ package net import ( + "errors" + "strconv" + "github.com/v2ray/v2ray-core/common/serial" ) +var ( + // ErrorInvalidPortRage indicates an error during port range parsing. + ErrorInvalidPortRange = errors.New("Invalid port range.") +) + type Port serial.Uint16Literal func PortFromBytes(port []byte) Port { return Port(serial.BytesLiteral(port).Uint16Value()) } +func PortFromInt(v int) (Port, error) { + if v <= 0 || v > 65535 { + return Port(0), ErrorInvalidPortRange + } + return Port(v), nil +} + +func PortFromString(s string) (Port, error) { + v, err := strconv.Atoi(s) + if err != nil { + return Port(0), ErrorInvalidPortRange + } + return PortFromInt(v) +} + func (this Port) Value() uint16 { return uint16(this) } diff --git a/common/net/port_json.go b/common/net/port_json.go index c2458e271..51137a038 100644 --- a/common/net/port_json.go +++ b/common/net/port_json.go @@ -4,66 +4,66 @@ package net import ( "encoding/json" - "errors" - "strconv" "strings" "github.com/v2ray/v2ray-core/common/log" - "github.com/v2ray/v2ray-core/common/serial" ) -var ( - ErrorInvalidPortRange = errors.New("Invalid port range.") -) +func parseIntPort(data []byte) (Port, error) { + var intPort int + err := json.Unmarshal(data, &intPort) + if err != nil { + return Port(0), err + } + return PortFromInt(intPort) +} +func parseStringPort(data []byte) (Port, Port, error) { + var s string + err := json.Unmarshal(data, &s) + if err != nil { + return Port(0), Port(0), err + } + pair := strings.SplitN(s, "-", 2) + if len(pair) == 0 { + return Port(0), Port(0), ErrorInvalidPortRange + } + if len(pair) == 1 { + port, err := PortFromString(pair[0]) + return port, port, err + } + + fromPort, err := PortFromString(pair[0]) + if err != nil { + return Port(0), Port(0), err + } + toPort, err := PortFromString(pair[1]) + if err != nil { + return Port(0), Port(0), err + } + return fromPort, toPort, nil +} + +// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (this *PortRange) UnmarshalJSON(data []byte) error { - var maybeint int - err := json.Unmarshal(data, &maybeint) + port, err := parseIntPort(data) if err == nil { - if maybeint <= 0 || maybeint >= 65535 { - log.Error("Invalid port [", serial.BytesLiteral(data), "]") - return ErrorInvalidPortRange - } - this.From = Port(maybeint) - this.To = Port(maybeint) + this.From = port + this.To = port return nil } - var maybestring string - err = json.Unmarshal(data, &maybestring) + from, to, err := parseStringPort(data) if err == nil { - pair := strings.SplitN(maybestring, "-", 2) - if len(pair) == 1 { - value, err := strconv.Atoi(pair[0]) - if err != nil || value <= 0 || value >= 65535 { - log.Error("Invalid from port ", pair[0]) - return ErrorInvalidPortRange - } - this.From = Port(value) - this.To = Port(value) - return nil - } else if len(pair) == 2 { - from, err := strconv.Atoi(pair[0]) - if err != nil || from <= 0 || from >= 65535 { - log.Error("Invalid from port ", pair[0]) - return ErrorInvalidPortRange - } - this.From = Port(from) - - to, err := strconv.Atoi(pair[1]) - if err != nil || to <= 0 || to >= 65535 { - log.Error("Invalid to port ", pair[1]) - return ErrorInvalidPortRange - } - this.To = Port(to) - - if this.From > this.To { - log.Error("Invalid port range ", this.From, " -> ", this.To) - return ErrorInvalidPortRange - } - return nil + this.From = from + this.To = to + if this.From > this.To { + log.Error("Invalid port range ", this.From, " -> ", this.To) + return ErrorInvalidPortRange } + return nil } + log.Error("Invalid port range: ", string(data)) return ErrorInvalidPortRange }