oidc/keyhandler.go

168 lines
3.8 KiB
Go

package oidc
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"github.com/lestrrat-go/jwx/jwk"
"go.uber.org/ratelimit"
"golang.org/x/sync/semaphore"
)
type keyHandler struct {
sync.RWMutex
jwksURI string
disableKeyID bool
keySet jwk.Set
fetchTimeout time.Duration
keyUpdateSemaphore *semaphore.Weighted
keyUpdateChannel chan keyUpdate
keyUpdateCount int
keyUpdateLimiter ratelimit.Limiter
httpClient *http.Client
}
type keyUpdate struct {
keySet jwk.Set
err error
}
func newKeyHandler(httpClient *http.Client, jwksUri string, fetchTimeout time.Duration, keyUpdateRPS uint, disableKeyID bool) (*keyHandler, error) {
h := &keyHandler{
jwksURI: jwksUri,
disableKeyID: disableKeyID,
fetchTimeout: fetchTimeout,
keyUpdateSemaphore: semaphore.NewWeighted(int64(1)),
keyUpdateChannel: make(chan keyUpdate),
keyUpdateLimiter: ratelimit.New(int(keyUpdateRPS)),
httpClient: httpClient,
}
ctx := context.Background()
_, err := h.updateKeySet(ctx)
if err != nil {
return nil, err
}
return h, nil
}
func (h *keyHandler) updateKeySet(ctx context.Context) (jwk.Set, error) {
ctx, cancel := context.WithTimeout(ctx, h.fetchTimeout)
defer cancel()
keySet, err := jwk.Fetch(ctx, h.jwksURI, jwk.WithHTTPClient(h.httpClient))
if err != nil {
return nil, fmt.Errorf("unable to fetch keys from %q: %w", h.jwksURI, err)
}
if h.disableKeyID && keySet.Len() != 1 {
return nil, fmt.Errorf("keyID is disabled, but received a keySet with more than one key: %d", keySet.Len())
}
h.Lock()
h.keySet = keySet
h.keyUpdateCount++
h.Unlock()
return keySet, nil
}
// waitForUpdateKeySetSet handles concurrent requests to update the jwks as well as rate limiting.
func (h *keyHandler) waitForUpdateKeySetAndGetKeySet(ctx context.Context) (jwk.Set, error) {
// ok will be false if there's already an update in progress.
ok := h.keyUpdateSemaphore.TryAcquire(1)
if ok {
defer h.keyUpdateSemaphore.Release(1)
_ = h.keyUpdateLimiter.Take()
keySet, err := h.updateKeySet(ctx)
result := keyUpdate{
keySet,
err,
}
// start go routine to handle all requests waiting for result.
go func(res keyUpdate) {
// for each request waiting for update, send result to them.
for {
select {
case h.keyUpdateChannel <- res:
default:
return
}
}
}(result)
return keySet, err
}
// wait for the request that is updating keys and return the result from it
result := <-h.keyUpdateChannel
return result.keySet, result.err
}
func (h *keyHandler) waitForUpdateKeySetAndGetKey(ctx context.Context) (jwk.Key, error) {
keySet, err := h.waitForUpdateKeySetAndGetKeySet(ctx)
if err != nil {
return nil, err
}
key, found := keySet.Get(0)
if !found {
return nil, fmt.Errorf("no key found")
}
return key, nil
}
func (h *keyHandler) getKey(ctx context.Context, keyID string) (jwk.Key, error) {
if h.disableKeyID {
return h.getKeyWithoutKeyID()
}
return h.getKeyFromID(ctx, keyID)
}
func (h *keyHandler) getKeySet() jwk.Set {
h.RLock()
defer h.RUnlock()
return h.keySet
}
func (h *keyHandler) getKeyFromID(ctx context.Context, keyID string) (jwk.Key, error) {
keySet := h.getKeySet()
key, found := keySet.LookupKeyID(keyID)
if !found {
updatedKeySet, err := h.waitForUpdateKeySetAndGetKeySet(ctx)
if err != nil {
return nil, fmt.Errorf("unable to update key set for key %q: %w", keyID, err)
}
updatedKey, found := updatedKeySet.LookupKeyID(keyID)
if !found {
return nil, fmt.Errorf("unable to find key %q", keyID)
}
return updatedKey, nil
}
return key, nil
}
func (h *keyHandler) getKeyWithoutKeyID() (jwk.Key, error) {
keySet := h.getKeySet()
key, found := keySet.Get(0)
if !found {
return nil, fmt.Errorf("no key found")
}
return key, nil
}