LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
IntervalAnalysis.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysis.cpp - Interval analysis implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
13#include "llzk/Util/Debug.h"
15
16#include <mlir/Dialect/SCF/IR/SCF.h>
17
18#include <llvm/ADT/TypeSwitch.h>
19
20using namespace mlir;
21
22namespace llzk {
23
24using namespace array;
25using namespace boolean;
26using namespace cast;
27using namespace component;
28using namespace constrain;
29using namespace felt;
30using namespace function;
31
32/* ExpressionValue */
33
35 if (expr == nullptr && rhs.expr == nullptr) {
36 return i == rhs.i;
37 }
38 if (expr == nullptr || rhs.expr == nullptr) {
39 return false;
40 }
41 return i == rhs.i && *expr == *rhs.expr;
42}
43
45boolToFelt(llvm::SMTSolverRef solver, const ExpressionValue &expr, unsigned bitwidth) {
46 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
47 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
48 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.getExpr(), one, zero);
49 return expr.withExpression(boolToFeltConv);
50}
51
52ExpressionValue
53intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
54 Interval res = lhs.i.intersect(rhs.i);
55 auto exprEq = solver->mkEqual(lhs.expr, rhs.expr);
56 return ExpressionValue(exprEq, res);
57}
58
60add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
62 res.i = lhs.i + rhs.i;
63 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
64 return res;
65}
66
68sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
70 res.i = lhs.i - rhs.i;
71 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
72 return res;
73}
74
76mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
78 res.i = lhs.i * rhs.i;
79 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
80 return res;
81}
82
84div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs,
85 const ExpressionValue &rhs) {
87 auto divRes = lhs.i / rhs.i;
88 if (failed(divRes)) {
89 op->emitWarning(
90 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
91 " Range of division result will be treated as unbounded."
92 )
93 .report();
94 res.i = Interval::Entire(lhs.getField());
95 } else {
96 res.i = *divRes;
97 }
98 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
99 return res;
100}
101
103mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
104 ExpressionValue res;
105 res.i = lhs.i % rhs.i;
106 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
107 return res;
108}
109
111bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
112 ExpressionValue res;
113 res.i = lhs.i & rhs.i;
114 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
115 return res;
116}
117
119shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
120 ExpressionValue res;
121 res.i = lhs.i << rhs.i;
122 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
123 return res;
124}
125
127shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
128 ExpressionValue res;
129 res.i = lhs.i >> rhs.i;
130 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
131 return res;
132}
133
135cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs) {
136 ExpressionValue res;
137 const Field &f = lhs.getField();
138 // Default result is any boolean output for when we are unsure about the comparison result.
139 res.i = Interval::Boolean(f);
140 switch (op.getPredicate()) {
141 case FeltCmpPredicate::EQ:
142 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
143 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
144 res.i = lhs.i == rhs.i ? Interval::True(f) : Interval::False(f);
145 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
146 res.i = Interval::False(f);
147 }
148 break;
149 case FeltCmpPredicate::NE:
150 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
151 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
152 res.i = lhs.i != rhs.i ? Interval::True(f) : Interval::False(f);
153 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
154 res.i = Interval::True(f);
155 }
156 break;
157 case FeltCmpPredicate::LT:
158 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
159 if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
160 res.i = Interval::True(f);
161 }
162 if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
163 res.i = Interval::False(f);
164 }
165 break;
166 case FeltCmpPredicate::LE:
167 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
168 if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
169 res.i = Interval::True(f);
170 }
171 if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
172 res.i = Interval::False(f);
173 }
174 break;
175 case FeltCmpPredicate::GT:
176 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
177 if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
178 res.i = Interval::True(f);
179 }
180 if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
181 res.i = Interval::False(f);
182 }
183 break;
184 case FeltCmpPredicate::GE:
185 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
186 if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
187 res.i = Interval::True(f);
188 }
189 if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
190 res.i = Interval::False(f);
191 }
192 break;
193 }
194 return res;
195}
196
198boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
199 ExpressionValue res;
200 res.i = boolAnd(lhs.i, rhs.i);
201 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
202 return res;
203}
204
206boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
207 ExpressionValue res;
208 res.i = boolOr(lhs.i, rhs.i);
209 res.expr = solver->mkOr(lhs.expr, rhs.expr);
210 return res;
211}
212
214boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
215 ExpressionValue res;
216 res.i = boolXor(lhs.i, rhs.i);
217 // There's no Xor, so we do (L || R) && !(L && R)
218 res.expr = solver->mkAnd(
219 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
220 );
221 return res;
222}
223
225 llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs
226) {
227 ExpressionValue res;
228 res.i = Interval::Entire(lhs.getField());
229 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
230 .Case<OrFeltOp>([&](auto) { return solver->mkBVOr(lhs.expr, rhs.expr); })
231 .Case<XorFeltOp>([&](auto) {
232 return solver->mkBVXor(lhs.expr, rhs.expr);
233 }).Default([&](auto *unsupported) {
234 llvm::report_fatal_error(
235 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
236 );
237 return nullptr;
238 });
239
240 return res;
241}
242
243ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val) {
244 ExpressionValue res;
245 res.i = -val.i;
246 res.expr = solver->mkBVNeg(val.expr);
247 return res;
248}
249
250ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val) {
251 ExpressionValue res;
252 res.i = ~val.i;
253 res.expr = solver->mkBVNot(val.expr);
254 return res;
255}
256
257ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val) {
258 ExpressionValue res;
259 res.i = boolNot(val.i);
260 res.expr = solver->mkNot(val.expr);
261 return res;
262}
263
265fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val) {
266 const Field &field = val.getField();
267 ExpressionValue res;
268 res.i = Interval::Entire(field);
269 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
270 .Case<InvFeltOp>([&](auto) {
271 // The definition of an inverse X^-1 is Y s.t. XY % prime = 1.
272 // To create this expression, we create a new symbol for Y and add the
273 // XY % prime = 1 constraint to the solver.
274 std::string symName = buildStringViaInsertionOp(*op);
275 llvm::SMTExprRef invSym = field.createSymbol(solver, symName.c_str());
276 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.bitWidth());
277 llvm::SMTExprRef prime = solver->mkBitvector(toAPSInt(field.prime()), field.bitWidth());
278 llvm::SMTExprRef mult = solver->mkBVMul(val.getExpr(), invSym);
279 llvm::SMTExprRef mod = solver->mkBVURem(mult, prime);
280 llvm::SMTExprRef constraint = solver->mkEqual(mod, one);
281 solver->addConstraint(constraint);
282 return invSym;
283 }).Default([](Operation *unsupported) {
284 llvm::report_fatal_error(
285 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
286 );
287 return nullptr;
288 });
289
290 return res;
291}
292
293void ExpressionValue::print(mlir::raw_ostream &os) const {
294 if (expr) {
295 expr->print(os);
296 } else {
297 os << "<null expression>";
298 }
299
300 os << " ( interval: " << i << " )";
301}
302
303/* IntervalAnalysisLattice */
304
305ChangeResult IntervalAnalysisLattice::join(const AbstractSparseLattice &other) {
306 const auto *rhs = dynamic_cast<const IntervalAnalysisLattice *>(&other);
307 if (!rhs) {
308 llvm::report_fatal_error("invalid join lattice type");
309 }
310 ChangeResult res = val.update(rhs->getValue());
311 for (auto &v : rhs->constraints) {
312 if (!constraints.contains(v)) {
313 constraints.insert(v);
314 res |= ChangeResult::Change;
315 }
316 }
317 return res;
318}
319
320ChangeResult IntervalAnalysisLattice::meet(const AbstractSparseLattice &other) {
321 const auto *rhs = dynamic_cast<const IntervalAnalysisLattice *>(&other);
322 if (!rhs) {
323 llvm::report_fatal_error("invalid join lattice type");
324 }
325 // Intersect the intervals
326 ExpressionValue lhsExpr = val.getScalarValue();
327 ExpressionValue rhsExpr = rhs->getValue().getScalarValue();
328 Interval newInterval = lhsExpr.getInterval().intersect(rhsExpr.getInterval());
329 ChangeResult res = setValue(lhsExpr.withInterval(newInterval));
330 for (auto &v : rhs->constraints) {
331 if (!constraints.contains(v)) {
332 constraints.insert(v);
333 res |= ChangeResult::Change;
334 }
335 }
336 return res;
337}
338
339void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const {
340 os << "IntervalAnalysisLattice { " << val << " }";
341}
342
344 if (val == newVal) {
345 return ChangeResult::NoChange;
346 }
347 val = newVal;
348 return ChangeResult::Change;
349}
350
352 LatticeValue newVal(e);
353 return setValue(newVal);
354}
355
357 if (!constraints.contains(e)) {
358 constraints.insert(e);
359 return ChangeResult::Change;
360 }
361 return ChangeResult::NoChange;
362}
363
364/* IntervalDataFlowAnalysis */
365
366const SourceRefLattice *
367IntervalDataFlowAnalysis::getSourceRefLattice(Operation *baseOp, Value val) {
368 ProgramPoint *pp = _dataflowSolver.getProgramPointAfter(baseOp);
369 auto defaultSourceRefLattice = _dataflowSolver.lookupState<SourceRefLattice>(pp);
370 ensure(defaultSourceRefLattice, "failed to get lattice");
371 if (Operation *defOp = val.getDefiningOp()) {
372 ProgramPoint *defPoint = _dataflowSolver.getProgramPointAfter(defOp);
373 auto sourceRefLattice = _dataflowSolver.lookupState<SourceRefLattice>(defPoint);
374 ensure(sourceRefLattice, "failed to get SourceRefLattice for value");
375 return sourceRefLattice;
376 }
377 return defaultSourceRefLattice;
378}
379
381 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
382) {
383 // We only perform the visitation on operations within functions
384 FuncDefOp fn = op->getParentOfType<FuncDefOp>();
385 if (!fn) {
386 return success();
387 }
388
389 // If there are no operands or results, skip.
390 if (operands.empty() && results.empty()) {
391 return success();
392 }
393
394 // Get the values or defaults from the operand lattices
395 llvm::SmallVector<LatticeValue> operandVals;
396 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
397 for (unsigned opNum = 0; opNum < op->getNumOperands(); ++opNum) {
398 Value val = op->getOperand(opNum);
399 SourceRefLatticeValue refSet = getSourceRefLattice(op, val)->getOrDefault(val);
400 if (refSet.isSingleValue()) {
401 operandRefs.push_back(refSet.getSingleValue());
402 } else {
403 operandRefs.push_back(std::nullopt);
404 }
405 // First, lookup the operand value after it is initialized
406 auto priorState = operands[opNum]->getValue();
407 if (priorState.getScalarValue().getExpr() != nullptr) {
408 operandVals.push_back(priorState);
409 continue;
410 }
411
412 // Else, look up the stored value by `SourceRef`.
413 // We only care about scalar type values, so we ignore composite types, which
414 // are currently limited to structs and arrays.
415 Type valTy = val.getType();
416 if (llvm::isa<ArrayType, StructType>(valTy)) {
417 ExpressionValue anyVal(field.get(), createFeltSymbol(val));
418 operandVals.emplace_back(anyVal);
419 continue;
420 }
421
422 ensure(refSet.isScalar(), "should have ruled out array values already");
423
424 if (refSet.getScalarValue().empty()) {
425 // If we can't compute the reference, then there must be some unsupported
426 // op the reference analysis cannot handle. We emit a warning and return
427 // early, since there's no meaningful computation we can do for this op.
428 op->emitWarning()
429 .append(
430 "state of ", val, " is empty; defining operation is unsupported by SourceRef analysis"
431 )
432 .report();
433 // We still return success so we can return overapproximated and partial
434 // results to the user.
435 return success();
436 } else if (!refSet.isSingleValue()) {
437 std::string warning;
438 debug::Appender(warning) << "operand " << val << " is not a single value " << refSet
439 << ", overapproximating";
440 op->emitWarning(warning).report();
441 // Here, we will override the prior lattice value with a new symbol, representing
442 // "any" value, then use that value for the operands.
443 ExpressionValue anyVal(field.get(), createFeltSymbol(val));
444 operandVals.emplace_back(anyVal);
445 } else {
446 const SourceRef &ref = refSet.getSingleValue();
447 // See if we've written the value before. If so, use that.
448 if (auto it = memberWriteResults.find(ref); it != memberWriteResults.end()) {
449 operandVals.emplace_back(it->second);
450 } else {
451 ExpressionValue exprVal(field.get(), getOrCreateSymbol(ref));
452 operandVals.emplace_back(exprVal);
453 }
454 }
455
456 // Since we initialized a value that was not found in the before lattice,
457 // update that value in the lattice so we can find it later, but we don't
458 // need to propagate the changes, since we already have what we need.
459 Lattice *operandLattice = getLatticeElement(val);
460 (void)operandLattice->setValue(operandVals[opNum]);
461 }
462
463 // Now, the way we update is dependent on the type of the operation.
464 if (isConstOp(op)) {
465 llvm::DynamicAPInt constVal = getConst(op);
466 llvm::SMTExprRef expr;
467 if (isBoolConstOp(op)) {
468 expr = createConstBoolExpr(constVal != 0);
469 } else {
470 expr = createConstBitvectorExpr(constVal);
471 }
472
473 ExpressionValue latticeVal(field.get(), expr, constVal);
474 propagateIfChanged(results[0], results[0]->setValue(latticeVal));
475 } else if (isArithmeticOp(op)) {
476 ExpressionValue result;
477 if (operands.size() == 2) {
478 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
479 } else {
480 result = performUnaryArithmetic(op, operandVals[0]);
481 }
482 // Also intersect with prior interval, if it's initialized
483 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
484 if (prior.getExpr()) {
485 result = result.withInterval(result.getInterval().intersect(prior.getInterval()));
486 }
487 propagateIfChanged(results[0], results[0]->setValue(result));
488 } else if (EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
489 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
490 ExpressionValue lhsExpr = operandVals[0].getScalarValue();
491 ExpressionValue rhsExpr = operandVals[1].getScalarValue();
492
493 // Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
494 // These patterns enforce that s is one of c0, ..., cN.
495 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
496 if (succeeded(res)) {
497 for (Value signalVal : res->first) {
498 applyInterval(emitEq, signalVal, res->second);
499 }
500 }
501
502 ExpressionValue constraint = intersection(smtSolver, lhsExpr, rhsExpr);
503 // Update the LHS and RHS to the same value, but restricted intervals
504 // based on the constraints.
505 const Interval &constrainInterval = constraint.getInterval();
506 applyInterval(emitEq, lhsVal, constrainInterval);
507 applyInterval(emitEq, rhsVal, constrainInterval);
508 } else if (auto assertOp = llvm::dyn_cast<AssertOp>(op)) {
509 // assert enforces that the operand is true. So we apply an interval of [1, 1]
510 // to the operand.
511 Value cond = assertOp.getCondition();
512 applyInterval(assertOp, cond, Interval::True(field.get()));
513 // Also add the solver constraint that the expression must be true.
514 auto assertExpr = operandVals[0].getScalarValue();
515 // No need to propagate the constraint
516 (void)getLatticeElement(cond)->addSolverConstraint(assertExpr);
517 } else if (auto writem = llvm::dyn_cast<MemberWriteOp>(op)) {
518 // Update values stored in a member
519 ExpressionValue writeVal = operandVals[1].getScalarValue();
520 auto cmp = writem.getComponent();
521 // We also need to update the interval on the assigned symbol
522 SourceRefLatticeValue refSet = getSourceRefLattice(op, cmp)->getOrDefault(cmp);
523 if (refSet.isSingleValue()) {
524 auto memberDefRes = writem.getMemberDefOp(tables);
525 if (succeeded(memberDefRes)) {
526 SourceRefIndex idx(memberDefRes.value());
527 SourceRef memberRef = refSet.getSingleValue().createChild(idx);
528 llvm::SMTExprRef expr = getOrCreateSymbol(memberRef);
529 ExpressionValue written(expr, writeVal.getInterval());
530
531 if (auto it = memberWriteResults.find(memberRef); it != memberWriteResults.end()) {
532 const ExpressionValue &old = it->second;
533 Interval combinedWrite = old.getInterval().join(written.getInterval());
534 memberWriteResults[memberRef] = old.withInterval(combinedWrite);
535 } else {
536 memberWriteResults[memberRef] = written;
537 }
538
539 // Propagate to all member readers we've collected so far.
540 for (Lattice *readerLattice : memberReadResults[memberRef]) {
541 ExpressionValue prior = readerLattice->getValue().getScalarValue();
544 propagateIfChanged(readerLattice, readerLattice->setValue(newVal));
545 }
546 }
547 }
548 } else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
549 // Casts don't modify the intervals, but they do modify the SMT types.
550 ExpressionValue expr = operandVals[0].getScalarValue();
551 // We treat all ints and indexes as felts with the exception of comparison
552 // results, which are bools. So if `expr` is a bool, this cast needs to
553 // upcast to a felt.
554 if (expr.isBoolSort(smtSolver)) {
555 expr = boolToFelt(smtSolver, expr, field.get().bitWidth());
556 }
557 propagateIfChanged(results[0], results[0]->setValue(expr));
558 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
559 // Fetch the lattice for after the parent operation so we can propagate
560 // the yielded value to subsequent operations.
561 Operation *parent = op->getParentOp();
562 ensure(parent, "yield operation must have parent operation");
563 // Bind the operand values to the result values of the parent
564 for (unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
565 Value parentRes = parent->getResult(idx);
566 Lattice *resLattice = getLatticeElement(parentRes);
567 // Merge with the existing value, if present (e.g., another branch)
568 // has possible value that must be merged.
569 ExpressionValue exprVal = resLattice->getValue().getScalarValue();
570 ExpressionValue newResVal = operandVals[idx].getScalarValue();
571 if (exprVal.getExpr() != nullptr) {
572 newResVal = exprVal.withInterval(exprVal.getInterval().join(newResVal.getInterval()));
573 } else {
574 newResVal = ExpressionValue(createFeltSymbol(parentRes), newResVal.getInterval());
575 }
576 propagateIfChanged(resLattice, resLattice->setValue(newResVal));
577 }
578 } else if (
579 // We do not need to explicitly handle read ops since they are resolved at the operand value
580 // step where `SourceRef`s are queries.
581 !isReadOp(op)
582 // We do not currently handle return ops as the analysis is currently limited to constrain
583 // functions, which return no value.
584 && !isReturnOp(op)
585 // The analysis ignores definition ops.
586 && !isDefinitionOp(op)
587 // We do not need to analyze the creation of structs.
588 && !llvm::isa<CreateStructOp>(op)
589 ) {
590 op->emitWarning("unhandled operation, analysis may be incomplete").report();
591 }
592
593 return success();
594}
595
597 auto it = refSymbols.find(r);
598 if (it != refSymbols.end()) {
599 return it->second;
600 }
601 llvm::SMTExprRef sym = createFeltSymbol(r);
602 refSymbols[r] = sym;
603 return sym;
604}
605
606llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const SourceRef &r) const {
607 return createFeltSymbol(buildStringViaPrint(r).c_str());
608}
609
610llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(Value v) const {
611 return createFeltSymbol(buildStringViaPrint(v).c_str());
612}
613
614llvm::SMTExprRef IntervalDataFlowAnalysis::createFeltSymbol(const char *name) const {
615 return field.get().createSymbol(smtSolver, name);
616}
617
618llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
619 ensure(isConstOp(op), "op is not a const op");
620
621 // NOTE: I think clang-format makes these hard to read by default
622 // clang-format off
623 llvm::DynamicAPInt fieldConst = TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
624 .Case<FeltConstantOp>([&](FeltConstantOp feltConst) {
625 llvm::APSInt constOpVal(feltConst.getValue());
626 return field.get().reduce(constOpVal);
627 })
628 .Case<arith::ConstantIndexOp>([&](arith::ConstantIndexOp indexConst) {
629 return DynamicAPInt(indexConst.value());
630 })
631 .Case<arith::ConstantIntOp>([&](arith::ConstantIntOp intConst) {
632 auto valAttr = dyn_cast<IntegerAttr>(intConst.getValue());
633 ensure(valAttr != nullptr, "arith::ConstantIntOp must have an IntegerAttr as its value");
634 return toDynamicAPInt(valAttr.getValue());
635 })
636 .Default([](Operation *illegalOp) {
637 std::string err;
638 debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
639 llvm::report_fatal_error(Twine(err));
640 return llvm::DynamicAPInt();
641 });
642 // clang-format on
643 return fieldConst;
644}
645
646ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
647 Operation *op, const LatticeValue &a, const LatticeValue &b
648) {
649 ensure(isArithmeticOp(op), "is not arithmetic op");
650
651 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
652 ensure(lhs.getExpr(), "cannot perform arithmetic over null lhs smt expr");
653 ensure(rhs.getExpr(), "cannot perform arithmetic over null rhs smt expr");
654
655 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
656 .Case<AddFeltOp>([&](auto _) { return add(smtSolver, lhs, rhs); })
657 .Case<SubFeltOp>([&](auto _) { return sub(smtSolver, lhs, rhs); })
658 .Case<MulFeltOp>([&](auto _) { return mul(smtSolver, lhs, rhs); })
659 .Case<DivFeltOp>([&](auto divOp) { return div(smtSolver, divOp, lhs, rhs); })
660 .Case<UnsignedModFeltOp>([&](auto _) { return mod(smtSolver, lhs, rhs); })
661 .Case<AndFeltOp>([&](auto _) { return bitAnd(smtSolver, lhs, rhs); })
662 .Case<ShlFeltOp>([&](auto _) { return shiftLeft(smtSolver, lhs, rhs); })
663 .Case<ShrFeltOp>([&](auto _) { return shiftRight(smtSolver, lhs, rhs); })
664 .Case<CmpOp>([&](auto cmpOp) { return cmp(smtSolver, cmpOp, lhs, rhs); })
665 .Case<AndBoolOp>([&](auto _) { return boolAnd(smtSolver, lhs, rhs); })
666 .Case<OrBoolOp>([&](auto _) { return boolOr(smtSolver, lhs, rhs); })
667 .Case<XorBoolOp>([&](auto _) {
668 return boolXor(smtSolver, lhs, rhs);
669 }).Default([&](auto *unsupported) {
670 unsupported
671 ->emitWarning(
672 "unsupported binary arithmetic operation, defaulting to over-approximated intervals"
673 )
674 .report();
675 return fallbackBinaryOp(smtSolver, unsupported, lhs, rhs);
676 });
677
678 ensure(res.getExpr(), "arithmetic produced null smt expr");
679 return res;
680}
681
683IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeValue &a) {
684 ensure(isArithmeticOp(op), "is not arithmetic op");
685
686 auto val = a.getScalarValue();
687 ensure(val.getExpr(), "cannot perform arithmetic over null smt expr");
688
689 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
690 .Case<NegFeltOp>([&](auto _) { return neg(smtSolver, val); })
691 .Case<NotFeltOp>([&](auto _) { return notOp(smtSolver, val); })
692 .Case<NotBoolOp>([&](auto _) { return boolNot(smtSolver, val); })
693 // The inverse op is currently overapproximated
694 .Case<InvFeltOp>([&](auto inv) {
695 return fallbackUnaryOp(smtSolver, inv, val);
696 }).Default([&](auto *unsupported) {
697 unsupported
698 ->emitWarning(
699 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
700 )
701 .report();
702 return fallbackUnaryOp(smtSolver, unsupported, val);
703 });
704
705 ensure(res.getExpr(), "arithmetic produced null smt expr");
706 return res;
707}
708
709void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val, Interval newInterval) {
710 Lattice *valLattice = getLatticeElement(val);
711 ExpressionValue oldLatticeVal = valLattice->getValue().getScalarValue();
712 // Intersect with the current value to accumulate restrictions across constraints.
713 Interval intersection = oldLatticeVal.getInterval().intersect(newInterval);
714 ExpressionValue newLatticeVal = oldLatticeVal.withInterval(intersection);
715 ChangeResult changed = valLattice->setValue(newLatticeVal);
716
717 if (auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
718 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
719
720 // Apply the interval from the constrain function inputs to the compute function inputs
721 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
722 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
723 auto structOp = fnOp->getParentOfType<StructDefOp>();
724 FuncDefOp computeFn = structOp.getComputeFuncOp();
725 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
726 Lattice *computeEntryLattice = getLatticeElement(computeArg);
727
728 SourceRef ref(computeArg);
729 ExpressionValue newArgVal(getOrCreateSymbol(ref), newInterval);
730 propagateIfChanged(computeEntryLattice, computeEntryLattice->setValue(newArgVal));
731 }
732 }
733
734 // Now we descend into val's operands, if it has any.
735 Operation *definingOp = val.getDefiningOp();
736 if (!definingOp) {
737 propagateIfChanged(valLattice, changed);
738 return;
739 }
740
741 const Field &f = field.get();
742
743 // This is a rules-based operation. If we have a rule for a given operation,
744 // then we can make some kind of update, otherwise we leave the intervals
745 // as is.
746 // - First we'll define all the rules so the type switch can be less messy
747
748 // cmp.<pred> restricts each side of the comparison if the result is known.
749 auto cmpCase = [&](CmpOp cmpOp) {
750 // Cmp output range is [0, 1], so in order to do something, we must have newInterval
751 // either "true" (1) or "false" (0).
752 // -- In the case of a contradictory circuit, however, the cmp result is allowed
753 // to be empty.
754 ensure(
755 newInterval.isBoolean() || newInterval.isEmpty(),
756 "new interval for CmpOp is not boolean or empty"
757 );
758 if (!newInterval.isDegenerate()) {
759 // The comparison result is unknown, so we can't update the operand ranges
760 return;
761 }
762
763 bool cmpTrue = newInterval.rhs() == f.one();
764
765 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
766 auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs);
767 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
768 rhsExpr = rhsLat->getValue().getScalarValue();
769
770 Interval newLhsInterval, newRhsInterval;
771 const Interval &lhsInterval = lhsExpr.getInterval();
772 const Interval &rhsInterval = rhsExpr.getInterval();
773
774 FeltCmpPredicate pred = cmpOp.getPredicate();
775 // predicate cases
776 auto eqCase = [&]() {
777 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
778 (pred == FeltCmpPredicate::NE && !cmpTrue);
779 };
780 auto neCase = [&]() {
781 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
782 (pred == FeltCmpPredicate::EQ && !cmpTrue);
783 };
784 auto ltCase = [&]() {
785 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
786 (pred == FeltCmpPredicate::GE && !cmpTrue);
787 };
788 auto leCase = [&]() {
789 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
790 (pred == FeltCmpPredicate::GT && !cmpTrue);
791 };
792 auto gtCase = [&]() {
793 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
794 (pred == FeltCmpPredicate::LE && !cmpTrue);
795 };
796 auto geCase = [&]() {
797 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
798 (pred == FeltCmpPredicate::LT && !cmpTrue);
799 };
800
801 // new intervals based on case
802 if (eqCase()) {
803 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
804 } else if (neCase()) {
805 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
806 // In this case, we know lhs and rhs cannot satisfy this assertion, so they have
807 // an empty value range.
808 newLhsInterval = newRhsInterval = Interval::Empty(f);
809 } else if (lhsInterval.isDegenerate()) {
810 // rhs must not overlap with lhs
811 newLhsInterval = lhsInterval;
812 newRhsInterval = rhsInterval.difference(lhsInterval);
813 } else if (rhsInterval.isDegenerate()) {
814 // lhs must not overlap with rhs
815 newLhsInterval = lhsInterval.difference(rhsInterval);
816 newRhsInterval = rhsInterval;
817 } else {
818 // Leave unchanged
819 newLhsInterval = lhsInterval;
820 newRhsInterval = rhsInterval;
821 }
822 } else if (ltCase()) {
823 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
824 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
825 } else if (leCase()) {
826 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
827 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
828 } else if (gtCase()) {
829 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
830 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
831 } else if (geCase()) {
832 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
833 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
834 } else {
835 cmpOp->emitWarning("unhandled cmp predicate").report();
836 return;
837 }
838
839 // Now we recurse to each operand
840 applyInterval(cmpOp, lhs, newLhsInterval);
841 applyInterval(cmpOp, rhs, newRhsInterval);
842 };
843
844 // Multiplication cases:
845 // - If the result of a multiplication is non-zero, then both operands must be
846 // non-zero.
847 // - If one operand is a constant, we can propagate the new interval when multiplied
848 // by the multiplicative inverse of the constant.
849 auto mulCase = [&](MulFeltOp mulOp) {
850 // We check for the constant case first.
851 auto constCase = [&](FeltConstantOp constOperand, Value multiplicand) {
852 auto latVal = getLatticeElement(multiplicand)->getValue().getScalarValue();
853 APInt constVal = constOperand.getValue();
854 if (constVal.isZero()) {
855 // There's no inverse for zero, so we do nothing.
856 return;
857 }
858 Interval updatedInterval = newInterval * Interval::Degenerate(f, f.inv(constVal));
859 applyInterval(mulOp, multiplicand, updatedInterval);
860 };
861
862 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
863
864 auto lhsConstOp = dyn_cast_if_present<FeltConstantOp>(lhs.getDefiningOp());
865 auto rhsConstOp = dyn_cast_if_present<FeltConstantOp>(rhs.getDefiningOp());
866 // If both are consts, we don't need to do anything
867 if (lhsConstOp && rhsConstOp) {
868 return;
869 } else if (lhsConstOp) {
870 constCase(lhsConstOp, rhs);
871 return;
872 } else if (rhsConstOp) {
873 constCase(rhsConstOp, lhs);
874 return;
875 }
876
877 // Otherwise, try to propagate non-zero information.
878 auto zeroInt = Interval::Degenerate(f, f.zero());
879 if (newInterval.intersect(zeroInt).isNotEmpty()) {
880 // The multiplication may be zero, so we can't reduce the operands to be non-zero
881 return;
882 }
883
884 auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs);
885 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
886 rhsExpr = rhsLat->getValue().getScalarValue();
887 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
888 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
889 applyInterval(mulOp, lhs, newLhsInterval);
890 applyInterval(mulOp, rhs, newRhsInterval);
891 };
892
893 auto addCase = [&](AddFeltOp addOp) {
894 Value lhs = addOp.getLhs(), rhs = addOp.getRhs();
895 Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs);
896 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
897 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
898
899 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
900
901 Interval derivedLhsInt = newInterval - currRhsInt;
902 Interval derivedRhsInt = newInterval - currLhsInt;
903
904 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
905 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
906
907 applyInterval(addOp, lhs, finalLhsInt);
908 applyInterval(addOp, rhs, finalRhsInt);
909 };
910
911 auto subCase = [&](SubFeltOp subOp) {
912 Value lhs = subOp.getLhs(), rhs = subOp.getRhs();
913 Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs);
914 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
915 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
916
917 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
918
919 Interval derivedLhsInt = newInterval + currRhsInt;
920 Interval derivedRhsInt = currLhsInt - newInterval;
921
922 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
923 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
924
925 applyInterval(subOp, lhs, finalLhsInt);
926 applyInterval(subOp, rhs, finalRhsInt);
927 };
928
929 auto readmCase = [&](MemberReadOp readmOp) {
930 const SourceRefLattice *sourceRefLattice = getSourceRefLattice(valUser, val);
931 SourceRefLatticeValue sourceRefVal = sourceRefLattice->getOrDefault(val);
932
933 if (sourceRefVal.isSingleValue()) {
934 const SourceRef &ref = sourceRefVal.getSingleValue();
935 memberReadResults[ref].insert(valLattice);
936
937 // Also propagate to all other member read results for this member
938 for (Lattice *l : memberReadResults[ref]) {
939 if (l != valLattice) {
940 propagateIfChanged(l, l->setValue(newLatticeVal));
941 }
942 }
943 }
944 };
945
946 auto readArrCase = [&](ReadArrayOp _) {
947 const SourceRefLattice *sourceRefLattice = getSourceRefLattice(valUser, val);
948 SourceRefLatticeValue sourceRefVal = sourceRefLattice->getOrDefault(val);
949
950 if (sourceRefVal.isSingleValue()) {
951 const SourceRef &ref = sourceRefVal.getSingleValue();
952 memberReadResults[ref].insert(valLattice);
953
954 // Also propagate to all other member read results for this member
955 for (Lattice *l : memberReadResults[ref]) {
956 if (l != valLattice) {
957 propagateIfChanged(l, l->setValue(newLatticeVal));
958 }
959 }
960 }
961 };
962
963 // For casts, just pass the interval along to the cast's operand.
964 auto castCase = [&](Operation *op) { applyInterval(op, op->getOperand(0), newInterval); };
965
966 // - Apply the rules given the op.
967 // NOTE: disabling clang-format for this because it makes the last case statement
968 // look ugly.
969 // clang-format off
970 TypeSwitch<Operation *>(definingOp)
971 .Case<CmpOp>([&](auto op) { cmpCase(op); })
972 .Case<AddFeltOp>([&](auto op) { return addCase(op); })
973 .Case<SubFeltOp>([&](auto op) { return subCase(op); })
974 .Case<MulFeltOp>([&](auto op) { mulCase(op); })
975 .Case<MemberReadOp>([&](auto op){ readmCase(op); })
976 .Case<ReadArrayOp>([&](auto op){ readArrCase(op); })
977 .Case<IntToFeltOp, FeltToIndexOp>([&](auto op) { castCase(op); })
978 .Default([&](Operation *) { });
979 // clang-format on
980
981 // Propagate after recursion to avoid having recursive calls unset the value.
982 propagateIfChanged(valLattice, changed);
983}
984
985FailureOr<std::pair<DenseSet<Value>, Interval>>
986IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value lhs, Value rhs) {
987 auto isZeroConst = [this](Value v) {
988 Operation *op = v.getDefiningOp();
989 if (!op) {
990 return false;
991 }
992 if (!isConstOp(op)) {
993 return false;
994 }
995 return getConst(op) == field.get().zero();
996 };
997 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
998 Value exprTree = nullptr;
999 if (lhsIsZero && !rhsIsZero) {
1000 exprTree = rhs;
1001 } else if (!lhsIsZero && rhsIsZero) {
1002 exprTree = lhs;
1003 } else {
1004 return failure();
1005 }
1006
1007 // We now explore the expression tree for multiplications of subtractions/signal values.
1008 std::optional<SourceRef> signalRef = std::nullopt;
1009 DenseSet<Value> signalVals;
1010 SmallVector<DynamicAPInt> consts;
1011 SmallVector<Value> frontier {exprTree};
1012 while (!frontier.empty()) {
1013 Value v = frontier.back();
1014 frontier.pop_back();
1015 Operation *op = v.getDefiningOp();
1016
1017 FeltConstantOp c;
1018 Value signalVal;
1019 auto handleRefValue = [this, &baseOp, &signalRef, &signalVal, &signalVals]() {
1020 SourceRefLatticeValue refSet =
1021 getSourceRefLattice(baseOp, signalVal)->getOrDefault(signalVal);
1022 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1023 return failure();
1024 }
1025 SourceRef r = refSet.getSingleValue();
1026 if (signalRef.has_value() && signalRef.value() != r) {
1027 return failure();
1028 } else if (!signalRef.has_value()) {
1029 signalRef = r;
1030 }
1031 signalVals.insert(signalVal);
1032 return success();
1033 };
1034
1035 auto subPattern = m_CommutativeOp<SubFeltOp>(m_RefValue(&signalVal), m_Constant(&c));
1036 if (op && matchPattern(op, subPattern)) {
1037 if (failed(handleRefValue())) {
1038 return failure();
1039 }
1040 auto constInt = APSInt(c.getValue());
1041 consts.push_back(field.get().reduce(constInt));
1042 continue;
1043 } else if (m_RefValue(&signalVal).match(v)) {
1044 if (failed(handleRefValue())) {
1045 return failure();
1046 }
1047 consts.push_back(field.get().zero());
1048 continue;
1049 }
1050
1051 Value a, b;
1052 auto mulPattern = m_CommutativeOp<MulFeltOp>(matchers::m_Any(&a), matchers::m_Any(&b));
1053 if (op && matchPattern(op, mulPattern)) {
1054 frontier.push_back(a);
1055 frontier.push_back(b);
1056 continue;
1057 }
1058
1059 return failure();
1060 }
1061
1062 // Now, we aggregate the Interval. If we have sparse values (e.g., 0, 2, 4),
1063 // we will create a larger range of [0, 4], since we don't support multiple intervals.
1064 std::sort(consts.begin(), consts.end());
1065 Interval iv = UnreducedInterval(consts.front(), consts.back()).reduce(field.get());
1066 return std::make_pair(std::move(signalVals), iv);
1067}
1068
1069/* StructIntervals */
1070
1072 mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx
1073) {
1074
1075 auto computeIntervalsImpl = [&solver, &ctx, this](
1076 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &memberRanges,
1077 llvm::SetVector<ExpressionValue> &solverConstraints
1078 ) {
1079 // Since every lattice value does not contain every value, we will traverse
1080 // the function backwards (from most up-to-date to least-up-to-date lattices)
1081 // searching for the source refs. Once a source ref is found, we remove it
1082 // from the search set.
1083
1084 SourceRefSet searchSet;
1085 for (const auto &ref : SourceRef::getAllSourceRefs(structDef, fn)) {
1086 // We only want to compute intervals for field elements and not composite types.
1087 if (!ref.isScalar()) {
1088 continue;
1089 }
1090 searchSet.insert(ref);
1091 }
1092
1093 // Iterate over arguments
1094 for (BlockArgument arg : fn.getArguments()) {
1095 SourceRef ref {arg};
1096 if (searchSet.erase(ref)) {
1097 const IntervalAnalysisLattice *lattice = solver.lookupState<IntervalAnalysisLattice>(arg);
1098 // If we never referenced this argument, use a default value
1099 ExpressionValue expr = lattice->getValue().getScalarValue();
1100 if (!expr.getExpr()) {
1101 expr = expr.withInterval(Interval::Entire(ctx.getField()));
1102 }
1103 memberRanges[ref] = expr.getInterval();
1104 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1105 }
1106 }
1107
1108 // Iterate over members that were touched by the analysis
1109 for (const auto &[ref, lattices] : ctx.intervalDFA->getMemberReadResults()) {
1110 // All lattices should have the same value, so we can get the front.
1111 if (!lattices.empty() && searchSet.erase(ref)) {
1112 const IntervalAnalysisLattice *lattice = *lattices.begin();
1113 memberRanges[ref] = lattice->getValue().getScalarValue().getInterval();
1114 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1115 }
1116 }
1117
1118 for (const auto &[ref, val] : ctx.intervalDFA->getMemberWriteResults()) {
1119 if (searchSet.erase(ref)) {
1120 memberRanges[ref] = val.getInterval();
1121 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1122 }
1123 }
1124
1125 // For all unfound refs, default to the entire range.
1126 for (const auto &ref : searchSet) {
1127 memberRanges[ref] = Interval::Entire(ctx.getField());
1128 }
1129
1130 // Sort the outputs since we assembled things out of order.
1131 llvm::sort(memberRanges, [](auto a, auto b) { return std::get<0>(a) < std::get<0>(b); });
1132 };
1133
1134 computeIntervalsImpl(structDef.getComputeFuncOp(), computeMemberRanges, computeSolverConstraints);
1135 computeIntervalsImpl(
1136 structDef.getConstrainFuncOp(), constrainMemberRanges, constrainSolverConstraints
1137 );
1138
1139 return success();
1140}
1141
1142void StructIntervals::print(mlir::raw_ostream &os, bool withConstraints, bool printCompute) const {
1143 auto writeIntervals =
1144 [&os, &withConstraints](
1145 const char *fnName, const llvm::MapVector<SourceRef, Interval> &memberRanges,
1146 const llvm::SetVector<ExpressionValue> &solverConstraints, bool printName
1147 ) {
1148 int indent = 4;
1149 if (printName) {
1150 os << '\n';
1151 os.indent(indent) << fnName << " {";
1152 indent += 4;
1153 }
1154
1155 if (memberRanges.empty()) {
1156 os << "}\n";
1157 return;
1158 }
1159
1160 for (auto &[ref, interval] : memberRanges) {
1161 os << '\n';
1162 os.indent(indent) << ref << " in " << interval;
1163 }
1164
1165 if (withConstraints) {
1166 os << "\n\n";
1167 os.indent(indent) << "Solver Constraints { ";
1168 if (solverConstraints.empty()) {
1169 os << "}\n";
1170 } else {
1171 for (const auto &e : solverConstraints) {
1172 os << '\n';
1173 os.indent(indent + 4);
1174 e.getExpr()->print(os);
1175 }
1176 os << '\n';
1177 os.indent(indent) << '}';
1178 }
1179 }
1180
1181 if (printName) {
1182 os << '\n';
1183 os.indent(indent - 4) << '}';
1184 }
1185 };
1186
1187 os << "StructIntervals { ";
1188 if (constrainMemberRanges.empty() && (!printCompute || computeMemberRanges.empty())) {
1189 os << "}\n";
1190 return;
1191 }
1192
1193 if (printCompute) {
1194 writeIntervals(FUNC_NAME_COMPUTE, computeMemberRanges, computeSolverConstraints, printCompute);
1195 }
1196 writeIntervals(
1197 FUNC_NAME_CONSTRAIN, constrainMemberRanges, constrainSolverConstraints, printCompute
1198 );
1199
1200 os << "\n}\n";
1201}
1202
1203} // namespace llzk
Tracks a solver expression and an interval range for that expression.
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
bool isBoolSort(llvm::SMTSolverRef solver) const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
Definition Field.h:27
llvm::SMTExprRef createSymbol(llvm::SMTSolverRef solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
Definition Field.h:71
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:39
unsigned bitWidth() const
Definition Field.h:68
const LatticeValue & getValue() const
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
mlir::ChangeResult addSolverConstraint(ExpressionValue e)
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getMemberReadResults() const
const llvm::DenseMap< SourceRef, ExpressionValue > & getMemberWriteResults() const
Intervals over a finite field.
Definition Intervals.h:200
bool isEmpty() const
Definition Intervals.h:304
static Interval True(const Field &f)
Definition Intervals.h:219
Interval intersect(const Interval &rhs) const
Intersect.
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
Definition Intervals.h:221
static Interval Entire(const Field &f)
Definition Intervals.h:223
bool isDegenerate() const
Definition Intervals.h:306
static Interval False(const Field &f)
Definition Intervals.h:217
Interval join(const Interval &rhs) const
Union.
Defines an index into an LLZK object.
Definition SourceRef.h:36
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
A lattice for use in dense analysis.
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:127
SourceRef createChild(SourceRefIndex r) const
Definition SourceRef.h:251
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root)
Produce all possible SourceRefs that are present starting from the given root.
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx)
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
Definition Intervals.cpp:60
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
Definition Intervals.cpp:83
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Definition Intervals.cpp:75
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
Definition Intervals.cpp:20
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Definition Intervals.cpp:68
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:601
std::variant< ScalarTy, ArrayTy > & getValue()
IntervalAnalysisLattice * getLatticeElement(mlir::Value value) override
auto m_RefValue()
Definition Matchers.h:69
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolToFelt(llvm::SMTSolverRef solver, const ExpressionValue &expr, unsigned bitwidth)
DynamicAPInt toDynamicAPInt(StringRef str)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_Constant()
Definition Matchers.h:89
ExpressionValue shiftLeft(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue shiftRight(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue boolXor(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(llvm::SMTSolverRef solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackBinaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue intersection(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
Definition Matchers.h:47
ExpressionValue notOp(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolNot(llvm::SMTSolverRef solver, const ExpressionValue &val)
ExpressionValue boolOr(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(llvm::SMTSolverRef solver, Operation *op, const ExpressionValue &val)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA