cors/config.go

84 lines
1.7 KiB
Go
Raw Normal View History

2015-11-11 22:16:49 +01:00
package cors
import (
"net/http"
2017-03-18 13:45:46 +01:00
"github.com/gin-gonic/gin"
2015-11-11 22:16:49 +01:00
)
type cors struct {
2015-11-12 15:14:38 +01:00
allowAllOrigins bool
allowCredentials bool
2015-11-12 15:14:38 +01:00
allowOriginFunc func(string) bool
allowOrigins []string
exposeHeaders []string
normalHeaders http.Header
preflightHeaders http.Header
2015-11-11 22:16:49 +01:00
}
2015-11-12 01:17:15 +01:00
func newCors(config Config) *cors {
if err := config.Validate(); err != nil {
2015-11-11 22:16:49 +01:00
panic(err.Error())
}
return &cors{
2015-11-12 15:14:38 +01:00
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
2015-11-12 15:14:38 +01:00
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
2015-11-11 22:16:49 +01:00
}
}
func (cors *cors) applyCors(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
if len(origin) == 0 {
// request is not a CORS request
return
}
if !cors.validateOrigin(origin) {
2015-11-12 01:17:15 +01:00
c.AbortWithStatus(http.StatusForbidden)
return
2015-11-11 22:16:49 +01:00
}
if c.Request.Method == "OPTIONS" {
2015-11-12 01:17:15 +01:00
cors.handlePreflight(c)
2016-06-26 19:04:15 +02:00
defer c.AbortWithStatus(200)
2015-11-11 22:16:49 +01:00
} else {
2015-11-12 01:17:15 +01:00
cors.handleNormal(c)
2015-11-11 22:16:49 +01:00
}
if !cors.allowAllOrigins {
2015-11-12 01:17:15 +01:00
c.Header("Access-Control-Allow-Origin", origin)
}
2015-11-11 22:16:49 +01:00
}
func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
}
2015-11-12 15:14:38 +01:00
for _, value := range cors.allowOrigins {
2015-11-11 22:16:49 +01:00
if value == origin {
return true
}
}
2015-11-12 15:14:38 +01:00
if cors.allowOriginFunc != nil {
return cors.allowOriginFunc(origin)
2015-11-11 22:16:49 +01:00
}
return false
}
2015-11-12 01:17:15 +01:00
func (cors *cors) handlePreflight(c *gin.Context) {
header := c.Writer.Header()
2015-11-11 22:16:49 +01:00
for key, value := range cors.preflightHeaders {
2015-11-12 01:17:15 +01:00
header[key] = value
2015-11-11 22:16:49 +01:00
}
}
2015-11-12 01:17:15 +01:00
func (cors *cors) handleNormal(c *gin.Context) {
header := c.Writer.Header()
2015-11-11 22:16:49 +01:00
for key, value := range cors.normalHeaders {
2015-11-12 01:17:15 +01:00
header[key] = value
2015-11-11 22:16:49 +01:00
}
}