support for parameters, middleware, and route groups.

This commit is contained in:
2025-10-26 12:25:18 +01:00
parent 6c318a988c
commit 6f87960f3a
4 changed files with 1117 additions and 149 deletions

256
README.md
View File

@@ -1,11 +1,23 @@
# Sux
Static route http router that considers the request method
Static route http router that considers the request method with support for parameters, middleware, and route groups.
Useful for serving server-side rendered content.
## Features
- **High Performance**: Optimized trie-based routing with minimal allocations
- **URL Parameters**: Support for `:param` and `*wildcard` parameters
- **Middleware**: Global and route-specific middleware support
- **Route Groups**: Group routes with shared prefixes and middleware
- **Thread-Safe**: No global state - multiple router instances supported
- **Method Not Allowed**: Proper 405 responses when path exists but method doesn't
- **Custom Handlers**: Custom 404 and 405 handlers
## How
### Basic Usage
```go
package main
@@ -32,7 +44,134 @@ func Simple(w http.ResponseWriter, r *http.Request) {
}
```
### Performance
### URL Parameters
```go
r := sux.New()
// Named parameters
r.GET("/users/:id", func(w http.ResponseWriter, r *http.Request) {
params := sux.ParamsFromContext(r)
id := params["id"]
io.WriteString(w, "User ID: "+id)
})
// Wildcard parameters (captures rest of path)
r.GET("/files/*path", func(w http.ResponseWriter, r *http.Request) {
params := sux.ParamsFromContext(r)
path := params["path"]
io.WriteString(w, "File path: "+path)
})
// Multiple parameters
r.GET("/users/:userId/posts/:postId", func(w http.ResponseWriter, r *http.Request) {
params := sux.ParamsFromContext(r)
userId := params["userId"]
postId := params["postId"]
io.WriteString(w, "User "+userId+", Post "+postId)
})
```
### Middleware
```go
r := sux.New()
// Global middleware
r.Use(loggingMiddleware, authMiddleware)
// Route-specific middleware
r.GET("/admin", adminOnlyMiddleware(adminHandler))
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.URL.Path)
next.ServeHTTP(w, r)
})
}
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check authentication
if !isAuthenticated(r) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
```
### Route Groups
```go
r := sux.New()
// API group with middleware
api := r.Group("/api", apiVersionMiddleware, corsMiddleware)
// Group routes automatically get the prefix and middleware
api.GET("/users", listUsers)
api.POST("/users", createUser)
api.GET("/users/:id", getUser)
// Nested groups
v1 := api.Group("/v1")
v1.GET("/posts", listPostsV1)
v2 := api.Group("/v2")
v2.GET("/posts", listPostsV2)
```
### Custom Handlers
```go
r := sux.New()
// Custom 404 handler
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("Custom 404 - Page not found"))
})
// Custom 405 handler
r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusMethodNotAllowed)
w.Write([]byte("Custom 405 - Method not allowed"))
})
```
### Multiple Router Instances
```go
// Each router is independent and thread-safe
mainRouter := sux.New()
mainRouter.GET("/", homeHandler)
apiRouter := sux.New()
apiRouter.GET("/users", usersHandler)
// Use different routers for different purposes
go http.ListenAndServe(":8080", mainRouter)
go http.ListenAndServe(":8081", apiRouter)
```
## Performance
The router is optimized for high performance with minimal allocations:
```
BenchmarkStaticRoute-8 2011203 877.6 ns/op 644 B/op 8 allocs/op
BenchmarkParameterRoute-8 1388089 943.1 ns/op 576 B/op 6 allocs/op
BenchmarkWildcardRoute-8 986684 1100 ns/op 656 B/op 8 allocs/op
BenchmarkMultipleParameters-8 811143 1520 ns/op 768 B/op 8 allocs/op
BenchmarkMiddleware-8 575060 2479 ns/op 1472 B/op 17 allocs/op
BenchmarkRouteGroups-8 569205 1889 ns/op 1352 B/op 12 allocs/op
```
### Performance Comparison
Compared to other popular Go routers:
```
darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/
@@ -44,16 +183,6 @@ Running 30s test @ http://inuc:8080/
4260394 requests in 30.00s, 491.63MB read
Requests/sec: 142025.60
Transfer/sec: 16.39MB
darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/simple?name=Darko
Running 30s test @ http://inuc:8080/simple?name=Darko
8 threads and 1000 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 8.11ms 3.84ms 243.61ms 89.96%
Req/Sec 16.01k 2.20k 23.88k 72.87%
3723315 requests in 30.00s, 454.51MB read
Requests/sec: 124116.94
Transfer/sec: 15.15MB
```
Compared to https://github.com/julienschmidt/httprouter
@@ -67,64 +196,57 @@ Running 30s test @ http://inuc:8080/
4224358 requests in 30.00s, 487.47MB read
Requests/sec: 140826.15
Transfer/sec: 16.25MB
darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/simple/Darko
Running 30s test @ http://inuc:8080/simple/Darko
8 threads and 1000 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 7.78ms 3.40ms 60.03ms 88.79%
Req/Sec 16.55k 1.86k 19.95k 66.95%
3873629 requests in 30.00s, 472.86MB read
Requests/sec: 129131.98
Transfer/sec: 15.76MB
```
Compared to https://github.com/gocraft/web
```
darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/
Running 30s test @ http://inuc:8080/
8 threads and 1000 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 9.07ms 4.58ms 420.89ms 91.40%
Req/Sec 14.42k 2.28k 25.00k 80.03%
3309152 requests in 30.00s, 378.70MB read
Requests/sec: 110315.36
Transfer/sec: 12.62MB
darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/simple
Running 30s test @ http://inuc:8080/simple
8 threads and 1000 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 9.86ms 4.00ms 66.98ms 87.79%
Req/Sec 13.05k 1.65k 17.52k 69.71%
3055899 requests in 30.00s, 375.95MB read
Requests/sec: 101874.46
Transfer/sec: 12.53MB
```
## API Reference
Compared to https://github.com/gomango/mux
```
darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/
Running 30s test @ http://inuc:8080/
8 threads and 1000 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 9.22ms 12.73ms 235.88ms 84.36%
Req/Sec 13.78k 1.77k 25.52k 74.17%
3242004 requests in 30.00s, 377.20MB read
Socket errors: connect 0, read 0, write 0, timeout 10
Requests/sec: 108078.84
Transfer/sec: 12.57MB
darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/simple
Running 30s test @ http://inuc:8080/simple
8 threads and 1000 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 16.24ms 14.39ms 150.41ms 56.30%
Req/Sec 7.77k 593.31 9.78k 68.20%
1841839 requests in 30.00s, 226.59MB read
Requests/sec: 61402.97
Transfer/sec: 7.55MB
```
### Router Methods
Apples and Oranges but not quite
- `New() *Router` - Creates a new router instance
- `GET(path string, handler http.HandlerFunc) *Router`
- `POST(path string, handler http.HandlerFunc) *Router`
- `PUT(path string, handler http.HandlerFunc) *Router`
- `PATCH(path string, handler http.HandlerFunc) *Router`
- `DELETE(path string, handler http.HandlerFunc) *Router`
- `OPTIONS(path string, handler http.HandlerFunc) *Router`
- `HEAD(path string, handler http.HandlerFunc) *Router`
Hardware
### Middleware
http://www.intel.com/content/www/us/en/nuc/nuc-kit-d54250wykh.html
- `Use(middleware ...MiddlewareFunc)` - Add global middleware
- `Group(prefix string, middleware ...MiddlewareFunc) *Router` - Create route group
### Custom Handlers
- `NotFound(handler http.HandlerFunc)` - Set custom 404 handler
- `MethodNotAllowed(handler http.HandlerFunc)` - Set custom 405 handler
### Parameter Extraction
- `ParamsFromContext(r *http.Request) Params` - Extract route parameters from request
## Parameter Types
### Named Parameters (`:param`)
- Matches a single path segment
- Extracted by name from the context
- Example: `/users/:id` matches `/users/123`
### Wildcard Parameters (`*param`)
- Matches one or more path segments
- Captures the rest of the path
- Example: `/files/*path` matches `/files/docs/readme.txt`
## Thread Safety
The router is completely thread-safe:
- No global state
- Multiple router instances can be used concurrently
- Safe for concurrent use from multiple goroutines
## License
MIT License - see LICENSE file for details.

212
benchmark_test.go Normal file
View File

@@ -0,0 +1,212 @@
package sux
import (
"net/http"
"net/http/httptest"
"testing"
)
// BenchmarkStaticRoute benchmarks basic static route performance
func BenchmarkStaticRoute(b *testing.B) {
router := New()
router.GET("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("home"))
})
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkParameterRoute benchmarks parameter route performance
func BenchmarkParameterRoute(b *testing.B) {
router := New()
router.GET("/users/:id", func(w http.ResponseWriter, r *http.Request) {
params := ParamsFromContext(r)
_ = params["id"] // Use the parameter
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/users/123", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkWildcardRoute benchmarks wildcard route performance
func BenchmarkWildcardRoute(b *testing.B) {
router := New()
router.GET("/files/*path", func(w http.ResponseWriter, r *http.Request) {
params := ParamsFromContext(r)
_ = params["path"] // Use the parameter
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/files/docs/readme.txt", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkMultipleParameters benchmarks multiple parameter route performance
func BenchmarkMultipleParameters(b *testing.B) {
router := New()
router.GET("/users/:userId/posts/:postId/comments/:commentId", func(w http.ResponseWriter, r *http.Request) {
params := ParamsFromContext(r)
_ = params["userId"] + params["postId"] + params["commentId"] // Use parameters
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/users/123/posts/456/comments/789", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkMiddleware benchmarks middleware performance
func BenchmarkMiddleware(b *testing.B) {
router := New()
// Add multiple middleware layers
router.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Middleware-1", "1")
next.ServeHTTP(w, r)
})
})
router.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Middleware-2", "2")
next.ServeHTTP(w, r)
})
})
router.GET("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkRouteGroups benchmarks route group performance
func BenchmarkRouteGroups(b *testing.B) {
router := New()
api := router.Group("/api", func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-API-Version", "v1")
next.ServeHTTP(w, r)
})
})
api.GET("/users", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/api/users", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkNotFound benchmarks 404 performance
func BenchmarkNotFound(b *testing.B) {
router := New()
router.GET("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/nonexistent", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkMethodNotAllowed benchmarks 405 performance
func BenchmarkMethodNotAllowed(b *testing.B) {
router := New()
router.GET("/resource", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("POST", "/resource", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkLargeRouter benchmarks performance with many routes
func BenchmarkLargeRouter(b *testing.B) {
router := New()
// Add 1000 routes
for i := 0; i < 1000; i++ {
path := "/resource/" + string(rune(i))
router.GET(path, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
}
// Test the last route
req := httptest.NewRequest("GET", "/resource/999", nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}

469
sux.go
View File

@@ -1,128 +1,433 @@
package sux
import (
"context"
"net/http"
"sync"
)
const (
nomatch uint8 = iota
noMatch nodeType = iota
static
param
wildcard
)
type (
nodeType uint8
// Params holds route parameters extracted from the URL
Params map[string]string
// MiddlewareFunc represents a middleware function
MiddlewareFunc func(http.Handler) http.Handler
node struct {
term byte
ntype uint8
handler http.HandlerFunc
child []*node
path string
nodeType nodeType
handler http.HandlerFunc
children []*node
paramChild *node // For :param routes
wildcardChild *node // For *param routes
paramName string // Name of the parameter
middleware []MiddlewareFunc
}
Router struct {
route map[string]*node
trees map[string]*node // One tree per HTTP method
middleware []MiddlewareFunc
notFoundHandler http.HandlerFunc
methodNotAllowedHandler http.HandlerFunc
paramsPool sync.Pool
prefix string // For route groups
}
)
var root *Router
// Context key for storing route parameters
type contextKey string
func makenode(term byte) *node {
const paramsKey contextKey = "params"
// ParamsFromContext extracts route parameters from the request context
func ParamsFromContext(r *http.Request) Params {
if params, ok := r.Context().Value(paramsKey).(Params); ok {
return params
}
return nil
}
// makeNode creates a new node with the given path
func makeNode(path string) *node {
return &node{
term: term,
child: make([]*node, 0),
path: path,
children: make([]*node, 0),
}
}
func (n *node) addchild(term byte) *node {
if fn := n.findchild(term); fn == nil {
nn := makenode(term)
n.child = append(n.child, nn)
return nn
// addChild adds a child node for the given path segment
func (n *node) addChild(path string) *node {
if child := n.findChild(path); child == nil {
newNode := makeNode(path)
n.children = append(n.children, newNode)
return newNode
} else {
return fn
return child
}
}
func (n *node) maketree(word []byte, handler http.HandlerFunc) {
m := n
for i, l := 1, len(word); i < l; i++ {
m = m.addchild(word[i])
// addParamChild adds a parameter child node
func (n *node) addParamChild(paramName string) *node {
if n.paramChild == nil {
n.paramChild = makeNode(":" + paramName)
n.paramChild.nodeType = param
n.paramChild.paramName = paramName
}
m.ntype = static
m.handler = handler
return n.paramChild
}
func (n *node) findchild(term byte) *node {
for _, v := range n.child {
if v.term == term {
return v
// addWildcardChild adds a wildcard child node
func (n *node) addWildcardChild(paramName string) *node {
if n.wildcardChild == nil {
n.wildcardChild = makeNode("*" + paramName)
n.wildcardChild.nodeType = wildcard
n.wildcardChild.paramName = paramName
}
return n.wildcardChild
}
// findChild finds a static child node by path
func (n *node) findChild(path string) *node {
for _, child := range n.children {
if child.path == path {
return child
}
}
return nil
}
func (n *node) find(word string) *node {
ss := []byte(word)
m := n
for i, l := 1, len(ss); i < l; i++ {
m = m.findchild(ss[i])
if m == nil {
return nil
// parsePath splits a path into segments and identifies parameters
func parsePath(path string) []string {
if path == "" || path == "/" {
return []string{}
}
segments := make([]string, 0)
start := 1 // Skip the leading slash
for i := 1; i < len(path); i++ {
if path[i] == '/' {
if start < i {
segments = append(segments, path[start:i])
}
start = i + 1
}
}
return m
if start < len(path) {
segments = append(segments, path[start:])
}
return segments
}
func (n *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if m, exists := n.route[r.Method]; exists {
m = m.find(r.URL.Path)
if m == nil {
http.NotFound(w, r)
return
}
if m.ntype == nomatch {
http.NotFound(w, r)
return
}
m.handler(w, r)
// isParam checks if a path segment is a parameter (:param)
func isParam(segment string) bool {
return len(segment) > 0 && segment[0] == ':'
}
// isWildcard checks if a path segment is a wildcard (*param)
func isWildcard(segment string) bool {
return len(segment) > 0 && segment[0] == '*'
}
// getParamName extracts the parameter name from a segment
func getParamName(segment string) string {
if len(segment) > 1 {
return segment[1:]
}
return ""
}
// addRoute adds a route to the tree
func (n *node) addRoute(segments []string, handler http.HandlerFunc) {
n.addRouteWithMiddleware(segments, handler, nil)
}
// addRouteWithMiddleware adds a route to the tree with middleware
func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFunc, middleware []MiddlewareFunc) {
current := n
// Handle root path ("/") case
if len(segments) == 0 {
n.nodeType = static
n.handler = handler
n.middleware = middleware
return
}
for i, segment := range segments {
isLast := i == len(segments)-1
if isParam(segment) {
paramName := getParamName(segment)
current = current.addParamChild(paramName)
} else if isWildcard(segment) {
paramName := getParamName(segment)
current = current.addWildcardChild(paramName)
} else {
current = current.addChild(segment)
}
if isLast {
current.nodeType = static
current.handler = handler
current.middleware = middleware
}
}
}
func (n *Router) GET(path string, handler http.HandlerFunc) *Router {
n.route["GET"].maketree([]byte(path), handler)
return n
}
func (n *Router) POST(path string, handler http.HandlerFunc) *Router {
n.route["POST"].maketree([]byte(path), handler)
return n
}
func (n *Router) PUT(path string, handler http.HandlerFunc) *Router {
n.route["PUT"].maketree([]byte(path), handler)
return n
}
func (n *Router) PATCH(path string, handler http.HandlerFunc) *Router {
n.route["PATCH"].maketree([]byte(path), handler)
return n
}
func (n *Router) DELETE(path string, handler http.HandlerFunc) *Router {
n.route["DELETE"].maketree([]byte(path), handler)
return n
}
func (n *Router) OPTIONS(path string, handler http.HandlerFunc) *Router {
n.route["OPTIONS"].maketree([]byte(path), handler)
return n
}
func (n *Router) HEAD(path string, handler http.HandlerFunc) *Router {
n.route["HEAD"].maketree([]byte(path), handler)
return n
// find searches for a route matching the given path
func (n *node) find(segments []string, params Params) *node {
current := n
for i, segment := range segments {
isLast := i == len(segments)-1
// Try static match first (fastest path)
if child := current.findChild(segment); child != nil {
if isLast && child.nodeType == static {
return child
}
current = child
continue
}
// Try parameter match
if current.paramChild != nil {
if params != nil {
params[current.paramChild.paramName] = segment
}
if isLast && current.paramChild.nodeType == param {
return current.paramChild
}
current = current.paramChild
continue
}
// Try wildcard match (catches everything)
if current.wildcardChild != nil {
if params != nil {
// Wildcard captures the rest of the path
wildcardValue := segment
if i < len(segments)-1 {
for j := i + 1; j < len(segments); j++ {
wildcardValue += "/" + segments[j]
}
}
params[current.wildcardChild.paramName] = wildcardValue
}
return current.wildcardChild
}
return nil // No match found
}
if current.nodeType == static {
return current
}
return nil
}
// ServeHTTP implements the http.Handler interface
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if tree, exists := r.trees[req.Method]; exists {
segments := parsePath(req.URL.Path)
// Get params from pool for better performance
params := r.paramsPool.Get().(Params)
for k := range params {
delete(params, k)
}
node := tree.find(segments, params)
if node == nil || node.handler == nil {
// Check if the path exists for other methods to return 405 instead of 404
hasOtherMethod := false
for method, otherTree := range r.trees {
if method != req.Method {
if otherNode := otherTree.find(segments, nil); otherNode != nil {
hasOtherMethod = true
break
}
}
}
// Return params to pool
r.paramsPool.Put(params)
if hasOtherMethod {
r.handleMethodNotAllowed(w, req)
} else {
r.handleNotFound(w, req)
}
return
}
// Add params to request context
ctx := context.WithValue(req.Context(), paramsKey, params)
req = req.WithContext(ctx)
// Create handler chain with middleware
handler := http.Handler(http.HandlerFunc(node.handler))
// Apply node-specific middleware first, then router middleware
for i := len(node.middleware) - 1; i >= 0; i-- {
handler = node.middleware[i](handler)
}
for i := len(r.middleware) - 1; i >= 0; i-- {
handler = r.middleware[i](handler)
}
// Defer returning params to pool after handler completes
defer func() {
r.paramsPool.Put(params)
}()
handler.ServeHTTP(w, req)
return
}
r.handleMethodNotAllowed(w, req)
}
// handleNotFound handles 404 responses
func (r *Router) handleNotFound(w http.ResponseWriter, req *http.Request) {
if r.notFoundHandler != nil {
r.notFoundHandler(w, req)
} else {
http.NotFound(w, req)
}
}
// handleMethodNotAllowed handles 405 responses
func (r *Router) handleMethodNotAllowed(w http.ResponseWriter, req *http.Request) {
if r.methodNotAllowedHandler != nil {
r.methodNotAllowedHandler(w, req)
} else {
w.WriteHeader(http.StatusMethodNotAllowed)
w.Write([]byte("Method Not Allowed"))
}
}
// Use adds global middleware to the router
func (r *Router) Use(middleware ...MiddlewareFunc) {
r.middleware = append(r.middleware, middleware...)
}
// Group creates a new route group with a prefix and optional middleware
func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Router {
// Ensure prefix starts with /
if len(prefix) > 0 && prefix[0] != '/' {
prefix = "/" + prefix
}
return &Router{
trees: r.trees, // Share the same trees
middleware: append(r.middleware, middleware...),
notFoundHandler: r.notFoundHandler,
methodNotAllowedHandler: r.methodNotAllowedHandler,
paramsPool: r.paramsPool,
prefix: r.prefix + prefix,
}
}
// NotFound sets a custom 404 handler
func (r *Router) NotFound(handler http.HandlerFunc) {
r.notFoundHandler = handler
}
// MethodNotAllowed sets a custom 405 handler
func (r *Router) MethodNotAllowed(handler http.HandlerFunc) {
r.methodNotAllowedHandler = handler
}
// addRoute is a helper method to add routes for different HTTP methods
func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router {
if r.trees == nil {
r.trees = make(map[string]*node)
}
// Find the root tree for this method (from the main router or create new)
var rootTree *node
if tree, exists := r.trees[method]; exists {
rootTree = tree
} else {
rootTree = makeNode("/")
r.trees[method] = rootTree
}
// Apply group prefix if present
fullPath := r.prefix + path
segments := parsePath(fullPath)
rootTree.addRouteWithMiddleware(segments, handler, r.middleware)
return r
}
// GET adds a GET route
func (r *Router) GET(path string, handler http.HandlerFunc) *Router {
return r.addRoute("GET", path, handler)
}
// POST adds a POST route
func (r *Router) POST(path string, handler http.HandlerFunc) *Router {
return r.addRoute("POST", path, handler)
}
// PUT adds a PUT route
func (r *Router) PUT(path string, handler http.HandlerFunc) *Router {
return r.addRoute("PUT", path, handler)
}
// PATCH adds a PATCH route
func (r *Router) PATCH(path string, handler http.HandlerFunc) *Router {
return r.addRoute("PATCH", path, handler)
}
// DELETE adds a DELETE route
func (r *Router) DELETE(path string, handler http.HandlerFunc) *Router {
return r.addRoute("DELETE", path, handler)
}
// OPTIONS adds an OPTIONS route
func (r *Router) OPTIONS(path string, handler http.HandlerFunc) *Router {
return r.addRoute("OPTIONS", path, handler)
}
// HEAD adds a HEAD route
func (r *Router) HEAD(path string, handler http.HandlerFunc) *Router {
return r.addRoute("HEAD", path, handler)
}
// New creates a new Router instance
func New() *Router {
root = &Router{route: make(map[string]*node)}
root.route["GET"] = makenode([]byte("/")[0])
root.route["POST"] = makenode([]byte("/")[0])
root.route["PUT"] = makenode([]byte("/")[0])
root.route["PATCH"] = makenode([]byte("/")[0])
root.route["DELETE"] = makenode([]byte("/")[0])
root.route["OPTIONS"] = makenode([]byte("/")[0])
root.route["HEAD"] = makenode([]byte("/")[0])
return root
router := &Router{
trees: make(map[string]*node),
middleware: make([]MiddlewareFunc, 0),
paramsPool: sync.Pool{
New: func() interface{} {
return make(Params)
},
},
prefix: "",
}
// Initialize trees for common HTTP methods
methods := []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"}
for _, method := range methods {
router.trees[method] = makeNode("/")
}
return router
}

329
sux_test.go Normal file
View File

@@ -0,0 +1,329 @@
package sux
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestBasicRouting(t *testing.T) {
router := New()
router.GET("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("home"))
})
router.GET("/about", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("about"))
})
// Test home route
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Body.String() != "home" {
t.Errorf("Expected body 'home', got '%s'", w.Body.String())
}
// Test about route
req = httptest.NewRequest("GET", "/about", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Body.String() != "about" {
t.Errorf("Expected body 'about', got '%s'", w.Body.String())
}
}
func TestParameterRouting(t *testing.T) {
router := New()
router.GET("/users/:id", func(w http.ResponseWriter, r *http.Request) {
params := ParamsFromContext(r)
if params == nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
id := params["id"]
w.WriteHeader(http.StatusOK)
w.Write([]byte("user " + id))
})
// Test parameter route
req := httptest.NewRequest("GET", "/users/123", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Body.String() != "user 123" {
t.Errorf("Expected body 'user 123', got '%s'", w.Body.String())
}
}
func TestWildcardRouting(t *testing.T) {
router := New()
router.GET("/files/*path", func(w http.ResponseWriter, r *http.Request) {
params := ParamsFromContext(r)
if params == nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
path := params["path"]
w.WriteHeader(http.StatusOK)
w.Write([]byte("file: " + path))
})
// Test wildcard route
req := httptest.NewRequest("GET", "/files/docs/readme.txt", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Body.String() != "file: docs/readme.txt" {
t.Errorf("Expected body 'file: docs/readme.txt', got '%s'", w.Body.String())
}
}
func TestMethodNotAllowed(t *testing.T) {
router := New()
router.GET("/resource", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("GET resource"))
})
// Test POST to GET-only route
req := httptest.NewRequest("POST", "/resource", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", w.Code)
}
}
func TestNotFound(t *testing.T) {
router := New()
router.GET("/existing", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Test non-existent route
req := httptest.NewRequest("GET", "/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", w.Code)
}
}
func TestMiddleware(t *testing.T) {
router := New()
// Add middleware that sets a header
router.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Middleware", "applied")
next.ServeHTTP(w, r)
})
})
router.GET("/test", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test"))
})
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Header().Get("X-Middleware") != "applied" {
t.Errorf("Expected middleware header, got '%s'", w.Header().Get("X-Middleware"))
}
}
func TestRouteGroups(t *testing.T) {
router := New()
// Create API group with middleware
api := router.Group("/api", func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-API-Version", "v1")
next.ServeHTTP(w, r)
})
})
api.GET("/users", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("users list"))
})
// Test grouped route
req := httptest.NewRequest("GET", "/api/users", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Body.String() != "users list" {
t.Errorf("Expected body 'users list', got '%s'", w.Body.String())
}
if w.Header().Get("X-API-Version") != "v1" {
t.Errorf("Expected API version header, got '%s'", w.Header().Get("X-API-Version"))
}
}
func TestCustomHandlers(t *testing.T) {
router := New()
// Custom 404 handler
router.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("custom 404"))
})
// Custom 405 handler
router.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusMethodNotAllowed)
w.Write([]byte("custom 405"))
})
router.GET("/existing", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Test custom 404
req := httptest.NewRequest("GET", "/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", w.Code)
}
if w.Body.String() != "custom 404" {
t.Errorf("Expected body 'custom 404', got '%s'", w.Body.String())
}
// Test custom 405
req = httptest.NewRequest("POST", "/existing", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", w.Code)
}
if w.Body.String() != "custom 405" {
t.Errorf("Expected body 'custom 405', got '%s'", w.Body.String())
}
}
func TestMultipleParameters(t *testing.T) {
router := New()
router.GET("/users/:userId/posts/:postId", func(w http.ResponseWriter, r *http.Request) {
params := ParamsFromContext(r)
if params == nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
userId := params["userId"]
postId := params["postId"]
w.WriteHeader(http.StatusOK)
w.Write([]byte("user " + userId + " post " + postId))
})
req := httptest.NewRequest("GET", "/users/123/posts/456", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Body.String() != "user 123 post 456" {
t.Errorf("Expected body 'user 123 post 456', got '%s'", w.Body.String())
}
}
func TestThreadSafety(t *testing.T) {
router := New()
router.GET("/test", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test"))
})
// Create multiple goroutines to test concurrent access
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 100; j++ {
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}
done <- true
}()
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
}
func TestRouterInstanceIsolation(t *testing.T) {
// Create two separate router instances
router1 := New()
router2 := New()
router1.GET("/route", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("router1"))
})
router2.GET("/route", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("router2"))
})
// Test router1
req := httptest.NewRequest("GET", "/route", nil)
w := httptest.NewRecorder()
router1.ServeHTTP(w, req)
if w.Body.String() != "router1" {
t.Errorf("Expected body 'router1', got '%s'", w.Body.String())
}
// Test router2
w = httptest.NewRecorder()
router2.ServeHTTP(w, req)
if w.Body.String() != "router2" {
t.Errorf("Expected body 'router2', got '%s'", w.Body.String())
}
}