Fix pooled params correctness and add context copy benchmark

This commit is contained in:
2025-11-25 15:59:31 +01:00
parent 3b94964068
commit 788a736e06
2 changed files with 145 additions and 89 deletions

View File

@@ -198,3 +198,23 @@ func BenchmarkPoolEfficiency(b *testing.B) {
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
} }
} }
// BenchmarkContextParamCopy measures overhead of copying params into context
func BenchmarkContextParamCopy(b *testing.B) {
router := New()
router.GET("/users/:id/posts/:postId", func(w http.ResponseWriter, r *http.Request) {
params := ParamsFromContext(r)
_ = params.Get("id") + params.Get("postId")
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/users/123/posts/456", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}

166
sux.go
View File

@@ -3,7 +3,6 @@ package sux
import ( import (
"context" "context"
"net/http" "net/http"
"strings"
"sync" "sync"
) )
@@ -44,8 +43,8 @@ type (
methodNotAllowedHandler http.HandlerFunc methodNotAllowedHandler http.HandlerFunc
paramsPool *sync.Pool // Use pointer to avoid copying paramsPool *sync.Pool // Use pointer to avoid copying
segmentsPool *sync.Pool // Pool for path segments segmentsPool *sync.Pool // Pool for path segments
builderPool *sync.Pool // Pool for string builders
prefix string // For route groups prefix string // For route groups
checkMethodNotAllowed bool // Whether to probe other methods on 404
} }
) )
@@ -56,8 +55,13 @@ const paramsKey contextKey = "params"
// ParamsFromContext extracts route parameters from the request context // ParamsFromContext extracts route parameters from the request context
func ParamsFromContext(r *http.Request) Params { func ParamsFromContext(r *http.Request) Params {
if params, ok := r.Context().Value(paramsKey).(Params); ok { switch params := r.Context().Value(paramsKey).(type) {
case Params:
return params return params
case *Params:
if params != nil {
return *params
}
} }
return Params{} return Params{}
} }
@@ -166,6 +170,31 @@ func parsePath(path string) []string {
return segments return segments
} }
// parsePathInto splits a path into segments using the provided buffer
func parsePathInto(path string, buf []string) []string {
if path == "" || path == "/" {
return buf[:0]
}
buf = buf[:0]
start := 1 // Skip the leading slash
for i := 1; i < len(path); i++ {
if path[i] == '/' {
if start < i {
buf = append(buf, path[start:i])
}
start = i + 1
}
}
if start < len(path) {
buf = append(buf, path[start:])
}
return buf
}
// parsePathZeroAlloc splits a path into segments with minimal allocations // parsePathZeroAlloc splits a path into segments with minimal allocations
// Uses string slicing to avoid allocations where possible // Uses string slicing to avoid allocations where possible
func parsePathZeroAlloc(path string) []string { func parsePathZeroAlloc(path string) []string {
@@ -249,7 +278,6 @@ func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFun
} }
if isLast { if isLast {
current.nodeType = static
current.handler = handler current.handler = handler
current.middleware = middleware current.middleware = middleware
} }
@@ -257,11 +285,14 @@ func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFun
} }
// find searches for a route matching the given path // find searches for a route matching the given path
func (n *node) find(segments []string, params *Params) *node { func (n *node) find(path string, segments []string, params *Params) *node {
current := n current := n
pos := 1 // Track offset into the original path for wildcard slicing
for i, segment := range segments { for i, segment := range segments {
isLast := i == len(segments)-1 isLast := i == len(segments)-1
segmentStart := pos
pos += len(segment) + 1 // advance past this segment and trailing slash
// Try static match first (fastest path) // Try static match first (fastest path)
if child := current.findChild(segment); child != nil { if child := current.findChild(segment); child != nil {
@@ -274,7 +305,7 @@ func (n *node) find(segments []string, params *Params) *node {
// Try parameter match // Try parameter match
if current.paramChild != nil { if current.paramChild != nil {
if params != nil && params.keys != nil { if params != nil {
params.Set(current.paramChild.paramName, segment) params.Set(current.paramChild.paramName, segment)
} }
if isLast && current.paramChild.nodeType == param { if isLast && current.paramChild.nodeType == param {
@@ -286,30 +317,9 @@ func (n *node) find(segments []string, params *Params) *node {
// Try wildcard match (catches everything) // Try wildcard match (catches everything)
if current.wildcardChild != nil { if current.wildcardChild != nil {
if params != nil && params.keys != nil { if params != nil {
// Wildcard captures the rest of the path // Wildcard captures the rest of the path directly from the original string
// Use pooled string builder for efficient concatenation params.Set(current.wildcardChild.paramName, path[segmentStart:])
// Note: We can't use the pool here since we don't have access to router
// But we can optimize the string concatenation
var wildcardValue string
if i < len(segments)-1 {
// Pre-calculate capacity to avoid reallocations
capacity := len(segment)
for j := i + 1; j < len(segments); j++ {
capacity += 1 + len(segments[j]) // 1 for '/'
}
builder := strings.Builder{}
builder.Grow(capacity)
builder.WriteString(segment)
for j := i + 1; j < len(segments); j++ {
builder.WriteByte('/')
builder.WriteString(segments[j])
}
wildcardValue = builder.String()
} else {
wildcardValue = segment
}
params.Set(current.wildcardChild.paramName, wildcardValue)
} }
return current.wildcardChild return current.wildcardChild
} }
@@ -326,31 +336,20 @@ func (n *node) find(segments []string, params *Params) *node {
// ServeHTTP implements the http.Handler interface // ServeHTTP implements the http.Handler interface
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if tree, exists := r.trees[req.Method]; exists { segments := r.segmentsPool.Get().([]string)
segments := parsePathZeroAlloc(req.URL.Path) segments = segments[:0]
segments = parsePathInto(req.URL.Path, segments)
// Get params from pool for better performance defer func() {
params := r.paramsPool.Get().(Params) // Clear to allow GC of prior path strings before pooling
params.Reset() for i := range segments {
segments[i] = ""
node := tree.find(segments, &params)
if node == nil || node.handler == nil {
// Check if the path exists for other methods to return 405 instead of 404
hasOtherMethod := false
for method, otherTree := range r.trees {
if method != req.Method {
if otherNode := otherTree.find(segments, &Params{}); otherNode != nil {
hasOtherMethod = true
break
}
}
} }
r.segmentsPool.Put(segments[:0])
}()
// Return params to pool tree, exists := r.trees[req.Method]
r.paramsPool.Put(params) if !exists {
if r.checkMethodNotAllowed && r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) {
if hasOtherMethod {
r.handleMethodNotAllowed(w, req) r.handleMethodNotAllowed(w, req)
} else { } else {
r.handleNotFound(w, req) r.handleNotFound(w, req)
@@ -358,8 +357,34 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return return
} }
// Add params to request context params := r.paramsPool.Get().(*Params)
ctx := context.WithValue(req.Context(), paramsKey, params) params.Reset()
node := tree.find(req.URL.Path, segments, params)
if node == nil || node.handler == nil {
params.Reset()
r.paramsPool.Put(params)
if r.checkMethodNotAllowed {
// Check if the path exists for other methods to return 405 instead of 404
if r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) {
r.handleMethodNotAllowed(w, req)
} else {
r.handleNotFound(w, req)
}
} else {
r.handleNotFound(w, req)
}
return
}
// Add params to request context as a copy to keep values stable after pooling
var ctxParams Params
if params.Len() > 0 {
ctxParams.keys = append(ctxParams.keys, params.keys...)
ctxParams.values = append(ctxParams.values, params.values...)
}
ctx := context.WithValue(req.Context(), paramsKey, ctxParams)
req = req.WithContext(ctx) req = req.WithContext(ctx)
// Create handler chain with middleware // Create handler chain with middleware
@@ -375,14 +400,24 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Defer returning params to pool after handler completes // Defer returning params to pool after handler completes
defer func() { defer func() {
params.Reset()
r.paramsPool.Put(params) r.paramsPool.Put(params)
}() }()
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
return }
}
r.handleMethodNotAllowed(w, req) // hasOtherMethodMatch checks whether the same path exists on another method tree.
func (r *Router) hasOtherMethodMatch(path string, segments []string, currentMethod string) bool {
for method, otherTree := range r.trees {
if method == currentMethod {
continue
}
if otherNode := otherTree.find(path, segments, nil); otherNode != nil && otherNode.handler != nil {
return true
}
}
return false
} }
// handleNotFound handles 404 responses // handleNotFound handles 404 responses
@@ -423,8 +458,8 @@ func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Router {
methodNotAllowedHandler: r.methodNotAllowedHandler, methodNotAllowedHandler: r.methodNotAllowedHandler,
paramsPool: r.paramsPool, // Already a pointer, just copy the reference paramsPool: r.paramsPool, // Already a pointer, just copy the reference
segmentsPool: r.segmentsPool, // Share the segments pool segmentsPool: r.segmentsPool, // Share the segments pool
builderPool: r.builderPool, // Share the builder pool
prefix: r.prefix + prefix, prefix: r.prefix + prefix,
checkMethodNotAllowed: r.checkMethodNotAllowed,
} }
} }
@@ -438,6 +473,11 @@ func (r *Router) MethodNotAllowed(handler http.HandlerFunc) {
r.methodNotAllowedHandler = handler r.methodNotAllowedHandler = handler
} }
// EnableMethodNotAllowedCheck toggles costly cross-method lookup on 404
func (r *Router) EnableMethodNotAllowedCheck(enabled bool) {
r.checkMethodNotAllowed = enabled
}
// addRoute is a helper method to add routes for different HTTP methods // addRoute is a helper method to add routes for different HTTP methods
func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router { func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router {
if r.trees == nil { if r.trees == nil {
@@ -502,7 +542,7 @@ func New() *Router {
middleware: make([]MiddlewareFunc, 0), middleware: make([]MiddlewareFunc, 0),
paramsPool: &sync.Pool{ paramsPool: &sync.Pool{
New: func() interface{} { New: func() interface{} {
return Params{ return &Params{
keys: make([]string, 0, 4), // Pre-allocate for common cases keys: make([]string, 0, 4), // Pre-allocate for common cases
values: make([]string, 0, 4), values: make([]string, 0, 4),
} }
@@ -513,12 +553,8 @@ func New() *Router {
return make([]string, 0, 4) // Pre-allocate for common cases return make([]string, 0, 4) // Pre-allocate for common cases
}, },
}, },
builderPool: &sync.Pool{
New: func() interface{} {
return &strings.Builder{}
},
},
prefix: "", prefix: "",
checkMethodNotAllowed: true,
} }
// Initialize trees for common HTTP methods // Initialize trees for common HTTP methods