From 788a736e067b0bc1cd9da5d8d3dced98193c435b Mon Sep 17 00:00:00 2001 From: Darko Luketic Date: Tue, 25 Nov 2025 15:59:31 +0100 Subject: [PATCH] Fix pooled params correctness and add context copy benchmark --- performance_test.go | 26 +++++- sux.go | 208 ++++++++++++++++++++++++++------------------ 2 files changed, 145 insertions(+), 89 deletions(-) diff --git a/performance_test.go b/performance_test.go index e80f74e..4bed62a 100644 --- a/performance_test.go +++ b/performance_test.go @@ -29,7 +29,7 @@ func BenchmarkParameterOperations(b *testing.B) { // BenchmarkPathParsing compares old vs new path parsing func BenchmarkPathParsingOld(b *testing.B) { path := "/users/123/posts/456/comments/789" - + b.ResetTimer() b.ReportAllocs() @@ -40,7 +40,7 @@ func BenchmarkPathParsingOld(b *testing.B) { func BenchmarkPathParsingNew(b *testing.B) { path := "/users/123/posts/456/comments/789" - + b.ResetTimer() b.ReportAllocs() @@ -197,4 +197,24 @@ func BenchmarkPoolEfficiency(b *testing.B) { w := httptest.NewRecorder() router.ServeHTTP(w, req) } -} \ No newline at end of file +} + +// BenchmarkContextParamCopy measures overhead of copying params into context +func BenchmarkContextParamCopy(b *testing.B) { + router := New() + router.GET("/users/:id/posts/:postId", func(w http.ResponseWriter, r *http.Request) { + params := ParamsFromContext(r) + _ = params.Get("id") + params.Get("postId") + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/users/123/posts/456", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} diff --git a/sux.go b/sux.go index 810b8a3..47033d1 100644 --- a/sux.go +++ b/sux.go @@ -3,7 +3,6 @@ package sux import ( "context" "net/http" - "strings" "sync" ) @@ -44,8 +43,8 @@ type ( methodNotAllowedHandler http.HandlerFunc paramsPool *sync.Pool // Use pointer to avoid copying segmentsPool *sync.Pool // Pool for path segments - builderPool *sync.Pool // Pool for string builders prefix string // For route groups + checkMethodNotAllowed bool // Whether to probe other methods on 404 } ) @@ -56,8 +55,13 @@ const paramsKey contextKey = "params" // ParamsFromContext extracts route parameters from the request context func ParamsFromContext(r *http.Request) Params { - if params, ok := r.Context().Value(paramsKey).(Params); ok { + switch params := r.Context().Value(paramsKey).(type) { + case Params: return params + case *Params: + if params != nil { + return *params + } } return Params{} } @@ -148,7 +152,7 @@ func parsePath(path string) []string { } segments := make([]string, 0, 4) // Pre-allocate for common cases - start := 1 // Skip the leading slash + start := 1 // Skip the leading slash for i := 1; i < len(path); i++ { if path[i] == '/' { @@ -166,6 +170,31 @@ func parsePath(path string) []string { return segments } +// parsePathInto splits a path into segments using the provided buffer +func parsePathInto(path string, buf []string) []string { + if path == "" || path == "/" { + return buf[:0] + } + + buf = buf[:0] + start := 1 // Skip the leading slash + + for i := 1; i < len(path); i++ { + if path[i] == '/' { + if start < i { + buf = append(buf, path[start:i]) + } + start = i + 1 + } + } + + if start < len(path) { + buf = append(buf, path[start:]) + } + + return buf +} + // parsePathZeroAlloc splits a path into segments with minimal allocations // Uses string slicing to avoid allocations where possible func parsePathZeroAlloc(path string) []string { @@ -249,7 +278,6 @@ func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFun } if isLast { - current.nodeType = static current.handler = handler current.middleware = middleware } @@ -257,11 +285,14 @@ func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFun } // find searches for a route matching the given path -func (n *node) find(segments []string, params *Params) *node { +func (n *node) find(path string, segments []string, params *Params) *node { current := n + pos := 1 // Track offset into the original path for wildcard slicing for i, segment := range segments { isLast := i == len(segments)-1 + segmentStart := pos + pos += len(segment) + 1 // advance past this segment and trailing slash // Try static match first (fastest path) if child := current.findChild(segment); child != nil { @@ -274,7 +305,7 @@ func (n *node) find(segments []string, params *Params) *node { // Try parameter match if current.paramChild != nil { - if params != nil && params.keys != nil { + if params != nil { params.Set(current.paramChild.paramName, segment) } if isLast && current.paramChild.nodeType == param { @@ -286,30 +317,9 @@ func (n *node) find(segments []string, params *Params) *node { // Try wildcard match (catches everything) if current.wildcardChild != nil { - if params != nil && params.keys != nil { - // Wildcard captures the rest of the path - // Use pooled string builder for efficient concatenation - // Note: We can't use the pool here since we don't have access to router - // But we can optimize the string concatenation - var wildcardValue string - if i < len(segments)-1 { - // Pre-calculate capacity to avoid reallocations - capacity := len(segment) - for j := i + 1; j < len(segments); j++ { - capacity += 1 + len(segments[j]) // 1 for '/' - } - builder := strings.Builder{} - builder.Grow(capacity) - builder.WriteString(segment) - for j := i + 1; j < len(segments); j++ { - builder.WriteByte('/') - builder.WriteString(segments[j]) - } - wildcardValue = builder.String() - } else { - wildcardValue = segment - } - params.Set(current.wildcardChild.paramName, wildcardValue) + if params != nil { + // Wildcard captures the rest of the path directly from the original string + params.Set(current.wildcardChild.paramName, path[segmentStart:]) } return current.wildcardChild } @@ -326,63 +336,88 @@ func (n *node) find(segments []string, params *Params) *node { // ServeHTTP implements the http.Handler interface func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if tree, exists := r.trees[req.Method]; exists { - segments := parsePathZeroAlloc(req.URL.Path) + segments := r.segmentsPool.Get().([]string) + segments = segments[:0] + segments = parsePathInto(req.URL.Path, segments) + defer func() { + // Clear to allow GC of prior path strings before pooling + for i := range segments { + segments[i] = "" + } + r.segmentsPool.Put(segments[:0]) + }() - // Get params from pool for better performance - params := r.paramsPool.Get().(Params) + tree, exists := r.trees[req.Method] + if !exists { + if r.checkMethodNotAllowed && r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) { + r.handleMethodNotAllowed(w, req) + } else { + r.handleNotFound(w, req) + } + return + } + + params := r.paramsPool.Get().(*Params) + params.Reset() + + node := tree.find(req.URL.Path, segments, params) + if node == nil || node.handler == nil { params.Reset() + r.paramsPool.Put(params) - node := tree.find(segments, ¶ms) - - if node == nil || node.handler == nil { + if r.checkMethodNotAllowed { // Check if the path exists for other methods to return 405 instead of 404 - hasOtherMethod := false - for method, otherTree := range r.trees { - if method != req.Method { - if otherNode := otherTree.find(segments, &Params{}); otherNode != nil { - hasOtherMethod = true - break - } - } - } - - // Return params to pool - r.paramsPool.Put(params) - - if hasOtherMethod { + if r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) { r.handleMethodNotAllowed(w, req) } else { r.handleNotFound(w, req) } - return + } else { + r.handleNotFound(w, req) } - - // Add params to request context - ctx := context.WithValue(req.Context(), paramsKey, params) - req = req.WithContext(ctx) - - // Create handler chain with middleware - handler := http.Handler(http.HandlerFunc(node.handler)) - - // Apply node-specific middleware first, then router middleware - for i := len(node.middleware) - 1; i >= 0; i-- { - handler = node.middleware[i](handler) - } - for i := len(r.middleware) - 1; i >= 0; i-- { - handler = r.middleware[i](handler) - } - - // Defer returning params to pool after handler completes - defer func() { - r.paramsPool.Put(params) - }() - - handler.ServeHTTP(w, req) return } - r.handleMethodNotAllowed(w, req) + // Add params to request context as a copy to keep values stable after pooling + var ctxParams Params + if params.Len() > 0 { + ctxParams.keys = append(ctxParams.keys, params.keys...) + ctxParams.values = append(ctxParams.values, params.values...) + } + ctx := context.WithValue(req.Context(), paramsKey, ctxParams) + req = req.WithContext(ctx) + + // Create handler chain with middleware + handler := http.Handler(http.HandlerFunc(node.handler)) + + // Apply node-specific middleware first, then router middleware + for i := len(node.middleware) - 1; i >= 0; i-- { + handler = node.middleware[i](handler) + } + for i := len(r.middleware) - 1; i >= 0; i-- { + handler = r.middleware[i](handler) + } + + // Defer returning params to pool after handler completes + defer func() { + params.Reset() + r.paramsPool.Put(params) + }() + + handler.ServeHTTP(w, req) +} + +// hasOtherMethodMatch checks whether the same path exists on another method tree. +func (r *Router) hasOtherMethodMatch(path string, segments []string, currentMethod string) bool { + for method, otherTree := range r.trees { + if method == currentMethod { + continue + } + if otherNode := otherTree.find(path, segments, nil); otherNode != nil && otherNode.handler != nil { + return true + } + } + return false } // handleNotFound handles 404 responses @@ -421,10 +456,10 @@ func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Router { middleware: append(r.middleware, middleware...), notFoundHandler: r.notFoundHandler, methodNotAllowedHandler: r.methodNotAllowedHandler, - paramsPool: r.paramsPool, // Already a pointer, just copy the reference + paramsPool: r.paramsPool, // Already a pointer, just copy the reference segmentsPool: r.segmentsPool, // Share the segments pool - builderPool: r.builderPool, // Share the builder pool prefix: r.prefix + prefix, + checkMethodNotAllowed: r.checkMethodNotAllowed, } } @@ -438,6 +473,11 @@ func (r *Router) MethodNotAllowed(handler http.HandlerFunc) { r.methodNotAllowedHandler = handler } +// EnableMethodNotAllowedCheck toggles costly cross-method lookup on 404 +func (r *Router) EnableMethodNotAllowedCheck(enabled bool) { + r.checkMethodNotAllowed = enabled +} + // addRoute is a helper method to add routes for different HTTP methods func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router { if r.trees == nil { @@ -502,8 +542,8 @@ func New() *Router { middleware: make([]MiddlewareFunc, 0), paramsPool: &sync.Pool{ New: func() interface{} { - return Params{ - keys: make([]string, 0, 4), // Pre-allocate for common cases + return &Params{ + keys: make([]string, 0, 4), // Pre-allocate for common cases values: make([]string, 0, 4), } }, @@ -513,12 +553,8 @@ func New() *Router { return make([]string, 0, 4) // Pre-allocate for common cases }, }, - builderPool: &sync.Pool{ - New: func() interface{} { - return &strings.Builder{} - }, - }, - prefix: "", + prefix: "", + checkMethodNotAllowed: true, } // Initialize trees for common HTTP methods