192 lines
3.7 KiB
Go
192 lines
3.7 KiB
Go
|
package oidc
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
|
||
|
"github.com/zclconf/go-cty/cty"
|
||
|
"github.com/zclconf/go-cty/cty/gocty"
|
||
|
)
|
||
|
|
||
|
func getCtyValueWithImpliedType(a interface{}) (cty.Value, error) {
|
||
|
if a == nil {
|
||
|
return cty.NilVal, fmt.Errorf("input is nil")
|
||
|
}
|
||
|
|
||
|
valueType, err := gocty.ImpliedType(a)
|
||
|
if err != nil {
|
||
|
return cty.NilVal, fmt.Errorf("unable to get cty.Type: %w", err)
|
||
|
}
|
||
|
|
||
|
return getCtyValueWithType(a, valueType)
|
||
|
}
|
||
|
|
||
|
func getCtyValueWithType(a interface{}, vt cty.Type) (cty.Value, error) {
|
||
|
if a == nil {
|
||
|
return cty.NilVal, fmt.Errorf("input value is nil")
|
||
|
}
|
||
|
|
||
|
if vt == cty.NilType {
|
||
|
return cty.NilVal, fmt.Errorf("input type is nil")
|
||
|
}
|
||
|
|
||
|
value, err := gocty.ToCtyValue(a, vt)
|
||
|
if err != nil {
|
||
|
// we should never receive this error
|
||
|
return cty.NilVal, fmt.Errorf("unable to get cty.Value: %w", err)
|
||
|
}
|
||
|
|
||
|
return value, nil
|
||
|
}
|
||
|
|
||
|
func getCtyValues(a interface{}, b interface{}) (cty.Value, cty.Value, error) {
|
||
|
first, err := getCtyValueWithImpliedType(a)
|
||
|
if err != nil {
|
||
|
return cty.NilVal, cty.NilVal, err
|
||
|
}
|
||
|
|
||
|
second, err := getCtyValueWithType(b, first.Type())
|
||
|
if err != nil {
|
||
|
return cty.NilVal, cty.NilVal, err
|
||
|
}
|
||
|
|
||
|
return first, second, nil
|
||
|
}
|
||
|
|
||
|
func isCtyPrimitiveValueValid(a cty.Value, b cty.Value) bool {
|
||
|
if !isCtyTypeSame(a, b) {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if getCtyType(a) != primitiveCtyType {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
return a.Equals(b) == cty.True
|
||
|
}
|
||
|
|
||
|
func isCtyListValid(a cty.Value, b cty.Value) bool {
|
||
|
if !isCtyTypeSame(a, b) {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if getCtyType(a) != listCtyType {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
listA := a.AsValueSlice()
|
||
|
listB := b.AsValueSlice()
|
||
|
|
||
|
for i := range listA {
|
||
|
if !ctyListContains(listB, listA[i]) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func isCtyMapValid(a cty.Value, b cty.Value) bool {
|
||
|
if !isCtyTypeSame(a, b) {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if getCtyType(a) != mapCtyType {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
mapA := a.AsValueMap()
|
||
|
mapB := b.AsValueMap()
|
||
|
|
||
|
for k := range mapA {
|
||
|
mapBValue, ok := mapB[k]
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
err := isCtyValueValid(mapA[k], mapBValue)
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func ctyListContains(a []cty.Value, b cty.Value) bool {
|
||
|
for i := range a {
|
||
|
err := isCtyValueValid(a[i], b)
|
||
|
if err == nil {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func isCtyTypeSame(a cty.Value, b cty.Value) bool {
|
||
|
return a.Type().Equals(b.Type())
|
||
|
}
|
||
|
|
||
|
func isCtyValueValid(a cty.Value, b cty.Value) error {
|
||
|
if !isCtyTypeSame(a, b) {
|
||
|
return fmt.Errorf("should be type %s, was type: %s", a.Type().GoString(), b.Type().GoString())
|
||
|
}
|
||
|
|
||
|
switch getCtyType(a) {
|
||
|
case primitiveCtyType:
|
||
|
valid := isCtyPrimitiveValueValid(a, b)
|
||
|
if !valid {
|
||
|
return fmt.Errorf("should be %s, was: %s", a.GoString(), b.GoString())
|
||
|
}
|
||
|
case listCtyType:
|
||
|
valid := isCtyListValid(a, b)
|
||
|
if !valid {
|
||
|
return fmt.Errorf("should contain %s, received: %s", a.GoString(), b.GoString())
|
||
|
}
|
||
|
case mapCtyType:
|
||
|
valid := isCtyMapValid(a, b)
|
||
|
if !valid {
|
||
|
return fmt.Errorf("should contain %s, received: %s", a.GoString(), b.GoString())
|
||
|
}
|
||
|
default:
|
||
|
return fmt.Errorf("non-implemented type - should be %s, received: %s", a.GoString(), b.GoString())
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type ctyType int
|
||
|
|
||
|
const (
|
||
|
unknownCtyType = iota
|
||
|
primitiveCtyType
|
||
|
listCtyType
|
||
|
mapCtyType
|
||
|
)
|
||
|
|
||
|
func getCtyType(a cty.Value) ctyType {
|
||
|
if a.Type().IsPrimitiveType() {
|
||
|
return primitiveCtyType
|
||
|
}
|
||
|
|
||
|
switch {
|
||
|
case a.Type().IsListType():
|
||
|
return listCtyType
|
||
|
|
||
|
// Adding the other cases to make it easier in the
|
||
|
// future to build logic for more types.
|
||
|
case a.Type().IsMapType():
|
||
|
return mapCtyType
|
||
|
case a.Type().IsSetType():
|
||
|
return unknownCtyType
|
||
|
case a.Type().IsObjectType():
|
||
|
return unknownCtyType
|
||
|
case a.Type().IsTupleType():
|
||
|
return unknownCtyType
|
||
|
case a.Type().IsCapsuleType():
|
||
|
return unknownCtyType
|
||
|
}
|
||
|
|
||
|
return unknownCtyType
|
||
|
}
|