diff --git a/net/http/params.go b/net/http/params.go index 085b26f..2fcca93 100644 --- a/net/http/params.go +++ b/net/http/params.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "net/url" @@ -11,23 +12,28 @@ import ( "strings" ) -func PathParam(ctx context.Context, Param func(ctx context.Context, paramName string) string, paramName string, required bool, dt string) (p any, err error) { +func PathParam(ctx context.Context, Param func(ctx context.Context, paramName string) string, p interface{}, paramName string, required bool) (err error) { s := Param(ctx, paramName) if s == "" && required { - return nil, errors.New("missing required parameter") + switch v := p.(type) { + case *string: + p = v + p = nil + } + return errors.New("missing required parameter") } - switch dt { - case "int64": - p, err = strconv.ParseInt(s, 10, 64) - case "int32": + switch v := p.(type) { + case *int64: + *v, err = strconv.ParseInt(s, 10, 64) + case *int32: var x int64 x, err = strconv.ParseInt(s, 10, 32) - p = int32(x) - case "string": - p = s + *v = int32(x) + case *string: + *v = s default: - err = errors.New("no match for type") + err = fmt.Errorf("no match for pointer type %T", v) } return @@ -41,8 +47,7 @@ func BodyParam(body io.ReadCloser, p any, v func(p any) error) (err error) { return } -func mappedParam(m map[string][]string, paramName string, required bool, dt string) (p any, err error) { - +func mappedParam(m map[string][]string, paramName string, p interface{}, required bool) (err error) { var s string q, exists := m[paramName] if !exists { // intentionally left empty @@ -53,60 +58,57 @@ func mappedParam(m map[string][]string, paramName string, required bool, dt stri } if s == "" && required { - return nil, errors.New("missing required parameter") + return errors.New("missing required parameter") } - switch dt { - case "int64": - p, err = strconv.ParseInt(s, 10, 64) - case "int32": + switch v := p.(type) { + case *int64: + *v, err = strconv.ParseInt(s, 10, 64) + case *int32: var x int64 x, err = strconv.ParseInt(s, 10, 32) - p = int32(x) - case "bool": - var b bool - b, err = strconv.ParseBool(s) - p = bool(b) - case "string": - p = s - case "[]int64": + *v = int32(x) + case *bool: + *v, err = strconv.ParseBool(s) + case *string: + *v = s + case *[]int64: str := strings.Split(s, ",") - ints := make([]int64, len(str)) - for i, s := range str { - if v, err := strconv.ParseInt(s, 10, 64); err != nil { - return nil, err + for _, s := range str { + if e, err := strconv.ParseInt(s, 10, 64); err != nil { + return err } else { - ints[i] = v + *v = append(*v, e) + // ints[i] = e } } - p = ints - case "[]int32": + case *[]int32: str := strings.Split(s, ",") - ints := make([]int32, len(str)) - for i, s := range str { - if v, err := strconv.ParseInt(s, 10, 32); err != nil { - return nil, err + for _, s := range str { + if e, err := strconv.ParseInt(s, 10, 32); err != nil { + return err } else { - ints[i] = int32(v) + *v = append(*v, int32(e)) + // ints[i] = e } } - p = ints - case "[]string": - p = strings.Split(s, ",") + case *[]string: + *v = strings.Split(s, ",") default: - err = errors.New("no match for type") + err = fmt.Errorf("no match for pointer type %T", v) } + return } -func QueryParam(query url.Values, paramName string, required bool, dt string) (p any, err error) { - return mappedParam(query, paramName, required, dt) +func QueryParam(query url.Values, paramName string, p interface{}, required bool) (err error) { + return mappedParam(query, paramName, p, required) } -func HeaderParam(h http.Header, paramName string, required bool, dt string) (p any, err error) { - return mappedParam(h, paramName, required, dt) +func HeaderParam(h http.Header, paramName string, p interface{}, required bool) (err error) { + return mappedParam(h, paramName, p, required) } -func FormParam(form url.Values, paramName string, required bool, dt string) (p any, err error) { - return mappedParam(form, paramName, required, dt) +func FormParam(form url.Values, paramName string, p interface{}, required bool) (err error) { + return mappedParam(form, paramName, p, required) } diff --git a/net/http/params_test.go b/net/http/params_test.go index 8317495..78d0f59 100644 --- a/net/http/params_test.go +++ b/net/http/params_test.go @@ -11,246 +11,213 @@ import ( ) func TestPathParam(t *testing.T) { - type args struct { - ctx context.Context - Param func(ctx context.Context, paramName string) string - paramName string - required bool - dt string - } - tests := []struct { - name string - args args - wantP any - wantErr bool - }{ - { - name: "test int64 parse", - args: args{ - context.WithValue(context.Background(), contextKey("int64id"), "123"), - Param, - "int64id", - true, - "int64", - }, - wantP: int64(123), - wantErr: false, - }, - { - name: "test int32 parse", - args: args{ - context.WithValue(context.Background(), contextKey("int32id"), "123"), - Param, - "int32id", - true, - "int32", - }, - wantP: int32(123), - wantErr: false, - }, - { - name: "test string parse", - args: args{ - context.WithValue(context.Background(), contextKey("stringid"), "foo"), - Param, - "stringid", - true, - "string", - }, - wantP: string("foo"), - wantErr: false, - }, - { - name: "test missing required parameter", - args: args{ - context.WithValue(context.Background(), contextKey("stringid"), ""), - Param, - "stringid", - true, - "string", - }, - wantP: nil, - wantErr: true, - }, - { - name: "test unknown type parameter", - args: args{ - context.WithValue(context.Background(), contextKey("stringid"), "foo"), - Param, - "stringid", - true, - "not_a_real_type", - }, - wantP: nil, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotP, err := PathParam(tt.args.ctx, tt.args.Param, tt.args.paramName, tt.args.required, tt.args.dt) - if (err != nil) != tt.wantErr { - t.Errorf("PathParam() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotP, tt.wantP) { - t.Errorf("PathParam() = %v, want %v", gotP, tt.wantP) - } - }) - } + t.Run("test int64 parse", func(t *testing.T) { + var p int64 + err := PathParam(context.WithValue(context.Background(), contextKey("int64id"), "123"), Param, &p, "int64id", true) + if (err != nil) != false { + t.Errorf("PathParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, int64(123)) { + t.Errorf("PathParam() = %v, want %v", p, int64(123)) + } + }) + + t.Run("test int32 parse", func(t *testing.T) { + var p int32 + err := PathParam(context.WithValue(context.Background(), contextKey("int32id"), "123"), Param, &p, "int32id", true) + if (err != nil) != false { + t.Errorf("PathParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, int32(123)) { + t.Errorf("PathParam() = %v, want %v", p, int32(123)) + } + }) + + t.Run("test string parse", func(t *testing.T) { + var p string + err := PathParam( + context.WithValue( + context.Background(), + contextKey("stringid"), + "foo"), + Param, + &p, + "stringid", + true) + if (err != nil) != false { + t.Errorf("PathParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, "foo") { + t.Errorf("PathParam() = %v, want %v", p, "foo") + } + }) + + t.Run("test missing required parameter", func(t *testing.T) { + var p string + err := PathParam(context.Background(), Param, &p, "stringid", true) + if (err != nil) != true { + t.Errorf("PathParam() error = %v, wantErr %v", err, true) + return + } + if !reflect.DeepEqual(p, "") { + t.Errorf("PathParam() = %v, want %v", p, "") + } + }) + + t.Run("test unknown type parameter", func(t *testing.T) { + var p complex64 + err := PathParam(context.WithValue(context.Background(), contextKey("stringid"), "foo"), + Param, p, "stringid", true) + if (err != nil) != true { + t.Errorf("PathParam() error = %v, wantErr %v", err, true) + return + } + if !reflect.DeepEqual(p, complex64(0)) { + t.Errorf("PathParam() = %v, want %v", p, complex64(0)) + } + }) } func toValues(s string) url.Values { v, _ := url.ParseQuery(s) return v } -func TestMappedParam(t *testing.T) { - type args struct { - query url.Values - paramName string - required bool - dt string - } - tests := []struct { - name string - args args - wantP any - wantErr bool - }{ - { - name: "test int64 parse", - args: args{ - toValues("x=123"), - "x", - true, - "int64", - }, - wantP: int64(123), - wantErr: false, - }, - { - name: "test int32 parse", - args: args{ - toValues("x=123"), - "x", - true, - "int32", - }, - wantP: int32(123), - wantErr: false, - }, - { - name: "test bool parse", - args: args{ - toValues("x=true"), - "x", - true, - "bool", - }, - wantP: bool(true), - wantErr: false, - }, - { - name: "test string parse", - args: args{ - toValues("x=foobar"), - "x", - true, - "string", - }, - wantP: string("foobar"), - wantErr: false, - }, - { - name: "test []int64 parse", - args: args{ - toValues("x=123&x=456"), - "x", - true, - "[]int64", - }, - wantP: []int64{int64(123), int64(456)}, - wantErr: false, - }, - { - name: "test []int64 bad parse", - args: args{ - toValues("x=123&x=4q56"), - "x", - true, - "[]int64", - }, - wantP: nil, - wantErr: true, - }, - { - name: "test []int32 parse", - args: args{ - toValues("x=123&x=456"), - "x", - true, - "[]int32", - }, - wantP: []int32{int32(123), int32(456)}, - wantErr: false, - }, - { - name: "test []int32 bad parse", - args: args{ - toValues("x=123&x=4q56"), - "x", - true, - "[]int32", - }, - wantP: nil, - wantErr: true, - }, - { - name: "test []string parse", - args: args{ - toValues("x=foo&x=bar"), - "x", - true, - "[]string", - }, - wantP: []string{"foo", "bar"}, - wantErr: false, - }, - { - name: "test missing required parameter", - args: args{ - toValues("y=hello"), - "x", - true, - "string", - }, - wantP: nil, - wantErr: true, - }, - { - name: "test unknown type parameter", - args: args{ - toValues("x=hello"), - "x", - true, - "not_a_real_type", - }, - wantP: nil, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotP, err := mappedParam(tt.args.query, tt.args.paramName, tt.args.required, tt.args.dt) - if (err != nil) != tt.wantErr { - t.Errorf("QueryParam() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(gotP, tt.wantP) { - t.Errorf("QueryParam() = %v, want %v", gotP, tt.wantP) - } - }) - } +func TestMappedParam(t *testing.T) { + t.Run("test int64 parse", func(t *testing.T) { + var p int64 + err := mappedParam(toValues("x=123"), "x", &p, true) + if (err != nil) != false { + t.Errorf("QueryParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, int64(123)) { + t.Errorf("QueryParam() = %v, want %v", p, int64(123)) + } + }) + + t.Run("test int32 parse", func(t *testing.T) { + var p int32 + err := mappedParam(toValues("x=123"), "x", &p, true) + if (err != nil) != false { + t.Errorf("QueryParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, int32(123)) { + t.Errorf("QueryParam() = %v, want %v", p, int32(123)) + } + }) + + t.Run("test bool parse", func(t *testing.T) { + var p bool + err := mappedParam(toValues("x=true"), "x", &p, true) + if (err != nil) != false { + t.Errorf("QueryParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, true) { + t.Errorf("QueryParam() = %v, want %v", p, true) + } + }) + + t.Run("test string parse", func(t *testing.T) { + var p string + err := mappedParam(toValues("x=foobar"), "x", &p, true) + if (err != nil) != false { + t.Errorf("QueryParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, "foobar") { + t.Errorf("QueryParam() = %v, want %v", p, "foobar") + } + }) + + t.Run("test []int64 parse", func(t *testing.T) { + var p []int64 + err := mappedParam(toValues("x=123&x=456"), "x", &p, true) + if (err != nil) != false { + t.Errorf("QueryParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, []int64{int64(123), int64(456)}) { + t.Errorf("QueryParam() = %v, want %v", p, []int64{int64(123), int64(456)}) + } + }) + + t.Run("test []int64 bad parse", func(t *testing.T) { + var p []int64 + err := mappedParam(toValues("x=123&x=4q56"), "x", &p, true) + if (err != nil) != true { + t.Errorf("QueryParam() error = %v, wantErr %v", err, true) + return + } + if !reflect.DeepEqual(p, []int64{123}) { + t.Errorf("QueryParam() = %v, want %v", p, []int64{}) + } + }) + + t.Run("test []int32 parse", func(t *testing.T) { + var p []int32 + err := mappedParam(toValues("x=123&x=456"), "x", &p, true) + if (err != nil) != false { + t.Errorf("QueryParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, []int32{int32(123), int32(456)}) { + t.Errorf("QueryParam() = %v, want %v", p, []int32{int32(123), int32(456)}) + } + }) + + t.Run("test []int32 bad parse", func(t *testing.T) { + var p []int32 + err := mappedParam(toValues("x=123&x=4q56"), "x", &p, true) + if (err != nil) != true { + t.Errorf("QueryParam() error = %v, wantErr %v", err, true) + return + } + if !reflect.DeepEqual(p, []int32{123}) { + t.Errorf("QueryParam() = %v, want %v", p, []int32{123}) + } + }) + + t.Run("test []string parse", func(t *testing.T) { + var p []string + err := mappedParam(toValues("x=foo&x=bar"), "x", &p, true) + if (err != nil) != false { + t.Errorf("QueryParam() error = %v, wantErr %v", err, false) + return + } + if !reflect.DeepEqual(p, []string{"foo", "bar"}) { + t.Errorf("QueryParam() = %v, want %v", p, []string{"foo", "bar"}) + } + }) + + t.Run("test missing required parameter", func(t *testing.T) { + var p string + err := mappedParam(toValues("y=hello"), "x", &p, true) + if (err != nil) != true { + t.Errorf("QueryParam() error = %v, wantErr %v", err, true) + return + } + if !reflect.DeepEqual(p, "") { + t.Errorf("QueryParam() = %v, want %v", p, "") + } + }) + + t.Run("test unknown type parameter", func(t *testing.T) { + var p complex64 + err := mappedParam(toValues("x=hello"), "x", &p, true) + if (err != nil) != true { + t.Errorf("QueryParam() error = %v, wantErr %v", err, true) + return + } + if !reflect.DeepEqual(p, complex64(0)) { + t.Errorf("QueryParam() = %v, want %v", p, complex64(0)) + } + }) } func TestBodyParam(t *testing.T) { @@ -310,3 +277,11 @@ func TestBodyParam(t *testing.T) { }) } } + +func Ptr[T any](v T) *T { + return &v +} + +func NilString() *string { + return nil +}