package oidc import ( "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "fmt" "net/http" "testing" "time" "code.icod.de/dalu/oidc/options" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jws" "github.com/lestrrat-go/jwx/jwt" "github.com/stretchr/testify/require" "github.com/xenitab/dispans/server" ) func TestGetHeadersFromTokenString(t *testing.T) { key, _ := testNewKey(t) // Test with KeyID and Type token1 := jwt.New() err := token1.Set("foo", "bar") require.NoError(t, err) headers1 := jws.NewHeaders() err = headers1.Set(jws.TypeKey, "JWT") require.NoError(t, err) signedTokenBytes1, err := jwt.Sign(token1, jwa.ES384, key, jwt.WithHeaders(headers1)) require.NoError(t, err) signedToken1 := string(signedTokenBytes1) parsedHeaders1, err := getHeadersFromTokenString(signedToken1) require.NoError(t, err) require.Equal(t, key.KeyID(), parsedHeaders1.KeyID()) require.Equal(t, headers1.Type(), parsedHeaders1.Type()) // Test with empty headers payload1 := `{"foo":"bar"}` headers2 := jws.NewHeaders() signedTokenBytes2, err := jws.Sign([]byte(payload1), jwa.ES384, key, jws.WithHeaders(headers2)) require.NoError(t, err) signedToken2 := string(signedTokenBytes2) parsedHeaders2, err := getHeadersFromTokenString(signedToken2) require.NoError(t, err) require.Empty(t, parsedHeaders2.Type()) // Test with multiple signatures payload2 := `{"foo":"bar"}` signer1, err := jws.NewSigner(jwa.ES384) require.NoError(t, err) signer2, err := jws.NewSigner(jwa.ES384) require.NoError(t, err) signedTokenBytes3, err := jws.SignMulti([]byte(payload2), jws.WithSigner(signer1, key, nil, nil), jws.WithSigner(signer2, key, nil, nil)) require.NoError(t, err) signedToken3 := string(signedTokenBytes3) _, err = getHeadersFromTokenString(signedToken3) require.Error(t, err) require.Equal(t, "more than one signature in token", err.Error()) // Test with non-token string _, err = getHeadersFromTokenString("foo") require.Error(t, err) require.Contains(t, err.Error(), "unable to parse tokenString") } func TestGetKeyIDFromTokenString(t *testing.T) { key, _ := testNewKey(t) // Test with KeyID token1 := jwt.New() err := token1.Set("foo", "bar") require.NoError(t, err) headers1 := jws.NewHeaders() signedTokenBytes1, err := jwt.Sign(token1, jwa.ES384, key, jwt.WithHeaders(headers1)) require.NoError(t, err) signedToken1 := string(signedTokenBytes1) keyID, err := getKeyIDFromTokenString(signedToken1) require.NoError(t, err) require.Equal(t, key.KeyID(), keyID) // Test without KeyID keyWithoutKeyID := key err = keyWithoutKeyID.Remove(jwk.KeyIDKey) require.NoError(t, err) token2 := jwt.New() err = token2.Set("foo", "bar") require.NoError(t, err) headers2 := jws.NewHeaders() signedTokenBytes2, err := jwt.Sign(token2, jwa.ES384, keyWithoutKeyID, jwt.WithHeaders(headers2)) require.NoError(t, err) signedToken2 := string(signedTokenBytes2) _, err = getKeyIDFromTokenString(signedToken2) require.Error(t, err) require.Equal(t, "token header does not contain key id (kid)", err.Error()) // Test with non-token string _, err = getKeyIDFromTokenString("foo") require.Error(t, err) require.Contains(t, err.Error(), "unable to parse tokenString") } func TestGetTokenTypeFromTokenString(t *testing.T) { key, _ := testNewKey(t) // Test with Type token1 := jwt.New() err := token1.Set("foo", "bar") require.NoError(t, err) headers1 := jws.NewHeaders() err = headers1.Set(jws.TypeKey, "foo") require.NoError(t, err) signedTokenBytes1, err := jwt.Sign(token1, jwa.ES384, key, jwt.WithHeaders(headers1)) require.NoError(t, err) signedToken1 := string(signedTokenBytes1) tokenType, err := getTokenTypeFromTokenString(signedToken1) require.NoError(t, err) require.Equal(t, headers1.Type(), tokenType) // Test without KeyID payload1 := `{"foo":"bar"}` signer1, err := jws.NewSigner(jwa.ES384) require.NoError(t, err) signedTokenBytes2, err := jws.SignMulti([]byte(payload1), jws.WithSigner(signer1, key, nil, nil)) require.NoError(t, err) signedToken2 := string(signedTokenBytes2) _, err = getTokenTypeFromTokenString(signedToken2) require.Error(t, err) require.Equal(t, "token header does not contain type (typ)", err.Error()) // Test with non-token string _, err = getTokenTypeFromTokenString("foo") require.Error(t, err) require.Contains(t, err.Error(), "unable to parse tokenString") } func TestIsTokenAudienceValid(t *testing.T) { cases := []struct { testDescription string requiredAudience string tokenAudiences []string expectedResult bool }{ { testDescription: "empty requiredAudience, empty tokenAudiences", requiredAudience: "", tokenAudiences: []string{}, expectedResult: true, }, { testDescription: "empty requiredAudience, one tokenAudiences", requiredAudience: "", tokenAudiences: []string{"foo"}, expectedResult: true, }, { testDescription: "empty requiredAudience, two tokenAudiences", requiredAudience: "", tokenAudiences: []string{"foo", "bar"}, expectedResult: true, }, { testDescription: "empty requiredAudience, three tokenAudiences", requiredAudience: "", tokenAudiences: []string{"foo", "bar", "baz"}, expectedResult: true, }, { testDescription: "one tokenAudiences, same as requiredAudience", requiredAudience: "foo", tokenAudiences: []string{"foo"}, expectedResult: true, }, { testDescription: "two tokenAudiences, first same as requiredAudience", requiredAudience: "foo", tokenAudiences: []string{"foo", "bar"}, expectedResult: true, }, { testDescription: "two tokenAudiences, second same as requiredAudience", requiredAudience: "bar", tokenAudiences: []string{"foo", "bar"}, expectedResult: true, }, { testDescription: "three tokenAudiences, third same as requiredAudience", requiredAudience: "baz", tokenAudiences: []string{"foo", "bar", "baz"}, expectedResult: true, }, { testDescription: "set requiredAudience, empty tokenAudiences", requiredAudience: "foo", tokenAudiences: []string{}, expectedResult: false, }, { testDescription: "one tokenAudience, not same as requiredAudience", requiredAudience: "foo", tokenAudiences: []string{"bar"}, expectedResult: false, }, { testDescription: "two tokenAudience, none same as requiredAudience", requiredAudience: "foo", tokenAudiences: []string{"bar", "baz"}, expectedResult: false, }, { testDescription: "three tokenAudience, none same as requiredAudience", requiredAudience: "foo", tokenAudiences: []string{"bar", "baz", "foobar"}, expectedResult: false, }, } for i, c := range cases { t.Logf("Test iteration %d: %s", i, c.testDescription) result := isTokenAudienceValid(c.requiredAudience, c.tokenAudiences) require.Equal(t, c.expectedResult, result) } } func TestTokenExpirationValid(t *testing.T) { cases := []struct { testDescription string expiration time.Time allowedDrift time.Duration expectedResult bool }{ { testDescription: "expires now, 50 millisecond drift allowed", expiration: time.Now(), allowedDrift: 50 * time.Millisecond, expectedResult: true, }, { testDescription: "expires now, 10 second drift allowed", expiration: time.Now(), allowedDrift: 10 * time.Second, expectedResult: true, }, { testDescription: "expires in one hour, 10 second drift allowed", expiration: time.Now().Add(1 * time.Hour), allowedDrift: 10 * time.Second, expectedResult: true, }, { testDescription: "expired 5 seconds ago, 10 second drift allowed", expiration: time.Now().Add(-5 * time.Second), allowedDrift: 10 * time.Second, expectedResult: true, }, { testDescription: "expired 11 seconds ago, 10 second drift allowed", expiration: time.Now().Add(-11 * time.Second), allowedDrift: 10 * time.Second, expectedResult: false, }, { testDescription: "expires now, no drift", expiration: time.Now(), allowedDrift: 0, expectedResult: false, }, { testDescription: "expired an hour ago, no drift", expiration: time.Now().Add(-1 * time.Hour), allowedDrift: 0, expectedResult: false, }, { testDescription: "expired an hour ago, 10 second drift", expiration: time.Now().Add(-1 * time.Hour), allowedDrift: 10 * time.Second, expectedResult: false, }, } for i, c := range cases { t.Logf("Test iteration %d: %s", i, c.testDescription) result := isTokenExpirationValid(c.expiration, c.allowedDrift) require.Equal(t, c.expectedResult, result) } } func TestIsTokenIssuerValid(t *testing.T) { cases := []struct { testDescription string requiredIssuer string tokenIssuer string expectedResult bool }{ { testDescription: "both requiredIssuer and tokenIssuer are the same", requiredIssuer: "foo", tokenIssuer: "foo", expectedResult: true, }, { testDescription: "requiredIssuer and tokenIssuer are not the same", requiredIssuer: "foo", tokenIssuer: "bar", expectedResult: false, }, { testDescription: "both requiredIssuer and tokenIssuer are empty", requiredIssuer: "", tokenIssuer: "", expectedResult: false, }, { testDescription: "requiredIssuer is empty and tokenIssuer is set", requiredIssuer: "", tokenIssuer: "foo", expectedResult: false, }, { testDescription: "requiredIssuer is set and tokenIssuer is empty", requiredIssuer: "foo", tokenIssuer: "", expectedResult: false, }, } for i, c := range cases { t.Logf("Test iteration %d: %s", i, c.testDescription) result := isTokenIssuerValid(c.requiredIssuer, c.tokenIssuer) require.Equal(t, c.expectedResult, result) } } func TestIsTokenTypeValid(t *testing.T) { cases := []struct { testDescription string requiredTokenType string tokenType string expectedResult bool }{ { testDescription: "both requiredTokenType and tokenType are empty", requiredTokenType: "", tokenType: "", expectedResult: true, }, { testDescription: "requiredTokenType is empty and tokenType is set", requiredTokenType: "", tokenType: "foo", expectedResult: true, }, { testDescription: "both requiredTokenType and tokenType are set to the same", requiredTokenType: "foo", tokenType: "foo", expectedResult: true, }, { testDescription: "requiredTokenType and tokenType are set to different", requiredTokenType: "foo", tokenType: "bar", expectedResult: false, }, { testDescription: "requiredTokenType and tokenType are set to different but tokenType contains requiredTokenType", requiredTokenType: "foo", tokenType: "foobar", expectedResult: false, }, } for i, c := range cases { t.Logf("Test iteration %d: %s", i, c.testDescription) key, _ := testNewKey(t) payload := `{"foo":"bar"}` signer, err := jws.NewSigner(jwa.ES384) require.NoError(t, err) var signedTokenBytes []byte if c.tokenType == "" { signedTokenBytes, err = jws.SignMulti([]byte(payload), jws.WithSigner(signer, key, nil, nil)) require.NoError(t, err) } else { headers := jws.NewHeaders() err = headers.Set(jws.TypeKey, c.tokenType) require.NoError(t, err) signedTokenBytes, err = jws.SignMulti([]byte(payload), jws.WithSigner(signer, key, nil, headers)) require.NoError(t, err) } token := string(signedTokenBytes) result := isTokenTypeValid(c.requiredTokenType, token) require.Equal(t, c.expectedResult, result) } } func TestGetAndValidateTokenFromString(t *testing.T) { op := server.NewTesting(t) defer op.Close(t) issuer := op.GetURL(t) discoveryUri := GetDiscoveryUriFromIssuer(issuer) jwksUri, err := getJwksUriFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond) require.NoError(t, err) keyHandler, err := newKeyHandler(http.DefaultClient, jwksUri, 50*time.Millisecond, 100, false) require.NoError(t, err) validKey, ok := keyHandler.getKeySet().Get(0) require.True(t, ok) validAccessToken := op.GetToken(t).AccessToken require.NotEmpty(t, validAccessToken) validIDToken, ok := op.GetToken(t).Extra("id_token").(string) require.True(t, ok) require.NotEmpty(t, validIDToken) invalidKey, invalidPubKey := testNewKey(t) invalidToken := jwt.New() err = invalidToken.Set("foo", "bar") require.NoError(t, err) invalidHeaders := jws.NewHeaders() err = invalidHeaders.Set(jws.TypeKey, "JWT") require.NoError(t, err) invalidTokenBytes, err := jwt.Sign(invalidToken, jwa.ES384, invalidKey, jwt.WithHeaders(invalidHeaders)) require.NoError(t, err) invalidSignedToken := string(invalidTokenBytes) cases := []struct { testDescription string tokenString string key jwk.Key expectedError bool }{ { testDescription: "valid access token, valid key", tokenString: validAccessToken, key: validKey, expectedError: false, }, { testDescription: "valid id token, valid key", tokenString: validIDToken, key: validKey, expectedError: false, }, { testDescription: "empty string, valid key", tokenString: "", key: validKey, expectedError: true, }, { testDescription: "random string, valid key", tokenString: "foobar", key: validKey, expectedError: true, }, { testDescription: "invalid token, valid key", tokenString: invalidSignedToken, key: validKey, expectedError: true, }, { testDescription: "invalid token, invalid key", tokenString: invalidSignedToken, key: invalidPubKey, expectedError: false, }, } for i, c := range cases { t.Logf("Test iteration %d: %s", i, c.testDescription) alg, err := getSignatureAlgorithm(c.key.KeyType(), c.key.Algorithm(), jwa.ES384) require.NoError(t, err) token, err := getAndValidateTokenFromString(c.tokenString, c.key, alg) if c.expectedError { require.Error(t, err) } else { require.NoError(t, err) require.NotEmpty(t, token) } } } func TestParseToken(t *testing.T) { keySets := testNewTestKeySet(t) testServer := testNewJwksServer(t, keySets) defer testServer.Close() cases := []struct { testDescription string options []options.Option numKeys int customIssuer string customExpirationMinutes int customClaims map[string]string expectedErrorContains string }{ { testDescription: "successful parse with keyID, one key", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(false), options.WithJwksRateLimit(100), }, numKeys: 1, expectedErrorContains: "", }, { testDescription: "successful parse without keyID, one key", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(true), options.WithJwksRateLimit(100), }, numKeys: 1, expectedErrorContains: "", }, { testDescription: "successful parse with keyID, two keys", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(false), options.WithJwksRateLimit(100), }, numKeys: 2, expectedErrorContains: "", }, { // without lazyLoad, New() panics testDescription: "unsuccessful parse without keyID, two keys with lazyLoad", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(true), options.WithJwksRateLimit(100), options.WithLazyLoadJwks(true), }, numKeys: 2, expectedErrorContains: "keyID is disabled, but received a keySet with more than one key", }, { testDescription: "wrong issuer, with keyID", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(false), }, numKeys: 1, customIssuer: "http://wrong.issuer", expectedErrorContains: "required issuer \"http://foo.bar\" was not found", }, { testDescription: "wrong issuer, without keyID", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(true), }, numKeys: 1, customIssuer: "http://wrong.issuer", expectedErrorContains: "required issuer \"http://foo.bar\" was not found", }, { testDescription: "expired token, with keyID", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(false), }, numKeys: 1, customExpirationMinutes: -1, expectedErrorContains: "token has expired", }, { testDescription: "expired token, without keyID", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(true), }, numKeys: 1, customExpirationMinutes: -1, expectedErrorContains: "token has expired", }, { testDescription: "correct requiredClaim", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithRequiredClaims(map[string]interface{}{ "foo": "bar", }), options.WithDisableKeyID(false), }, numKeys: 1, expectedErrorContains: "", }, { testDescription: "correct requiredClaim", options: []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithRequiredClaims(map[string]interface{}{ "foo": "bar", }), options.WithDisableKeyID(false), }, numKeys: 1, customClaims: map[string]string{ "foo": "baz", }, expectedErrorContains: "unable to validate required claims", }, } for i, c := range cases { t.Logf("Test iteration %d: %s", i, c.testDescription) opts := &options.Options{} for _, setter := range c.options { setter(opts) } keySets.setKeys(testNewKeySet(t, c.numKeys, opts.DisableKeyID)) h, err := NewHandler(c.options...) require.NoError(t, err) parseTokenFunc := h.ParseToken issuer := opts.Issuer if c.customIssuer != "" { issuer = c.customIssuer } expirationMinutes := 1 if c.customExpirationMinutes != 0 { expirationMinutes = c.customExpirationMinutes } customClaims := make(map[string]string) customClaims["foo"] = "bar" if c.customClaims != nil { customClaims = c.customClaims } token := testNewCustomTokenString(t, keySets.privateKeySet, issuer, expirationMinutes, customClaims) ctx := context.Background() _, err = parseTokenFunc(ctx, token) if c.expectedErrorContains == "" { require.NoError(t, err) } else { require.Contains(t, err.Error(), c.expectedErrorContains) } } } func TestParseTokenWithKeyID(t *testing.T) { disableKeyID := false keySets := testNewTestKeySet(t) testServer := testNewJwksServer(t, keySets) defer testServer.Close() keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) opts := []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(disableKeyID), options.WithJwksRateLimit(100), } h, err := NewHandler(opts...) require.NoError(t, err) parseTokenFunc := h.ParseToken // first token should succeed token1 := testNewTokenString(t, keySets.privateKeySet) ctx := context.Background() _, err = parseTokenFunc(ctx, token1) require.NoError(t, err) // second token should succeed, rotation successful keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) token2 := testNewTokenString(t, keySets.privateKeySet) _, err = parseTokenFunc(ctx, token2) require.NoError(t, err) // after rotation, first token should fail _, err = parseTokenFunc(ctx, token1) require.Error(t, err) // third token should succeed with two keys keySets.setKeys(testNewKeySet(t, 2, disableKeyID)) token3 := testNewTokenString(t, keySets.privateKeySet) _, err = parseTokenFunc(ctx, token3) require.NoError(t, err) // fourth token should fail since they token doesn't contain keyID keySets.setKeys(testNewKeySet(t, 1, true)) token4 := testNewTokenString(t, keySets.privateKeySet) _, err = parseTokenFunc(ctx, token4) require.Error(t, err) // fifth token should fail since it's the wrong key but correct keyID keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) currentPrivateKey, found := keySets.privateKeySet.Get(0) require.True(t, found) currentKeyID := currentPrivateKey.KeyID() invalidPrivKey, _ := testNewKey(t) err = invalidPrivKey.Set(jwk.KeyIDKey, currentKeyID) require.NoError(t, err) invalidKeySet := jwk.NewSet() invalidKeySet.Add(invalidPrivKey) token5 := testNewTokenString(t, invalidKeySet) _, err = parseTokenFunc(ctx, token5) require.ErrorIs(t, err, errSignatureVerification) // sixth token should fail since the jwks can't be refreshed keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) token6 := testNewTokenString(t, keySets.privateKeySet) testServer.Close() _, err = parseTokenFunc(ctx, token6) require.Error(t, err) } func TestParseTokenWithoutKeyID(t *testing.T) { disableKeyID := true keySets := testNewTestKeySet(t) testServer := testNewJwksServer(t, keySets) defer testServer.Close() keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) opts := []options.Option{ options.WithIssuer("http://foo.bar"), options.WithDiscoveryUri("http://foo.bar"), options.WithJwksUri(testServer.URL), options.WithDisableKeyID(disableKeyID), options.WithJwksRateLimit(100), } h, err := NewHandler(opts...) require.NoError(t, err) parseTokenFunc := h.ParseToken // first token should succeed token1 := testNewTokenString(t, keySets.privateKeySet) ctx := context.Background() _, err = parseTokenFunc(ctx, token1) require.NoError(t, err) // second token should succeed, with key rotation keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) token2 := testNewTokenString(t, keySets.privateKeySet) _, err = parseTokenFunc(ctx, token2) require.NoError(t, err) // after rotation, first token should fail _, err = parseTokenFunc(ctx, token1) require.Error(t, err) // third token should fail since there are two keys present keySets.setKeys(testNewKeySet(t, 2, disableKeyID)) token3 := testNewTokenString(t, keySets.privateKeySet) _, err = parseTokenFunc(ctx, token3) require.Error(t, err) // fourth token should fail since the jwks can't be refreshed keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) token4 := testNewTokenString(t, keySets.privateKeySet) testServer.Close() _, err = parseTokenFunc(ctx, token4) require.Error(t, err) } func TestGetAndValidateTokenFromStringWithKeyID(t *testing.T) { disableKeyID := false keySets := testNewTestKeySet(t) testServer := testNewJwksServer(t, keySets) defer testServer.Close() keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) keyHandler, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID) require.NoError(t, err) token1 := testNewTokenString(t, keySets.privateKeySet) keyID, err := getKeyIDFromTokenString(token1) require.NoError(t, err) pubKey, err := keyHandler.getKey(context.Background(), keyID) require.NoError(t, err) alg, err := getSignatureAlgorithm(pubKey.KeyType(), pubKey.Algorithm(), jwa.ES384) require.NoError(t, err) _, err = getAndValidateTokenFromString(token1, pubKey, alg) require.NoError(t, err) keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) token2 := testNewTokenString(t, keySets.privateKeySet) _, err = getAndValidateTokenFromString(token2, pubKey, alg) require.Error(t, err) } func TestGetAndValidateTokenFromStringWithoutKeyID(t *testing.T) { disableKeyID := true keySets := testNewTestKeySet(t) testServer := testNewJwksServer(t, keySets) keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) keyHandler, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID) require.NoError(t, err) token1 := testNewTokenString(t, keySets.privateKeySet) pubKey, err := keyHandler.getKey(context.Background(), "") require.NoError(t, err) alg, err := getSignatureAlgorithm(pubKey.KeyType(), pubKey.Algorithm(), jwa.ES384) require.NoError(t, err) _, err = getAndValidateTokenFromString(token1, pubKey, alg) require.NoError(t, err) keySets.setKeys(testNewKeySet(t, 1, disableKeyID)) token2 := testNewTokenString(t, keySets.privateKeySet) _, err = getAndValidateTokenFromString(token2, pubKey, alg) require.ErrorIs(t, err, errSignatureVerification) } func TestIsRequiredClaimsValid(t *testing.T) { cases := []struct { testDescription string requiredClaims map[string]interface{} tokenClaims map[string]interface{} expectedResult bool }{ { testDescription: "both are nil", requiredClaims: nil, tokenClaims: nil, expectedResult: true, }, { testDescription: "both are empty", requiredClaims: map[string]interface{}{}, tokenClaims: map[string]interface{}{}, expectedResult: true, }, { testDescription: "required claims are nil", requiredClaims: nil, tokenClaims: map[string]interface{}{ "foo": "bar", }, expectedResult: true, }, { testDescription: "required claims are empty", requiredClaims: map[string]interface{}{}, tokenClaims: map[string]interface{}{ "foo": "bar", }, expectedResult: true, }, { testDescription: "token claims are nil", requiredClaims: map[string]interface{}{ "foo": "bar", }, tokenClaims: nil, expectedResult: false, }, { testDescription: "token claims are empty", requiredClaims: map[string]interface{}{ "foo": "bar", }, tokenClaims: map[string]interface{}{}, expectedResult: false, }, { testDescription: "required is string, token is int", requiredClaims: map[string]interface{}{ "foo": "bar", }, tokenClaims: map[string]interface{}{ "foo": 1337, }, expectedResult: false, }, { testDescription: "matching with string", requiredClaims: map[string]interface{}{ "foo": "bar", }, tokenClaims: map[string]interface{}{ "foo": "bar", }, expectedResult: true, }, { testDescription: "matching with string and int", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, }, tokenClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, }, expectedResult: true, }, { testDescription: "matching with string and int in different orders", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, }, tokenClaims: map[string]interface{}{ "bar": 1337, "foo": "bar", }, expectedResult: true, }, { testDescription: "matching with string, int and float", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": 13.37, }, tokenClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": 13.37, }, expectedResult: true, }, { testDescription: "not matching with string, int and float", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": 13.37, }, tokenClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": 12.27, }, expectedResult: false, }, { testDescription: "matching slice", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"foo"}, }, tokenClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"foo"}, }, expectedResult: true, }, { testDescription: "matching slice with multiple values", requiredClaims: map[string]interface{}{ "oof": []string{"foo", "bar"}, }, tokenClaims: map[string]interface{}{ "oof": []string{"foo", "bar", "baz"}, }, expectedResult: true, }, { testDescription: "required slice contains in token slice", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"foo"}, }, tokenClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"foo", "bar", "baz"}, }, expectedResult: true, }, { testDescription: "not matching slice", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"foo"}, }, tokenClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"bar"}, }, expectedResult: false, }, { testDescription: "matching map", requiredClaims: map[string]interface{}{ "foo": map[string]string{ "foo": "bar", }, }, tokenClaims: map[string]interface{}{ "foo": map[string]string{ "foo": "bar", }, }, expectedResult: true, }, { testDescription: "matching map with multiple values", requiredClaims: map[string]interface{}{ "foo": map[string]string{ "foo": "bar", "bar": "foo", }, }, tokenClaims: map[string]interface{}{ "foo": map[string]string{ "a": "b", "foo": "bar", "bar": "foo", "c": "d", }, }, expectedResult: true, }, { testDescription: "matching map with multiple keys in token claims", requiredClaims: map[string]interface{}{ "foo": map[string]string{ "foo": "bar", }, }, tokenClaims: map[string]interface{}{ "foo": map[string]string{ "a": "b", "foo": "bar", "c": "d", }, }, expectedResult: true, }, { testDescription: "not matching map", requiredClaims: map[string]interface{}{ "foo": map[string]string{ "foo": "bar", }, }, tokenClaims: map[string]interface{}{ "foo": map[string]int{ "foo": 1337, }, }, expectedResult: false, }, { testDescription: "matching map with string slice", requiredClaims: map[string]interface{}{ "foo": map[string][]string{ "foo": {"bar"}, }, }, tokenClaims: map[string]interface{}{ "foo": map[string][]string{ "foo": {"foo", "bar", "baz"}, }, }, expectedResult: true, }, { testDescription: "not matching map with string slice", requiredClaims: map[string]interface{}{ "foo": map[string][]string{ "foo": {"foobar"}, }, }, tokenClaims: map[string]interface{}{ "foo": map[string][]string{ "foo": {"foo", "bar", "baz"}, }, }, expectedResult: false, }, { testDescription: "matching slice with map", requiredClaims: map[string]interface{}{ "foo": []map[string]string{ {"bar": "baz"}, }, }, tokenClaims: map[string]interface{}{ "foo": []map[string]string{ {"bar": "baz"}, }, }, expectedResult: true, }, { testDescription: "not matching slice with map", requiredClaims: map[string]interface{}{ "foo": []map[string]string{ {"bar": "foobar"}, }, }, tokenClaims: map[string]interface{}{ "foo": []map[string]string{ {"bar": "baz"}, }, }, expectedResult: false, }, { testDescription: "matching primitive types, slice and map", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"foo"}, "oof": []map[string]string{ {"bar": "baz"}, }, }, tokenClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"foo"}, "oof": []map[string]string{ {"bar": "baz"}, }, }, expectedResult: true, }, { testDescription: "matching primitive types, slice and map where token contains multiple values", requiredClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"bar"}, "oof": []map[string]string{ {"bar": "baz"}, }, }, tokenClaims: map[string]interface{}{ "foo": "bar", "bar": 1337, "baz": []string{"foo", "bar", "baz"}, "oof": []map[string]string{ {"a": "b"}, {"bar": "baz"}, {"c": "d"}, }, }, expectedResult: true, }, { testDescription: "valid interface list in an interface map", requiredClaims: map[string]interface{}{ "foo": map[string][]string{ "bar": {"baz"}, }, }, tokenClaims: map[string]interface{}{ "foo": map[string]interface{}{ "bar": []interface{}{ "uno", "dos", "baz", "tres", }, }, }, expectedResult: true, }, { testDescription: "invalid interface list in an interface map", requiredClaims: map[string]interface{}{ "foo": map[string][]string{ "bar": {"baz"}, }, }, tokenClaims: map[string]interface{}{ "foo": map[string]interface{}{ "bar": []interface{}{ "uno", "dos", "tres", }, }, }, expectedResult: false, }, } for i, c := range cases { t.Logf("Test iteration %d: %s", i, c.testDescription) err := isRequiredClaimsValid(c.requiredClaims, c.tokenClaims) if c.expectedResult { require.NoError(t, err) } else { require.Error(t, err) } } } func TestGetSignatureAlgorithm(t *testing.T) { cases := []struct { inputKty jwa.KeyType inputAlg string inputFallbackAlg jwa.SignatureAlgorithm expectedResult jwa.SignatureAlgorithm expectedError bool }{ { inputKty: jwa.RSA, inputAlg: "RS256", inputFallbackAlg: "", expectedResult: jwa.RS256, expectedError: false, }, { inputKty: jwa.EC, inputAlg: "ES256", inputFallbackAlg: "", expectedResult: jwa.ES256, expectedError: false, }, { inputKty: jwa.RSA, inputAlg: "", inputFallbackAlg: "", expectedResult: jwa.RS256, expectedError: false, }, { inputKty: jwa.EC, inputAlg: "", inputFallbackAlg: "", expectedResult: jwa.ES256, expectedError: false, }, { inputKty: "", inputAlg: "", inputFallbackAlg: "", expectedResult: "", expectedError: true, }, { inputKty: "", inputAlg: "foobar", inputFallbackAlg: "", expectedResult: "", expectedError: true, }, { inputKty: "", inputAlg: "", inputFallbackAlg: jwa.ES384, expectedResult: jwa.ES384, expectedError: false, }, { inputKty: jwa.RSA, inputAlg: "", inputFallbackAlg: jwa.ES384, expectedResult: jwa.ES384, expectedError: false, }, } for i, c := range cases { t.Logf("Test iteration %d: inputKty=%s, inputAlg=%s, inputFallbackAlg=%s", i, c.inputKty, c.inputAlg, c.inputFallbackAlg) result, err := getSignatureAlgorithm(c.inputKty, c.inputAlg, c.inputFallbackAlg) require.Equal(t, c.expectedResult, result) if !c.expectedError { require.NoError(t, err) } else { require.Error(t, err) } } } func testNewKey(tb testing.TB) (jwk.Key, jwk.Key) { tb.Helper() ecdsaKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) require.NoError(tb, err) key, err := jwk.New(ecdsaKey) require.NoError(tb, err) _, ok := key.(jwk.ECDSAPrivateKey) require.True(tb, ok) thumbprint, err := key.Thumbprint(crypto.SHA256) require.NoError(tb, err) keyID := fmt.Sprintf("%x", thumbprint) err = key.Set(jwk.KeyIDKey, keyID) require.NoError(tb, err) pubKey, err := jwk.New(ecdsaKey.PublicKey) require.NoError(tb, err) _, ok = pubKey.(jwk.ECDSAPublicKey) require.True(tb, ok) err = pubKey.Set(jwk.KeyIDKey, keyID) require.NoError(tb, err) err = pubKey.Set(jwk.AlgorithmKey, jwa.ES384) require.NoError(tb, err) return key, pubKey } func testNewTokenString(t *testing.T, privKeySet jwk.Set) string { t.Helper() jwtToken := jwt.New() err := jwtToken.Set(jwt.IssuerKey, "http://foo.bar") require.NoError(t, err) err = jwtToken.Set(jwt.ExpirationKey, time.Now().Add(1*time.Minute).Unix()) require.NoError(t, err) err = jwtToken.Set("foo", "bar") require.NoError(t, err) headers := jws.NewHeaders() err = headers.Set(jws.TypeKey, "JWT") require.NoError(t, err) privKey, found := privKeySet.Get(0) require.True(t, found) tokenBytes, err := jwt.Sign(jwtToken, jwa.ES384, privKey, jwt.WithHeaders(headers)) require.NoError(t, err) return string(tokenBytes) } func testNewCustomTokenString(t *testing.T, privKeySet jwk.Set, issuer string, expirationMinutes int, customClaims map[string]string) string { t.Helper() jwtToken := jwt.New() err := jwtToken.Set(jwt.IssuerKey, issuer) require.NoError(t, err) err = jwtToken.Set(jwt.ExpirationKey, time.Now().Add(time.Duration(expirationMinutes)*time.Minute).Unix()) require.NoError(t, err) for k, v := range customClaims { err := jwtToken.Set(k, v) require.NoError(t, err) } headers := jws.NewHeaders() err = headers.Set(jws.TypeKey, "JWT") require.NoError(t, err) privKey, found := privKeySet.Get(0) require.True(t, found) tokenBytes, err := jwt.Sign(jwtToken, jwa.ES384, privKey, jwt.WithHeaders(headers)) require.NoError(t, err) return string(tokenBytes) }