From 322e348743f7c58cba9530981c523dab57660109 Mon Sep 17 00:00:00 2001 From: "Manu Mtz.-Almeida" Date: Wed, 11 Nov 2015 22:16:49 +0100 Subject: [PATCH] Initial commit --- config.go | 177 +++++++++++++++++++++++++++++++++++++++++++++++++++ cors.go | 101 +++++++++++++++++++++++++++++ cors_test.go | 151 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 429 insertions(+) create mode 100644 config.go create mode 100644 cors.go create mode 100644 cors_test.go diff --git a/config.go b/config.go new file mode 100644 index 0000000..ad144b1 --- /dev/null +++ b/config.go @@ -0,0 +1,177 @@ +package cors + +import ( + "net/http" + "net/textproto" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +type cors struct { + allowAllOrigins bool + allowedOriginFunc func(string) bool + allowedOrigins []string + allowedMethods []string + allowedHeaders []string + exposedHeaders []string + normalHeaders http.Header + preflightHeaders http.Header +} + +func newCors(c Config) *cors { + if err := c.Validate(); err != nil { + panic(err.Error()) + } + return &cors{ + allowedOriginFunc: c.AllowOriginFunc, + allowAllOrigins: c.AllowAllOrigins, + allowedOrigins: normalize(c.AllowedOrigins), + allowedMethods: normalize(c.AllowedMethods), + allowedHeaders: normalize(c.AllowedHeaders), + normalHeaders: generateNormalHeaders(c), + preflightHeaders: generatePreflightHeaders(c), + } +} + +func (cors *cors) applyCors(c *gin.Context) { + origin := c.Request.Header.Get("Origin") + if len(origin) == 0 { + // request is not a CORS request + return + } + if !cors.validateOrigin(origin) { + goto failed + } + + if c.Request.Method == "OPTIONS" { + if !cors.handlePreflight(c) { + goto failed + } + } else if !cors.handleNormal(c) { + goto failed + } + if cors.allowAllOrigins { + c.Header("Access-Control-Allow-Origin", "*") + } else { + c.Header("Access-Control-Allow-Origin", origin) + } + return + +failed: + c.AbortWithStatus(http.StatusForbidden) +} + +func (cors *cors) validateOrigin(origin string) bool { + if cors.allowAllOrigins { + return true + } + if cors.allowedOriginFunc != nil { + return cors.allowedOriginFunc(origin) + } + for _, value := range cors.allowedOrigins { + if value == origin { + return true + } + } + return false +} + +func (cors *cors) validateMethod(method string) bool { + for _, value := range cors.allowedMethods { + if strings.EqualFold(value, method) { + return true + } + } + return false +} + +func (cors *cors) validateHeader(header string) bool { + for _, value := range cors.allowedHeaders { + if strings.EqualFold(value, header) { + return true + } + } + return false +} + +func (cors *cors) handlePreflight(c *gin.Context) bool { + c.AbortWithStatus(200) + if !cors.validateMethod(c.Request.Header.Get("Access-Control-Request-Method")) { + return false + } + if !cors.validateHeader(c.Request.Header.Get("Access-Control-Request-Header")) { + return false + } + for key, value := range cors.preflightHeaders { + c.Writer.Header()[key] = value + } + return true +} + +func (cors *cors) handleNormal(c *gin.Context) bool { + for key, value := range cors.normalHeaders { + c.Writer.Header()[key] = value + } + return true +} + +func generateNormalHeaders(c Config) http.Header { + headers := make(http.Header) + if c.AllowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + if len(c.ExposedHeaders) > 0 { + headers.Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ", ")) + } + if c.AllowAllOrigins { + headers.Set("Access-Control-Allow-Origin", "*") + } else { + headers.Set("Vary", "Origin") + } + return headers +} + +func generatePreflightHeaders(c Config) http.Header { + headers := make(http.Header) + if c.AllowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + if len(c.AllowedMethods) > 0 { + value := strings.Join(c.AllowedMethods, ", ") + headers.Set("Access-Control-Allow-Methods", value) + } + if len(c.AllowedHeaders) > 0 { + value := strings.Join(c.AllowedHeaders, ", ") + headers.Set("Access-Control-Allow-Headers", value) + } + if c.MaxAge > time.Duration(0) { + value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) + headers.Set("Access-Control-Max-Age", value) + } + if c.AllowAllOrigins { + headers.Set("Access-Control-Allow-Origin", "*") + } else { + headers.Set("Vary", "Origin") + } + return headers +} + +func normalize(values []string) []string { + if values == nil { + return nil + } + distinctMap := make(map[string]bool, len(values)) + normalized := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + value = textproto.CanonicalMIMEHeaderKey(value) + if _, seen := distinctMap[value]; !seen { + normalized = append(normalized, value) + distinctMap[value] = true + } + } + return normalized +} diff --git a/cors.go b/cors.go new file mode 100644 index 0000000..c3b3642 --- /dev/null +++ b/cors.go @@ -0,0 +1,101 @@ +package cors + +import ( + "errors" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +type Config struct { + AbortOnError bool + AllowAllOrigins bool + + // AllowedOrigins is a list of origins a cross-domain request can be executed from. + // If the special "*" value is present in the list, all origins will be allowed. + // Default value is ["*"] + AllowedOrigins []string + + // AllowOriginFunc is a custom function to validate the origin. It take the origin + // as argument and returns true if allowed or false otherwise. If this option is + // set, the content of AllowedOrigins is ignored. + AllowOriginFunc func(origin string) bool + + // AllowedMethods is a list of methods the client is allowed to use with + // cross-domain requests. Default value is simple methods (GET and POST) + AllowedMethods []string + + // AllowedHeaders is list of non simple headers the client is allowed to use with + // cross-domain requests. + // If the special "*" value is present in the list, all headers will be allowed. + // Default value is [] but "Origin" is always appended to the list. + AllowedHeaders []string + + // ExposedHeaders indicates which headers are safe to expose to the API of a CORS + // API specification + ExposedHeaders []string + + // AllowCredentials indicates whether the request can include user credentials like + // cookies, HTTP authentication or client side SSL certificates. + AllowCredentials bool + + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached + MaxAge time.Duration +} + +func (c *Config) AddAllowedMethods(methods ...string) { + c.AllowedMethods = append(c.AllowedMethods, methods...) +} + +func (c *Config) AddAllowedHeaders(headers ...string) { + c.AllowedHeaders = append(c.AllowedHeaders, headers...) +} + +func (c *Config) AddExposedHeaders(headers ...string) { + c.ExposedHeaders = append(c.ExposedHeaders, headers...) +} + +func (c Config) Validate() error { + if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowedOrigins) > 0) { + return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed") + } + if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowedOrigins) == 0 { + return errors.New("conflict settings: all origins disabled") + } + if c.AllowOriginFunc != nil && len(c.AllowedOrigins) > 0 { + return errors.New("conflict settings: if a allow origin func is provided, AllowedOrigins is not needed") + } + for _, origin := range c.AllowedOrigins { + if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { + return errors.New("bad origin: origins must include http:// or https://") + } + } + return nil +} + +func DefaultConfig() Config { + return Config{ + AbortOnError: false, + AllowAllOrigins: true, + AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "HEAD"}, + AllowedHeaders: []string{"Content-Type"}, + //ExposedHeaders: "", + AllowCredentials: false, + MaxAge: 12 * time.Hour, + } +} + +func Default() gin.HandlerFunc { + return New(DefaultConfig()) +} + +func New(config Config) gin.HandlerFunc { + s := newCors(config) + + // Algorithm based in http://www.html5rocks.com/static/images/cors_server_flowchart.png + return func(c *gin.Context) { + s.applyCors(c) + } +} diff --git a/cors_test.go b/cors_test.go new file mode 100644 index 0000000..425bffd --- /dev/null +++ b/cors_test.go @@ -0,0 +1,151 @@ +package cors + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder { + req, _ := http.NewRequest(method, path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + return w +} + +func TestBadConfig(t *testing.T) { + assert.Panics(t, func() { New(Config{}) }) + assert.Panics(t, func() { + New(Config{ + AllowAllOrigins: true, + AllowedOrigins: []string{"http://google.com"}, + }) + }) + assert.Panics(t, func() { + New(Config{ + AllowAllOrigins: true, + AllowOriginFunc: func(origin string) bool { return false }, + }) + }) + assert.Panics(t, func() { + New(Config{ + AllowedOrigins: []string{"http://google.com"}, + AllowOriginFunc: func(origin string) bool { return false }, + }) + }) + assert.Panics(t, func() { + New(Config{ + AllowedOrigins: []string{"google.com"}, + }) + }) +} + +func TestNormalize(t *testing.T) { + values := normalize([]string{ + "http-access ", "post", "POST", " poSt ", + "HTTP-Access", "", + }) + assert.Equal(t, values, []string{"Http-Access", "Post", ""}) + + values = normalize(nil) + assert.Nil(t, values) + + values = normalize([]string{}) + assert.Equal(t, values, []string{}) +} + +func TestGenerateNormalHeaders(t *testing.T) { + header := generateNormalHeaders(Config{ + AllowAllOrigins: false, + }) + assert.Contains(t, header.Get("Access-Control-Allow-Origin"), "") + assert.Contains(t, header.Get("Vary"), "Origin") + + header = generateNormalHeaders(Config{ + AllowAllOrigins: true, + }) + assert.Contains(t, header.Get("Access-Control-Allow-Origin"), "*") + assert.Contains(t, header.Get("Vary"), "") + + header = generateNormalHeaders(Config{ + AllowCredentials: true, + }) + assert.Contains(t, header.Get("Access-Control-Allow-Credentials"), "true") + + header = generateNormalHeaders(Config{ + AllowCredentials: false, + }) + assert.Contains(t, header.Get("Access-Control-Allow-Credentials"), "") + + header = generateNormalHeaders(Config{ + ExposedHeaders: []string{"x-user", "xpassword"}, + }) + assert.Contains(t, header.Get("Access-Control-Expose-Headers"), "x-user, xpassword") +} + +// +// func TestDeny0(t *testing.T) { +// called := false +// +// router := gin.New() +// router.Use(New(Config{ +// AllowedOrigins: []string{"http://example.com"}, +// })) +// router.GET("/", func(c *gin.Context) { +// called = true +// }) +// w := httptest.NewRecorder() +// req, _ := http.NewRequest("GET", "/", nil) +// req.Header.Set("Origin", "https://example.com") +// router.ServeHTTP(w, req) +// +// assert.True(t, called) +// assert.NotContains(t, w.Header(), "Access-Control") +// } +// +// func TestDenyAbortOnError(t *testing.T) { +// called := false +// +// router := gin.New() +// router.Use(New(Config{ +// AbortOnError: true, +// AllowedOrigins: []string{"http://example.com"}, +// })) +// router.GET("/", func(c *gin.Context) { +// called = true +// }) +// +// w := httptest.NewRecorder() +// req, _ := http.NewRequest("GET", "/", nil) +// req.Header.Set("Origin", "https://example.com") +// router.ServeHTTP(w, req) +// +// assert.False(t, called) +// assert.NotContains(t, w.Header(), "Access-Control") +// } +// +// func TestDeny2(t *testing.T) { +// +// } +// func TestDeny3(t *testing.T) { +// +// } +// +// func TestPasses0(t *testing.T) { +// +// } +// +// func TestPasses1(t *testing.T) { +// +// } +// +// func TestPasses2(t *testing.T) { +// +// }