diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 601ba0f1a..08ed6439a 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -59,7 +59,7 @@ jobs: cache: false # don't use cache for self-hosted runners - name: Unit Test - run: go test -covermode=atomic -coverprofile=coverage.txt ./... + run: go test -coverpkg=./... -coverprofile=coverage.txt ./... - uses: codecov/codecov-action@v5 with: diff --git a/go.mod b/go.mod index 0e8d1ef10..91d9fbfca 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/bytedance/sonic v1.14.0 github.com/cloudwego/netpoll v0.7.0 github.com/fsnotify/fsnotify v1.5.4 - github.com/nyaruka/phonenumbers v1.0.55 github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.8.0 @@ -20,7 +19,6 @@ require ( github.com/cloudwego/base64x v0.1.5 // indirect github.com/cloudwego/gopkg v0.1.4 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/golang/protobuf v1.5.0 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect diff --git a/go.sum b/go.sum index 8c9c2c748..090ed49c5 100644 --- a/go.sum +++ b/go.sum @@ -18,11 +18,7 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= @@ -31,8 +27,6 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= -github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -106,8 +100,6 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/tagexpr/README.md b/internal/tagexpr/README.md index b248ac3b4..cead3658b 100644 --- a/internal/tagexpr/README.md +++ b/internal/tagexpr/README.md @@ -1,3 +1,3 @@ # go-tagexpr -originally from https://github.com/bytedance/go-tagexpr +originally from https://github.com/bytedance/go-tagexpr v2.9.2 diff --git a/internal/tagexpr/tagexpr.go b/internal/tagexpr/tagexpr.go index 3842797cd..f88c20718 100644 --- a/internal/tagexpr/tagexpr.go +++ b/internal/tagexpr/tagexpr.go @@ -820,11 +820,10 @@ func FakeBool(v interface{}) bool { } return bol default: - vv := dereferenceValue(reflect.ValueOf(v)) - if vv.IsValid() || vv.IsZero() { - return false - } - return true + // https://github.com/bytedance/go-tagexpr/blob/v2.9.2/tagexpr.go#L801 + // the original implementation either returns false or panics for default case + // we always return false for unsupported types to avoid introducing new behavior + return false } } diff --git a/internal/tagexpr/tagexpr_test.go b/internal/tagexpr/tagexpr_test.go index bdc93b639..bd83ccad3 100644 --- a/internal/tagexpr/tagexpr_test.go +++ b/internal/tagexpr/tagexpr_test.go @@ -16,6 +16,7 @@ package tagexpr_test import ( "encoding/json" + "errors" "fmt" "reflect" "strconv" @@ -869,3 +870,65 @@ func TestHertzIssue1410(t *testing.T) { t.Fatal(err) } } + +func TestFakeBool(t *testing.T) { + // Numeric types - zero values should be false, non-zero should be true + tests := []struct { + input interface{} + expected bool + }{ + // Float types + {float64(0), false}, + {float64(3.14), true}, + {float32(0), false}, + {float32(2.5), true}, + + // Integer types + {int(0), false}, {int(42), true}, + {int8(0), false}, {int8(127), true}, + {int16(0), false}, {int16(32767), true}, + {int32(0), false}, {int32(2147483647), true}, + {int64(0), false}, {int64(9223372036854775807), true}, + + // Unsigned integer types + {uint(0), false}, {uint(1), true}, + {uint8(0), false}, {uint8(255), true}, + {uint16(0), false}, {uint16(65535), true}, + {uint32(0), false}, {uint32(4294967295), true}, + {uint64(0), false}, {uint64(18446744073709551615), true}, + + // String type + {"", false}, + {"hello", true}, + + // Boolean type + {true, true}, + {false, false}, + + // Nil and error types + {nil, false}, + {errors.New("test"), false}, + + // Slice of interfaces - all elements must be truthy for true + {[]interface{}{}, true}, // empty slice -> true + {[]interface{}{1, "hello", true}, true}, // all truthy -> true + {[]interface{}{1, "", true}, false}, // one falsy -> false + {[]interface{}{0, "", false}, false}, // all falsy -> false + {[]interface{}{nil, nil}, false}, // nil values are falsy -> false + + // Unsupported types should return false + {struct{}{}, false}, + {new(int), false}, + {make(chan int), false}, + {func() {}, false}, + {map[string]int{}, false}, + {[3]int{1, 2, 3}, false}, + } + + for _, tt := range tests { + result := tagexpr.FakeBool(tt.input) + if result != tt.expected { + t.Errorf("FakeBool(%v) = %v; want %v", tt.input, result, tt.expected) + } + } +} diff --git a/internal/tagexpr/validator/README.md b/internal/tagexpr/validator/README.md index b3321a671..b827ced93 100644 --- a/internal/tagexpr/validator/README.md +++ b/internal/tagexpr/validator/README.md @@ -1,204 +1,3 @@ -# validator [![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](http://godoc.org/github.com/bytedance/go-tagexpr/v2/validator) +# validator -A powerful validator that supports struct tag expression. - -## Feature - -- Support for a variety of common operator -- Support for accessing arrays, slices, members of the dictionary -- Support access to any field in the current structure -- Support access to nested fields, non-exported fields, etc. -- Support registers validator function expression -- Built-in len, sprintf, regexp, email, phone functions -- Support simple mode, or specify error message mode -- Use offset pointers to directly take values, better performance -- Required go version ≥1.9 - -## Example - -```go -package validator_test - -import ( - "fmt" - - vd "github.com/bytedance/go-tagexpr/v2/validator" -) - -func Example() { - type InfoRequest struct { - Name string `vd:"($!='Alice'||(Age)$==18) && regexp('\\w')"` - Age int `vd:"$>0"` - Email string `vd:"email($)"` - Phone1 string `vd:"phone($)"` - OtherPhones []string `vd:"range($, phone(#v,'CN'))"` - *InfoRequest `vd:"?"` - Info1 *InfoRequest `vd:"?"` - Info2 *InfoRequest `vd:"-"` - } - info := &InfoRequest{ - Name: "Alice", - Age: 18, - Email: "henrylee2cn@gmail.com", - Phone1: "+8618812345678", - OtherPhones: []string{"18812345679", "18812345680"}, - } - fmt.Println(vd.Validate(info)) - - type A struct { - A int `vd:"$<0||$>=100"` - Info interface{} - } - info.Email = "xxx" - a := &A{A: 107, Info: info} - fmt.Println(vd.Validate(a)) - type B struct { - B string `vd:"len($)>1 && regexp('^\\w*$')"` - } - b := &B{"abc"} - fmt.Println(vd.Validate(b) == nil) - - type C struct { - C bool `vd:"@:(S.A)$>0 && !$; msg:'C must be false when S.A>0'"` - S *A - } - c := &C{C: true, S: a} - fmt.Println(vd.Validate(c)) - - type D struct { - d []string `vd:"@:len($)>0 && $[0]=='D'; msg:sprintf('invalid d: %v',$)"` - } - d := &D{d: []string{"x", "y"}} - fmt.Println(vd.Validate(d)) - - type E struct { - e map[string]int `vd:"len($)==$['len']"` - } - e := &E{map[string]int{"len": 2}} - fmt.Println(vd.Validate(e)) - - // Customizes the factory of validation error. - vd.SetErrorFactory(func(failPath, msg string) error { - return fmt.Errorf(`{"succ":false, "error":"validation failed: %s"}`, failPath) - }) - - type F struct { - f struct { - g int `vd:"$%3==0"` - } - } - f := &F{} - f.f.g = 10 - fmt.Println(vd.Validate(f)) - - fmt.Println(vd.Validate(map[string]*F{"a": f})) - fmt.Println(vd.Validate(map[string]map[string]*F{"a": {"b": f}})) - fmt.Println(vd.Validate([]map[string]*F{{"a": f}})) - fmt.Println(vd.Validate(struct { - A []map[string]*F - }{A: []map[string]*F{{"x": f}}})) - fmt.Println(vd.Validate(map[*F]int{f: 1})) - fmt.Println(vd.Validate([][1]*F{{f}})) - fmt.Println(vd.Validate((*F)(nil))) - fmt.Println(vd.Validate(map[string]*F{})) - fmt.Println(vd.Validate(map[string]map[string]*F{})) - fmt.Println(vd.Validate([]map[string]*F{})) - fmt.Println(vd.Validate([]*F{})) - - // Output: - // - // email format is incorrect - // true - // C must be false when S.A>0 - // invalid d: [x y] - // invalid parameter: e - // {"succ":false, "error":"validation failed: f.g"} - // {"succ":false, "error":"validation failed: {v for k=a}.f.g"} - // {"succ":false, "error":"validation failed: {v for k=a}{v for k=b}.f.g"} - // {"succ":false, "error":"validation failed: [0]{v for k=a}.f.g"} - // {"succ":false, "error":"validation failed: A[0]{v for k=x}.f.g"} - // {"succ":false, "error":"validation failed: {k}.f.g"} - // {"succ":false, "error":"validation failed: [0][0].f.g"} - // unsupported data: nil - // - // - // - // -} -``` - -## Syntax - -Struct tag syntax spec: - -``` -type T struct { - // Simple model - Field1 T1 `tagName:"expression"` - // Specify error message mode - Field2 T2 `tagName:"@:expression; msg:expression2"` - // Omit it - Field3 T3 `tagName:"-"` - // Omit it when it is nil - Field4 T4 `tagName:"?"` - ... -} -``` - -|Operator or Operand|Explain| -|-----|---------| -|`true` `false`|boolean| -|`0` `0.0`|float64 "0"| -|`''`|String| -|`\\'`| Escape `'` delims in string| -|`\"`| Escape `"` delims in string| -|`nil`|nil, undefined| -|`!`|not| -|`+`|Digital addition or string splicing| -|`-`|Digital subtraction or negative| -|`*`|Digital multiplication| -|`/`|Digital division| -|`%`|division remainder, as: `float64(int64(a)%int64(b))`| -|`==`|`eq`| -|`!=`|`ne`| -|`>`|`gt`| -|`>=`|`ge`| -|`<`|`lt`| -|`<=`|`le`| -|`&&`|Logic `and`| -|`\|\|`|Logic `or`| -|`()`|Expression group| -|`(X)$`|Struct field value named X| -|`(X.Y)$`|Struct field value named X.Y| -|`$`|Shorthand for `(X)$`, omit `(X)` to indicate current struct field value| -|`(X)$['A']`|Map value with key A or struct A sub-field in the struct field X| -|`(X)$[0]`|The 0th element or sub-field of the struct field X(type: map, slice, array, struct)| -|`len((X)$)`|Built-in function `len`, the length of struct field X| -|`mblen((X)$)`|the length of string field X (character number)| -|`regexp('^\\w*$', (X)$)`|Regular match the struct field X, return boolean| -|`regexp('^\\w*$')`|Regular match the current struct field, return boolean| -|`sprintf('X value: %v', (X)$)`|`fmt.Sprintf`, format the value of struct field X| -|`range(KvExpr, forEachExpr)`|Iterate over an array, slice, or dictionary
- `#k` is the element key var
- `#v` is the element value var
- `##` is the number of elements
- e.g. [example](../spec_range_test.go)| -|`in((X)$, enum_1, ...enum_n)`|Check if the first parameter is one of the enumerated parameters| -|`email((X)$)`|Regular match the struct field X, return true if it is email| -|`phone((X)$,<'defaultRegion'>)`|Regular match the struct field X, return true if it is phone| - - - - - -Operator priority(high -> low): - -* `()` `!` `bool` `float64` `string` `nil` -* `*` `/` `%` -* `+` `-` -* `<` `<=` `>` `>=` -* `==` `!=` -* `&&` -* `||` +originally from https://github.com/bytedance/go-tagexpr v2.9.2 diff --git a/internal/tagexpr/validator/func.go b/internal/tagexpr/validator/func.go index 17800cc65..1ff998122 100644 --- a/internal/tagexpr/validator/func.go +++ b/internal/tagexpr/validator/func.go @@ -18,8 +18,6 @@ import ( "errors" "regexp" - "github.com/nyaruka/phonenumbers" - "github.com/cloudwego/hertz/internal/tagexpr" ) @@ -79,6 +77,19 @@ func init() { }, true) } +// Phone validation always returns true. +// +// Removed github.com/nyaruka/phonenumbers dependency for the following reasons: +// 1. The tagexpr validator package is deprecated +// 2. The phonenumbers library has unresolved issues requiring upgrades +// 3. The phonenumbers library is memory-heavy (loads many objects into memory even when unused) +// +// Since this validator is deprecated, we simply return true for all phone numbers +// instead of maintaining complex validation logic. +func validatePhone(numberToParse, region string) bool { + return true +} + func init() { // phone: defaultRegion is 'CN' MustRegFunc("phone", func(args ...interface{}) error { @@ -102,13 +113,7 @@ func init() { if defaultRegion == "" { defaultRegion = "CN" } - num, err := phonenumbers.Parse(numberToParse, defaultRegion) - if err != nil { - return err - } - matched := phonenumbers.IsValidNumber(num) - if !matched { - // return ErrInvalidWithoutMsg + if !validatePhone(numberToParse, defaultRegion) { return errors.New("phone format is incorrect") } return nil diff --git a/internal/tagexpr/validator/validator_test.go b/internal/tagexpr/validator/validator_test.go index 5cc2d7fb1..0260852e6 100644 --- a/internal/tagexpr/validator/validator_test.go +++ b/internal/tagexpr/validator/validator_test.go @@ -282,7 +282,7 @@ func TestIssue24(t *testing.T) { } data := &SubmitDoctorImportRequest{SubmitDoctorImport: []*SubmitDoctorImportItem{{}}} err := vd.Validate(data, true) - assertEqualError(t, err, "invalid parameter: SubmitDoctorImport[0].Idcard\tinvalid parameter: SubmitDoctorImport[0].PracCertNo\temail format is incorrect\tthe phone number supplied is not a number") + assertEqualError(t, err, "invalid parameter: SubmitDoctorImport[0].Idcard\tinvalid parameter: SubmitDoctorImport[0].PracCertNo\temail format is incorrect") } func TestStructSliceMap(t *testing.T) { diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go index 917d34b35..cec5fa15a 100644 --- a/internal/testutils/testutils.go +++ b/internal/testutils/testutils.go @@ -18,85 +18,10 @@ package testutils import ( "net" - "path" - "reflect" + "testing" "time" - "unsafe" - - "github.com/cloudwego/hertz/pkg/network" - "github.com/cloudwego/hertz/pkg/route" -) - -var ( - engineType = reflect.TypeOf((*route.Engine)(nil)).Elem() - transportType = reflect.TypeOf((*network.Transporter)(nil)).Elem() ) -func unwrapValue(rv reflect.Value) reflect.Value { - for rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface { - rv = rv.Elem() - } - return rv -} - -func getEngine(rv reflect.Value) reflect.Value { - rv = unwrapValue(rv) - if rv.Type() == engineType { - return rv - } - for i := 0; i < rv.NumField(); i++ { - f := unwrapValue(rv.Field(i)) - if f.Type() == engineType { - return f - } - } - panic("not found *route.Engine") -} - -func getNetworkTransporter(rv reflect.Value) reflect.Value { - rv = unwrapValue(rv) - for i := 0; i < rv.NumField(); i++ { - f := rv.Field(i) - if f.Type() == transportType { - return f - } - } - panic("not found network.Transporter") -} - -func getUnexportedField(f reflect.Value) interface{} { - return reflect.NewAt(f.Type(), unsafe.Pointer(f.Addr().Pointer())).Elem().Interface() -} - -// GetListener extracts net.Listener from network.Transporter in route.Engine -func GetListener(v interface{}) net.Listener { - rv := getEngine(reflect.ValueOf(v)) // *route.Engine - rv = getNetworkTransporter(rv) // network.Transporter - - // implemented by network/netpoll & standard - type ListenerIface interface { - Listener() net.Listener - } - - // NOTE: do not access net.Listener directly, it may cause race issue. - // use `Listener()` method - trans := getUnexportedField(rv) - if p, ok := trans.(ListenerIface); ok { - return p.Listener() - } - panic("network.Transporter has no Listener() method") -} - -// GetListenerAddr is shortcut of GetListener(e).Addr().String() -func GetListenerAddr(v interface{}) string { - return GetListener(v).Addr().String() -} - -// GetURL ... -func GetURL(v interface{}, p string) string { - return "http://" + path.Join(GetListenerAddr(v), p) -} - type RouteEngine interface { IsRunning() bool } @@ -110,3 +35,15 @@ func WaitEngineRunning(e RouteEngine) { } panic("not running") } + +// NewTestListener creates a TCP listener on a random available port. +// It calls tb.Fatal if the listener cannot be created. +// The caller is responsible for closing the listener (usually via defer). +func NewTestListener(tb testing.TB) net.Listener { + tb.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + tb.Fatalf("failed to create test listener: %s", err) + } + return ln +} diff --git a/internal/testutils/testutils_test.go b/internal/testutils/testutils_test.go index 8317d6b70..41afd40b8 100644 --- a/internal/testutils/testutils_test.go +++ b/internal/testutils/testutils_test.go @@ -17,50 +17,15 @@ package testutils import ( - "context" - "io" - "net/http" "sync/atomic" "testing" "time" - - "github.com/cloudwego/hertz/pkg/app" - "github.com/cloudwego/hertz/pkg/common/config" - "github.com/cloudwego/hertz/pkg/protocol/consts" - "github.com/cloudwego/hertz/pkg/route" ) -func TestGetListener(t *testing.T) { - msg := "world" - e := route.NewEngine(&config.Options{Network: "tcp", Addr: "127.0.0.1:0"}) - e.GET("/hello", func(ctx context.Context, c *app.RequestContext) { - c.String(consts.StatusOK, msg) - }) - defer e.Shutdown(context.Background()) - - go e.Run() - time.Sleep(20 * time.Millisecond) - - type AppServer struct { - *route.Engine - } - if GetURL(e, "") != GetURL(&AppServer{e}, "") { - t.Fatal("ne") - } - - resp, err := http.Get(GetURL(e, "/hello")) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - if s := string(body); s != msg { - t.Fatal(s, "!=", msg) - } +func TestNewTestListener(t *testing.T) { + ln := NewTestListener(t) + defer ln.Close() + t.Log(ln.Addr()) } type routeEngine struct { diff --git a/pkg/app/client/client.go b/pkg/app/client/client.go index 09a371a7d..7948a20c6 100644 --- a/pkg/app/client/client.go +++ b/pkg/app/client/client.go @@ -547,6 +547,9 @@ func (c *Client) CloseIdleConnections() { for _, v := range c.m { v.CloseIdleConnections() } + for _, v := range c.ms { + v.CloseIdleConnections() + } c.mLock.Unlock() } diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index e408aaff5..21518e9b9 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -54,11 +54,11 @@ import ( "net/http/httptest" "net/url" "os" + "path" "path/filepath" "reflect" "regexp" "runtime" - "strconv" "strings" "sync" "sync/atomic" @@ -92,32 +92,26 @@ func assertNil(err error) { } } -var unixsockPath string - -func TestMain(m *testing.M) { - dir, err := os.MkdirTemp("", "tests-*") - assertNil(err) - unixsockPath = dir - defer os.RemoveAll(dir) - - m.Run() +func waitEngineRunning(e *route.Engine) { + testutils.WaitEngineRunning(e) } -var nextUnixSockID = int32(10000) - -func nextUnixSock() string { - n := atomic.AddInt32(&nextUnixSockID, 1) - return filepath.Join(unixsockPath, strconv.Itoa(int(n))+".sock") +func newTestOptions(t *testing.T) (*config.Options, net.Listener) { + ln := testutils.NewTestListener(t) + opt := config.NewOptions([]config.Option{}) + opt.Listener = ln + opt.Addr = ln.Addr().String() + opt.Network = "tcp" + return opt, ln } -func waitEngineRunning(e *route.Engine) { - testutils.WaitEngineRunning(e) +func fullURL(ln net.Listener, p string) string { + return "http://" + path.Join(ln.Addr().String(), p) } func TestCloseIdleConnections(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) go engine.Run() @@ -163,10 +157,53 @@ func TestCloseIdleConnections(t *testing.T) { }() } +func TestCloseIdleTLSConnections(t *testing.T) { + httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("https response")) + })) + defer httpsServer.Close() + + c, _ := NewClient( + WithTLSConfig(httpsServer.Client().Transport.(*http.Transport).TLSClientConfig), + WithDialTimeout(1*time.Second), + ) + + httpsReq, httpsResp := protocol.AcquireRequest(), protocol.AcquireResponse() + defer func() { + protocol.ReleaseRequest(httpsReq) + protocol.ReleaseResponse(httpsResp) + }() + httpsReq.SetRequestURI(httpsServer.URL) + if err := c.Do(context.Background(), httpsReq, httpsResp); err != nil { + t.Fatalf("HTTPS request failed: %v", err) + } + + c.CloseIdleConnections() + + c.mLock.Lock() + var totalConns int + for _, hc := range c.ms { + totalConns += hc.ConnectionCount() + } + c.mLock.Unlock() + + if totalConns > 0 { + t.Errorf("expected 0 HTTPS idle connections after close, got %d", totalConns) + } + + c.cleanHostClients(true) + + c.mLock.Lock() + defer c.mLock.Unlock() + if len(c.ms) != 0 { + t.Errorf("expected 0 HTTPS host clients, got %d", len(c.ms)) + } +} + func TestClientInvalidURI(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() requests := int64(0) engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { @@ -196,9 +233,8 @@ func TestClientInvalidURI(t *testing.T) { } func TestClientGetWithBody(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { body := ctx.Request.Body() @@ -229,9 +265,8 @@ func TestClientGetWithBody(t *testing.T) { } func TestClientPostBodyStream(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { body := ctx.Request.Body() @@ -269,9 +304,8 @@ func TestClientURLAuth(t *testing.T) { } ch := make(chan string, 1) - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/foo/bar", func(c context.Context, ctx *app.RequestContext) { ch <- string(ctx.Request.Header.Peek(consts.HeaderAuthorization)) @@ -301,9 +335,8 @@ func TestClientURLAuth(t *testing.T) { } func TestClientNilResp(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { @@ -328,8 +361,11 @@ func TestClientNilResp(t *testing.T) { } func TestClientParseConn(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { }) @@ -338,7 +374,7 @@ func TestClientParseConn(t *testing.T) { engine.Close() }() waitEngineRunning(engine) - opt.Addr = testutils.GetListenerAddr(engine) + opt.Addr = ln.Addr().String() c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -364,9 +400,8 @@ func TestClientParseConn(t *testing.T) { } func TestClientPostArgs(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { body := ctx.Request.Body() @@ -402,9 +437,8 @@ func TestClientPostArgs(t *testing.T) { } func TestClientHeaderCase(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { zw := ctx.GetWriter() @@ -434,9 +468,8 @@ func TestClientHeaderCase(t *testing.T) { } func TestClientReadTimeout(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) readtimeout := 50 * time.Millisecond @@ -510,9 +543,8 @@ func TestClientReadTimeout(t *testing.T) { } func TestClientDefaultUserAgent(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { @@ -541,9 +573,8 @@ func TestClientDefaultUserAgent(t *testing.T) { } func TestClientSetUserAgent(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { @@ -572,9 +603,8 @@ func TestClientSetUserAgent(t *testing.T) { } func TestClientNoUserAgent(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { @@ -612,9 +642,8 @@ func TestClientDoWithCustomHeaders(t *testing.T) { "a-b-c-d-f": "", } body := "request body" - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.POST("/foo/bar/baz", func(c context.Context, ctx *app.RequestContext) { @@ -691,9 +720,8 @@ func TestClientDoWithCustomHeaders(t *testing.T) { } func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.Use(func(c context.Context, ctx *app.RequestContext) { @@ -730,9 +758,8 @@ func TestHostClientPendingRequests(t *testing.T) { const concurrency = 10 doneCh := make(chan struct{}) readyCh := make(chan struct{}, concurrency) - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/baz", func(c context.Context, ctx *app.RequestContext) { @@ -817,9 +844,8 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) { timeout = 50 * time.Millisecond wg sync.WaitGroup ) - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { @@ -884,9 +910,8 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) { func TestHostClientMaxConnDuration(t *testing.T) { connectionCloseCount := uint32(0) - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/bbb/cc", func(c context.Context, ctx *app.RequestContext) { @@ -929,9 +954,8 @@ func TestHostClientMaxConnDuration(t *testing.T) { } func TestHostClientMultipleAddrs(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.GET("/baz/aaa", func(c context.Context, ctx *app.RequestContext) { @@ -978,9 +1002,8 @@ func TestHostClientMultipleAddrs(t *testing.T) { } func TestClientFollowRedirects(t *testing.T) { - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) handler := func(c context.Context, ctx *app.RequestContext) { @@ -1084,9 +1107,8 @@ func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) { emptyBodyCount uint8 wg sync.WaitGroup ) - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { @@ -1152,9 +1174,8 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { emptyBodyCount uint8 wg sync.WaitGroup ) - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { @@ -1223,8 +1244,11 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { } func TestNewClient(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { ctx.SetBodyString("pong") @@ -1240,7 +1264,7 @@ func TestNewClient(t *testing.T) { t.Fatal(err) return } - status, resp, err := client.Get(context.Background(), nil, testutils.GetURL(engine, "/ping")) + status, resp, err := client.Get(context.Background(), nil, fullURL(ln, "/ping")) if err != nil { t.Fatal(err) return @@ -1252,8 +1276,11 @@ func TestNewClient(t *testing.T) { } func TestUseShortConnection(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { }) @@ -1262,7 +1289,7 @@ func TestUseShortConnection(t *testing.T) { engine.Close() }() waitEngineRunning(engine) - opt.Addr = testutils.GetListenerAddr(engine) + opt.Addr = ln.Addr().String() c, _ := NewClient(WithKeepAlive(false)) var wg sync.WaitGroup @@ -1270,7 +1297,7 @@ func TestUseShortConnection(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if _, _, err := c.Get(context.Background(), nil, testutils.GetURL(engine, "")); err != nil { + if _, _, err := c.Get(context.Background(), nil, fullURL(ln, "")); err != nil { t.Error(err) return } @@ -1294,8 +1321,11 @@ func TestUseShortConnection(t *testing.T) { } func TestPostWithFormData(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { var ans string @@ -1331,7 +1361,7 @@ func TestPostWithFormData(t *testing.T) { "a": []string{"d", "e"}, "c": []string{"f"}, }) - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodPost) err := client.Do(context.Background(), req, rsp) if err != nil { @@ -1347,8 +1377,11 @@ func TestPostWithFormData(t *testing.T) { } func TestPostWithMultipartField(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { if string(ctx.FormValue("a")) != "1" { @@ -1377,7 +1410,7 @@ func TestPostWithMultipartField(t *testing.T) { "b": "2", } req.SetMethod(consts.MethodPost) - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) req.SetMultipartFormData(data) req.SetMultipartFormData(map[string]string{ "c": "3", @@ -1389,8 +1422,11 @@ func TestPostWithMultipartField(t *testing.T) { } func TestSetFiles(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { form, _ := ctx.MultipartForm() @@ -1419,7 +1455,7 @@ func TestSetFiles(t *testing.T) { protocol.ReleaseResponse(rsp) }() req.SetMethod(consts.MethodPost) - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) files := []string{"../../common/testdata/test.txt", "../../common/testdata/proto/test.proto", "../../common/testdata/test.png", "../../common/testdata/proto/test.pb.go"} defer func() { for _, file := range files { @@ -1439,8 +1475,11 @@ func TestSetFiles(t *testing.T) { } func TestSetMultipartFields(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { t.Log(req.GetHTTP1Request(&ctx.Request).String()) @@ -1493,7 +1532,7 @@ func TestSetMultipartFields(t *testing.T) { }() req.SetMultipartFields(fields...) req.SetMultipartFormData(map[string]string{"a": "1", "b": "2"}) - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodPost) err := client.DoTimeout(context.Background(), req, rsp, 1*time.Second) if err != nil { @@ -1505,8 +1544,11 @@ func TestClientReadResponseBodyStream(t *testing.T) { part1 := "abcdef" part2 := "ghij" + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(ctx context.Context, c *app.RequestContext) { c.String(consts.StatusOK, part1+part2) @@ -1523,7 +1565,7 @@ func TestClientReadResponseBodyStream(t *testing.T) { protocol.ReleaseRequest(req) protocol.ReleaseResponse(resp) }() - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodPost) err := client.Do(context.Background(), req, resp) if err != nil { @@ -1549,8 +1591,11 @@ func TestClientReadResponseBodyStream(t *testing.T) { } func TestWithBasicAuth(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { auth := ctx.GetHeader(consts.HeaderAuthorization) @@ -1580,7 +1625,7 @@ func TestWithBasicAuth(t *testing.T) { // Success req.SetBasicAuth("myuser", "basicauth") - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodGet) err := client.Do(context.Background(), req, rsp) if err != nil { @@ -1593,7 +1638,7 @@ func TestWithBasicAuth(t *testing.T) { // Fail req.Reset() rsp.Reset() - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodGet) err = client.Do(context.Background(), req, rsp) if err != nil { @@ -1934,8 +1979,11 @@ func TestClientReadResponseBodyStreamWithDoubleRequest(t *testing.T) { } part2 := "ghij" + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(ctx context.Context, c *app.RequestContext) { c.String(consts.StatusOK, part1+part2) @@ -1952,7 +2000,7 @@ func TestClientReadResponseBodyStreamWithDoubleRequest(t *testing.T) { protocol.ReleaseRequest(req) protocol.ReleaseResponse(resp) }() - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodPost) err := client.Do(context.Background(), req, resp) if err != nil { @@ -1979,7 +2027,7 @@ func TestClientReadResponseBodyStreamWithDoubleRequest(t *testing.T) { protocol.ReleaseRequest(req1) protocol.ReleaseResponse(resp1) }() - req1.SetRequestURI(testutils.GetURL(engine, "")) + req1.SetRequestURI(fullURL(ln, "")) req1.SetMethod(consts.MethodPost) err = client.Do(context.Background(), req1, resp1) if err != nil { @@ -2007,8 +2055,11 @@ func TestClientReadResponseBodyStreamWithConnectionClose(t *testing.T) { part1 += "a" } + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(ctx context.Context, c *app.RequestContext) { c.String(consts.StatusOK, part1) @@ -2029,7 +2080,7 @@ func TestClientReadResponseBodyStreamWithConnectionClose(t *testing.T) { }() req.SetConnectionClose() req.SetMethod(consts.MethodPost) - req.SetRequestURI(testutils.GetURL(engine, "")) + req.SetRequestURI(fullURL(ln, "")) err := client.Do(context.Background(), req, resp) if err != nil { @@ -2046,7 +2097,7 @@ func TestClientReadResponseBodyStreamWithConnectionClose(t *testing.T) { }() req1.SetConnectionClose() req1.SetMethod(consts.MethodPost) - req1.SetRequestURI(testutils.GetURL(engine, "")) + req1.SetRequestURI(fullURL(ln, "")) err = client.Do(context.Background(), req1, resp1) if err != nil { @@ -2292,9 +2343,8 @@ func TestClientDoWithDialFunc(t *testing.T) { ch := make(chan error, 1) uri := "/foo/bar/baz" body := "request body" - opt := config.NewOptions([]config.Option{}) - opt.Addr = nextUnixSock() - opt.Network = "unix" + opt, ln := newTestOptions(t) + defer ln.Close() engine := route.NewEngine(opt) engine.POST("/foo/bar/baz", func(c context.Context, ctx *app.RequestContext) { @@ -2353,15 +2403,18 @@ func TestClientDoWithDialFunc(t *testing.T) { } func TestClientState(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) - opt.Addr = testutils.GetListenerAddr(engine) + opt.Addr = ln.Addr().String() var wg sync.WaitGroup wg.Add(2) @@ -2393,8 +2446,11 @@ func TestClientState(t *testing.T) { func TestClientRetryErr(t *testing.T) { t.Run("200", func(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) var l sync.Mutex retryNum := 0 @@ -2411,7 +2467,7 @@ func TestClientRetryErr(t *testing.T) { waitEngineRunning(engine) c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) - _, _, err := c.Get(context.Background(), nil, testutils.GetURL(engine, "/ping")) + _, _, err := c.Get(context.Background(), nil, fullURL(ln, "/ping")) assert.Nil(t, err) l.Lock() assert.DeepEqual(t, 1, retryNum) @@ -2419,8 +2475,11 @@ func TestClientRetryErr(t *testing.T) { }) t.Run("502", func(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) var l sync.Mutex retryNum := 0 @@ -2440,7 +2499,7 @@ func TestClientRetryErr(t *testing.T) { c.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { return resp.StatusCode() == 502 }) - _, _, err := c.Get(context.Background(), nil, testutils.GetURL(engine, "/ping")) + _, _, err := c.Get(context.Background(), nil, fullURL(ln, "/ping")) assert.Nil(t, err) l.Lock() assert.DeepEqual(t, 3, retryNum) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index f7fd59221..7eca5b10e 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -152,12 +152,15 @@ func TestBind_SliceType(t *testing.T) { ID *[]int `query:"id"` Str [3]string `query:"str"` Byte []byte `query:"b"` + HH []string `header:"h"` } IDs := []int{11, 12, 13} Strs := [3]string{"qwe", "asd", "zxc"} Bytes := []byte("123") + Headers := []string{"header"} req := newMockRequest(). + SetHeaders("H", Headers[0]). SetRequestURI(fmt.Sprintf("http://foobar.com?id=%d&id=%d&id=%d&str=%s&str=%s&str=%s&b=%d&b=%d&b=%d", IDs[0], IDs[1], IDs[2], Strs[0], Strs[1], Strs[2], Bytes[0], Bytes[1], Bytes[2])) var result Req @@ -178,6 +181,7 @@ func TestBind_SliceType(t *testing.T) { for idx, val := range Bytes { assert.DeepEqual(t, val, result.Byte[idx]) } + assert.DeepEqual(t, Headers, result.HH) } func TestBind_StructType(t *testing.T) { diff --git a/pkg/app/server/binding/internal/decoder/slice_getter.go b/pkg/app/server/binding/internal/decoder/slice_getter.go index 27d2b4174..763c00817 100644 --- a/pkg/app/server/binding/internal/decoder/slice_getter.go +++ b/pkg/app/server/binding/internal/decoder/slice_getter.go @@ -122,16 +122,10 @@ func cookieSlice(req *protocol.Request, params param.Params, key string, default } func headerSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { - req.Header.VisitAll(func(headerKey, value []byte) { - if bytesconv.B2s(headerKey) == key { - ret = append(ret, string(value)) - } - }) - - if len(ret) == 0 && len(defaultValue) != 0 { - ret = append(ret, defaultValue...) + ret = defaultValue + if vv := req.Header.GetAll(key); len(vv) > 0 { + ret = vv } - return } diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 1d331e163..4f177eb21 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -470,7 +470,7 @@ func TestNonstruct(t *testing.T) { } t.Logf("%s", b) - bodyReader = strings.NewReader("b=334ddddd&token=yoMba34uspjVQEbhflgTRe2ceeDFUK32&type=url_verification") + bodyReader = strings.NewReader("b=334ddddd&token=mymocktoken&type=url_verification") header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") req = newRequest("", header, nil, bodyReader) recv = nil diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 5a2fd4d5f..1636a9984 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -26,6 +26,7 @@ import ( "io/ioutil" "net" "net/http" + "path" "strings" "sync" "sync/atomic" @@ -40,7 +41,6 @@ import ( "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" - "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/utils" @@ -60,8 +60,14 @@ func waitEngineRunning(e routeEngine) { testutils.WaitEngineRunning(e) } +func fullURL(ln net.Listener, p string) string { + return "http://" + path.Join(ln.Addr().String(), p) +} + func TestHertz_Run(t *testing.T) { - hertz := Default(WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + hertz := Default(WithListener(ln)) hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) { time.Sleep(time.Second) path := ctx.Request.URI().PathOriginal() @@ -79,16 +85,18 @@ func TestHertz_Run(t *testing.T) { waitEngineRunning(hertz) hertz.Close() - resp, err := http.Get(testutils.GetURL(hertz, "/test")) + resp, err := http.Get(fullURL(ln, "/test")) assert.NotNil(t, err) assert.Nil(t, resp) assert.DeepEqual(t, uint32(0), atomic.LoadUint32(&testint)) } func TestHertz_GracefulShutdown(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() handling := make(chan struct{}) closing := make(chan struct{}) - engine := New(WithHostPorts("127.0.0.1:0")) + engine := New(WithListener(ln)) engine.GET("/test", func(c context.Context, ctx *app.RequestContext) { close(handling) <-closing @@ -123,7 +131,7 @@ func TestHertz_GracefulShutdown(t *testing.T) { defer ticker.Stop() for range ticker.C { t.Logf("[%v]begin listening\n", time.Now()) - _, err2 := hc.Get(testutils.GetURL(engine, "/test2")) + _, err2 := hc.Get(fullURL(ln, "/test2")) if err2 != nil { t.Logf("[%v]listening closed: %v", time.Now(), err2) ch2 <- struct{}{} @@ -133,7 +141,7 @@ func TestHertz_GracefulShutdown(t *testing.T) { }() go func() { t.Logf("[%v]begin request\n", time.Now()) - resp, err = http.Get(testutils.GetURL(engine, "/test")) + resp, err = http.Get(fullURL(ln, "/test")) t.Logf("[%v]end request\n", time.Now()) ch <- struct{}{} }() @@ -162,7 +170,9 @@ func TestHertz_GracefulShutdown(t *testing.T) { } func TestLoadHTMLGlob(t *testing.T) { - engine := New(WithMaxRequestBodySize(15), WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + engine := New(WithMaxRequestBodySize(15), WithListener(ln)) engine.Delims("{[{", "}]}") engine.LoadHTMLGlob("../../common/testdata/template/index.tmpl") engine.GET("/index", func(c context.Context, ctx *app.RequestContext) { @@ -176,7 +186,7 @@ func TestLoadHTMLGlob(t *testing.T) { }() waitEngineRunning(engine) - resp, _ := http.Get(testutils.GetURL(engine, "/index")) + resp, _ := http.Get(fullURL(ln, "/index")) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) n, _ := resp.Body.Read(b) @@ -186,7 +196,9 @@ func TestLoadHTMLGlob(t *testing.T) { } func TestLoadHTMLFiles(t *testing.T) { - engine := New(WithMaxRequestBodySize(15), WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + engine := New(WithMaxRequestBodySize(15), WithListener(ln)) engine.Delims("{[{", "}]}") engine.SetFuncMap(template.FuncMap{ "formatAsDate": formatAsDate, @@ -204,7 +216,7 @@ func TestLoadHTMLFiles(t *testing.T) { }() waitEngineRunning(engine) - resp, _ := http.Get(testutils.GetURL(engine, "/raw")) + resp, _ := http.Get(fullURL(ln, "/raw")) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) n, _ := resp.Body.Read(b) @@ -239,29 +251,31 @@ func Test_getServerName(t *testing.T) { } func TestServer_Run(t *testing.T) { - hertz := New(WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + hertz := New(WithListener(ln)) hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) { path := ctx.Request.URI().PathOriginal() ctx.SetBodyString(string(path)) }) hertz.POST("/redirect", func(c context.Context, ctx *app.RequestContext) { - ctx.Redirect(consts.StatusMovedPermanently, []byte(testutils.GetURL(hertz, "/test"))) + ctx.Redirect(consts.StatusMovedPermanently, []byte(fullURL(ln, "/test"))) }) go hertz.Run() waitEngineRunning(hertz) - resp, err := http.Get(testutils.GetURL(hertz, "/test")) + resp, err := http.Get(fullURL(ln, "/test")) assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 5) resp.Body.Read(b) assert.DeepEqual(t, "/test", string(b)) - resp, err = http.Get(testutils.GetURL(hertz, "/foo")) + resp, err = http.Get(fullURL(ln, "/foo")) assert.Nil(t, err) assert.DeepEqual(t, consts.StatusNotFound, resp.StatusCode) - resp, err = http.Post(testutils.GetURL(hertz, "/redirect"), "", nil) + resp, err = http.Post(fullURL(ln, "/redirect"), "", nil) assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b = make([]byte, 5) @@ -274,7 +288,9 @@ func TestServer_Run(t *testing.T) { } func TestNotAbsolutePath(t *testing.T) { - engine := New(WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + engine := New(WithListener(ln)) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { ctx.Write(ctx.Request.Body()) }) @@ -311,7 +327,9 @@ func TestNotAbsolutePath(t *testing.T) { } func TestNotAbsolutePathWithRawPath(t *testing.T) { - engine := New(WithHostPorts("127.0.0.1:0"), WithUseRawPath(true)) + ln := testutils.NewTestListener(t) + defer ln.Close() + engine := New(WithListener(ln), WithUseRawPath(true)) const ( MiddlewareKey = "middleware_key" MiddlewareValue = "middleware_value" @@ -357,7 +375,9 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { } func TestNotValidHost(t *testing.T) { - engine := New(WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + engine := New(WithListener(ln)) const ( MiddlewareKey = "middleware_key" MiddlewareValue = "middleware_value" @@ -398,7 +418,9 @@ func TestNotValidHost(t *testing.T) { } func TestWithBasePath(t *testing.T) { - engine := New(WithBasePath("/hertz"), WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + engine := New(WithBasePath("/hertz"), WithListener(ln)) engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() @@ -411,13 +433,15 @@ func TestWithBasePath(t *testing.T) { r.ParseForm() r.Form.Add("xxxxxx", "xxx") body := strings.NewReader(r.Form.Encode()) - resp, err := http.Post(testutils.GetURL(engine, "/hertz/test"), "application/x-www-form-urlencoded", body) + resp, err := http.Post(fullURL(ln, "/hertz/test"), "application/x-www-form-urlencoded", body) assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) } func TestNotEnoughBodySize(t *testing.T) { - engine := New(WithMaxRequestBodySize(5), WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + engine := New(WithMaxRequestBodySize(5), WithListener(ln)) engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() @@ -430,7 +454,7 @@ func TestNotEnoughBodySize(t *testing.T) { r.ParseForm() r.Form.Add("xxxxxx", "xxx") body := strings.NewReader(r.Form.Encode()) - resp, err := http.Post(testutils.GetURL(engine, "/test"), "application/x-www-form-urlencoded", body) + resp, err := http.Post(fullURL(ln, "/test"), "application/x-www-form-urlencoded", body) assert.Nil(t, err) assert.DeepEqual(t, 413, resp.StatusCode) bodyBytes, _ := ioutil.ReadAll(resp.Body) @@ -438,7 +462,9 @@ func TestNotEnoughBodySize(t *testing.T) { } func TestEnoughBodySize(t *testing.T) { - engine := New(WithMaxRequestBodySize(15), WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + engine := New(WithMaxRequestBodySize(15), WithListener(ln)) engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() @@ -451,7 +477,7 @@ func TestEnoughBodySize(t *testing.T) { r.ParseForm() r.Form.Add("xxxxxx", "xxx") body := strings.NewReader(r.Form.Encode()) - resp, _ := http.Post(testutils.GetURL(engine, "/test"), "application/x-www-form-urlencoded", body) + resp, _ := http.Post(fullURL(ln, "/test"), "application/x-www-form-urlencoded", body) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) } @@ -557,8 +583,10 @@ func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStat } func TestParamInconsist(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() mapS := sync.Map{} - h := New(WithHostPorts("127.0.0.1:0")) + h := New(WithListener(ln)) h.GET("/:label", func(c context.Context, ctx *app.RequestContext) { label := ctx.Param("label") x, _ := mapS.LoadOrStore(label, label) @@ -575,13 +603,13 @@ func TestParamInconsist(t *testing.T) { tr := func() { defer wg.Done() for i := 0; i < 500; i++ { - client.Get(context.Background(), nil, testutils.GetURL(h, "/test1")) + client.Get(context.Background(), nil, fullURL(ln, "/test1")) } } ti := func() { defer wg.Done() for i := 0; i < 500; i++ { - client.Get(context.Background(), nil, testutils.GetURL(h, "/test2")) + client.Get(context.Background(), nil, fullURL(ln, "/test2")) } } @@ -594,7 +622,9 @@ func TestParamInconsist(t *testing.T) { } func TestDuplicateReleaseBodyStream(t *testing.T) { - h := New(WithStreamBody(true), WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + h := New(WithStreamBody(true), WithListener(ln)) h.POST("/test", func(ctx context.Context, c *app.RequestContext) { stream := c.RequestBodyStream() c.Response.SetBodyStream(stream, -1) @@ -617,7 +647,7 @@ func TestDuplicateReleaseBodyStream(t *testing.T) { wg := sync.WaitGroup{} testFunc := func() { defer wg.Done() - r := protocol.NewRequest("POST", testutils.GetURL(h, "/test"), nil) + r := protocol.NewRequest("POST", fullURL(ln, "/test"), nil) r.SetBodyString(body) resp := protocol.AcquireResponse() err := client.Do(context.Background(), r, resp) @@ -639,6 +669,8 @@ func TestDuplicateReleaseBodyStream(t *testing.T) { func TestServiceRegisterFailed(t *testing.T) { t.Parallel() // slow test, make it parallel + ln := testutils.NewTestListener(t) + defer ln.Close() mockRegErr := errors.New("mock register error") var rCount int32 var drCount int32 @@ -654,7 +686,7 @@ func TestServiceRegisterFailed(t *testing.T) { } var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, nil)) - opts = append(opts, WithHostPorts("127.0.0.1:0")) + opts = append(opts, WithListener(ln)) srv := New(opts...) srv.Spin() assert.Assert(t, atomic.LoadInt32(&rCount) == 1) @@ -663,6 +695,8 @@ func TestServiceRegisterFailed(t *testing.T) { func TestServiceDeregisterFailed(t *testing.T) { t.Parallel() // slow test, make it parallel + ln := testutils.NewTestListener(t) + defer ln.Close() mockDeregErr := errors.New("mock deregister error") var wg sync.WaitGroup @@ -684,7 +718,7 @@ func TestServiceDeregisterFailed(t *testing.T) { var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, nil)) - opts = append(opts, WithHostPorts("127.0.0.1:0")) + opts = append(opts, WithListener(ln)) srv := New(opts...) go srv.Spin() waitEngineRunning(srv) @@ -701,6 +735,8 @@ func TestServiceDeregisterFailed(t *testing.T) { func TestServiceRegistryInfo(t *testing.T) { t.Parallel() // slow test, make it parallel + ln := testutils.NewTestListener(t) + defer ln.Close() registryInfo := ®istry.Info{ Weight: 100, Tags: map[string]string{"aa": "bb"}, @@ -733,7 +769,7 @@ func TestServiceRegistryInfo(t *testing.T) { } var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, registryInfo)) - opts = append(opts, WithHostPorts("127.0.0.1:0")) + opts = append(opts, WithListener(ln)) srv := New(opts...) go srv.Spin() waitEngineRunning(srv) @@ -749,6 +785,8 @@ func TestServiceRegistryInfo(t *testing.T) { func TestServiceRegistryNoInitInfo(t *testing.T) { t.Parallel() // slow test, make it parallel + ln := testutils.NewTestListener(t) + defer ln.Close() checkInfo := func(info *registry.Info) { assert.Assert(t, info == nil) } @@ -773,7 +811,7 @@ func TestServiceRegistryNoInitInfo(t *testing.T) { } var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, nil)) - opts = append(opts, WithHostPorts("127.0.0.1:0")) + opts = append(opts, WithListener(ln)) srv := New(opts...) go srv.Spin() waitEngineRunning(srv) @@ -801,7 +839,9 @@ func (t testTracer) Start(ctx context.Context, c *app.RequestContext) context.Co func (t testTracer) Finish(ctx context.Context, c *app.RequestContext) {} func TestReuseCtx(t *testing.T) { - h := New(WithTracer(testTracer{}), WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + h := New(WithTracer(testTracer{}), WithListener(ln)) h.GET("/ping", func(ctx context.Context, c *app.RequestContext) { assert.DeepEqual(t, 0, ctx.Value("testKey").(int)) }) @@ -810,7 +850,7 @@ func TestReuseCtx(t *testing.T) { waitEngineRunning(h) for i := 0; i < 1000; i++ { - _, _, err := c.Get(context.Background(), nil, testutils.GetURL(h, "/ping")) + _, _, err := c.Get(context.Background(), nil, fullURL(ln, "/ping")) assert.Nil(t, err) } } @@ -820,8 +860,10 @@ type CloseWithoutResetBuffer interface { } func TestOnprepare(t *testing.T) { + ln1 := testutils.NewTestListener(t) + defer ln1.Close() h1 := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln1), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { b, err := conn.Peek(3) assert.Nil(t, err) @@ -840,33 +882,37 @@ func TestOnprepare(t *testing.T) { go h1.Spin() waitEngineRunning(h1) - _, _, err := c.Get(context.Background(), nil, testutils.GetURL(h1, "/ping")) + _, _, err := c.Get(context.Background(), nil, fullURL(ln1, "/ping")) assert.DeepEqual(t, "the server closed connection before returning the first response byte. Make sure the server returns 'Connection: close' response header before closing the connection", err.Error()) + ln2 := testutils.NewTestListener(t) + defer ln2.Close() h2 := New( WithOnAccept(func(conn net.Conn) context.Context { conn.Close() return context.Background() }), - WithHostPorts("127.0.0.1:0")) + WithListener(ln2)) h2.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go h2.Spin() waitEngineRunning(h2) - _, _, err = c.Get(context.Background(), nil, testutils.GetURL(h2, "/ping")) + _, _, err = c.Get(context.Background(), nil, fullURL(ln2, "/ping")) if err == nil { t.Fatalf("err should not be nil") } + ln3 := testutils.NewTestListener(t) + defer ln3.Close() var h3 *Hertz h3 = New( WithOnAccept(func(conn net.Conn) context.Context { - assert.DeepEqual(t, conn.LocalAddr().String(), testutils.GetListenerAddr(h3)) + assert.DeepEqual(t, conn.LocalAddr().String(), ln3.Addr().String()) return context.Background() }), - WithHostPorts("127.0.0.1:0"), + WithListener(ln3), WithTransport(standard.NewTransporter)) h3.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) @@ -874,7 +920,7 @@ func TestOnprepare(t *testing.T) { go h3.Spin() waitEngineRunning(h3) - c.Get(context.Background(), nil, testutils.GetURL(h3, "/ping")) + c.Get(context.Background(), nil, fullURL(ln3, "/ping")) } type lockBuffer struct { @@ -894,32 +940,11 @@ func (l *lockBuffer) String() string { return l.b.String() } -func TestSilentMode(t *testing.T) { - hlog.SetSilentMode(true) - b := &lockBuffer{b: bytes.Buffer{}} - - hlog.SetOutput(b) - - h := New(WithHostPorts("127.0.0.1:0"), WithTransport(standard.NewTransporter)) - h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { - ctx.Write([]byte("hello, world")) - }) - go h.Spin() - waitEngineRunning(h) - - d := standard.NewDialer() - conn, _ := d.DialConnection("tcp", testutils.GetListenerAddr(h), 0, nil) - conn.Write([]byte("aaa")) - conn.Close() - - if strings.Contains(b.String(), "Error") { - t.Fatalf("unexpected error in log: %s", b.String()) - } -} - func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() h := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln), WithDisableHeaderNamesNormalizing(true), ) headerName := "CASE-senSITive-HEAder-NAME" @@ -944,7 +969,7 @@ func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { cli, _ := c.NewClient(c.WithDisableHeaderNamesNormalizing(true)) - r := protocol.NewRequest("GET", testutils.GetURL(h, "/test"), nil) + r := protocol.NewRequest("GET", fullURL(ln, "/test"), nil) r.Header.DisableNormalizing() r.Header.Set(headerName, headerValue) res := protocol.AcquireResponse() @@ -959,8 +984,10 @@ func TestBindConfig(t *testing.T) { } bindConfig := binding.NewBindConfig() bindConfig.LooseZeroMode = true + ln := testutils.NewTestListener(t) + defer ln.Close() h := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln), WithBindConfig(bindConfig)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -974,13 +1001,15 @@ func TestBindConfig(t *testing.T) { waitEngineRunning(h) hc := http.Client{Timeout: time.Second} - _, err := hc.Get(testutils.GetURL(h, "/bind?a=")) + _, err := hc.Get(fullURL(ln, "/bind?a=")) assert.Nil(t, err) bindConfig = binding.NewBindConfig() bindConfig.LooseZeroMode = false + ln2 := testutils.NewTestListener(t) + defer ln2.Close() h2 := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln2), WithBindConfig(bindConfig)) h2.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -993,7 +1022,7 @@ func TestBindConfig(t *testing.T) { go h2.Spin() waitEngineRunning(h2) - _, err = hc.Get(testutils.GetURL(h2, "/bind?a=")) + _, err = hc.Get(fullURL(ln2, "/bind?a=")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } @@ -1002,8 +1031,10 @@ func TestCustomBinder(t *testing.T) { type Req struct { A int `query:"a"` } + ln := testutils.NewTestListener(t) + defer ln.Close() h := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln), WithCustomBinder(binder.NewBinderWithValidateError(errors.New("test binder")))) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -1018,7 +1049,7 @@ func TestCustomBinder(t *testing.T) { waitEngineRunning(h) hc := http.Client{Timeout: time.Second} - _, err := hc.Get(testutils.GetURL(h, "/bind?a=")) + _, err := hc.Get(fullURL(ln, "/bind?a=")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } @@ -1031,7 +1062,9 @@ func TestValidateConfigRegValidateFunc(t *testing.T) { validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { return fmt.Errorf("test validator") }) - h := New(WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + h := New(WithListener(ln)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) @@ -1045,7 +1078,7 @@ func TestValidateConfigRegValidateFunc(t *testing.T) { waitEngineRunning(h) hc := http.Client{Timeout: time.Second} - _, err := hc.Get(testutils.GetURL(h, "/bind?a=2")) + _, err := hc.Get(fullURL(ln, "/bind?a=2")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } @@ -1054,8 +1087,10 @@ func TestCustomValidator(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` } + ln := testutils.NewTestListener(t) + defer ln.Close() h := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln), WithCustomValidatorFunc(func(_ *protocol.Request, _ interface{}) error { return errors.New("test mock validator") })) @@ -1071,7 +1106,7 @@ func TestCustomValidator(t *testing.T) { go h.Spin() time.Sleep(100 * time.Millisecond) hc := http.Client{Timeout: time.Second} - _, err := hc.Get(testutils.GetURL(h, "/bind?a=2")) + _, err := hc.Get(fullURL(ln, "/bind?a=2")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } @@ -1103,8 +1138,10 @@ func TestValidateConfigSetSetErrorFactory(t *testing.T) { } validateConfig := binding.NewValidateConfig() validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) + ln := testutils.NewTestListener(t) + defer ln.Close() h := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln), WithValidateConfig(validateConfig)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req TestValidate @@ -1119,7 +1156,7 @@ func TestValidateConfigSetSetErrorFactory(t *testing.T) { waitEngineRunning(h) hc := http.Client{Timeout: time.Second} - _, err := hc.Get(testutils.GetURL(h, "/bind?b=1")) + _, err := hc.Get(fullURL(ln, "/bind?b=1")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } @@ -1130,8 +1167,10 @@ func TestValidateConfigAndBindConfig(t *testing.T) { } validateConfig := binding.NewValidateConfig() validateConfig.ValidateTag = "vt" + ln := testutils.NewTestListener(t) + defer ln.Close() h := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln), WithValidateConfig(validateConfig)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -1146,14 +1185,16 @@ func TestValidateConfigAndBindConfig(t *testing.T) { waitEngineRunning(h) hc := http.Client{Timeout: time.Second} - _, err := hc.Get(testutils.GetURL(h, "/bind?a=135")) + _, err := hc.Get(fullURL(ln, "/bind?a=135")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } func TestWithDisableDefaultDate(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() h := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln), WithDisableDefaultDate(true), ) h.GET("/", func(_ context.Context, c *app.RequestContext) {}) @@ -1161,13 +1202,15 @@ func TestWithDisableDefaultDate(t *testing.T) { waitEngineRunning(h) hc := http.Client{Timeout: time.Second} - r, _ := hc.Get(testutils.GetURL(h, "")) //nolint:errcheck + r, _ := hc.Get(fullURL(ln, "")) //nolint:errcheck assert.DeepEqual(t, "", r.Header.Get("Date")) } func TestWithDisableDefaultContentType(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() h := New( - WithHostPorts("127.0.0.1:0"), + WithListener(ln), WithDisableDefaultContentType(true), ) h.GET("/", func(_ context.Context, c *app.RequestContext) {}) @@ -1175,13 +1218,15 @@ func TestWithDisableDefaultContentType(t *testing.T) { waitEngineRunning(h) hc := http.Client{Timeout: time.Second} - r, _ := hc.Get(testutils.GetURL(h, "")) //nolint:errcheck + r, _ := hc.Get(fullURL(ln, "")) //nolint:errcheck assert.DeepEqual(t, "", r.Header.Get("Content-Type")) } func TestWithSenseClientDisconnection(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() var closeFlag int32 - h := New(WithHostPorts("127.0.0.1:0"), WithSenseClientDisconnection(true)) + h := New(WithListener(ln), WithSenseClientDisconnection(true)) h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { assert.DeepEqual(t, "aa", string(ctx.Host())) ch := make(chan struct{}) @@ -1195,7 +1240,7 @@ func TestWithSenseClientDisconnection(t *testing.T) { go h.Spin() waitEngineRunning(h) - con, err := net.Dial("tcp", testutils.GetListenerAddr(h)) + con, err := net.Dial("tcp", ln.Addr().String()) assert.Nil(t, err) _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) assert.Nil(t, err) @@ -1207,8 +1252,10 @@ func TestWithSenseClientDisconnection(t *testing.T) { } func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() var closeFlag int32 - h := New(WithHostPorts("127.0.0.1:0"), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { + h := New(WithListener(ln), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { return ctx })) h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { @@ -1224,7 +1271,7 @@ func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) { go h.Spin() waitEngineRunning(h) - con, err := net.Dial("tcp", testutils.GetListenerAddr(h)) + con, err := net.Dial("tcp", ln.Addr().String()) assert.Nil(t, err) _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) assert.Nil(t, err) @@ -1236,7 +1283,9 @@ func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) { } func TestServerReturns413And431OnSizeLimits(t *testing.T) { - h := Default(WithHostPorts("127.0.0.1:0"), WithMaxHeaderBytes(500), WithMaxRequestBodySize(1000)) + ln := testutils.NewTestListener(t) + defer ln.Close() + h := Default(WithListener(ln), WithMaxHeaderBytes(500), WithMaxRequestBodySize(1000)) h.GET("/test", func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, "success") @@ -1249,7 +1298,7 @@ func TestServerReturns413And431OnSizeLimits(t *testing.T) { waitEngineRunning(h) defer h.Shutdown(context.Background()) - addr := testutils.GetListenerAddr(h) + addr := ln.Addr().String() client := &http.Client{Timeout: 2 * time.Second} // Test 431 - Request Header Fields Too Large diff --git a/pkg/app/server/hertz_unix_test.go b/pkg/app/server/hertz_unix_test.go index 53376666b..369931cec 100644 --- a/pkg/app/server/hertz_unix_test.go +++ b/pkg/app/server/hertz_unix_test.go @@ -81,7 +81,10 @@ func TestReusePorts(t *testing.T) { } func TestHertz_Spin(t *testing.T) { - engine := New(WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + + engine := New(WithListener(ln)) engine.GET("/test", func(c context.Context, ctx *app.RequestContext) { time.Sleep(40 * time.Millisecond) path := ctx.Request.URI().PathOriginal() @@ -106,7 +109,7 @@ func TestHertz_Spin(t *testing.T) { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for range ticker.C { - _, err := hc.Get(testutils.GetURL(engine, "/test2")) + _, err := hc.Get(fullURL(ln, "/test2")) t.Logf("[%v]begin listening\n", time.Now()) if err != nil { t.Logf("[%v]listening closed: %v", time.Now(), err) @@ -117,7 +120,7 @@ func TestHertz_Spin(t *testing.T) { }() go func() { t.Logf("[%v]begin request\n", time.Now()) - resp, err = http.Get(testutils.GetURL(engine, "/test")) + resp, err = http.Get(fullURL(ln, "/test")) t.Logf("[%v]end request\n", time.Now()) ch <- struct{}{} }() diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index 7b44136f3..892250f00 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -162,6 +162,26 @@ func WithHostPorts(hp string) config.Option { }} } +// WithListener sets the listener to use. +// +// If set, the server will use this listener instead of creating a new one. +// This is useful for custom listener implementations or testing. +// Note: This will update Network and Addr based on the listener's address, +// and reset ListenConfig since it's not needed when a listener is provided. +// +// WARNING: Custom net.Listener implementations may not be supported by cloudwego/netpoll. +// If your custom listener doesn't support netpoll, you need to explicitly set the transporter to the standard library: +// +// WithListener(customListener), WithTransport(standard.NewTransporter) +func WithListener(ln net.Listener) config.Option { + return config.Option{F: func(o *config.Options) { + o.Listener = ln + o.Network = ln.Addr().Network() + o.Addr = ln.Addr().String() + o.ListenConfig = nil + }} +} + // WithBasePath sets basePath.Must be "/" prefix and suffix,If not the default concatenate "/" func WithBasePath(basePath string) config.Option { return config.Option{F: func(o *config.Options) { diff --git a/pkg/app/server/option_test.go b/pkg/app/server/option_test.go index baf5cd4cd..b0d7e4f0d 100644 --- a/pkg/app/server/option_test.go +++ b/pkg/app/server/option_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" @@ -150,6 +151,29 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, opt.MaxHeaderBytes, 1<<20) } +func TestWithListener(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + + cfg := &net.ListenConfig{} + opt := config.NewOptions([]config.Option{ + WithHostPorts("127.0.0.1:8888"), + WithNetwork("udp"), + WithListenConfig(cfg), + WithListener(ln), + }) + + // Listener should be set + assert.DeepEqual(t, opt.Listener, ln) + + // Network and Addr should be updated from listener + assert.DeepEqual(t, opt.Network, ln.Addr().Network()) + assert.DeepEqual(t, opt.Addr, ln.Addr().String()) + + // ListenConfig should be reset + assert.Assert(t, opt.ListenConfig == nil) +} + type mockTransporter struct{} func (m *mockTransporter) ListenAndServe(onData network.OnData) (err error) { diff --git a/pkg/app/server/server_bench_test.go b/pkg/app/server/server_bench_test.go new file mode 100644 index 000000000..89855b413 --- /dev/null +++ b/pkg/app/server/server_bench_test.go @@ -0,0 +1,80 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bufio" + "context" + "net" + "testing" + + "github.com/cloudwego/hertz/internal/testutils" + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/network/standard" +) + +func BenchmarkServerHelloWorld(b *testing.B) { + ln := testutils.NewTestListener(b) + defer ln.Close() + + h := Default(WithListener(ln), WithTransport(standard.NewTransporter)) + h.GET("/hello", func(c context.Context, ctx *app.RequestContext) { + ctx.SetBodyString("hello world") + }) + + go h.Run() + waitEngineRunning(h) + defer h.Engine.Close() + + addr := ln.Addr().String() + + // Pre-create connection pool with keep-alive + const poolSize = 10 + connPool := make([]net.Conn, poolSize) + readerPool := make([]*bufio.Reader, poolSize) + for i := 0; i < poolSize; i++ { + conn, err := net.Dial("tcp", addr) + if err != nil { + b.Fatalf("failed to dial: %s", err) + } + connPool[i] = conn + readerPool[i] = bufio.NewReader(conn) + defer conn.Close() + } + + request := []byte("GET /hello HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n") + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + conn := connPool[i%poolSize] + reader := readerPool[i%poolSize] + _, err := conn.Write(request) + if err != nil { + b.Fatalf("write error: %s", err) + } + _, err = reader.Peek(1) + if err != nil { + b.Fatal(err) + } + _, err = reader.Discard(reader.Buffered()) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/pkg/common/adaptor/handler.go b/pkg/common/adaptor/handler.go index 2af3c4f45..9a2759db1 100644 --- a/pkg/common/adaptor/handler.go +++ b/pkg/common/adaptor/handler.go @@ -22,6 +22,8 @@ import ( "errors" "net" "net/http" + "runtime" + "sync" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/app" @@ -86,7 +88,14 @@ func HertzHandler(h http.Handler) app.HandlerFunc { // coz it's server response // no need to copy anything from hertz Response w := &httpResponseWriter{rc: rc} + h.ServeHTTP(w, req) + if w.hijacked != nil { + // wait for hijacked conn to close before returning, + // otherwise either hertz will close the conn + // or netpoll may reuse the conn for next request. + <-w.hijacked + } } } @@ -97,8 +106,9 @@ type httpResponseWriter struct { err error wroteHeader bool - hijacked bool skipBody bool + + hijacked chan struct{} // != nil if hijacked } var errConnHijacked = errors.New("hertz net/http adaptor: conn hijacked") @@ -111,8 +121,9 @@ func (p *httpResponseWriter) Header() http.Header { return p.header } -func (p *httpResponseWriter) Write(b []byte) (n int, _ error) { - if p.hijacked { +// Write implements http.ResponseWriter.Write +func (p *httpResponseWriter) Write(b []byte) (n int, err error) { + if p.hijacked != nil { return 0, errConnHijacked } if !p.wroteHeader { @@ -128,8 +139,9 @@ func (p *httpResponseWriter) Write(b []byte) (n int, _ error) { return n, p.err } +// WriteHeader implements http.ResponseWriter.WriteHeader func (p *httpResponseWriter) WriteHeader(statusCode int) { - if p.wroteHeader || p.hijacked { + if p.wroteHeader || p.hijacked != nil { return } p.wroteHeader = true @@ -153,35 +165,50 @@ func (p *httpResponseWriter) WriteHeader(statusCode int) { // must be set for hertz request loop or it would write header and body after handler returns r.HijackWriter(noopWriter{}) p.err = resp.WriteHeader(&r.Header, w) - return - } - if r.Header.ContentLength() < 0 { - r.HijackWriter(resp.NewChunkedBodyWriter(r, w)) + } else if r.Header.ContentLength() < 0 { + // For chunked encoding, write headers immediately + cw := resp.NewChunkedBodyWriter(r, w) + r.HijackWriter(cw) + type chunkedBodyWriter interface { + WriteHeader() error + } + p.err = cw.(chunkedBodyWriter).WriteHeader() } else { - p.err = resp.WriteHeader(&r.Header, w) // use Writer directly instead of keep buffering data in resp.BodyBuffer() // you never know how much data would be written to response r.HijackWriter(writer2writerExt(w)) + p.err = resp.WriteHeader(&r.Header, w) } } var _ http.Flusher = (*httpResponseWriter)(nil) -// Flush implements the [http.Flusher] +// Flush implements http.Flusher and captures any flush errors func (p *httpResponseWriter) Flush() { - _ = p.rc.GetWriter().Flush() + if p.err == nil { + p.err = p.rc.GetWriter().Flush() + } } var _ http.Hijacker = (*httpResponseWriter)(nil) -// Hijack implements the [net/http.Hijacker] +// Hijack implements http.Hijacker func (p *httpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if p.hijacked { + if p.hijacked != nil { return nil, nil, errConnHijacked } - p.hijacked = true - - conn := p.rc.GetConn() + if p.err != nil { + return nil, nil, p.err + } + // If headers were already written, flush the buffer to avoid losing + // any pending data before hijacking the connection + if p.wroteHeader { + if p.err = p.rc.GetWriter().Flush(); p.err != nil { + return nil, nil, p.err + } + } + conn := newHijackedConn(p.rc.GetConn()) + p.hijacked = conn.closeCh // reset timeout if any _ = conn.SetReadTimeout(0) @@ -213,3 +240,29 @@ var _ network.ExtWriter = noopWriter{} func (noopWriter) Write(b []byte) (int, error) { return len(b), nil } func (noopWriter) Flush() error { return nil } func (noopWriter) Finalize() error { return nil } + +type hijackedConn struct { + network.Conn + + closeOnce sync.Once + closeCh chan struct{} +} + +func newHijackedConn(conn network.Conn) *hijackedConn { + c := &hijackedConn{Conn: conn, closeCh: make(chan struct{})} + runtime.SetFinalizer(c, hijackedConnFinalizer) + return c +} + +func hijackedConnFinalizer(c *hijackedConn) { + _ = c.Close() +} + +func (c *hijackedConn) Close() error { + runtime.SetFinalizer(c, nil) + err := c.Conn.Close() + c.closeOnce.Do(func() { + close(c.closeCh) + }) + return err +} diff --git a/pkg/common/adaptor/handler_test.go b/pkg/common/adaptor/handler_test.go index 0facc2b60..cbfbb8626 100644 --- a/pkg/common/adaptor/handler_test.go +++ b/pkg/common/adaptor/handler_test.go @@ -25,6 +25,7 @@ import ( "mime/multipart" "net" "net/http" + "runtime" "strings" "sync" "testing" @@ -41,13 +42,14 @@ import ( var adaptorFiles embed.FS func runEngine(onCreate func(*route.Engine)) (string, *route.Engine) { + ln := testutils.NewTestListener(&testing.T{}) opt := config.NewOptions(nil) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) onCreate(engine) go engine.Run() testutils.WaitEngineRunning(engine) - return testutils.GetListenerAddr(engine), engine + return ln.Addr().String(), engine } func TestHertzHandler_BodyStream(t *testing.T) { @@ -114,23 +116,57 @@ func TestHertzHandler_Chunked(t *testing.T) { } } +func TestHertzHandler_WriteHeader(t *testing.T) { + h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "0") + w.WriteHeader(500) + w.(http.Flusher).Flush() + time.Sleep(time.Second) // Simulate long-running handler + })) + addr, e := runEngine(func(e *route.Engine) { + e.GET("/test", h) + }) + defer e.Close() + + conn, err := net.Dial("tcp", addr) + assert.Nil(t, err) + defer conn.Close() + + _, err = conn.Write([]byte("GET /test HTTP/1.1\r\nHost: example.com\r\n\r\n")) + assert.Nil(t, err) + + // Set a short read deadline to verify headers arrive quickly after WriteHeader() + conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + b := make([]byte, 200) + n, err := conn.Read(b) + assert.Nil(t, err) + assert.Assert(t, strings.HasPrefix(string(b[:n]), "HTTP/1.1 500 ")) +} + func TestHertzHandler_Hijack(t *testing.T) { h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, rw, err := w.(http.Hijacker).Hijack() + conn, rw, err := w.(http.Hijacker).Hijack() assert.Nil(t, err) _, _, err = w.(http.Hijacker).Hijack() // hijacked assert.NotNil(t, err) w.WriteHeader(500) // hijacked, noop - _, err = w.Write([]byte("hello")) - assert.Assert(t, err == errConnHijacked) - rw.Write([]byte("hello")) - rw.Flush() - b := make([]byte, 10) - n, err := rw.Read(b) - assert.Nil(t, err) - assert.Assert(t, string(b[:n]) == "world") + go func() { + defer conn.Close() + time.Sleep(50 * time.Millisecond) + _, err = w.Write([]byte("hello")) + assert.Assert(t, err == errConnHijacked) + + _, err = rw.Write([]byte("hello")) + assert.Nil(t, err) + err = rw.Flush() + assert.Nil(t, err) + b := make([]byte, 10) + n, err := rw.Read(b) + assert.Nil(t, err) + assert.Assert(t, string(b[:n]) == "world") + }() })) addr, e := runEngine(func(e *route.Engine) { e.GET("/test", h) @@ -147,12 +183,80 @@ func TestHertzHandler_Hijack(t *testing.T) { assert.Assert(t, string(b[:n]) == "hello", string(b[:n])) _, err = conn.Write([]byte("world")) assert.Nil(t, err) - n, err = conn.Read(b) // Keep-Alive will not work if hijacked assert.Assert(t, err == io.EOF) assert.Assert(t, n == 0) } +// TestHertzHandler_HijackGC tests that hijacked conn is closed by GC finalizer +// when user forgets to call Close() +func TestHertzHandler_HijackGC(t *testing.T) { + h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _, err := w.(http.Hijacker).Hijack() + assert.Nil(t, err) + // intentionally not closing conn, let GC handle it + runtime.GC() + runtime.GC() // make sure the net.Conn is closed by GC + })) + addr, e := runEngine(func(e *route.Engine) { + e.GET("/test", h) + }) + defer e.Close() + + conn, err := net.Dial("tcp", addr) + assert.Nil(t, err) + defer conn.Close() + conn.Write([]byte("GET /test HTTP/1.1\r\nHost: example.com\r\n\r\n")) + + b := make([]byte, 100) + n, err := conn.Read(b) // conn should be closed by finalizer + assert.Assert(t, err == io.EOF, err) + assert.Assert(t, n == 0) +} + +// TestHertzHandler_WriteHeader_Hijack verifies that headers are properly flushed +// before hijacking the connection. This test ensures that when WriteHeader is called +// before Hijack, the headers are correctly sent and the connection can be taken over. +func TestHertzHandler_WriteHeader_Hijack(t *testing.T) { + h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom-Header", "test-value") + w.WriteHeader(200) + + conn, rw, err := w.(http.Hijacker).Hijack() + assert.Nil(t, err) + defer conn.Close() + + _, err = rw.WriteString("hijacked response body") + assert.Nil(t, err) + assert.Nil(t, rw.Flush()) + })) + addr, e := runEngine(func(e *route.Engine) { + e.GET("/test", h) + }) + defer e.Close() + + conn, err := net.Dial("tcp", addr) + assert.Nil(t, err) + defer conn.Close() + + _, err = conn.Write([]byte("GET /test HTTP/1.1\r\nHost: example.com\r\n\r\n")) + assert.Nil(t, err) + + // Wait briefly for server to process and send response + time.Sleep(50 * time.Millisecond) + conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + + b := make([]byte, 1024) + n, err := conn.Read(b) + assert.Nil(t, err) + + response := string(b[:n]) + t.Logf("Response: %q", response) + assert.Assert(t, strings.Contains(response, "HTTP/1.1 200 OK"), response) + assert.Assert(t, strings.Contains(response, "X-Custom-Header: test-value"), response) + assert.Assert(t, strings.Contains(response, "hijacked response body"), response) +} + func TestHertzHandler_FSEmbed(t *testing.T) { addr, e := runEngine(func(e *route.Engine) { h := HertzHandler(http.FileServer(http.FS(adaptorFiles))) diff --git a/pkg/common/adaptor/request_test.go b/pkg/common/adaptor/request_test.go index 3e0a37dfd..e9220dd8a 100644 --- a/pkg/common/adaptor/request_test.go +++ b/pkg/common/adaptor/request_test.go @@ -19,8 +19,10 @@ package adaptor import ( "context" "io/ioutil" + "net" "net/http" "net/url" + "path" "strings" "testing" "time" @@ -33,6 +35,10 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/consts" ) +func fullURL(ln net.Listener, p string) string { + return "http://" + path.Join(ln.Addr().String(), p) +} + func TestCompatResponse_WriteHeader(t *testing.T) { var testHeader http.Header var testBody string @@ -47,7 +53,10 @@ func TestCompatResponse_WriteHeader(t *testing.T) { testBody = "test body" - h := server.New(server.WithHostPorts("127.0.0.1:0")) + ln := testutils.NewTestListener(t) + defer ln.Close() + + h := server.New(server.WithListener(ln)) h.POST("/test1", func(c context.Context, ctx *app.RequestContext) { req, _ := GetCompatRequest(&ctx.Request) resp := GetCompatResponseWriter(&ctx.Response) @@ -63,8 +72,8 @@ func TestCompatResponse_WriteHeader(t *testing.T) { go h.Spin() time.Sleep(100 * time.Millisecond) - testUrl1 := testutils.GetURL(h, "/test1") - testUrl2 := testutils.GetURL(h, "/test2") + testUrl1 := fullURL(ln, "/test1") + testUrl2 := fullURL(ln, "/test2") makeACall(t, http.MethodPost, testUrl1, testHeader, testBody, testStatusCode, []byte(testCookieValue)) makeACall(t, http.MethodPost, testUrl2, testHeader, testBody, consts.StatusOK, []byte(testCookieValue)) } diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index c84377edd..eaa37ea3f 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -76,6 +76,7 @@ type Options struct { ALPN bool Tracers []interface{} TraceLevel interface{} + Listener net.Listener ListenConfig *net.ListenConfig BindConfig interface{} CustomBinder interface{} diff --git a/pkg/network/netpoll/dial_test.go b/pkg/network/netpoll/dial_test.go index f867425b2..c82001f6d 100644 --- a/pkg/network/netpoll/dial_test.go +++ b/pkg/network/netpoll/dial_test.go @@ -23,23 +23,19 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" - "github.com/cloudwego/hertz/pkg/network" ) -func getListenerAddr(trans network.Transporter) string { - return trans.(*transporter).Listener().Addr().String() -} - func TestDial(t *testing.T) { t.Run("NetpollDial", func(t *testing.T) { - const nw = "tcp" - var addr = "127.0.0.1:0" + ln := testutils.NewTestListener(t) + defer ln.Close() + transporter := NewTransporter(&config.Options{ - Addr: addr, - Network: nw, + Listener: ln, }) go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { return nil @@ -48,18 +44,19 @@ func TestDial(t *testing.T) { time.Sleep(100 * time.Millisecond) dial := NewDialer() + addr := ln.Addr().String() + nw := ln.Addr().Network() + // DialConnection - _, err := dial.DialConnection("tcp", "localhost:10101", time.Second, nil) // wrong addr + _, err := dial.DialConnection(nw, "localhost:10101", time.Second, nil) // wrong addr assert.NotNil(t, err) - - addr = getListenerAddr(transporter) nwConn, err := dial.DialConnection(nw, addr, time.Second, nil) assert.Nil(t, err) defer nwConn.Close() _, err = nwConn.Write([]byte("abcdef")) assert.Nil(t, err) // DialTimeout - nConn, err := dial.DialTimeout(nw, addr, time.Second, nil) + nConn, err := dial.DialTimeout("tcp", addr, time.Second, nil) assert.Nil(t, err) defer nConn.Close() }) diff --git a/pkg/network/netpoll/transport.go b/pkg/network/netpoll/transport.go index 64fee4102..0200cf8eb 100644 --- a/pkg/network/netpoll/transport.go +++ b/pkg/network/netpoll/transport.go @@ -70,6 +70,7 @@ func NewTransporter(options *config.Options) network.Transporter { keepAliveTimeout: options.KeepAliveTimeout, readTimeout: options.ReadTimeout, writeTimeout: options.WriteTimeout, + ln: options.Listener, listenConfig: options.ListenConfig, OnAccept: options.OnAccept, OnConnect: options.OnConnect, @@ -88,18 +89,20 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) { network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck t.mu.Lock() - if t.listenConfig != nil { - t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr) - } else { - t.ln, err = net.Listen(t.network, t.addr) + if t.ln == nil { + if t.listenConfig != nil { + t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr) + } else { + t.ln, err = net.Listen(t.network, t.addr) + } + if err != nil { + t.mu.Unlock() + panic("create netpoll listener fail: " + err.Error()) + } } ln := t.ln t.mu.Unlock() - if err != nil { - panic("create netpoll listener fail: " + err.Error()) - } - // Initialize custom option for EventLoop opts := []netpoll.Option{ netpoll.WithIdleTimeout(t.keepAliveTimeout), diff --git a/pkg/network/netpoll/transport_test.go b/pkg/network/netpoll/transport_test.go index 2a1e4ed49..1d8dde009 100644 --- a/pkg/network/netpoll/transport_test.go +++ b/pkg/network/netpoll/transport_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network" @@ -32,13 +33,13 @@ import ( ) func TestTransport(t *testing.T) { - const nw = "tcp" t.Run("TestDefault", func(t *testing.T) { - var addr = "127.0.0.1:0" + ln := testutils.NewTestListener(t) + defer ln.Close() + var onConnFlag, onAcceptFlag, onDataFlag int32 transporter := NewTransporter(&config.Options{ - Addr: addr, - Network: nw, + Listener: ln, OnConnect: func(ctx context.Context, conn network.Conn) context.Context { atomic.StoreInt32(&onConnFlag, 1) return ctx @@ -56,7 +57,8 @@ func TestTransport(t *testing.T) { defer transporter.Close() time.Sleep(100 * time.Millisecond) - addr = getListenerAddr(transporter) + addr := ln.Addr().String() + nw := ln.Addr().Network() dial := NewDialer() conn, err := dial.DialConnection(nw, addr, time.Second, nil) @@ -71,11 +73,12 @@ func TestTransport(t *testing.T) { }) t.Run("TestSenseClientDisconnection", func(t *testing.T) { - var addr = "127.0.0.1:0" + ln := testutils.NewTestListener(t) + defer ln.Close() + var onReqFlag int32 transporter := NewTransporter(&config.Options{ - Addr: addr, - Network: nw, + Listener: ln, SenseClientDisconnection: true, }) @@ -88,7 +91,8 @@ func TestTransport(t *testing.T) { defer transporter.Close() time.Sleep(100 * time.Millisecond) - addr = getListenerAddr(transporter) + addr := ln.Addr().String() + nw := ln.Addr().Network() dial := NewDialer() conn, err := dial.DialConnection(nw, addr, time.Second, nil) @@ -110,8 +114,8 @@ func TestTransport(t *testing.T) { }) }} transporter := NewTransporter(&config.Options{ + Network: "tcp", Addr: "127.0.0.1:0", - Network: nw, ListenConfig: listenCfg, }) go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { @@ -130,4 +134,35 @@ func TestTransport(t *testing.T) { }) }) }) + + t.Run("TestWithListener", func(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + + var onDataFlag int32 + trans := NewTransporter(&config.Options{ + Listener: ln, + }).(*transporter) + go trans.ListenAndServe(func(ctx context.Context, conn interface{}) error { + atomic.StoreInt32(&onDataFlag, 1) + return nil + }) + defer trans.Close() + time.Sleep(100 * time.Millisecond) + + // Verify listener is used + assert.DeepEqual(t, ln.Addr().String(), trans.Listener().Addr().String()) + + nw := ln.Addr().Network() + + // Connect and send data + dial := NewDialer() + conn, err := dial.DialConnection(nw, ln.Addr().String(), time.Second, nil) + assert.Nil(t, err) + _, err = conn.Write([]byte("test")) + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) + + assert.Assert(t, atomic.LoadInt32(&onDataFlag) == 1) + }) } diff --git a/pkg/network/standard/dial_test.go b/pkg/network/standard/dial_test.go index 66f0fcf65..fcb06e3fe 100644 --- a/pkg/network/standard/dial_test.go +++ b/pkg/network/standard/dial_test.go @@ -29,21 +29,19 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" - "github.com/cloudwego/hertz/pkg/network" ) -func getListenerAddr(trans network.Transporter) string { - return trans.(*transport).Listener().Addr().String() -} - func TestDial(t *testing.T) { const nw = "tcp" - addr := "127.0.0.1:0" + ln := testutils.NewTestListener(t) + defer ln.Close() + transporter := NewTransporter(&config.Options{ - Addr: addr, - Network: nw, + Listener: ln, + Network: nw, }) go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { @@ -52,7 +50,7 @@ func TestDial(t *testing.T) { defer transporter.Close() time.Sleep(time.Millisecond * 100) - addr = getListenerAddr(transporter) + addr := ln.Addr().String() dial := NewDialer() _, err := dial.DialConnection(nw, addr, time.Second, nil) diff --git a/pkg/network/standard/transport.go b/pkg/network/standard/transport.go index 79d5ea655..ec89b9df0 100644 --- a/pkg/network/standard/transport.go +++ b/pkg/network/standard/transport.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "errors" "net" + "strings" "sync" "sync/atomic" "time" @@ -46,7 +47,8 @@ type transport struct { OnConnect func(ctx context.Context, conn network.Conn) context.Context // active connections. it +1 after accept and -1 after handler returns - active int32 + active int32 + shuttingDown int32 mu sync.RWMutex ln net.Listener @@ -61,24 +63,33 @@ func (t *transport) Listener() net.Listener { func (t *transport) serve() (err error) { network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck t.mu.Lock() - if t.listenConfig != nil { - t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr) - } else { - t.ln, err = net.Listen(t.network, t.addr) + if t.ln == nil { + if t.listenConfig != nil { + t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr) + } else { + t.ln, err = net.Listen(t.network, t.addr) + } + if err != nil { + t.mu.Unlock() + return err + } } // fix concurrency issue // normally listener must not be changed during serve() ln := t.ln t.mu.Unlock() - if err != nil { - return err - } hlog.SystemLogger().Infof("HTTP server listening on address=%s", ln.Addr().String()) for { ctx := context.Background() conn, err := ln.Accept() if err != nil { - hlog.SystemLogger().Errorf("Error=%s", err.Error()) + if atomic.LoadInt32(&t.shuttingDown) > 0 { + return nil + } + if strings.Contains(err.Error(), "closed") { + return nil + } + hlog.SystemLogger().Errorf("Accept err: %v", err) return err } t.updateActive(1) @@ -135,6 +146,8 @@ var ( ) func (t *transport) Shutdown(ctx context.Context) error { + atomic.StoreInt32(&t.shuttingDown, 1) + defer func() { network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck }() @@ -180,6 +193,7 @@ func NewTransporter(options *config.Options) network.Transporter { readTimeout: options.ReadTimeout, senseClientDisconnection: options.SenseClientDisconnection, tls: options.TLS, + ln: options.Listener, listenConfig: options.ListenConfig, OnAccept: options.OnAccept, OnConnect: options.OnConnect, diff --git a/pkg/network/standard/transport_test.go b/pkg/network/standard/transport_test.go index a38861a9f..d6868e51f 100644 --- a/pkg/network/standard/transport_test.go +++ b/pkg/network/standard/transport_test.go @@ -24,6 +24,7 @@ import ( "time" internalNetwork "github.com/cloudwego/hertz/internal/network" + "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/network" @@ -112,3 +113,22 @@ func TestTransporter(t *testing.T) { time.Sleep(10 * time.Millisecond) // wait handler returns, and active conn to be updated. checkActiveConn(0) } + +func TestAcceptError(t *testing.T) { + ln := testutils.NewTestListener(t) + defer ln.Close() + + trans := NewTransporter(&config.Options{Listener: ln}).(*transport) + errCh := make(chan error, 1) + go func() { errCh <- trans.ListenAndServe(func(context.Context, interface{}) error { return nil }) }() + + time.Sleep(10 * time.Millisecond) // Wait for serve to start + + // Close listener to trigger error + ln.Close() + + // Wait for serve to exit with error + if err := <-errCh; err != nil { + t.Fatal("expected nil after listener close") + } +} diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go index 89072e37e..89a36e085 100644 --- a/pkg/protocol/header.go +++ b/pkg/protocol/header.go @@ -1644,8 +1644,8 @@ func (h *ResponseHeader) Get(key string) string { // GetAll returns all header value for the given key // it is concurrent safety and long lifetime. func (h *RequestHeader) GetAll(key string) []string { - res := make([]string, 0) headers := h.PeekAll(key) + res := make([]string, 0, len(headers)) for _, header := range headers { res = append(res, string(header)) } @@ -1655,8 +1655,8 @@ func (h *RequestHeader) GetAll(key string) []string { // GetAll returns all header value for the given key and is concurrent safety. // it is concurrent safety and long lifetime. func (h *ResponseHeader) GetAll(key string) []string { - res := make([]string, 0) headers := h.PeekAll(key) + res := make([]string, 0, len(headers)) for _, header := range headers { res = append(res, string(header)) } diff --git a/pkg/protocol/http1/client_unix_test.go b/pkg/protocol/http1/client_unix_test.go index 1a162cad9..1ca24ec82 100644 --- a/pkg/protocol/http1/client_unix_test.go +++ b/pkg/protocol/http1/client_unix_test.go @@ -19,7 +19,6 @@ package http1 import ( "context" "errors" - "net" "net/http" "runtime" "sync" @@ -27,6 +26,7 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/internal/testutils" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network/netpoll" @@ -35,8 +35,7 @@ import ( ) func TestGcBodyStream(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - assert.Nil(t, err) + ln := testutils.NewTestListener(t) defer ln.Close() srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { for range [1024]int{} { @@ -72,8 +71,7 @@ func TestGcBodyStream(t *testing.T) { } func TestMaxConn(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - assert.Nil(t, err) + ln := testutils.NewTestListener(t) defer ln.Close() srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello world\n")) diff --git a/pkg/protocol/http1/resp/writer.go b/pkg/protocol/http1/resp/writer.go index 1150826ac..1bbe498ec 100644 --- a/pkg/protocol/http1/resp/writer.go +++ b/pkg/protocol/http1/resp/writer.go @@ -55,7 +55,7 @@ func (c *chunkedBodyWriter) Write(p []byte) (n int, err error) { if c.err != nil { return 0, c.err } - if err := c.writeHeader(); err != nil { + if err := c.WriteHeader(); err != nil { return 0, err } if len(p) == 0 { @@ -70,9 +70,10 @@ func (c *chunkedBodyWriter) Write(p []byte) (n int, err error) { return len(p), nil } -func (c *chunkedBodyWriter) writeHeader() error { +// WriteHeader writes the response header for chunked encoding +func (c *chunkedBodyWriter) WriteHeader() error { if c.wroteHeader { - return nil + return c.err } c.wroteHeader = true c.r.Header.SetContentLength(-1) @@ -100,7 +101,7 @@ func (c *chunkedBodyWriter) Finalize() error { return c.err } c.finalized = true - if err := c.writeHeader(); err != nil { + if err := c.WriteHeader(); err != nil { return err } // zero-len chunk diff --git a/pkg/protocol/sse/example_test.go b/pkg/protocol/sse/example_test.go index 6b1a143b2..e6919059a 100644 --- a/pkg/protocol/sse/example_test.go +++ b/pkg/protocol/sse/example_test.go @@ -19,9 +19,9 @@ package sse import ( "context" "fmt" + "net" "time" - "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/common/config" @@ -32,8 +32,11 @@ import ( // Example demonstrates a simple SSE server and client interaction. func Example() { // --- SSE Server --- + ln, _ := net.Listen("tcp", "127.0.0.1:0") + defer ln.Close() + opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:0" + opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/", func(ctx context.Context, c *app.RequestContext) { println("Server Got LastEventID", GetLastEventID(&c.Request)) @@ -49,7 +52,7 @@ func Example() { go engine.Run() defer engine.Close() time.Sleep(20 * time.Millisecond) // wait for server to start - opt.Addr = testutils.GetListenerAddr(engine) + opt.Addr = ln.Addr().String() // --- SSE Client --- c, _ := client.NewClient() diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 7dfe043fa..5f2ed4bf2 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -508,9 +508,9 @@ func (engine *Engine) Serve(c context.Context, conn network.Conn) (err error) { if err != nil { logError(conn, err) } - if !errors.Is(err, errs.ErrHijacked) { - _ = conn.Close() - } + // always close conn before Serve returns, + // some implementations (e.g., netpoll) may reuse conn if not closed + _ = conn.Close() }() // H2C path