oidc/oidc.go

392 lines
9.4 KiB
Go
Raw Normal View History

package oidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.icod.de/dalu/oidc/options"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jws"
"github.com/lestrrat-go/jwx/jwt"
)
var (
errSignatureVerification = fmt.Errorf("failed to verify signature")
)
type handler struct {
issuer string
discoveryUri string
jwksUri string
jwksFetchTimeout time.Duration
jwksRateLimit uint
fallbackSignatureAlgorithm jwa.SignatureAlgorithm
allowedTokenDrift time.Duration
requiredAudience string
requiredTokenType string
requiredClaims map[string]interface{}
disableKeyID bool
httpClient *http.Client
keyHandler *keyHandler
}
func NewHandler(setters ...options.Option) (*handler, error) {
opts := options.New(setters...)
h := &handler{
issuer: opts.Issuer,
discoveryUri: opts.DiscoveryUri,
jwksUri: opts.JwksUri,
jwksFetchTimeout: opts.JwksFetchTimeout,
jwksRateLimit: opts.JwksRateLimit,
allowedTokenDrift: opts.AllowedTokenDrift,
requiredTokenType: opts.RequiredTokenType,
requiredAudience: opts.RequiredAudience,
requiredClaims: opts.RequiredClaims,
disableKeyID: opts.DisableKeyID,
httpClient: opts.HttpClient,
}
if h.issuer == "" {
return nil, fmt.Errorf("issuer is empty")
}
if h.discoveryUri == "" {
h.discoveryUri = GetDiscoveryUriFromIssuer(h.issuer)
}
if opts.FallbackSignatureAlgorithm != "" {
alg, err := getSignatureAlgorithmFromString(opts.FallbackSignatureAlgorithm)
if err != nil {
return nil, fmt.Errorf("FallbackSignatureAlgorithm not accepted: %w", err)
}
h.fallbackSignatureAlgorithm = alg
}
if !opts.LazyLoadJwks {
err := h.loadJwks()
if err != nil {
return nil, fmt.Errorf("unable to load jwks: %w", err)
}
}
return h, nil
}
func (h *handler) loadJwks() error {
if h.jwksUri == "" {
jwksUri, err := getJwksUriFromDiscoveryUri(h.httpClient, h.discoveryUri, 5*time.Second)
if err != nil {
return fmt.Errorf("unable to fetch jwksUri from discoveryUri (%s): %w", h.discoveryUri, err)
}
h.jwksUri = jwksUri
}
keyHandler, err := newKeyHandler(h.httpClient, h.jwksUri, h.jwksFetchTimeout, h.jwksRateLimit, h.disableKeyID)
if err != nil {
return fmt.Errorf("unable to initialize keyHandler: %w", err)
}
h.keyHandler = keyHandler
return nil
}
func (h *handler) SetIssuer(issuer string) {
h.issuer = issuer
}
func (h *handler) SetDiscoveryUri(discoveryUri string) {
h.discoveryUri = discoveryUri
}
type ParseTokenFunc func(ctx context.Context, tokenString string) (jwt.Token, error)
func (h *handler) ParseToken(ctx context.Context, tokenString string) (jwt.Token, error) {
if h.keyHandler == nil {
err := h.loadJwks()
if err != nil {
return nil, fmt.Errorf("unable to load jwks: %w", err)
}
}
tokenTypeValid := isTokenTypeValid(h.requiredTokenType, tokenString)
if !tokenTypeValid {
return nil, fmt.Errorf("token type %q required", h.requiredTokenType)
}
keyID := ""
if !h.disableKeyID {
var err error
keyID, err = getKeyIDFromTokenString(tokenString)
if err != nil {
return nil, err
}
}
key, err := h.keyHandler.getKey(ctx, keyID)
if err != nil {
return nil, fmt.Errorf("unable to get public key: %w", err)
}
alg, err := getSignatureAlgorithm(key.KeyType(), key.Algorithm(), h.fallbackSignatureAlgorithm)
if err != nil {
return nil, err
}
token, err := getAndValidateTokenFromString(tokenString, key, alg)
if err != nil {
if h.disableKeyID && errors.Is(err, errSignatureVerification) {
updatedKey, err := h.keyHandler.waitForUpdateKeySetAndGetKey(ctx)
if err != nil {
return nil, err
}
alg, err := getSignatureAlgorithm(key.KeyType(), key.Algorithm(), h.fallbackSignatureAlgorithm)
if err != nil {
return nil, err
}
token, err = getAndValidateTokenFromString(tokenString, updatedKey, alg)
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
validExpiration := isTokenExpirationValid(token.Expiration(), h.allowedTokenDrift)
if !validExpiration {
return nil, fmt.Errorf("token has expired: %s", token.Expiration())
}
validIssuer := isTokenIssuerValid(h.issuer, token.Issuer())
if !validIssuer {
return nil, fmt.Errorf("required issuer %q was not found, received: %s", h.issuer, token.Issuer())
}
validAudience := isTokenAudienceValid(h.requiredAudience, token.Audience())
if !validAudience {
return nil, fmt.Errorf("required audience %q was not found, received: %v", h.requiredAudience, token.Audience())
}
if h.requiredClaims != nil {
tokenClaims, err := token.AsMap(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get token claims: %w", err)
}
err = isRequiredClaimsValid(h.requiredClaims, tokenClaims)
if err != nil {
return nil, fmt.Errorf("unable to validate required claims: %w", err)
}
}
return token, nil
}
func GetDiscoveryUriFromIssuer(issuer string) string {
return fmt.Sprintf("%s/.well-known/openid-configuration", strings.TrimSuffix(issuer, "/"))
}
func getJwksUriFromDiscoveryUri(httpClient *http.Client, discoveryUri string, fetchTimeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), fetchTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryUri, nil)
if err != nil {
return "", err
}
req.Header.Set("Accept", "application/json")
res, err := httpClient.Do(req)
if err != nil {
return "", err
}
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}
err = res.Body.Close()
if err != nil {
return "", err
}
var discoveryData struct {
JwksUri string `json:"jwks_uri"`
}
err = json.Unmarshal(bodyBytes, &discoveryData)
if err != nil {
return "", err
}
if discoveryData.JwksUri == "" {
return "", fmt.Errorf("JwksUri is empty")
}
return discoveryData.JwksUri, nil
}
func getKeyIDFromTokenString(tokenString string) (string, error) {
headers, err := getHeadersFromTokenString(tokenString)
if err != nil {
return "", err
}
keyID := headers.KeyID()
if keyID == "" {
return "", fmt.Errorf("token header does not contain key id (kid)")
}
return keyID, nil
}
func getTokenTypeFromTokenString(tokenString string) (string, error) {
headers, err := getHeadersFromTokenString(tokenString)
if err != nil {
return "", err
}
tokenType := headers.Type()
if tokenType == "" {
return "", fmt.Errorf("token header does not contain type (typ)")
}
return tokenType, nil
}
func getHeadersFromTokenString(tokenString string) (jws.Headers, error) {
msg, err := jws.ParseString(tokenString)
if err != nil {
return nil, fmt.Errorf("unable to parse tokenString: %w", err)
}
signatures := msg.Signatures()
if len(signatures) != 1 {
return nil, fmt.Errorf("more than one signature in token")
}
headers := signatures[0].ProtectedHeaders()
return headers, nil
}
func isTokenAudienceValid(requiredAudience string, audiences []string) bool {
if requiredAudience == "" {
return true
}
for _, audience := range audiences {
if audience == requiredAudience {
return true
}
}
return false
}
func isTokenExpirationValid(expiration time.Time, allowedDrift time.Duration) bool {
expirationWithAllowedDrift := expiration.Round(0).Add(allowedDrift)
return expirationWithAllowedDrift.After(time.Now())
}
func isTokenIssuerValid(requiredIssuer string, tokenIssuer string) bool {
if requiredIssuer == "" {
return false
}
return tokenIssuer == requiredIssuer
}
func isTokenTypeValid(requiredTokenType string, tokenString string) bool {
if requiredTokenType == "" {
return true
}
tokenType, err := getTokenTypeFromTokenString(tokenString)
if err != nil {
return false
}
if tokenType != requiredTokenType {
return false
}
return true
}
func isRequiredClaimsValid(requiredClaims map[string]interface{}, tokenClaims map[string]interface{}) error {
for requiredKey, requiredValue := range requiredClaims {
tokenValue, ok := tokenClaims[requiredKey]
if !ok {
return fmt.Errorf("token does not have the claim: %s", requiredKey)
}
required, received, err := getCtyValues(requiredValue, tokenValue)
if err != nil {
return err
}
err = isCtyValueValid(required, received)
if err != nil {
return fmt.Errorf("claim %q not valid: %w", requiredKey, err)
}
}
return nil
}
func getAndValidateTokenFromString(tokenString string, key jwk.Key, alg jwa.SignatureAlgorithm) (jwt.Token, error) {
token, err := jwt.ParseString(tokenString, jwt.WithVerify(alg, key))
if err != nil {
if strings.Contains(err.Error(), errSignatureVerification.Error()) {
return nil, errSignatureVerification
}
return nil, err
}
return token, nil
}
func getSignatureAlgorithm(kty jwa.KeyType, keyAlg string, fallbackAlg jwa.SignatureAlgorithm) (jwa.SignatureAlgorithm, error) {
if keyAlg != "" {
return getSignatureAlgorithmFromString(keyAlg)
}
if fallbackAlg != "" {
return fallbackAlg, nil
}
switch kty {
case jwa.RSA:
return jwa.RS256, nil
case jwa.EC:
return jwa.ES256, nil
default:
return "", fmt.Errorf("unable to get signature algorithm with kty=%s, alg=%s, fallbackAlg=%s", kty, keyAlg, fallbackAlg)
}
}
func getSignatureAlgorithmFromString(s string) (jwa.SignatureAlgorithm, error) {
var alg jwa.SignatureAlgorithm
err := alg.Accept(s)
if err != nil {
return "", err
}
return alg, nil
}