LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Intervals.cpp
Go to the documentation of this file.
1//===-- Intervals.cpp ---------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
14
15#include <llvm/ADT/SmallVector.h>
16
17using namespace mlir;
18
19namespace llzk {
20
21/* UnreducedInterval */
22
24 if (a > b) {
25 return Interval::Empty(field);
26 }
27 if (width() >= field.prime()) {
28 return Interval::Entire(field);
29 }
30 auto lhs = field.reduce(a), rhs = field.reduce(b);
31 if (rhs == lhs) {
32 return Interval::Degenerate(field, lhs);
33 }
34
35 const auto &half = field.half();
36 if (lhs <= rhs) {
37 if (lhs < half && rhs < half) {
38 return Interval::TypeA(field, lhs, rhs);
39 } else if (lhs < half) {
40 return Interval::TypeC(field, lhs, rhs);
41 } else {
42 return Interval::TypeB(field, lhs, rhs);
43 }
44 } else {
45 if (lhs >= half && rhs < half) {
46 return Interval::TypeF(field, lhs, rhs);
47 } else {
48 return Interval::Entire(field);
49 }
50 }
51}
52
54 const auto &lhs = *this;
55 return UnreducedInterval(std::max(lhs.a, rhs.a), std::min(lhs.b, rhs.b));
56}
57
59 const auto &lhs = *this;
60 return UnreducedInterval(std::min(lhs.a, rhs.a), std::max(lhs.b, rhs.b));
61}
62
64 if (isEmpty() || rhs.isEmpty()) {
65 return *this;
66 }
67 DynamicAPInt bound = rhs.b - 1;
68 return UnreducedInterval(a, std::min(b, bound));
69}
70
72 if (isEmpty() || rhs.isEmpty()) {
73 return *this;
74 }
75 return UnreducedInterval(a, std::min(b, rhs.b));
76}
77
79 if (isEmpty() || rhs.isEmpty()) {
80 return *this;
81 }
82 DynamicAPInt bound = rhs.a + 1;
83 return UnreducedInterval(std::max(a, bound), b);
84}
85
87 if (isEmpty() || rhs.isEmpty()) {
88 return *this;
89 }
90 return UnreducedInterval(std::max(a, rhs.a), b);
91}
92
94 if (isEmpty()) {
95 return *this;
96 }
97 return UnreducedInterval(-b, -a);
98}
99
101 DynamicAPInt low = lhs.a + rhs.a, high = lhs.b + rhs.b;
102 return UnreducedInterval(low, high);
103}
104
106 return lhs + (-rhs);
107}
108
110 DynamicAPInt v1 = lhs.a * rhs.a;
111 DynamicAPInt v2 = lhs.a * rhs.b;
112 DynamicAPInt v3 = lhs.b * rhs.a;
113 DynamicAPInt v4 = lhs.b * rhs.b;
114
115 auto minVal = std::min({v1, v2, v3, v4});
116 auto maxVal = std::max({v1, v2, v3, v4});
117
118 return UnreducedInterval(minVal, maxVal);
119}
120
122 return isNotEmpty() && rhs.isNotEmpty() && (b >= rhs.a) && (a <= rhs.b);
123}
124
125std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs) {
126 if ((lhs.a < rhs.a) || ((lhs.a == rhs.a) && (lhs.b < rhs.b))) {
127 return std::strong_ordering::less;
128 }
129 if ((lhs.a > rhs.a) || ((lhs.a == rhs.a) && (lhs.b > rhs.b))) {
130 return std::strong_ordering::greater;
131 }
132 return std::strong_ordering::equal;
133}
134
135DynamicAPInt UnreducedInterval::width() const {
136 DynamicAPInt w;
137 if (a > b) {
138 // This would be reduced to an empty Interval, so the width is just zero.
139 w = 0;
140 } else {
141 // Since the range is inclusive, we add one to the difference to get the true width.
142 w = (b - a) + 1;
143 }
144 ensure(w >= 0, "cannot have negative width");
145 return w;
146}
147
148/* Interval */
149
150const Field &checkFields(const Interval &lhs, const Interval &rhs) {
151 ensure(
152 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
153 );
154 return lhs.getField();
155}
156
157namespace {
158
159llvm::SmallVector<UnreducedInterval, 2> getUnsignedCanonicalParts(const Interval &iv) {
160 const Field &f = iv.getField();
161 llvm::SmallVector<UnreducedInterval, 2> parts;
162 if (iv.isEmpty()) {
163 return parts;
164 }
165 if (iv.isEntire()) {
166 parts.emplace_back(f.zero(), f.maxVal());
167 return parts;
168 }
169 if (iv.isTypeF()) {
170 parts.emplace_back(iv.lhs(), f.maxVal());
171 parts.emplace_back(f.zero(), iv.rhs());
172 return parts;
173 }
174
175 parts.emplace_back(iv.lhs(), iv.rhs());
176 return parts;
177}
178
179llvm::SmallVector<UnreducedInterval, 2> getSignedCanonicalParts(const Interval &iv) {
180 const Field &f = iv.getField();
181 llvm::SmallVector<UnreducedInterval, 2> parts;
182 if (iv.isEmpty()) {
183 return parts;
184 }
185 if (iv.isEntire()) {
186 parts.emplace_back(f.half() - f.prime(), f.half() - f.one());
187 return parts;
188 }
189 if (iv.isDegenerate()) {
190 DynamicAPInt v = iv.lhs();
191 if (v < f.half()) {
192 parts.emplace_back(v, v);
193 } else {
194 parts.emplace_back(v - f.prime(), v - f.prime());
195 }
196 return parts;
197 }
198 if (iv.isTypeA()) {
199 parts.emplace_back(iv.lhs(), iv.rhs());
200 return parts;
201 }
202 if (iv.isTypeB()) {
203 parts.emplace_back(iv.lhs() - f.prime(), iv.rhs() - f.prime());
204 return parts;
205 }
206 if (iv.isTypeC()) {
207 parts.emplace_back(f.half() - f.prime(), iv.rhs() - f.prime());
208 parts.emplace_back(iv.lhs(), f.half() - f.one());
209 return parts;
210 }
211
212 ensure(iv.isTypeF(), "expected TypeF interval");
213 parts.emplace_back(iv.lhs() - f.prime(), iv.rhs());
214 return parts;
215}
216
217bool containsZero(const UnreducedInterval &iv) { return iv.getLHS() <= 0 && iv.getRHS() >= 0; }
218
219Interval joinDivisionPiece(
220 const Field &f, Interval acc, const llvm::DynamicAPInt &q0, const llvm::DynamicAPInt &q1,
221 const llvm::DynamicAPInt &q2, const llvm::DynamicAPInt &q3
222) {
223 DynamicAPInt minQ = std::min({q0, q1, q2, q3});
224 DynamicAPInt maxQ = std::max({q0, q1, q2, q3});
225 Interval piece = UnreducedInterval(minQ, maxQ).reduce(f);
226 return acc.join(piece);
227}
228
229} // namespace
230
232 if (isEmpty()) {
233 // Since ranges are inclusive, empty is encoded as `[a, b]` where `a` > `b`.
234 // This matches the definition provided by UnreducedInterval::width().
235 return UnreducedInterval(field.get().one(), field.get().zero());
236 }
237 if (isEntire()) {
238 return UnreducedInterval(field.get().zero(), field.get().maxVal());
239 }
240 return UnreducedInterval(a, b);
241}
242
244 if (is<Type::TypeF>()) {
245 return UnreducedInterval(a - field.get().prime(), b);
246 }
247 return toUnreduced();
248}
249
251 ensure(is<Type::TypeA, Type::TypeB, Type::TypeC>(), "unsupported range type");
252 return UnreducedInterval(a - field.get().prime(), b - field.get().prime());
253}
254
256 const auto &lhs = *this;
257 const Field &f = checkFields(lhs, rhs);
258
259 // Trivial cases
260 if (lhs.isEntire() || rhs.isEntire()) {
261 return Interval::Entire(f);
262 }
263 if (lhs.isEmpty()) {
264 return rhs;
265 }
266 if (rhs.isEmpty()) {
267 return lhs;
268 }
269 if (lhs.isDegenerate() || rhs.isDegenerate()) {
270 return lhs.toUnreduced().doUnion(rhs.toUnreduced()).reduce(f);
271 }
272
273 // More complex cases
274 if (areOneOf<
277 auto newLhs = std::min(lhs.a, rhs.a);
278 auto newRhs = std::max(lhs.b, rhs.b);
279 if (newLhs == newRhs) {
280 return Interval::Degenerate(f, newLhs);
281 }
282 return Interval(rhs.ty, f, newLhs, newRhs);
283 }
285 auto lhsUnred = lhs.firstUnreduced();
286 auto opt1 = rhs.firstUnreduced().doUnion(lhsUnred);
287 auto opt2 = rhs.secondUnreduced().doUnion(lhsUnred);
288 if (opt1.width() <= opt2.width()) {
289 return opt1.reduce(f);
290 }
291 return opt2.reduce(f);
292 }
294 return lhs.firstUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
295 }
297 return lhs.secondUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
298 }
300 return Interval::Entire(f);
301 }
302 if (areOneOf<
305 lhs, rhs
306 )) {
307 return rhs.join(lhs);
308 }
309 llvm::report_fatal_error("unhandled join case");
310 return Interval::Entire(f);
311}
312
314 const auto &lhs = *this;
315 const Field &f = checkFields(lhs, rhs);
316 // Trivial cases
317 if (lhs == rhs) {
318 return lhs;
319 }
320 if (lhs.isEmpty() || rhs.isEmpty()) {
321 return Interval::Empty(f);
322 }
323 if (lhs.isEntire()) {
324 return rhs;
325 }
326 if (rhs.isEntire()) {
327 return lhs;
328 }
329 if (lhs.isDegenerate() && rhs.isDegenerate()) {
330 // These must not be equal
331 return Interval::Empty(f);
332 }
333 if (lhs.isDegenerate()) {
334 return Interval::TypeA(f, lhs.a, lhs.a).intersect(rhs);
335 }
336 if (rhs.isDegenerate()) {
337 return Interval::TypeA(f, rhs.a, rhs.a).intersect(lhs);
338 }
339
340 // More complex cases
341 if (areOneOf<
344 auto maxA = std::max(lhs.a, rhs.a);
345 auto minB = std::min(lhs.b, rhs.b);
346 if (maxA < minB) {
347 return Interval(lhs.ty, f, maxA, minB);
348 } else if (maxA == minB) {
349 return Interval::Degenerate(f, maxA);
350 } else {
351 return Interval::Empty(f);
352 }
353 }
355 return Interval::Empty(f);
356 }
358 return lhs.firstUnreduced().intersect(rhs.firstUnreduced()).reduce(f);
359 }
361 return lhs.secondUnreduced().intersect(rhs.firstUnreduced()).reduce(f);
362 }
364 auto rhsUnred = rhs.firstUnreduced();
365 auto opt1 = lhs.firstUnreduced().intersect(rhsUnred).reduce(f);
366 auto opt2 = lhs.secondUnreduced().intersect(rhsUnred).reduce(f);
367 ensure(!opt1.isEntire() && !opt2.isEntire(), "impossible intersection");
368 if (opt1.isEmpty()) {
369 return opt2;
370 }
371 if (opt2.isEmpty()) {
372 return opt1;
373 }
374 return opt1.join(opt2);
375 }
376 if (areOneOf<
379 lhs, rhs
380 )) {
381 return rhs.intersect(lhs);
382 }
383 return Interval::Empty(f);
384}
385
387 const Field &f = checkFields(*this, other);
388 // intersect checks that we're in the same field
390 if (intersection.isEmpty()) {
391 // There's nothing to remove, so just return this
392 return *this;
393 }
394
395 // Trivial cases with a non-empty intersection
396 if (isDegenerate() || other.isEntire()) {
397 return Interval::Empty(f);
398 }
399 if (isEntire()) {
400 // Since we don't support punching arbitrary holes in ranges, we only reduce
401 // entire ranges if other is [0, b] or [a, prime - 1]
402 if (other.a == f.zero()) {
403 return UnreducedInterval(other.b + f.one(), f.maxVal()).reduce(f);
404 }
405 if (other.b == f.maxVal()) {
406 return UnreducedInterval(f.zero(), other.a - f.one()).reduce(f);
407 }
408
409 return *this;
410 }
411
412 // Non-trivial cases
413 // - Internal+internal or external+external cases
416 areOneOf<{Type::TypeF, Type::TypeF}>(*this, intersection)) {
417 // The intersection needs to be at the end of the interval, otherwise we would
418 // split the interval in two, and we aren't set up to support multiple intervals
419 // per value.
420 if (a != intersection.a && b != intersection.b) {
421 return *this;
422 }
423 // Otherwise, remove the intersection and reduce
424 if (a == intersection.a) {
425 return UnreducedInterval(intersection.b + f.one(), b).reduce(f);
426 }
427 // else b == intersection.b
428 return UnreducedInterval(a, intersection.a - f.one()).reduce(f);
429 }
430 // - Mixed internal/external cases. We flip the comparison
431 if (isTypeF()) {
432 if (a != intersection.b && b != intersection.a) {
433 return *this;
434 }
435 // Otherwise, remove the intersection and reduce
436 if (a == intersection.b) {
437 return UnreducedInterval(intersection.a + f.one(), b).reduce(f);
438 }
439 // else b == intersection.a
440 return UnreducedInterval(a, intersection.b - f.one()).reduce(f);
441 }
442
443 // In cases we don't know how to handle, we over-approximate and return
444 // the original interval.
445 return *this;
446}
447
448Interval Interval::operator-() const { return (-firstUnreduced()).reduce(field.get()); }
449
451 return Interval::Degenerate(field.get(), field.get().one()) - *this;
452}
453
455 const Field &f = checkFields(lhs, rhs);
456 if (lhs.isEmpty() || rhs.isEntire()) {
457 return rhs;
458 }
459 if (rhs.isEmpty() || lhs.isEntire()) {
460 return lhs;
461 }
462 return (lhs.firstUnreduced() + rhs.firstUnreduced()).reduce(f);
463}
464
465Interval operator-(const Interval &lhs, const Interval &rhs) { return lhs + (-rhs); }
466
468 const Field &f = checkFields(lhs, rhs);
469 auto zeroInterval = Interval::Degenerate(f, f.zero());
470 if (lhs == zeroInterval || rhs == zeroInterval) {
471 return zeroInterval;
472 }
473 if (lhs.isEmpty() || rhs.isEmpty()) {
474 return Interval::Empty(f);
475 }
476 if (lhs.isEntire() || rhs.isEntire()) {
477 return Interval::Entire(f);
478 }
479
481 return (lhs.secondUnreduced() * rhs.secondUnreduced()).reduce(f);
482 }
483 return (lhs.firstUnreduced() * rhs.firstUnreduced()).reduce(f);
484}
485
486FailureOr<Interval> feltDiv(const Interval &lhs, const Interval &rhs) {
487 const Field &f = checkFields(lhs, rhs);
488 if (lhs.isEmpty() || rhs.isEmpty()) {
489 return success(Interval::Empty(f));
490 }
491 if (!rhs.isDegenerate() || rhs.lhs() == f.zero()) {
492 // Supporting arbitrary divisor intervals would require enumerating every
493 // possible divisor, inverting each value, and joining the products, which
494 // is too expensive. So, we return a failure in the non-degenerate case
495 // and in the divide-by-zero case.
496 return failure();
497 }
498 return success(lhs * Interval::Degenerate(f, f.inv(rhs.lhs())));
499}
500
501FailureOr<Interval> unsignedIntDiv(const Interval &lhs, const Interval &rhs) {
502 const Field &f = checkFields(lhs, rhs);
503 if (lhs.isEmpty() || rhs.isEmpty()) {
504 return success(Interval::Empty(f));
505 }
506
507 llvm::SmallVector<UnreducedInterval, 2> lhsParts = getUnsignedCanonicalParts(lhs);
508 llvm::SmallVector<UnreducedInterval, 2> rhsParts = getUnsignedCanonicalParts(rhs);
509 for (const UnreducedInterval &rhsPart : rhsParts) {
510 if (rhsPart.getLHS() == f.zero()) {
511 return failure();
512 }
513 }
514
515 Interval result = Interval::Empty(f);
516 for (const UnreducedInterval &lhsPart : lhsParts) {
517 for (const UnreducedInterval &rhsPart : rhsParts) {
518 result = joinDivisionPiece(
519 f, result, lhsPart.getLHS() / rhsPart.getRHS(), lhsPart.getLHS() / rhsPart.getLHS(),
520 lhsPart.getRHS() / rhsPart.getRHS(), lhsPart.getRHS() / rhsPart.getLHS()
521 );
522 }
523 }
524 return success(result);
525}
526
527FailureOr<Interval> signedIntDiv(const Interval &lhs, const Interval &rhs) {
528 const Field &f = checkFields(lhs, rhs);
529 if (lhs.isEmpty() || rhs.isEmpty()) {
530 return success(Interval::Empty(f));
531 }
532
533 llvm::SmallVector<UnreducedInterval, 2> lhsParts = getSignedCanonicalParts(lhs);
534 llvm::SmallVector<UnreducedInterval, 2> rhsParts = getSignedCanonicalParts(rhs);
535 for (const UnreducedInterval &rhsPart : rhsParts) {
536 if (containsZero(rhsPart)) {
537 return failure();
538 }
539 }
540
541 Interval result = Interval::Empty(f);
542 for (const UnreducedInterval &lhsPart : lhsParts) {
543 for (const UnreducedInterval &rhsPart : rhsParts) {
544 result = joinDivisionPiece(
545 f, result, lhsPart.getLHS() / rhsPart.getLHS(), lhsPart.getLHS() / rhsPart.getRHS(),
546 lhsPart.getRHS() / rhsPart.getLHS(), lhsPart.getRHS() / rhsPart.getRHS()
547 );
548 }
549 }
550 return success(result);
551}
552
554 const Field &f = checkFields(lhs, rhs);
555 if (lhs.isEmpty() || rhs.isEmpty()) {
556 return Interval::Empty(f);
557 }
558
559 if (lhs.isDegenerate() && rhs.isDegenerate() && rhs.a != f.zero()) {
560 return Interval::Degenerate(f, lhs.a % rhs.a);
561 }
562
563 if (rhs.isDegenerate()) {
564 if (rhs.a == f.zero()) {
565 return Interval::Entire(f);
566 }
567 return UnreducedInterval(f.zero(), rhs.a - f.one()).reduce(f);
568 }
569
570 // For any interval modulus, the result is bounded by the largest value of
571 // the interval.
572 // Since TypeF wraps around, the interval is just Entire since the max value
573 // would be the prime field's max value.
574 if (rhs.isTypeF() || rhs.isEntire()) {
575 return Interval::Entire(f);
576 }
577 // Any possible division by zero also yields Entire
578 Interval zeroInt = Interval::Degenerate(f, f.zero());
579 if (rhs.intersect(zeroInt) == zeroInt) {
580 return Interval::Entire(f);
581 }
582
583 return UnreducedInterval(f.zero(), rhs.b - f.one()).reduce(f);
584}
585
587 const Field &f = checkFields(lhs, rhs);
588 if (lhs.isEmpty() || rhs.isEmpty()) {
589 return Interval::Empty(f);
590 }
591 if (lhs.isDegenerate() && rhs.isDegenerate()) {
592 return Interval::Degenerate(f, lhs.a & rhs.a);
593 } else if (lhs.isDegenerate()) {
594 return UnreducedInterval(f.zero(), lhs.a).reduce(f);
595 } else if (rhs.isDegenerate()) {
596 return UnreducedInterval(f.zero(), rhs.a).reduce(f);
597 }
598 return Interval::Entire(f);
599}
600
602 const Field &f = checkFields(lhs, rhs);
603 if (lhs.isEmpty() || rhs.isEmpty()) {
604 return Interval::Empty(f);
605 }
606 auto zeroInterval = Interval::Degenerate(f, f.zero());
607 if (lhs == zeroInterval) {
608 return rhs;
609 }
610 if (rhs == zeroInterval) {
611 return lhs;
612 }
613 if (lhs.isDegenerate() && rhs.isDegenerate()) {
614 return Interval::Degenerate(f, f.reduce(lhs.a | rhs.a));
615 }
616 return Interval::Entire(f);
617}
618
620 const Field &f = checkFields(lhs, rhs);
621 if (lhs.isEmpty() || rhs.isEmpty()) {
622 return Interval::Empty(f);
623 }
624 auto zeroInterval = Interval::Degenerate(f, f.zero());
625 if (lhs == zeroInterval) {
626 return rhs;
627 }
628 if (rhs == zeroInterval) {
629 return lhs;
630 }
631 if (lhs.isDegenerate() && rhs.isDegenerate()) {
632 return Interval::Degenerate(f, f.reduce(lhs.a ^ rhs.a));
633 }
634 return Interval::Entire(f);
635}
636
638 const Field &f = checkFields(lhs, rhs);
639 if (lhs.isEmpty() || rhs.isEmpty()) {
640 return Interval::Empty(f);
641 }
642 if (lhs.isDegenerate() && rhs.isDegenerate()) {
643 if (rhs.a > f.bitWidth()) {
644 return Interval::Entire(f);
645 }
646
647 DynamicAPInt v = lhs.a << rhs.a;
648 return UnreducedInterval(v, v).reduce(f);
649 }
650 return Interval::Entire(f);
651}
652
654 const Field &f = checkFields(lhs, rhs);
655 if (lhs.isEmpty() || rhs.isEmpty()) {
656 return Interval::Empty(f);
657 }
658 if (lhs.isDegenerate() && rhs.isDegenerate()) {
659 if (rhs.a > f.bitWidth()) {
660 return Interval::Degenerate(f, f.zero());
661 }
662
663 return Interval::Degenerate(f, lhs.a >> rhs.a);
664 }
665 return Interval::Entire(f);
666}
667
668DynamicAPInt Interval::width() const {
669 switch (ty) {
670 case Type::Empty:
671 return field.get().zero();
672 case Type::Degenerate:
673 return field.get().one();
674 case Type::Entire:
675 return field.get().prime();
676 default:
677 return field.get().reduce(toUnreduced().width());
678 }
679}
680
682 ensure(
683 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
684 );
685 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
686 const auto &field = rhs.getField();
687
688 if (lhs.isBoolFalse() || rhs.isBoolFalse()) {
689 return Interval::False(field);
690 }
691 if (lhs.isBoolTrue() && rhs.isBoolTrue()) {
692 return Interval::True(field);
693 }
694
695 return Interval::Boolean(field);
696}
697
699 ensure(
700 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
701 );
702 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
703 const auto &field = rhs.getField();
704
705 if (lhs.isBoolFalse() && rhs.isBoolFalse()) {
706 return Interval::False(field);
707 }
708 if (lhs.isBoolTrue() || rhs.isBoolTrue()) {
709 return Interval::True(field);
710 }
711
712 return Interval::Boolean(field);
713}
714
716 ensure(
717 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
718 );
719 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
720 const auto &field = rhs.getField();
721
722 // Xor-ing anything with [0, 1] could still result in either case, so just return
723 // the full boolean range.
724 if (lhs.isBoolEither() || rhs.isBoolEither()) {
725 return Interval::Boolean(lhs.getField());
726 }
727
728 if (lhs.isBoolTrue() && rhs.isBoolTrue()) {
729 return Interval::False(field);
730 }
731 if (lhs.isBoolTrue() || rhs.isBoolTrue()) {
732 return Interval::True(field);
733 }
734 if (lhs.isBoolFalse() && rhs.isBoolFalse()) {
735 return Interval::False(field);
736 }
737
738 return Interval::Boolean(field);
739}
740
742 ensure(iv.isBoolean(), "operation only supported for boolean-type intervals");
743 const auto &field = iv.getField();
744
745 if (iv.isBoolTrue()) {
746 return Interval::False(field);
747 }
748 if (iv.isBoolFalse()) {
749 return Interval::True(field);
750 }
751
752 return iv;
753}
754
755void Interval::print(mlir::raw_ostream &os) const {
756 os << TypeName(ty);
757 if (is<Type::Degenerate>()) {
758 os << '(' << a << ')';
759 } else if (!is<Type::Entire, Type::Empty>()) {
760 os << ":[ " << a << ", " << b << " ]";
761 }
762}
763
764} // namespace llzk
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
Definition Field.h:35
llvm::DynamicAPInt half() const
Returns p / 2.
Definition Field.h:74
llvm::DynamicAPInt zero() const
Returns 0 at the bitwidth of the field.
Definition Field.h:80
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:71
llvm::DynamicAPInt reduce(const llvm::DynamicAPInt &i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
llvm::DynamicAPInt one() const
Returns 1 at the bitwidth of the field.
Definition Field.h:83
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
unsigned bitWidth() const
Definition Field.h:106
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
llvm::DynamicAPInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
Definition Field.h:86
Intervals over a finite field.
Definition Intervals.h:200
bool isEmpty() const
Definition Intervals.h:308
static Interval True(const Field &f)
Definition Intervals.h:219
llvm::DynamicAPInt rhs() const
Definition Intervals.h:333
Interval intersect(const Interval &rhs) const
Intersect.
bool isBoolean() const
Definition Intervals.h:320
static std::string_view TypeName(Type t)
Definition Intervals.h:207
void print(llvm::raw_ostream &os) const
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
Definition Intervals.h:221
bool isBoolFalse() const
Definition Intervals.h:317
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
bool isDegenerate() const
Definition Intervals.h:310
const Field & getField() const
Definition Intervals.h:328
bool isBoolTrue() const
Definition Intervals.h:318
bool is() const
Definition Intervals.h:322
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
bool isTypeF() const
Definition Intervals.h:315
static Interval False(const Field &f)
Definition Intervals.h:217
static Interval Empty(const Field &f)
Definition Intervals.h:211
Interval operator~() const
llvm::DynamicAPInt lhs() const
Definition Intervals.h:332
static bool areOneOf(const Interval &a, const Interval &b)
Definition Intervals.h:257
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
Definition Intervals.h:243
static Interval Degenerate(const Field &f, const llvm::DynamicAPInt &val)
Definition Intervals.h:213
llvm::DynamicAPInt width() const
bool isEntire() const
Definition Intervals.h:311
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
Definition Intervals.h:26
UnreducedInterval operator-() const
Definition Intervals.cpp:93
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
Definition Intervals.cpp:53
UnreducedInterval(const llvm::DynamicAPInt &x, const llvm::DynamicAPInt &y)
Definition Intervals.h:28
bool isEmpty() const
Returns true iff width() is zero.
Definition Intervals.h:114
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
Definition Intervals.cpp:63
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
Definition Intervals.cpp:86
bool isNotEmpty() const
Definition Intervals.h:116
bool overlaps(const UnreducedInterval &rhs) const
llvm::DynamicAPInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
Definition Intervals.cpp:58
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Definition Intervals.cpp:78
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
Definition Intervals.cpp:23
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Definition Intervals.cpp:71
ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
ExpressionValue intersection(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > signedIntDiv(const Interval &lhs, const Interval &rhs)
Computes signed integer division with possibly non-Degenerate divisors.
Interval operator|(const Interval &lhs, const Interval &rhs)
Interval operator^(const Interval &lhs, const Interval &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator<<(const Interval &lhs, const Interval &rhs)
ExpressionValue boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > unsignedIntDiv(const Interval &lhs, const Interval &rhs)
Computes unsigned integer division with possibly non-Degenerate divisors.
std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval operator-(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
Interval operator>>(const Interval &lhs, const Interval &rhs)
Interval operator&(const Interval &lhs, const Interval &rhs)
UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
const Field & checkFields(const Interval &lhs, const Interval &rhs)
UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
FailureOr< Interval > feltDiv(const Interval &lhs, const Interval &rhs)
Computes finite-field division by multiplying the dividend by the multiplicative inverse of the divis...
ExpressionValue boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)