LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Func and call op implementations --------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// Copyright 2026 Project LLZK
7// SPDX-License-Identifier: Apache-2.0
8//
9// Adapted from the LLVM Project's lib/Dialect/Func/IR/FuncOps.cpp
10// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
11// See https://llvm.org/LICENSE.txt for license information.
12// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
13//
14//===----------------------------------------------------------------------===//
15
17
28#include "llzk/Util/Compare.h"
33
34#include <mlir/IR/IRMapping.h>
35#include <mlir/IR/OpImplementation.h>
36#include <mlir/Interfaces/FunctionImplementation.h>
37
38#include <llvm/ADT/DenseSet.h>
39#include <llvm/ADT/MapVector.h>
40
41// TableGen'd implementation files
42#define GET_OP_CLASSES
44
45using namespace mlir;
46using namespace llzk::felt;
47using namespace llzk::component;
48using namespace llzk::polymorphic;
49
50namespace llzk::function {
51
52FunctionKind fnNameToKind(mlir::StringRef name) {
53 if (FUNC_NAME_COMPUTE == name) {
55 } else if (FUNC_NAME_CONSTRAIN == name) {
57 } else if (FUNC_NAME_PRODUCT == name) {
59 } else {
60 return FunctionKind::Free;
61 }
62}
63
64namespace {
66inline LogicalResult
67verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, FunctionType funcType) {
69 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
70 );
71}
72
77static LogicalResult verifyArgOrResNameAttrs(
78 ArrayAttr attrs, StringRef ownAttrName, StringRef crossAttrName, StringRef ownLabel,
79 StringRef crossLabel, EmitErrorFn emitFn
80) {
81 if (!attrs) {
82 return success();
83 }
84 llvm::DenseSet<StringAttr> seenNames;
85 for (auto [i, attr] : llvm::enumerate(attrs)) {
86 auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr);
87 if (!dictAttr) {
88 continue;
89 }
90 if (dictAttr.contains(crossAttrName)) {
91 return emitFn().append(
92 '\'', crossAttrName, "' is only valid on function ", crossLabel, "s but found on ",
93 ownLabel, ' ', i
94 );
95 }
96 Attribute nameAttr = dictAttr.get(ownAttrName);
97 if (!nameAttr) {
98 continue;
99 }
100 auto name = llvm::dyn_cast<StringAttr>(nameAttr);
101 if (!name) {
102 return emitFn().append(
103 '\'', ownAttrName, "' on ", ownLabel, ' ', i, " must be a string attribute"
104 );
105 }
106 if (!llvm::isa<NoneType>(name.getType())) {
107 return emitFn().append(
108 '\'', ownAttrName, "' on ", ownLabel, ' ', i, " must not have an explicit type"
109 );
110 }
111 if (name.getValue().empty()) {
112 return emitFn().append('\'', ownAttrName, "' on ", ownLabel, ' ', i, " must not be empty");
113 }
114 if (!seenNames.insert(name).second) {
115 return emitFn().append(
116 "duplicate '", ownAttrName, "' value \"", name.getValue(), "\" on ", ownLabel, ' ', i
117 );
118 }
119 }
120 return success();
121}
122} // namespace
123
124//===----------------------------------------------------------------------===//
125// FuncDefOp
126//===----------------------------------------------------------------------===//
127
129 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
130) {
131 return delegate_to_build<FuncDefOp>(location, name, type, attrs);
132}
133
135 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
136) {
137 SmallVector<NamedAttribute, 8> attrRef(attrs);
138 return create(location, name, type, llvm::ArrayRef(attrRef));
139}
140
142 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
143 ArrayRef<DictionaryAttr> argAttrs
144) {
145 FuncDefOp func = create(location, name, type, attrs);
146 func.setAllArgAttrs(argAttrs);
147 return func;
148}
149
151 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
152 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
153) {
154 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
155 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
156 state.attributes.append(attrs.begin(), attrs.end());
157 state.addRegion();
158
159 if (argAttrs.empty()) {
160 return;
161 }
162 assert(type.getNumInputs() == argAttrs.size());
163 function_interface_impl::addArgAndResultAttrs(
164 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name),
165 getResAttrsAttrName(state.name)
166 );
167}
168
169ParseResult FuncDefOp::parse(OpAsmParser &parser, OperationState &result) {
170 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
171 function_interface_impl::VariadicFlag,
172 std::string &) { return builder.getFunctionType(argTypes, results); };
173
174 return function_interface_impl::parseFunctionOp(
175 parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType,
176 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)
177 );
178}
179
180void FuncDefOp::print(OpAsmPrinter &p) {
181 function_interface_impl::printFunctionOp(
182 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(),
184 );
185}
186
189void FuncDefOp::cloneInto(FuncDefOp dest, IRMapping &mapper) {
190 // Add the attributes of this function to dest.
191 llvm::MapVector<StringAttr, Attribute> newAttrMap;
192 for (const auto &attr : dest->getAttrs()) {
193 newAttrMap.insert({attr.getName(), attr.getValue()});
194 }
195 for (const auto &attr : (*this)->getAttrs()) {
196 newAttrMap.insert({attr.getName(), attr.getValue()});
197 }
198
199 auto newAttrs =
200 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
201 return NamedAttribute(attrPair.first, attrPair.second);
202 }));
203 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
204
205 // Clone the body.
206 getBody().cloneInto(&dest.getBody(), mapper);
207}
208
214FuncDefOp FuncDefOp::clone(IRMapping &mapper) {
215 // Create the new function.
216 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
217
218 // If the function has a body, then the user might be deleting arguments to
219 // the function by specifying them in the mapper. If so, we don't add the
220 // argument to the input type vector.
221 if (!isExternal()) {
222 FunctionType oldType = getFunctionType();
223
224 unsigned oldNumArgs = oldType.getNumInputs();
225 SmallVector<Type, 4> newInputs;
226 newInputs.reserve(oldNumArgs);
227 for (unsigned i = 0; i != oldNumArgs; ++i) {
228 if (!mapper.contains(getArgument(i))) {
229 newInputs.push_back(oldType.getInput(i));
230 }
231 }
232
235 if (newInputs.size() != oldNumArgs) {
236 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
237
238 if (ArrayAttr argAttrs = getAllArgAttrs()) {
239 SmallVector<Attribute> newArgAttrs;
240 newArgAttrs.reserve(newInputs.size());
241 for (unsigned i = 0; i != oldNumArgs; ++i) {
242 if (!mapper.contains(getArgument(i))) {
243 newArgAttrs.push_back(argAttrs[i]);
244 }
245 }
246 newFunc.setAllArgAttrs(newArgAttrs);
247 }
248 }
249 }
250
252 cloneInto(newFunc, mapper);
253 return newFunc;
254}
255
257 IRMapping mapper;
258 return clone(mapper);
259}
260
262 if (newValue) {
263 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
264 } else {
265 getOperation()->removeAttr(AllowConstraintAttr::name);
266 }
267}
268
270 if (newValue) {
271 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
272 } else {
273 getOperation()->removeAttr(AllowWitnessAttr::name);
274 }
275}
276
278 if (newValue) {
279 getOperation()->setAttr(AllowNonNativeFieldOpsAttr::name, UnitAttr::get(getContext()));
280 } else {
281 getOperation()->removeAttr(AllowNonNativeFieldOpsAttr::name);
282 }
283}
284
285bool FuncDefOp::hasArgPublicAttr(unsigned index) {
286 if (index < this->getNumArguments()) {
287 DictionaryAttr res = function_interface_impl::getArgAttrDict(*this, index);
288 return res ? res.contains(PublicAttr::name) : false;
289 } else {
290 // TODO: print error? requested attribute for non-existant argument index
291 return false;
292 }
293}
294
295bool FuncDefOp::hasArgName(unsigned index) { return static_cast<bool>(getArgNameAttr(index)); }
296
297std::optional<StringAttr> FuncDefOp::getArgNameAttr(unsigned index) {
298 if (index >= getNumArguments()) {
299 return std::nullopt;
300 }
301 if (StringAttr attr = getArgAttrOfType<StringAttr>(index, ARG_NAME_ATTR_NAME)) {
302 return attr;
303 }
304 return std::nullopt;
305}
306
307void FuncDefOp::setArgNameAttr(unsigned index, const StringAttr &attr) {
308 assert(index < getNumArguments() && "argument index out of range");
309 setArgAttr(index, ARG_NAME_ATTR_NAME, attr);
310}
311
312void FuncDefOp::setArgName(unsigned index, StringRef name) {
313 setArgNameAttr(index, StringAttr::get(getContext(), name));
314}
315
316LogicalResult FuncDefOp::verify() {
317 OwningEmitErrorFn emitErrorFunc = getEmitOpErrFn(this);
318
319 if ((*this)->hasAttr(ARG_NAME_ATTR_NAME)) {
320 return emitErrorFunc() << '\'' << ARG_NAME_ATTR_NAME << "' is only valid on function arguments";
321 }
322 if ((*this)->hasAttr(RES_NAME_ATTR_NAME)) {
323 return emitErrorFunc() << '\'' << RES_NAME_ATTR_NAME << "' is only valid on function results";
324 }
325
326 if (failed(verifyArgOrResNameAttrs(
327 getAllResultAttrs(), RES_NAME_ATTR_NAME, ARG_NAME_ATTR_NAME, "result", "argument",
328 emitErrorFunc
329 ))) {
330 return failure();
331 }
332
333 if (failed(verifyArgOrResNameAttrs(
334 getAllArgAttrs(), ARG_NAME_ATTR_NAME, RES_NAME_ATTR_NAME, "argument", "result",
335 emitErrorFunc
336 ))) {
337 return failure();
338 }
339
340 // Ensure that only valid LLZK types are used for arguments and return. Additionally, the struct
341 // functions may not use AffineMapAttrs in their parameter types. If such a scenario seems to make
342 // sense when generating LLZK IR, it's likely better to introduce a struct parameter to use
343 // instead and instantiate the struct with that AffineMapAttr.
344 FunctionType type = getFunctionType();
345 for (Type t : type.getInputs()) {
346 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
347 return failure();
348 }
349 if (isInStruct() && hasAffineMapAttr(t)) {
350 return emitErrorFunc().append(
351 "\"@", getName(), "\" parameters cannot contain affine map attributes but found ", t
352 );
353 }
354 }
355 for (Type t : type.getResults()) {
356 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
357 return failure();
358 }
359 }
360 // Ensure that the function does not contain nested modules.
361 // Functions also cannot contain nested structs, but this check is handled
362 // via struct.def's requirement of having module as a parent.
363 WalkResult res = this->walk<WalkOrder::PreOrder>([this](ModuleOp nestedMod) {
364 getEmitOpErrFn(nestedMod)().append(
365 "cannot be nested within '", getOperation()->getName(), "' operations"
366 );
367 return WalkResult::interrupt();
368 });
369 if (res.wasInterrupted()) {
370 return failure();
371 }
372
373 return success();
374}
375
376namespace {
377
378LogicalResult
379verifyFuncTypeCompute(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
380 FunctionType funcType = origin.getFunctionType();
381 llvm::ArrayRef<Type> resTypes = funcType.getResults();
382 // Must return type of parent struct
383 if (resTypes.size() != 1) {
384 return origin.emitOpError().append(
385 "\"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
386 );
387 }
388 if (failed(checkSelfType(tables, parent, resTypes.front(), origin, "return"))) {
389 return failure();
390 }
391
392 // After the more specific checks (to ensure more specific error messages would be produced if
393 // necessary), do the general check that all symbol references in the types are valid. The return
394 // types were already checked so just check the input types.
395 return llzk::verifyTypeResolution(tables, origin, funcType.getInputs());
396}
397
398LogicalResult
399verifyFuncTypeProduct(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
400 // The signature for @product is the same as the signature for @compute
401 return verifyFuncTypeCompute(origin, tables, parent);
402}
403
404LogicalResult
405verifyFuncTypeConstrain(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
406 FunctionType funcType = origin.getFunctionType();
407 // Must return '()' type, i.e., have no return types
408 if (funcType.getResults().size() != 0) {
409 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
410 }
411
412 // Type of the first parameter must match the parent StructDefOp of the current operation.
413 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
414 if (inputTypes.size() < 1) {
415 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN
416 << "\" must have at least one input type";
417 }
418 if (failed(checkSelfType(tables, parent, inputTypes.front(), origin, "first input"))) {
419 return failure();
420 }
421
422 // After the more specific checks (to ensure more specific error messages would be produced if
423 // necessary), do the general check that all symbol references in the types are valid. There are
424 // no return types, just check the remaining input types (the first was already checked via
425 // the checkSelfType() call above).
426 return llzk::verifyTypeResolution(tables, origin, inputTypes.drop_front());
427}
428
429} // namespace
430
431LogicalResult FuncDefOp::verifySymbolUses(SymbolTableCollection &tables) {
432 // Additional checks for the compute/constrain/product functions within a struct
433 if (StructDefOp parentStructOpt = getParentOfType<StructDefOp>(*this)) {
434 // Verify return type restrictions for functions within a StructDefOp
435 if (nameIsCompute()) {
436 return verifyFuncTypeCompute(*this, tables, parentStructOpt);
437 } else if (nameIsConstrain()) {
438 return verifyFuncTypeConstrain(*this, tables, parentStructOpt);
439 } else if (nameIsProduct()) {
440 return verifyFuncTypeProduct(*this, tables, parentStructOpt);
441 }
442 }
443 // In the general case, verify symbol resolution in all input and output types.
444 return verifyTypeResolution(tables, *this, getFunctionType());
445}
446
447SymbolRefAttr FuncDefOp::getFullyQualifiedName(bool requireParent) {
448 // If the parent is not present and not required, just return the symbol name
449 if (!requireParent && getOperation()->getParentOp() == nullptr) {
450 return SymbolRefAttr::get(getOperation());
451 }
452 auto res = getPathFromRoot(*this);
453 assert(succeeded(res));
454 return res.value();
455}
456
458 assert(nameIsCompute()); // skip inStruct check to allow dangling functions
459 // Get the single block of the function body
460 Region &body = getBody();
461 assert(!body.empty() && "compute() function body is empty");
462 Block &block = body.back();
463
464 // The terminator should be the return op
465 Operation *terminator = block.getTerminator();
466 assert(terminator && "compute() function has no terminator");
467 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
468 if (!retOp) {
469 llvm::errs() << "Expected '" << ReturnOp::getOperationName() << "' but found '"
470 << terminator->getName() << "'\n";
471 llvm_unreachable("compute() function must end with ReturnOp");
472 }
473 return retOp.getOperands().front();
474}
475
477 assert(nameIsConstrain()); // skip inStruct check to allow dangling functions
478 return getArguments().front();
479}
480
482 assert(isStructCompute() && "violated implementation pre-condition");
484}
485
486//===----------------------------------------------------------------------===//
487// ReturnOp
488//===----------------------------------------------------------------------===//
489
490LogicalResult ReturnOp::verify() {
491 auto function = getParentOp<FuncDefOp>(); // parent is FuncDefOp per ODS
492
493 // The operand number and types must match the function signature.
494 const auto results = function.getFunctionType().getResults();
495 if (getNumOperands() != results.size()) {
496 return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@"
497 << function.getName() << ") returns " << results.size();
498 }
499
500 for (unsigned i = 0, e = results.size(); i != e; ++i) {
501 if (!typesUnify(getOperand(i).getType(), results[i])) {
502 return emitError() << "type of return operand " << i << " (" << getOperand(i).getType()
503 << ") doesn't match function result type (" << results[i] << ")"
504 << " in function @" << function.getName();
505 }
506 }
507
508 return success();
509}
510
511//===----------------------------------------------------------------------===//
512// CallOp
513//===----------------------------------------------------------------------===//
514
515// Custom implementation to deserialize bytecode produced prior to version 2 which added optional
516// `OptionalAttr<ArrayAttr>:$templateParams`.
517LogicalResult CallOp::readProperties(DialectBytecodeReader &reader, OperationState &state) {
518 auto &prop = state.getOrAddProperties<Properties>();
519 if (failed(reader.readAttribute(prop.callee)) ||
520 failed(reader.readAttribute(prop.mapOpGroupSizes)) ||
521 failed(reader.readOptionalAttribute(prop.numDimsPerMap))) {
522 return failure();
523 }
524
525 if (reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
526 auto &propStorage = prop.operandSegmentSizes;
527 DenseI32ArrayAttr attr;
528 if (failed(reader.readAttribute(attr))) {
529 return failure();
530 }
531 if (attr.size() > static_cast<int64_t>(sizeof(propStorage) / sizeof(int32_t))) {
532 reader.emitError("size mismatch for operand/result_segment_size");
533 return failure();
534 }
535 llvm::copy(ArrayRef<int32_t>(attr), propStorage.begin());
536 }
537
538 // The `templateParams` is only available in version 2 or later.
539 auto versionOpt = reader.getDialectVersion<FunctionDialect>();
540 if (succeeded(versionOpt)) {
541 const auto &ver = static_cast<const LLZKDialectVersion &>(**versionOpt);
542 if (ver.majorVersion >= 2) {
543 if (failed(reader.readOptionalAttribute(prop.templateParams))) {
544 return failure();
545 }
546 }
547 }
548
549 if (reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) {
550 return reader.readSparseArray(MutableArrayRef(prop.operandSegmentSizes));
551 };
552 return success();
553}
554
555// Same as tablegen would generate to serialize version 2 IR.
556void CallOp::writeProperties(DialectBytecodeWriter &writer) {
557 auto &prop = getProperties();
558 writer.writeAttribute(prop.callee);
559 writer.writeAttribute(prop.mapOpGroupSizes);
560 writer.writeOptionalAttribute(prop.numDimsPerMap);
561
562 if (writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
563 auto &propStorage = prop.operandSegmentSizes;
564 writer.writeAttribute(DenseI32ArrayAttr::get(this->getContext(), propStorage));
565 }
566
567 writer.writeOptionalAttribute(prop.templateParams);
568
569 auto &propStorage = prop.operandSegmentSizes;
570 if (writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) {
571 writer.writeSparseArray(ArrayRef(propStorage));
572 }
573}
574
575void CallOp::build(
576 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
577 ValueRange argOperands, ArrayRef<Attribute> templateParams
578) {
579 odsState.addTypes(resultTypes);
580 odsState.addOperands(argOperands);
582 odsBuilder, odsState, llzk::checkedCast<int32_t>(argOperands.size())
583 );
584 props.setCallee(callee);
585 addTemplateParams<CallOp>(odsBuilder, props, templateParams);
586}
587
588void CallOp::build(
589 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
590 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands,
591 ArrayRef<Attribute> templateParams
592) {
593 odsState.addTypes(resultTypes);
594 odsState.addOperands(argOperands);
596 odsBuilder, odsState, mapOperands, numDimsPerMap,
597 llzk::checkedCast<int32_t>(argOperands.size())
598 );
599 props.setCallee(callee);
600 addTemplateParams<CallOp>(odsBuilder, props, templateParams);
601}
602
603LogicalResult
604CallOp::verifyTemplateParamCompatibility(Attribute paramFromCallOp, TemplateParamOp targetParam) {
605 // A wildcard `?` (represented as kDynamic) defers inference to a later pass.
606 // It is only valid for parameters with a `!poly.tvar` type restriction.
607 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(paramFromCallOp)) {
608 if (isDynamic(intAttr)) {
609 std::optional<Type> declaredType = targetParam.getTypeOpt();
610 if (!declaredType || !llvm::isa<TypeVarType>(*declaredType)) {
611 auto diag = this->emitOpError().append(
612 "wildcard `?` can only be used for template parameters with `!poly.tvar` "
613 "type restriction, but parameter \"@",
614 targetParam.getName(), "\" has "
615 );
616 if (declaredType) {
617 diag.append("type restriction ", *declaredType);
618 } else {
619 diag.append("no type restriction");
620 }
621 return diag;
622 }
623 return success();
624 }
625 }
626 if (std::optional<Type> declaredType = targetParam.getTypeOpt()) {
627 // Note: `declaredType` is restricted by `isValidConstReadType()`
628 bool compatible = false;
629 if (llvm::isa<TypeVarType>(*declaredType)) {
630 compatible = llvm::isa<TypeAttr>(paramFromCallOp);
631 } else if (llvm::isa<FeltType>(*declaredType)) {
632 compatible = llvm::isa<FeltConstAttr, IntegerAttr>(paramFromCallOp) &&
633 isValidConstReadType(llvm::cast<TypedAttr>(paramFromCallOp).getType());
634 } else if (llvm::isa<IndexType, IntegerType>(*declaredType)) {
635 // Note: Just like struct type instantiation, there is no restriction on passing a
636 // larger value to an `i1`. The flattening pass will treat 0 as false and any other
637 // value as true (but give a warning if it's not 1).
638 compatible = llvm::isa<IntegerAttr>(paramFromCallOp) &&
639 isValidConstReadType(llvm::cast<TypedAttr>(paramFromCallOp).getType());
640 } else {
641 llvm_unreachable("inconsistent with `isValidConstReadType()`");
642 }
643 if (!compatible) {
644 // Tested in call_with_template_params_fail.llzk
645 return this->emitOpError().append(
646 "instantiation value '", paramFromCallOp, "' is not compatible with parameter \"@",
647 targetParam.getName(), "\" type restriction ", *declaredType
648 );
649 }
650 }
651 return success();
652}
653
655 llvm::iterator_range<Region::op_iterator<TemplateParamOp>> targetParamDefs
656) {
657 ArrayAttr callParams = this->getTemplateParamsAttr();
658 assert(!isNullOrEmpty(callParams) && "pre-condition");
659 assert((callParams.size() == llvm::range_size(targetParamDefs)) && "pre-condition");
660
661 for (auto [paramOp, attr] : llvm::zip_equal(targetParamDefs, callParams.getValue())) {
662 if (failed(verifyTemplateParamCompatibility(attr, paramOp))) {
663 return failure();
664 }
665 }
666 return success();
667}
668
670 llvm::iterator_range<Region::op_iterator<TemplateParamOp>> targetParamDefs,
671 const UnificationMap &unifications
672) {
673 ArrayAttr callParams = this->getTemplateParamsAttr();
674 assert(!isNullOrEmpty(callParams) && "pre-condition");
675 assert((callParams.size() == llvm::range_size(targetParamDefs)) && "pre-condition");
676
677 for (auto [paramOp, attr] : llvm::zip_equal(targetParamDefs, callParams.getValue())) {
678 // Skip wildcards (`?` / kDynamic) - their value will be resolved by a later inference pass.
679 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
680 if (isDynamic(intAttr)) {
681 continue;
682 }
683 }
684 auto it = unifications.find({FlatSymbolRefAttr::get(paramOp.getNameAttr()), Side::RHS});
685 if (it != unifications.end() && !typeParamsUnify({attr}, {it->second})) {
686 // Tested in call_with_template_params_fail.llzk
687 return this->emitOpError().append(
688 "template instantiation value '", attr, "' for parameter \"@", paramOp.getName(),
689 "\" conflicts with value '", it->second, "' inferred from function type signature"
690 );
691 }
692 }
693 return success();
694}
695
696namespace {
697
698struct CallOpVerifier {
699 CallOpVerifier(CallOp *c, FunctionKind tgtFuncKind) : callOp(c), tgtKind(tgtFuncKind) {}
700 CallOpVerifier(CallOp *c, StringRef tgtName) : CallOpVerifier(c, fnNameToKind(tgtName)) {}
701 virtual ~CallOpVerifier() = default;
702
703 LogicalResult verify() {
704 // Rather than immediately returning on failure, we check all verifier steps and aggregate to
705 // provide as many errors are possible in a single verifier run.
706 LogicalResult aggregateResult = success();
707 if (failed(verifyTargetAttributes())) {
708 aggregateResult = failure();
709 }
710 if (failed(verifyInputs())) {
711 aggregateResult = failure();
712 }
713 if (failed(verifyOutputs())) {
714 aggregateResult = failure();
715 }
716 if (failed(verifyTemplateParams())) {
717 aggregateResult = failure();
718 }
719 if (failed(verifyAffineMapParams())) {
720 aggregateResult = failure();
721 }
722 return aggregateResult;
723 }
724
725protected:
726 CallOp *callOp;
727 FunctionKind tgtKind;
728
729 virtual LogicalResult verifyTargetAttributes() = 0;
730 virtual LogicalResult verifyInputs() = 0;
731 virtual LogicalResult verifyOutputs() = 0;
732 virtual LogicalResult verifyTemplateParams() = 0;
733 virtual LogicalResult verifyAffineMapParams() = 0;
734
736 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
737 LogicalResult aggregateRes = success();
738 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
739 auto emitAttrErr = [&](StringLiteral attrName) {
740 aggregateRes = callOp->emitOpError()
741 << "target '@" << target.getName() << "' has '" << attrName
742 << "' attribute, which is not specified by the caller '@" << caller.getName()
743 << '\'';
744 };
745
746 if (target.hasAllowConstraintAttr() && !caller.hasAllowConstraintAttr()) {
747 emitAttrErr(AllowConstraintAttr::name);
748 }
749 if (target.hasAllowWitnessAttr() && !caller.hasAllowWitnessAttr()) {
750 emitAttrErr(AllowWitnessAttr::name);
751 }
752 if (target.hasAllowNonNativeFieldOpsAttr() && !caller.hasAllowNonNativeFieldOpsAttr()) {
753 emitAttrErr(AllowNonNativeFieldOpsAttr::name);
754 }
755 }
756 return aggregateRes;
757 }
758
759 LogicalResult verifyNoTemplateInstantiations() {
760 if (!isNullOrEmpty(callOp->getTemplateParamsAttr())) {
761 // Tested in call_with_template_params_fail.llzk
762 return callOp->emitOpError().append(
763 "can only have template instantiations when targeting a templated free function"
764 );
765 }
766 return success();
767 }
768
769 LogicalResult verifyNoAffineMapInstantiations() {
770 if (!isNullOrEmpty(callOp->getMapOpGroupSizesAttr())) {
771 // Tested in call_with_affinemap_fail.llzk
772 return callOp->emitOpError().append(
773 "can only have affine map instantiations when targeting a \"@", FUNC_NAME_COMPUTE,
774 "\" function"
775 );
776 }
777 // ASSERT: the check above is sufficient due to VerifySizesForMultiAffineOps trait.
778 assert(isNullOrEmpty(callOp->getNumDimsPerMapAttr()));
779 assert(callOp->getMapOperands().empty());
780 return success();
781 }
782};
783
784struct KnownTargetVerifier : public CallOpVerifier {
785 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
786 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
787 includeSymNames(tgtRes.getNamespace()) {}
788
789 LogicalResult verifyTargetAttributes() override {
790 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
791 }
792
793 LogicalResult verifyInputs() override {
794 return verifyTypesMatch(callOp->getArgOperands().getTypes(), tgtType.getInputs(), "operand");
795 }
796
797 LogicalResult verifyOutputs() override {
798 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(), "result");
799 }
800
801 LogicalResult verifyTemplateParams() override {
802 Operation *tgtOp = tgt.getOperation();
803 if (isInStruct(tgtOp)) {
804 // Struct function calls cannot contain template parameter instantiations.
805 return verifyNoTemplateInstantiations();
806 } else if (TemplateOp tgtOpParent = getParentOfType<TemplateOp>(tgtOp)) {
807 // When the target function is a free function within a TemplateOp, the CallOp may have
808 // template parameter instantiations that must be checked against the template parameters.
809 // - If the function type signature references all template parameters, then the parameter
810 // instantiation list on the CallOp is optional, otherwise it's required.
811 // - If present, the instantiation list must provide a value for every template parameter
812 // and the value must be type-compatible with the parameter's declared type (if any).
813 // - If present, the instantiation list must result in a function type signature that can
814 // be unified with the CallOp's operand and result types.
815 auto realParams = tgtOpParent.getConstOps<TemplateParamOp>();
816 ArrayAttr callParams = callOp->getTemplateParamsAttr();
817
818 // When there is no instantiation list, just ensure that it's not required.
819 if (isNullOrEmpty(callParams)) {
820 llvm::SmallDenseSet<SymbolRefAttr> referencedInSignature;
821 llzk::getSymbolsUsedIn(tgtType.getInputs(), referencedInSignature);
822 llzk::getSymbolsUsedIn(tgtType.getResults(), referencedInSignature);
823
824 bool allParamsReferenced = llvm::all_of(realParams, [&](TemplateParamOp p) {
825 return referencedInSignature.contains(FlatSymbolRefAttr::get(p.getNameAttr()));
826 });
827 if (allParamsReferenced) {
828 return success();
829 }
830 // Tested in call_with_template_params_fail.llzk
831 return callOp->emitOpError().append(
832 "must provide template instantiation parameters when calling \"@", tgt.getSymName(),
833 "\" because not all template parameters of \"@", tgtOpParent.getSymName(),
834 "\" appear in the function type signature"
835 );
836 }
837
838 // Ensure `forceIntAttrTypes()` was successful on the CallOp's template parameters.
839 if (failed(llzk::forceIntAttrTypes(callParams.getValue(), [this] {
840 return llzk::InFlightDiagnosticWrapper(this->callOp->emitOpError());
841 }))) {
842 return failure();
843 }
844
845 // The instantiation list is present. Check it has exactly one entry per template param.
846 size_t numTemplateParams = llvm::range_size(realParams);
847 if (callParams.size() != numTemplateParams) {
848 // Tested in call_with_template_params_fail.llzk
849 return callOp->emitOpError().append(
850 "template instantiation has ", callParams.size(), " parameter(s) but \"@",
851 tgtOpParent.getSymName(), "\" expects ", numTemplateParams, " template parameter(s)"
852 );
853 }
854
855 // Check type compatibility of each provided value with the declared parameter type (if any).
856 if (failed(callOp->verifyTemplateParamCompatibility(realParams))) {
857 return failure();
858 }
859
860 // Check that the provided instantiation values are consistent with what type unification
861 // of the target function types against the call's operand and result types would determine.
862 FailureOr<UnificationMap> unifyResult = callOp->unifyTypeSignature(tgtType);
863 assert(succeeded(unifyResult) && "already checked by `verifyInputs()` and `verifyOutputs()`");
864 return callOp->verifyTemplateParamsMatchInferred(realParams, unifyResult.value());
865 } else {
866 // Non-template functions cannot contain template parameter instantiations.
867 return verifyNoTemplateInstantiations();
868 }
869 }
870
871 LogicalResult verifyAffineMapParams() override {
872 if ((FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) &&
873 isInStruct(tgt.getOperation())) {
874 // Return type should be a single StructType. If that is not the case here, just bail without
875 // producing an error. The combination of this KnownTargetVerifier resolving the callee to a
876 // specific FuncDefOp and verifyFuncTypeCompute() ensuring all FUNC_NAME_COMPUTE FuncOps have
877 // a single StructType return value will produce a more relevant error message in that case.
878 if (StructType retTy = callOp->getSingleResultTypeOfWitnessGen()) {
879 if (ArrayAttr params = retTy.getParams()) {
880 // Collect the struct parameters that are defined via AffineMapAttr
881 SmallVector<AffineMapAttr> mapAttrs;
882 for (Attribute a : params) {
883 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
884 mapAttrs.push_back(m);
885 }
886 }
888 callOp->getMapOperands(), callOp->getNumDimsPerMap(), mapAttrs, *callOp
889 );
890 }
891 }
892 return success();
893 } else {
894 // Global functions and constrain functions cannot have affine map instantiations.
895 return verifyNoAffineMapInstantiations();
896 }
897 }
898
899private:
900 template <typename T>
901 LogicalResult
902 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes, const char *aspect) {
903 if (tgtTypes.size() != callOpTypes.size()) {
904 return callOp->emitOpError()
905 .append("incorrect number of ", aspect, "s for callee, expected ", tgtTypes.size())
906 .attachNote(tgt.getLoc())
907 .append("callee defined here");
908 }
909 for (unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
910 if (!typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
911 return callOp->emitOpError().append(
912 aspect, " type mismatch: expected type ", tgtTypes[i], ", but found ", callOpTypes[i],
913 " for ", aspect, " number ", i
914 );
915 }
916 }
917 return success();
918 }
919
920 FuncDefOp tgt;
921 FunctionType tgtType;
922 std::vector<llvm::StringRef> includeSymNames;
923};
924
927LogicalResult checkSelfTypeUnknownTarget(
928 StringAttr expectedParamName, Type actualType, CallOp *origin, const char *aspect
929) {
930 if (!llvm::isa<TypeVarType>(actualType) ||
931 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
932 // Tested in function_restrictions_fail.llzk:
933 // Non-tvar for constrain input via "call_target_constrain_without_self_non_struct"
934 // Non-tvar for compute output via "call_target_compute_wrong_type_ret"
935 // Wrong tvar for constrain input via "call_target_constrain_without_self_wrong_tvar_param"
936 // Wrong tvar for compute output via "call_target_compute_wrong_tvar_param_ret"
937 return origin->emitOpError().append(
938 "target \"@", origin->getCallee().getLeafReference().getValue(), "\" expected ", aspect,
939 " type '!", TypeVarType::name, "<@", expectedParamName.getValue(), ">' but found ",
940 actualType
941 );
942 }
943 return success();
944}
945
955struct UnknownTargetVerifier : public CallOpVerifier {
956 UnknownTargetVerifier(CallOp *c, FunctionKind tgtFuncKind, SymbolRefAttr callee)
957 : CallOpVerifier(c, tgtFuncKind), calleeAttr(callee) {
958 assert(
959 tgtFuncKind == FunctionKind::StructCompute ||
960 tgtFuncKind == FunctionKind::StructConstrain || tgtFuncKind == FunctionKind::StructProduct
961 ); // pre-condition mentioned above
962 }
963
964 LogicalResult verifyTargetAttributes() override {
965 // Based on the precondition of this verifier, the target must be either a
966 // struct compute, constrain, or product function.
967 LogicalResult aggregateRes = success();
968 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
969 auto emitAttrErr = [&](StringLiteral attrName) {
970 aggregateRes = callOp->emitOpError()
971 << "target '" << calleeAttr << "' has '" << attrName
972 << "' attribute, which is not specified by the caller '@" << caller.getName()
973 << '\'';
974 };
975
976 switch (tgtKind) {
978 if (!caller.hasAllowConstraintAttr()) {
979 emitAttrErr(AllowConstraintAttr::name);
980 }
981 break;
983 if (!caller.hasAllowWitnessAttr()) {
984 emitAttrErr(AllowWitnessAttr::name);
985 }
986 break;
988 if (!caller.hasAllowWitnessAttr()) {
989 emitAttrErr(AllowWitnessAttr::name);
990 }
991 if (!caller.hasAllowConstraintAttr()) {
992 emitAttrErr(AllowConstraintAttr::name);
993 }
994 break;
995 default:
996 break;
997 }
998 }
999 return aggregateRes;
1000 }
1001
1002 LogicalResult verifyInputs() override {
1003 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
1004 // Without known target, no additional checks can be done.
1005 } else if (FunctionKind::StructConstrain == tgtKind) {
1006 // Without known target, this can only check that the first input is VarType using the same
1007 // struct parameter as the base of the callee (later replaced with the target struct's type).
1008 Operation::operand_type_range inputTypes = callOp->getArgOperands().getTypes();
1009 if (inputTypes.size() < 1) {
1010 // Tested in function_restrictions_fail.llzk
1011 return callOp->emitOpError()
1012 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have at least one input type";
1013 }
1014 return checkSelfTypeUnknownTarget(
1015 calleeAttr.getRootReference(), inputTypes.front(), callOp, "first input"
1016 );
1017 }
1018 return success();
1019 }
1020
1021 LogicalResult verifyOutputs() override {
1022 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
1023 // Without known target, this can only check that the function returns VarType using the same
1024 // struct parameter as the base of the callee (later replaced with the target struct's type).
1025 Operation::result_type_range resTypes = callOp->getResultTypes();
1026 if (resTypes.size() != 1) {
1027 // Tested in function_restrictions_fail.llzk
1028 return callOp->emitOpError().append(
1029 "target \"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
1030 );
1031 }
1032 return checkSelfTypeUnknownTarget(
1033 calleeAttr.getRootReference(), resTypes.front(), callOp, "return"
1034 );
1035 } else if (FunctionKind::StructConstrain == tgtKind) {
1036 // Without known target, this can only check that the function has no return
1037 if (callOp->getNumResults() != 0) {
1038 // Tested in function_restrictions_fail.llzk
1039 return callOp->emitOpError()
1040 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
1041 }
1042 }
1043 return success();
1044 }
1045
1046 LogicalResult verifyTemplateParams() override {
1047 // Struct function calls cannot contain template parameter instantiations.
1048 return verifyNoTemplateInstantiations();
1049 }
1050
1051 LogicalResult verifyAffineMapParams() override {
1052 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
1053 // Without known target, no additional checks can be done.
1054 } else if (FunctionKind::StructConstrain == tgtKind) {
1055 // Without known target, this can only check that there are no affine map instantiations.
1056 return verifyNoAffineMapInstantiations();
1057 }
1058 return success();
1059 }
1060
1061private:
1062 SymbolRefAttr calleeAttr;
1063};
1064
1065} // namespace
1066
1067LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &tables) {
1068 // First, verify symbol resolution in all input and output types.
1069 if (failed(verifyTypeResolution(tables, *this, getTypeSignature()))) {
1070 return failure(); // verifyTypeResolution() already emits a sufficient error message
1071 }
1072
1073 // Check that the callee attribute was specified.
1074 SymbolRefAttr calleeAttr = getCalleeAttr();
1075 if (!calleeAttr) {
1076 return emitOpError("requires a 'callee' symbol reference attribute");
1077 }
1078
1079 // If the callee references a parameter of the template where this call appears, perform
1080 // the subset of checks that can be done even though the target is unknown.
1081 if (calleeAttr.getNestedReferences().size() == 1) {
1082 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
1083 if (parent.hasConstNamed<TemplateParamOp>(calleeAttr.getRootReference())) {
1084 FunctionKind tgtKind = fnNameToKind(calleeAttr.getLeafReference().getValue());
1085 if (tgtKind != FunctionKind::Free) {
1086 return UnknownTargetVerifier(this, tgtKind, calleeAttr).verify();
1087 }
1088 return this->emitError("expected parameterized callee to target a struct function")
1089 .append(
1090 " (i.e. \"@", FUNC_NAME_PRODUCT, "\", \"@", FUNC_NAME_COMPUTE, "\", or \"@",
1091 FUNC_NAME_CONSTRAIN, "\")"
1092 );
1093 }
1094 }
1095 }
1096
1097 // Otherwise, callee must be specified via full path from the root module. Perform the full set of
1098 // checks against the known target function.
1099 auto tgtOpt = lookupTopLevelSymbol<FuncDefOp>(tables, calleeAttr, *this);
1100 if (failed(tgtOpt)) {
1101 return this->emitError() << "expected '" << FuncDefOp::getOperationName() << "' named \""
1102 << calleeAttr << '"';
1103 }
1104 return KnownTargetVerifier(this, std::move(*tgtOpt)).verify();
1105}
1106
1108 return FunctionType::get(getContext(), getArgOperands().getTypes(), getResultTypes());
1109}
1110
1111FailureOr<UnificationMap> CallOp::unifyTypeSignature(FunctionType other) {
1112 UnificationMap unifications;
1113 if (functionTypesUnify(getTypeSignature(), other, {}, &unifications)) {
1114 return unifications;
1115 } else {
1116 return failure();
1117 }
1118}
1119
1120namespace {
1121
1122bool calleeIsStructFunctionImpl(
1123 const char *funcName, SymbolRefAttr callee, llvm::function_ref<StructType()> getType
1124) {
1125 if (callee.getLeafReference() == funcName) {
1126 if (StructType t = getType()) {
1127 // If the name ref within the StructType matches the `callee` prefix (i.e., sans the function
1128 // name itself), then the `callee` target must be within a StructDefOp because validation
1129 // checks elsewhere ensure that every StructType references a StructDefOp (i.e., the `callee`
1130 // function is not simply a free function nested within a ModuleOp)
1131 return t.getNameRef() == getPrefixAsSymbolRefAttr(callee);
1132 }
1133 }
1134 return false;
1135}
1136
1137} // namespace
1138
1140 return calleeIsStructFunctionImpl(FUNC_NAME_COMPUTE, getCallee(), [this]() {
1141 return this->getSingleResultTypeOfCompute();
1142 });
1143}
1144
1146 return calleeIsStructFunctionImpl(FUNC_NAME_CONSTRAIN, getCallee(), [this]() {
1147 return getAtIndex<StructType>(this->getArgOperands().getTypes(), 0);
1148 });
1149}
1150
1152 assert(calleeIsStructCompute());
1153 return getResults().front();
1154}
1155
1157 assert(calleeIsStructConstrain());
1158 return getArgOperands().front();
1159}
1160
1161FailureOr<SymbolLookupResult<FuncDefOp>> CallOp::getCalleeTarget(SymbolTableCollection &tables) {
1162 Operation *thisOp = this->getOperation();
1163 auto root = getRootModule(thisOp);
1164 assert(succeeded(root));
1165 return llzk::lookupSymbolIn<FuncDefOp>(tables, getCallee(), root->getOperation(), thisOp);
1166}
1167
1169 assert(calleeIsCompute() && "violated implementation pre-condition");
1170 return getIfSingleton<StructType>(getResultTypes());
1171}
1172
1174 assert(calleeContainsWitnessGen() && "violated implementation pre-condition");
1175 return getIfSingleton<StructType>(getResultTypes());
1176}
1177
1179CallInterfaceCallable CallOp::getCallableForCallee() { return getCalleeAttr(); }
1180
1182void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1183 setCalleeAttr(llvm::cast<SymbolRefAttr>(callee));
1184}
1185
1186SmallVector<ValueRange> CallOp::toVectorOfValueRange(OperandRangeRange input) {
1187 llvm::SmallVector<ValueRange, 4> output;
1188 output.reserve(input.size());
1189 for (OperandRange r : input) {
1190 output.push_back(r);
1191 }
1192 return output;
1193}
1194
1195Operation *CallOp::resolveCallableInTable(SymbolTableCollection *symbolTable) {
1196 FailureOr<SymbolLookupResult<FuncDefOp>> res =
1197 llzk::resolveCallable<FuncDefOp>(*symbolTable, *this);
1198 if (failed(res) || res->isManaged()) {
1199 // Cannot return pointer to a managed Operation since it would cause memory errors.
1200 return nullptr;
1201 }
1202 return res->get();
1203}
1204
1206 SymbolTableCollection tables;
1207 return resolveCallableInTable(&tables);
1208}
1209
1210} // namespace llzk::function
This file defines methods symbol lookup across LLZK operations and included files.
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::SymbolRefAttr callee, ::mlir::ValueRange argOperands={}, ::llvm::ArrayRef<::mlir::Attribute > templateParams={})
bool calleeContainsWitnessGen()
Return true iff the callee function can contain witness generation code (this does not check if the c...
Definition Ops.h.inc:389
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:1145
::mlir::CallInterfaceCallable getCallableForCallee()
Return the callee of this operation.
Definition Ops.cpp:1179
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:1168
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:292
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:1067
::mlir::Operation * resolveCallableInTable(::mlir::SymbolTableCollection *symbolTable)
Required by CallOpInterface.
Definition Ops.cpp:1195
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:1139
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:470
::mlir::Operation * resolveCallable()
Required by CallOpInterface.
Definition Ops.cpp:1205
void writeProperties(::mlir::DialectBytecodeWriter &writer)
Definition Ops.cpp:556
::mlir::FunctionType getTypeSignature()
Return the FunctionType inferred from the arg operands and result types of this CallOp.
Definition Ops.cpp:1107
bool calleeIsCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE (this does not check if the callee func...
Definition Ops.h.inc:383
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:480
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:1156
::mlir::ArrayAttr getTemplateParamsAttr()
Definition Ops.h.inc:297
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:270
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:1186
::llzk::component::StructType getSingleResultTypeOfWitnessGen()
Assuming the callee contains witness generation code, return the single StructType result.
Definition Ops.cpp:1173
FoldAdaptor::Properties Properties
Definition Ops.h.inc:209
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.h.inc:312
::llvm::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state)
Definition Ops.cpp:517
::mlir::FailureOr< UnificationMap > unifyTypeSignature(::mlir::FunctionType other)
Attempt type unfication between the inferred FunctionType from this CallOp (as LHS) and the given Fun...
Definition Ops.cpp:1111
::mlir::LogicalResult verifyTemplateParamsMatchInferred(::llvm::iterator_range<::mlir::Region::op_iterator<::llzk::polymorphic::TemplateParamOp > > targetParamDefs, const UnificationMap &unifications)
Verify that each template parameter value provided in this CallOp is consistent with the value inferr...
Definition Ops.cpp:669
::mlir::LogicalResult verifyTemplateParamCompatibility(::mlir::Attribute paramFromCallOp, ::llzk::polymorphic::TemplateParamOp targetParam)
Check type compatibility of the given template parameter value from this CallOp against the declared ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:1151
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee)
Set the callee for this operation.
Definition Ops.cpp:1182
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:1161
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:269
void setArgNameAttr(unsigned index, const ::mlir::StringAttr &attr)
Set the function.arg_name attribute for the argument at the given index.
Definition Ops.cpp:307
void setArgName(unsigned index, ::llvm::StringRef name)
Set the function.arg_name attribute for the argument at the given index from a string.
Definition Ops.cpp:312
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:431
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:457
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={}, ::llvm::ArrayRef<::mlir::DictionaryAttr > argAttrs={})
void setAllowNonNativeFieldOpsAttr(bool newValue=true)
Add (resp. remove) the allow_non_native_field_ops attribute to (resp. from) the function def.
Definition Ops.cpp:277
::mlir::StringAttr getFunctionTypeAttrName()
Definition Ops.h.inc:642
bool hasAllowNonNativeFieldOpsAttr()
Return true iff the function def has the allow_non_native_field_ops attribute.
Definition Ops.h.inc:820
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:476
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:180
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:865
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:812
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:481
::mlir::StringAttr getResAttrsAttrName()
Definition Ops.h.inc:650
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:169
bool hasArgName(unsigned index)
Return true iff the argument at the given index has a function.arg_name attribute.
Definition Ops.cpp:295
void cloneInto(FuncDefOp dest, ::mlir::IRMapping &mapper)
Clone the internal blocks and attributes from this function into dest.
Definition Ops.cpp:189
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:873
::std::optional<::mlir::StringAttr > getArgNameAttr(unsigned index)
Return the function.arg_name attribute for the argument at the given index.
Definition Ops.cpp:297
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:869
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:879
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:876
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:854
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:666
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:261
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:447
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
Definition Ops.cpp:285
::llvm::LogicalResult verify()
Definition Ops.cpp:316
::mlir::Region & getBody()
Definition Ops.h.inc:690
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:804
::mlir::StringAttr getArgAttrsAttrName()
Definition Ops.h.inc:634
::llvm::LogicalResult verify()
Definition Ops.cpp:490
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:980
::std::optional<::mlir::Type > getTypeOpt()
Definition Ops.cpp.inc:1337
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:27
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
bool isInStruct(Operation *op)
Definition Ops.cpp:57
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:121
constexpr char ARG_NAME_ATTR_NAME[]
Attribute name for source-level function argument names.
Definition Ops.h:34
constexpr char RES_NAME_ATTR_NAME[]
Attribute name for source-level function result names.
Definition Ops.h:37
FunctionKind fnNameToKind(mlir::StringRef name)
Given a function name, return the corresponding FunctionKind.
Definition Ops.cpp:52
FunctionKind
Kinds of functions in LLZK.
Definition Ops.h:40
@ StructConstrain
Function within a struct named FUNC_NAME_CONSTRAIN.
Definition Ops.h:44
@ StructProduct
Function within a struct named FUNC_NAME_PRODUCT.
Definition Ops.h:46
@ StructCompute
Function within a struct named FUNC_NAME_COMPUTE.
Definition Ops.h:42
@ Free
Function that is not within a struct.
Definition Ops.h:48
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:270
FailureOr< ModuleOp > getRootModule(Operation *from)
void getSymbolsUsedIn(mlir::Type t, llvm::SmallDenseSet< mlir::SymbolRefAttr > &symbolsUsed)
Add all symbols used within the given Type to the provided set.
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:186
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
bool isNullOrEmpty(mlir::ArrayAttr a)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
Definition OpHelpers.h:51
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:274
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
bool isDynamic(IntegerAttr intAttr)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool functionTypesUnify(FunctionType lhs, FunctionType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
OpClass delegate_to_build(mlir::Location location, Args &&...args)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:114
void addTemplateParams(mlir::OpBuilder &odsBuilder, typename OpClass::Properties &props, llvm::ArrayRef< mlir::Attribute > templateParams)
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
bool isValidConstReadType(Type type)