Fix pooled params correctness and add context copy benchmark
This commit is contained in:
@@ -29,7 +29,7 @@ func BenchmarkParameterOperations(b *testing.B) {
|
||||
// BenchmarkPathParsing compares old vs new path parsing
|
||||
func BenchmarkPathParsingOld(b *testing.B) {
|
||||
path := "/users/123/posts/456/comments/789"
|
||||
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
@@ -40,7 +40,7 @@ func BenchmarkPathParsingOld(b *testing.B) {
|
||||
|
||||
func BenchmarkPathParsingNew(b *testing.B) {
|
||||
path := "/users/123/posts/456/comments/789"
|
||||
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
@@ -197,4 +197,24 @@ func BenchmarkPoolEfficiency(b *testing.B) {
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkContextParamCopy measures overhead of copying params into context
|
||||
func BenchmarkContextParamCopy(b *testing.B) {
|
||||
router := New()
|
||||
router.GET("/users/:id/posts/:postId", func(w http.ResponseWriter, r *http.Request) {
|
||||
params := ParamsFromContext(r)
|
||||
_ = params.Get("id") + params.Get("postId")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/users/123/posts/456", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
208
sux.go
208
sux.go
@@ -3,7 +3,6 @@ package sux
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -44,8 +43,8 @@ type (
|
||||
methodNotAllowedHandler http.HandlerFunc
|
||||
paramsPool *sync.Pool // Use pointer to avoid copying
|
||||
segmentsPool *sync.Pool // Pool for path segments
|
||||
builderPool *sync.Pool // Pool for string builders
|
||||
prefix string // For route groups
|
||||
checkMethodNotAllowed bool // Whether to probe other methods on 404
|
||||
}
|
||||
)
|
||||
|
||||
@@ -56,8 +55,13 @@ const paramsKey contextKey = "params"
|
||||
|
||||
// ParamsFromContext extracts route parameters from the request context
|
||||
func ParamsFromContext(r *http.Request) Params {
|
||||
if params, ok := r.Context().Value(paramsKey).(Params); ok {
|
||||
switch params := r.Context().Value(paramsKey).(type) {
|
||||
case Params:
|
||||
return params
|
||||
case *Params:
|
||||
if params != nil {
|
||||
return *params
|
||||
}
|
||||
}
|
||||
return Params{}
|
||||
}
|
||||
@@ -148,7 +152,7 @@ func parsePath(path string) []string {
|
||||
}
|
||||
|
||||
segments := make([]string, 0, 4) // Pre-allocate for common cases
|
||||
start := 1 // Skip the leading slash
|
||||
start := 1 // Skip the leading slash
|
||||
|
||||
for i := 1; i < len(path); i++ {
|
||||
if path[i] == '/' {
|
||||
@@ -166,6 +170,31 @@ func parsePath(path string) []string {
|
||||
return segments
|
||||
}
|
||||
|
||||
// parsePathInto splits a path into segments using the provided buffer
|
||||
func parsePathInto(path string, buf []string) []string {
|
||||
if path == "" || path == "/" {
|
||||
return buf[:0]
|
||||
}
|
||||
|
||||
buf = buf[:0]
|
||||
start := 1 // Skip the leading slash
|
||||
|
||||
for i := 1; i < len(path); i++ {
|
||||
if path[i] == '/' {
|
||||
if start < i {
|
||||
buf = append(buf, path[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
if start < len(path) {
|
||||
buf = append(buf, path[start:])
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// parsePathZeroAlloc splits a path into segments with minimal allocations
|
||||
// Uses string slicing to avoid allocations where possible
|
||||
func parsePathZeroAlloc(path string) []string {
|
||||
@@ -249,7 +278,6 @@ func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFun
|
||||
}
|
||||
|
||||
if isLast {
|
||||
current.nodeType = static
|
||||
current.handler = handler
|
||||
current.middleware = middleware
|
||||
}
|
||||
@@ -257,11 +285,14 @@ func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFun
|
||||
}
|
||||
|
||||
// find searches for a route matching the given path
|
||||
func (n *node) find(segments []string, params *Params) *node {
|
||||
func (n *node) find(path string, segments []string, params *Params) *node {
|
||||
current := n
|
||||
pos := 1 // Track offset into the original path for wildcard slicing
|
||||
|
||||
for i, segment := range segments {
|
||||
isLast := i == len(segments)-1
|
||||
segmentStart := pos
|
||||
pos += len(segment) + 1 // advance past this segment and trailing slash
|
||||
|
||||
// Try static match first (fastest path)
|
||||
if child := current.findChild(segment); child != nil {
|
||||
@@ -274,7 +305,7 @@ func (n *node) find(segments []string, params *Params) *node {
|
||||
|
||||
// Try parameter match
|
||||
if current.paramChild != nil {
|
||||
if params != nil && params.keys != nil {
|
||||
if params != nil {
|
||||
params.Set(current.paramChild.paramName, segment)
|
||||
}
|
||||
if isLast && current.paramChild.nodeType == param {
|
||||
@@ -286,30 +317,9 @@ func (n *node) find(segments []string, params *Params) *node {
|
||||
|
||||
// Try wildcard match (catches everything)
|
||||
if current.wildcardChild != nil {
|
||||
if params != nil && params.keys != nil {
|
||||
// Wildcard captures the rest of the path
|
||||
// Use pooled string builder for efficient concatenation
|
||||
// Note: We can't use the pool here since we don't have access to router
|
||||
// But we can optimize the string concatenation
|
||||
var wildcardValue string
|
||||
if i < len(segments)-1 {
|
||||
// Pre-calculate capacity to avoid reallocations
|
||||
capacity := len(segment)
|
||||
for j := i + 1; j < len(segments); j++ {
|
||||
capacity += 1 + len(segments[j]) // 1 for '/'
|
||||
}
|
||||
builder := strings.Builder{}
|
||||
builder.Grow(capacity)
|
||||
builder.WriteString(segment)
|
||||
for j := i + 1; j < len(segments); j++ {
|
||||
builder.WriteByte('/')
|
||||
builder.WriteString(segments[j])
|
||||
}
|
||||
wildcardValue = builder.String()
|
||||
} else {
|
||||
wildcardValue = segment
|
||||
}
|
||||
params.Set(current.wildcardChild.paramName, wildcardValue)
|
||||
if params != nil {
|
||||
// Wildcard captures the rest of the path directly from the original string
|
||||
params.Set(current.wildcardChild.paramName, path[segmentStart:])
|
||||
}
|
||||
return current.wildcardChild
|
||||
}
|
||||
@@ -326,63 +336,88 @@ func (n *node) find(segments []string, params *Params) *node {
|
||||
|
||||
// ServeHTTP implements the http.Handler interface
|
||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if tree, exists := r.trees[req.Method]; exists {
|
||||
segments := parsePathZeroAlloc(req.URL.Path)
|
||||
segments := r.segmentsPool.Get().([]string)
|
||||
segments = segments[:0]
|
||||
segments = parsePathInto(req.URL.Path, segments)
|
||||
defer func() {
|
||||
// Clear to allow GC of prior path strings before pooling
|
||||
for i := range segments {
|
||||
segments[i] = ""
|
||||
}
|
||||
r.segmentsPool.Put(segments[:0])
|
||||
}()
|
||||
|
||||
// Get params from pool for better performance
|
||||
params := r.paramsPool.Get().(Params)
|
||||
tree, exists := r.trees[req.Method]
|
||||
if !exists {
|
||||
if r.checkMethodNotAllowed && r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) {
|
||||
r.handleMethodNotAllowed(w, req)
|
||||
} else {
|
||||
r.handleNotFound(w, req)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
params := r.paramsPool.Get().(*Params)
|
||||
params.Reset()
|
||||
|
||||
node := tree.find(req.URL.Path, segments, params)
|
||||
if node == nil || node.handler == nil {
|
||||
params.Reset()
|
||||
r.paramsPool.Put(params)
|
||||
|
||||
node := tree.find(segments, ¶ms)
|
||||
|
||||
if node == nil || node.handler == nil {
|
||||
if r.checkMethodNotAllowed {
|
||||
// Check if the path exists for other methods to return 405 instead of 404
|
||||
hasOtherMethod := false
|
||||
for method, otherTree := range r.trees {
|
||||
if method != req.Method {
|
||||
if otherNode := otherTree.find(segments, &Params{}); otherNode != nil {
|
||||
hasOtherMethod = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return params to pool
|
||||
r.paramsPool.Put(params)
|
||||
|
||||
if hasOtherMethod {
|
||||
if r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) {
|
||||
r.handleMethodNotAllowed(w, req)
|
||||
} else {
|
||||
r.handleNotFound(w, req)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
r.handleNotFound(w, req)
|
||||
}
|
||||
|
||||
// Add params to request context
|
||||
ctx := context.WithValue(req.Context(), paramsKey, params)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// Create handler chain with middleware
|
||||
handler := http.Handler(http.HandlerFunc(node.handler))
|
||||
|
||||
// Apply node-specific middleware first, then router middleware
|
||||
for i := len(node.middleware) - 1; i >= 0; i-- {
|
||||
handler = node.middleware[i](handler)
|
||||
}
|
||||
for i := len(r.middleware) - 1; i >= 0; i-- {
|
||||
handler = r.middleware[i](handler)
|
||||
}
|
||||
|
||||
// Defer returning params to pool after handler completes
|
||||
defer func() {
|
||||
r.paramsPool.Put(params)
|
||||
}()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
r.handleMethodNotAllowed(w, req)
|
||||
// Add params to request context as a copy to keep values stable after pooling
|
||||
var ctxParams Params
|
||||
if params.Len() > 0 {
|
||||
ctxParams.keys = append(ctxParams.keys, params.keys...)
|
||||
ctxParams.values = append(ctxParams.values, params.values...)
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), paramsKey, ctxParams)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// Create handler chain with middleware
|
||||
handler := http.Handler(http.HandlerFunc(node.handler))
|
||||
|
||||
// Apply node-specific middleware first, then router middleware
|
||||
for i := len(node.middleware) - 1; i >= 0; i-- {
|
||||
handler = node.middleware[i](handler)
|
||||
}
|
||||
for i := len(r.middleware) - 1; i >= 0; i-- {
|
||||
handler = r.middleware[i](handler)
|
||||
}
|
||||
|
||||
// Defer returning params to pool after handler completes
|
||||
defer func() {
|
||||
params.Reset()
|
||||
r.paramsPool.Put(params)
|
||||
}()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// hasOtherMethodMatch checks whether the same path exists on another method tree.
|
||||
func (r *Router) hasOtherMethodMatch(path string, segments []string, currentMethod string) bool {
|
||||
for method, otherTree := range r.trees {
|
||||
if method == currentMethod {
|
||||
continue
|
||||
}
|
||||
if otherNode := otherTree.find(path, segments, nil); otherNode != nil && otherNode.handler != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// handleNotFound handles 404 responses
|
||||
@@ -421,10 +456,10 @@ func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Router {
|
||||
middleware: append(r.middleware, middleware...),
|
||||
notFoundHandler: r.notFoundHandler,
|
||||
methodNotAllowedHandler: r.methodNotAllowedHandler,
|
||||
paramsPool: r.paramsPool, // Already a pointer, just copy the reference
|
||||
paramsPool: r.paramsPool, // Already a pointer, just copy the reference
|
||||
segmentsPool: r.segmentsPool, // Share the segments pool
|
||||
builderPool: r.builderPool, // Share the builder pool
|
||||
prefix: r.prefix + prefix,
|
||||
checkMethodNotAllowed: r.checkMethodNotAllowed,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -438,6 +473,11 @@ func (r *Router) MethodNotAllowed(handler http.HandlerFunc) {
|
||||
r.methodNotAllowedHandler = handler
|
||||
}
|
||||
|
||||
// EnableMethodNotAllowedCheck toggles costly cross-method lookup on 404
|
||||
func (r *Router) EnableMethodNotAllowedCheck(enabled bool) {
|
||||
r.checkMethodNotAllowed = enabled
|
||||
}
|
||||
|
||||
// addRoute is a helper method to add routes for different HTTP methods
|
||||
func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router {
|
||||
if r.trees == nil {
|
||||
@@ -502,8 +542,8 @@ func New() *Router {
|
||||
middleware: make([]MiddlewareFunc, 0),
|
||||
paramsPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return Params{
|
||||
keys: make([]string, 0, 4), // Pre-allocate for common cases
|
||||
return &Params{
|
||||
keys: make([]string, 0, 4), // Pre-allocate for common cases
|
||||
values: make([]string, 0, 4),
|
||||
}
|
||||
},
|
||||
@@ -513,12 +553,8 @@ func New() *Router {
|
||||
return make([]string, 0, 4) // Pre-allocate for common cases
|
||||
},
|
||||
},
|
||||
builderPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &strings.Builder{}
|
||||
},
|
||||
},
|
||||
prefix: "",
|
||||
prefix: "",
|
||||
checkMethodNotAllowed: true,
|
||||
}
|
||||
|
||||
// Initialize trees for common HTTP methods
|
||||
|
||||
Reference in New Issue
Block a user