LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Builders.h
Go to the documentation of this file.
1//===-- Builders.h ----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
14
15#include <mlir/IR/Builders.h>
16#include <mlir/IR/MLIRContext.h>
17
18#include <memory>
19#include <unordered_map>
20
21namespace llzk {
22
23inline mlir::Location getUnknownLoc(mlir::MLIRContext *context) {
24 return mlir::UnknownLoc::get(context);
25}
26
27mlir::OwningOpRef<mlir::ModuleOp> createLLZKModule(mlir::MLIRContext *context, mlir::Location loc);
28
29inline mlir::OwningOpRef<mlir::ModuleOp> createLLZKModule(mlir::MLIRContext *context) {
30 return createLLZKModule(context, getUnknownLoc(context));
31}
32
33void addLangAttrForLLZKDialect(mlir::ModuleOp mod);
34
36protected:
37 mlir::MLIRContext *context;
38
39public:
40 BaseBuilder(mlir::MLIRContext *ctx) : context(ctx) {}
41
42 inline mlir::Location getUnknownLoc() { return llzk::getUnknownLoc(context); }
43};
44
45template <typename Derived> class ModuleLikeBuilder : public BaseBuilder {
46private:
47 friend Derived;
48
49 ModuleLikeBuilder(mlir::MLIRContext *ctx) : BaseBuilder(ctx) {}
50
51protected:
52 // keyed on function name
53 std::unordered_map<std::string_view, function::FuncDefOp> freeFuncMap;
54 // keyed on struct name
55 std::unordered_map<std::string_view, component::StructDefOp> structMap;
56 // keyed on struct name
57 std::unordered_map<std::string_view, function::FuncDefOp> computeFnMap;
58 // keyed on struct name
59 std::unordered_map<std::string_view, function::FuncDefOp> constrainFnMap;
60 // keyed on struct name
61 std::unordered_map<std::string_view, function::FuncDefOp> productFnMap;
62
66 void ensureNoSuchFreeFunc(std::string_view funcName);
67
71 void ensureFreeFnExists(std::string_view funcName);
72
76 void ensureNoSuchStruct(std::string_view structName);
77
81 void ensureStructExists(std::string_view structName);
82
86 void ensureNoSuchComputeFn(std::string_view structName);
87
91 void ensureComputeFnExists(std::string_view structName);
92
96 void ensureNoSuchConstrainFn(std::string_view structName);
97
101 void ensureConstrainFnExists(std::string_view structName);
102
106 void ensureNoSuchProductFn(std::string_view structName);
107
111 void ensureProductFnExists(std::string_view structName);
112
113public:
114 /* Getter methods */
115
116 inline mlir::Region &getBodyRegion() { return static_cast<Derived *>(this)->getBodyRegion(); }
117
118 mlir::FailureOr<component::StructDefOp> getStruct(std::string_view structName) const {
119 if (structMap.find(structName) != structMap.end()) {
120 return structMap.at(structName);
121 }
122 return mlir::failure();
123 }
124
125 mlir::FailureOr<function::FuncDefOp> getComputeFn(std::string_view structName) const {
126 if (computeFnMap.find(structName) != computeFnMap.end()) {
127 return computeFnMap.at(structName);
128 }
129 return mlir::failure();
130 }
131 inline mlir::FailureOr<function::FuncDefOp> getComputeFn(component::StructDefOp op) const {
132 return getComputeFn(op.getName());
133 }
134
135 mlir::FailureOr<function::FuncDefOp> getConstrainFn(std::string_view structName) const {
136 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
137 return constrainFnMap.at(structName);
138 }
139 return mlir::failure();
140 }
141 inline mlir::FailureOr<function::FuncDefOp> getConstrainFn(component::StructDefOp op) const {
142 return getConstrainFn(op.getName());
143 }
144
145 mlir::FailureOr<function::FuncDefOp> getProductFn(std::string_view structName) const {
146 if (productFnMap.find(structName) != productFnMap.end()) {
147 return productFnMap.at(structName);
148 }
149 return mlir::failure();
150 }
151 inline mlir::FailureOr<function::FuncDefOp> getProductFn(component::StructDefOp op) const {
152 return getProductFn(op.getName());
153 }
154
155 mlir::FailureOr<function::FuncDefOp> getFreeFunc(std::string_view funcName) const {
156 if (freeFuncMap.find(funcName) != freeFuncMap.end()) {
157 return freeFuncMap.at(funcName);
158 }
159 return mlir::failure();
160 }
161
162 /* Builder methods */
163
164 Derived &insertEmptyStruct(std::string_view structName, mlir::Location loc);
165 inline Derived &insertEmptyStruct(std::string_view structName) {
166 return insertEmptyStruct(structName, getUnknownLoc());
167 }
168
170 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc
171 ) {
172 insertEmptyStruct(structName, structLoc);
173 insertComputeFn(structName, computeLoc);
174 return static_cast<Derived &>(*this);
175 }
176
177 Derived &insertComputeOnlyStruct(std::string_view structName) {
178 auto unk = getUnknownLoc();
179 return insertComputeOnlyStruct(structName, unk, unk);
180 }
181
183 std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc
184 ) {
185 insertEmptyStruct(structName, structLoc);
186 insertConstrainFn(structName, constrainLoc);
187 return static_cast<Derived &>(*this);
188 }
189
190 Derived &insertConstrainOnlyStruct(std::string_view structName) {
191 auto unk = getUnknownLoc();
192 return insertConstrainOnlyStruct(structName, unk, unk);
193 }
194
196 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc,
197 mlir::Location constrainLoc
198 ) {
199 insertEmptyStruct(structName, structLoc);
200 insertComputeFn(structName, computeLoc);
201 insertConstrainFn(structName, constrainLoc);
202 return static_cast<Derived &>(*this);
203 }
204
206 Derived &insertFullStruct(std::string_view structName) {
207 auto unk = getUnknownLoc();
208 return insertFullStruct(structName, unk, unk, unk);
209 }
210
212 std::string_view structName, mlir::Location structLoc, mlir::Location productLoc
213 ) {
214 insertEmptyStruct(structName, structLoc);
215 insertProductFn(structName, productLoc);
216 return static_cast<Derived &>(*this);
217 }
218
219 Derived &insertProductStruct(std::string_view structName) {
220 auto unk = getUnknownLoc();
221 return insertProductStruct(structName, unk, unk);
222 }
223
228 static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc);
229 Derived &insertComputeFn(component::StructDefOp op, mlir::Location loc);
230 Derived &insertComputeFn(std::string_view structName, mlir::Location loc);
231 inline Derived &insertComputeFn(std::string_view structName) {
232 return insertComputeFn(structName, getUnknownLoc());
233 }
234
238 static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc);
239 Derived &insertConstrainFn(component::StructDefOp op, mlir::Location loc);
240 Derived &insertConstrainFn(std::string_view structName, mlir::Location loc);
241 inline Derived &insertConstrainFn(std::string_view structName) {
242 return insertConstrainFn(structName, getUnknownLoc());
243 }
244
249 static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc);
250 Derived &insertProductFn(component::StructDefOp op, mlir::Location loc);
251 Derived &insertProductFn(std::string_view structName, mlir::Location loc);
252 inline Derived &insertProductFn(std::string_view structName) {
253 return insertProductFn(structName, getUnknownLoc());
254 }
255
262 component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc
263 );
264 Derived &
265 insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc);
266 Derived &insertComputeCall(std::string_view caller, std::string_view callee) {
267 return insertComputeCall(caller, callee, getUnknownLoc());
268 }
269
277 component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc,
278 mlir::Location memberDefLoc
279 );
281 std::string_view caller, std::string_view callee, mlir::Location callLoc,
282 mlir::Location memberDefLoc
283 );
284 Derived &insertConstrainCall(std::string_view caller, std::string_view callee) {
285 return insertConstrainCall(caller, callee, getUnknownLoc(), getUnknownLoc());
286 }
287
288 Derived &insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc);
289 inline Derived &insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type) {
290 return insertFreeFunc(funcName, type, getUnknownLoc());
291 }
292
293 Derived &
294 insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc);
295 Derived &insertFreeCall(function::FuncDefOp caller, std::string_view callee) {
296 return insertFreeCall(caller, callee, getUnknownLoc());
297 }
298};
299
303class TemplateBuilder : public ModuleLikeBuilder<TemplateBuilder> {
304 polymorphic::TemplateOp myTemplate;
305
306public:
307 TemplateBuilder(polymorphic::TemplateOp t) : ModuleLikeBuilder(t.getContext()), myTemplate(t) {}
308
309 /* Getter methods */
310
311 mlir::Region &getBodyRegion() { return myTemplate.getBodyRegion(); }
312
314 polymorphic::TemplateOp &getTemplate() { return myTemplate; }
315
316 mlir::FailureOr<polymorphic::TemplateParamOp> getParam(std::string_view name) {
317 auto op = myTemplate.getConstNamed<polymorphic::TemplateParamOp>(name);
318 if (op) {
319 return op;
320 }
321 return mlir::failure();
322 }
323
324 mlir::FailureOr<polymorphic::TemplateExprOp> getExpr(std::string_view name) {
325 auto op = myTemplate.getConstNamed<polymorphic::TemplateExprOp>(name);
326 if (op) {
327 return op;
328 }
329 return mlir::failure();
330 }
331
333 insertParam(std::string_view name, mlir::Location loc, mlir::TypeAttr type = {}) {
334 if (succeeded(getParam(name))) {
335 llvm::report_fatal_error("Duplicate TemplateParamOp insertion attempted");
336 }
337
338 mlir::OpBuilder builder(context);
339
340 auto &region = getBodyRegion();
341 if (region.empty()) {
342 region.emplaceBlock();
343 }
344
345 builder.setInsertionPointToEnd(&region.front());
346
347 auto nameAttr = builder.getStringAttr(name);
348
349 builder.create<polymorphic::TemplateParamOp>(loc, nameAttr, type);
350
351 return *this;
352 }
353
354 inline TemplateBuilder &insertParam(std::string_view name) {
355 return insertParam(name, getUnknownLoc());
356 }
357
358 TemplateBuilder &insertExpr(std::string_view name, mlir::Location loc) {
359 if (succeeded(getExpr(name))) {
360 llvm::report_fatal_error("Duplicate TemplateExprOp insertion attempted");
361 }
362
363 mlir::OpBuilder builder(context);
364
365 auto &region = getBodyRegion();
366 if (region.empty()) {
367 region.emplaceBlock();
368 }
369
370 builder.setInsertionPointToEnd(&region.front());
371
372 auto nameAttr = builder.getStringAttr(name);
373
374 builder.create<polymorphic::TemplateExprOp>(loc, nameAttr);
375
376 return *this;
377 }
378
379 inline TemplateBuilder &insertExpr(std::string_view name) {
380 return insertExpr(name, getUnknownLoc());
381 }
382};
383
387class ModuleBuilder : public ModuleLikeBuilder<ModuleBuilder> {
388 mlir::ModuleOp myModule;
389
390 // keyed on template name
391 std::unordered_map<std::string_view, std::unique_ptr<TemplateBuilder>> templateMap;
392 // keyed on nested module name
393 std::unordered_map<std::string_view, std::unique_ptr<ModuleBuilder>> nestedModuleMap;
394
398 void ensureNoSuchTemplate(std::string_view templateName);
399
403 void ensureTemplateExists(std::string_view templateName);
404
408 void ensureNoSuchNestedModule(std::string_view moduleName);
409
413 void ensureNestedModuleExists(std::string_view moduleName);
414
415public:
416 ModuleBuilder(mlir::ModuleOp m) : ModuleLikeBuilder(m.getContext()), myModule(m) {}
417
418 /* Getter methods */
419
420 mlir::Region &getBodyRegion() { return myModule.getBodyRegion(); }
421
423 mlir::ModuleOp &getModule() { return myModule; }
424
425 mlir::FailureOr<TemplateBuilder *> getTemplate(std::string_view templateName) const {
426 auto it = templateMap.find(templateName);
427 if (it != templateMap.end()) {
428 return it->second.get();
429 }
430 return mlir::failure();
431 }
432
433 mlir::FailureOr<ModuleBuilder *> getNestedModule(std::string_view moduleName) const {
434 auto it = nestedModuleMap.find(moduleName);
435 if (it != nestedModuleMap.end()) {
436 return it->second.get();
437 }
438 return mlir::failure();
439 }
440
441 /* Builder methods */
442
444 insertTemplate(std::string_view templateName, mlir::Location loc, unsigned numParams = 0);
445 inline ModuleBuilder &insertTemplate(std::string_view templateName, unsigned numParams = 0) {
446 return insertTemplate(templateName, getUnknownLoc(), numParams);
447 }
448
449 ModuleBuilder &insertNestedModule(std::string_view moduleName, mlir::Location loc);
450 inline ModuleBuilder &insertNestedModule(std::string_view moduleName) {
451 return insertNestedModule(moduleName, getUnknownLoc());
452 }
453};
454
455} // namespace llzk
mlir::Location getUnknownLoc()
Definition Builders.h:42
mlir::MLIRContext * context
Definition Builders.h:37
BaseBuilder(mlir::MLIRContext *ctx)
Definition Builders.h:40
Builds out a LLZK-compliant module and provides utilities for populating that module.
Definition Builders.h:387
mlir::Region & getBodyRegion()
Definition Builders.h:420
ModuleBuilder(mlir::ModuleOp m)
Definition Builders.h:416
mlir::FailureOr< TemplateBuilder * > getTemplate(std::string_view templateName) const
Definition Builders.h:425
ModuleBuilder & insertNestedModule(std::string_view moduleName)
Definition Builders.h:450
ModuleBuilder & insertTemplate(std::string_view templateName, mlir::Location loc, unsigned numParams=0)
mlir::FailureOr< ModuleBuilder * > getNestedModule(std::string_view moduleName) const
Definition Builders.h:433
ModuleBuilder & insertTemplate(std::string_view templateName, unsigned numParams=0)
Definition Builders.h:445
ModuleBuilder & insertNestedModule(std::string_view moduleName, mlir::Location loc)
mlir::ModuleOp & getModule()
Get the associated module of this builder.
Definition Builders.h:423
mlir::Region & getBodyRegion()
Definition Builders.h:116
Derived & insertComputeFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location memberDefLoc)
To call a constraint function, you must:
std::unordered_map< std::string_view, function::FuncDefOp > productFnMap
Definition Builders.h:61
Derived & insertComputeFn(std::string_view structName, mlir::Location loc)
Derived & insertComputeFn(std::string_view structName)
Definition Builders.h:231
Derived & insertComputeOnlyStruct(std::string_view structName)
Definition Builders.h:177
Derived & insertProductStruct(std::string_view structName)
Definition Builders.h:219
mlir::FailureOr< function::FuncDefOp > getProductFn(component::StructDefOp op) const
Definition Builders.h:151
Derived & insertConstrainCall(std::string_view caller, std::string_view callee)
Definition Builders.h:284
Derived & insertProductFn(std::string_view structName, mlir::Location loc)
Derived & insertComputeCall(std::string_view caller, std::string_view callee)
Definition Builders.h:266
void ensureNoSuchFreeFunc(std::string_view funcName)
Ensure that a global function with the given funcName has not been added, reporting a fatal error oth...
Definition Builders.cpp:41
mlir::FailureOr< function::FuncDefOp > getProductFn(std::string_view structName) const
Definition Builders.h:145
Derived & insertProductStruct(std::string_view structName, mlir::Location structLoc, mlir::Location productLoc)
Definition Builders.h:211
mlir::FailureOr< function::FuncDefOp > getFreeFunc(std::string_view funcName) const
Definition Builders.h:155
mlir::FailureOr< function::FuncDefOp > getConstrainFn(component::StructDefOp op) const
Definition Builders.h:141
std::unordered_map< std::string_view, function::FuncDefOp > constrainFnMap
Definition Builders.h:59
Derived & insertFreeCall(function::FuncDefOp caller, std::string_view callee)
Definition Builders.h:295
Derived & insertEmptyStruct(std::string_view structName, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
Definition Builders.cpp:124
Derived & insertConstrainCall(std::string_view caller, std::string_view callee, mlir::Location callLoc, mlir::Location memberDefLoc)
Derived & insertFullStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc, mlir::Location constrainLoc)
Definition Builders.h:195
Derived & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
mlir::FailureOr< function::FuncDefOp > getComputeFn(std::string_view structName) const
Definition Builders.h:125
Derived & insertConstrainOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc)
Definition Builders.h:182
mlir::FailureOr< function::FuncDefOp > getComputeFn(component::StructDefOp op) const
Definition Builders.h:131
Derived & insertConstrainFn(std::string_view structName, mlir::Location loc)
void ensureProductFnExists(std::string_view structName)
Ensure that the given struct has a product function, reporting a fatal error otherwise.
Definition Builders.cpp:104
void ensureFreeFnExists(std::string_view funcName)
Ensure that a global function with the given funcName has been added, reporting a fatal error otherwi...
Definition Builders.cpp:48
Derived & insertProductFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
Derived & insertEmptyStruct(std::string_view structName)
Definition Builders.h:165
mlir::FailureOr< component::StructDefOp > getStruct(std::string_view structName) const
Definition Builders.h:118
void ensureNoSuchProductFn(std::string_view structName)
Ensure that the given struct does not have a product function, reporting a fatal error otherwise.
Definition Builders.cpp:97
void ensureStructExists(std::string_view structName)
Ensure that a struct with the given structName exists, reporting a fatal error otherwise.
Definition Builders.cpp:62
void ensureNoSuchConstrainFn(std::string_view structName)
Ensure that the given struct does not have a constrain function, reporting a fatal error otherwise.
Definition Builders.cpp:83
Derived & insertProductFn(std::string_view structName)
Definition Builders.h:252
Derived & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
mlir::FailureOr< function::FuncDefOp > getConstrainFn(std::string_view structName) const
Definition Builders.h:135
Derived & insertConstrainFn(std::string_view structName)
Definition Builders.h:241
std::unordered_map< std::string_view, function::FuncDefOp > freeFuncMap
Definition Builders.h:53
void ensureNoSuchComputeFn(std::string_view structName)
Ensure that the given struct does not have a compute function, reporting a fatal error otherwise.
Definition Builders.cpp:69
Derived & insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc)
void ensureComputeFnExists(std::string_view structName)
Ensure that the given struct has a compute function, reporting a fatal error otherwise.
Definition Builders.cpp:76
void ensureConstrainFnExists(std::string_view structName)
Ensure that the given struct has a constrain function, reporting a fatal error otherwise.
Definition Builders.cpp:90
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
Definition Builders.cpp:150
std::unordered_map< std::string_view, function::FuncDefOp > computeFnMap
Definition Builders.h:57
std::unordered_map< std::string_view, component::StructDefOp > structMap
Definition Builders.h:55
Derived & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type)
Definition Builders.h:289
Derived & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
void ensureNoSuchStruct(std::string_view structName)
Ensure that a struct with the given structName has not been added, reporting a fatal error otherwise.
Definition Builders.cpp:55
Derived & insertComputeOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc)
Definition Builders.h:169
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
Definition Builders.cpp:176
Derived & insertConstrainOnlyStruct(std::string_view structName)
Definition Builders.h:190
Derived & insertFullStruct(std::string_view structName)
Inserts a struct with both compute and constrain functions.
Definition Builders.h:206
Builds out a LLZK-compliant template and provides utilities for populating that template.
Definition Builders.h:303
TemplateBuilder & insertParam(std::string_view name, mlir::Location loc, mlir::TypeAttr type={})
Definition Builders.h:333
TemplateBuilder & insertExpr(std::string_view name)
Definition Builders.h:379
TemplateBuilder(polymorphic::TemplateOp t)
Definition Builders.h:307
mlir::FailureOr< polymorphic::TemplateParamOp > getParam(std::string_view name)
Definition Builders.h:316
mlir::Region & getBodyRegion()
Definition Builders.h:311
TemplateBuilder & insertParam(std::string_view name)
Definition Builders.h:354
mlir::FailureOr< polymorphic::TemplateExprOp > getExpr(std::string_view name)
Definition Builders.h:324
TemplateBuilder & insertExpr(std::string_view name, mlir::Location loc)
Definition Builders.h:358
polymorphic::TemplateOp & getTemplate()
Get the associated template of this builder.
Definition Builders.h:314
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *, Location loc)
Definition Builders.cpp:23
void addLangAttrForLLZKDialect(ModuleOp mod)
Definition Builders.cpp:29
mlir::Location getUnknownLoc(mlir::MLIRContext *context)
Definition Builders.h:23