Skip to content

Commit b755d9b

Browse files
committed
Merge remote-tracking branch 'origin/master' into develop
2 parents a24191a + bc4cdac commit b755d9b

File tree

6 files changed

+208
-10
lines changed

6 files changed

+208
-10
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[bumpversion]
22
commit = True
33
tag = True
4-
current_version = 0.3.0
4+
current_version = 0.3.1
55

66
[bumpversion:file:encode.go]
77

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
fail-fast: false
1515
matrix:
16-
go: ['1.11', '1.12', '1.13', '1.14', '1.15', '1.16', '1.17']
16+
go: ['1.11', '1.12', '1.13', '1.14', '1.15', '1.16', '1.17', '1.18']
1717

1818
steps:
1919
- uses: actions/setup-go@v1

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## [0.3.1] (2022-04-09)
4+
5+
* fix Decode: don't fill value for struct fields that don't exist in header
6+
7+
38
## [0.3.0] (2021-01-24)
49

510
* add `func Decode(header http.Header, v interface{}) error` to support decoding headers into struct
@@ -21,3 +26,4 @@
2126
[0.2.0]: https://github.com/mozillazg/go-httpheader/compare/v0.1.0...v0.2.0
2227
[0.2.1]: https://github.com/mozillazg/go-httpheader/compare/v0.2.0...v0.2.1
2328
[0.3.0]: https://github.com/mozillazg/go-httpheader/compare/v0.2.1...v0.3.0
29+
[0.3.1]: https://github.com/mozillazg/go-httpheader/compare/v0.3.0...v0.3.1

