Skip to content

Commit 3b3a438

Browse files
Separate unnest optimization from composer to capture type info (#1138)
* Separate unnest optimization from composer to capture type info * Simplify the variable tracking during unnest
1 parent 9855c70 commit 3b3a438

3 files changed

Lines changed: 114 additions & 48 deletions

File tree

policy/compiler_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ func TestRuleComposerUnnest(t *testing.T) {
7979
if normalize(unparsed) != normalize(tc.composed) {
8080
t.Errorf("cel.AstToString() got %s, wanted %s", unparsed, tc.composed)
8181
}
82+
if !ast.OutputType().IsEquivalentType(tc.outputType) {
83+
t.Errorf("ast.OutputType() got %v, wanted %v", ast.OutputType(), tc.outputType)
84+
}
8285
r.setup(t, env, ast)
8386
r.run(t)
8487
})

policy/composer.go

Lines changed: 102 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,21 @@ type RuleComposer struct {
7272
// Compose stitches together a set of expressions within a CompiledRule into a single CEL ast.
7373
func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) {
7474
ruleRoot, _ := c.env.Compile("true")
75-
opt := cel.NewStaticOptimizer(
76-
&ruleComposerImpl{
77-
rule: r,
78-
varIndices: []varIndex{},
79-
exprUnnestHeight: c.exprUnnestHeight,
80-
})
81-
return opt.Optimize(c.env, ruleRoot)
75+
composer := &ruleComposerImpl{
76+
rule: r,
77+
varIndices: []varIndex{},
78+
}
79+
opt := cel.NewStaticOptimizer(composer)
80+
ast, iss := opt.Optimize(c.env, ruleRoot)
81+
if iss.Err() != nil {
82+
return nil, iss
83+
}
84+
unnester := &ruleUnnesterImpl{
85+
varIndices: []varIndex{},
86+
exprUnnestHeight: c.exprUnnestHeight,
87+
}
88+
opt = cel.NewStaticOptimizer(unnester)
89+
return opt.Optimize(c.env, ast)
8290
}
8391

8492
type varIndex struct {
@@ -93,8 +101,6 @@ type ruleComposerImpl struct {
93101
rule *CompiledRule
94102
nextVarIndex int
95103
varIndices []varIndex
96-
97-
exprUnnestHeight int
98104
}
99105

100106
// Optimize implements an AST optimizer for CEL which composes an expression graph into a single
@@ -103,21 +109,16 @@ func (opt *ruleComposerImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *as
103109
// The input to optimize is a dummy expression which is completely replaced according
104110
// to the configuration of the rule composition graph.
105111
ruleExpr := opt.optimizeRule(ctx, opt.rule)
106-
// If the rule is deeply nested, it may need to be unnested. This process may generate
107-
// additional variables that are included in the `sortedVariables` list.
108-
ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr)
109112

110-
// Collect all variables associated with the rule expression.
111-
allVars := opt.sortedVariables()
112113
// If there were no variables, return the expression.
113-
if len(allVars) == 0 {
114+
if len(opt.varIndices) == 0 {
114115
return ctx.NewAST(ruleExpr)
115116
}
116117

117118
// Otherwise populate the cel.@block with the variable declarations and wrap the expression
118119
// in the block.
119-
varExprs := make([]ast.Expr, len(allVars))
120-
for i, vi := range allVars {
120+
varExprs := make([]ast.Expr, len(opt.varIndices))
121+
for i, vi := range opt.varIndices {
121122
varExprs[i] = vi.expr
122123
err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType))
123124
if err != nil {
@@ -197,15 +198,90 @@ func (opt *ruleComposerImpl) rewriteVariableName(ctx *cel.OptimizerContext) ast.
197198
})
198199
}
199200

