LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Intervals.cpp
Go to the documentation of this file.
1//===-- Intervals.cpp ---------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
14
15#include <llvm/ADT/SmallVector.h>
16
17using namespace mlir;
18
19namespace llzk {
20
21/* UnreducedInterval */
22
24 if (a > b) {
25 return Interval::Empty(field);
26 }
27 if (width() >= field.prime()) {
28 return Interval::Entire(field);
29 }
30 auto lhs = field.reduce(a), rhs = field.reduce(b);
31 if (rhs == lhs) {
32 return Interval::Degenerate(field, lhs);
33 }
34
35 const auto &half = field.half();
36 if (lhs <= rhs) {
37 if (lhs < half && rhs < half) {
38 return Interval::TypeA(field, lhs, rhs);
39 } else if (lhs < half) {
40 return Interval::TypeC(field, lhs, rhs);
41 } else {
42 return Interval::TypeB(field, lhs, rhs);
43 }
44 } else {
45 if (lhs >= half && rhs < half) {
46 return Interval::TypeF(field, lhs, rhs);
47 } else {
48 return Interval::Entire(field);
49 }
50 }
51}
52
54 const auto &lhs = *this;
55 return UnreducedInterval(std::max(lhs.a, rhs.a), std::min(lhs.b, rhs.b));
56}
57
59 const auto &lhs = *this;
60 return UnreducedInterval(std::min(lhs.a, rhs.a), std::max(lhs.b, rhs.b));
61}
62
64 if (isEmpty() || rhs.isEmpty()) {
65 return *this;
66 }
67 DynamicAPInt bound = rhs.b - 1;
68 return UnreducedInterval(a, std::min(b, bound));
69}
70
72 if (isEmpty() || rhs.isEmpty()) {
73 return *this;
74 }
75 return UnreducedInterval(a, std::min(b, rhs.b));
76}
77
79 if (isEmpty() || rhs.isEmpty()) {
80 return *this;
81 }
82 DynamicAPInt bound = rhs.a + 1;
83 return UnreducedInterval(std::max(a, bound), b);
84}
85
87 if (isEmpty() || rhs.isEmpty()) {
88 return *this;
89 }
90 return UnreducedInterval(std::max(a, rhs.a), b);
91}
92
94 if (isEmpty()) {
95 return *this;
96 }
97 return UnreducedInterval(-b, -a);
98}
99
101 DynamicAPInt low = lhs.a + rhs.a, high = lhs.b + rhs.b;
102 return UnreducedInterval(low, high);
103}
104
106 return lhs + (-rhs);
107}
108
110 DynamicAPInt v1 = lhs.a * rhs.a;
111 DynamicAPInt v2 = lhs.a * rhs.b;
112 DynamicAPInt v3 = lhs.b * rhs.a;
113 DynamicAPInt v4 = lhs.b * rhs.b;
114
115 auto minVal = std::min({v1, v2, v3, v4});
116 auto maxVal = std::max({v1, v2, v3, v4});
117
118 return UnreducedInterval(minVal, maxVal);
119}
120
122 return isNotEmpty() && rhs.isNotEmpty() && (b >= rhs.a) && (a <= rhs.b);
123}
124
125std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs) {
126 if ((lhs.a < rhs.a) || ((lhs.a == rhs.a) && (lhs.b < rhs.b))) {
127 return std::strong_ordering::less;
128 }
129 if ((lhs.a > rhs.a) || ((lhs.a == rhs.a) && (lhs.b > rhs.b))) {
130 return std::strong_ordering::greater;
131 }
132 return std::strong_ordering::equal;
133}
134
135DynamicAPInt UnreducedInterval::width() const {
136 DynamicAPInt w;
137 if (a > b) {
138 // This would be reduced to an empty Interval, so the width is just zero.
139 w = 0;
140 } else {
141 // Since the range is inclusive, we add one to the difference to get the true width.
142 w = (b - a) + 1;
143 }
144 ensure(w >= 0, "cannot have negative width");
145 return w;
146}
147
148/* Interval */
149
150const Field &checkFields(const Interval &lhs, const Interval &rhs) {
151 ensure(
152 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
153 );
154 return lhs.getField();
155}
156
157namespace {
158
159llvm::SmallVector<UnreducedInterval, 2> getUnsignedCanonicalParts(const Interval &iv) {
160 const Field &f = iv.getField();
161 llvm::SmallVector<UnreducedInterval, 2> parts;
162 if (iv.isEmpty()) {
163 return parts;
164 }
165 if (iv.isEntire()) {
166 parts.emplace_back(f.zero(), f.maxVal());
167 return parts;
168 }
169 if (iv.isTypeF()) {
170 parts.emplace_back(iv.lhs(), f.maxVal());
171 parts.emplace_back(f.zero(), iv.rhs());
172 return parts;
173 }
174
175 parts.emplace_back(iv.lhs(), iv.rhs());
176 return parts;
177}
178
179llvm::SmallVector<UnreducedInterval, 2> getSignedCanonicalParts(const Interval &iv) {
180 const Field &f = iv.getField();
181 llvm::SmallVector<UnreducedInterval, 2> parts;
182 if (iv.isEmpty()) {
183 return parts;
184 }
185 if (iv.isEntire()) {
186 parts.emplace_back(f.half() - f.prime(), f.half() - f.one());
187 return parts;
188 }
189 if (iv.isDegenerate()) {
190 DynamicAPInt v = iv.lhs();
191 if (v < f.half()) {
192 parts.emplace_back(v, v);
193 } else {
194 parts.emplace_back(v - f.prime(), v - f.prime());
195 }
196 return parts;
197 }
198 if (iv.isTypeA()) {
199 parts.emplace_back(iv.lhs(), iv.rhs());
200 return parts;
201 }
202 if (iv.isTypeB()) {
203 parts.emplace_back(iv.lhs() - f.prime(), iv.rhs() - f.prime());
204 return parts;
205 }
206 if (iv.isTypeC()) {
207 parts.emplace_back(f.half() - f.prime(), iv.rhs() - f.prime());
208 parts.emplace_back(iv.lhs(), f.half() - f.one());
209 return parts;
210 }
211
212 ensure(iv.isTypeF(), "expected TypeF interval");
213 parts.emplace_back(iv.lhs() - f.prime(), iv.rhs());
214 return parts;
215}
216
217bool containsZero(const UnreducedInterval &iv) { return iv.getLHS() <= 0 && iv.getRHS() >= 0; }
218
219llvm::DynamicAPInt absSigned(const llvm::DynamicAPInt &v) { return v < 0 ? -v : v; }
220
221Interval joinDivisionPiece(
222 const Field &f, const Interval &acc, const llvm::DynamicAPInt &q0, const llvm::DynamicAPInt &q1,
223 const llvm::DynamicAPInt &q2, const llvm::DynamicAPInt &q3
224) {
225 DynamicAPInt minQ = std::min({q0, q1, q2, q3});
226 DynamicAPInt maxQ = std::max({q0, q1, q2, q3});
227 Interval piece = UnreducedInterval(minQ, maxQ).reduce(f);
228 return acc.join(piece);
229}
230
231} // namespace
232
234 if (isEmpty()) {
235 // Since ranges are inclusive, empty is encoded as `[a, b]` where `a` > `b`.
236 // This matches the definition provided by UnreducedInterval::width().
237 return UnreducedInterval(field.get().one(), field.get().zero());
238 }
239 if (isEntire()) {
240 return UnreducedInterval(field.get().zero(), field.get().maxVal());
241 }
242 return UnreducedInterval(a, b);
243}
244
246 if (is<Type::TypeF>()) {
247 return UnreducedInterval(a - field.get().prime(), b);
248 }
249 return toUnreduced();
250}
251
253 ensure(is<Type::TypeA, Type::TypeB, Type::TypeC>(), "unsupported range type");
254 return UnreducedInterval(a - field.get().prime(), b - field.get().prime());
255}
256
258 const auto &lhs = *this;
259 const Field &f = checkFields(lhs, rhs);
260
261 // Trivial cases
262 if (lhs.isEntire() || rhs.isEntire()) {
263 return Interval::Entire(f);
264 }
265 if (lhs.isEmpty()) {
266 return rhs;
267 }
268 if (rhs.isEmpty()) {
269 return lhs;
270 }
271 if (lhs.isDegenerate() || rhs.isDegenerate()) {
272 return lhs.toUnreduced().doUnion(rhs.toUnreduced()).reduce(f);
273 }
274
275 // More complex cases
276 if (areOneOf<
279 auto newLhs = std::min(lhs.a, rhs.a);
280 auto newRhs = std::max(lhs.b, rhs.b);
281 if (newLhs == newRhs) {
282 return Interval::Degenerate(f, newLhs);
283 }
284 return Interval(rhs.ty, f, newLhs, newRhs);
285 }
287 auto lhsUnred = lhs.firstUnreduced();
288 auto opt1 = rhs.firstUnreduced().doUnion(lhsUnred);
289 auto opt2 = rhs.secondUnreduced().doUnion(lhsUnred);
290 if (opt1.width() <= opt2.width()) {
291 return opt1.reduce(f);
292 }
293 return opt2.reduce(f);
294 }
296 return lhs.firstUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
297 }
299 return lhs.secondUnreduced().doUnion(rhs.firstUnreduced()).reduce(f);
300 }
302 return Interval::Entire(f);
303 }
304 if (areOneOf<
307 lhs, rhs
308 )) {
309 return rhs.join(lhs);
310 }
311 llvm::report_fatal_error("unhandled join case");
312 return Interval::Entire(f);
313}
314
316 const auto &lhs = *this;
317 const Field &f = checkFields(lhs, rhs);
318 // Trivial cases
319 if (lhs == rhs) {
320 return lhs;
321 }
322 if (lhs.isEmpty() || rhs.isEmpty()) {
323 return Interval::Empty(f);
324 }
325 if (lhs.isEntire()) {
326 return rhs;
327 }
328 if (rhs.isEntire()) {
329 return lhs;
330 }
331 if (lhs.isDegenerate() && rhs.isDegenerate()) {
332 // These must not be equal
333 return Interval::Empty(f);
334 }
335 if (lhs.isDegenerate()) {
336 return Interval::TypeA(f, lhs.a, lhs.a).intersect(rhs);
337 }
338 if (rhs.isDegenerate()) {
339 return Interval::TypeA(f, rhs.a, rhs.a).intersect(lhs);
340 }
341
342 // More complex cases
343 if (areOneOf<
346 auto maxA = std::max(lhs.a, rhs.a);
347 auto minB = std::min(lhs.b, rhs.b);
348 if (maxA < minB) {
349 return Interval(lhs.ty, f, maxA, minB);
350 } else if (maxA == minB) {
351 return Interval::Degenerate(f, maxA);
352 } else {
353 return Interval::Empty(f);
354 }
355 }
357 return Interval::Empty(f);
358 }
360 return lhs.firstUnreduced().intersect(rhs.firstUnreduced()).reduce(f);
361 }
363 return lhs.secondUnreduced().intersect(rhs.firstUnreduced()).reduce(f);
364 }
366 auto rhsUnred = rhs.firstUnreduced();
367 auto opt1 = lhs.firstUnreduced().intersect(rhsUnred).reduce(f);
368 auto opt2 = lhs.secondUnreduced().intersect(rhsUnred).reduce(f);
369 ensure(!opt1.isEntire() && !opt2.isEntire(), "impossible intersection");
370 if (opt1.isEmpty()) {
371 return opt2;
372 }
373 if (opt2.isEmpty()) {
374 return opt1;
375 }
376 return opt1.join(opt2);
377 }
378 if (areOneOf<
381 lhs, rhs
382 )) {
383 return rhs.intersect(lhs);
384 }
385 return Interval::Empty(f);
386}
387
389 const Field &f = checkFields(*this, other);
390 // intersect checks that we're in the same field
392 if (intersection.isEmpty()) {
393 // There's nothing to remove, so just return this
394 return *this;
395 }
396
397 // Trivial cases with a non-empty intersection
398 if (isDegenerate() || other.isEntire()) {
399 return Interval::Empty(f);
400 }
401 if (isEntire()) {
402 // Since we don't support punching arbitrary holes in ranges, we only reduce
403 // entire ranges if other is [0, b] or [a, prime - 1]
404 if (other.a == f.zero()) {
405 return UnreducedInterval(other.b + f.one(), f.maxVal()).reduce(f);
406 }
407 if (other.b == f.maxVal()) {
408 return UnreducedInterval(f.zero(), other.a - f.one()).reduce(f);
409 }
410
411 return *this;
412 }
413
414 // Non-trivial cases
415 // - Internal+internal or external+external cases
418 areOneOf<{Type::TypeF, Type::TypeF}>(*this, intersection)) {
419 // The intersection needs to be at the end of the interval, otherwise we would
420 // split the interval in two, and we aren't set up to support multiple intervals
421 // per value.
422 if (a != intersection.a && b != intersection.b) {
423 return *this;
424 }
425 // Otherwise, remove the intersection and reduce
426 if (a == intersection.a) {
427 return UnreducedInterval(intersection.b + f.one(), b).reduce(f);
428 }
429 // else b == intersection.b
430 return UnreducedInterval(a, intersection.a - f.one()).reduce(f);
431 }
432 // - Mixed internal/external cases. We flip the comparison
433 if (isTypeF()) {
434 if (a != intersection.b && b != intersection.a) {
435 return *this;
436 }
437 // Otherwise, remove the intersection and reduce
438 if (a == intersection.b) {
439 return UnreducedInterval(intersection.a + f.one(), b).reduce(f);
440 }
441 // else b == intersection.a
442 return UnreducedInterval(a, intersection.b - f.one()).reduce(f);
443 }
444
445 // In cases we don't know how to handle, we over-approximate and return
446 // the original interval.
447 return *this;
448}
449
450Interval Interval::operator-() const { return (-firstUnreduced()).reduce(field.get()); }
451
453 return Interval::Degenerate(field.get(), field.get().one()) - *this;
454}
455
457 const Field &f = checkFields(lhs, rhs);
458 if (lhs.isEmpty() || rhs.isEntire()) {
459 return rhs;
460 }
461 if (rhs.isEmpty() || lhs.isEntire()) {
462 return lhs;
463 }
464 return (lhs.firstUnreduced() + rhs.firstUnreduced()).reduce(f);
465}
466
467Interval operator-(const Interval &lhs, const Interval &rhs) { return lhs + (-rhs); }
468
470 const Field &f = checkFields(lhs, rhs);
471 auto zeroInterval = Interval::Degenerate(f, f.zero());
472 if (lhs == zeroInterval || rhs == zeroInterval) {
473 return zeroInterval;
474 }
475 if (lhs.isEmpty() || rhs.isEmpty()) {
476 return Interval::Empty(f);
477 }
478 if (lhs.isEntire() || rhs.isEntire()) {
479 return Interval::Entire(f);
480 }
481
483 return (lhs.secondUnreduced() * rhs.secondUnreduced()).reduce(f);
484 }
485 return (lhs.firstUnreduced() * rhs.firstUnreduced()).reduce(f);
486}
487
488FailureOr<Interval> feltDiv(const Interval &lhs, const Interval &rhs) {
489 const Field &f = checkFields(lhs, rhs);
490 if (lhs.isEmpty() || rhs.isEmpty()) {
491 return success(Interval::Empty(f));
492 }
493 if (!rhs.isDegenerate() || rhs.lhs() == f.zero()) {
494 // Supporting arbitrary divisor intervals would require enumerating every
495 // possible divisor, inverting each value, and joining the products, which
496 // is too expensive. So, we return a failure in the non-degenerate case
497 // and in the divide-by-zero case.
498 return failure();
499 }
500 return success(lhs * Interval::Degenerate(f, f.inv(rhs.lhs())));
501}
502
503FailureOr<Interval> unsignedIntDiv(const Interval &lhs, const Interval &rhs) {
504 const Field &f = checkFields(lhs, rhs);
505 if (lhs.isEmpty() || rhs.isEmpty()) {
506 return success(Interval::Empty(f));
507 }
508
509 llvm::SmallVector<UnreducedInterval, 2> lhsParts = getUnsignedCanonicalParts(lhs);
510 llvm::SmallVector<UnreducedInterval, 2> rhsParts = getUnsignedCanonicalParts(rhs);
511 for (const UnreducedInterval &rhsPart : rhsParts) {
512 if (rhsPart.getLHS() == f.zero()) {
513 return failure();
514 }
515 }
516
517 Interval result = Interval::Empty(f);
518 for (const UnreducedInterval &lhsPart : lhsParts) {
519 for (const UnreducedInterval &rhsPart : rhsParts) {
520 result = joinDivisionPiece(
521 f, result, lhsPart.getLHS() / rhsPart.getRHS(), lhsPart.getLHS() / rhsPart.getLHS(),
522 lhsPart.getRHS() / rhsPart.getRHS(), lhsPart.getRHS() / rhsPart.getLHS()
523 );
524 }
525 }
526 return success(result);
527}
528
529FailureOr<Interval> signedIntDiv(const Interval &lhs, const Interval &rhs) {
530 const Field &f = checkFields(lhs, rhs);
531 if (lhs.isEmpty() || rhs.isEmpty()) {
532 return success(Interval::Empty(f));
533 }
534
535 llvm::SmallVector<UnreducedInterval, 2> lhsParts = getSignedCanonicalParts(lhs);
536 llvm::SmallVector<UnreducedInterval, 2> rhsParts = getSignedCanonicalParts(rhs);
537 for (const UnreducedInterval &rhsPart : rhsParts) {
538 if (containsZero(rhsPart)) {
539 return failure();
540 }
541 }
542
543 Interval result = Interval::Empty(f);
544 for (const UnreducedInterval &lhsPart : lhsParts) {
545 for (const UnreducedInterval &rhsPart : rhsParts) {
546 result = joinDivisionPiece(
547 f, result, lhsPart.getLHS() / rhsPart.getLHS(), lhsPart.getLHS() / rhsPart.getRHS(),
548 lhsPart.getRHS() / rhsPart.getLHS(), lhsPart.getRHS() / rhsPart.getRHS()
549 );
550 }
551 }
552 return success(result);
553}
554
555Interval signedMod(const Interval &lhs, const Interval &rhs) {
556 const Field &f = checkFields(lhs, rhs);
557 if (lhs.isEmpty() || rhs.isEmpty()) {
558 return Interval::Empty(f);
559 }
560
561 llvm::SmallVector<UnreducedInterval, 2> lhsParts = getSignedCanonicalParts(lhs);
562 llvm::SmallVector<UnreducedInterval, 2> rhsParts = getSignedCanonicalParts(rhs);
563 Interval result = Interval::Empty(f);
564
565 for (const UnreducedInterval &lhsPart : lhsParts) {
566 for (const UnreducedInterval &rhsPart : rhsParts) {
567 if (containsZero(rhsPart)) {
568 return Interval::Entire(f);
569 }
570
571 if (lhsPart.getLHS() == lhsPart.getRHS() && rhsPart.getLHS() == rhsPart.getRHS()) {
572 auto rem = lhsPart.getLHS() % rhsPart.getLHS();
573 result = result.join(UnreducedInterval(rem, rem).reduce(f));
574 continue;
575 }
576
577 llvm::DynamicAPInt maxAbsDivisor =
578 std::max(absSigned(rhsPart.getLHS()), absSigned(rhsPart.getRHS()));
579 if (maxAbsDivisor == 0) {
580 return Interval::Entire(f);
581 }
582
583 llvm::DynamicAPInt maxAbsRemainder = maxAbsDivisor - 1;
584 llvm::DynamicAPInt low = lhsPart.getLHS() < 0 ? std::max(lhsPart.getLHS(), -maxAbsRemainder)
585 : llvm::DynamicAPInt(0);
586 llvm::DynamicAPInt high = lhsPart.getRHS() > 0 ? std::min(lhsPart.getRHS(), maxAbsRemainder)
587 : llvm::DynamicAPInt(0);
588 result = result.join(UnreducedInterval(low, high).reduce(f));
589 }
590 }
591
592 return result;
593}
594
596 const Field &f = checkFields(lhs, rhs);
597 if (lhs.isEmpty() || rhs.isEmpty()) {
598 return Interval::Empty(f);
599 }
600
601 if (lhs.isDegenerate() && rhs.isDegenerate() && rhs.a != f.zero()) {
602 return Interval::Degenerate(f, lhs.a % rhs.a);
603 }
604
605 if (rhs.isDegenerate()) {
606 if (rhs.a == f.zero()) {
607 return Interval::Entire(f);
608 }
609 return UnreducedInterval(f.zero(), rhs.a - f.one()).reduce(f);
610 }
611
612 // For any interval modulus, the result is bounded by the largest value of
613 // the interval.
614 // Since TypeF wraps around, the interval is just Entire since the max value
615 // would be the prime field's max value.
616 if (rhs.isTypeF() || rhs.isEntire()) {
617 return Interval::Entire(f);
618 }
619 // Any possible division by zero also yields Entire
620 Interval zeroInt = Interval::Degenerate(f, f.zero());
621 if (rhs.intersect(zeroInt) == zeroInt) {
622 return Interval::Entire(f);
623 }
624
625 return UnreducedInterval(f.zero(), rhs.b - f.one()).reduce(f);
626}
627
629 const Field &f = checkFields(lhs, rhs);
630 if (lhs.isEmpty() || rhs.isEmpty()) {
631 return Interval::Empty(f);
632 }
633 if (lhs.isDegenerate() && rhs.isDegenerate()) {
634 return Interval::Degenerate(f, lhs.a & rhs.a);
635 } else if (lhs.isDegenerate()) {
636 return UnreducedInterval(f.zero(), lhs.a).reduce(f);
637 } else if (rhs.isDegenerate()) {
638 return UnreducedInterval(f.zero(), rhs.a).reduce(f);
639 }
640 return Interval::Entire(f);
641}
642
644 const Field &f = checkFields(lhs, rhs);
645 if (lhs.isEmpty() || rhs.isEmpty()) {
646 return Interval::Empty(f);
647 }
648 auto zeroInterval = Interval::Degenerate(f, f.zero());
649 if (lhs == zeroInterval) {
650 return rhs;
651 }
652 if (rhs == zeroInterval) {
653 return lhs;
654 }
655 if (lhs.isDegenerate() && rhs.isDegenerate()) {
656 return Interval::Degenerate(f, f.reduce(lhs.a | rhs.a));
657 }
658 return Interval::Entire(f);
659}
660
662 const Field &f = checkFields(lhs, rhs);
663 if (lhs.isEmpty() || rhs.isEmpty()) {
664 return Interval::Empty(f);
665 }
666 auto zeroInterval = Interval::Degenerate(f, f.zero());
667 if (lhs == zeroInterval) {
668 return rhs;
669 }
670 if (rhs == zeroInterval) {
671 return lhs;
672 }
673 if (lhs.isDegenerate() && rhs.isDegenerate()) {
674 return Interval::Degenerate(f, f.reduce(lhs.a ^ rhs.a));
675 }
676 return Interval::Entire(f);
677}
678
680 const Field &f = checkFields(lhs, rhs);
681 if (lhs.isEmpty() || rhs.isEmpty()) {
682 return Interval::Empty(f);
683 }
684 if (lhs.isDegenerate() && rhs.isDegenerate()) {
685 if (rhs.a > f.bitWidth()) {
686 return Interval::Entire(f);
687 }
688
689 DynamicAPInt v = lhs.a << rhs.a;
690 return UnreducedInterval(v, v).reduce(f);
691 }
692 return Interval::Entire(f);
693}
694
696 const Field &f = checkFields(lhs, rhs);
697 if (lhs.isEmpty() || rhs.isEmpty()) {
698 return Interval::Empty(f);
699 }
700 if (lhs.isDegenerate() && rhs.isDegenerate()) {
701 if (rhs.a > f.bitWidth()) {
702 return Interval::Degenerate(f, f.zero());
703 }
704
705 return Interval::Degenerate(f, lhs.a >> rhs.a);
706 }
707 return Interval::Entire(f);
708}
709
710DynamicAPInt Interval::width() const {
711 switch (ty) {
712 case Type::Empty:
713 return field.get().zero();
714 case Type::Degenerate:
715 return field.get().one();
716 case Type::Entire:
717 return field.get().prime();
718 default:
719 return field.get().reduce(toUnreduced().width());
720 }
721}
722
724 ensure(
725 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
726 );
727 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
728 const auto &field = rhs.getField();
729
730 if (lhs.isBoolFalse() || rhs.isBoolFalse()) {
731 return Interval::False(field);
732 }
733 if (lhs.isBoolTrue() && rhs.isBoolTrue()) {
734 return Interval::True(field);
735 }
736
737 return Interval::Boolean(field);
738}
739
741 ensure(
742 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
743 );
744 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
745 const auto &field = rhs.getField();
746
747 if (lhs.isBoolFalse() && rhs.isBoolFalse()) {
748 return Interval::False(field);
749 }
750 if (lhs.isBoolTrue() || rhs.isBoolTrue()) {
751 return Interval::True(field);
752 }
753
754 return Interval::Boolean(field);
755}
756
758 ensure(
759 lhs.getField() == rhs.getField(), "interval operations across differing fields is unsupported"
760 );
761 ensure(lhs.isBoolean() && rhs.isBoolean(), "operation only supported for boolean-type intervals");
762 const auto &field = rhs.getField();
763
764 // Xor-ing anything with [0, 1] could still result in either case, so just return
765 // the full boolean range.
766 if (lhs.isBoolEither() || rhs.isBoolEither()) {
767 return Interval::Boolean(lhs.getField());
768 }
769
770 if (lhs.isBoolTrue() && rhs.isBoolTrue()) {
771 return Interval::False(field);
772 }
773 if (lhs.isBoolTrue() || rhs.isBoolTrue()) {
774 return Interval::True(field);
775 }
776 if (lhs.isBoolFalse() && rhs.isBoolFalse()) {
777 return Interval::False(field);
778 }
779
780 return Interval::Boolean(field);
781}
782
784 ensure(iv.isBoolean(), "operation only supported for boolean-type intervals");
785 const auto &field = iv.getField();
786
787 if (iv.isBoolTrue()) {
788 return Interval::False(field);
789 }
790 if (iv.isBoolFalse()) {
791 return Interval::True(field);
792 }
793
794 return iv;
795}
796
797void Interval::print(mlir::raw_ostream &os) const {
798 os << TypeName(ty);
799 if (is<Type::Degenerate>()) {
800 os << '(' << a << ')';
801 } else if (!is<Type::Entire, Type::Empty>()) {
802 os << ":[ " << a << ", " << b << " ]";
803 }
804}
805
806} // namespace llzk
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
Definition Field.h:36
llvm::DynamicAPInt half() const
Returns p / 2.
Definition Field.h:75
llvm::DynamicAPInt zero() const
Returns 0 at the bitwidth of the field.
Definition Field.h:81
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:72
llvm::DynamicAPInt reduce(const llvm::DynamicAPInt &i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
llvm::DynamicAPInt one() const
Returns 1 at the bitwidth of the field.
Definition Field.h:84
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
unsigned bitWidth() const
Definition Field.h:107
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
llvm::DynamicAPInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
Definition Field.h:87
Intervals over a finite field.
Definition Intervals.h:206
bool isEmpty() const
Definition Intervals.h:314
static Interval True(const Field &f)
Definition Intervals.h:225
llvm::DynamicAPInt rhs() const
Definition Intervals.h:339
Interval intersect(const Interval &rhs) const
Intersect.
bool isBoolean() const
Definition Intervals.h:326
static std::string_view TypeName(Type t)
Definition Intervals.h:213
void print(llvm::raw_ostream &os) const
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
Definition Intervals.h:227
bool isBoolFalse() const
Definition Intervals.h:323
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
bool isDegenerate() const
Definition Intervals.h:316
const Field & getField() const
Definition Intervals.h:334
bool isBoolTrue() const
Definition Intervals.h:324
bool is() const
Definition Intervals.h:328
UnreducedInterval secondUnreduced() const
Get the second side of the interval for TypeA, TypeB, and TypeC intervals.
bool isTypeF() const
Definition Intervals.h:321
static Interval False(const Field &f)
Definition Intervals.h:223
static Interval Empty(const Field &f)
Definition Intervals.h:217
Interval operator~() const
llvm::DynamicAPInt lhs() const
Definition Intervals.h:338
static bool areOneOf(const Interval &a, const Interval &b)
Definition Intervals.h:263
Interval()
To satisfy the dataflow::ScalarLatticeValue requirements, this class must be default initializable.
Definition Intervals.h:249
static Interval Degenerate(const Field &f, const llvm::DynamicAPInt &val)
Definition Intervals.h:219
llvm::DynamicAPInt width() const
bool isEntire() const
Definition Intervals.h:317
Interval difference(const Interval &other) const
Computes and returns this - (this & other) if the operation produces a single interval.
Interval operator-() const
Interval join(const Interval &rhs) const
Union.
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
Definition Intervals.h:26
UnreducedInterval operator-() const
Definition Intervals.cpp:93
UnreducedInterval intersect(const UnreducedInterval &rhs) const
Compute and return the intersection of this interval and the given RHS.
Definition Intervals.cpp:53
UnreducedInterval(const llvm::DynamicAPInt &x, const llvm::DynamicAPInt &y)
Definition Intervals.h:28
bool isEmpty() const
Returns true iff width() is zero.
Definition Intervals.h:120
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
Definition Intervals.cpp:63
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
Definition Intervals.cpp:86
bool isNotEmpty() const
Definition Intervals.h:122
bool overlaps(const UnreducedInterval &rhs) const
llvm::DynamicAPInt width() const
Compute the width of this interval within a given field f.
UnreducedInterval doUnion(const UnreducedInterval &rhs) const
Compute and return the union of this interval and the given RHS.
Definition Intervals.cpp:58
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Definition Intervals.cpp:78
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
Definition Intervals.cpp:23
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Definition Intervals.cpp:71
ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
ExpressionValue intersection(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > signedIntDiv(const Interval &lhs, const Interval &rhs)
Computes signed integer division with possibly non-Degenerate divisors.
Interval operator|(const Interval &lhs, const Interval &rhs)
Interval operator^(const Interval &lhs, const Interval &rhs)
Interval signedMod(const Interval &lhs, const Interval &rhs)
Computes signed integer remainder with possibly non-Degenerate divisors.
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Interval operator%(const Interval &lhs, const Interval &rhs)
Interval operator<<(const Interval &lhs, const Interval &rhs)
ExpressionValue boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > unsignedIntDiv(const Interval &lhs, const Interval &rhs)
Computes unsigned integer division with possibly non-Degenerate divisors.
std::strong_ordering operator<=>(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
UnreducedInterval operator-(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
Interval operator>>(const Interval &lhs, const Interval &rhs)
Interval operator&(const Interval &lhs, const Interval &rhs)
UnreducedInterval operator*(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
const Field & checkFields(const Interval &lhs, const Interval &rhs)
UnreducedInterval operator+(const UnreducedInterval &lhs, const UnreducedInterval &rhs)
FailureOr< Interval > feltDiv(const Interval &lhs, const Interval &rhs)
Computes finite-field division by multiplying the dividend by the multiplicative inverse of the divis...
ExpressionValue boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)