From 888e1b9c7df69e6ee8c5272f31c629535b548925 Mon Sep 17 00:00:00 2001 From: Colin Henry Date: Wed, 14 Sep 2022 15:39:41 -0700 Subject: [PATCH] new functions to extract certain types of parameters --- net/http/params.go | 112 +++++++++++++++ net/http/params_test.go | 312 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 424 insertions(+) create mode 100644 net/http/params.go create mode 100644 net/http/params_test.go diff --git a/net/http/params.go b/net/http/params.go new file mode 100644 index 0000000..085b26f --- /dev/null +++ b/net/http/params.go @@ -0,0 +1,112 @@ +package http + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/url" + "strconv" + "strings" +) + +func PathParam(ctx context.Context, Param func(ctx context.Context, paramName string) string, paramName string, required bool, dt string) (p any, err error) { + s := Param(ctx, paramName) + if s == "" && required { + return nil, errors.New("missing required parameter") + } + + switch dt { + case "int64": + p, err = strconv.ParseInt(s, 10, 64) + case "int32": + var x int64 + x, err = strconv.ParseInt(s, 10, 32) + p = int32(x) + case "string": + p = s + default: + err = errors.New("no match for type") + } + + return +} + +func BodyParam(body io.ReadCloser, p any, v func(p any) error) (err error) { + d := json.NewDecoder(body) + if err = d.Decode(p); err == nil { + err = v(p) + } + return +} + +func mappedParam(m map[string][]string, paramName string, required bool, dt string) (p any, err error) { + + var s string + q, exists := m[paramName] + if !exists { // intentionally left empty + } else if len(q) > 1 { + s = strings.Join(q, ",") + } else { + s = q[0] + } + + if s == "" && required { + return nil, errors.New("missing required parameter") + } + + switch dt { + case "int64": + p, err = strconv.ParseInt(s, 10, 64) + case "int32": + var x int64 + x, err = strconv.ParseInt(s, 10, 32) + p = int32(x) + case "bool": + var b bool + b, err = strconv.ParseBool(s) + p = bool(b) + case "string": + p = s + case "[]int64": + str := strings.Split(s, ",") + ints := make([]int64, len(str)) + for i, s := range str { + if v, err := strconv.ParseInt(s, 10, 64); err != nil { + return nil, err + } else { + ints[i] = v + } + } + p = ints + case "[]int32": + str := strings.Split(s, ",") + ints := make([]int32, len(str)) + for i, s := range str { + if v, err := strconv.ParseInt(s, 10, 32); err != nil { + return nil, err + } else { + ints[i] = int32(v) + } + } + p = ints + case "[]string": + p = strings.Split(s, ",") + default: + err = errors.New("no match for type") + } + return +} + +func QueryParam(query url.Values, paramName string, required bool, dt string) (p any, err error) { + return mappedParam(query, paramName, required, dt) +} + +func HeaderParam(h http.Header, paramName string, required bool, dt string) (p any, err error) { + return mappedParam(h, paramName, required, dt) +} + +func FormParam(form url.Values, paramName string, required bool, dt string) (p any, err error) { + return mappedParam(form, paramName, required, dt) +} diff --git a/net/http/params_test.go b/net/http/params_test.go new file mode 100644 index 0000000..8317495 --- /dev/null +++ b/net/http/params_test.go @@ -0,0 +1,312 @@ +package http + +import ( + "context" + "errors" + "io" + "net/url" + "reflect" + "strings" + "testing" +) + +func TestPathParam(t *testing.T) { + type args struct { + ctx context.Context + Param func(ctx context.Context, paramName string) string + paramName string + required bool + dt string + } + tests := []struct { + name string + args args + wantP any + wantErr bool + }{ + { + name: "test int64 parse", + args: args{ + context.WithValue(context.Background(), contextKey("int64id"), "123"), + Param, + "int64id", + true, + "int64", + }, + wantP: int64(123), + wantErr: false, + }, + { + name: "test int32 parse", + args: args{ + context.WithValue(context.Background(), contextKey("int32id"), "123"), + Param, + "int32id", + true, + "int32", + }, + wantP: int32(123), + wantErr: false, + }, + { + name: "test string parse", + args: args{ + context.WithValue(context.Background(), contextKey("stringid"), "foo"), + Param, + "stringid", + true, + "string", + }, + wantP: string("foo"), + wantErr: false, + }, + { + name: "test missing required parameter", + args: args{ + context.WithValue(context.Background(), contextKey("stringid"), ""), + Param, + "stringid", + true, + "string", + }, + wantP: nil, + wantErr: true, + }, + { + name: "test unknown type parameter", + args: args{ + context.WithValue(context.Background(), contextKey("stringid"), "foo"), + Param, + "stringid", + true, + "not_a_real_type", + }, + wantP: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotP, err := PathParam(tt.args.ctx, tt.args.Param, tt.args.paramName, tt.args.required, tt.args.dt) + if (err != nil) != tt.wantErr { + t.Errorf("PathParam() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotP, tt.wantP) { + t.Errorf("PathParam() = %v, want %v", gotP, tt.wantP) + } + }) + } +} + +func toValues(s string) url.Values { + v, _ := url.ParseQuery(s) + return v +} +func TestMappedParam(t *testing.T) { + + type args struct { + query url.Values + paramName string + required bool + dt string + } + tests := []struct { + name string + args args + wantP any + wantErr bool + }{ + { + name: "test int64 parse", + args: args{ + toValues("x=123"), + "x", + true, + "int64", + }, + wantP: int64(123), + wantErr: false, + }, + { + name: "test int32 parse", + args: args{ + toValues("x=123"), + "x", + true, + "int32", + }, + wantP: int32(123), + wantErr: false, + }, + { + name: "test bool parse", + args: args{ + toValues("x=true"), + "x", + true, + "bool", + }, + wantP: bool(true), + wantErr: false, + }, + { + name: "test string parse", + args: args{ + toValues("x=foobar"), + "x", + true, + "string", + }, + wantP: string("foobar"), + wantErr: false, + }, + { + name: "test []int64 parse", + args: args{ + toValues("x=123&x=456"), + "x", + true, + "[]int64", + }, + wantP: []int64{int64(123), int64(456)}, + wantErr: false, + }, + { + name: "test []int64 bad parse", + args: args{ + toValues("x=123&x=4q56"), + "x", + true, + "[]int64", + }, + wantP: nil, + wantErr: true, + }, + { + name: "test []int32 parse", + args: args{ + toValues("x=123&x=456"), + "x", + true, + "[]int32", + }, + wantP: []int32{int32(123), int32(456)}, + wantErr: false, + }, + { + name: "test []int32 bad parse", + args: args{ + toValues("x=123&x=4q56"), + "x", + true, + "[]int32", + }, + wantP: nil, + wantErr: true, + }, + { + name: "test []string parse", + args: args{ + toValues("x=foo&x=bar"), + "x", + true, + "[]string", + }, + wantP: []string{"foo", "bar"}, + wantErr: false, + }, + { + name: "test missing required parameter", + args: args{ + toValues("y=hello"), + "x", + true, + "string", + }, + wantP: nil, + wantErr: true, + }, + { + name: "test unknown type parameter", + args: args{ + toValues("x=hello"), + "x", + true, + "not_a_real_type", + }, + wantP: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotP, err := mappedParam(tt.args.query, tt.args.paramName, tt.args.required, tt.args.dt) + if (err != nil) != tt.wantErr { + t.Errorf("QueryParam() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotP, tt.wantP) { + t.Errorf("QueryParam() = %v, want %v", gotP, tt.wantP) + } + }) + } +} + +func TestBodyParam(t *testing.T) { + type args struct { + body io.ReadCloser + p any + v func(p any) error + } + + type x struct{} + var y x + + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "test happy path", + args: args{ + body: io.NopCloser(strings.NewReader("{}")), + p: &y, + v: func(p any) error { + return nil + }, + }, + wantErr: false, + }, + { + name: "test bad json", + args: args{ + body: io.NopCloser(strings.NewReader("}")), + p: &y, + v: func(p any) error { + return nil + }, + }, + wantErr: true, + }, + { + name: "test validation failed", + args: args{ + body: io.NopCloser(strings.NewReader("{}")), + p: &y, + v: func(p any) error { + return errors.New("validation failed") + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := BodyParam(tt.args.body, tt.args.p, tt.args.v); (err != nil) != tt.wantErr { + t.Errorf("BodyParam() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}