diff --git a/cors_test.go b/cors_test.go index 7cd934b..76b7f35 100644 --- a/cors_test.go +++ b/cors_test.go @@ -14,8 +14,26 @@ func init() { gin.SetMode(gin.TestMode) } -func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder { - req, _ := http.NewRequest(method, path, nil) +func newTestRouter(config Config) *gin.Engine { + router := gin.New() + router.Use(New(config)) + router.GET("/", func(c *gin.Context) { + c.String(200, "get") + }) + router.POST("/", func(c *gin.Context) { + c.String(200, "post") + }) + router.PATCH("/", func(c *gin.Context) { + c.String(200, "patch") + }) + return router +} + +func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder { + req, _ := http.NewRequest(method, "/", nil) + if len(origin) > 0 { + req.Header.Set("Origin", origin) + } w := httptest.NewRecorder() r.ServeHTTP(w, req) return w @@ -164,10 +182,8 @@ func TestValidateOrigin(t *testing.T) { assert.False(t, cors.validateOrigin("google.com")) } -func TestPasses0(t *testing.T) { - called := false - router := gin.New() - router.Use(New(Config{ +func TestPassesAllowedOrigins(t *testing.T) { + router := newTestRouter(Config{ AllowedOrigins: []string{"http://google.com"}, AllowedMethods: []string{" GeT ", "get", "post", "PUT ", "Head", "POST"}, AllowedHeaders: []string{"Content-type", "timeStamp "}, @@ -177,68 +193,78 @@ func TestPasses0(t *testing.T) { AllowOriginFunc: func(origin string) bool { return origin == "http://github.com" }, - })) - router.GET("/", func(c *gin.Context) { - called = true }) - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/", nil) - router.ServeHTTP(w, req) - assert.True(t, called) - assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin") - assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials") - assert.NotContains(t, w.Header(), "Access-Control-Expose-Headers") + // no CORS request, origin == "" + w := performRequest(router, "GET", "") + assert.Equal(t, w.Body.String(), "get") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) - called = false - w = httptest.NewRecorder() - req, _ = http.NewRequest("GET", "/", nil) - req.Header.Set("Origin", "http://google.com") - router.ServeHTTP(w, req) - assert.True(t, called) + // allowed CORS request + w = performRequest(router, "GET", "http://google.com") + assert.Equal(t, w.Body.String(), "get") assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://google.com") assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true") assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data, x-user") - called = false - w = httptest.NewRecorder() - req, _ = http.NewRequest("GET", "/", nil) - req.Header.Set("Origin", "https://google.com") - router.ServeHTTP(w, req) - assert.False(t, called) - assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin") - assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials") - assert.NotContains(t, w.Header(), "Access-Control-Expose-Headers") + // deny CORS request + w = performRequest(router, "GET", "https://google.com") + assert.Equal(t, w.Code, 403) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) - called = false - w = httptest.NewRecorder() - req, _ = http.NewRequest("OPTIONS", "/", nil) - req.Header.Set("Origin", "http://github.com") - router.ServeHTTP(w, req) - assert.False(t, called) + // allowed CORS prefligh request + w = performRequest(router, "OPTIONS", "http://github.com") + assert.Equal(t, w.Code, 200) assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://github.com") assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true") assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "get, post, put, head") assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type, timestamp") assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "43200") - called = false - w = httptest.NewRecorder() - req, _ = http.NewRequest("OPTIONS", "/", nil) - req.Header.Set("Origin", "http://example.com") - router.ServeHTTP(w, req) - assert.False(t, called) - assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin") - assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials") - assert.NotContains(t, w.Header(), "Access-Control-Allow-Methods") - assert.NotContains(t, w.Header(), "Access-Control-Allow-Headers") - assert.NotContains(t, w.Header(), "Access-Control-Max-Age") + // deny CORS prefligh request + w = performRequest(router, "OPTIONS", "http://example.com") + assert.Equal(t, w.Code, 403) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) + assert.Empty(t, w.Header().Get("Access-Control-Max-Age")) } -func TestPasses1(t *testing.T) { - -} - -func TestPasses2(t *testing.T) { +func TestPassesAllowedAllOrigins(t *testing.T) { + router := newTestRouter(Config{ + AllowAllOrigins: true, + AllowedMethods: []string{" Patch ", "get", "post", "POST"}, + AllowedHeaders: []string{"Content-type", " testheader "}, + ExposedHeaders: []string{"Data2", "x-User2"}, + AllowCredentials: false, + MaxAge: 10 * time.Hour, + }) + // no CORS request, origin == "" + w := performRequest(router, "GET", "") + assert.Equal(t, w.Body.String(), "get") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) + + // allowed CORS request + w = performRequest(router, "POST", "example.com") + assert.Equal(t, w.Body.String(), "post") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*") + assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data2, x-user2") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + + // allowed CORS prefligh request + w = performRequest(router, "OPTIONS", "https://facebook.com") + assert.Equal(t, w.Code, 200) + assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "patch, get, post") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type, testheader") + assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "36000") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) }