200-
func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.Expr) ast.Expr {
201-
// Split the expr into local variables based on expression height
202-
ruleAST := ctx.NewAST(ruleExpr)
203-
ruleNav := ast.NavigateAST(ruleAST)
201+
// registerVariable creates an entry for a variable name within the cel.@block used to enumerate
202+
// variables within composed policy expression.
203+
func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) {
204+
varName := fmt.Sprintf("variables.%s", v.Name())
205+
indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex)
206+
varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep())
207+
ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx))
208+
vi := varIndex{
209+
index: opt.nextVarIndex,
210+
indexVar: indexVar,
211+
localVar: varName,
212+
expr: varExpr,
213+
celType: v.Declaration().Type()}
214+
opt.varIndices = append(opt.varIndices, vi)
215+
opt.nextVarIndex++
216+
}
217+
218+
type ruleUnnesterImpl struct {
219+
nextVarIndex int
220+
varIndices []varIndex
221+
exprUnnestHeight int
222+
}
223+
224+
func (opt *ruleUnnesterImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
225+
// Since the optimizer is based on the original environment provided to the composer,
226+
// a second pass on the `cel.@block` will require a rebuilding of the cel environment
227+
ruleExpr := ast.NavigateAST(a)
228+
var varExprs []ast.Expr
229+
var varDecls []cel.EnvOption
230+
if ruleExpr.Kind() == ast.CallKind && ruleExpr.AsCall().FunctionName() == "cel.@block" {
231+
// Extract the expr from the cel.@block, args[1], as a navigable expr value.
232+
// Also extract the variable declarations and all associated types from the cel.@block as
233+
// varIndex values, but without doing any rewrites as the types are all correct already.
234+
block := ruleExpr.AsCall()
235+
ruleExpr = block.Args()[1].(ast.NavigableExpr)
236+
237+
// Collect the list of variables associated with the block
238+
blockList := block.Args()[0].(ast.NavigableExpr)
239+
vars := blockList.AsList()
240+
varExprs = make([]ast.Expr, vars.Size())
241+
varDecls = make([]cel.EnvOption, vars.Size())
242+
copy(varExprs, vars.Elements())
243+
for i, v := range varExprs {
244+
// Track the variable he varDecls set.
245+
indexVar := fmt.Sprintf("@index%d", i)
246+
celType := a.GetType(v.ID())
247+
varDecls[i] = cel.Variable(indexVar, celType)
248+
opt.nextVarIndex++
249+
}
250+
}
251+
if len(varDecls) != 0 {
252+
err := ctx.ExtendEnv(varDecls...)
253+
if err != nil {
254+
ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error())
255+
}
256+
}
257+
258+
// Attempt to unnest the rule.
259+
ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr)
260+
// If there were no variables, return the expression.
261+
if len(opt.varIndices) == 0 {
262+
return a
263+
}
264+
265+
// Otherwise populate the cel.@block with the variable declarations and wrap the expression
266+
// in the block.
267+
for i := 0; i < len(opt.varIndices); i++ {
268+
vi := opt.varIndices[i]
269+
varExprs = append(varExprs, vi.expr)
270+
err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType))
271+
if err != nil {
272+
ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error())
273+
}
274+
}
275+
blockExpr := ctx.NewCall("cel.@block", ctx.NewList(varExprs, []int32{}), ruleExpr)
276+
return ctx.NewAST(blockExpr)
277+
}
278+
279+
func (opt *ruleUnnesterImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.NavigableExpr) ast.NavigableExpr {
204280
// Unnest expressions are ordered from leaf to root via the ast.MatchDescendants call.
205-
heights := ast.Heights(ruleAST)
281+
heights := ast.Heights(ast.NewAST(ruleExpr, nil))
206282
unnestMap := map[int64]bool{}
207283
unnestExprs := []ast.NavigableExpr{}
208-
ast.MatchDescendants(ruleNav, func(e ast.NavigableExpr) bool {
284+
ast.MatchDescendants(ruleExpr, func(e ast.NavigableExpr) bool {
209285
// If the expression is a comprehension, then all unnest candidates captured previously that relate
210286
// to the comprehension body should be removed from the list of candidate branches for unnesting.
211287
if e.Kind() == ast.ComprehensionKind {
@@ -243,31 +319,14 @@ func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr
243319
continue
244320
}
245321
reduceHeight(heights, e, opt.exprUnnestHeight)
246-
opt.registerBranchVariable(ctx, e)
322+
opt.registerUnnestVariable(ctx, e)
247323
}
248324
return ruleExpr
249325
}
250326

251-
// registerVariable creates an entry for a variable name within the cel.@block used to enumerate
252-
// variables within composed policy expression.
253-
func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) {
254-
varName := fmt.Sprintf("variables.%s", v.Name())
255-
indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex)
256-
varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep())
257-
ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx))
258-
vi := varIndex{
259-
index: opt.nextVarIndex,
260-
indexVar: indexVar,
261-
localVar: varName,
262-
expr: varExpr,
263-
celType: v.Declaration().Type()}
264-
opt.varIndices = append(opt.varIndices, vi)
265-
opt.nextVarIndex++
266-
}
267-
268-
// registerBranchVariable creates an entry for a variable name within the cel.@block used to unnest
327+
// registerUnnestVariable creates an entry for a variable name within the cel.@block used to unnest
269328
// a deeply nested logical branch or logical operator.
270-
func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) {
329+
func (opt *ruleUnnesterImpl) registerUnnestVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) {
271330
indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex)
272331
varExprCopy := ctx.CopyASTAndMetadata(ctx.NewAST(varExpr))
273332
vi := varIndex{
@@ -281,11 +340,6 @@ func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, v
281340
opt.nextVarIndex++
282341
}
283342

