Files
sux/sux.go

568 lines
14 KiB
Go
Raw Normal View History

2021-10-31 09:03:21 +01:00
package sux
import (
"context"
2021-10-31 09:03:21 +01:00
"net/http"
"sync"
2021-10-31 09:03:21 +01:00
)
const (
static nodeType = iota
param
wildcard
2021-10-31 09:03:21 +01:00
)
type (
nodeType uint8
// Params holds route parameters extracted from the URL
// Optimized to reduce allocations compared to map[string]string
Params struct {
keys []string
values []string
}
// MiddlewareFunc represents a middleware function
MiddlewareFunc func(http.Handler) http.Handler
2021-10-31 09:03:21 +01:00
node struct {
path string
nodeType nodeType
handler http.HandlerFunc
children map[string]*node // Hash map for O(1) static child lookup
paramChild *node // For :param routes
wildcardChild *node // For *param routes
paramName string // Name of the parameter
middleware []MiddlewareFunc
2021-10-31 09:03:21 +01:00
}
2021-10-31 09:03:21 +01:00
Router struct {
trees map[string]*node // One tree per HTTP method
middleware []MiddlewareFunc
notFoundHandler http.HandlerFunc
methodNotAllowedHandler http.HandlerFunc
paramsPool *sync.Pool // Use pointer to avoid copying
segmentsPool *sync.Pool // Pool for path segments
prefix string // For route groups
checkMethodNotAllowed bool // Whether to probe other methods on 404
2021-10-31 09:03:21 +01:00
}
)
// Context key for storing route parameters
type contextKey string
2021-10-31 09:03:21 +01:00
const paramsKey contextKey = "params"
// ParamsFromContext extracts route parameters from the request context
func ParamsFromContext(r *http.Request) Params {
switch params := r.Context().Value(paramsKey).(type) {
case Params:
return params
case *Params:
if params != nil {
return *params
}
}
return Params{}
}
// Get returns the value for the given key
func (p Params) Get(key string) string {
for i, k := range p.keys {
if k == key {
return p.values[i]
}
}
return ""
}
// Set sets the value for the given key
func (p *Params) Set(key, value string) {
// Check if key already exists
for i, k := range p.keys {
if k == key {
p.values[i] = value
return
}
}
// Add new key-value pair
p.keys = append(p.keys, key)
p.values = append(p.values, value)
}
// Reset clears all parameters for reuse
func (p *Params) Reset() {
p.keys = p.keys[:0]
p.values = p.values[:0]
}
// Len returns the number of parameters
func (p Params) Len() int {
return len(p.keys)
}
// makeNode creates a new node with the given path
func makeNode(path string) *node {
2021-10-31 09:03:21 +01:00
return &node{
path: path,
children: make(map[string]*node),
2021-10-31 09:03:21 +01:00
}
}
// addChild adds a child node for the given path segment
func (n *node) addChild(path string) *node {
if child := n.findChild(path); child == nil {
newNode := makeNode(path)
n.children[path] = newNode
return newNode
2021-10-31 09:03:21 +01:00
} else {
return child
}
}
// addParamChild adds a parameter child node
func (n *node) addParamChild(paramName string) *node {
if n.paramChild == nil {
n.paramChild = makeNode(":" + paramName)
n.paramChild.nodeType = param
n.paramChild.paramName = paramName
2021-10-31 09:03:21 +01:00
}
return n.paramChild
2021-10-31 09:03:21 +01:00
}
// addWildcardChild adds a wildcard child node
func (n *node) addWildcardChild(paramName string) *node {
if n.wildcardChild == nil {
n.wildcardChild = makeNode("*" + paramName)
n.wildcardChild.nodeType = wildcard
n.wildcardChild.paramName = paramName
2021-10-31 09:03:21 +01:00
}
return n.wildcardChild
2021-10-31 09:03:21 +01:00
}
// findChild finds a static child node by path
func (n *node) findChild(path string) *node {
return n.children[path]
2021-10-31 09:03:21 +01:00
}
// parsePath splits a path into segments and identifies parameters
func parsePath(path string) []string {
if path == "" || path == "/" {
return []string{}
}
segments := make([]string, 0, 4) // Pre-allocate for common cases
start := 1 // Skip the leading slash
for i := 1; i < len(path); i++ {
if path[i] == '/' {
if start < i {
segments = append(segments, path[start:i])
}
start = i + 1
}
}
if start < len(path) {
segments = append(segments, path[start:])
}
return segments
}
// parsePathInto splits a path into segments using the provided buffer
func parsePathInto(path string, buf []string) []string {
if path == "" || path == "/" {
return buf[:0]
}
buf = buf[:0]
start := 1 // Skip the leading slash
for i := 1; i < len(path); i++ {
if path[i] == '/' {
if start < i {
buf = append(buf, path[start:i])
}
start = i + 1
}
}
if start < len(path) {
buf = append(buf, path[start:])
}
return buf
}
// parsePathZeroAlloc splits a path into segments with minimal allocations
// Uses string slicing to avoid allocations where possible
func parsePathZeroAlloc(path string) []string {
if path == "" || path == "/" {
return nil
}
// Count segments first to pre-allocate exact size
segmentCount := 1
for i := 1; i < len(path); i++ {
if path[i] == '/' {
segmentCount++
}
}
segments := make([]string, 0, segmentCount)
start := 1 // Skip the leading slash
for i := 1; i < len(path); i++ {
if path[i] == '/' {
if start < i {
segments = append(segments, path[start:i])
}
start = i + 1
2021-10-31 09:03:21 +01:00
}
}
if start < len(path) {
segments = append(segments, path[start:])
}
return segments
2021-10-31 09:03:21 +01:00
}
// isParam checks if a path segment is a parameter (:param)
func isParam(segment string) bool {
return len(segment) > 0 && segment[0] == ':'
}
// isWildcard checks if a path segment is a wildcard (*param)
func isWildcard(segment string) bool {
return len(segment) > 0 && segment[0] == '*'
}
// getParamName extracts the parameter name from a segment
func getParamName(segment string) string {
if len(segment) > 1 {
return segment[1:]
}
return ""
}
// addRoute adds a route to the tree
func (n *node) addRoute(segments []string, handler http.HandlerFunc) {
n.addRouteWithMiddleware(segments, handler, nil)
}
// addRouteWithMiddleware adds a route to the tree with middleware
func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFunc, middleware []MiddlewareFunc) {
current := n
// Handle root path ("/") case
if len(segments) == 0 {
n.nodeType = static
n.handler = handler
n.middleware = middleware
return
}
for i, segment := range segments {
isLast := i == len(segments)-1
if isParam(segment) {
paramName := getParamName(segment)
current = current.addParamChild(paramName)
} else if isWildcard(segment) {
paramName := getParamName(segment)
current = current.addWildcardChild(paramName)
} else {
current = current.addChild(segment)
}
if isLast {
current.handler = handler
current.middleware = middleware
}
}
}
// find searches for a route matching the given path
func (n *node) find(path string, segments []string, params *Params) *node {
current := n
pos := 1 // Track offset into the original path for wildcard slicing
for i, segment := range segments {
isLast := i == len(segments)-1
segmentStart := pos
pos += len(segment) + 1 // advance past this segment and trailing slash
// Try static match first (fastest path)
if child := current.findChild(segment); child != nil {
if isLast && child.nodeType == static {
return child
}
current = child
continue
}
// Try parameter match
if current.paramChild != nil {
if params != nil {
params.Set(current.paramChild.paramName, segment)
}
if isLast && current.paramChild.nodeType == param {
return current.paramChild
}
current = current.paramChild
continue
2021-10-31 09:03:21 +01:00
}
// Try wildcard match (catches everything)
if current.wildcardChild != nil {
if params != nil {
// Wildcard captures the rest of the path directly from the original string
params.Set(current.wildcardChild.paramName, path[segmentStart:])
}
return current.wildcardChild
}
return nil // No match found
}
if current.nodeType == static {
return current
}
return nil
}
// ServeHTTP implements the http.Handler interface
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
segments := r.segmentsPool.Get().([]string)
segments = segments[:0]
segments = parsePathInto(req.URL.Path, segments)
defer func() {
// Clear to allow GC of prior path strings before pooling
for i := range segments {
segments[i] = ""
}
r.segmentsPool.Put(segments[:0])
}()
tree, exists := r.trees[req.Method]
if !exists {
if r.checkMethodNotAllowed && r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) {
r.handleMethodNotAllowed(w, req)
} else {
r.handleNotFound(w, req)
}
return
}
params := r.paramsPool.Get().(*Params)
params.Reset()
node := tree.find(req.URL.Path, segments, params)
if node == nil || node.handler == nil {
params.Reset()
r.paramsPool.Put(params)
if r.checkMethodNotAllowed {
// Check if the path exists for other methods to return 405 instead of 404
if r.hasOtherMethodMatch(req.URL.Path, segments, req.Method) {
r.handleMethodNotAllowed(w, req)
} else {
r.handleNotFound(w, req)
}
} else {
r.handleNotFound(w, req)
2021-10-31 09:03:21 +01:00
}
return
}
// Add params to request context as a copy to keep values stable after pooling
var ctxParams Params
if params.Len() > 0 {
ctxParams.keys = append(ctxParams.keys, params.keys...)
ctxParams.values = append(ctxParams.values, params.values...)
}
ctx := context.WithValue(req.Context(), paramsKey, ctxParams)
req = req.WithContext(ctx)
// Create handler chain with middleware
handler := http.Handler(http.HandlerFunc(node.handler))
// Apply node-specific middleware first, then router middleware
for i := len(node.middleware) - 1; i >= 0; i-- {
handler = node.middleware[i](handler)
}
for i := len(r.middleware) - 1; i >= 0; i-- {
handler = r.middleware[i](handler)
}
// Defer returning params to pool after handler completes
defer func() {
params.Reset()
r.paramsPool.Put(params)
}()
handler.ServeHTTP(w, req)
}
// hasOtherMethodMatch checks whether the same path exists on another method tree.
func (r *Router) hasOtherMethodMatch(path string, segments []string, currentMethod string) bool {
for method, otherTree := range r.trees {
if method == currentMethod {
continue
}
if otherNode := otherTree.find(path, segments, nil); otherNode != nil && otherNode.handler != nil {
return true
}
}
return false
2021-10-31 09:03:21 +01:00
}
// handleNotFound handles 404 responses
func (r *Router) handleNotFound(w http.ResponseWriter, req *http.Request) {
if r.notFoundHandler != nil {
r.notFoundHandler(w, req)
} else {
http.NotFound(w, req)
}
2021-10-31 09:03:21 +01:00
}
// handleMethodNotAllowed handles 405 responses
func (r *Router) handleMethodNotAllowed(w http.ResponseWriter, req *http.Request) {
if r.methodNotAllowedHandler != nil {
r.methodNotAllowedHandler(w, req)
} else {
w.WriteHeader(http.StatusMethodNotAllowed)
w.Write([]byte("Method Not Allowed"))
}
2021-10-31 09:03:21 +01:00
}
// Use adds global middleware to the router
func (r *Router) Use(middleware ...MiddlewareFunc) {
r.middleware = append(r.middleware, middleware...)
2021-10-31 09:03:21 +01:00
}
// Group creates a new route group with a prefix and optional middleware
func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Router {
// Ensure prefix starts with /
if len(prefix) > 0 && prefix[0] != '/' {
prefix = "/" + prefix
}
return &Router{
trees: r.trees, // Share the same trees
middleware: append(r.middleware, middleware...),
notFoundHandler: r.notFoundHandler,
methodNotAllowedHandler: r.methodNotAllowedHandler,
paramsPool: r.paramsPool, // Already a pointer, just copy the reference
segmentsPool: r.segmentsPool, // Share the segments pool
prefix: r.prefix + prefix,
checkMethodNotAllowed: r.checkMethodNotAllowed,
}
2021-10-31 09:03:21 +01:00
}
// NotFound sets a custom 404 handler
func (r *Router) NotFound(handler http.HandlerFunc) {
r.notFoundHandler = handler
2021-10-31 09:03:21 +01:00
}
// MethodNotAllowed sets a custom 405 handler
func (r *Router) MethodNotAllowed(handler http.HandlerFunc) {
r.methodNotAllowedHandler = handler
2021-10-31 09:03:21 +01:00
}
// EnableMethodNotAllowedCheck toggles costly cross-method lookup on 404
func (r *Router) EnableMethodNotAllowedCheck(enabled bool) {
r.checkMethodNotAllowed = enabled
}
// addRoute is a helper method to add routes for different HTTP methods
func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router {
if r.trees == nil {
r.trees = make(map[string]*node)
}
// Find the root tree for this method (from the main router or create new)
var rootTree *node
if tree, exists := r.trees[method]; exists {
rootTree = tree
} else {
rootTree = makeNode("/")
r.trees[method] = rootTree
}
// Apply group prefix if present
fullPath := r.prefix + path
segments := parsePath(fullPath)
rootTree.addRouteWithMiddleware(segments, handler, r.middleware)
return r
}
// GET adds a GET route
func (r *Router) GET(path string, handler http.HandlerFunc) *Router {
return r.addRoute("GET", path, handler)
}
// POST adds a POST route
func (r *Router) POST(path string, handler http.HandlerFunc) *Router {
return r.addRoute("POST", path, handler)
}
// PUT adds a PUT route
func (r *Router) PUT(path string, handler http.HandlerFunc) *Router {
return r.addRoute("PUT", path, handler)
2021-10-31 09:03:21 +01:00
}
// PATCH adds a PATCH route
func (r *Router) PATCH(path string, handler http.HandlerFunc) *Router {
return r.addRoute("PATCH", path, handler)
}
// DELETE adds a DELETE route
func (r *Router) DELETE(path string, handler http.HandlerFunc) *Router {
return r.addRoute("DELETE", path, handler)
}
// OPTIONS adds an OPTIONS route
func (r *Router) OPTIONS(path string, handler http.HandlerFunc) *Router {
return r.addRoute("OPTIONS", path, handler)
}
// HEAD adds a HEAD route
func (r *Router) HEAD(path string, handler http.HandlerFunc) *Router {
return r.addRoute("HEAD", path, handler)
}
// New creates a new Router instance
2021-10-31 09:03:21 +01:00
func New() *Router {
router := &Router{
trees: make(map[string]*node),
middleware: make([]MiddlewareFunc, 0),
paramsPool: &sync.Pool{
New: func() interface{} {
return &Params{
keys: make([]string, 0, 4), // Pre-allocate for common cases
values: make([]string, 0, 4),
}
},
},
segmentsPool: &sync.Pool{
New: func() interface{} {
return make([]string, 0, 4) // Pre-allocate for common cases
},
},
prefix: "",
checkMethodNotAllowed: true,
}
// Initialize trees for common HTTP methods
methods := []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"}
for _, method := range methods {
router.trees[method] = makeNode("/")
}
return router
2021-10-31 09:03:21 +01:00
}