LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Builders.h
Go to the documentation of this file.
1//===-- Builders.h ----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
14
15#include <mlir/IR/Builders.h>
16#include <mlir/IR/MLIRContext.h>
17
18#include <memory>
19#include <unordered_map>
20
21namespace llzk {
22
23inline mlir::Location getUnknownLoc(mlir::MLIRContext *context) {
24 return mlir::UnknownLoc::get(context);
25}
26
27mlir::OwningOpRef<mlir::ModuleOp> createLLZKModule(mlir::MLIRContext *context, mlir::Location loc);
28
29inline mlir::OwningOpRef<mlir::ModuleOp> createLLZKModule(mlir::MLIRContext *context) {
30 return createLLZKModule(context, getUnknownLoc(context));
31}
32
33void addLangAttrForLLZKDialect(mlir::ModuleOp mod);
34
36protected:
37 mlir::MLIRContext *context;
38
39public:
40 BaseBuilder(mlir::MLIRContext *ctx) : context(ctx) {}
41
42 inline mlir::Location getUnknownLoc() { return llzk::getUnknownLoc(context); }
43};
44
45template <typename Derived> class ModuleLikeBuilder : public BaseBuilder {
46protected:
47 // keyed on function name
48 std::unordered_map<std::string_view, function::FuncDefOp> freeFuncMap;
49 // keyed on struct name
50 std::unordered_map<std::string_view, component::StructDefOp> structMap;
51 // keyed on struct name
52 std::unordered_map<std::string_view, function::FuncDefOp> computeFnMap;
53 // keyed on struct name
54 std::unordered_map<std::string_view, function::FuncDefOp> constrainFnMap;
55 // keyed on struct name
56 std::unordered_map<std::string_view, function::FuncDefOp> productFnMap;
57
61 void ensureNoSuchFreeFunc(std::string_view funcName);
62
66 void ensureFreeFnExists(std::string_view funcName);
67
71 void ensureNoSuchStruct(std::string_view structName);
72
76 void ensureStructExists(std::string_view structName);
77
81 void ensureNoSuchComputeFn(std::string_view structName);
82
86 void ensureComputeFnExists(std::string_view structName);
87
91 void ensureNoSuchConstrainFn(std::string_view structName);
92
96 void ensureConstrainFnExists(std::string_view structName);
97
101 void ensureNoSuchProductFn(std::string_view structName);
102
106 void ensureProductFnExists(std::string_view structName);
107
108public:
109 ModuleLikeBuilder(mlir::MLIRContext *ctx) : BaseBuilder(ctx) {}
110
111 /* Getter methods */
112
113 inline mlir::Region &getBodyRegion() { return static_cast<Derived *>(this)->getBodyRegion(); }
114
115 mlir::FailureOr<component::StructDefOp> getStruct(std::string_view structName) const {
116 if (structMap.find(structName) != structMap.end()) {
117 return structMap.at(structName);
118 }
119 return mlir::failure();
120 }
121
122 mlir::FailureOr<function::FuncDefOp> getComputeFn(std::string_view structName) const {
123 if (computeFnMap.find(structName) != computeFnMap.end()) {
124 return computeFnMap.at(structName);
125 }
126 return mlir::failure();
127 }
128 inline mlir::FailureOr<function::FuncDefOp> getComputeFn(component::StructDefOp op) const {
129 return getComputeFn(op.getName());
130 }
131
132 mlir::FailureOr<function::FuncDefOp> getConstrainFn(std::string_view structName) const {
133 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
134 return constrainFnMap.at(structName);
135 }
136 return mlir::failure();
137 }
138 inline mlir::FailureOr<function::FuncDefOp> getConstrainFn(component::StructDefOp op) const {
139 return getConstrainFn(op.getName());
140 }
141
142 mlir::FailureOr<function::FuncDefOp> getProductFn(std::string_view structName) const {
143 if (productFnMap.find(structName) != productFnMap.end()) {
144 return productFnMap.at(structName);
145 }
146 return mlir::failure();
147 }
148 inline mlir::FailureOr<function::FuncDefOp> getProductFn(component::StructDefOp op) const {
149 return getProductFn(op.getName());
150 }
151
152 mlir::FailureOr<function::FuncDefOp> getFreeFunc(std::string_view funcName) const {
153 if (freeFuncMap.find(funcName) != freeFuncMap.end()) {
154 return freeFuncMap.at(funcName);
155 }
156 return mlir::failure();
157 }
158
159 /* Builder methods */
160
161 Derived &insertEmptyStruct(std::string_view structName, mlir::Location loc);
162 inline Derived &insertEmptyStruct(std::string_view structName) {
163 return insertEmptyStruct(structName, getUnknownLoc());
164 }
165
167 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc
168 ) {
169 insertEmptyStruct(structName, structLoc);
170 insertComputeFn(structName, computeLoc);
171 return static_cast<Derived &>(*this);
172 }
173
174 Derived &insertComputeOnlyStruct(std::string_view structName) {
175 auto unk = getUnknownLoc();
176 return insertComputeOnlyStruct(structName, unk, unk);
177 }
178
180 std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc
181 ) {
182 insertEmptyStruct(structName, structLoc);
183 insertConstrainFn(structName, constrainLoc);
184 return static_cast<Derived &>(*this);
185 }
186
187 Derived &insertConstrainOnlyStruct(std::string_view structName) {
188 auto unk = getUnknownLoc();
189 return insertConstrainOnlyStruct(structName, unk, unk);
190 }
191
193 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc,
194 mlir::Location constrainLoc
195 ) {
196 insertEmptyStruct(structName, structLoc);
197 insertComputeFn(structName, computeLoc);
198 insertConstrainFn(structName, constrainLoc);
199 return static_cast<Derived &>(*this);
200 }
201
203 Derived &insertFullStruct(std::string_view structName) {
204 auto unk = getUnknownLoc();
205 return insertFullStruct(structName, unk, unk, unk);
206 }
207
209 std::string_view structName, mlir::Location structLoc, mlir::Location productLoc
210 ) {
211 insertEmptyStruct(structName, structLoc);
212 insertProductFn(structName, productLoc);
213 return static_cast<Derived &>(*this);
214 }
215
216 Derived &insertProductStruct(std::string_view structName) {
217 auto unk = getUnknownLoc();
218 return insertProductStruct(structName, unk, unk);
219 }
220
225 static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc);
226 Derived &insertComputeFn(component::StructDefOp op, mlir::Location loc);
227 Derived &insertComputeFn(std::string_view structName, mlir::Location loc);
228 inline Derived &insertComputeFn(std::string_view structName) {
229 return insertComputeFn(structName, getUnknownLoc());
230 }
231
235 static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc);
236 Derived &insertConstrainFn(component::StructDefOp op, mlir::Location loc);
237 Derived &insertConstrainFn(std::string_view structName, mlir::Location loc);
238 inline Derived &insertConstrainFn(std::string_view structName) {
239 return insertConstrainFn(structName, getUnknownLoc());
240 }
241
246 static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc);
247 Derived &insertProductFn(component::StructDefOp op, mlir::Location loc);
248 Derived &insertProductFn(std::string_view structName, mlir::Location loc);
249 inline Derived &insertProductFn(std::string_view structName) {
250 return insertProductFn(structName, getUnknownLoc());
251 }
252
259 component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc
260 );
261 Derived &
262 insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc);
263 Derived &insertComputeCall(std::string_view caller, std::string_view callee) {
264 return insertComputeCall(caller, callee, getUnknownLoc());
265 }
266
274 component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc,
275 mlir::Location memberDefLoc
276 );
278 std::string_view caller, std::string_view callee, mlir::Location callLoc,
279 mlir::Location memberDefLoc
280 );
281 Derived &insertConstrainCall(std::string_view caller, std::string_view callee) {
282 return insertConstrainCall(caller, callee, getUnknownLoc(), getUnknownLoc());
283 }
284
285 Derived &insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc);
286 inline Derived &insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type) {
287 return insertFreeFunc(funcName, type, getUnknownLoc());
288 }
289
290 Derived &
291 insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc);
292 Derived &insertFreeCall(function::FuncDefOp caller, std::string_view callee) {
293 return insertFreeCall(caller, callee, getUnknownLoc());
294 }
295};
296
300class TemplateBuilder : public ModuleLikeBuilder<TemplateBuilder> {
301 polymorphic::TemplateOp myTemplate;
302
303public:
304 TemplateBuilder(polymorphic::TemplateOp t) : ModuleLikeBuilder(t.getContext()), myTemplate(t) {}
305
306 /* Getter methods */
307
308 mlir::Region &getBodyRegion() { return myTemplate.getBodyRegion(); }
309
311 polymorphic::TemplateOp &getTemplate() { return myTemplate; }
312
313 // TODO: other getters for template-specific ops like param/expr
314
315 /* Builder methods */
316
317 // TODO: other builders for template-specific ops like param/expr
318};
319
323class ModuleBuilder : public ModuleLikeBuilder<ModuleBuilder> {
324 mlir::ModuleOp myModule;
325
326 // keyed on template name
327 std::unordered_map<std::string_view, std::unique_ptr<TemplateBuilder>> templateMap;
328 // keyed on nested module name
329 std::unordered_map<std::string_view, std::unique_ptr<ModuleBuilder>> nestedModuleMap;
330
334 void ensureNoSuchTemplate(std::string_view templateName);
335
339 void ensureTemplateExists(std::string_view templateName);
340
344 void ensureNoSuchNestedModule(std::string_view moduleName);
345
349 void ensureNestedModuleExists(std::string_view moduleName);
350
351public:
352 ModuleBuilder(mlir::ModuleOp m) : ModuleLikeBuilder(m.getContext()), myModule(m) {}
353
354 /* Getter methods */
355
356 mlir::Region &getBodyRegion() { return myModule.getBodyRegion(); }
357
359 mlir::ModuleOp &getModule() { return myModule; }
360
361 mlir::FailureOr<TemplateBuilder *> getTemplate(std::string_view templateName) const {
362 auto it = templateMap.find(templateName);
363 if (it != templateMap.end()) {
364 return it->second.get();
365 }
366 return mlir::failure();
367 }
368
369 mlir::FailureOr<ModuleBuilder *> getNestedModule(std::string_view moduleName) const {
370 auto it = nestedModuleMap.find(moduleName);
371 if (it != nestedModuleMap.end()) {
372 return it->second.get();
373 }
374 return mlir::failure();
375 }
376
377 /* Builder methods */
378
380 insertTemplate(std::string_view templateName, mlir::Location loc, unsigned numParams = 0);
381 inline ModuleBuilder &insertTemplate(std::string_view templateName, unsigned numParams = 0) {
382 return insertTemplate(templateName, getUnknownLoc(), numParams);
383 }
384
385 ModuleBuilder &insertNestedModule(std::string_view moduleName, mlir::Location loc);
386 inline ModuleBuilder &insertNestedModule(std::string_view moduleName) {
387 return insertNestedModule(moduleName, getUnknownLoc());
388 }
389};
390
391} // namespace llzk
mlir::Location getUnknownLoc()
Definition Builders.h:42
mlir::MLIRContext * context
Definition Builders.h:37
BaseBuilder(mlir::MLIRContext *ctx)
Definition Builders.h:40
Builds out a LLZK-compliant module and provides utilities for populating that module.
Definition Builders.h:323
mlir::Region & getBodyRegion()
Definition Builders.h:356
ModuleBuilder(mlir::ModuleOp m)
Definition Builders.h:352
mlir::FailureOr< TemplateBuilder * > getTemplate(std::string_view templateName) const
Definition Builders.h:361
ModuleBuilder & insertNestedModule(std::string_view moduleName)
Definition Builders.h:386
ModuleBuilder & insertTemplate(std::string_view templateName, mlir::Location loc, unsigned numParams=0)
mlir::FailureOr< ModuleBuilder * > getNestedModule(std::string_view moduleName) const
Definition Builders.h:369
ModuleBuilder & insertTemplate(std::string_view templateName, unsigned numParams=0)
Definition Builders.h:381
ModuleBuilder & insertNestedModule(std::string_view moduleName, mlir::Location loc)
mlir::ModuleOp & getModule()
Get the associated module of this builder.
Definition Builders.h:359
mlir::Region & getBodyRegion()
Definition Builders.h:113
Derived & insertComputeFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location memberDefLoc)
To call a constraint function, you must:
std::unordered_map< std::string_view, function::FuncDefOp > productFnMap
Definition Builders.h:56
Derived & insertComputeFn(std::string_view structName, mlir::Location loc)
Derived & insertComputeFn(std::string_view structName)
Definition Builders.h:228
Derived & insertComputeOnlyStruct(std::string_view structName)
Definition Builders.h:174
Derived & insertProductStruct(std::string_view structName)
Definition Builders.h:216
mlir::FailureOr< function::FuncDefOp > getProductFn(component::StructDefOp op) const
Definition Builders.h:148
Derived & insertConstrainCall(std::string_view caller, std::string_view callee)
Definition Builders.h:281
Derived & insertProductFn(std::string_view structName, mlir::Location loc)
Derived & insertComputeCall(std::string_view caller, std::string_view callee)
Definition Builders.h:263
void ensureNoSuchFreeFunc(std::string_view funcName)
Ensure that a global function with the given funcName has not been added, reporting a fatal error oth...
Definition Builders.cpp:41
mlir::FailureOr< function::FuncDefOp > getProductFn(std::string_view structName) const
Definition Builders.h:142
Derived & insertProductStruct(std::string_view structName, mlir::Location structLoc, mlir::Location productLoc)
Definition Builders.h:208
mlir::FailureOr< function::FuncDefOp > getFreeFunc(std::string_view funcName) const
Definition Builders.h:152
mlir::FailureOr< function::FuncDefOp > getConstrainFn(component::StructDefOp op) const
Definition Builders.h:138
std::unordered_map< std::string_view, function::FuncDefOp > constrainFnMap
Definition Builders.h:54
Derived & insertFreeCall(function::FuncDefOp caller, std::string_view callee)
Definition Builders.h:292
Derived & insertEmptyStruct(std::string_view structName, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
Definition Builders.cpp:124
Derived & insertConstrainCall(std::string_view caller, std::string_view callee, mlir::Location callLoc, mlir::Location memberDefLoc)
Derived & insertFullStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc, mlir::Location constrainLoc)
Definition Builders.h:192
Derived & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
mlir::FailureOr< function::FuncDefOp > getComputeFn(std::string_view structName) const
Definition Builders.h:122
Derived & insertConstrainOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc)
Definition Builders.h:179
mlir::FailureOr< function::FuncDefOp > getComputeFn(component::StructDefOp op) const
Definition Builders.h:128
Derived & insertConstrainFn(std::string_view structName, mlir::Location loc)
void ensureProductFnExists(std::string_view structName)
Ensure that the given struct has a product function, reporting a fatal error otherwise.
Definition Builders.cpp:104
void ensureFreeFnExists(std::string_view funcName)
Ensure that a global function with the given funcName has been added, reporting a fatal error otherwi...
Definition Builders.cpp:48
Derived & insertProductFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
Derived & insertEmptyStruct(std::string_view structName)
Definition Builders.h:162
mlir::FailureOr< component::StructDefOp > getStruct(std::string_view structName) const
Definition Builders.h:115
void ensureNoSuchProductFn(std::string_view structName)
Ensure that the given struct does not have a product function, reporting a fatal error otherwise.
Definition Builders.cpp:97
void ensureStructExists(std::string_view structName)
Ensure that a struct with the given structName exists, reporting a fatal error otherwise.
Definition Builders.cpp:62
void ensureNoSuchConstrainFn(std::string_view structName)
Ensure that the given struct does not have a constrain function, reporting a fatal error otherwise.
Definition Builders.cpp:83
Derived & insertProductFn(std::string_view structName)
Definition Builders.h:249
Derived & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
mlir::FailureOr< function::FuncDefOp > getConstrainFn(std::string_view structName) const
Definition Builders.h:132
Derived & insertConstrainFn(std::string_view structName)
Definition Builders.h:238
std::unordered_map< std::string_view, function::FuncDefOp > freeFuncMap
Definition Builders.h:48
void ensureNoSuchComputeFn(std::string_view structName)
Ensure that the given struct does not have a compute function, reporting a fatal error otherwise.
Definition Builders.cpp:69
Derived & insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc)
void ensureComputeFnExists(std::string_view structName)
Ensure that the given struct has a compute function, reporting a fatal error otherwise.
Definition Builders.cpp:76
void ensureConstrainFnExists(std::string_view structName)
Ensure that the given struct has a constrain function, reporting a fatal error otherwise.
Definition Builders.cpp:90
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
Definition Builders.cpp:150
std::unordered_map< std::string_view, function::FuncDefOp > computeFnMap
Definition Builders.h:52
std::unordered_map< std::string_view, component::StructDefOp > structMap
Definition Builders.h:50
Derived & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type)
Definition Builders.h:286
Derived & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
void ensureNoSuchStruct(std::string_view structName)
Ensure that a struct with the given structName has not been added, reporting a fatal error otherwise.
Definition Builders.cpp:55
ModuleLikeBuilder(mlir::MLIRContext *ctx)
Definition Builders.h:109
Derived & insertComputeOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc)
Definition Builders.h:166
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
Definition Builders.cpp:176
Derived & insertConstrainOnlyStruct(std::string_view structName)
Definition Builders.h:187
Derived & insertFullStruct(std::string_view structName)
Inserts a struct with both compute and constrain functions.
Definition Builders.h:203
TemplateBuilder(polymorphic::TemplateOp t)
Definition Builders.h:304
mlir::Region & getBodyRegion()
Definition Builders.h:308
polymorphic::TemplateOp & getTemplate()
Get the associated template of this builder.
Definition Builders.h:311
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *, Location loc)
Definition Builders.cpp:23
void addLangAttrForLLZKDialect(ModuleOp mod)
Definition Builders.cpp:29
mlir::Location getUnknownLoc(mlir::MLIRContext *context)
Definition Builders.h:23