1485 lines
37 KiB
Go
1485 lines
37 KiB
Go
|
package oidc
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto"
|
||
|
"crypto/ecdsa"
|
||
|
"crypto/elliptic"
|
||
|
"crypto/rand"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"git.icod.de/dalu/oidc/options"
|
||
|
"github.com/lestrrat-go/jwx/jwa"
|
||
|
"github.com/lestrrat-go/jwx/jwk"
|
||
|
"github.com/lestrrat-go/jwx/jws"
|
||
|
"github.com/lestrrat-go/jwx/jwt"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
"github.com/xenitab/dispans/server"
|
||
|
)
|
||
|
|
||
|
func TestGetHeadersFromTokenString(t *testing.T) {
|
||
|
key, _ := testNewKey(t)
|
||
|
|
||
|
// Test with KeyID and Type
|
||
|
token1 := jwt.New()
|
||
|
err := token1.Set("foo", "bar")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
headers1 := jws.NewHeaders()
|
||
|
err = headers1.Set(jws.TypeKey, "JWT")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedTokenBytes1, err := jwt.Sign(token1, jwa.ES384, key, jwt.WithHeaders(headers1))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedToken1 := string(signedTokenBytes1)
|
||
|
parsedHeaders1, err := getHeadersFromTokenString(signedToken1)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Equal(t, key.KeyID(), parsedHeaders1.KeyID())
|
||
|
require.Equal(t, headers1.Type(), parsedHeaders1.Type())
|
||
|
|
||
|
// Test with empty headers
|
||
|
payload1 := `{"foo":"bar"}`
|
||
|
|
||
|
headers2 := jws.NewHeaders()
|
||
|
|
||
|
signedTokenBytes2, err := jws.Sign([]byte(payload1), jwa.ES384, key, jws.WithHeaders(headers2))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedToken2 := string(signedTokenBytes2)
|
||
|
parsedHeaders2, err := getHeadersFromTokenString(signedToken2)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Empty(t, parsedHeaders2.Type())
|
||
|
|
||
|
// Test with multiple signatures
|
||
|
payload2 := `{"foo":"bar"}`
|
||
|
|
||
|
signer1, err := jws.NewSigner(jwa.ES384)
|
||
|
require.NoError(t, err)
|
||
|
signer2, err := jws.NewSigner(jwa.ES384)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedTokenBytes3, err := jws.SignMulti([]byte(payload2), jws.WithSigner(signer1, key, nil, nil), jws.WithSigner(signer2, key, nil, nil))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedToken3 := string(signedTokenBytes3)
|
||
|
|
||
|
_, err = getHeadersFromTokenString(signedToken3)
|
||
|
require.Error(t, err)
|
||
|
require.Equal(t, "more than one signature in token", err.Error())
|
||
|
|
||
|
// Test with non-token string
|
||
|
_, err = getHeadersFromTokenString("foo")
|
||
|
require.Error(t, err)
|
||
|
require.Contains(t, err.Error(), "unable to parse tokenString")
|
||
|
}
|
||
|
|
||
|
func TestGetKeyIDFromTokenString(t *testing.T) {
|
||
|
key, _ := testNewKey(t)
|
||
|
|
||
|
// Test with KeyID
|
||
|
token1 := jwt.New()
|
||
|
err := token1.Set("foo", "bar")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
headers1 := jws.NewHeaders()
|
||
|
|
||
|
signedTokenBytes1, err := jwt.Sign(token1, jwa.ES384, key, jwt.WithHeaders(headers1))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedToken1 := string(signedTokenBytes1)
|
||
|
keyID, err := getKeyIDFromTokenString(signedToken1)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Equal(t, key.KeyID(), keyID)
|
||
|
|
||
|
// Test without KeyID
|
||
|
keyWithoutKeyID := key
|
||
|
err = keyWithoutKeyID.Remove(jwk.KeyIDKey)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
token2 := jwt.New()
|
||
|
err = token2.Set("foo", "bar")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
headers2 := jws.NewHeaders()
|
||
|
|
||
|
signedTokenBytes2, err := jwt.Sign(token2, jwa.ES384, keyWithoutKeyID, jwt.WithHeaders(headers2))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedToken2 := string(signedTokenBytes2)
|
||
|
_, err = getKeyIDFromTokenString(signedToken2)
|
||
|
require.Error(t, err)
|
||
|
require.Equal(t, "token header does not contain key id (kid)", err.Error())
|
||
|
|
||
|
// Test with non-token string
|
||
|
_, err = getKeyIDFromTokenString("foo")
|
||
|
require.Error(t, err)
|
||
|
require.Contains(t, err.Error(), "unable to parse tokenString")
|
||
|
}
|
||
|
|
||
|
func TestGetTokenTypeFromTokenString(t *testing.T) {
|
||
|
key, _ := testNewKey(t)
|
||
|
|
||
|
// Test with Type
|
||
|
token1 := jwt.New()
|
||
|
err := token1.Set("foo", "bar")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
headers1 := jws.NewHeaders()
|
||
|
err = headers1.Set(jws.TypeKey, "foo")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedTokenBytes1, err := jwt.Sign(token1, jwa.ES384, key, jwt.WithHeaders(headers1))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedToken1 := string(signedTokenBytes1)
|
||
|
tokenType, err := getTokenTypeFromTokenString(signedToken1)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.Equal(t, headers1.Type(), tokenType)
|
||
|
|
||
|
// Test without KeyID
|
||
|
payload1 := `{"foo":"bar"}`
|
||
|
|
||
|
signer1, err := jws.NewSigner(jwa.ES384)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedTokenBytes2, err := jws.SignMulti([]byte(payload1), jws.WithSigner(signer1, key, nil, nil))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedToken2 := string(signedTokenBytes2)
|
||
|
_, err = getTokenTypeFromTokenString(signedToken2)
|
||
|
require.Error(t, err)
|
||
|
require.Equal(t, "token header does not contain type (typ)", err.Error())
|
||
|
|
||
|
// Test with non-token string
|
||
|
_, err = getTokenTypeFromTokenString("foo")
|
||
|
require.Error(t, err)
|
||
|
require.Contains(t, err.Error(), "unable to parse tokenString")
|
||
|
}
|
||
|
|
||
|
func TestIsTokenAudienceValid(t *testing.T) {
|
||
|
cases := []struct {
|
||
|
testDescription string
|
||
|
requiredAudience string
|
||
|
tokenAudiences []string
|
||
|
expectedResult bool
|
||
|
}{
|
||
|
{
|
||
|
testDescription: "empty requiredAudience, empty tokenAudiences",
|
||
|
requiredAudience: "",
|
||
|
tokenAudiences: []string{},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "empty requiredAudience, one tokenAudiences",
|
||
|
requiredAudience: "",
|
||
|
tokenAudiences: []string{"foo"},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "empty requiredAudience, two tokenAudiences",
|
||
|
requiredAudience: "",
|
||
|
tokenAudiences: []string{"foo", "bar"},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "empty requiredAudience, three tokenAudiences",
|
||
|
requiredAudience: "",
|
||
|
tokenAudiences: []string{"foo", "bar", "baz"},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "one tokenAudiences, same as requiredAudience",
|
||
|
requiredAudience: "foo",
|
||
|
tokenAudiences: []string{"foo"},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "two tokenAudiences, first same as requiredAudience",
|
||
|
requiredAudience: "foo",
|
||
|
tokenAudiences: []string{"foo", "bar"},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "two tokenAudiences, second same as requiredAudience",
|
||
|
requiredAudience: "bar",
|
||
|
tokenAudiences: []string{"foo", "bar"},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "three tokenAudiences, third same as requiredAudience",
|
||
|
requiredAudience: "baz",
|
||
|
tokenAudiences: []string{"foo", "bar", "baz"},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "set requiredAudience, empty tokenAudiences",
|
||
|
requiredAudience: "foo",
|
||
|
tokenAudiences: []string{},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "one tokenAudience, not same as requiredAudience",
|
||
|
requiredAudience: "foo",
|
||
|
tokenAudiences: []string{"bar"},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "two tokenAudience, none same as requiredAudience",
|
||
|
requiredAudience: "foo",
|
||
|
tokenAudiences: []string{"bar", "baz"},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "three tokenAudience, none same as requiredAudience",
|
||
|
requiredAudience: "foo",
|
||
|
tokenAudiences: []string{"bar", "baz", "foobar"},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, c := range cases {
|
||
|
t.Logf("Test iteration %d: %s", i, c.testDescription)
|
||
|
result := isTokenAudienceValid(c.requiredAudience, c.tokenAudiences)
|
||
|
require.Equal(t, c.expectedResult, result)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestTokenExpirationValid(t *testing.T) {
|
||
|
cases := []struct {
|
||
|
testDescription string
|
||
|
expiration time.Time
|
||
|
allowedDrift time.Duration
|
||
|
expectedResult bool
|
||
|
}{
|
||
|
{
|
||
|
testDescription: "expires now, 50 millisecond drift allowed",
|
||
|
expiration: time.Now(),
|
||
|
allowedDrift: 50 * time.Millisecond,
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expires now, 10 second drift allowed",
|
||
|
expiration: time.Now(),
|
||
|
allowedDrift: 10 * time.Second,
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expires in one hour, 10 second drift allowed",
|
||
|
expiration: time.Now().Add(1 * time.Hour),
|
||
|
allowedDrift: 10 * time.Second,
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expired 5 seconds ago, 10 second drift allowed",
|
||
|
expiration: time.Now().Add(-5 * time.Second),
|
||
|
allowedDrift: 10 * time.Second,
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expired 11 seconds ago, 10 second drift allowed",
|
||
|
expiration: time.Now().Add(-11 * time.Second),
|
||
|
allowedDrift: 10 * time.Second,
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expires now, no drift",
|
||
|
expiration: time.Now(),
|
||
|
allowedDrift: 0,
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expired an hour ago, no drift",
|
||
|
expiration: time.Now().Add(-1 * time.Hour),
|
||
|
allowedDrift: 0,
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expired an hour ago, 10 second drift",
|
||
|
expiration: time.Now().Add(-1 * time.Hour),
|
||
|
allowedDrift: 10 * time.Second,
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, c := range cases {
|
||
|
t.Logf("Test iteration %d: %s", i, c.testDescription)
|
||
|
result := isTokenExpirationValid(c.expiration, c.allowedDrift)
|
||
|
require.Equal(t, c.expectedResult, result)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestIsTokenIssuerValid(t *testing.T) {
|
||
|
cases := []struct {
|
||
|
testDescription string
|
||
|
requiredIssuer string
|
||
|
tokenIssuer string
|
||
|
expectedResult bool
|
||
|
}{
|
||
|
{
|
||
|
testDescription: "both requiredIssuer and tokenIssuer are the same",
|
||
|
requiredIssuer: "foo",
|
||
|
tokenIssuer: "foo",
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "requiredIssuer and tokenIssuer are not the same",
|
||
|
requiredIssuer: "foo",
|
||
|
tokenIssuer: "bar",
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "both requiredIssuer and tokenIssuer are empty",
|
||
|
requiredIssuer: "",
|
||
|
tokenIssuer: "",
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "requiredIssuer is empty and tokenIssuer is set",
|
||
|
requiredIssuer: "",
|
||
|
tokenIssuer: "foo",
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "requiredIssuer is set and tokenIssuer is empty",
|
||
|
requiredIssuer: "foo",
|
||
|
tokenIssuer: "",
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, c := range cases {
|
||
|
t.Logf("Test iteration %d: %s", i, c.testDescription)
|
||
|
result := isTokenIssuerValid(c.requiredIssuer, c.tokenIssuer)
|
||
|
require.Equal(t, c.expectedResult, result)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestIsTokenTypeValid(t *testing.T) {
|
||
|
cases := []struct {
|
||
|
testDescription string
|
||
|
requiredTokenType string
|
||
|
tokenType string
|
||
|
expectedResult bool
|
||
|
}{
|
||
|
{
|
||
|
testDescription: "both requiredTokenType and tokenType are empty",
|
||
|
requiredTokenType: "",
|
||
|
tokenType: "",
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "requiredTokenType is empty and tokenType is set",
|
||
|
requiredTokenType: "",
|
||
|
tokenType: "foo",
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "both requiredTokenType and tokenType are set to the same",
|
||
|
requiredTokenType: "foo",
|
||
|
tokenType: "foo",
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "requiredTokenType and tokenType are set to different",
|
||
|
requiredTokenType: "foo",
|
||
|
tokenType: "bar",
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "requiredTokenType and tokenType are set to different but tokenType contains requiredTokenType",
|
||
|
requiredTokenType: "foo",
|
||
|
tokenType: "foobar",
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, c := range cases {
|
||
|
t.Logf("Test iteration %d: %s", i, c.testDescription)
|
||
|
|
||
|
key, _ := testNewKey(t)
|
||
|
payload := `{"foo":"bar"}`
|
||
|
|
||
|
signer, err := jws.NewSigner(jwa.ES384)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
var signedTokenBytes []byte
|
||
|
if c.tokenType == "" {
|
||
|
signedTokenBytes, err = jws.SignMulti([]byte(payload), jws.WithSigner(signer, key, nil, nil))
|
||
|
require.NoError(t, err)
|
||
|
} else {
|
||
|
headers := jws.NewHeaders()
|
||
|
err = headers.Set(jws.TypeKey, c.tokenType)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
signedTokenBytes, err = jws.SignMulti([]byte(payload), jws.WithSigner(signer, key, nil, headers))
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
token := string(signedTokenBytes)
|
||
|
|
||
|
result := isTokenTypeValid(c.requiredTokenType, token)
|
||
|
require.Equal(t, c.expectedResult, result)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetAndValidateTokenFromString(t *testing.T) {
|
||
|
op := server.NewTesting(t)
|
||
|
defer op.Close(t)
|
||
|
|
||
|
issuer := op.GetURL(t)
|
||
|
discoveryUri := GetDiscoveryUriFromIssuer(issuer)
|
||
|
jwksUri, err := getJwksUriFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, jwksUri, 50*time.Millisecond, 100, false)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
validKey, ok := keyHandler.getKeySet().Get(0)
|
||
|
require.True(t, ok)
|
||
|
|
||
|
validAccessToken := op.GetToken(t).AccessToken
|
||
|
require.NotEmpty(t, validAccessToken)
|
||
|
|
||
|
validIDToken, ok := op.GetToken(t).Extra("id_token").(string)
|
||
|
require.True(t, ok)
|
||
|
require.NotEmpty(t, validIDToken)
|
||
|
|
||
|
invalidKey, invalidPubKey := testNewKey(t)
|
||
|
|
||
|
invalidToken := jwt.New()
|
||
|
err = invalidToken.Set("foo", "bar")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
invalidHeaders := jws.NewHeaders()
|
||
|
err = invalidHeaders.Set(jws.TypeKey, "JWT")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
invalidTokenBytes, err := jwt.Sign(invalidToken, jwa.ES384, invalidKey, jwt.WithHeaders(invalidHeaders))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
invalidSignedToken := string(invalidTokenBytes)
|
||
|
|
||
|
cases := []struct {
|
||
|
testDescription string
|
||
|
tokenString string
|
||
|
key jwk.Key
|
||
|
expectedError bool
|
||
|
}{
|
||
|
{
|
||
|
testDescription: "valid access token, valid key",
|
||
|
tokenString: validAccessToken,
|
||
|
key: validKey,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "valid id token, valid key",
|
||
|
tokenString: validIDToken,
|
||
|
key: validKey,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "empty string, valid key",
|
||
|
tokenString: "",
|
||
|
key: validKey,
|
||
|
expectedError: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "random string, valid key",
|
||
|
tokenString: "foobar",
|
||
|
key: validKey,
|
||
|
expectedError: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "invalid token, valid key",
|
||
|
tokenString: invalidSignedToken,
|
||
|
key: validKey,
|
||
|
expectedError: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "invalid token, invalid key",
|
||
|
tokenString: invalidSignedToken,
|
||
|
key: invalidPubKey,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, c := range cases {
|
||
|
t.Logf("Test iteration %d: %s", i, c.testDescription)
|
||
|
|
||
|
alg, err := getSignatureAlgorithm(c.key.KeyType(), c.key.Algorithm(), jwa.ES384)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
token, err := getAndValidateTokenFromString(c.tokenString, c.key, alg)
|
||
|
if c.expectedError {
|
||
|
require.Error(t, err)
|
||
|
} else {
|
||
|
require.NoError(t, err)
|
||
|
require.NotEmpty(t, token)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestParseToken(t *testing.T) {
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
cases := []struct {
|
||
|
testDescription string
|
||
|
options []options.Option
|
||
|
numKeys int
|
||
|
customIssuer string
|
||
|
customExpirationMinutes int
|
||
|
customClaims map[string]string
|
||
|
expectedErrorContains string
|
||
|
}{
|
||
|
{
|
||
|
testDescription: "successful parse with keyID, one key",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(false),
|
||
|
options.WithJwksRateLimit(100),
|
||
|
},
|
||
|
numKeys: 1,
|
||
|
expectedErrorContains: "",
|
||
|
},
|
||
|
{
|
||
|
testDescription: "successful parse without keyID, one key",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(true),
|
||
|
options.WithJwksRateLimit(100),
|
||
|
},
|
||
|
numKeys: 1,
|
||
|
expectedErrorContains: "",
|
||
|
},
|
||
|
{
|
||
|
testDescription: "successful parse with keyID, two keys",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(false),
|
||
|
options.WithJwksRateLimit(100),
|
||
|
},
|
||
|
numKeys: 2,
|
||
|
expectedErrorContains: "",
|
||
|
},
|
||
|
{
|
||
|
// without lazyLoad, New() panics
|
||
|
testDescription: "unsuccessful parse without keyID, two keys with lazyLoad",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(true),
|
||
|
options.WithJwksRateLimit(100),
|
||
|
options.WithLazyLoadJwks(true),
|
||
|
},
|
||
|
numKeys: 2,
|
||
|
expectedErrorContains: "keyID is disabled, but received a keySet with more than one key",
|
||
|
},
|
||
|
{
|
||
|
testDescription: "wrong issuer, with keyID",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(false),
|
||
|
},
|
||
|
numKeys: 1,
|
||
|
customIssuer: "http://wrong.issuer",
|
||
|
expectedErrorContains: "required issuer \"http://foo.bar\" was not found",
|
||
|
},
|
||
|
{
|
||
|
testDescription: "wrong issuer, without keyID",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(true),
|
||
|
},
|
||
|
numKeys: 1,
|
||
|
customIssuer: "http://wrong.issuer",
|
||
|
expectedErrorContains: "required issuer \"http://foo.bar\" was not found",
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expired token, with keyID",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(false),
|
||
|
},
|
||
|
numKeys: 1,
|
||
|
customExpirationMinutes: -1,
|
||
|
expectedErrorContains: "token has expired",
|
||
|
},
|
||
|
{
|
||
|
testDescription: "expired token, without keyID",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(true),
|
||
|
},
|
||
|
numKeys: 1,
|
||
|
customExpirationMinutes: -1,
|
||
|
expectedErrorContains: "token has expired",
|
||
|
},
|
||
|
{
|
||
|
testDescription: "correct requiredClaim",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithRequiredClaims(map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
}),
|
||
|
options.WithDisableKeyID(false),
|
||
|
},
|
||
|
numKeys: 1,
|
||
|
expectedErrorContains: "",
|
||
|
},
|
||
|
{
|
||
|
testDescription: "correct requiredClaim",
|
||
|
options: []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithRequiredClaims(map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
}),
|
||
|
options.WithDisableKeyID(false),
|
||
|
},
|
||
|
numKeys: 1,
|
||
|
customClaims: map[string]string{
|
||
|
"foo": "baz",
|
||
|
},
|
||
|
expectedErrorContains: "unable to validate required claims",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, c := range cases {
|
||
|
t.Logf("Test iteration %d: %s", i, c.testDescription)
|
||
|
|
||
|
opts := &options.Options{}
|
||
|
|
||
|
for _, setter := range c.options {
|
||
|
setter(opts)
|
||
|
}
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, c.numKeys, opts.DisableKeyID))
|
||
|
|
||
|
h, err := NewHandler(c.options...)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
parseTokenFunc := h.ParseToken
|
||
|
|
||
|
issuer := opts.Issuer
|
||
|
if c.customIssuer != "" {
|
||
|
issuer = c.customIssuer
|
||
|
}
|
||
|
|
||
|
expirationMinutes := 1
|
||
|
if c.customExpirationMinutes != 0 {
|
||
|
expirationMinutes = c.customExpirationMinutes
|
||
|
}
|
||
|
|
||
|
customClaims := make(map[string]string)
|
||
|
customClaims["foo"] = "bar"
|
||
|
if c.customClaims != nil {
|
||
|
customClaims = c.customClaims
|
||
|
}
|
||
|
|
||
|
token := testNewCustomTokenString(t, keySets.privateKeySet, issuer, expirationMinutes, customClaims)
|
||
|
|
||
|
ctx := context.Background()
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token)
|
||
|
|
||
|
if c.expectedErrorContains == "" {
|
||
|
require.NoError(t, err)
|
||
|
} else {
|
||
|
require.Contains(t, err.Error(), c.expectedErrorContains)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestParseTokenWithKeyID(t *testing.T) {
|
||
|
disableKeyID := false
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
opts := []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(disableKeyID),
|
||
|
options.WithJwksRateLimit(100),
|
||
|
}
|
||
|
|
||
|
h, err := NewHandler(opts...)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
parseTokenFunc := h.ParseToken
|
||
|
|
||
|
// first token should succeed
|
||
|
token1 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
ctx := context.Background()
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token1)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// second token should succeed, rotation successful
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
token2 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token2)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// after rotation, first token should fail
|
||
|
_, err = parseTokenFunc(ctx, token1)
|
||
|
require.Error(t, err)
|
||
|
|
||
|
// third token should succeed with two keys
|
||
|
keySets.setKeys(testNewKeySet(t, 2, disableKeyID))
|
||
|
|
||
|
token3 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token3)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// fourth token should fail since they token doesn't contain keyID
|
||
|
keySets.setKeys(testNewKeySet(t, 1, true))
|
||
|
|
||
|
token4 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token4)
|
||
|
require.Error(t, err)
|
||
|
|
||
|
// fifth token should fail since it's the wrong key but correct keyID
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
currentPrivateKey, found := keySets.privateKeySet.Get(0)
|
||
|
require.True(t, found)
|
||
|
|
||
|
currentKeyID := currentPrivateKey.KeyID()
|
||
|
invalidPrivKey, _ := testNewKey(t)
|
||
|
|
||
|
err = invalidPrivKey.Set(jwk.KeyIDKey, currentKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
invalidKeySet := jwk.NewSet()
|
||
|
invalidKeySet.Add(invalidPrivKey)
|
||
|
|
||
|
token5 := testNewTokenString(t, invalidKeySet)
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token5)
|
||
|
require.ErrorIs(t, err, errSignatureVerification)
|
||
|
|
||
|
// sixth token should fail since the jwks can't be refreshed
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
token6 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
testServer.Close()
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token6)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestParseTokenWithoutKeyID(t *testing.T) {
|
||
|
disableKeyID := true
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
opts := []options.Option{
|
||
|
options.WithIssuer("http://foo.bar"),
|
||
|
options.WithDiscoveryUri("http://foo.bar"),
|
||
|
options.WithJwksUri(testServer.URL),
|
||
|
options.WithDisableKeyID(disableKeyID),
|
||
|
options.WithJwksRateLimit(100),
|
||
|
}
|
||
|
|
||
|
h, err := NewHandler(opts...)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
parseTokenFunc := h.ParseToken
|
||
|
|
||
|
// first token should succeed
|
||
|
token1 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
ctx := context.Background()
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token1)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// second token should succeed, with key rotation
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
token2 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token2)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
// after rotation, first token should fail
|
||
|
_, err = parseTokenFunc(ctx, token1)
|
||
|
require.Error(t, err)
|
||
|
|
||
|
// third token should fail since there are two keys present
|
||
|
keySets.setKeys(testNewKeySet(t, 2, disableKeyID))
|
||
|
|
||
|
token3 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token3)
|
||
|
require.Error(t, err)
|
||
|
|
||
|
// fourth token should fail since the jwks can't be refreshed
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
token4 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
testServer.Close()
|
||
|
|
||
|
_, err = parseTokenFunc(ctx, token4)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestGetAndValidateTokenFromStringWithKeyID(t *testing.T) {
|
||
|
disableKeyID := false
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
defer testServer.Close()
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
token1 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
keyID, err := getKeyIDFromTokenString(token1)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
pubKey, err := keyHandler.getKey(context.Background(), keyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
alg, err := getSignatureAlgorithm(pubKey.KeyType(), pubKey.Algorithm(), jwa.ES384)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
_, err = getAndValidateTokenFromString(token1, pubKey, alg)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
token2 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
_, err = getAndValidateTokenFromString(token2, pubKey, alg)
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestGetAndValidateTokenFromStringWithoutKeyID(t *testing.T) {
|
||
|
disableKeyID := true
|
||
|
keySets := testNewTestKeySet(t)
|
||
|
testServer := testNewJwksServer(t, keySets)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
keyHandler, err := newKeyHandler(http.DefaultClient, testServer.URL, 10*time.Millisecond, 100, disableKeyID)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
token1 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
pubKey, err := keyHandler.getKey(context.Background(), "")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
alg, err := getSignatureAlgorithm(pubKey.KeyType(), pubKey.Algorithm(), jwa.ES384)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
_, err = getAndValidateTokenFromString(token1, pubKey, alg)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
keySets.setKeys(testNewKeySet(t, 1, disableKeyID))
|
||
|
|
||
|
token2 := testNewTokenString(t, keySets.privateKeySet)
|
||
|
|
||
|
_, err = getAndValidateTokenFromString(token2, pubKey, alg)
|
||
|
require.ErrorIs(t, err, errSignatureVerification)
|
||
|
}
|
||
|
|
||
|
func TestIsRequiredClaimsValid(t *testing.T) {
|
||
|
cases := []struct {
|
||
|
testDescription string
|
||
|
requiredClaims map[string]interface{}
|
||
|
tokenClaims map[string]interface{}
|
||
|
expectedResult bool
|
||
|
}{
|
||
|
{
|
||
|
testDescription: "both are nil",
|
||
|
requiredClaims: nil,
|
||
|
tokenClaims: nil,
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "both are empty",
|
||
|
requiredClaims: map[string]interface{}{},
|
||
|
tokenClaims: map[string]interface{}{},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "required claims are nil",
|
||
|
requiredClaims: nil,
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "required claims are empty",
|
||
|
requiredClaims: map[string]interface{}{},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "token claims are nil",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
tokenClaims: nil,
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "token claims are empty",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "required is string, token is int",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": 1337,
|
||
|
},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching with string",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching with string and int",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching with string and int in different orders",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"bar": 1337,
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching with string, int and float",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": 13.37,
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": 13.37,
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "not matching with string, int and float",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": 13.37,
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": 12.27,
|
||
|
},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching slice",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"foo"},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"foo"},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching slice with multiple values",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"oof": []string{"foo", "bar"},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"oof": []string{"foo", "bar", "baz"},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "required slice contains in token slice",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"foo"},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"foo", "bar", "baz"},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "not matching slice",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"foo"},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"bar"},
|
||
|
},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching map",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": map[string]string{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": map[string]string{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching map with multiple values",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": map[string]string{
|
||
|
"foo": "bar",
|
||
|
"bar": "foo",
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": map[string]string{
|
||
|
"a": "b",
|
||
|
"foo": "bar",
|
||
|
"bar": "foo",
|
||
|
"c": "d",
|
||
|
},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching map with multiple keys in token claims",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": map[string]string{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": map[string]string{
|
||
|
"a": "b",
|
||
|
"foo": "bar",
|
||
|
"c": "d",
|
||
|
},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "not matching map",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": map[string]string{
|
||
|
"foo": "bar",
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": map[string]int{
|
||
|
"foo": 1337,
|
||
|
},
|
||
|
},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching map with string slice",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": map[string][]string{
|
||
|
"foo": {"bar"},
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": map[string][]string{
|
||
|
"foo": {"foo", "bar", "baz"},
|
||
|
},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "not matching map with string slice",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": map[string][]string{
|
||
|
"foo": {"foobar"},
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": map[string][]string{
|
||
|
"foo": {"foo", "bar", "baz"},
|
||
|
},
|
||
|
},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching slice with map",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": []map[string]string{
|
||
|
{"bar": "baz"},
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": []map[string]string{
|
||
|
{"bar": "baz"},
|
||
|
},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "not matching slice with map",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": []map[string]string{
|
||
|
{"bar": "foobar"},
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": []map[string]string{
|
||
|
{"bar": "baz"},
|
||
|
},
|
||
|
},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching primitive types, slice and map",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"foo"},
|
||
|
"oof": []map[string]string{
|
||
|
{"bar": "baz"},
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"foo"},
|
||
|
"oof": []map[string]string{
|
||
|
{"bar": "baz"},
|
||
|
},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "matching primitive types, slice and map where token contains multiple values",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"bar"},
|
||
|
"oof": []map[string]string{
|
||
|
{"bar": "baz"},
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"bar": 1337,
|
||
|
"baz": []string{"foo", "bar", "baz"},
|
||
|
"oof": []map[string]string{
|
||
|
{"a": "b"},
|
||
|
{"bar": "baz"},
|
||
|
{"c": "d"},
|
||
|
},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "valid interface list in an interface map",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": map[string][]string{
|
||
|
"bar": {"baz"},
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": map[string]interface{}{
|
||
|
"bar": []interface{}{
|
||
|
"uno",
|
||
|
"dos",
|
||
|
"baz",
|
||
|
"tres",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
expectedResult: true,
|
||
|
},
|
||
|
{
|
||
|
testDescription: "invalid interface list in an interface map",
|
||
|
requiredClaims: map[string]interface{}{
|
||
|
"foo": map[string][]string{
|
||
|
"bar": {"baz"},
|
||
|
},
|
||
|
},
|
||
|
tokenClaims: map[string]interface{}{
|
||
|
"foo": map[string]interface{}{
|
||
|
"bar": []interface{}{
|
||
|
"uno",
|
||
|
"dos",
|
||
|
"tres",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
expectedResult: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, c := range cases {
|
||
|
t.Logf("Test iteration %d: %s", i, c.testDescription)
|
||
|
|
||
|
err := isRequiredClaimsValid(c.requiredClaims, c.tokenClaims)
|
||
|
|
||
|
if c.expectedResult {
|
||
|
require.NoError(t, err)
|
||
|
} else {
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGetSignatureAlgorithm(t *testing.T) {
|
||
|
cases := []struct {
|
||
|
inputKty jwa.KeyType
|
||
|
inputAlg string
|
||
|
inputFallbackAlg jwa.SignatureAlgorithm
|
||
|
expectedResult jwa.SignatureAlgorithm
|
||
|
expectedError bool
|
||
|
}{
|
||
|
{
|
||
|
inputKty: jwa.RSA,
|
||
|
inputAlg: "RS256",
|
||
|
inputFallbackAlg: "",
|
||
|
expectedResult: jwa.RS256,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
{
|
||
|
inputKty: jwa.EC,
|
||
|
inputAlg: "ES256",
|
||
|
inputFallbackAlg: "",
|
||
|
expectedResult: jwa.ES256,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
{
|
||
|
inputKty: jwa.RSA,
|
||
|
inputAlg: "",
|
||
|
inputFallbackAlg: "",
|
||
|
expectedResult: jwa.RS256,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
{
|
||
|
inputKty: jwa.EC,
|
||
|
inputAlg: "",
|
||
|
inputFallbackAlg: "",
|
||
|
expectedResult: jwa.ES256,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
{
|
||
|
inputKty: "",
|
||
|
inputAlg: "",
|
||
|
inputFallbackAlg: "",
|
||
|
expectedResult: "",
|
||
|
expectedError: true,
|
||
|
},
|
||
|
{
|
||
|
inputKty: "",
|
||
|
inputAlg: "foobar",
|
||
|
inputFallbackAlg: "",
|
||
|
expectedResult: "",
|
||
|
expectedError: true,
|
||
|
},
|
||
|
{
|
||
|
inputKty: "",
|
||
|
inputAlg: "",
|
||
|
inputFallbackAlg: jwa.ES384,
|
||
|
expectedResult: jwa.ES384,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
{
|
||
|
inputKty: jwa.RSA,
|
||
|
inputAlg: "",
|
||
|
inputFallbackAlg: jwa.ES384,
|
||
|
expectedResult: jwa.ES384,
|
||
|
expectedError: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, c := range cases {
|
||
|
t.Logf("Test iteration %d: inputKty=%s, inputAlg=%s, inputFallbackAlg=%s", i, c.inputKty, c.inputAlg, c.inputFallbackAlg)
|
||
|
|
||
|
result, err := getSignatureAlgorithm(c.inputKty, c.inputAlg, c.inputFallbackAlg)
|
||
|
require.Equal(t, c.expectedResult, result)
|
||
|
|
||
|
if !c.expectedError {
|
||
|
require.NoError(t, err)
|
||
|
} else {
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testNewKey(tb testing.TB) (jwk.Key, jwk.Key) {
|
||
|
tb.Helper()
|
||
|
|
||
|
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||
|
require.NoError(tb, err)
|
||
|
|
||
|
key, err := jwk.New(ecdsaKey)
|
||
|
require.NoError(tb, err)
|
||
|
|
||
|
_, ok := key.(jwk.ECDSAPrivateKey)
|
||
|
require.True(tb, ok)
|
||
|
|
||
|
thumbprint, err := key.Thumbprint(crypto.SHA256)
|
||
|
require.NoError(tb, err)
|
||
|
|
||
|
keyID := fmt.Sprintf("%x", thumbprint)
|
||
|
err = key.Set(jwk.KeyIDKey, keyID)
|
||
|
require.NoError(tb, err)
|
||
|
|
||
|
pubKey, err := jwk.New(ecdsaKey.PublicKey)
|
||
|
require.NoError(tb, err)
|
||
|
|
||
|
_, ok = pubKey.(jwk.ECDSAPublicKey)
|
||
|
require.True(tb, ok)
|
||
|
|
||
|
err = pubKey.Set(jwk.KeyIDKey, keyID)
|
||
|
require.NoError(tb, err)
|
||
|
|
||
|
err = pubKey.Set(jwk.AlgorithmKey, jwa.ES384)
|
||
|
require.NoError(tb, err)
|
||
|
|
||
|
return key, pubKey
|
||
|
}
|
||
|
|
||
|
func testNewTokenString(t *testing.T, privKeySet jwk.Set) string {
|
||
|
t.Helper()
|
||
|
|
||
|
jwtToken := jwt.New()
|
||
|
err := jwtToken.Set(jwt.IssuerKey, "http://foo.bar")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err = jwtToken.Set(jwt.ExpirationKey, time.Now().Add(1*time.Minute).Unix())
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err = jwtToken.Set("foo", "bar")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
headers := jws.NewHeaders()
|
||
|
err = headers.Set(jws.TypeKey, "JWT")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
privKey, found := privKeySet.Get(0)
|
||
|
require.True(t, found)
|
||
|
|
||
|
tokenBytes, err := jwt.Sign(jwtToken, jwa.ES384, privKey, jwt.WithHeaders(headers))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
return string(tokenBytes)
|
||
|
}
|
||
|
|
||
|
func testNewCustomTokenString(t *testing.T, privKeySet jwk.Set, issuer string, expirationMinutes int, customClaims map[string]string) string {
|
||
|
t.Helper()
|
||
|
|
||
|
jwtToken := jwt.New()
|
||
|
err := jwtToken.Set(jwt.IssuerKey, issuer)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err = jwtToken.Set(jwt.ExpirationKey, time.Now().Add(time.Duration(expirationMinutes)*time.Minute).Unix())
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
for k, v := range customClaims {
|
||
|
err := jwtToken.Set(k, v)
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
headers := jws.NewHeaders()
|
||
|
|
||
|
err = headers.Set(jws.TypeKey, "JWT")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
privKey, found := privKeySet.Get(0)
|
||
|
require.True(t, found)
|
||
|
|
||
|
tokenBytes, err := jwt.Sign(jwtToken, jwa.ES384, privKey, jwt.WithHeaders(headers))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
return string(tokenBytes)
|
||
|
}
|