340 lines
8.4 KiB
Go
340 lines
8.4 KiB
Go
|
package oidc
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/lestrrat-go/jwx/jwk"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
"github.com/xenitab/dispans/server"
|
||
|
)
|
||
|
|
||
|
func TestNewKeyHandler(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
op := server.NewTesting(t)
|
||
|
issuer := op.GetURL(t)
|
||
|
discoveryUri := GetDiscoveryUriFromIssuer(issuer)
|
||
|
jwksUri, err := getJwksUriFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, jwksUri, 10*time.Millisecond, 100, false)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySet1 := keyHandler.getKeySet()
|
||
|
require.Equal(t, 1, keySet1.Len())
|
||
|
|
||
|
expectedKey1, ok := keySet1.Get(0)
|
||
|
require.True(t, ok)
|
||
|
|
||
|
token1 := op.GetToken(t)
|
||
|
keyID1, err := getKeyIDFromTokenString(token1.AccessToken)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// Test valid key id
|
||
|
key1, err := keyHandler.getKeyFromID(ctx, keyID1)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, expectedKey1, key1)
|
||
|
|
||
|
// Test invalid key id
|
||
|
_, err = keyHandler.getKeyFromID(ctx, "foo")
|
||
|
require.Error(t, err)
|
||
|
|
||
|
// Test with rotated keys
|
||
|
op.RotateKeys(t)
|
||
|
|
||
|
token2 := op.GetToken(t)
|
||
|
keyID2, err := getKeyIDFromTokenString(token2.AccessToken)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
key2, err := keyHandler.getKeyFromID(ctx, keyID2)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySet2 := keyHandler.getKeySet()
|
||
|
require.Equal(t, 1, keySet2.Len())
|
||
|
|
||
|
expectedKey2, ok := keySet2.Get(0)
|
||
|
require.True(t, ok)
|
||
|
|
||
|
require.Equal(t, expectedKey2, key2)
|
||
|
|
||
|
// Test that old key doesn't match new key
|
||
|
require.NotEqual(t, key1, key2)
|
||
|
|
||
|
// Validate that error is returned when using fake jwks uri
|
||
|
_, err = newKeyHandler(http.DefaultClient, "http://foo.bar/baz", 10*time.Millisecond, 100, false)
|
||
|
require.Error(t, err)
|
||
|
|
||
|
// Validate that error is returned when keys are rotated,
|
||
|
// new token with new key and jwks uri isn't accessible
|
||
|
op.RotateKeys(t)
|
||
|
token3 := op.GetToken(t)
|
||
|
keyID3, err := getKeyIDFromTokenString(token3.AccessToken)
|
||
|
require.NoError(t, err)
|
||
|
op.Close(t)
|
||
|
_, err = keyHandler.getKeyFromID(ctx, keyID3)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestUpdate(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
op := server.NewTesting(t)
|
||
|
issuer := op.GetURL(t)
|
||
|
discoveryUri := GetDiscoveryUriFromIssuer(issuer)
|
||
|
jwksUri, err := getJwksUriFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
rateLimit := uint(10)
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, jwksUri, 10*time.Millisecond, rateLimit, false)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Equal(t, 1, keyHandler.keyUpdateCount)
|
||
|
|
||
|
_, err = keyHandler.waitForUpdateKeySetAndGetKeySet(ctx)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Equal(t, 2, keyHandler.keyUpdateCount)
|
||
|
|
||
|
concurrentUpdate := func(workers int) {
|
||
|
wg1 := sync.WaitGroup{}
|
||
|
wg1.Add(1)
|
||
|
|
||
|
wg2 := sync.WaitGroup{}
|
||
|
for i := 0; i < workers; i++ {
|
||
|
wg2.Add(1)
|
||
|
go func() {
|
||
|
wg1.Wait()
|
||
|
_, err := keyHandler.waitForUpdateKeySetAndGetKeySet(ctx)
|
||
|
require.NoError(t, err)
|
||
|
wg2.Done()
|
||
|
}()
|
||
|
}
|
||
|
wg1.Done()
|
||
|
wg2.Wait()
|
||
|
}
|
||
|
|
||
|
concurrentUpdate(100)
|
||
|
require.Equal(t, 3, keyHandler.keyUpdateCount)
|
||
|
concurrentUpdate(100)
|
||
|
require.Equal(t, 4, keyHandler.keyUpdateCount)
|
||
|
concurrentUpdate(100)
|
||
|
require.Equal(t, 5, keyHandler.keyUpdateCount)
|
||
|
|
||
|
multipleConcurrentUpdates := func() {
|
||
|
wg1 := sync.WaitGroup{}
|
||
|
wg1.Add(1)
|
||
|
|
||
|
wg2 := sync.WaitGroup{}
|
||
|
for i := 0; i < 10; i++ {
|
||
|
wg2.Add(1)
|
||
|
go func() {
|
||
|
wg1.Wait()
|
||
|
concurrentUpdate(10)
|
||
|
wg2.Done()
|
||
|
}()
|
||
|
}
|
||
|
wg1.Done()
|
||
|
wg2.Wait()
|
||
|
}
|
||
|
|
||
|
multipleConcurrentUpdates()
|
||
|
require.Equal(t, 6, keyHandler.keyUpdateCount)
|
||
|
|
||
|
// test rate limit
|
||
|
time.Sleep(10 * time.Millisecond)
|
||
|
start := time.Now()
|
||
|
_, err = keyHandler.waitForUpdateKeySetAndGetKeySet(ctx)
|
||
|
require.NoError(t, err)
|
||
|
stop := time.Now()
|
||
|
expectedStop := start.Add(time.Second / time.Duration(rateLimit))
|
||
|
|
||
|
require.WithinDuration(t, expectedStop, stop, 20*time.Millisecond)
|
||
|
|
||
|
require.Equal(t, 7, keyHandler.keyUpdateCount)
|
||
|
}
|
||
|
|
||
|
func TestNewKeyHandlerWithKeyIDDisabled(t *testing.T) {
|
||
|
disableKeyID := true
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
_, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 2, disableKeyID))
|
||
|
|
||
|
_, err = newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestNewKeyHandlerWithKeyIDEnabled(t *testing.T) {
|
||
|
disableKeyID := false
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
_, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 2, disableKeyID))
|
||
|
|
||
|
_, err = newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
func TestUpdateKeySetWithKeyIDDisabled(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
disableKeyID := true
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
_, err = keyHandler.updateKeySet(ctx)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 2, disableKeyID))
|
||
|
|
||
|
_, err = keyHandler.updateKeySet(ctx)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestUpdateKeySetWithKeyIDEnabled(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
disableKeyID := false
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, testServer.URL, 100*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
_, err = keyHandler.updateKeySet(ctx)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 2, disableKeyID))
|
||
|
|
||
|
_, err = keyHandler.updateKeySet(ctx)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
func TestWaitForUpdateKeySetWithKeyIDDisabled(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
disableKeyID := true
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
_, err = keyHandler.waitForUpdateKeySetAndGetKey(ctx)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 2, disableKeyID))
|
||
|
|
||
|
_, err = keyHandler.waitForUpdateKeySetAndGetKey(ctx)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestWaitForUpdateKeySetWithKeyIDEnabled(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
disableKeyID := false
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
_, err = keyHandler.waitForUpdateKeySetAndGetKey(ctx)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 2, disableKeyID))
|
||
|
|
||
|
_, err = keyHandler.waitForUpdateKeySetAndGetKey(ctx)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
func testNewJwksServer(t *testing.T, keySets *testKeySets) *httptest.Server {
|
||
|
t.Helper()
|
||
|
|
||
|
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
w.Header().Set("Content-Type", "application/json")
|
||
|
err := json.NewEncoder(w).Encode(keySets.publicKeySet)
|
||
|
require.NoError(t, err)
|
||
|
}))
|
||
|
|
||
|
return testServer
|
||
|
}
|
||
|
|
||
|
type testKeySets struct {
|
||
|
privateKeySet jwk.Set
|
||
|
publicKeySet jwk.Set
|
||
|
}
|
||
|
|
||
|
func testNewTestKeySet(t *testing.T) *testKeySets {
|
||
|
t.Helper()
|
||
|
|
||
|
return &testKeySets{}
|
||
|
}
|
||
|
|
||
|
func (k *testKeySets) setKeys(privKeySet jwk.Set, pubKeySet jwk.Set) {
|
||
|
k.privateKeySet = privKeySet
|
||
|
k.publicKeySet = pubKeySet
|
||
|
}
|
||
|
|
||
|
func testNewKeySet(t *testing.T, numKeys int, disableKeyID bool) (jwk.Set, jwk.Set) {
|
||
|
t.Helper()
|
||
|
|
||
|
privKeySet := jwk.NewSet()
|
||
|
pubKeySet := jwk.NewSet()
|
||
|
for i := 0; i < numKeys; i++ {
|
||
|
privKey, pubKey := testNewKey(t)
|
||
|
|
||
|
if disableKeyID {
|
||
|
err := privKey.Remove(jwk.KeyIDKey)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err = pubKey.Remove(jwk.KeyIDKey)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
privKeySet.Add(privKey)
|
||
|
pubKeySet.Add(pubKey)
|
||
|
}
|
||
|
|
||
|
return privKeySet, pubKeySet
|
||
|
}
|