15#include <mlir/IR/Builders.h>
16#include <mlir/IR/MLIRContext.h>
19#include <unordered_map>
24 return mlir::UnknownLoc::get(context);
27mlir::OwningOpRef<mlir::ModuleOp>
createLLZKModule(mlir::MLIRContext *context, mlir::Location loc);
45template <
typename Derived>
class ModuleLikeBuilder :
public BaseBuilder {
49 ModuleLikeBuilder(mlir::MLIRContext *ctx) :
BaseBuilder(ctx) {}
53 std::unordered_map<std::string_view, function::FuncDefOp>
freeFuncMap;
55 std::unordered_map<std::string_view, component::StructDefOp>
structMap;
57 std::unordered_map<std::string_view, function::FuncDefOp>
computeFnMap;
61 std::unordered_map<std::string_view, function::FuncDefOp>
productFnMap;
118 mlir::FailureOr<component::StructDefOp>
getStruct(std::string_view structName)
const {
122 return mlir::failure();
125 mlir::FailureOr<function::FuncDefOp>
getComputeFn(std::string_view structName)
const {
129 return mlir::failure();
135 mlir::FailureOr<function::FuncDefOp>
getConstrainFn(std::string_view structName)
const {
139 return mlir::failure();
145 mlir::FailureOr<function::FuncDefOp>
getProductFn(std::string_view structName)
const {
149 return mlir::failure();
155 mlir::FailureOr<function::FuncDefOp>
getFreeFunc(std::string_view funcName)
const {
159 return mlir::failure();
170 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc
174 return static_cast<Derived &
>(*this);
183 std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc
187 return static_cast<Derived &
>(*this);
196 std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc,
197 mlir::Location constrainLoc
202 return static_cast<Derived &
>(*this);
212 std::string_view structName, mlir::Location structLoc, mlir::Location productLoc
216 return static_cast<Derived &
>(*this);
278 mlir::Location memberDefLoc
281 std::string_view caller, std::string_view callee, mlir::Location callLoc,
282 mlir::Location memberDefLoc
288 Derived &
insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc);
289 inline Derived &
insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type) {
316 mlir::FailureOr<polymorphic::TemplateParamOp>
getParam(std::string_view name) {
321 return mlir::failure();
324 mlir::FailureOr<polymorphic::TemplateExprOp>
getExpr(std::string_view name) {
329 return mlir::failure();
333 insertParam(std::string_view name, mlir::Location loc, mlir::TypeAttr type = {}) {
335 llvm::report_fatal_error(
"Duplicate TemplateParamOp insertion attempted");
338 mlir::OpBuilder builder(
context);
341 if (region.empty()) {
342 region.emplaceBlock();
345 builder.setInsertionPointToEnd(®ion.front());
347 auto nameAttr = builder.getStringAttr(name);
349 builder.create<polymorphic::TemplateParamOp>(loc, nameAttr, type);
359 if (succeeded(
getExpr(name))) {
360 llvm::report_fatal_error(
"Duplicate TemplateExprOp insertion attempted");
363 mlir::OpBuilder builder(
context);
366 if (region.empty()) {
367 region.emplaceBlock();
370 builder.setInsertionPointToEnd(®ion.front());
372 auto nameAttr = builder.getStringAttr(name);
388 mlir::ModuleOp myModule;
391 std::unordered_map<std::string_view, std::unique_ptr<TemplateBuilder>> templateMap;
393 std::unordered_map<std::string_view, std::unique_ptr<ModuleBuilder>> nestedModuleMap;
398 void ensureNoSuchTemplate(std::string_view templateName);
403 void ensureTemplateExists(std::string_view templateName);
408 void ensureNoSuchNestedModule(std::string_view moduleName);
413 void ensureNestedModuleExists(std::string_view moduleName);
416 ModuleBuilder(mlir::ModuleOp m) : ModuleLikeBuilder(m.getContext()), myModule(m) {}
425 mlir::FailureOr<TemplateBuilder *>
getTemplate(std::string_view templateName)
const {
426 auto it = templateMap.find(templateName);
427 if (it != templateMap.end()) {
428 return it->second.get();
430 return mlir::failure();
434 auto it = nestedModuleMap.find(moduleName);
435 if (it != nestedModuleMap.end()) {
436 return it->second.get();
438 return mlir::failure();
444 insertTemplate(std::string_view templateName, mlir::Location loc,
unsigned numParams = 0);
mlir::Location getUnknownLoc()
mlir::MLIRContext * context
BaseBuilder(mlir::MLIRContext *ctx)
Builds out a LLZK-compliant module and provides utilities for populating that module.
mlir::Region & getBodyRegion()
ModuleBuilder(mlir::ModuleOp m)
mlir::FailureOr< TemplateBuilder * > getTemplate(std::string_view templateName) const
ModuleBuilder & insertNestedModule(std::string_view moduleName)
ModuleBuilder & insertTemplate(std::string_view templateName, mlir::Location loc, unsigned numParams=0)
mlir::FailureOr< ModuleBuilder * > getNestedModule(std::string_view moduleName) const
ModuleBuilder & insertTemplate(std::string_view templateName, unsigned numParams=0)
ModuleBuilder & insertNestedModule(std::string_view moduleName, mlir::Location loc)
mlir::ModuleOp & getModule()
Get the associated module of this builder.
mlir::Region & getBodyRegion()
Derived & insertComputeFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc, mlir::Location memberDefLoc)
To call a constraint function, you must:
std::unordered_map< std::string_view, function::FuncDefOp > productFnMap
Derived & insertComputeFn(std::string_view structName, mlir::Location loc)
Derived & insertComputeFn(std::string_view structName)
Derived & insertComputeOnlyStruct(std::string_view structName)
Derived & insertProductStruct(std::string_view structName)
mlir::FailureOr< function::FuncDefOp > getProductFn(component::StructDefOp op) const
Derived & insertConstrainCall(std::string_view caller, std::string_view callee)
Derived & insertProductFn(std::string_view structName, mlir::Location loc)
Derived & insertComputeCall(std::string_view caller, std::string_view callee)
void ensureNoSuchFreeFunc(std::string_view funcName)
Ensure that a global function with the given funcName has not been added, reporting a fatal error oth...
mlir::FailureOr< function::FuncDefOp > getProductFn(std::string_view structName) const
Derived & insertProductStruct(std::string_view structName, mlir::Location structLoc, mlir::Location productLoc)
mlir::FailureOr< function::FuncDefOp > getFreeFunc(std::string_view funcName) const
mlir::FailureOr< function::FuncDefOp > getConstrainFn(component::StructDefOp op) const
std::unordered_map< std::string_view, function::FuncDefOp > constrainFnMap
Derived & insertFreeCall(function::FuncDefOp caller, std::string_view callee)
Derived & insertEmptyStruct(std::string_view structName, mlir::Location loc)
static function::FuncDefOp buildComputeFn(component::StructDefOp op, mlir::Location loc)
compute returns the type of the struct that defines it.
Derived & insertConstrainCall(std::string_view caller, std::string_view callee, mlir::Location callLoc, mlir::Location memberDefLoc)
Derived & insertFullStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc, mlir::Location constrainLoc)
Derived & insertComputeCall(component::StructDefOp caller, component::StructDefOp callee, mlir::Location callLoc)
Only requirement for compute is the call itself.
mlir::FailureOr< function::FuncDefOp > getComputeFn(std::string_view structName) const
Derived & insertConstrainOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location constrainLoc)
mlir::FailureOr< function::FuncDefOp > getComputeFn(component::StructDefOp op) const
Derived & insertConstrainFn(std::string_view structName, mlir::Location loc)
void ensureProductFnExists(std::string_view structName)
Ensure that the given struct has a product function, reporting a fatal error otherwise.
void ensureFreeFnExists(std::string_view funcName)
Ensure that a global function with the given funcName has been added, reporting a fatal error otherwi...
Derived & insertProductFn(component::StructDefOp op, mlir::Location loc)
Derived & insertConstrainFn(component::StructDefOp op, mlir::Location loc)
Derived & insertEmptyStruct(std::string_view structName)
mlir::FailureOr< component::StructDefOp > getStruct(std::string_view structName) const
void ensureNoSuchProductFn(std::string_view structName)
Ensure that the given struct does not have a product function, reporting a fatal error otherwise.
void ensureStructExists(std::string_view structName)
Ensure that a struct with the given structName exists, reporting a fatal error otherwise.
void ensureNoSuchConstrainFn(std::string_view structName)
Ensure that the given struct does not have a constrain function, reporting a fatal error otherwise.
Derived & insertProductFn(std::string_view structName)
Derived & insertFreeCall(function::FuncDefOp caller, std::string_view callee, mlir::Location callLoc)
mlir::FailureOr< function::FuncDefOp > getConstrainFn(std::string_view structName) const
Derived & insertConstrainFn(std::string_view structName)
std::unordered_map< std::string_view, function::FuncDefOp > freeFuncMap
void ensureNoSuchComputeFn(std::string_view structName)
Ensure that the given struct does not have a compute function, reporting a fatal error otherwise.
Derived & insertComputeCall(std::string_view caller, std::string_view callee, mlir::Location callLoc)
void ensureComputeFnExists(std::string_view structName)
Ensure that the given struct has a compute function, reporting a fatal error otherwise.
void ensureConstrainFnExists(std::string_view structName)
Ensure that the given struct has a constrain function, reporting a fatal error otherwise.
static function::FuncDefOp buildConstrainFn(component::StructDefOp op, mlir::Location loc)
constrain accepts the struct type as the first argument.
std::unordered_map< std::string_view, function::FuncDefOp > computeFnMap
std::unordered_map< std::string_view, component::StructDefOp > structMap
Derived & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type)
Derived & insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc)
void ensureNoSuchStruct(std::string_view structName)
Ensure that a struct with the given structName has not been added, reporting a fatal error otherwise.
Derived & insertComputeOnlyStruct(std::string_view structName, mlir::Location structLoc, mlir::Location computeLoc)
static function::FuncDefOp buildProductFn(component::StructDefOp op, mlir::Location loc)
product returns the type of the struct that defines it.
Derived & insertConstrainOnlyStruct(std::string_view structName)
Derived & insertFullStruct(std::string_view structName)
Inserts a struct with both compute and constrain functions.
Builds out a LLZK-compliant template and provides utilities for populating that template.
TemplateBuilder & insertParam(std::string_view name, mlir::Location loc, mlir::TypeAttr type={})
TemplateBuilder & insertExpr(std::string_view name)
TemplateBuilder(polymorphic::TemplateOp t)
mlir::FailureOr< polymorphic::TemplateParamOp > getParam(std::string_view name)
mlir::Region & getBodyRegion()
TemplateBuilder & insertParam(std::string_view name)
mlir::FailureOr< polymorphic::TemplateExprOp > getExpr(std::string_view name)
TemplateBuilder & insertExpr(std::string_view name, mlir::Location loc)
polymorphic::TemplateOp & getTemplate()
Get the associated template of this builder.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
OwningOpRef< ModuleOp > createLLZKModule(MLIRContext *, Location loc)
void addLangAttrForLLZKDialect(ModuleOp mod)
mlir::Location getUnknownLoc(mlir::MLIRContext *context)