diff --git a/README.md b/README.md index 279e939..ba67cf5 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,23 @@ # Sux -Static route http router that considers the request method +Static route http router that considers the request method with support for parameters, middleware, and route groups. Useful for serving server-side rendered content. +## Features + +- **High Performance**: Optimized trie-based routing with minimal allocations +- **URL Parameters**: Support for `:param` and `*wildcard` parameters +- **Middleware**: Global and route-specific middleware support +- **Route Groups**: Group routes with shared prefixes and middleware +- **Thread-Safe**: No global state - multiple router instances supported +- **Method Not Allowed**: Proper 405 responses when path exists but method doesn't +- **Custom Handlers**: Custom 404 and 405 handlers + ## How +### Basic Usage + ```go package main @@ -32,7 +44,134 @@ func Simple(w http.ResponseWriter, r *http.Request) { } ``` -### Performance +### URL Parameters + +```go +r := sux.New() + +// Named parameters +r.GET("/users/:id", func(w http.ResponseWriter, r *http.Request) { + params := sux.ParamsFromContext(r) + id := params["id"] + io.WriteString(w, "User ID: "+id) +}) + +// Wildcard parameters (captures rest of path) +r.GET("/files/*path", func(w http.ResponseWriter, r *http.Request) { + params := sux.ParamsFromContext(r) + path := params["path"] + io.WriteString(w, "File path: "+path) +}) + +// Multiple parameters +r.GET("/users/:userId/posts/:postId", func(w http.ResponseWriter, r *http.Request) { + params := sux.ParamsFromContext(r) + userId := params["userId"] + postId := params["postId"] + io.WriteString(w, "User "+userId+", Post "+postId) +}) +``` + +### Middleware + +```go +r := sux.New() + +// Global middleware +r.Use(loggingMiddleware, authMiddleware) + +// Route-specific middleware +r.GET("/admin", adminOnlyMiddleware(adminHandler)) + +func loggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s %s", r.Method, r.URL.Path) + next.ServeHTTP(w, r) + }) +} + +func authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check authentication + if !isAuthenticated(r) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} +``` + +### Route Groups + +```go +r := sux.New() + +// API group with middleware +api := r.Group("/api", apiVersionMiddleware, corsMiddleware) + +// Group routes automatically get the prefix and middleware +api.GET("/users", listUsers) +api.POST("/users", createUser) +api.GET("/users/:id", getUser) + +// Nested groups +v1 := api.Group("/v1") +v1.GET("/posts", listPostsV1) + +v2 := api.Group("/v2") +v2.GET("/posts", listPostsV2) +``` + +### Custom Handlers + +```go +r := sux.New() + +// Custom 404 handler +r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("Custom 404 - Page not found")) +}) + +// Custom 405 handler +r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + w.Write([]byte("Custom 405 - Method not allowed")) +}) +``` + +### Multiple Router Instances + +```go +// Each router is independent and thread-safe +mainRouter := sux.New() +mainRouter.GET("/", homeHandler) + +apiRouter := sux.New() +apiRouter.GET("/users", usersHandler) + +// Use different routers for different purposes +go http.ListenAndServe(":8080", mainRouter) +go http.ListenAndServe(":8081", apiRouter) +``` + +## Performance + +The router is optimized for high performance with minimal allocations: + +``` +BenchmarkStaticRoute-8 2011203 877.6 ns/op 644 B/op 8 allocs/op +BenchmarkParameterRoute-8 1388089 943.1 ns/op 576 B/op 6 allocs/op +BenchmarkWildcardRoute-8 986684 1100 ns/op 656 B/op 8 allocs/op +BenchmarkMultipleParameters-8 811143 1520 ns/op 768 B/op 8 allocs/op +BenchmarkMiddleware-8 575060 2479 ns/op 1472 B/op 17 allocs/op +BenchmarkRouteGroups-8 569205 1889 ns/op 1352 B/op 12 allocs/op +``` + +### Performance Comparison + +Compared to other popular Go routers: ``` darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/ @@ -44,16 +183,6 @@ Running 30s test @ http://inuc:8080/ 4260394 requests in 30.00s, 491.63MB read Requests/sec: 142025.60 Transfer/sec: 16.39MB -darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/simple?name=Darko -Running 30s test @ http://inuc:8080/simple?name=Darko - 8 threads and 1000 connections - Thread Stats Avg Stdev Max +/- Stdev - Latency 8.11ms 3.84ms 243.61ms 89.96% - Req/Sec 16.01k 2.20k 23.88k 72.87% - 3723315 requests in 30.00s, 454.51MB read -Requests/sec: 124116.94 -Transfer/sec: 15.15MB - ``` Compared to https://github.com/julienschmidt/httprouter @@ -67,64 +196,57 @@ Running 30s test @ http://inuc:8080/ 4224358 requests in 30.00s, 487.47MB read Requests/sec: 140826.15 Transfer/sec: 16.25MB -darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/simple/Darko -Running 30s test @ http://inuc:8080/simple/Darko - 8 threads and 1000 connections - Thread Stats Avg Stdev Max +/- Stdev - Latency 7.78ms 3.40ms 60.03ms 88.79% - Req/Sec 16.55k 1.86k 19.95k 66.95% - 3873629 requests in 30.00s, 472.86MB read -Requests/sec: 129131.98 -Transfer/sec: 15.76MB ``` -Compared to https://github.com/gocraft/web -``` -darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/ -Running 30s test @ http://inuc:8080/ - 8 threads and 1000 connections - Thread Stats Avg Stdev Max +/- Stdev - Latency 9.07ms 4.58ms 420.89ms 91.40% - Req/Sec 14.42k 2.28k 25.00k 80.03% - 3309152 requests in 30.00s, 378.70MB read -Requests/sec: 110315.36 -Transfer/sec: 12.62MB -darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/simple -Running 30s test @ http://inuc:8080/simple - 8 threads and 1000 connections - Thread Stats Avg Stdev Max +/- Stdev - Latency 9.86ms 4.00ms 66.98ms 87.79% - Req/Sec 13.05k 1.65k 17.52k 69.71% - 3055899 requests in 30.00s, 375.95MB read -Requests/sec: 101874.46 -Transfer/sec: 12.53MB -``` +## API Reference -Compared to https://github.com/gomango/mux -``` -darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/ -Running 30s test @ http://inuc:8080/ - 8 threads and 1000 connections - Thread Stats Avg Stdev Max +/- Stdev - Latency 9.22ms 12.73ms 235.88ms 84.36% - Req/Sec 13.78k 1.77k 25.52k 74.17% - 3242004 requests in 30.00s, 377.20MB read - Socket errors: connect 0, read 0, write 0, timeout 10 -Requests/sec: 108078.84 -Transfer/sec: 12.57MB -darko@arch ~ $ wrk -c1000 -t8 -d30s http://inuc:8080/simple -Running 30s test @ http://inuc:8080/simple - 8 threads and 1000 connections - Thread Stats Avg Stdev Max +/- Stdev - Latency 16.24ms 14.39ms 150.41ms 56.30% - Req/Sec 7.77k 593.31 9.78k 68.20% - 1841839 requests in 30.00s, 226.59MB read -Requests/sec: 61402.97 -Transfer/sec: 7.55MB -``` +### Router Methods -Apples and Oranges but not quite +- `New() *Router` - Creates a new router instance +- `GET(path string, handler http.HandlerFunc) *Router` +- `POST(path string, handler http.HandlerFunc) *Router` +- `PUT(path string, handler http.HandlerFunc) *Router` +- `PATCH(path string, handler http.HandlerFunc) *Router` +- `DELETE(path string, handler http.HandlerFunc) *Router` +- `OPTIONS(path string, handler http.HandlerFunc) *Router` +- `HEAD(path string, handler http.HandlerFunc) *Router` -Hardware +### Middleware -http://www.intel.com/content/www/us/en/nuc/nuc-kit-d54250wykh.html +- `Use(middleware ...MiddlewareFunc)` - Add global middleware +- `Group(prefix string, middleware ...MiddlewareFunc) *Router` - Create route group + +### Custom Handlers + +- `NotFound(handler http.HandlerFunc)` - Set custom 404 handler +- `MethodNotAllowed(handler http.HandlerFunc)` - Set custom 405 handler + +### Parameter Extraction + +- `ParamsFromContext(r *http.Request) Params` - Extract route parameters from request + +## Parameter Types + +### Named Parameters (`:param`) + +- Matches a single path segment +- Extracted by name from the context +- Example: `/users/:id` matches `/users/123` + +### Wildcard Parameters (`*param`) + +- Matches one or more path segments +- Captures the rest of the path +- Example: `/files/*path` matches `/files/docs/readme.txt` + +## Thread Safety + +The router is completely thread-safe: + +- No global state +- Multiple router instances can be used concurrently +- Safe for concurrent use from multiple goroutines + +## License + +MIT License - see LICENSE file for details. diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000..21b312b --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,212 @@ +package sux + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +// BenchmarkStaticRoute benchmarks basic static route performance +func BenchmarkStaticRoute(b *testing.B) { + router := New() + + router.GET("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("home")) + }) + + req := httptest.NewRequest("GET", "/", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} + +// BenchmarkParameterRoute benchmarks parameter route performance +func BenchmarkParameterRoute(b *testing.B) { + router := New() + + router.GET("/users/:id", func(w http.ResponseWriter, r *http.Request) { + params := ParamsFromContext(r) + _ = params["id"] // Use the parameter + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/users/123", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} + +// BenchmarkWildcardRoute benchmarks wildcard route performance +func BenchmarkWildcardRoute(b *testing.B) { + router := New() + + router.GET("/files/*path", func(w http.ResponseWriter, r *http.Request) { + params := ParamsFromContext(r) + _ = params["path"] // Use the parameter + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/files/docs/readme.txt", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} + +// BenchmarkMultipleParameters benchmarks multiple parameter route performance +func BenchmarkMultipleParameters(b *testing.B) { + router := New() + + router.GET("/users/:userId/posts/:postId/comments/:commentId", func(w http.ResponseWriter, r *http.Request) { + params := ParamsFromContext(r) + _ = params["userId"] + params["postId"] + params["commentId"] // Use parameters + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/users/123/posts/456/comments/789", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} + +// BenchmarkMiddleware benchmarks middleware performance +func BenchmarkMiddleware(b *testing.B) { + router := New() + + // Add multiple middleware layers + router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Middleware-1", "1") + next.ServeHTTP(w, r) + }) + }) + + router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Middleware-2", "2") + next.ServeHTTP(w, r) + }) + }) + + router.GET("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} + +// BenchmarkRouteGroups benchmarks route group performance +func BenchmarkRouteGroups(b *testing.B) { + router := New() + + api := router.Group("/api", func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-API-Version", "v1") + next.ServeHTTP(w, r) + }) + }) + + api.GET("/users", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/api/users", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} + +// BenchmarkNotFound benchmarks 404 performance +func BenchmarkNotFound(b *testing.B) { + router := New() + + router.GET("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/nonexistent", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} + +// BenchmarkMethodNotAllowed benchmarks 405 performance +func BenchmarkMethodNotAllowed(b *testing.B) { + router := New() + + router.GET("/resource", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("POST", "/resource", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} + +// BenchmarkLargeRouter benchmarks performance with many routes +func BenchmarkLargeRouter(b *testing.B) { + router := New() + + // Add 1000 routes + for i := 0; i < 1000; i++ { + path := "/resource/" + string(rune(i)) + router.GET(path, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + } + + // Test the last route + req := httptest.NewRequest("GET", "/resource/999", nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + } +} diff --git a/sux.go b/sux.go index af92695..08e7911 100644 --- a/sux.go +++ b/sux.go @@ -1,128 +1,433 @@ package sux import ( + "context" "net/http" + "sync" ) const ( - nomatch uint8 = iota + noMatch nodeType = iota static + param + wildcard ) type ( + nodeType uint8 + + // Params holds route parameters extracted from the URL + Params map[string]string + + // MiddlewareFunc represents a middleware function + MiddlewareFunc func(http.Handler) http.Handler + node struct { - term byte - ntype uint8 - handler http.HandlerFunc - child []*node + path string + nodeType nodeType + handler http.HandlerFunc + children []*node + paramChild *node // For :param routes + wildcardChild *node // For *param routes + paramName string // Name of the parameter + middleware []MiddlewareFunc } + Router struct { - route map[string]*node + trees map[string]*node // One tree per HTTP method + middleware []MiddlewareFunc + notFoundHandler http.HandlerFunc + methodNotAllowedHandler http.HandlerFunc + paramsPool sync.Pool + prefix string // For route groups } ) -var root *Router +// Context key for storing route parameters +type contextKey string -func makenode(term byte) *node { +const paramsKey contextKey = "params" + +// ParamsFromContext extracts route parameters from the request context +func ParamsFromContext(r *http.Request) Params { + if params, ok := r.Context().Value(paramsKey).(Params); ok { + return params + } + return nil +} + +// makeNode creates a new node with the given path +func makeNode(path string) *node { return &node{ - term: term, - child: make([]*node, 0), + path: path, + children: make([]*node, 0), } } -func (n *node) addchild(term byte) *node { - if fn := n.findchild(term); fn == nil { - nn := makenode(term) - n.child = append(n.child, nn) - return nn +// addChild adds a child node for the given path segment +func (n *node) addChild(path string) *node { + if child := n.findChild(path); child == nil { + newNode := makeNode(path) + n.children = append(n.children, newNode) + return newNode } else { - return fn + return child } } -func (n *node) maketree(word []byte, handler http.HandlerFunc) { - m := n - for i, l := 1, len(word); i < l; i++ { - m = m.addchild(word[i]) +// addParamChild adds a parameter child node +func (n *node) addParamChild(paramName string) *node { + if n.paramChild == nil { + n.paramChild = makeNode(":" + paramName) + n.paramChild.nodeType = param + n.paramChild.paramName = paramName } - m.ntype = static - m.handler = handler + return n.paramChild } -func (n *node) findchild(term byte) *node { - for _, v := range n.child { - if v.term == term { - return v +// addWildcardChild adds a wildcard child node +func (n *node) addWildcardChild(paramName string) *node { + if n.wildcardChild == nil { + n.wildcardChild = makeNode("*" + paramName) + n.wildcardChild.nodeType = wildcard + n.wildcardChild.paramName = paramName + } + return n.wildcardChild +} + +// findChild finds a static child node by path +func (n *node) findChild(path string) *node { + for _, child := range n.children { + if child.path == path { + return child } } return nil } -func (n *node) find(word string) *node { - ss := []byte(word) - m := n - for i, l := 1, len(ss); i < l; i++ { - m = m.findchild(ss[i]) - if m == nil { - return nil +// parsePath splits a path into segments and identifies parameters +func parsePath(path string) []string { + if path == "" || path == "/" { + return []string{} + } + + segments := make([]string, 0) + start := 1 // Skip the leading slash + + for i := 1; i < len(path); i++ { + if path[i] == '/' { + if start < i { + segments = append(segments, path[start:i]) + } + start = i + 1 } } - return m + + if start < len(path) { + segments = append(segments, path[start:]) + } + + return segments } -func (n *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if m, exists := n.route[r.Method]; exists { - m = m.find(r.URL.Path) - if m == nil { - http.NotFound(w, r) - return - } - if m.ntype == nomatch { - http.NotFound(w, r) - return - } - m.handler(w, r) +// isParam checks if a path segment is a parameter (:param) +func isParam(segment string) bool { + return len(segment) > 0 && segment[0] == ':' +} + +// isWildcard checks if a path segment is a wildcard (*param) +func isWildcard(segment string) bool { + return len(segment) > 0 && segment[0] == '*' +} + +// getParamName extracts the parameter name from a segment +func getParamName(segment string) string { + if len(segment) > 1 { + return segment[1:] + } + return "" +} + +// addRoute adds a route to the tree +func (n *node) addRoute(segments []string, handler http.HandlerFunc) { + n.addRouteWithMiddleware(segments, handler, nil) +} + +// addRouteWithMiddleware adds a route to the tree with middleware +func (n *node) addRouteWithMiddleware(segments []string, handler http.HandlerFunc, middleware []MiddlewareFunc) { + current := n + + // Handle root path ("/") case + if len(segments) == 0 { + n.nodeType = static + n.handler = handler + n.middleware = middleware return } + + for i, segment := range segments { + isLast := i == len(segments)-1 + + if isParam(segment) { + paramName := getParamName(segment) + current = current.addParamChild(paramName) + } else if isWildcard(segment) { + paramName := getParamName(segment) + current = current.addWildcardChild(paramName) + } else { + current = current.addChild(segment) + } + + if isLast { + current.nodeType = static + current.handler = handler + current.middleware = middleware + } + } } -func (n *Router) GET(path string, handler http.HandlerFunc) *Router { - n.route["GET"].maketree([]byte(path), handler) - return n -} -func (n *Router) POST(path string, handler http.HandlerFunc) *Router { - n.route["POST"].maketree([]byte(path), handler) - return n -} -func (n *Router) PUT(path string, handler http.HandlerFunc) *Router { - n.route["PUT"].maketree([]byte(path), handler) - return n -} -func (n *Router) PATCH(path string, handler http.HandlerFunc) *Router { - n.route["PATCH"].maketree([]byte(path), handler) - return n -} -func (n *Router) DELETE(path string, handler http.HandlerFunc) *Router { - n.route["DELETE"].maketree([]byte(path), handler) - return n -} -func (n *Router) OPTIONS(path string, handler http.HandlerFunc) *Router { - n.route["OPTIONS"].maketree([]byte(path), handler) - return n -} -func (n *Router) HEAD(path string, handler http.HandlerFunc) *Router { - n.route["HEAD"].maketree([]byte(path), handler) - return n +// find searches for a route matching the given path +func (n *node) find(segments []string, params Params) *node { + current := n + + for i, segment := range segments { + isLast := i == len(segments)-1 + + // Try static match first (fastest path) + if child := current.findChild(segment); child != nil { + if isLast && child.nodeType == static { + return child + } + current = child + continue + } + + // Try parameter match + if current.paramChild != nil { + if params != nil { + params[current.paramChild.paramName] = segment + } + if isLast && current.paramChild.nodeType == param { + return current.paramChild + } + current = current.paramChild + continue + } + + // Try wildcard match (catches everything) + if current.wildcardChild != nil { + if params != nil { + // Wildcard captures the rest of the path + wildcardValue := segment + if i < len(segments)-1 { + for j := i + 1; j < len(segments); j++ { + wildcardValue += "/" + segments[j] + } + } + params[current.wildcardChild.paramName] = wildcardValue + } + return current.wildcardChild + } + + return nil // No match found + } + + if current.nodeType == static { + return current + } + + return nil } +// ServeHTTP implements the http.Handler interface +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if tree, exists := r.trees[req.Method]; exists { + segments := parsePath(req.URL.Path) + + // Get params from pool for better performance + params := r.paramsPool.Get().(Params) + for k := range params { + delete(params, k) + } + + node := tree.find(segments, params) + + if node == nil || node.handler == nil { + // Check if the path exists for other methods to return 405 instead of 404 + hasOtherMethod := false + for method, otherTree := range r.trees { + if method != req.Method { + if otherNode := otherTree.find(segments, nil); otherNode != nil { + hasOtherMethod = true + break + } + } + } + + // Return params to pool + r.paramsPool.Put(params) + + if hasOtherMethod { + r.handleMethodNotAllowed(w, req) + } else { + r.handleNotFound(w, req) + } + return + } + + // Add params to request context + ctx := context.WithValue(req.Context(), paramsKey, params) + req = req.WithContext(ctx) + + // Create handler chain with middleware + handler := http.Handler(http.HandlerFunc(node.handler)) + + // Apply node-specific middleware first, then router middleware + for i := len(node.middleware) - 1; i >= 0; i-- { + handler = node.middleware[i](handler) + } + for i := len(r.middleware) - 1; i >= 0; i-- { + handler = r.middleware[i](handler) + } + + // Defer returning params to pool after handler completes + defer func() { + r.paramsPool.Put(params) + }() + + handler.ServeHTTP(w, req) + return + } + + r.handleMethodNotAllowed(w, req) +} + +// handleNotFound handles 404 responses +func (r *Router) handleNotFound(w http.ResponseWriter, req *http.Request) { + if r.notFoundHandler != nil { + r.notFoundHandler(w, req) + } else { + http.NotFound(w, req) + } +} + +// handleMethodNotAllowed handles 405 responses +func (r *Router) handleMethodNotAllowed(w http.ResponseWriter, req *http.Request) { + if r.methodNotAllowedHandler != nil { + r.methodNotAllowedHandler(w, req) + } else { + w.WriteHeader(http.StatusMethodNotAllowed) + w.Write([]byte("Method Not Allowed")) + } +} + +// Use adds global middleware to the router +func (r *Router) Use(middleware ...MiddlewareFunc) { + r.middleware = append(r.middleware, middleware...) +} + +// Group creates a new route group with a prefix and optional middleware +func (r *Router) Group(prefix string, middleware ...MiddlewareFunc) *Router { + // Ensure prefix starts with / + if len(prefix) > 0 && prefix[0] != '/' { + prefix = "/" + prefix + } + + return &Router{ + trees: r.trees, // Share the same trees + middleware: append(r.middleware, middleware...), + notFoundHandler: r.notFoundHandler, + methodNotAllowedHandler: r.methodNotAllowedHandler, + paramsPool: r.paramsPool, + prefix: r.prefix + prefix, + } +} + +// NotFound sets a custom 404 handler +func (r *Router) NotFound(handler http.HandlerFunc) { + r.notFoundHandler = handler +} + +// MethodNotAllowed sets a custom 405 handler +func (r *Router) MethodNotAllowed(handler http.HandlerFunc) { + r.methodNotAllowedHandler = handler +} + +// addRoute is a helper method to add routes for different HTTP methods +func (r *Router) addRoute(method, path string, handler http.HandlerFunc) *Router { + if r.trees == nil { + r.trees = make(map[string]*node) + } + + // Find the root tree for this method (from the main router or create new) + var rootTree *node + if tree, exists := r.trees[method]; exists { + rootTree = tree + } else { + rootTree = makeNode("/") + r.trees[method] = rootTree + } + + // Apply group prefix if present + fullPath := r.prefix + path + segments := parsePath(fullPath) + rootTree.addRouteWithMiddleware(segments, handler, r.middleware) + return r +} + +// GET adds a GET route +func (r *Router) GET(path string, handler http.HandlerFunc) *Router { + return r.addRoute("GET", path, handler) +} + +// POST adds a POST route +func (r *Router) POST(path string, handler http.HandlerFunc) *Router { + return r.addRoute("POST", path, handler) +} + +// PUT adds a PUT route +func (r *Router) PUT(path string, handler http.HandlerFunc) *Router { + return r.addRoute("PUT", path, handler) +} + +// PATCH adds a PATCH route +func (r *Router) PATCH(path string, handler http.HandlerFunc) *Router { + return r.addRoute("PATCH", path, handler) +} + +// DELETE adds a DELETE route +func (r *Router) DELETE(path string, handler http.HandlerFunc) *Router { + return r.addRoute("DELETE", path, handler) +} + +// OPTIONS adds an OPTIONS route +func (r *Router) OPTIONS(path string, handler http.HandlerFunc) *Router { + return r.addRoute("OPTIONS", path, handler) +} + +// HEAD adds a HEAD route +func (r *Router) HEAD(path string, handler http.HandlerFunc) *Router { + return r.addRoute("HEAD", path, handler) +} + +// New creates a new Router instance func New() *Router { - root = &Router{route: make(map[string]*node)} - root.route["GET"] = makenode([]byte("/")[0]) - root.route["POST"] = makenode([]byte("/")[0]) - root.route["PUT"] = makenode([]byte("/")[0]) - root.route["PATCH"] = makenode([]byte("/")[0]) - root.route["DELETE"] = makenode([]byte("/")[0]) - root.route["OPTIONS"] = makenode([]byte("/")[0]) - root.route["HEAD"] = makenode([]byte("/")[0]) - return root + router := &Router{ + trees: make(map[string]*node), + middleware: make([]MiddlewareFunc, 0), + paramsPool: sync.Pool{ + New: func() interface{} { + return make(Params) + }, + }, + prefix: "", + } + + // Initialize trees for common HTTP methods + methods := []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"} + for _, method := range methods { + router.trees[method] = makeNode("/") + } + + return router } diff --git a/sux_test.go b/sux_test.go new file mode 100644 index 0000000..e677721 --- /dev/null +++ b/sux_test.go @@ -0,0 +1,329 @@ +package sux + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestBasicRouting(t *testing.T) { + router := New() + + router.GET("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("home")) + }) + + router.GET("/about", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("about")) + }) + + // Test home route + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + if w.Body.String() != "home" { + t.Errorf("Expected body 'home', got '%s'", w.Body.String()) + } + + // Test about route + req = httptest.NewRequest("GET", "/about", nil) + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + if w.Body.String() != "about" { + t.Errorf("Expected body 'about', got '%s'", w.Body.String()) + } +} + +func TestParameterRouting(t *testing.T) { + router := New() + + router.GET("/users/:id", func(w http.ResponseWriter, r *http.Request) { + params := ParamsFromContext(r) + if params == nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + id := params["id"] + w.WriteHeader(http.StatusOK) + w.Write([]byte("user " + id)) + }) + + // Test parameter route + req := httptest.NewRequest("GET", "/users/123", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + if w.Body.String() != "user 123" { + t.Errorf("Expected body 'user 123', got '%s'", w.Body.String()) + } +} + +func TestWildcardRouting(t *testing.T) { + router := New() + + router.GET("/files/*path", func(w http.ResponseWriter, r *http.Request) { + params := ParamsFromContext(r) + if params == nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + path := params["path"] + w.WriteHeader(http.StatusOK) + w.Write([]byte("file: " + path)) + }) + + // Test wildcard route + req := httptest.NewRequest("GET", "/files/docs/readme.txt", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + if w.Body.String() != "file: docs/readme.txt" { + t.Errorf("Expected body 'file: docs/readme.txt', got '%s'", w.Body.String()) + } +} + +func TestMethodNotAllowed(t *testing.T) { + router := New() + + router.GET("/resource", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("GET resource")) + }) + + // Test POST to GET-only route + req := httptest.NewRequest("POST", "/resource", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405, got %d", w.Code) + } +} + +func TestNotFound(t *testing.T) { + router := New() + + router.GET("/existing", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Test non-existent route + req := httptest.NewRequest("GET", "/nonexistent", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Expected status 404, got %d", w.Code) + } +} + +func TestMiddleware(t *testing.T) { + router := New() + + // Add middleware that sets a header + router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Middleware", "applied") + next.ServeHTTP(w, r) + }) + }) + + router.GET("/test", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test")) + }) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + if w.Header().Get("X-Middleware") != "applied" { + t.Errorf("Expected middleware header, got '%s'", w.Header().Get("X-Middleware")) + } +} + +func TestRouteGroups(t *testing.T) { + router := New() + + // Create API group with middleware + api := router.Group("/api", func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-API-Version", "v1") + next.ServeHTTP(w, r) + }) + }) + + api.GET("/users", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("users list")) + }) + + // Test grouped route + req := httptest.NewRequest("GET", "/api/users", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + if w.Body.String() != "users list" { + t.Errorf("Expected body 'users list', got '%s'", w.Body.String()) + } + if w.Header().Get("X-API-Version") != "v1" { + t.Errorf("Expected API version header, got '%s'", w.Header().Get("X-API-Version")) + } +} + +func TestCustomHandlers(t *testing.T) { + router := New() + + // Custom 404 handler + router.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("custom 404")) + }) + + // Custom 405 handler + router.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + w.Write([]byte("custom 405")) + }) + + router.GET("/existing", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Test custom 404 + req := httptest.NewRequest("GET", "/nonexistent", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Expected status 404, got %d", w.Code) + } + if w.Body.String() != "custom 404" { + t.Errorf("Expected body 'custom 404', got '%s'", w.Body.String()) + } + + // Test custom 405 + req = httptest.NewRequest("POST", "/existing", nil) + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405, got %d", w.Code) + } + if w.Body.String() != "custom 405" { + t.Errorf("Expected body 'custom 405', got '%s'", w.Body.String()) + } +} + +func TestMultipleParameters(t *testing.T) { + router := New() + + router.GET("/users/:userId/posts/:postId", func(w http.ResponseWriter, r *http.Request) { + params := ParamsFromContext(r) + if params == nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + userId := params["userId"] + postId := params["postId"] + w.WriteHeader(http.StatusOK) + w.Write([]byte("user " + userId + " post " + postId)) + }) + + req := httptest.NewRequest("GET", "/users/123/posts/456", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + if w.Body.String() != "user 123 post 456" { + t.Errorf("Expected body 'user 123 post 456', got '%s'", w.Body.String()) + } +} + +func TestThreadSafety(t *testing.T) { + router := New() + + router.GET("/test", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test")) + }) + + // Create multiple goroutines to test concurrent access + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + for j := 0; j < 100; j++ { + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + } + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } +} + +func TestRouterInstanceIsolation(t *testing.T) { + // Create two separate router instances + router1 := New() + router2 := New() + + router1.GET("/route", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("router1")) + }) + + router2.GET("/route", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("router2")) + }) + + // Test router1 + req := httptest.NewRequest("GET", "/route", nil) + w := httptest.NewRecorder() + router1.ServeHTTP(w, req) + + if w.Body.String() != "router1" { + t.Errorf("Expected body 'router1', got '%s'", w.Body.String()) + } + + // Test router2 + w = httptest.NewRecorder() + router2.ServeHTTP(w, req) + + if w.Body.String() != "router2" { + t.Errorf("Expected body 'router2', got '%s'", w.Body.String()) + } +}