initial split from 04756eee65e20001b3a44af6249b60c52af446b7
This commit is contained in:
167
keyhandler.go
Normal file
167
keyhandler.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"go.uber.org/ratelimit"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type keyHandler struct {
|
||||
sync.RWMutex
|
||||
jwksURI string
|
||||
disableKeyID bool
|
||||
keySet jwk.Set
|
||||
fetchTimeout time.Duration
|
||||
keyUpdateSemaphore *semaphore.Weighted
|
||||
keyUpdateChannel chan keyUpdate
|
||||
keyUpdateCount int
|
||||
keyUpdateLimiter ratelimit.Limiter
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type keyUpdate struct {
|
||||
keySet jwk.Set
|
||||
err error
|
||||
}
|
||||
|
||||
func newKeyHandler(httpClient *http.Client, jwksUri string, fetchTimeout time.Duration, keyUpdateRPS uint, disableKeyID bool) (*keyHandler, error) {
|
||||
h := &keyHandler{
|
||||
jwksURI: jwksUri,
|
||||
disableKeyID: disableKeyID,
|
||||
fetchTimeout: fetchTimeout,
|
||||
keyUpdateSemaphore: semaphore.NewWeighted(int64(1)),
|
||||
keyUpdateChannel: make(chan keyUpdate),
|
||||
keyUpdateLimiter: ratelimit.New(int(keyUpdateRPS)),
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := h.updateKeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *keyHandler) updateKeySet(ctx context.Context) (jwk.Set, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, h.fetchTimeout)
|
||||
defer cancel()
|
||||
keySet, err := jwk.Fetch(ctx, h.jwksURI, jwk.WithHTTPClient(h.httpClient))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to fetch keys from %q: %w", h.jwksURI, err)
|
||||
}
|
||||
|
||||
if h.disableKeyID && keySet.Len() != 1 {
|
||||
return nil, fmt.Errorf("keyID is disabled, but received a keySet with more than one key: %d", keySet.Len())
|
||||
}
|
||||
|
||||
h.Lock()
|
||||
h.keySet = keySet
|
||||
h.keyUpdateCount++
|
||||
h.Unlock()
|
||||
|
||||
return keySet, nil
|
||||
}
|
||||
|
||||
// waitForUpdateKeySetSet handles concurrent requests to update the jwks as well as rate limiting.
|
||||
func (h *keyHandler) waitForUpdateKeySetAndGetKeySet(ctx context.Context) (jwk.Set, error) {
|
||||
// ok will be false if there's already an update in progress.
|
||||
ok := h.keyUpdateSemaphore.TryAcquire(1)
|
||||
if ok {
|
||||
defer h.keyUpdateSemaphore.Release(1)
|
||||
_ = h.keyUpdateLimiter.Take()
|
||||
keySet, err := h.updateKeySet(ctx)
|
||||
|
||||
result := keyUpdate{
|
||||
keySet,
|
||||
err,
|
||||
}
|
||||
|
||||
// start go routine to handle all requests waiting for result.
|
||||
go func(res keyUpdate) {
|
||||
// for each request waiting for update, send result to them.
|
||||
for {
|
||||
select {
|
||||
case h.keyUpdateChannel <- res:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}(result)
|
||||
|
||||
return keySet, err
|
||||
}
|
||||
|
||||
// wait for the request that is updating keys and return the result from it
|
||||
result := <-h.keyUpdateChannel
|
||||
return result.keySet, result.err
|
||||
}
|
||||
|
||||
func (h *keyHandler) waitForUpdateKeySetAndGetKey(ctx context.Context) (jwk.Key, error) {
|
||||
keySet, err := h.waitForUpdateKeySetAndGetKeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, found := keySet.Get(0)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("no key found")
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (h *keyHandler) getKey(ctx context.Context, keyID string) (jwk.Key, error) {
|
||||
if h.disableKeyID {
|
||||
return h.getKeyWithoutKeyID()
|
||||
}
|
||||
|
||||
return h.getKeyFromID(ctx, keyID)
|
||||
}
|
||||
|
||||
func (h *keyHandler) getKeySet() jwk.Set {
|
||||
h.RLock()
|
||||
defer h.RUnlock()
|
||||
return h.keySet
|
||||
}
|
||||
|
||||
func (h *keyHandler) getKeyFromID(ctx context.Context, keyID string) (jwk.Key, error) {
|
||||
keySet := h.getKeySet()
|
||||
|
||||
key, found := keySet.LookupKeyID(keyID)
|
||||
|
||||
if !found {
|
||||
updatedKeySet, err := h.waitForUpdateKeySetAndGetKeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to update key set for key %q: %w", keyID, err)
|
||||
}
|
||||
|
||||
updatedKey, found := updatedKeySet.LookupKeyID(keyID)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("unable to find key %q", keyID)
|
||||
}
|
||||
|
||||
return updatedKey, nil
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (h *keyHandler) getKeyWithoutKeyID() (jwk.Key, error) {
|
||||
keySet := h.getKeySet()
|
||||
|
||||
key, found := keySet.Get(0)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("no key found")
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
Reference in New Issue
Block a user