284-
// sortedVariables returns the variables ordered by their declaration index.
285-
func (opt *ruleComposerImpl) sortedVariables() []varIndex {
286-
return opt.varIndices
287-
}
288-
289343
// compositionStep interface represents an intermediate stage of rule and match expression composition
290344
//
291345
// The CompiledRule and CompiledMatch types are meant to represent standalone tuples of condition

policy/helper_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ var (
216216
expr string
217217
composed string
218218
composerOpts []ComposerOption
219+
outputType *cel.Type
219220
}{
220221
{
221222
name: "unnest",
@@ -233,6 +234,7 @@ var (
233234
: @index3],
234235
@index2 ? optional.of("some divisible by 2") : @index4)
235236
`,
237+
outputType: cel.OptionalType(cel.StringType),
236238
},
237239
{
238240
name: "required_labels",
@@ -248,6 +250,7 @@ var (
248250
"invalid values provided on one or more labels: %s".format([@index2])],
249251
@index3 ? optional.of(@index4) : (@index5 ? optional.of(@index6) : optional.none()))
250252
`,
253+
outputType: cel.OptionalType(cel.StringType),
251254
},
252255
{
253256
name: "required_labels",
@@ -264,6 +267,7 @@ var (
264267
(@index1.size() > 0)
265268
? optional.of("missing one or more required labels: %s".format([@index1]))
266269
: @index3)`,
270+
outputType: cel.OptionalType(cel.StringType),
267271
},
268272
{
269273
name: "nested_rule2",
@@ -277,6 +281,7 @@ var (
277281
resource.?user.orValue("").startsWith("bad")
278282
? (@index2 ? {"banned": "restricted_region"} : {"banned": "bad_actor"})
279283
: @index3)`,
284+
outputType: cel.MapType(cel.StringType, cel.StringType),
280285
},
281286
{
282287
name: "nested_rule2",
@@ -293,6 +298,7 @@ var (
293298
: (!(resource.origin in @index0)
294299
? {"banned": "unconfigured_region"}
295300
: {}))`,
301+
outputType: cel.MapType(cel.StringType, cel.StringType),
296302
},
297303
{
298304
name: "limits",
@@ -310,6 +316,7 @@ var (
310316
? ((now.getHours() < 21) ? optional.of(@index4 + "!") :
311317
((now.getHours() < 22) ? optional.of(@index4 + "!!") : @index5))
312318
: @index6)`,
319+
outputType: cel.OptionalType(cel.StringType),
313320
},
314321
{
315322
name: "limits",
@@ -327,6 +334,7 @@ var (
327334
? ((now.getHours() < 21) ? optional.of(@index4 + "!") : @index5)
328335
: optional.of(@index3.format([@index0, @index2])))
329336
`,
337+
outputType: cel.OptionalType(cel.StringType),
330338
},
331339
{
332340
name: "limits",
@@ -342,6 +350,7 @@ var (
342350
((now.getHours() < 22) ? optional.of(@index4 + "!!") :
343351
((now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none()))],
344352
(now.getHours() >= 20) ? @index5 : optional.of(@index3.format([@index0, @index2])))`,
353+
outputType: cel.OptionalType(cel.StringType),
345354
},
346355
}
347356

0 commit comments

Comments
 (0)