oidc/cty.go

192 lines
3.7 KiB
Go

package oidc
import (
"fmt"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/gocty"
)
func getCtyValueWithImpliedType(a interface{}) (cty.Value, error) {
if a == nil {
return cty.NilVal, fmt.Errorf("input is nil")
}
valueType, err := gocty.ImpliedType(a)
if err != nil {
return cty.NilVal, fmt.Errorf("unable to get cty.Type: %w", err)
}
return getCtyValueWithType(a, valueType)
}
func getCtyValueWithType(a interface{}, vt cty.Type) (cty.Value, error) {
if a == nil {
return cty.NilVal, fmt.Errorf("input value is nil")
}
if vt == cty.NilType {
return cty.NilVal, fmt.Errorf("input type is nil")
}
value, err := gocty.ToCtyValue(a, vt)
if err != nil {
// we should never receive this error
return cty.NilVal, fmt.Errorf("unable to get cty.Value: %w", err)
}
return value, nil
}
func getCtyValues(a interface{}, b interface{}) (cty.Value, cty.Value, error) {
first, err := getCtyValueWithImpliedType(a)
if err != nil {
return cty.NilVal, cty.NilVal, err
}
second, err := getCtyValueWithType(b, first.Type())
if err != nil {
return cty.NilVal, cty.NilVal, err
}
return first, second, nil
}
func isCtyPrimitiveValueValid(a cty.Value, b cty.Value) bool {
if !isCtyTypeSame(a, b) {
return false
}
if getCtyType(a) != primitiveCtyType {
return false
}
return a.Equals(b) == cty.True
}
func isCtyListValid(a cty.Value, b cty.Value) bool {
if !isCtyTypeSame(a, b) {
return false
}
if getCtyType(a) != listCtyType {
return false
}
listA := a.AsValueSlice()
listB := b.AsValueSlice()
for i := range listA {
if !ctyListContains(listB, listA[i]) {
return false
}
}
return true
}
func isCtyMapValid(a cty.Value, b cty.Value) bool {
if !isCtyTypeSame(a, b) {
return false
}
if getCtyType(a) != mapCtyType {
return false
}
mapA := a.AsValueMap()
mapB := b.AsValueMap()
for k := range mapA {
mapBValue, ok := mapB[k]
if !ok {
return false
}
err := isCtyValueValid(mapA[k], mapBValue)
if err != nil {
return false
}
}
return true
}
func ctyListContains(a []cty.Value, b cty.Value) bool {
for i := range a {
err := isCtyValueValid(a[i], b)
if err == nil {
return true
}
}
return false
}
func isCtyTypeSame(a cty.Value, b cty.Value) bool {
return a.Type().Equals(b.Type())
}
func isCtyValueValid(a cty.Value, b cty.Value) error {
if !isCtyTypeSame(a, b) {
return fmt.Errorf("should be type %s, was type: %s", a.Type().GoString(), b.Type().GoString())
}
switch getCtyType(a) {
case primitiveCtyType:
valid := isCtyPrimitiveValueValid(a, b)
if !valid {
return fmt.Errorf("should be %s, was: %s", a.GoString(), b.GoString())
}
case listCtyType:
valid := isCtyListValid(a, b)
if !valid {
return fmt.Errorf("should contain %s, received: %s", a.GoString(), b.GoString())
}
case mapCtyType:
valid := isCtyMapValid(a, b)
if !valid {
return fmt.Errorf("should contain %s, received: %s", a.GoString(), b.GoString())
}
default:
return fmt.Errorf("non-implemented type - should be %s, received: %s", a.GoString(), b.GoString())
}
return nil
}
type ctyType int
const (
unknownCtyType = iota
primitiveCtyType
listCtyType
mapCtyType
)
func getCtyType(a cty.Value) ctyType {
if a.Type().IsPrimitiveType() {
return primitiveCtyType
}
switch {
case a.Type().IsListType():
return listCtyType
// Adding the other cases to make it easier in the
// future to build logic for more types.
case a.Type().IsMapType():
return mapCtyType
case a.Type().IsSetType():
return unknownCtyType
case a.Type().IsObjectType():
return unknownCtyType
case a.Type().IsTupleType():
return unknownCtyType
case a.Type().IsCapsuleType():
return unknownCtyType
}
return unknownCtyType
}