LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Builders.cpp
Go to the documentation of this file.
1//===-- Builders.cpp - Operation builder implementations --------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
13
14#include <llvm/Support/ErrorHandling.h>
15
16namespace llzk {
17
18using namespace mlir;
19using namespace component;
20using namespace function;
21using namespace polymorphic;
22
23OwningOpRef<ModuleOp> createLLZKModule(MLIRContext * /*context*/, Location loc) {
24 auto mod = ModuleOp::create(loc);
26 return mod;
27}
28
30 MLIRContext *ctx = mod.getContext();
31 if (auto *dialect = ctx->getOrLoadDialect<LLZKDialect>()) {
32 mod->setAttr(LANG_ATTR_NAME, StringAttr::get(ctx, dialect->getNamespace()));
33 } else {
34 llvm::report_fatal_error("Could not load LLZK dialect!");
35 }
36}
37
38/* ModuleLikeBuilder */
39
40template <typename Derived>
41void ModuleLikeBuilder<Derived>::ensureNoSuchFreeFunc(std::string_view funcName) {
42 if (freeFuncMap.find(funcName) != freeFuncMap.end()) {
43 llvm::report_fatal_error("global function " + Twine(funcName) + " already exists!");
44 }
45}
46
47template <typename Derived>
48void ModuleLikeBuilder<Derived>::ensureFreeFnExists(std::string_view funcName) {
49 if (freeFuncMap.find(funcName) == freeFuncMap.end()) {
50 llvm::report_fatal_error("global function " + Twine(funcName) + " does not exist!");
51 }
52}
53
54template <typename Derived>
55void ModuleLikeBuilder<Derived>::ensureNoSuchStruct(std::string_view structName) {
56 if (structMap.find(structName) != structMap.end()) {
57 llvm::report_fatal_error("struct " + Twine(structName) + " already exists!");
58 }
59}
60
61template <typename Derived>
62void ModuleLikeBuilder<Derived>::ensureStructExists(std::string_view structName) {
63 if (structMap.find(structName) == structMap.end()) {
64 llvm::report_fatal_error("struct " + Twine(structName) + " does not exist!");
65 }
67
68template <typename Derived>
69void ModuleLikeBuilder<Derived>::ensureNoSuchComputeFn(std::string_view structName) {
70 if (computeFnMap.find(structName) != computeFnMap.end()) {
71 llvm::report_fatal_error("struct " + Twine(structName) + " already has a compute function!");
72 }
73}
74
75template <typename Derived>
76void ModuleLikeBuilder<Derived>::ensureComputeFnExists(std::string_view structName) {
77 if (computeFnMap.find(structName) == computeFnMap.end()) {
78 llvm::report_fatal_error("struct " + Twine(structName) + " has no compute function!");
79 }
80}
82template <typename Derived>
83void ModuleLikeBuilder<Derived>::ensureNoSuchConstrainFn(std::string_view structName) {
84 if (constrainFnMap.find(structName) != constrainFnMap.end()) {
85 llvm::report_fatal_error("struct " + Twine(structName) + " already has a constrain function!");
86 }
87}
88
89template <typename Derived>
90void ModuleLikeBuilder<Derived>::ensureConstrainFnExists(std::string_view structName) {
91 if (constrainFnMap.find(structName) == constrainFnMap.end()) {
92 llvm::report_fatal_error("struct " + Twine(structName) + " has no constrain function!");
93 }
94}
95
96template <typename Derived>
97void ModuleLikeBuilder<Derived>::ensureNoSuchProductFn(std::string_view structName) {
98 if (productFnMap.find(structName) != productFnMap.end()) {
99 llvm::report_fatal_error("struct " + Twine(structName) + " already has a product function!");
100 }
102
103template <typename Derived>
104void ModuleLikeBuilder<Derived>::ensureProductFnExists(std::string_view structName) {
105 if (productFnMap.find(structName) == productFnMap.end()) {
106 llvm::report_fatal_error("struct " + Twine(structName) + " has no product function!");
107 }
108}
109
110template <typename Derived>
111Derived &ModuleLikeBuilder<Derived>::insertEmptyStruct(std::string_view structName, Location loc) {
112 ensureNoSuchStruct(structName);
113
114 OpBuilder opBuilder(this->getBodyRegion());
115 auto structDef = opBuilder.create<StructDefOp>(loc, StringAttr::get(context, structName));
116 // populate the initial region
117 (void)structDef.getRegion().emplaceBlock();
118 structMap[structName] = structDef;
119
120 return static_cast<Derived &>(*this);
121}
122
123template <typename Derived>
125 MLIRContext *context = op.getContext();
126 OpBuilder opBuilder(op.getBodyRegion());
127 auto fnOp = opBuilder.create<FuncDefOp>(
128 loc, StringAttr::get(context, FUNC_NAME_COMPUTE),
129 FunctionType::get(context, {}, {op.getType()})
130 );
131 fnOp.setAllowWitnessAttr();
132 fnOp.addEntryBlock();
133 return fnOp;
134}
135
136template <typename Derived>
138 ensureNoSuchComputeFn(op.getName());
139 computeFnMap[op.getName()] = buildComputeFn(op, loc);
140 return static_cast<Derived &>(*this);
141}
142
143template <typename Derived>
144Derived &ModuleLikeBuilder<Derived>::insertComputeFn(std::string_view structName, Location loc) {
145 ensureStructExists(structName);
146 return insertComputeFn(structMap.at(structName), loc);
147}
148
149template <typename Derived>
151 MLIRContext *context = op.getContext();
152 OpBuilder opBuilder(op.getBodyRegion());
153 auto fnOp = opBuilder.create<FuncDefOp>(
154 loc, StringAttr::get(context, FUNC_NAME_CONSTRAIN),
155 FunctionType::get(context, {op.getType()}, {})
156 );
157 fnOp.setAllowConstraintAttr();
158 fnOp.addEntryBlock();
159 return fnOp;
160}
161
162template <typename Derived>
164 ensureNoSuchConstrainFn(op.getName());
165 constrainFnMap[op.getName()] = buildConstrainFn(op, loc);
166 return static_cast<Derived &>(*this);
167}
168
169template <typename Derived>
170Derived &ModuleLikeBuilder<Derived>::insertConstrainFn(std::string_view structName, Location loc) {
171 ensureStructExists(structName);
172 return insertConstrainFn(structMap.at(structName), loc);
173}
174
175template <typename Derived>
177 MLIRContext *context = op.getContext();
178 OpBuilder opBuilder(op.getBodyRegion());
179 auto fnOp = opBuilder.create<FuncDefOp>(
180 loc, StringAttr::get(context, FUNC_NAME_PRODUCT),
181 FunctionType::get(context, {}, {op.getType()})
182 );
183 fnOp.setAllowWitnessAttr();
184 fnOp.setAllowConstraintAttr();
185 fnOp.addEntryBlock();
186 return fnOp;
187}
188
189template <typename Derived>
191 ensureNoSuchProductFn(op.getName());
192 productFnMap[op.getName()] = buildProductFn(op, loc);
193 return static_cast<Derived &>(*this);
194}
195
196template <typename Derived>
197Derived &ModuleLikeBuilder<Derived>::insertProductFn(std::string_view structName, Location loc) {
198 ensureStructExists(structName);
199 return insertProductFn(structMap.at(structName), loc);
200}
201
202template <typename Derived>
204 StructDefOp caller, StructDefOp callee, Location callLoc
205) {
206 ensureComputeFnExists(caller.getName());
207 ensureComputeFnExists(callee.getName());
208
209 auto callerFn = computeFnMap.at(caller.getName());
210 auto calleeFn = computeFnMap.at(callee.getName());
211
212 OpBuilder builder(callerFn.getBody());
213 builder.create<CallOp>(callLoc, calleeFn);
214 return static_cast<Derived &>(*this);
215}
216
217template <typename Derived>
219 std::string_view caller, std::string_view callee, Location callLoc
220) {
221 ensureStructExists(caller);
222 ensureStructExists(callee);
223 return insertComputeCall(structMap.at(caller), structMap.at(callee), callLoc);
224}
226template <typename Derived>
228 StructDefOp caller, StructDefOp callee, Location callLoc, Location memberDefLoc
229) {
230 ensureConstrainFnExists(caller.getName());
231 ensureConstrainFnExists(callee.getName());
232
233 FuncDefOp callerFn = constrainFnMap.at(caller.getName());
234 FuncDefOp calleeFn = constrainFnMap.at(callee.getName());
235 StructType calleeTy = callee.getType();
236
237 size_t numOps = caller.getBody()->getOperations().size();
238 auto memberName = StringAttr::get(context, callee.getName().str() + std::to_string(numOps));
239
240 // Insert the member declaration op
241 {
242 OpBuilder builder(caller.getBodyRegion());
243 builder.create<MemberDefOp>(memberDefLoc, memberName, calleeTy);
244 }
245
246 // Insert the constrain function ops
247 {
248 OpBuilder builder(callerFn.getBody());
249
250 auto member = builder.create<MemberReadOp>(
251 callLoc, calleeTy, callerFn.getSelfValueFromConstrain(), memberName
252 );
253 builder.create<CallOp>(
254 callLoc, TypeRange {}, calleeFn.getFullyQualifiedName(), ValueRange {member}
255 );
256 }
257 return static_cast<Derived &>(*this);
258}
259
260template <typename Derived>
262 std::string_view caller, std::string_view callee, Location callLoc, Location memberDefLoc
263) {
264 ensureStructExists(caller);
265 ensureStructExists(callee);
266 return insertConstrainCall(structMap.at(caller), structMap.at(callee), callLoc, memberDefLoc);
267}
268
269template <typename Derived>
271 std::string_view funcName, FunctionType type, Location loc
272) {
273 ensureNoSuchFreeFunc(funcName);
274
275 OpBuilder opBuilder(this->getBodyRegion());
276 auto funcDef = opBuilder.create<FuncDefOp>(loc, funcName, type);
277 (void)funcDef.addEntryBlock();
278 freeFuncMap[funcName] = funcDef;
279
280 return static_cast<Derived &>(*this);
281}
282
283template <typename Derived>
285 FuncDefOp caller, std::string_view callee, Location callLoc
286) {
287 ensureFreeFnExists(callee);
288 FuncDefOp calleeFn = freeFuncMap.at(callee);
289
290 OpBuilder builder(caller.getBody());
291 builder.create<CallOp>(callLoc, calleeFn);
292 return static_cast<Derived &>(*this);
293}
294
295/* ModuleBuilder */
296
297void ModuleBuilder::ensureNoSuchTemplate(std::string_view templateName) {
298 if (templateMap.find(templateName) != templateMap.end()) {
299 llvm::report_fatal_error("template " + Twine(templateName) + " already exists!");
300 }
301}
302
303void ModuleBuilder::ensureTemplateExists(std::string_view templateName) {
304 if (templateMap.find(templateName) == templateMap.end()) {
305 llvm::report_fatal_error("template " + Twine(templateName) + " does not exist!");
306 }
307}
308
310ModuleBuilder::insertTemplate(std::string_view templateName, Location loc, unsigned numParams) {
311 ensureNoSuchTemplate(templateName);
312
313 OpBuilder opBuilder(myModule.getBodyRegion());
314 auto templateDef = opBuilder.create<TemplateOp>(loc, StringAttr::get(context, templateName));
315 opBuilder.setInsertionPointToStart(&templateDef.getBodyRegion().emplaceBlock());
316 for (unsigned i = 0; i < numParams; ++i) {
317 opBuilder.create<TemplateParamOp>(
318 loc, StringAttr::get(context, 'T' + std::to_string(i)), TypeAttr()
319 );
320 }
321
322 auto key = templateDef.getName();
323 templateMap.emplace(key, std::make_unique<TemplateBuilder>(templateDef));
324
325 return *this;
326}
327
328void ModuleBuilder::ensureNoSuchNestedModule(std::string_view moduleName) {
329 if (nestedModuleMap.find(moduleName) != nestedModuleMap.end()) {
330 llvm::report_fatal_error("nested module " + Twine(moduleName) + " already exists!");
331 }
332}
333
334void ModuleBuilder::ensureNestedModuleExists(std::string_view moduleName) {
335 if (nestedModuleMap.find(moduleName) == nestedModuleMap.end()) {
336 llvm::report_fatal_error("nested module " + Twine(moduleName) + " does not exist!");
337 }
338}
339
340ModuleBuilder &ModuleBuilder::insertNestedModule(std::string_view moduleName, Location loc) {
341 ensureNoSuchNestedModule(moduleName);
342
343 OpBuilder opBuilder(myModule.getBodyRegion());
344 auto nestedMod = opBuilder.create<ModuleOp>(loc);
345 nestedMod.setSymName(moduleName);
346
347 auto key = *nestedMod.getSymName();
348 nestedModuleMap.emplace(key, std::make_unique<ModuleBuilder>(nestedMod));
349
350 return *this;
351}
352
353/* Explicit template instantiations */
354
357
358} // namespace llzk
mlir::MLIRContext * context
Definition Builders.h:37
Builds out a LLZK-compliant module and provides utilities for populating that module.
Definition Builders.h:323
ModuleBuilder & insertTemplate(std::string_view templateName, mlir::Location loc, unsigned numParams=0)
ModuleBuilder & insertNestedModule(std::string_view moduleName, mlir::Location loc)
Derived & insertComputeFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location memberDefLoc)
To call a constraint function, you must:
std::unordered_map< std::string_view, function::FuncDefOp > productFnMap
Definition Builders.h:56
void ensureNoSuchFreeFunc(std::string_view funcName)
Ensure that a global function with the given funcName has not been added, reporting a fatal error oth...
Definition Builders.cpp:41
std::unordered_map< std::string_view, function::FuncDefOp > constrainFnMap
Definition Builders.h:54
Derived & insertEmptyStruct(std::string_view structName, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
Definition Builders.cpp:124
Derived & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
void ensureProductFnExists(std::string_view structName)
Ensure that the given struct has a product function, reporting a fatal error otherwise.
Definition Builders.cpp:104
void ensureFreeFnExists(std::string_view funcName)
Ensure that a global function with the given funcName has been added, reporting a fatal error otherwi...
Definition Builders.cpp:48
Derived & insertProductFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
void ensureNoSuchProductFn(std::string_view structName)
Ensure that the given struct does not have a product function, reporting a fatal error otherwise.
Definition Builders.cpp:97
void ensureStructExists(std::string_view structName)
Ensure that a struct with the given structName exists, reporting a fatal error otherwise.
Definition Builders.cpp:62
void ensureNoSuchConstrainFn(std::string_view structName)
Ensure that the given struct does not have a constrain function, reporting a fatal error otherwise.
Definition Builders.cpp:83
Derived & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
std::unordered_map< std::string_view, function::FuncDefOp > freeFuncMap
Definition Builders.h:48
void ensureNoSuchComputeFn(std::string_view structName)
Ensure that the given struct does not have a compute function, reporting a fatal error otherwise.
Definition Builders.cpp:69
void ensureComputeFnExists(std::string_view structName)
Ensure that the given struct has a compute function, reporting a fatal error otherwise.
Definition Builders.cpp:76
void ensureConstrainFnExists(std::string_view structName)
Ensure that the given struct has a constrain function, reporting a fatal error otherwise.
Definition Builders.cpp:90
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
Definition Builders.cpp:150
std::unordered_map< std::string_view, function::FuncDefOp > computeFnMap
Definition Builders.h:52
std::unordered_map< std::string_view, component::StructDefOp > structMap
Definition Builders.h:50
Derived & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
void ensureNoSuchStruct(std::string_view structName)
Ensure that a struct with the given structName has not been added, reporting a fatal error otherwise.
Definition Builders.cpp:55
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
Definition Builders.cpp:176
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:1189
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:382
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:353
::mlir::Region & getBody()
Definition Ops.h.inc:690
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that identifies the ModuleOp as the root module and s...
Definition Constants.h:23
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *, Location loc)
Definition Builders.cpp:23
void addLangAttrForLLZKDialect(ModuleOp mod)
Definition Builders.cpp:29