diff --git a/pkg/definition/ast/transformer.go b/pkg/definition/ast/transformer.go index 987b022eb..8e064bbd6 100644 --- a/pkg/definition/ast/transformer.go +++ b/pkg/definition/ast/transformer.go @@ -21,7 +21,6 @@ import ( "strings" "cuelang.org/go/cue/ast" - "cuelang.org/go/cue/parser" "cuelang.org/go/cue/token" ) @@ -111,14 +110,15 @@ func unmarshalField[T ast.Node](field *ast.Field, key string, validator func(T) } unquoted := strings.TrimSpace(TrimCueRawString(basicLit.Value)) - expr, err := parser.ParseExpr("-", WrapCueStruct(unquoted)) - if err != nil { - return fmt.Errorf("unexpected error re-parsing validated %s string: %w", key, err) + + structLit, hasImports, hasPackage, parseErr := ParseCueContent(unquoted) + if parseErr != nil { + return fmt.Errorf("unexpected error re-parsing validated %s string: %w", key, parseErr) } - structLit, ok := expr.(*ast.StructLit) - if !ok { - return fmt.Errorf("expected struct after validation in field %s", key) + if hasImports || hasPackage { + // Keep as string literal to preserve imports/package + return nil } statusField.Value = structLit diff --git a/pkg/definition/ast/transformer_test.go b/pkg/definition/ast/transformer_test.go index 3c99b457c..30e6ede3f 100644 --- a/pkg/definition/ast/transformer_test.go +++ b/pkg/definition/ast/transformer_test.go @@ -22,6 +22,7 @@ import ( "cuelang.org/go/cue/ast" "cuelang.org/go/cue/format" "cuelang.org/go/cue/parser" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -176,6 +177,50 @@ func TestMarshalAndUnmarshalMetadata(t *testing.T) { `, expectContains: "$local", }, + { + name: "status details with import statement should work", + input: ` + attributes: { + status: { + details: #""" + import "strconv" + replicas: strconv.Atoi(context.output.status.replicas) + """# + } + } + `, + expectContains: "import \"strconv\"", + }, + { + name: "status details with package declaration", + input: ` + attributes: { + status: { + details: #""" + package status + + ready: true + phase: "Running" + """# + } + } + `, + expectContains: "package status", + }, + { + name: "status details with import cannot bypass validation", + input: ` + attributes: { + status: { + details: #""" + import "strings" + data: { nested: "structure" } + """# + } + } + `, + expectMarshalErr: "unsupported expression type", + }, } for _, tt := range tests { @@ -379,6 +424,21 @@ func TestMarshalAndUnmarshalHealthPolicy(t *testing.T) { `, expectContains: "isHealth", }, + { + name: "healthPolicy with package declaration", + input: ` + attributes: { + status: { + healthPolicy: #""" + package health + + isHealth: context.output.status.phase == "Running" + """# + } + } + `, + expectContains: "package health", + }, } for _, tt := range tests { @@ -610,6 +670,96 @@ func TestMarshalAndUnmarshalCustomStatus(t *testing.T) { `, expectContains: "message", }, + { + name: "customStatus with import statement should work", + input: ` + attributes: { + status: { + customStatus: #""" + import "strings" + message: strings.Join(["hello", "world"], ",") + """# + } + } + `, + expectContains: "import \"strings\"", + }, + { + name: "customStatus with multiple imports", + input: ` + attributes: { + status: { + customStatus: #""" + import "strings" + import "strconv" + count: strconv.Atoi("42") + message: strings.Join(["Count", strconv.FormatInt(count, 10)], ": ") + """# + } + } + `, + expectContains: "import \"strconv\"", + }, + { + name: "customStatus with import alias", + input: ` + attributes: { + status: { + customStatus: #""" + import str "strings" + message: str.ToUpper(str.Join(["hello", "world"], " ")) + """# + } + } + `, + expectContains: "import str \"strings\"", + }, + { + name: "customStatus with package declaration", + input: ` + attributes: { + status: { + customStatus: #""" + package mytest + + message: "Package test" + """# + } + } + `, + expectContains: "package mytest", + }, + { + name: "customStatus with package and imports", + input: ` + attributes: { + status: { + customStatus: #""" + package mytest + + import "strings" + + message: strings.ToUpper("hello world") + """# + } + } + `, + expectContains: "package mytest", + }, + { + name: "customStatus with import still requires message field", + input: ` + attributes: { + status: { + customStatus: #""" + import "strings" + someOtherField: "value" + """# + } + } + `, + expectMarshalErr: "customStatus must contain a 'message' field", + }, } for _, tt := range tests { @@ -951,6 +1101,57 @@ func TestCustomStatusEdgeCases(t *testing.T) { } } +func TestMixedFieldsWithAndWithoutImports(t *testing.T) { + input := ` + attributes: { + status: { + healthPolicy: #""" + isHealth: context.output.status.phase == "Running" + """# + customStatus: #""" + import "strings" + message: strings.ToLower(context.output.status.phase) + """# + } + } + ` + + file, err := parser.ParseFile("-", input) + require.NoError(t, err) + + var rootField *ast.Field + for _, decl := range file.Decls { + if f, ok := decl.(*ast.Field); ok { + rootField = f + break + } + } + require.NotNil(t, rootField) + + // Encode (struct -> string) + err = EncodeMetadata(rootField) + require.NoError(t, err) + + // Decode (string -> struct/string based on imports) + err = DecodeMetadata(rootField) + require.NoError(t, err) + + // Check healthPolicy (no imports) - should be decoded to struct + healthField, ok := GetFieldByPath(rootField, "attributes.status.healthPolicy") + require.True(t, ok) + _, isStruct := healthField.Value.(*ast.StructLit) + assert.True(t, isStruct, "healthPolicy without imports should be decoded to struct") + + // Check customStatus (has imports) - should remain as string + customField, ok := GetFieldByPath(rootField, "attributes.status.customStatus") + require.True(t, ok) + basicLit, isString := customField.Value.(*ast.BasicLit) + assert.True(t, isString, "customStatus with imports should remain as string") + if isString { + assert.Contains(t, basicLit.Value, "import \"strings\"") + } +} + func TestBackwardCompatibility(t *testing.T) { tests := []struct { name string diff --git a/pkg/definition/ast/utils.go b/pkg/definition/ast/utils.go index ba683cc2b..1063a72cb 100644 --- a/pkg/definition/ast/utils.go +++ b/pkg/definition/ast/utils.go @@ -152,18 +152,15 @@ func ValidateCueStringLiteral[T ast.Node](lit *ast.BasicLit, validator func(T) e return nil } - wrapped := WrapCueStruct(raw) - - expr, err := parser.ParseExpr("-", wrapped) + structLit, _, _, err := ParseCueContent(raw) if err != nil { return fmt.Errorf("invalid cue content in string literal: %w", err) } - node, ok := expr.(T) + node, ok := ast.Node(structLit).(T) if !ok { return fmt.Errorf("parsed expression is not of expected type %T", *new(T)) } - return validator(node) } @@ -197,6 +194,36 @@ func WrapCueStruct(s string) string { return fmt.Sprintf("{\n%s\n}", s) } +// ParseCueContent parses CUE content and extracts struct fields, skipping imports/packages +func ParseCueContent(content string) (*ast.StructLit, bool, bool, error) { + if strings.TrimSpace(content) == "" { + return &ast.StructLit{Elts: []ast.Decl{}}, false, false, nil + } + + file, err := parser.ParseFile("-", content) + if err != nil { + return nil, false, false, err + } + + hasImports := len(file.Imports) > 0 + hasPackage := file.PackageName() != "" + + structLit := &ast.StructLit{ + Elts: []ast.Decl{}, + } + + for _, decl := range file.Decls { + switch decl.(type) { + case *ast.ImportDecl, *ast.Package: + // Skip imports and package declarations + default: + structLit.Elts = append(structLit.Elts, decl) + } + } + + return structLit, hasImports, hasPackage, nil +} + // FindAndValidateField searches for a field at the top level or within top-level if statements func FindAndValidateField(sl *ast.StructLit, fieldName string, validator fieldValidator) (found bool, err error) { // First check top-level fields