14#include <llvm/Support/ErrorHandling.h>
24 auto mod = ModuleOp::create(loc);
30 MLIRContext *ctx =
mod.getContext();
31 if (
auto *dialect = ctx->getOrLoadDialect<
LLZKDialect>()) {
34 llvm::report_fatal_error(
"Could not load LLZK dialect!");
40template <
typename Derived>
43 llvm::report_fatal_error(
"global function " + Twine(funcName) +
" already exists!");
47template <
typename Derived>
50 llvm::report_fatal_error(
"global function " + Twine(funcName) +
" does not exist!");
54template <
typename Derived>
57 llvm::report_fatal_error(
"struct " + Twine(structName) +
" already exists!");
61template <
typename Derived>
64 llvm::report_fatal_error(
"struct " + Twine(structName) +
" does not exist!");
68template <
typename Derived>
71 llvm::report_fatal_error(
"struct " + Twine(structName) +
" already has a compute function!");
75template <
typename Derived>
78 llvm::report_fatal_error(
"struct " + Twine(structName) +
" has no compute function!");
82template <
typename Derived>
85 llvm::report_fatal_error(
"struct " + Twine(structName) +
" already has a constrain function!");
89template <
typename Derived>
92 llvm::report_fatal_error(
"struct " + Twine(structName) +
" has no constrain function!");
96template <
typename Derived>
99 llvm::report_fatal_error(
"struct " + Twine(structName) +
" already has a product function!");
103template <
typename Derived>
106 llvm::report_fatal_error(
"struct " + Twine(structName) +
" has no product function!");
110template <
typename Derived>
112 ensureNoSuchStruct(structName);
114 OpBuilder opBuilder(this->getBodyRegion());
115 auto structDef = opBuilder.create<StructDefOp>(loc, StringAttr::get(context, structName));
117 (void)structDef.getRegion().emplaceBlock();
118 structMap[structName] = structDef;
120 return static_cast<Derived &
>(*this);
123template <
typename Derived>
125 MLIRContext *
context = op.getContext();
131 fnOp.setAllowWitnessAttr();
132 fnOp.addEntryBlock();
136template <
typename Derived>
138 ensureNoSuchComputeFn(op.getName());
139 computeFnMap[op.getName()] = buildComputeFn(op, loc);
140 return static_cast<Derived &
>(*this);
143template <
typename Derived>
145 ensureStructExists(structName);
146 return insertComputeFn(structMap.at(structName), loc);
149template <
typename Derived>
151 MLIRContext *
context = op.getContext();
157 fnOp.setAllowConstraintAttr();
158 fnOp.addEntryBlock();
162template <
typename Derived>
164 ensureNoSuchConstrainFn(op.getName());
165 constrainFnMap[op.getName()] = buildConstrainFn(op, loc);
166 return static_cast<Derived &
>(*this);
169template <
typename Derived>
171 ensureStructExists(structName);
172 return insertConstrainFn(structMap.at(structName), loc);
175template <
typename Derived>
177 MLIRContext *
context = op.getContext();
183 fnOp.setAllowWitnessAttr();
184 fnOp.setAllowConstraintAttr();
185 fnOp.addEntryBlock();
189template <
typename Derived>
191 ensureNoSuchProductFn(op.getName());
192 productFnMap[op.getName()] = buildProductFn(op, loc);
193 return static_cast<Derived &
>(*this);
196template <
typename Derived>
198 ensureStructExists(structName);
199 return insertProductFn(structMap.at(structName), loc);
202template <
typename Derived>
206 ensureComputeFnExists(caller.getName());
207 ensureComputeFnExists(callee.getName());
209 auto callerFn = computeFnMap.at(caller.getName());
210 auto calleeFn = computeFnMap.at(callee.getName());
212 OpBuilder builder(callerFn.getBody());
213 builder.create<
CallOp>(callLoc, calleeFn);
214 return static_cast<Derived &
>(*this);
217template <
typename Derived>
219 std::string_view caller, std::string_view callee, Location callLoc
221 ensureStructExists(caller);
222 ensureStructExists(callee);
223 return insertComputeCall(structMap.at(caller), structMap.at(callee), callLoc);
226template <
typename Derived>
237 size_t numOps = caller.getBody()->getOperations().size();
238 auto memberName = StringAttr::get(
context, callee.getName().str() + std::to_string(numOps));
243 builder.create<
MemberDefOp>(memberDefLoc, memberName, calleeTy);
248 OpBuilder builder(callerFn.
getBody());
257 return static_cast<Derived &
>(*this);
260template <
typename Derived>
262 std::string_view caller, std::string_view callee, Location callLoc, Location memberDefLoc
264 ensureStructExists(caller);
265 ensureStructExists(callee);
266 return insertConstrainCall(structMap.at(caller), structMap.at(callee), callLoc, memberDefLoc);
269template <
typename Derived>
271 std::string_view funcName, FunctionType type, Location loc
273 ensureNoSuchFreeFunc(funcName);
275 OpBuilder opBuilder(this->getBodyRegion());
276 auto funcDef = opBuilder.create<
FuncDefOp>(loc, funcName, type);
277 (void)funcDef.addEntryBlock();
278 freeFuncMap[funcName] = funcDef;
280 return static_cast<Derived &
>(*this);
283template <
typename Derived>
285 FuncDefOp caller, std::string_view callee, Location callLoc
287 ensureFreeFnExists(callee);
288 FuncDefOp calleeFn = freeFuncMap.at(callee);
290 OpBuilder builder(caller.getBody());
291 builder.create<
CallOp>(callLoc, calleeFn);
292 return static_cast<Derived &
>(*this);
297void ModuleBuilder::ensureNoSuchTemplate(std::string_view templateName) {
298 if (templateMap.find(templateName) != templateMap.end()) {
299 llvm::report_fatal_error(
"template " + Twine(templateName) +
" already exists!");
303void ModuleBuilder::ensureTemplateExists(std::string_view templateName) {
304 if (templateMap.find(templateName) == templateMap.end()) {
305 llvm::report_fatal_error(
"template " + Twine(templateName) +
" does not exist!");
311 ensureNoSuchTemplate(templateName);
313 OpBuilder opBuilder(myModule.getBodyRegion());
314 auto templateDef = opBuilder.create<TemplateOp>(loc, StringAttr::get(
context, templateName));
315 opBuilder.setInsertionPointToStart(&templateDef.getBodyRegion().emplaceBlock());
316 for (
unsigned i = 0; i < numParams; ++i) {
317 opBuilder.create<TemplateParamOp>(
318 loc, StringAttr::get(
context,
'T' + std::to_string(i)), TypeAttr()
322 auto key = templateDef.getName();
323 templateMap.emplace(key, std::make_unique<TemplateBuilder>(templateDef));
328void ModuleBuilder::ensureNoSuchNestedModule(std::string_view moduleName) {
329 if (nestedModuleMap.find(moduleName) != nestedModuleMap.end()) {
330 llvm::report_fatal_error(
"nested module " + Twine(moduleName) +
" already exists!");
334void ModuleBuilder::ensureNestedModuleExists(std::string_view moduleName) {
335 if (nestedModuleMap.find(moduleName) == nestedModuleMap.end()) {
336 llvm::report_fatal_error(
"nested module " + Twine(moduleName) +
" does not exist!");
341 ensureNoSuchNestedModule(moduleName);
343 OpBuilder opBuilder(myModule.getBodyRegion());
344 auto nestedMod = opBuilder.create<ModuleOp>(loc);
345 nestedMod.setSymName(moduleName);
347 auto key = *nestedMod.getSymName();
348 nestedModuleMap.emplace(key, std::make_unique<ModuleBuilder>(nestedMod));
mlir::MLIRContext * context
Builds out a LLZK-compliant module and provides utilities for populating that module.
ModuleBuilder & insertTemplate(std::string_view templateName, mlir::Location loc, unsigned numParams=0)
ModuleBuilder & insertNestedModule(std::string_view moduleName, mlir::Location loc)
Derived & insertComputeFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location memberDefLoc)
To call a constraint function, you must:
std::unordered_map< std::string_view, function::FuncDefOp > productFnMap
void ensureNoSuchFreeFunc(std::string_view funcName)
Ensure that a global function with the given funcName has not been added, reporting a fatal error oth...
std::unordered_map< std::string_view, function::FuncDefOp > constrainFnMap
Derived & insertEmptyStruct(std::string_view structName, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
Derived & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
void ensureProductFnExists(std::string_view structName)
Ensure that the given struct has a product function, reporting a fatal error otherwise.
void ensureFreeFnExists(std::string_view funcName)
Ensure that a global function with the given funcName has been added, reporting a fatal error otherwi...
Derived & insertProductFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
void ensureNoSuchProductFn(std::string_view structName)
Ensure that the given struct does not have a product function, reporting a fatal error otherwise.
void ensureStructExists(std::string_view structName)
Ensure that a struct with the given structName exists, reporting a fatal error otherwise.
void ensureNoSuchConstrainFn(std::string_view structName)
Ensure that the given struct does not have a constrain function, reporting a fatal error otherwise.
Derived & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
std::unordered_map< std::string_view, function::FuncDefOp > freeFuncMap
void ensureNoSuchComputeFn(std::string_view structName)
Ensure that the given struct does not have a compute function, reporting a fatal error otherwise.
void ensureComputeFnExists(std::string_view structName)
Ensure that the given struct has a compute function, reporting a fatal error otherwise.
void ensureConstrainFnExists(std::string_view structName)
Ensure that the given struct has a constrain function, reporting a fatal error otherwise.
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
std::unordered_map< std::string_view, function::FuncDefOp > computeFnMap
std::unordered_map< std::string_view, component::StructDefOp > structMap
Derived & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
void ensureNoSuchStruct(std::string_view structName)
Ensure that a struct with the given structName has not been added, reporting a fatal error otherwise.
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
::mlir::Region & getBody()
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that identifies the ModuleOp as the root module and s...
constexpr char FUNC_NAME_CONSTRAIN[]
constexpr char FUNC_NAME_PRODUCT[]
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *, Location loc)
void addLangAttrForLLZKDialect(ModuleOp mod)