diff --git a/config.go b/config.go index ad144b1..34fa64a 100644 --- a/config.go +++ b/config.go @@ -2,10 +2,6 @@ package cors import ( "net/http" - "net/textproto" - "strconv" - "strings" - "time" "github.com/gin-gonic/gin" ) @@ -14,25 +10,21 @@ type cors struct { allowAllOrigins bool allowedOriginFunc func(string) bool allowedOrigins []string - allowedMethods []string - allowedHeaders []string exposedHeaders []string normalHeaders http.Header preflightHeaders http.Header } -func newCors(c Config) *cors { - if err := c.Validate(); err != nil { +func newCors(config Config) *cors { + if err := config.Validate(); err != nil { panic(err.Error()) } return &cors{ - allowedOriginFunc: c.AllowOriginFunc, - allowAllOrigins: c.AllowAllOrigins, - allowedOrigins: normalize(c.AllowedOrigins), - allowedMethods: normalize(c.AllowedMethods), - allowedHeaders: normalize(c.AllowedHeaders), - normalHeaders: generateNormalHeaders(c), - preflightHeaders: generatePreflightHeaders(c), + allowedOriginFunc: config.AllowOriginFunc, + allowAllOrigins: config.AllowAllOrigins, + allowedOrigins: normalize(config.AllowedOrigins), + normalHeaders: generateNormalHeaders(config), + preflightHeaders: generatePreflightHeaders(config), } } @@ -43,135 +35,47 @@ func (cors *cors) applyCors(c *gin.Context) { return } if !cors.validateOrigin(origin) { - goto failed + c.AbortWithStatus(http.StatusForbidden) + return } if c.Request.Method == "OPTIONS" { - if !cors.handlePreflight(c) { - goto failed - } - } else if !cors.handleNormal(c) { - goto failed - } - if cors.allowAllOrigins { - c.Header("Access-Control-Allow-Origin", "*") + cors.handlePreflight(c) } else { + cors.handleNormal(c) + } + + if !cors.allowAllOrigins { c.Header("Access-Control-Allow-Origin", origin) } - return - -failed: - c.AbortWithStatus(http.StatusForbidden) } func (cors *cors) validateOrigin(origin string) bool { if cors.allowAllOrigins { return true } - if cors.allowedOriginFunc != nil { - return cors.allowedOriginFunc(origin) - } for _, value := range cors.allowedOrigins { if value == origin { return true } } - return false -} - -func (cors *cors) validateMethod(method string) bool { - for _, value := range cors.allowedMethods { - if strings.EqualFold(value, method) { - return true - } + if cors.allowedOriginFunc != nil { + return cors.allowedOriginFunc(origin) } return false } -func (cors *cors) validateHeader(header string) bool { - for _, value := range cors.allowedHeaders { - if strings.EqualFold(value, header) { - return true - } - } - return false -} - -func (cors *cors) handlePreflight(c *gin.Context) bool { +func (cors *cors) handlePreflight(c *gin.Context) { c.AbortWithStatus(200) - if !cors.validateMethod(c.Request.Header.Get("Access-Control-Request-Method")) { - return false - } - if !cors.validateHeader(c.Request.Header.Get("Access-Control-Request-Header")) { - return false - } + header := c.Writer.Header() for key, value := range cors.preflightHeaders { - c.Writer.Header()[key] = value + header[key] = value } - return true } -func (cors *cors) handleNormal(c *gin.Context) bool { +func (cors *cors) handleNormal(c *gin.Context) { + header := c.Writer.Header() for key, value := range cors.normalHeaders { - c.Writer.Header()[key] = value + header[key] = value } - return true -} - -func generateNormalHeaders(c Config) http.Header { - headers := make(http.Header) - if c.AllowCredentials { - headers.Set("Access-Control-Allow-Credentials", "true") - } - if len(c.ExposedHeaders) > 0 { - headers.Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ", ")) - } - if c.AllowAllOrigins { - headers.Set("Access-Control-Allow-Origin", "*") - } else { - headers.Set("Vary", "Origin") - } - return headers -} - -func generatePreflightHeaders(c Config) http.Header { - headers := make(http.Header) - if c.AllowCredentials { - headers.Set("Access-Control-Allow-Credentials", "true") - } - if len(c.AllowedMethods) > 0 { - value := strings.Join(c.AllowedMethods, ", ") - headers.Set("Access-Control-Allow-Methods", value) - } - if len(c.AllowedHeaders) > 0 { - value := strings.Join(c.AllowedHeaders, ", ") - headers.Set("Access-Control-Allow-Headers", value) - } - if c.MaxAge > time.Duration(0) { - value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) - headers.Set("Access-Control-Max-Age", value) - } - if c.AllowAllOrigins { - headers.Set("Access-Control-Allow-Origin", "*") - } else { - headers.Set("Vary", "Origin") - } - return headers -} - -func normalize(values []string) []string { - if values == nil { - return nil - } - distinctMap := make(map[string]bool, len(values)) - normalized := make([]string, 0, len(values)) - for _, value := range values { - value = strings.TrimSpace(value) - value = textproto.CanonicalMIMEHeaderKey(value) - if _, seen := distinctMap[value]; !seen { - normalized = append(normalized, value) - distinctMap[value] = true - } - } - return normalized } diff --git a/cors.go b/cors.go index c3b3642..b80f389 100644 --- a/cors.go +++ b/cors.go @@ -9,7 +9,6 @@ import ( ) type Config struct { - AbortOnError bool AllowAllOrigins bool // AllowedOrigins is a list of origins a cross-domain request can be executed from. @@ -64,9 +63,6 @@ func (c Config) Validate() error { if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowedOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } - if c.AllowOriginFunc != nil && len(c.AllowedOrigins) > 0 { - return errors.New("conflict settings: if a allow origin func is provided, AllowedOrigins is not needed") - } for _, origin := range c.AllowedOrigins { if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { return errors.New("bad origin: origins must include http:// or https://") @@ -77,11 +73,9 @@ func (c Config) Validate() error { func DefaultConfig() Config { return Config{ - AbortOnError: false, - AllowAllOrigins: true, - AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "HEAD"}, - AllowedHeaders: []string{"Content-Type"}, - //ExposedHeaders: "", + AllowAllOrigins: true, + AllowedMethods: []string{"GET", "POST", "PUT", "HEAD"}, + AllowedHeaders: []string{"Content-Type"}, AllowCredentials: false, MaxAge: 12 * time.Hour, } @@ -92,10 +86,8 @@ func Default() gin.HandlerFunc { } func New(config Config) gin.HandlerFunc { - s := newCors(config) - - // Algorithm based in http://www.html5rocks.com/static/images/cors_server_flowchart.png + cors := newCors(config) return func(c *gin.Context) { - s.applyCors(c) + cors.applyCors(c) } } diff --git a/cors_test.go b/cors_test.go index 425bffd..7cd934b 100644 --- a/cors_test.go +++ b/cors_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" @@ -34,12 +35,6 @@ func TestBadConfig(t *testing.T) { AllowOriginFunc: func(origin string) bool { return false }, }) }) - assert.Panics(t, func() { - New(Config{ - AllowedOrigins: []string{"http://google.com"}, - AllowOriginFunc: func(origin string) bool { return false }, - }) - }) assert.Panics(t, func() { New(Config{ AllowedOrigins: []string{"google.com"}, @@ -49,10 +44,10 @@ func TestBadConfig(t *testing.T) { func TestNormalize(t *testing.T) { values := normalize([]string{ - "http-access ", "post", "POST", " poSt ", + "http-Access ", "Post", "POST", " poSt ", "HTTP-Access", "", }) - assert.Equal(t, values, []string{"Http-Access", "Post", ""}) + assert.Equal(t, values, []string{"http-access", "post", ""}) values = normalize(nil) assert.Nil(t, values) @@ -61,91 +56,189 @@ func TestNormalize(t *testing.T) { assert.Equal(t, values, []string{}) } -func TestGenerateNormalHeaders(t *testing.T) { +func TestGenerateNormalHeaders_AllowAllOrigins(t *testing.T) { header := generateNormalHeaders(Config{ AllowAllOrigins: false, }) - assert.Contains(t, header.Get("Access-Control-Allow-Origin"), "") - assert.Contains(t, header.Get("Vary"), "Origin") + assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "") + assert.Equal(t, header.Get("Vary"), "Origin") + assert.Len(t, header, 1) header = generateNormalHeaders(Config{ AllowAllOrigins: true, }) - assert.Contains(t, header.Get("Access-Control-Allow-Origin"), "*") - assert.Contains(t, header.Get("Vary"), "") - - header = generateNormalHeaders(Config{ - AllowCredentials: true, - }) - assert.Contains(t, header.Get("Access-Control-Allow-Credentials"), "true") - - header = generateNormalHeaders(Config{ - AllowCredentials: false, - }) - assert.Contains(t, header.Get("Access-Control-Allow-Credentials"), "") - - header = generateNormalHeaders(Config{ - ExposedHeaders: []string{"x-user", "xpassword"}, - }) - assert.Contains(t, header.Get("Access-Control-Expose-Headers"), "x-user, xpassword") + assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*") + assert.Equal(t, header.Get("Vary"), "") + assert.Len(t, header, 1) } -// -// func TestDeny0(t *testing.T) { -// called := false -// -// router := gin.New() -// router.Use(New(Config{ -// AllowedOrigins: []string{"http://example.com"}, -// })) -// router.GET("/", func(c *gin.Context) { -// called = true -// }) -// w := httptest.NewRecorder() -// req, _ := http.NewRequest("GET", "/", nil) -// req.Header.Set("Origin", "https://example.com") -// router.ServeHTTP(w, req) -// -// assert.True(t, called) -// assert.NotContains(t, w.Header(), "Access-Control") -// } -// -// func TestDenyAbortOnError(t *testing.T) { -// called := false -// -// router := gin.New() -// router.Use(New(Config{ -// AbortOnError: true, -// AllowedOrigins: []string{"http://example.com"}, -// })) -// router.GET("/", func(c *gin.Context) { -// called = true -// }) -// -// w := httptest.NewRecorder() -// req, _ := http.NewRequest("GET", "/", nil) -// req.Header.Set("Origin", "https://example.com") -// router.ServeHTTP(w, req) -// -// assert.False(t, called) -// assert.NotContains(t, w.Header(), "Access-Control") -// } -// -// func TestDeny2(t *testing.T) { -// -// } -// func TestDeny3(t *testing.T) { -// -// } -// -// func TestPasses0(t *testing.T) { -// -// } -// -// func TestPasses1(t *testing.T) { -// -// } -// -// func TestPasses2(t *testing.T) { -// -// } +func TestGenerateNormalHeaders_AllowCredentials(t *testing.T) { + header := generateNormalHeaders(Config{ + AllowCredentials: true, + }) + assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true") + assert.Equal(t, header.Get("Vary"), "Origin") + assert.Len(t, header, 2) +} + +func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) { + header := generateNormalHeaders(Config{ + ExposedHeaders: []string{"X-user", "xPassword"}, + }) + assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "x-user, xpassword") + assert.Equal(t, header.Get("Vary"), "Origin") + assert.Len(t, header, 2) +} + +func TestGeneratePreflightHeaders(t *testing.T) { + header := generatePreflightHeaders(Config{ + AllowAllOrigins: false, + }) + assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "") + assert.Equal(t, header.Get("Vary"), "Origin") + assert.Len(t, header, 1) + + header = generateNormalHeaders(Config{ + AllowAllOrigins: true, + }) + assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*") + assert.Equal(t, header.Get("Vary"), "") + assert.Len(t, header, 1) +} + +func TestGeneratePreflightHeaders_AllowCredentials(t *testing.T) { + header := generatePreflightHeaders(Config{ + AllowCredentials: true, + }) + assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true") + assert.Equal(t, header.Get("Vary"), "Origin") + assert.Len(t, header, 2) +} + +func TestGeneratePreflightHeaders_AllowedMethods(t *testing.T) { + header := generatePreflightHeaders(Config{ + AllowedMethods: []string{"GET ", "post", "PUT", " put "}, + }) + assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "get, post, put") + assert.Equal(t, header.Get("Vary"), "Origin") + assert.Len(t, header, 2) +} + +func TestGeneratePreflightHeaders_AllowedHeaders(t *testing.T) { + header := generatePreflightHeaders(Config{ + AllowedHeaders: []string{"X-user", "Content-Type"}, + }) + assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "x-user, content-type") + assert.Equal(t, header.Get("Vary"), "Origin") + assert.Len(t, header, 2) +} + +func TestGeneratePreflightHeaders_MaxAge(t *testing.T) { + header := generatePreflightHeaders(Config{ + MaxAge: 12 * time.Hour, + }) + assert.Equal(t, header.Get("Access-Control-Max-Age"), "43200") // 12*60*60 + assert.Equal(t, header.Get("Vary"), "Origin") + assert.Len(t, header, 2) +} + +func TestValidateOrigin(t *testing.T) { + cors := newCors(Config{ + AllowAllOrigins: true, + }) + assert.True(t, cors.validateOrigin("http://google.com")) + assert.True(t, cors.validateOrigin("https://google.com")) + assert.True(t, cors.validateOrigin("example.com")) + + cors = newCors(Config{ + AllowedOrigins: []string{"https://google.com", "https://github.com"}, + AllowOriginFunc: func(origin string) bool { + return (origin == "http://news.ycombinator.com") + }, + }) + assert.False(t, cors.validateOrigin("http://google.com")) + assert.True(t, cors.validateOrigin("https://google.com")) + assert.True(t, cors.validateOrigin("https://github.com")) + assert.True(t, cors.validateOrigin("http://news.ycombinator.com")) + assert.False(t, cors.validateOrigin("http://example.com")) + assert.False(t, cors.validateOrigin("google.com")) +} + +func TestPasses0(t *testing.T) { + called := false + router := gin.New() + router.Use(New(Config{ + AllowedOrigins: []string{"http://google.com"}, + AllowedMethods: []string{" GeT ", "get", "post", "PUT ", "Head", "POST"}, + AllowedHeaders: []string{"Content-type", "timeStamp "}, + ExposedHeaders: []string{"Data", "x-User"}, + AllowCredentials: true, + MaxAge: 12 * time.Hour, + AllowOriginFunc: func(origin string) bool { + return origin == "http://github.com" + }, + })) + router.GET("/", func(c *gin.Context) { + called = true + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + router.ServeHTTP(w, req) + assert.True(t, called) + assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin") + assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials") + assert.NotContains(t, w.Header(), "Access-Control-Expose-Headers") + + called = false + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/", nil) + req.Header.Set("Origin", "http://google.com") + router.ServeHTTP(w, req) + assert.True(t, called) + assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://google.com") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true") + assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data, x-user") + + called = false + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/", nil) + req.Header.Set("Origin", "https://google.com") + router.ServeHTTP(w, req) + assert.False(t, called) + assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin") + assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials") + assert.NotContains(t, w.Header(), "Access-Control-Expose-Headers") + + called = false + w = httptest.NewRecorder() + req, _ = http.NewRequest("OPTIONS", "/", nil) + req.Header.Set("Origin", "http://github.com") + router.ServeHTTP(w, req) + assert.False(t, called) + assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://github.com") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "get, post, put, head") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type, timestamp") + assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "43200") + + called = false + w = httptest.NewRecorder() + req, _ = http.NewRequest("OPTIONS", "/", nil) + req.Header.Set("Origin", "http://example.com") + router.ServeHTTP(w, req) + assert.False(t, called) + assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin") + assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials") + assert.NotContains(t, w.Header(), "Access-Control-Allow-Methods") + assert.NotContains(t, w.Header(), "Access-Control-Allow-Headers") + assert.NotContains(t, w.Header(), "Access-Control-Max-Age") +} + +func TestPasses1(t *testing.T) { + +} + +func TestPasses2(t *testing.T) { + +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..62a123a --- /dev/null +++ b/utils.go @@ -0,0 +1,69 @@ +package cors + +import ( + "net/http" + "strconv" + "strings" + "time" +) + +func generateNormalHeaders(c Config) http.Header { + headers := make(http.Header) + if c.AllowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + if len(c.ExposedHeaders) > 0 { + exposedHeaders := normalize(c.ExposedHeaders) + headers.Set("Access-Control-Expose-Headers", strings.Join(exposedHeaders, ", ")) + } + if c.AllowAllOrigins { + headers.Set("Access-Control-Allow-Origin", "*") + } else { + headers.Set("Vary", "Origin") + } + return headers +} + +func generatePreflightHeaders(c Config) http.Header { + headers := make(http.Header) + if c.AllowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + if len(c.AllowedMethods) > 0 { + allowedMethods := normalize(c.AllowedMethods) + value := strings.Join(allowedMethods, ", ") + headers.Set("Access-Control-Allow-Methods", value) + } + if len(c.AllowedHeaders) > 0 { + allowedHeaders := normalize(c.AllowedHeaders) + value := strings.Join(allowedHeaders, ", ") + headers.Set("Access-Control-Allow-Headers", value) + } + if c.MaxAge > time.Duration(0) { + value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) + headers.Set("Access-Control-Max-Age", value) + } + if c.AllowAllOrigins { + headers.Set("Access-Control-Allow-Origin", "*") + } else { + headers.Set("Vary", "Origin") + } + return headers +} + +func normalize(values []string) []string { + if values == nil { + return nil + } + distinctMap := make(map[string]bool, len(values)) + normalized := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + value = strings.ToLower(value) + if _, seen := distinctMap[value]; !seen { + normalized = append(normalized, value) + distinctMap[value] = true + } + } + return normalized +}