package sux import ( "context" "net/http" "strings" "sync" ) const ( static nodeType = iota param wildcard ) type ( nodeType uint8 // Params holds route parameters extracted from the URL // Optimized to reduce allocations compared to map[string]string Params struct { keys []string values []string } // MiddlewareFunc represents a middleware function MiddlewareFunc func(http.Handler) http.Handler node struct { path string nodeType nodeType handler http.HandlerFunc children map[string]*node // Hash map for O(1) static child lookup paramChild *node // For :param routes wildcardChild *node // For *param routes paramName string // Name of the parameter middleware []MiddlewareFunc } Router struct { trees map[string]*node // One tree per HTTP method middleware []MiddlewareFunc notFoundHandler http.HandlerFunc methodNotAllowedHandler http.HandlerFunc paramsPool *sync.Pool // Use pointer to avoid copying segmentsPool *sync.Pool // Pool for path segments builderPool *sync.Pool // Pool for string builders prefix string // For route groups } ) // Context key for storing route parameters type contextKey string const paramsKey contextKey = "params" // ParamsFromContext extracts route parameters from the request context func ParamsFromContext(r *http.Request) Params { if params, ok := r.Context().Value(paramsKey).(Params); ok { return params } return Params{} } // Get returns the value for the given key func (p Params) Get(key string) string { for i, k := range p.keys { if k == key { return p.values[i] } } return "" } // Set sets the value for the given key func (p *Params) Set(key, value string) { // Check if key already exists for i, k := range p.keys { if k == key { p.values[i] = value return } } // Add new key-value pair p.keys = append(p.keys, key) p.values = append(p.values, value) } // Reset clears all parameters for reuse func (p *Params) Reset() { p.keys = p.keys[:0] p.values = p.values[:0] } // Len returns the number of parameters func (p Params) Len() int { return len(p.keys) } // makeNode creates a new node with the given path func makeNode(path string) *node { return &node{ path: path, children: make(map[string]*node), } } // addChild adds a child node for the given path segment func (n *node) addChild(path string) *node { if child := n.findChild(path); child == nil { newNode := makeNode(path) n.children[path] = newNode return newNode } else { return child } } // addParamChild adds a parameter child node func (n *node) addParamChild(paramName string) *node { if n.paramChild == nil { n.paramChild = makeNode(":" + paramName) n.paramChild.nodeType = param n.paramChild.paramName = paramName } return n.paramChild } // addWildcardChild adds a wildcard child node func (n *node) addWildcardChild(paramName string) *node { if n.wildcardChild == nil { n.wildcardChild = makeNode("*" + paramName) n.wildcardChild.nodeType = wildcard n.wildcardChild.paramName = paramName } return n.wildcardChild } // findChild finds a static child node by path func (n *node) findChild(path string) *node { return n.children[path] } // parsePath splits a path into segments and identifies parameters func parsePath(path string) []string { if path == "" || path == "/" { return []string{} } segments := make([]string, 0, 4) // Pre-allocate for common cases start := 1 // Skip the leading slash for i := 1; i < len(path); i++ { if path[i] == '/' { if start < i { segments = append(segments, path[start:i]) } start = i + 1 } } if start < len(path) { segments = append(segments, path[start:]) } return segments } // parsePathZeroAlloc splits a path into segments with minimal allocations // Uses string slicing to avoid allocations where possible func parsePathZeroAlloc(path string) []string { if path == "" || path == "/" { return nil } // Count segments first to pre-allocate exact size segmentCount := 1 for i := 1; i < len(path); i++ { if path[i] == '/' { segmentCount++ } } segments := make([]string, 0, segmentCount) start := 1 // Skip the leading slash for i := 1; i < len(path); i++ { if path[i] == '/' { if start < i { segments = append(segments, path[start:i]) } start = i + 1 } } if start < len(path) { segments = append(segments, path[start:]) } return segments } // isParam checks if a path segment is a parameter (:param) func isParam(segment string) bool { return len(segment) > 0 && segment[0] == ':' } // isWildcard checks if a path segment is a wildcard (*param) func isWildcard(segment string) bool { return len(segment) > 0 && segment[0] == '*' } // getParamName extracts the parameter name from a segment func getParamName(segment string) string { if len(segment) > 1 { return segment[1:] } return "" } // addRoute adds a route to the tree func (n *node) addRoute(segments []string, handler http.HandlerFunc) { n.addRouteWithMiddleware(segments, handler, nil) } // addRouteWithMiddleware adds a route to the tree with middleware func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFunc, middleware []MiddlewareFunc) { current := n // Handle root path ("/") case if len(segments) == 0 { n.nodeType = static n.handler = handler n.middleware = middleware return } for i, segment := range segments { isLast := i == len(segments)-1 if isParam(segment) { paramName := getParamName(segment) current = current.addParamChild(paramName) } else if isWildcard(segment) { paramName := getParamName(segment) current = current.addWildcardChild(paramName) } else { current = current.addChild(segment) } if isLast { current.nodeType = static current.handler = handler current.middleware = middleware } } } // find searches for a route matching the given path func (n *node) find(segments []string, params *Params) *node { current := n for i, segment := range segments { isLast := i == len(segments)-1 // Try static match first (fastest path) if child := current.findChild(segment); child != nil { if isLast && child.nodeType == static { return child } current = child continue } // Try parameter match if current.paramChild != nil { if params != nil && params.keys != nil { params.Set(current.paramChild.paramName, segment) } if isLast && current.paramChild.nodeType == param { return current.paramChild } current = current.paramChild continue } // Try wildcard match (catches everything) if current.wildcardChild != nil { if params != nil && params.keys != nil { // Wildcard captures the rest of the path // Use pooled string builder for efficient concatenation // Note: We can't use the pool here since we don't have access to router // But we can optimize the string concatenation var wildcardValue string if i < len(segments)-1 { // Pre-calculate capacity to avoid reallocations capacity := len(segment) for j := i + 1; j < len(segments); j++ { capacity += 1 + len(segments[j]) // 1 for '/' } builder := strings.Builder{} builder.Grow(capacity) builder.WriteString(segment) for j := i + 1; j < len(segments); j++ { builder.WriteByte('/') builder.WriteString(segments[j]) } wildcardValue = builder.String() } else { wildcardValue = segment } params.Set(current.wildcardChild.paramName, wildcardValue) } return current.wildcardChild } return nil // No match found } if current.nodeType == static { return current } return nil } // ServeHTTP implements the http.Handler interface func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if tree, exists := r.trees[req.Method]; exists { segments := parsePathZeroAlloc(req.URL.Path) // Get params from pool for better performance params := r.paramsPool.Get().(Params) params.Reset() node := tree.find(segments, ¶ms) if node == nil || node.handler == nil { // Check if the path exists for other methods to return 405 instead of 404 hasOtherMethod := false for method, otherTree := range r.trees { if method != req.Method { if otherNode := otherTree.find(segments, &Params{}); otherNode != nil { hasOtherMethod = true break } } } // Return params to pool r.paramsPool.Put(params) if hasOtherMethod { r.handleMethodNotAllowed(w, req) } else { r.handleNotFound(w, req) } return } // Add params to request context ctx := context.WithValue(req.Context(), paramsKey, params) req = req.WithContext(ctx) // Create handler chain with middleware handler := http.Handler(http.HandlerFunc(node.handler)) // Apply node-specific middleware first, then router middleware for i := len(node.middleware) - 1; i >= 0; i-- { handler = node.middleware[i](handler) } for i := len(r.middleware) - 1; i >= 0; i-- { handler = r.middleware[i](handler) } // Defer returning params to pool after handler completes defer func() { r.paramsPool.Put(params) }() handler.ServeHTTP(w, req) return } r.handleMethodNotAllowed(w, req) } // handleNotFound handles 404 responses func (r *Router) handleNotFound(w http.ResponseWriter, req *http.Request) { if r.notFoundHandler != nil { r.notFoundHandler(w, req) } else { http.NotFound(w, req) } } // handleMethodNotAllowed handles 405 responses func (r *Router) handleMethodNotAllowed(w http.ResponseWriter, req *http.Request) { if r.methodNotAllowedHandler != nil { r.methodNotAllowedHandler(w, req) } else { w.WriteHeader(http.StatusMethodNotAllowed) w.Write([]byte("Method Not Allowed")) } } // Use adds global middleware to the router func (r *Router) Use(middleware ...MiddlewareFunc) { r.middleware = append(r.middleware, middleware...) } // Group creates a new route group with a prefix and optional middleware func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Router { // Ensure prefix starts with / if len(prefix) > 0 && prefix[0] != '/' { prefix = "/" + prefix } return &Router{ trees: r.trees, // Share the same trees middleware: append(r.middleware, middleware...), notFoundHandler: r.notFoundHandler, methodNotAllowedHandler: r.methodNotAllowedHandler, paramsPool: r.paramsPool, // Already a pointer, just copy the reference segmentsPool: r.segmentsPool, // Share the segments pool builderPool: r.builderPool, // Share the builder pool prefix: r.prefix + prefix, } } // NotFound sets a custom 404 handler func (r *Router) NotFound(handler http.HandlerFunc) { r.notFoundHandler = handler } // MethodNotAllowed sets a custom 405 handler func (r *Router) MethodNotAllowed(handler http.HandlerFunc) { r.methodNotAllowedHandler = handler } // addRoute is a helper method to add routes for different HTTP methods func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router { if r.trees == nil { r.trees = make(map[string]*node) } // Find the root tree for this method (from the main router or create new) var rootTree *node if tree, exists := r.trees[method]; exists { rootTree = tree } else { rootTree = makeNode("/") r.trees[method] = rootTree } // Apply group prefix if present fullPath := r.prefix + path segments := parsePath(fullPath) rootTree.addRouteWithMiddleware(segments, handler, r.middleware) return r } // GET adds a GET route func (r *Router) GET(path string, handler http.HandlerFunc) *Router { return r.addRoute("GET", path, handler) } // POST adds a POST route func (r *Router) POST(path string, handler http.HandlerFunc) *Router { return r.addRoute("POST", path, handler) } // PUT adds a PUT route func (r *Router) PUT(path string, handler http.HandlerFunc) *Router { return r.addRoute("PUT", path, handler) } // PATCH adds a PATCH route func (r *Router) PATCH(path string, handler http.HandlerFunc) *Router { return r.addRoute("PATCH", path, handler) } // DELETE adds a DELETE route func (r *Router) DELETE(path string, handler http.HandlerFunc) *Router { return r.addRoute("DELETE", path, handler) } // OPTIONS adds an OPTIONS route func (r *Router) OPTIONS(path string, handler http.HandlerFunc) *Router { return r.addRoute("OPTIONS", path, handler) } // HEAD adds a HEAD route func (r *Router) HEAD(path string, handler http.HandlerFunc) *Router { return r.addRoute("HEAD", path, handler) } // New creates a new Router instance func New() *Router { router := &Router{ trees: make(map[string]*node), middleware: make([]MiddlewareFunc, 0), paramsPool: &sync.Pool{ New: func() interface{} { return Params{ keys: make([]string, 0, 4), // Pre-allocate for common cases values: make([]string, 0, 4), } }, }, segmentsPool: &sync.Pool{ New: func() interface{} { return make([]string, 0, 4) // Pre-allocate for common cases }, }, builderPool: &sync.Pool{ New: func() interface{} { return &strings.Builder{} }, }, prefix: "", } // Initialize trees for common HTTP methods methods := []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"} for _, method := range methods { router.trees[method] = makeNode("/") } return router }