package sux import ( "context" "net/http" "sync" ) const ( static nodeType = iota param wildcard ) type ( nodeType uint8 // Params holds route parameters extracted from the URL // Optimized to reduce allocations compared to map[string]string Params struct { keys []string values []string } // MiddlewareFunc represents a middleware function MiddlewareFunc func(http.Handler) http.Handler node struct { path string nodeType nodeType handler http.HandlerFunc children map[string]*node // Hash map for O(1) static child lookup paramChild *node // For :param routes wildcardChild *node // For *param routes paramName string // Name of the parameter middleware []MiddlewareFunc } Router struct { trees map[string]*node // One tree per HTTP method middleware []MiddlewareFunc notFoundHandler http.HandlerFunc methodNotAllowedHandler http.HandlerFunc paramsPool *sync.Pool // Use pointer to avoid copying segmentsPool *sync.Pool // Pool for path segments prefix string // For route groups checkMethodNotAllowed bool // Whether to probe other methods on 404 } ) // Context key for storing route parameters type contextKey string const paramsKey contextKey = "params" // ParamsFromContext extracts route parameters from the request context func ParamsFromContext(r *http.Request) Params { switch params := r.Context().Value(paramsKey).(type) { case Params: return params case *Params: if params != nil { return *params } } return Params{} } // Get returns the value for the given key func (p Params) Get(key string) string { for i, k := range p.keys { if k == key { return p.values[i] } } return "" } // Set sets the value for the given key func (p *Params) Set(key, value string) { // Check if key already exists for i, k := range p.keys { if k == key { p.values[i] = value return } } // Add new key-value pair p.keys = append(p.keys, key) p.values = append(p.values, value) } // Reset clears all parameters for reuse func (p *Params) Reset() { p.keys = p.keys[:0] p.values = p.values[:0] } // Len returns the number of parameters func (p Params) Len() int { return len(p.keys) } // makeNode creates a new node with the given path func makeNode(path string) *node { return &node{ path: path, children: make(map[string]*node), } } // addChild adds a child node for the given path segment func (n *node) addChild(path string) *node { if child := n.findChild(path); child == nil { newNode := makeNode(path) n.children[path] = newNode return newNode } else { return child } } // addParamChild adds a parameter child node func (n *node) addParamChild(paramName string) *node { if n.paramChild == nil { n.paramChild = makeNode(":" + paramName) n.paramChild.nodeType = param n.paramChild.paramName = paramName } return n.paramChild } // addWildcardChild adds a wildcard child node func (n *node) addWildcardChild(paramName string) *node { if n.wildcardChild == nil { n.wildcardChild = makeNode("*" + paramName) n.wildcardChild.nodeType = wildcard n.wildcardChild.paramName = paramName } return n.wildcardChild } // findChild finds a static child node by path func (n *node) findChild(path string) *node { return n.children[path] } // parsePath splits a path into segments and identifies parameters func parsePath(path string) []string { if path == "" || path == "/" { return []string{} } segments := make([]string, 0, 4) // Pre-allocate for common cases start := 1 // Skip the leading slash for i := 1; i < len(path); i++ { if path[i] == '/' { if start < i { segments = append(segments, path[start:i]) } start = i + 1 } } if start < len(path) { segments = append(segments, path[start:]) } return segments } // parsePathInto splits a path into segments using the provided buffer func parsePathInto(path string, buf []string) []string { if path == "" || path == "/" { return buf[:0] } buf = buf[:0] start := 1 // Skip the leading slash for i := 1; i < len(path); i++ { if path[i] == '/' { if start < i { buf = append(buf, path[start:i]) } start = i + 1 } } if start < len(path) { buf = append(buf, path[start:]) } return buf } // parsePathZeroAlloc splits a path into segments with minimal allocations // Uses string slicing to avoid allocations where possible func parsePathZeroAlloc(path string) []string { if path == "" || path == "/" { return nil } // Count segments first to pre-allocate exact size segmentCount := 1 for i := 1; i < len(path); i++ { if path[i] == '/' { segmentCount++ } } segments := make([]string, 0, segmentCount) start := 1 // Skip the leading slash for i := 1; i < len(path); i++ { if path[i] == '/' { if start < i { segments = append(segments, path[start:i]) } start = i + 1 } } if start < len(path) { segments = append(segments, path[start:]) } return segments } // isParam checks if a path segment is a parameter (:param) func isParam(segment string) bool { return len(segment) > 0 && segment[0] == ':' } // isWildcard checks if a path segment is a wildcard (*param) func isWildcard(segment string) bool { return len(segment) > 0 && segment[0] == '*' } // getParamName extracts the parameter name from a segment func getParamName(segment string) string { if len(segment) > 1 { return segment[1:] } return "" } // addRoute adds a route to the tree func (n *node) addRoute(segments []string, handler http.HandlerFunc) { n.addRouteWithMiddleware(segments, handler, nil) } // addRouteWithMiddleware adds a route to the tree with middleware func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFunc, middleware []MiddlewareFunc) { current := n // Handle root path ("/") case if len(segments) == 0 { n.nodeType = static n.handler = handler n.middleware = middleware return } for i, segment := range segments { isLast := i == len(segments)-1 if isParam(segment) { paramName := getParamName(segment) current = current.addParamChild(paramName) } else if isWildcard(segment) { paramName := getParamName(segment) current = current.addWildcardChild(paramName) } else { current = current.addChild(segment) } if isLast { current.handler = handler current.middleware = middleware } } } // find searches for a route matching the given path func (n *node) find(path string, segments []string, params *Params) *node { current := n pos := 1 // Track offset into the original path for wildcard slicing for i, segment := range segments { isLast := i == len(segments)-1 segmentStart := pos pos += len(segment) + 1 // advance past this segment and trailing slash // Try static match first (fastest path) if child := current.findChild(segment); child != nil { if isLast && child.nodeType == static { return child } current = child continue } // Try parameter match if current.paramChild != nil { if params != nil { params.Set(current.paramChild.paramName, segment) } if isLast && current.paramChild.nodeType == param { return current.paramChild } current = current.paramChild continue } // Try wildcard match (catches everything) if current.wildcardChild != nil { if params != nil { // Wildcard captures the rest of the path directly from the original string params.Set(current.wildcardChild.paramName, path[segmentStart:]) } return current.wildcardChild } return nil // No match found } if current.nodeType == static { return current } return nil } // ServeHTTP implements the http.Handler interface func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { segments := r.segmentsPool.Get().([]string) segments = segments[:0] segments = parsePathInto(req.URL.Path, segments) defer func() { // Clear to allow GC of prior path strings before pooling for i := range segments { segments[i] = "" } r.segmentsPool.Put(segments[:0]) }() tree, exists := r.trees[req.Method] if !exists { if r.checkMethodNotAllowed && r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) { r.handleMethodNotAllowed(w, req) } else { r.handleNotFound(w, req) } return } params := r.paramsPool.Get().(*Params) params.Reset() node := tree.find(req.URL.Path, segments, params) if node == nil || node.handler == nil { params.Reset() r.paramsPool.Put(params) if r.checkMethodNotAllowed { // Check if the path exists for other methods to return 405 instead of 404 if r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) { r.handleMethodNotAllowed(w, req) } else { r.handleNotFound(w, req) } } else { r.handleNotFound(w, req) } return } // Add params to request context as a copy to keep values stable after pooling var ctxParams Params if params.Len() > 0 { ctxParams.keys = append(ctxParams.keys, params.keys...) ctxParams.values = append(ctxParams.values, params.values...) } ctx := context.WithValue(req.Context(), paramsKey, ctxParams) req = req.WithContext(ctx) // Create handler chain with middleware handler := http.Handler(http.HandlerFunc(node.handler)) // Apply node-specific middleware first, then router middleware for i := len(node.middleware) - 1; i >= 0; i-- { handler = node.middleware[i](handler) } for i := len(r.middleware) - 1; i >= 0; i-- { handler = r.middleware[i](handler) } // Defer returning params to pool after handler completes defer func() { params.Reset() r.paramsPool.Put(params) }() handler.ServeHTTP(w, req) } // hasOtherMethodMatch checks whether the same path exists on another method tree. func (r *Router) hasOtherMethodMatch(path string, segments []string, currentMethod string) bool { for method, otherTree := range r.trees { if method == currentMethod { continue } if otherNode := otherTree.find(path, segments, nil); otherNode != nil && otherNode.handler != nil { return true } } return false } // handleNotFound handles 404 responses func (r *Router) handleNotFound(w http.ResponseWriter, req *http.Request) { if r.notFoundHandler != nil { r.notFoundHandler(w, req) } else { http.NotFound(w, req) } } // handleMethodNotAllowed handles 405 responses func (r *Router) handleMethodNotAllowed(w http.ResponseWriter, req *http.Request) { if r.methodNotAllowedHandler != nil { r.methodNotAllowedHandler(w, req) } else { w.WriteHeader(http.StatusMethodNotAllowed) w.Write([]byte("Method Not Allowed")) } } // Use adds global middleware to the router func (r *Router) Use(middleware ...MiddlewareFunc) { r.middleware = append(r.middleware, middleware...) } // Group creates a new route group with a prefix and optional middleware func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Router { // Ensure prefix starts with / if len(prefix) > 0 && prefix[0] != '/' { prefix = "/" + prefix } return &Router{ trees: r.trees, // Share the same trees middleware: append(r.middleware, middleware...), notFoundHandler: r.notFoundHandler, methodNotAllowedHandler: r.methodNotAllowedHandler, paramsPool: r.paramsPool, // Already a pointer, just copy the reference segmentsPool: r.segmentsPool, // Share the segments pool prefix: r.prefix + prefix, checkMethodNotAllowed: r.checkMethodNotAllowed, } } // NotFound sets a custom 404 handler func (r *Router) NotFound(handler http.HandlerFunc) { r.notFoundHandler = handler } // MethodNotAllowed sets a custom 405 handler func (r *Router) MethodNotAllowed(handler http.HandlerFunc) { r.methodNotAllowedHandler = handler } // EnableMethodNotAllowedCheck toggles costly cross-method lookup on 404 func (r *Router) EnableMethodNotAllowedCheck(enabled bool) { r.checkMethodNotAllowed = enabled } // addRoute is a helper method to add routes for different HTTP methods func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router { if r.trees == nil { r.trees = make(map[string]*node) } // Find the root tree for this method (from the main router or create new) var rootTree *node if tree, exists := r.trees[method]; exists { rootTree = tree } else { rootTree = makeNode("/") r.trees[method] = rootTree } // Apply group prefix if present fullPath := r.prefix + path segments := parsePath(fullPath) rootTree.addRouteWithMiddleware(segments, handler, r.middleware) return r } // GET adds a GET route func (r *Router) GET(path string, handler http.HandlerFunc) *Router { return r.addRoute("GET", path, handler) } // POST adds a POST route func (r *Router) POST(path string, handler http.HandlerFunc) *Router { return r.addRoute("POST", path, handler) } // PUT adds a PUT route func (r *Router) PUT(path string, handler http.HandlerFunc) *Router { return r.addRoute("PUT", path, handler) } // PATCH adds a PATCH route func (r *Router) PATCH(path string, handler http.HandlerFunc) *Router { return r.addRoute("PATCH", path, handler) } // DELETE adds a DELETE route func (r *Router) DELETE(path string, handler http.HandlerFunc) *Router { return r.addRoute("DELETE", path, handler) } // OPTIONS adds an OPTIONS route func (r *Router) OPTIONS(path string, handler http.HandlerFunc) *Router { return r.addRoute("OPTIONS", path, handler) } // HEAD adds a HEAD route func (r *Router) HEAD(path string, handler http.HandlerFunc) *Router { return r.addRoute("HEAD", path, handler) } // New creates a new Router instance func New() *Router { router := &Router{ trees: make(map[string]*node), middleware: make([]MiddlewareFunc, 0), paramsPool: &sync.Pool{ New: func() interface{} { return &Params{ keys: make([]string, 0, 4), // Pre-allocate for common cases values: make([]string, 0, 4), } }, }, segmentsPool: &sync.Pool{ New: func() interface{} { return make([]string, 0, 4) // Pre-allocate for common cases }, }, prefix: "", checkMethodNotAllowed: true, } // Initialize trees for common HTTP methods methods := []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"} for _, method := range methods { router.trees[method] = makeNode("/") } return router }