package cors import ( "net/http" "github.com/gin-gonic/gin" ) type cors struct { allowAllOrigins bool allowedOriginFunc func(string) bool allowedOrigins []string exposedHeaders []string normalHeaders http.Header preflightHeaders http.Header } func newCors(config Config) *cors { if err := config.Validate(); err != nil { panic(err.Error()) } return &cors{ allowedOriginFunc: config.AllowOriginFunc, allowAllOrigins: config.AllowAllOrigins, allowedOrigins: normalize(config.AllowedOrigins), normalHeaders: generateNormalHeaders(config), preflightHeaders: generatePreflightHeaders(config), } } func (cors *cors) applyCors(c *gin.Context) { origin := c.Request.Header.Get("Origin") if len(origin) == 0 { // request is not a CORS request return } if !cors.validateOrigin(origin) { c.AbortWithStatus(http.StatusForbidden) return } if c.Request.Method == "OPTIONS" { cors.handlePreflight(c) } else { cors.handleNormal(c) } if !cors.allowAllOrigins { c.Header("Access-Control-Allow-Origin", origin) } } func (cors *cors) validateOrigin(origin string) bool { if cors.allowAllOrigins { return true } for _, value := range cors.allowedOrigins { if value == origin { return true } } if cors.allowedOriginFunc != nil { return cors.allowedOriginFunc(origin) } return false } func (cors *cors) handlePreflight(c *gin.Context) { c.AbortWithStatus(200) header := c.Writer.Header() for key, value := range cors.preflightHeaders { header[key] = value } } func (cors *cors) handleNormal(c *gin.Context) { header := c.Writer.Header() for key, value := range cors.normalHeaders { header[key] = value } }