168 lines
3.8 KiB
Go
168 lines
3.8 KiB
Go
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/lestrrat-go/jwx/jwk"
|
|
"go.uber.org/ratelimit"
|
|
"golang.org/x/sync/semaphore"
|
|
)
|
|
|
|
type keyHandler struct {
|
|
sync.RWMutex
|
|
jwksURI string
|
|
disableKeyID bool
|
|
keySet jwk.Set
|
|
fetchTimeout time.Duration
|
|
keyUpdateSemaphore *semaphore.Weighted
|
|
keyUpdateChannel chan keyUpdate
|
|
keyUpdateCount int
|
|
keyUpdateLimiter ratelimit.Limiter
|
|
httpClient *http.Client
|
|
}
|
|
|
|
type keyUpdate struct {
|
|
keySet jwk.Set
|
|
err error
|
|
}
|
|
|
|
func newKeyHandler(httpClient *http.Client, jwksUri string, fetchTimeout time.Duration, keyUpdateRPS uint, disableKeyID bool) (*keyHandler, error) {
|
|
h := &keyHandler{
|
|
jwksURI: jwksUri,
|
|
disableKeyID: disableKeyID,
|
|
fetchTimeout: fetchTimeout,
|
|
keyUpdateSemaphore: semaphore.NewWeighted(int64(1)),
|
|
keyUpdateChannel: make(chan keyUpdate),
|
|
keyUpdateLimiter: ratelimit.New(int(keyUpdateRPS)),
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
_, err := h.updateKeySet(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return h, nil
|
|
}
|
|
|
|
func (h *keyHandler) updateKeySet(ctx context.Context) (jwk.Set, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, h.fetchTimeout)
|
|
defer cancel()
|
|
keySet, err := jwk.Fetch(ctx, h.jwksURI, jwk.WithHTTPClient(h.httpClient))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to fetch keys from %q: %w", h.jwksURI, err)
|
|
}
|
|
|
|
if h.disableKeyID && keySet.Len() != 1 {
|
|
return nil, fmt.Errorf("keyID is disabled, but received a keySet with more than one key: %d", keySet.Len())
|
|
}
|
|
|
|
h.Lock()
|
|
h.keySet = keySet
|
|
h.keyUpdateCount++
|
|
h.Unlock()
|
|
|
|
return keySet, nil
|
|
}
|
|
|
|
// waitForUpdateKeySetSet handles concurrent requests to update the jwks as well as rate limiting.
|
|
func (h *keyHandler) waitForUpdateKeySetAndGetKeySet(ctx context.Context) (jwk.Set, error) {
|
|
// ok will be false if there's already an update in progress.
|
|
ok := h.keyUpdateSemaphore.TryAcquire(1)
|
|
if ok {
|
|
defer h.keyUpdateSemaphore.Release(1)
|
|
_ = h.keyUpdateLimiter.Take()
|
|
keySet, err := h.updateKeySet(ctx)
|
|
|
|
result := keyUpdate{
|
|
keySet,
|
|
err,
|
|
}
|
|
|
|
// start go routine to handle all requests waiting for result.
|
|
go func(res keyUpdate) {
|
|
// for each request waiting for update, send result to them.
|
|
for {
|
|
select {
|
|
case h.keyUpdateChannel <- res:
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
}(result)
|
|
|
|
return keySet, err
|
|
}
|
|
|
|
// wait for the request that is updating keys and return the result from it
|
|
result := <-h.keyUpdateChannel
|
|
return result.keySet, result.err
|
|
}
|
|
|
|
func (h *keyHandler) waitForUpdateKeySetAndGetKey(ctx context.Context) (jwk.Key, error) {
|
|
keySet, err := h.waitForUpdateKeySetAndGetKeySet(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
key, found := keySet.Get(0)
|
|
if !found {
|
|
return nil, fmt.Errorf("no key found")
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func (h *keyHandler) getKey(ctx context.Context, keyID string) (jwk.Key, error) {
|
|
if h.disableKeyID {
|
|
return h.getKeyWithoutKeyID()
|
|
}
|
|
|
|
return h.getKeyFromID(ctx, keyID)
|
|
}
|
|
|
|
func (h *keyHandler) getKeySet() jwk.Set {
|
|
h.RLock()
|
|
defer h.RUnlock()
|
|
return h.keySet
|
|
}
|
|
|
|
func (h *keyHandler) getKeyFromID(ctx context.Context, keyID string) (jwk.Key, error) {
|
|
keySet := h.getKeySet()
|
|
|
|
key, found := keySet.LookupKeyID(keyID)
|
|
|
|
if !found {
|
|
updatedKeySet, err := h.waitForUpdateKeySetAndGetKeySet(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to update key set for key %q: %w", keyID, err)
|
|
}
|
|
|
|
updatedKey, found := updatedKeySet.LookupKeyID(keyID)
|
|
if !found {
|
|
return nil, fmt.Errorf("unable to find key %q", keyID)
|
|
}
|
|
|
|
return updatedKey, nil
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func (h *keyHandler) getKeyWithoutKeyID() (jwk.Key, error) {
|
|
keySet := h.getKeySet()
|
|
|
|
key, found := keySet.Get(0)
|
|
if !found {
|
|
return nil, fmt.Errorf("no key found")
|
|
}
|
|
|
|
return key, nil
|
|
}
|