decode.go

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type Decoder interface {
2020
func Decode(header http.Header, v interface{}) error {
2121
val := reflect.ValueOf(v)
2222
if val.Kind() != reflect.Ptr || val.IsNil() {
23-
return fmt.Errorf("v should be point and should not be nil")
23+
return fmt.Errorf("v should be a pointer and should not be nil")
2424
}
2525

2626
for val.Kind() == reflect.Ptr {
@@ -33,7 +33,12 @@ func Decode(header http.Header, v interface{}) error {
3333
return parseValue(header, val)
3434
}
3535

36+
// parseValue populates the struct fields in val from the header fields.
37+
// Embedded structs are followed recursively (using the rules defined in the
38+
// Values function documentation) breadth-first.
3639
func parseValue(header http.Header, val reflect.Value) error {
40+
var embedded []reflect.Value
41+
3742
typ := val.Type()
3843
for i := 0; i < typ.NumField(); i++ {
3944
sf := typ.Field(i)
@@ -49,6 +54,8 @@ func parseValue(header http.Header, val reflect.Value) error {
4954
name, opts := parseTag(tag)
5055
if name == "" {
5156
if sf.Anonymous && sv.Kind() == reflect.Struct {
57+
// save embedded struct for later processing
58+
embedded = append(embedded, sv)
5259
continue
5360
}
5461
name = sf.Name
@@ -73,16 +80,24 @@ func parseValue(header http.Header, val reflect.Value) error {
7380
}
7481

7582
if sv.Kind() == reflect.Ptr {
83+
valArr, exist := headerValues(header, name)
84+
if !exist {
85+
continue
86+
}
7687
ve := reflect.New(sv.Type().Elem())
77-
if err := fillValues(ve, opts, headerValues(header, name)); err != nil {
88+
if err := fillValues(ve, opts, valArr); err != nil {
7889
return err
7990
}
8091
sv.Set(ve)
8192
continue
8293
}
8394

8495
if sv.Type() == timeType {
85-
if err := fillValues(sv, opts, headerValues(header, name)); err != nil {
96+
valArr, exist := headerValues(header, name)
97+
if !exist {
98+
continue
99+
}
100+
if err := fillValues(sv, opts, valArr); err != nil {
86101
return err
87102
}
88103
continue
@@ -95,7 +110,35 @@ func parseValue(header http.Header, val reflect.Value) error {
95110
continue
96111
}
97112

98-
if err := fillValues(sv, opts, headerValues(header, name)); err != nil {
113+
if sv.Kind() != reflect.Slice && sv.Kind() != reflect.Array && sv.Kind() != reflect.Interface {
114+
vals := header.Values(name)
115+
if len(vals) > 0 {
116+
v := vals[0]
117+
vals = vals[1:]
118+
119+
if err := fillValues(sv, opts, []string{v}); err != nil {
120+
return err
121+
}
122+
123+
header.Del(name)
124+
for _, v := range vals {
125+
header.Add(name, v)
126+
}
127+
}
128+
continue
129+
}
130+
131+
valArr, exist := headerValues(header, name)
132+
if !exist {
133+
continue
134+
}
135+
if err := fillValues(sv, opts, valArr); err != nil {
136+
return err
137+
}
138+
}
139+
140+
for _, f := range embedded {
141+
if err := parseValue(header, f); err != nil {
99142
return err
100143
}
101144
}
@@ -249,6 +292,7 @@ func fillValues(sv reflect.Value, opts tagOptions, valArr []string) error {
249292
return nil
250293
}
251294

252-
func headerValues(h http.Header, key string) []string {
253-
return textproto.MIMEHeader(h)[textproto.CanonicalMIMEHeaderKey(key)]
295+
func headerValues(h http.Header, key string) ([]string, bool) {
296+
vs, ok := textproto.MIMEHeader(h)[textproto.CanonicalMIMEHeaderKey(key)]
297+
return vs, ok
254298
}

decode_test.go

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ func TestDecodeHeader_UnmarshalerWithNilPointer(t *testing.T) {
136136

137137
type simpleStruct struct {
138138
Foo string
139+
Bar int
139140
}
140141

141142
type fullTypeStruct struct {
@@ -243,7 +244,7 @@ func TestDecodeHeader_more_data_type(t *testing.T) {
243244
t.Errorf("Decode returned error: %#v", err)
244245
}
245246
if !reflect.DeepEqual(want, got) {
246-
t.Errorf("want %#v, but got %#v", want, got)
247+
t.Errorf("want/got:\n%#v\n%#v", want, got)
247248
}
248249
}
249250

@@ -388,6 +389,42 @@ func Test_fillValues_errors(t *testing.T) {
388389
},
389390
wantErr: true,
390391
},
392+
{
393+
name: "slice",
394+
args: args{
395+
sv: reflect.New(reflect.TypeOf([]int{})),
396+
opts: tagOptions{},
397+
valArr: []string{"a"},
398+
},
399+
wantErr: true,
400+
},
401+
{
402+
name: "array",
403+
args: args{
404+
sv: reflect.New(reflect.TypeOf([1]int{})),
405+
opts: tagOptions{},
406+
valArr: []string{"a"},
407+
},
408+
wantErr: true,
409+
},
410+
{
411+
name: "time",
412+
args: args{
413+
sv: reflect.New(reflect.TypeOf(time.Time{})),
414+
opts: tagOptions{},
415+
valArr: []string{"a"},
416+
},
417+
wantErr: true,
418+
},
419+
{
420+
name: "time unix",
421+
args: args{
422+
sv: reflect.New(reflect.TypeOf(time.Time{})),
423+
opts: tagOptions{"unix"},
424+
valArr: []string{"a"},
425+
},
426+
wantErr: true,
427+
},
391428
}
392429
for _, tt := range tests {
393430
t.Run(tt.name, func(t *testing.T) {
@@ -397,3 +434,114 @@ func Test_fillValues_errors(t *testing.T) {
397434
})
398435
}
399436
}
437+
438+
func TestDecode_check_header_key_not_present_no_point(t *testing.T) {
439+
h := http.Header{}
440+
h.Set("Length", "100")
441+
442+
var got fullTypeStruct
443+
err := Decode(h, &got)
444+
if err != nil {
445+
t.Errorf("Decode returned error: %#v", err)
446+
}
447+
448+
var want fullTypeStruct
449+
if !reflect.DeepEqual(want, got) {
450+
t.Errorf("want %#v, but got %#v", want, got)
451+
}
452+
}
453+
454+
func TestDecode_check_header_key_not_present_point(t *testing.T) {
455+
type testStruct struct {
456+
A *string
457+
B *fullTypeStruct
458+
C *int
459+
D *[]string
460+
E *[2]string
461+
F interface{}
462+
G *time.Time
463+
}
464+
h := http.Header{}
465+
h.Set("Length", "100")
466+
467+
var got testStruct
468+
err := Decode(h, &got)
469+
if err != nil {
470+
t.Errorf("Decode returned error: %#v", err)
471+
}
472+
473+
var want testStruct
474+
if !reflect.DeepEqual(want, got) {
475+
t.Errorf("want %#v, but got %#v", want, got)
476+
}
477+
if got.A != nil || got.B != nil || got.C != nil || got.D != nil || got.E != nil || got.F != nil || got.G != nil {
478+
t.Error("all fields should be nil")
479+
}
480+
}
481+
482+
func TestDecode_error(t *testing.T) {
483+
h := http.Header{
484+
"Int": []string{"abc"},
485+
}
486+
var got fullTypeStruct
487+
err := Decode(h, &got)
488+
if err == nil {
489+
t.Errorf("expect error, got : %#v", got)
490+
}
491+
}
492+
493+
func TestDecodeHeader_embeddedStructs(t *testing.T) {
494+
tests := []struct {
495+
in http.Header
496+
decode func(http.Header) (interface{}, error)
497+
want interface{}
498+
}{
499+
{
500+
http.Header{"C": []string{"foo"}},
501+
func(h http.Header) (interface{}, error) {
502+
var a A
503+
err := Decode(h, &a)
504+
return a, err
505+
},
506+
A{B{C: "foo"}},
507+
},
508+
{
509+
http.Header{"C": []string{"foo"}},
510+
func(h http.Header) (interface{}, error) {
511+
var d D
512+
err := Decode(h, &d)
513+
return d, err
514+
},
515+
D{B: B{C: ""}, C: "foo"},
516+
},
517+
{
518+
http.Header{"C": []string{"foo", "bar"}},
519+
func(h http.Header) (interface{}, error) {
520+
var d D
521+
err := Decode(h, &d)
522+
return d, err
523+
},
524+
D{B: B{C: "bar"}, C: "foo"},
525+
},
526+
{
527+
http.Header{"C": []string{"foo", "bar"}},
528+
func(h http.Header) (interface{}, error) {
529+
var f F
530+
err := Decode(h, &f)
531+
return f, err
532+
},
533+
F{e{B: B{C: "bar"}, C: "foo"}}, // With unexported embed
534+
},
535+
}
536+
537+
for i, tt := range tests {
538+
v, err := tt.decode(tt.in)
539+
if err != nil {
540+
t.Errorf("%d. Header(%+v) returned error: %v", i, tt.in, err)
541+
}
542+
543+
if !reflect.DeepEqual(tt.want, v) {
544+
t.Errorf("%d. Header(%+v) returned/want:\n%#+v\n%#+v", i, tt.in, v, tt.want)
545+
}
546+
}
547+
}

encode.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import (
2929
const tagName = "header"
3030

3131
// Version ...
32-
const Version = "0.3.0"
32+
const Version = "0.3.1"
3333

3434
var timeType = reflect.TypeOf(time.Time{})
3535
var headerType = reflect.TypeOf(http.Header{})

0 commit comments

Comments
 (0)