LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Func and call op implementations --------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from the LLVM Project's lib/Dialect/Func/IR/FuncOps.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
14
16
27#include "llzk/Util/Compare.h"
32
33#include <mlir/IR/IRMapping.h>
34#include <mlir/IR/OpImplementation.h>
35#include <mlir/Interfaces/FunctionImplementation.h>
36
37#include <llvm/ADT/DenseSet.h>
38#include <llvm/ADT/MapVector.h>
39
40// TableGen'd implementation files
41#define GET_OP_CLASSES
43
44using namespace mlir;
45using namespace llzk::felt;
46using namespace llzk::component;
47using namespace llzk::polymorphic;
48
49namespace llzk::function {
50
51FunctionKind fnNameToKind(mlir::StringRef name) {
52 if (FUNC_NAME_COMPUTE == name) {
54 } else if (FUNC_NAME_CONSTRAIN == name) {
56 } else if (FUNC_NAME_PRODUCT == name) {
58 } else {
59 return FunctionKind::Free;
60 }
61}
62
63namespace {
65inline LogicalResult
66verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, FunctionType funcType) {
68 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
69 );
70}
71} // namespace
72
73//===----------------------------------------------------------------------===//
74// FuncDefOp
75//===----------------------------------------------------------------------===//
76
78 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
79) {
80 return delegate_to_build<FuncDefOp>(location, name, type, attrs);
81}
82
84 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
85) {
86 SmallVector<NamedAttribute, 8> attrRef(attrs);
87 return create(location, name, type, llvm::ArrayRef(attrRef));
88}
89
91 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
92 ArrayRef<DictionaryAttr> argAttrs
93) {
94 FuncDefOp func = create(location, name, type, attrs);
95 func.setAllArgAttrs(argAttrs);
96 return func;
97}
98
100 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
101 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
102) {
103 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
104 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
105 state.attributes.append(attrs.begin(), attrs.end());
106 state.addRegion();
107
108 if (argAttrs.empty()) {
109 return;
110 }
111 assert(type.getNumInputs() == argAttrs.size());
112 function_interface_impl::addArgAndResultAttrs(
113 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name),
114 getResAttrsAttrName(state.name)
115 );
116}
117
118ParseResult FuncDefOp::parse(OpAsmParser &parser, OperationState &result) {
119 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
120 function_interface_impl::VariadicFlag,
121 std::string &) { return builder.getFunctionType(argTypes, results); };
122
123 return function_interface_impl::parseFunctionOp(
124 parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType,
125 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)
126 );
127}
128
129void FuncDefOp::print(OpAsmPrinter &p) {
130 function_interface_impl::printFunctionOp(
131 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(),
133 );
134}
135
138void FuncDefOp::cloneInto(FuncDefOp dest, IRMapping &mapper) {
139 // Add the attributes of this function to dest.
140 llvm::MapVector<StringAttr, Attribute> newAttrMap;
141 for (const auto &attr : dest->getAttrs()) {
142 newAttrMap.insert({attr.getName(), attr.getValue()});
143 }
144 for (const auto &attr : (*this)->getAttrs()) {
145 newAttrMap.insert({attr.getName(), attr.getValue()});
146 }
147
148 auto newAttrs =
149 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
150 return NamedAttribute(attrPair.first, attrPair.second);
151 }));
152 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
153
154 // Clone the body.
155 getBody().cloneInto(&dest.getBody(), mapper);
156}
157
163FuncDefOp FuncDefOp::clone(IRMapping &mapper) {
164 // Create the new function.
165 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
166
167 // If the function has a body, then the user might be deleting arguments to
168 // the function by specifying them in the mapper. If so, we don't add the
169 // argument to the input type vector.
170 if (!isExternal()) {
171 FunctionType oldType = getFunctionType();
172
173 unsigned oldNumArgs = oldType.getNumInputs();
174 SmallVector<Type, 4> newInputs;
175 newInputs.reserve(oldNumArgs);
176 for (unsigned i = 0; i != oldNumArgs; ++i) {
177 if (!mapper.contains(getArgument(i))) {
178 newInputs.push_back(oldType.getInput(i));
179 }
180 }
181
184 if (newInputs.size() != oldNumArgs) {
185 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
186
187 if (ArrayAttr argAttrs = getAllArgAttrs()) {
188 SmallVector<Attribute> newArgAttrs;
189 newArgAttrs.reserve(newInputs.size());
190 for (unsigned i = 0; i != oldNumArgs; ++i) {
191 if (!mapper.contains(getArgument(i))) {
192 newArgAttrs.push_back(argAttrs[i]);
193 }
194 }
195 newFunc.setAllArgAttrs(newArgAttrs);
196 }
197 }
198 }
199
201 cloneInto(newFunc, mapper);
202 return newFunc;
203}
204
206 IRMapping mapper;
207 return clone(mapper);
208}
209
211 if (newValue) {
212 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
213 } else {
214 getOperation()->removeAttr(AllowConstraintAttr::name);
215 }
216}
217
219 if (newValue) {
220 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
221 } else {
222 getOperation()->removeAttr(AllowWitnessAttr::name);
223 }
224}
225
227 if (newValue) {
228 getOperation()->setAttr(AllowNonNativeFieldOpsAttr::name, UnitAttr::get(getContext()));
229 } else {
230 getOperation()->removeAttr(AllowNonNativeFieldOpsAttr::name);
231 }
232}
233
234bool FuncDefOp::hasArgPublicAttr(unsigned index) {
235 if (index < this->getNumArguments()) {
236 DictionaryAttr res = function_interface_impl::getArgAttrDict(*this, index);
237 return res ? res.contains(PublicAttr::name) : false;
238 } else {
239 // TODO: print error? requested attribute for non-existant argument index
240 return false;
241 }
242}
243
244LogicalResult FuncDefOp::verify() {
245 OwningEmitErrorFn emitErrorFunc = getEmitOpErrFn(this);
246 // Ensure that only valid LLZK types are used for arguments and return. Additionally, the struct
247 // functions may not use AffineMapAttrs in their parameter types. If such a scenario seems to make
248 // sense when generating LLZK IR, it's likely better to introduce a struct parameter to use
249 // instead and instantiate the struct with that AffineMapAttr.
250 FunctionType type = getFunctionType();
251 for (Type t : type.getInputs()) {
252 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
253 return failure();
254 }
255 if (isInStruct() && hasAffineMapAttr(t)) {
256 return emitErrorFunc().append(
257 "\"@", getName(), "\" parameters cannot contain affine map attributes but found ", t
258 );
259 }
260 }
261 for (Type t : type.getResults()) {
262 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
263 return failure();
264 }
265 }
266 // Ensure that the function does not contain nested modules.
267 // Functions also cannot contain nested structs, but this check is handled
268 // via struct.def's requirement of having module as a parent.
269 WalkResult res = this->walk<WalkOrder::PreOrder>([this](ModuleOp nestedMod) {
270 getEmitOpErrFn(nestedMod)().append(
271 "cannot be nested within '", getOperation()->getName(), "' operations"
272 );
273 return WalkResult::interrupt();
274 });
275 if (res.wasInterrupted()) {
276 return failure();
277 }
278
279 return success();
280}
281
282namespace {
283
284LogicalResult
285verifyFuncTypeCompute(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
286 FunctionType funcType = origin.getFunctionType();
287 llvm::ArrayRef<Type> resTypes = funcType.getResults();
288 // Must return type of parent struct
289 if (resTypes.size() != 1) {
290 return origin.emitOpError().append(
291 "\"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
292 );
293 }
294 if (failed(checkSelfType(tables, parent, resTypes.front(), origin, "return"))) {
295 return failure();
296 }
297
298 // After the more specific checks (to ensure more specific error messages would be produced if
299 // necessary), do the general check that all symbol references in the types are valid. The return
300 // types were already checked so just check the input types.
301 return llzk::verifyTypeResolution(tables, origin, funcType.getInputs());
302}
303
304LogicalResult
305verifyFuncTypeProduct(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
306 // The signature for @product is the same as the signature for @compute
307 return verifyFuncTypeCompute(origin, tables, parent);
308}
309
310LogicalResult
311verifyFuncTypeConstrain(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
312 FunctionType funcType = origin.getFunctionType();
313 // Must return '()' type, i.e., have no return types
314 if (funcType.getResults().size() != 0) {
315 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
316 }
317
318 // Type of the first parameter must match the parent StructDefOp of the current operation.
319 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
320 if (inputTypes.size() < 1) {
321 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN
322 << "\" must have at least one input type";
323 }
324 if (failed(checkSelfType(tables, parent, inputTypes.front(), origin, "first input"))) {
325 return failure();
326 }
327
328 // After the more specific checks (to ensure more specific error messages would be produced if
329 // necessary), do the general check that all symbol references in the types are valid. There are
330 // no return types, just check the remaining input types (the first was already checked via
331 // the checkSelfType() call above).
332 return llzk::verifyTypeResolution(tables, origin, inputTypes.drop_front());
333}
334
335} // namespace
336
337LogicalResult FuncDefOp::verifySymbolUses(SymbolTableCollection &tables) {
338 // Additional checks for the compute/constrain/product functions within a struct
339 if (StructDefOp parentStructOpt = getParentOfType<StructDefOp>(*this)) {
340 // Verify return type restrictions for functions within a StructDefOp
341 if (nameIsCompute()) {
342 return verifyFuncTypeCompute(*this, tables, parentStructOpt);
343 } else if (nameIsConstrain()) {
344 return verifyFuncTypeConstrain(*this, tables, parentStructOpt);
345 } else if (nameIsProduct()) {
346 return verifyFuncTypeProduct(*this, tables, parentStructOpt);
347 }
348 }
349 // In the general case, verify symbol resolution in all input and output types.
350 return verifyTypeResolution(tables, *this, getFunctionType());
351}
352
353SymbolRefAttr FuncDefOp::getFullyQualifiedName(bool requireParent) {
354 // If the parent is not present and not required, just return the symbol name
355 if (!requireParent && getOperation()->getParentOp() == nullptr) {
356 return SymbolRefAttr::get(getOperation());
357 }
358 auto res = getPathFromRoot(*this);
359 assert(succeeded(res));
360 return res.value();
361}
362
364 assert(nameIsCompute()); // skip inStruct check to allow dangling functions
365 // Get the single block of the function body
366 Region &body = getBody();
367 assert(!body.empty() && "compute() function body is empty");
368 Block &block = body.back();
369
370 // The terminator should be the return op
371 Operation *terminator = block.getTerminator();
372 assert(terminator && "compute() function has no terminator");
373 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
374 if (!retOp) {
375 llvm::errs() << "Expected '" << ReturnOp::getOperationName() << "' but found '"
376 << terminator->getName() << "'\n";
377 llvm_unreachable("compute() function must end with ReturnOp");
378 }
379 return retOp.getOperands().front();
380}
381
383 assert(nameIsConstrain()); // skip inStruct check to allow dangling functions
384 return getArguments().front();
385}
386
388 assert(isStructCompute() && "violated implementation pre-condition");
390}
391
392//===----------------------------------------------------------------------===//
393// ReturnOp
394//===----------------------------------------------------------------------===//
395
396LogicalResult ReturnOp::verify() {
397 auto function = getParentOp<FuncDefOp>(); // parent is FuncDefOp per ODS
398
399 // The operand number and types must match the function signature.
400 const auto results = function.getFunctionType().getResults();
401 if (getNumOperands() != results.size()) {
402 return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@"
403 << function.getName() << ") returns " << results.size();
404 }
405
406 for (unsigned i = 0, e = results.size(); i != e; ++i) {
407 if (!typesUnify(getOperand(i).getType(), results[i])) {
408 return emitError() << "type of return operand " << i << " (" << getOperand(i).getType()
409 << ") doesn't match function result type (" << results[i] << ")"
410 << " in function @" << function.getName();
411 }
412 }
413
414 return success();
415}
416
417//===----------------------------------------------------------------------===//
418// CallOp
419//===----------------------------------------------------------------------===//
420
421// Custom implementation to deserialize bytecode produced prior to version 2 which added optional
422// `OptionalAttr<ArrayAttr>:$templateParams`.
423LogicalResult CallOp::readProperties(DialectBytecodeReader &reader, OperationState &state) {
424 auto &prop = state.getOrAddProperties<Properties>();
425 if (failed(reader.readAttribute(prop.callee)) ||
426 failed(reader.readAttribute(prop.mapOpGroupSizes)) ||
427 failed(reader.readOptionalAttribute(prop.numDimsPerMap))) {
428 return failure();
429 }
430
431 if (reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
432 auto &propStorage = prop.operandSegmentSizes;
433 DenseI32ArrayAttr attr;
434 if (failed(reader.readAttribute(attr))) {
435 return failure();
436 }
437 if (attr.size() > static_cast<int64_t>(sizeof(propStorage) / sizeof(int32_t))) {
438 reader.emitError("size mismatch for operand/result_segment_size");
439 return failure();
440 }
441 llvm::copy(ArrayRef<int32_t>(attr), propStorage.begin());
442 }
443
444 // The `templateParams` is only available in version 2 or later.
445 auto versionOpt = reader.getDialectVersion<FunctionDialect>();
446 if (succeeded(versionOpt)) {
447 const auto &ver = static_cast<const LLZKDialectVersion &>(**versionOpt);
448 if (ver.majorVersion >= 2) {
449 if (failed(reader.readOptionalAttribute(prop.templateParams))) {
450 return failure();
451 }
452 }
453 }
454
455 if (reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) {
456 return reader.readSparseArray(MutableArrayRef(prop.operandSegmentSizes));
457 };
458 return success();
459}
460
461// Same as tablegen would generate to serialize version 2 IR.
462void CallOp::writeProperties(DialectBytecodeWriter &writer) {
463 auto &prop = getProperties();
464 writer.writeAttribute(prop.callee);
465 writer.writeAttribute(prop.mapOpGroupSizes);
466 writer.writeOptionalAttribute(prop.numDimsPerMap);
467
468 if (writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
469 auto &propStorage = prop.operandSegmentSizes;
470 writer.writeAttribute(DenseI32ArrayAttr::get(this->getContext(), propStorage));
471 }
472
473 writer.writeOptionalAttribute(prop.templateParams);
474
475 auto &propStorage = prop.operandSegmentSizes;
476 if (writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) {
477 writer.writeSparseArray(ArrayRef(propStorage));
478 }
479}
480
481static void addTemplateParams(
482 OpBuilder &odsBuilder, CallOp::Properties &props, ArrayRef<Attribute> templateParams
483) {
484 if (!templateParams.empty()) {
485 // Must attempt to convert attribute types but `build()` functions do not have a failure path or
486 // error reporting. That comes during validation of the constructed op so ignore errors here.
487 FailureOr<SmallVector<Attribute>> r = llzk::forceIntAttrTypes(templateParams, [&odsBuilder]() {
488 return InFlightDiagnosticWrapper::createSilent(odsBuilder.getContext());
489 });
490 ArrayRef<Attribute> converted = succeeded(r) ? r.value() : templateParams;
491 props.setTemplateParams(odsBuilder.getArrayAttr(converted));
492 }
493}
494
495void CallOp::build(
496 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
497 ValueRange argOperands, ArrayRef<Attribute> templateParams
498) {
499 odsState.addTypes(resultTypes);
500 odsState.addOperands(argOperands);
502 odsBuilder, odsState, llzk::checkedCast<int32_t>(argOperands.size())
503 );
504 props.setCallee(callee);
505 addTemplateParams(odsBuilder, props, templateParams);
506}
507
508void CallOp::build(
509 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
510 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands,
511 ArrayRef<Attribute> templateParams
512) {
513 odsState.addTypes(resultTypes);
514 odsState.addOperands(argOperands);
516 odsBuilder, odsState, mapOperands, numDimsPerMap,
517 llzk::checkedCast<int32_t>(argOperands.size())
518 );
519 props.setCallee(callee);
520 addTemplateParams(odsBuilder, props, templateParams);
521}
522
523LogicalResult
524CallOp::verifyTemplateParamCompatibility(Attribute paramFromCallOp, TemplateParamOp targetParam) {
525 if (std::optional<Type> declaredType = targetParam.getTypeOpt()) {
526 // Note: `declaredType` is restricted by `isValidConstReadType()`
527 bool compatible = false;
528 if (llvm::isa<TypeVarType>(*declaredType)) {
529 compatible = llvm::isa<TypeAttr>(paramFromCallOp);
530 } else if (llvm::isa<FeltType>(*declaredType)) {
531 compatible = llvm::isa<FeltConstAttr, IntegerAttr>(paramFromCallOp) &&
532 isValidConstReadType(llvm::cast<TypedAttr>(paramFromCallOp).getType());
533 } else if (llvm::isa<IndexType, IntegerType>(*declaredType)) {
534 // Note: Just like struct type instantiation, there is no restriction on passing a
535 // larger value to an `i1`. The flattening pass will treat 0 as false and any other
536 // value as true (but give a warning if it's not 1).
537 compatible = llvm::isa<IntegerAttr>(paramFromCallOp) &&
538 isValidConstReadType(llvm::cast<TypedAttr>(paramFromCallOp).getType());
539 } else {
540 llvm_unreachable("inconsistent with `isValidConstReadType()`");
541 }
542 if (!compatible) {
543 // Tested in call_with_template_params_fail.llzk
544 return this->emitOpError().append(
545 "instantiation value '", paramFromCallOp, "' is not compatible with parameter \"@",
546 targetParam.getName(), "\" type restriction ", *declaredType
547 );
548 }
549 }
550 return success();
551}
552
554 llvm::iterator_range<Region::op_iterator<TemplateParamOp>> targetParamDefs
555) {
556 ArrayAttr callParams = this->getTemplateParamsAttr();
557 assert(!isNullOrEmpty(callParams) && "pre-condition");
558 assert((callParams.size() == llvm::range_size(targetParamDefs)) && "pre-condition");
559
560 for (auto [paramOp, attr] : llvm::zip_equal(targetParamDefs, callParams.getValue())) {
561 if (failed(verifyTemplateParamCompatibility(attr, paramOp))) {
562 return failure();
563 }
564 }
565 return success();
566}
567
569 llvm::iterator_range<Region::op_iterator<TemplateParamOp>> targetParamDefs,
570 const UnificationMap &unifications
571) {
572 ArrayAttr callParams = this->getTemplateParamsAttr();
573 assert(!isNullOrEmpty(callParams) && "pre-condition");
574 assert((callParams.size() == llvm::range_size(targetParamDefs)) && "pre-condition");
575
576 for (auto [paramOp, attr] : llvm::zip_equal(targetParamDefs, callParams.getValue())) {
577 auto it = unifications.find({FlatSymbolRefAttr::get(paramOp.getNameAttr()), Side::RHS});
578 if (it != unifications.end() && !typeParamsUnify({attr}, {it->second})) {
579 // Tested in call_with_template_params_fail.llzk
580 return this->emitOpError().append(
581 "template instantiation value '", attr, "' for parameter \"@", paramOp.getName(),
582 "\" conflicts with value '", it->second, "' inferred from function type signature"
583 );
584 }
585 }
586 return success();
587}
588
589namespace {
590
591struct CallOpVerifier {
592 CallOpVerifier(CallOp *c, FunctionKind tgtFuncKind) : callOp(c), tgtKind(tgtFuncKind) {}
593 CallOpVerifier(CallOp *c, StringRef tgtName) : CallOpVerifier(c, fnNameToKind(tgtName)) {}
594 virtual ~CallOpVerifier() = default;
595
596 LogicalResult verify() {
597 // Rather than immediately returning on failure, we check all verifier steps and aggregate to
598 // provide as many errors are possible in a single verifier run.
599 LogicalResult aggregateResult = success();
600 if (failed(verifyTargetAttributes())) {
601 aggregateResult = failure();
602 }
603 if (failed(verifyInputs())) {
604 aggregateResult = failure();
605 }
606 if (failed(verifyOutputs())) {
607 aggregateResult = failure();
608 }
609 if (failed(verifyTemplateParams())) {
610 aggregateResult = failure();
611 }
612 if (failed(verifyAffineMapParams())) {
613 aggregateResult = failure();
614 }
615 return aggregateResult;
616 }
617
618protected:
619 CallOp *callOp;
620 FunctionKind tgtKind;
621
622 virtual LogicalResult verifyTargetAttributes() = 0;
623 virtual LogicalResult verifyInputs() = 0;
624 virtual LogicalResult verifyOutputs() = 0;
625 virtual LogicalResult verifyTemplateParams() = 0;
626 virtual LogicalResult verifyAffineMapParams() = 0;
627
629 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
630 LogicalResult aggregateRes = success();
631 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
632 auto emitAttrErr = [&](StringLiteral attrName) {
633 aggregateRes = callOp->emitOpError()
634 << "target '@" << target.getName() << "' has '" << attrName
635 << "' attribute, which is not specified by the caller '@" << caller.getName()
636 << '\'';
637 };
638
639 if (target.hasAllowConstraintAttr() && !caller.hasAllowConstraintAttr()) {
640 emitAttrErr(AllowConstraintAttr::name);
641 }
642 if (target.hasAllowWitnessAttr() && !caller.hasAllowWitnessAttr()) {
643 emitAttrErr(AllowWitnessAttr::name);
644 }
645 if (target.hasAllowNonNativeFieldOpsAttr() && !caller.hasAllowNonNativeFieldOpsAttr()) {
646 emitAttrErr(AllowNonNativeFieldOpsAttr::name);
647 }
648 }
649 return aggregateRes;
650 }
651
652 LogicalResult verifyNoTemplateInstantiations() {
653 if (!isNullOrEmpty(callOp->getTemplateParamsAttr())) {
654 // Tested in call_with_template_params_fail.llzk
655 return callOp->emitOpError().append(
656 "can only have template instantiations when targeting a templated free function"
657 );
658 }
659 return success();
660 }
661
662 LogicalResult verifyNoAffineMapInstantiations() {
663 if (!isNullOrEmpty(callOp->getMapOpGroupSizesAttr())) {
664 // Tested in call_with_affinemap_fail.llzk
665 return callOp->emitOpError().append(
666 "can only have affine map instantiations when targeting a \"@", FUNC_NAME_COMPUTE,
667 "\" function"
668 );
669 }
670 // ASSERT: the check above is sufficient due to VerifySizesForMultiAffineOps trait.
671 assert(isNullOrEmpty(callOp->getNumDimsPerMapAttr()));
672 assert(callOp->getMapOperands().empty());
673 return success();
674 }
675};
676
677struct KnownTargetVerifier : public CallOpVerifier {
678 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
679 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
680 includeSymNames(tgtRes.getNamespace()) {}
681
682 LogicalResult verifyTargetAttributes() override {
683 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
684 }
685
686 LogicalResult verifyInputs() override {
687 return verifyTypesMatch(callOp->getArgOperands().getTypes(), tgtType.getInputs(), "operand");
688 }
689
690 LogicalResult verifyOutputs() override {
691 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(), "result");
692 }
693
694 LogicalResult verifyTemplateParams() override {
695 auto tgtOp = tgt.getOperation();
696 if (isInStruct(tgtOp)) {
697 // Struct function calls cannot contain template parameter instantiations.
698 return verifyNoTemplateInstantiations();
699 } else if (TemplateOp tgtOpParent = getParentOfType<TemplateOp>(tgtOp)) {
700 // When the target function is a free function within a TemplateOp, the CallOp may have
701 // template parameter instantiations that must be checked against the template parameters.
702 // - If the function type signature references all template parameters, then the parameter
703 // instantiation list on the CallOp is optional, otherwise it's required.
704 // - If present, the instantiation list must provide a value for every template parameter
705 // and the value must be type-compatible with the parameter's declared type (if any).
706 // - If present, the instantiation list must result in a function type signature that can
707 // be unified with the CallOp's operand and result types.
708 auto realParams = tgtOpParent.getConstOps<TemplateParamOp>();
709 ArrayAttr callParams = callOp->getTemplateParamsAttr();
710
711 // When there is no instantiation list, just ensure that it's not required.
712 if (isNullOrEmpty(callParams)) {
713 llvm::SmallDenseSet<SymbolRefAttr> referencedInSignature;
714 llzk::getSymbolsUsedIn(tgtType.getInputs(), referencedInSignature);
715 llzk::getSymbolsUsedIn(tgtType.getResults(), referencedInSignature);
716
717 bool allParamsReferenced = llvm::all_of(realParams, [&](TemplateParamOp p) {
718 return referencedInSignature.contains(FlatSymbolRefAttr::get(p.getNameAttr()));
719 });
720 if (allParamsReferenced) {
721 return success();
722 }
723 // Tested in call_with_template_params_fail.llzk
724 return callOp->emitOpError().append(
725 "must provide template instantiation parameters when calling \"@", tgt.getSymName(),
726 "\" because not all template parameters of \"@", tgtOpParent.getSymName(),
727 "\" appear in the function type signature"
728 );
729 }
730
731 // Ensure `forceIntAttrTypes()` was successful on the CallOp's template parameters.
732 if (failed(llzk::forceIntAttrTypes(callParams.getValue(), [this] {
733 return llzk::InFlightDiagnosticWrapper(this->callOp->emitOpError());
734 }))) {
735 return failure();
736 }
737
738 // The instantiation list is present. Check it has exactly one entry per template param.
739 size_t numTemplateParams = llvm::range_size(realParams);
740 if (callParams.size() != numTemplateParams) {
741 // Tested in call_with_template_params_fail.llzk
742 return callOp->emitOpError().append(
743 "template instantiation has ", callParams.size(), " parameter(s) but \"@",
744 tgtOpParent.getSymName(), "\" expects ", numTemplateParams, " template parameter(s)"
745 );
746 }
747
748 // Check type compatibility of each provided value with the declared parameter type (if any).
749 if (failed(callOp->verifyTemplateParamCompatibility(realParams))) {
750 return failure();
751 }
752
753 // Check that the provided instantiation values are consistent with what type unification
754 // of the target function types against the call's operand and result types would determine.
755 FailureOr<UnificationMap> unifyResult = callOp->unifyTypeSignature(tgtType);
756 assert(succeeded(unifyResult) && "already checked by `verifyInputs()` and `verifyOutputs()`");
757 return callOp->verifyTemplateParamsMatchInferred(realParams, unifyResult.value());
758 } else {
759 // Non-template functions cannot contain template parameter instantiations.
760 return verifyNoTemplateInstantiations();
761 }
762 }
763
764 LogicalResult verifyAffineMapParams() override {
765 if ((FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) &&
766 isInStruct(tgt.getOperation())) {
767 // Return type should be a single StructType. If that is not the case here, just bail without
768 // producing an error. The combination of this KnownTargetVerifier resolving the callee to a
769 // specific FuncDefOp and verifyFuncTypeCompute() ensuring all FUNC_NAME_COMPUTE FuncOps have
770 // a single StructType return value will produce a more relevant error message in that case.
771 if (StructType retTy = callOp->getSingleResultTypeOfWitnessGen()) {
772 if (ArrayAttr params = retTy.getParams()) {
773 // Collect the struct parameters that are defined via AffineMapAttr
774 SmallVector<AffineMapAttr> mapAttrs;
775 for (Attribute a : params) {
776 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
777 mapAttrs.push_back(m);
778 }
779 }
781 callOp->getMapOperands(), callOp->getNumDimsPerMap(), mapAttrs, *callOp
782 );
783 }
784 }
785 return success();
786 } else {
787 // Global functions and constrain functions cannot have affine map instantiations.
788 return verifyNoAffineMapInstantiations();
789 }
790 }
791
792private:
793 template <typename T>
794 LogicalResult
795 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes, const char *aspect) {
796 if (tgtTypes.size() != callOpTypes.size()) {
797 return callOp->emitOpError()
798 .append("incorrect number of ", aspect, "s for callee, expected ", tgtTypes.size())
799 .attachNote(tgt.getLoc())
800 .append("callee defined here");
801 }
802 for (unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
803 if (!typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
804 return callOp->emitOpError().append(
805 aspect, " type mismatch: expected type ", tgtTypes[i], ", but found ", callOpTypes[i],
806 " for ", aspect, " number ", i
807 );
808 }
809 }
810 return success();
811 }
812
813 FuncDefOp tgt;
814 FunctionType tgtType;
815 std::vector<llvm::StringRef> includeSymNames;
816};
817
820LogicalResult checkSelfTypeUnknownTarget(
821 StringAttr expectedParamName, Type actualType, CallOp *origin, const char *aspect
822) {
823 if (!llvm::isa<TypeVarType>(actualType) ||
824 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
825 // Tested in function_restrictions_fail.llzk:
826 // Non-tvar for constrain input via "call_target_constrain_without_self_non_struct"
827 // Non-tvar for compute output via "call_target_compute_wrong_type_ret"
828 // Wrong tvar for constrain input via "call_target_constrain_without_self_wrong_tvar_param"
829 // Wrong tvar for compute output via "call_target_compute_wrong_tvar_param_ret"
830 return origin->emitOpError().append(
831 "target \"@", origin->getCallee().getLeafReference().getValue(), "\" expected ", aspect,
832 " type '!", TypeVarType::name, "<@", expectedParamName.getValue(), ">' but found ",
833 actualType
834 );
835 }
836 return success();
837}
838
848struct UnknownTargetVerifier : public CallOpVerifier {
849 UnknownTargetVerifier(CallOp *c, FunctionKind tgtFuncKind, SymbolRefAttr callee)
850 : CallOpVerifier(c, tgtFuncKind), calleeAttr(callee) {
851 assert(
852 tgtFuncKind == FunctionKind::StructCompute ||
853 tgtFuncKind == FunctionKind::StructConstrain || tgtFuncKind == FunctionKind::StructProduct
854 ); // pre-condition mentioned above
855 }
856
857 LogicalResult verifyTargetAttributes() override {
858 // Based on the precondition of this verifier, the target must be either a
859 // struct compute, constrain, or product function.
860 LogicalResult aggregateRes = success();
861 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
862 auto emitAttrErr = [&](StringLiteral attrName) {
863 aggregateRes = callOp->emitOpError()
864 << "target '" << calleeAttr << "' has '" << attrName
865 << "' attribute, which is not specified by the caller '@" << caller.getName()
866 << '\'';
867 };
868
869 switch (tgtKind) {
871 if (!caller.hasAllowConstraintAttr()) {
872 emitAttrErr(AllowConstraintAttr::name);
873 }
874 break;
876 if (!caller.hasAllowWitnessAttr()) {
877 emitAttrErr(AllowWitnessAttr::name);
878 }
879 break;
881 if (!caller.hasAllowWitnessAttr()) {
882 emitAttrErr(AllowWitnessAttr::name);
883 }
884 if (!caller.hasAllowConstraintAttr()) {
885 emitAttrErr(AllowConstraintAttr::name);
886 }
887 break;
888 default:
889 break;
890 }
891 }
892 return aggregateRes;
893 }
894
895 LogicalResult verifyInputs() override {
896 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
897 // Without known target, no additional checks can be done.
898 } else if (FunctionKind::StructConstrain == tgtKind) {
899 // Without known target, this can only check that the first input is VarType using the same
900 // struct parameter as the base of the callee (later replaced with the target struct's type).
901 Operation::operand_type_range inputTypes = callOp->getArgOperands().getTypes();
902 if (inputTypes.size() < 1) {
903 // Tested in function_restrictions_fail.llzk
904 return callOp->emitOpError()
905 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have at least one input type";
906 }
907 return checkSelfTypeUnknownTarget(
908 calleeAttr.getRootReference(), inputTypes.front(), callOp, "first input"
909 );
910 }
911 return success();
912 }
913
914 LogicalResult verifyOutputs() override {
915 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
916 // Without known target, this can only check that the function returns VarType using the same
917 // struct parameter as the base of the callee (later replaced with the target struct's type).
918 Operation::result_type_range resTypes = callOp->getResultTypes();
919 if (resTypes.size() != 1) {
920 // Tested in function_restrictions_fail.llzk
921 return callOp->emitOpError().append(
922 "target \"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
923 );
924 }
925 return checkSelfTypeUnknownTarget(
926 calleeAttr.getRootReference(), resTypes.front(), callOp, "return"
927 );
928 } else if (FunctionKind::StructConstrain == tgtKind) {
929 // Without known target, this can only check that the function has no return
930 if (callOp->getNumResults() != 0) {
931 // Tested in function_restrictions_fail.llzk
932 return callOp->emitOpError()
933 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
934 }
935 }
936 return success();
937 }
938
939 LogicalResult verifyTemplateParams() override {
940 // Struct function calls cannot contain template parameter instantiations.
941 return verifyNoTemplateInstantiations();
942 }
943
944 LogicalResult verifyAffineMapParams() override {
945 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
946 // Without known target, no additional checks can be done.
947 } else if (FunctionKind::StructConstrain == tgtKind) {
948 // Without known target, this can only check that there are no affine map instantiations.
949 return verifyNoAffineMapInstantiations();
950 }
951 return success();
952 }
953
954private:
955 SymbolRefAttr calleeAttr;
956};
957
958} // namespace
959
960LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &tables) {
961 // First, verify symbol resolution in all input and output types.
962 if (failed(verifyTypeResolution(tables, *this, getTypeSignature()))) {
963 return failure(); // verifyTypeResolution() already emits a sufficient error message
964 }
965
966 // Check that the callee attribute was specified.
967 SymbolRefAttr calleeAttr = getCalleeAttr();
968 if (!calleeAttr) {
969 return emitOpError("requires a 'callee' symbol reference attribute");
970 }
971
972 // If the callee references a parameter of the template where this call appears, perform
973 // the subset of checks that can be done even though the target is unknown.
974 if (calleeAttr.getNestedReferences().size() == 1) {
975 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
976 if (parent.hasConstNamed<TemplateParamOp>(calleeAttr.getRootReference())) {
977 FunctionKind tgtKind = fnNameToKind(calleeAttr.getLeafReference().getValue());
978 if (tgtKind != FunctionKind::Free) {
979 return UnknownTargetVerifier(this, tgtKind, calleeAttr).verify();
980 }
981 return this->emitError("expected parameterized callee to target a struct function")
982 .append(
983 " (i.e. \"@", FUNC_NAME_PRODUCT, "\", \"@", FUNC_NAME_COMPUTE, "\", or \"@",
985 );
986 }
987 }
988 }
989
990 // Otherwise, callee must be specified via full path from the root module. Perform the full set of
991 // checks against the known target function.
992 auto tgtOpt = lookupTopLevelSymbol<FuncDefOp>(tables, calleeAttr, *this);
993 if (failed(tgtOpt)) {
994 return this->emitError() << "expected '" << FuncDefOp::getOperationName() << "' named \""
995 << calleeAttr << '"';
996 }
997 return KnownTargetVerifier(this, std::move(*tgtOpt)).verify();
998}
999
1001 return FunctionType::get(getContext(), getArgOperands().getTypes(), getResultTypes());
1002}
1003
1004FailureOr<UnificationMap> CallOp::unifyTypeSignature(FunctionType other) {
1005 UnificationMap unifications;
1006 if (functionTypesUnify(getTypeSignature(), other, {}, &unifications)) {
1007 return unifications;
1008 } else {
1009 return failure();
1010 }
1011}
1012
1013namespace {
1014
1015bool calleeIsStructFunctionImpl(
1016 const char *funcName, SymbolRefAttr callee, llvm::function_ref<StructType()> getType
1017) {
1018 if (callee.getLeafReference() == funcName) {
1019 if (StructType t = getType()) {
1020 // If the name ref within the StructType matches the `callee` prefix (i.e., sans the function
1021 // name itself), then the `callee` target must be within a StructDefOp because validation
1022 // checks elsewhere ensure that every StructType references a StructDefOp (i.e., the `callee`
1023 // function is not simply a free function nested within a ModuleOp)
1024 return t.getNameRef() == getPrefixAsSymbolRefAttr(callee);
1025 }
1026 }
1027 return false;
1028}
1029
1030} // namespace
1031
1033 return calleeIsStructFunctionImpl(FUNC_NAME_COMPUTE, getCallee(), [this]() {
1034 return this->getSingleResultTypeOfCompute();
1035 });
1036}
1037
1039 return calleeIsStructFunctionImpl(FUNC_NAME_CONSTRAIN, getCallee(), [this]() {
1040 return getAtIndex<StructType>(this->getArgOperands().getTypes(), 0);
1041 });
1042}
1043
1045 assert(calleeIsStructCompute());
1046 return getResults().front();
1047}
1048
1050 assert(calleeIsStructConstrain());
1051 return getArgOperands().front();
1052}
1053
1054FailureOr<SymbolLookupResult<FuncDefOp>> CallOp::getCalleeTarget(SymbolTableCollection &tables) {
1055 Operation *thisOp = this->getOperation();
1056 auto root = getRootModule(thisOp);
1057 assert(succeeded(root));
1058 return llzk::lookupSymbolIn<FuncDefOp>(tables, getCallee(), root->getOperation(), thisOp);
1059}
1060
1062 assert(calleeIsCompute() && "violated implementation pre-condition");
1063 return getIfSingleton<StructType>(getResultTypes());
1064}
1065
1067 assert(calleeContainsWitnessGen() && "violated implementation pre-condition");
1068 return getIfSingleton<StructType>(getResultTypes());
1069}
1070
1072CallInterfaceCallable CallOp::getCallableForCallee() { return getCalleeAttr(); }
1073
1075void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1076 setCalleeAttr(llvm::cast<SymbolRefAttr>(callee));
1077}
1078
1079SmallVector<ValueRange> CallOp::toVectorOfValueRange(OperandRangeRange input) {
1080 llvm::SmallVector<ValueRange, 4> output;
1081 output.reserve(input.size());
1082 for (OperandRange r : input) {
1083 output.push_back(r);
1084 }
1085 return output;
1086}
1087
1088Operation *CallOp::resolveCallableInTable(SymbolTableCollection *symbolTable) {
1089 FailureOr<SymbolLookupResult<FuncDefOp>> res =
1090 llzk::resolveCallable<FuncDefOp>(*symbolTable, *this);
1091 if (failed(res) || res->isManaged()) {
1092 // Cannot return pointer to a managed Operation since it would cause memory errors.
1093 return nullptr;
1094 }
1095 return res->get();
1096}
1097
1099 SymbolTableCollection tables;
1100 return resolveCallableInTable(&tables);
1101}
1102
1103} // namespace llzk::function
This file defines methods symbol lookup across LLZK operations and included files.
static InFlightDiagnosticWrapper createSilent(mlir::MLIRContext *ctx)
Construct a silent diagnostic that does nothing when appended to or reported.
Definition ErrorHelper.h:74
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::SymbolRefAttr callee, ::mlir::ValueRange argOperands={}, ::llvm::ArrayRef<::mlir::Attribute > templateParams={})
bool calleeContainsWitnessGen()
Return true iff the callee function can contain witness generation code (this does not check if the c...
Definition Ops.h.inc:389
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:1038
::mlir::CallInterfaceCallable getCallableForCallee()
Return the callee of this operation.
Definition Ops.cpp:1072
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:1061
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:292
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:960
::mlir::Operation * resolveCallableInTable(::mlir::SymbolTableCollection *symbolTable)
Required by CallOpInterface.
Definition Ops.cpp:1088
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:1032
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:470
::mlir::Operation * resolveCallable()
Required by CallOpInterface.
Definition Ops.cpp:1098
void writeProperties(::mlir::DialectBytecodeWriter &writer)
Definition Ops.cpp:462
::mlir::FunctionType getTypeSignature()
Return the FunctionType inferred from the arg operands and result types of this CallOp.
Definition Ops.cpp:1000
bool calleeIsCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE (this does not check if the callee func...
Definition Ops.h.inc:383
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:480
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:1049
::mlir::ArrayAttr getTemplateParamsAttr()
Definition Ops.h.inc:297
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:270
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:1079
::llzk::component::StructType getSingleResultTypeOfWitnessGen()
Assuming the callee contains witness generation code, return the single StructType result.
Definition Ops.cpp:1066
FoldAdaptor::Properties Properties
Definition Ops.h.inc:209
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.h.inc:312
::llvm::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state)
Definition Ops.cpp:423
::mlir::FailureOr< UnificationMap > unifyTypeSignature(::mlir::FunctionType other)
Attempt type unfication between the inferred FunctionType from this CallOp (as LHS) and the given Fun...
Definition Ops.cpp:1004
::mlir::LogicalResult verifyTemplateParamsMatchInferred(::llvm::iterator_range<::mlir::Region::op_iterator<::llzk::polymorphic::TemplateParamOp > > targetParamDefs, const UnificationMap &unifications)
Verify that each template parameter value provided in this CallOp is consistent with the value inferr...
Definition Ops.cpp:568
::mlir::LogicalResult verifyTemplateParamCompatibility(::mlir::Attribute paramFromCallOp, ::llzk::polymorphic::TemplateParamOp targetParam)
Check type compatibility of the given template parameter value from this CallOp against the declared ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:1044
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee)
Set the callee for this operation.
Definition Ops.cpp:1075
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:1054
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:218
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:337
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:363
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={}, ::llvm::ArrayRef<::mlir::DictionaryAttr > argAttrs={})
void setAllowNonNativeFieldOpsAttr(bool newValue=true)
Add (resp. remove) the allow_non_native_field_ops attribute to (resp. from) the function def.
Definition Ops.cpp:226
::mlir::StringAttr getFunctionTypeAttrName()
Definition Ops.h.inc:642
bool hasAllowNonNativeFieldOpsAttr()
Return true iff the function def has the allow_non_native_field_ops attribute.
Definition Ops.h.inc:820
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:382
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:129
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:853
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:812
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:387
::mlir::StringAttr getResAttrsAttrName()
Definition Ops.h.inc:650
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:118
void cloneInto(FuncDefOp dest, ::mlir::IRMapping &mapper)
Clone the internal blocks and attributes from this function into dest.
Definition Ops.cpp:138
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:861
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:857
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:867
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:864
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:842
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:666
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:210
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:353
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
Definition Ops.cpp:234
::llvm::LogicalResult verify()
Definition Ops.cpp:244
::mlir::Region & getBody()
Definition Ops.h.inc:690
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:804
::mlir::StringAttr getArgAttrsAttrName()
Definition Ops.h.inc:634
::llvm::LogicalResult verify()
Definition Ops.cpp:396
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:968
::std::optional<::mlir::Type > getTypeOpt()
Definition Ops.cpp.inc:1337
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:27
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
bool isInStruct(Operation *op)
Definition Ops.cpp:55
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:119
FunctionKind fnNameToKind(mlir::StringRef name)
Given a function name, return the corresponding FunctionKind.
Definition Ops.cpp:51
FunctionKind
Kinds of functions in LLZK.
Definition Ops.h:33
@ StructConstrain
Function within a struct named FUNC_NAME_CONSTRAIN.
Definition Ops.h:37
@ StructProduct
Function within a struct named FUNC_NAME_PRODUCT.
Definition Ops.h:39
@ StructCompute
Function within a struct named FUNC_NAME_COMPUTE.
Definition Ops.h:35
@ Free
Function that is not within a struct.
Definition Ops.h:41
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:270
FailureOr< ModuleOp > getRootModule(Operation *from)
void getSymbolsUsedIn(mlir::Type t, llvm::SmallDenseSet< mlir::SymbolRefAttr > &symbolsUsed)
Add all symbols used within the given Type to the provided set.
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:186
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
bool isNullOrEmpty(mlir::ArrayAttr a)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:69
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:274
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool functionTypesUnify(FunctionType lhs, FunctionType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
OpClass delegate_to_build(mlir::Location location, Args &&...args)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:114
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
bool isValidConstReadType(Type type)
void setTemplateParams(const ::mlir::ArrayAttr &propValue)
Definition Ops.h.inc:76