LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Func and call op implementations --------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from the LLVM Project's lib/Dialect/Func/IR/FuncOps.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
14
24
25#include <mlir/IR/IRMapping.h>
26#include <mlir/IR/OpImplementation.h>
27#include <mlir/Interfaces/FunctionImplementation.h>
28
29#include <llvm/ADT/MapVector.h>
30
31// TableGen'd implementation files
32#define GET_OP_CLASSES
34
35using namespace mlir;
36using namespace llzk::component;
37using namespace llzk::polymorphic;
38
39namespace llzk::function {
40
41FunctionKind fnNameToKind(mlir::StringRef name) {
42 if (FUNC_NAME_COMPUTE == name) {
44 } else if (FUNC_NAME_CONSTRAIN == name) {
46 } else if (FUNC_NAME_PRODUCT == name) {
48 } else {
49 return FunctionKind::Free;
50 }
51}
52
53namespace {
55inline LogicalResult
56verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, FunctionType funcType) {
58 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
59 );
60}
61} // namespace
62
63//===----------------------------------------------------------------------===//
64// FuncDefOp
65//===----------------------------------------------------------------------===//
66
68 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
69) {
70 return delegate_to_build<FuncDefOp>(location, name, type, attrs);
71}
72
74 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
75) {
76 SmallVector<NamedAttribute, 8> attrRef(attrs);
77 return create(location, name, type, llvm::ArrayRef(attrRef));
78}
79
81 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
82 ArrayRef<DictionaryAttr> argAttrs
83) {
84 FuncDefOp func = create(location, name, type, attrs);
85 func.setAllArgAttrs(argAttrs);
86 return func;
87}
88
90 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
91 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
92) {
93 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
94 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
95 state.attributes.append(attrs.begin(), attrs.end());
96 state.addRegion();
97
98 if (argAttrs.empty()) {
99 return;
100 }
101 assert(type.getNumInputs() == argAttrs.size());
102 function_interface_impl::addArgAndResultAttrs(
103 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name),
104 getResAttrsAttrName(state.name)
105 );
106}
107
108ParseResult FuncDefOp::parse(OpAsmParser &parser, OperationState &result) {
109 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
110 function_interface_impl::VariadicFlag,
111 std::string &) { return builder.getFunctionType(argTypes, results); };
112
113 return function_interface_impl::parseFunctionOp(
114 parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType,
115 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)
116 );
117}
118
119void FuncDefOp::print(OpAsmPrinter &p) {
120 function_interface_impl::printFunctionOp(
121 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(),
123 );
124}
125
128void FuncDefOp::cloneInto(FuncDefOp dest, IRMapping &mapper) {
129 // Add the attributes of this function to dest.
130 llvm::MapVector<StringAttr, Attribute> newAttrMap;
131 for (const auto &attr : dest->getAttrs()) {
132 newAttrMap.insert({attr.getName(), attr.getValue()});
133 }
134 for (const auto &attr : (*this)->getAttrs()) {
135 newAttrMap.insert({attr.getName(), attr.getValue()});
136 }
137
138 auto newAttrs =
139 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
140 return NamedAttribute(attrPair.first, attrPair.second);
141 }));
142 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
143
144 // Clone the body.
145 getBody().cloneInto(&dest.getBody(), mapper);
146}
147
153FuncDefOp FuncDefOp::clone(IRMapping &mapper) {
154 // Create the new function.
155 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
156
157 // If the function has a body, then the user might be deleting arguments to
158 // the function by specifying them in the mapper. If so, we don't add the
159 // argument to the input type vector.
160 if (!isExternal()) {
161 FunctionType oldType = getFunctionType();
162
163 unsigned oldNumArgs = oldType.getNumInputs();
164 SmallVector<Type, 4> newInputs;
165 newInputs.reserve(oldNumArgs);
166 for (unsigned i = 0; i != oldNumArgs; ++i) {
167 if (!mapper.contains(getArgument(i))) {
168 newInputs.push_back(oldType.getInput(i));
169 }
170 }
171
174 if (newInputs.size() != oldNumArgs) {
175 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
176
177 if (ArrayAttr argAttrs = getAllArgAttrs()) {
178 SmallVector<Attribute> newArgAttrs;
179 newArgAttrs.reserve(newInputs.size());
180 for (unsigned i = 0; i != oldNumArgs; ++i) {
181 if (!mapper.contains(getArgument(i))) {
182 newArgAttrs.push_back(argAttrs[i]);
183 }
184 }
185 newFunc.setAllArgAttrs(newArgAttrs);
186 }
187 }
188 }
189
191 cloneInto(newFunc, mapper);
192 return newFunc;
193}
194
196 IRMapping mapper;
197 return clone(mapper);
198}
199
201 if (newValue) {
202 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
203 } else {
204 getOperation()->removeAttr(AllowConstraintAttr::name);
205 }
206}
207
209 if (newValue) {
210 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
211 } else {
212 getOperation()->removeAttr(AllowWitnessAttr::name);
213 }
214}
215
217 if (newValue) {
218 getOperation()->setAttr(AllowNonNativeFieldOpsAttr::name, UnitAttr::get(getContext()));
219 } else {
220 getOperation()->removeAttr(AllowNonNativeFieldOpsAttr::name);
221 }
222}
223
224bool FuncDefOp::hasArgPublicAttr(unsigned index) {
225 if (index < this->getNumArguments()) {
226 DictionaryAttr res = function_interface_impl::getArgAttrDict(*this, index);
227 return res ? res.contains(PublicAttr::name) : false;
228 } else {
229 // TODO: print error? requested attribute for non-existant argument index
230 return false;
231 }
232}
233
234LogicalResult FuncDefOp::verify() {
235 OwningEmitErrorFn emitErrorFunc = getEmitOpErrFn(this);
236 // Ensure that only valid LLZK types are used for arguments and return. Additionally, the struct
237 // functions may not use AffineMapAttrs in their parameter types. If such a scenario seems to make
238 // sense when generating LLZK IR, it's likely better to introduce a struct parameter to use
239 // instead and instantiate the struct with that AffineMapAttr.
240 FunctionType type = getFunctionType();
241 for (Type t : type.getInputs()) {
242 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
243 return failure();
244 }
245 if (isInStruct() && hasAffineMapAttr(t)) {
246 return emitErrorFunc().append(
247 "\"@", getName(), "\" parameters cannot contain affine map attributes but found ", t
248 );
249 }
250 }
251 for (Type t : type.getResults()) {
252 if (llzk::checkValidType(emitErrorFunc, t).failed()) {
253 return failure();
254 }
255 }
256 // Ensure that the function does not contain nested modules.
257 // Functions also cannot contain nested structs, but this check is handled
258 // via struct.def's requirement of having module as a parent.
259 WalkResult res = this->walk<WalkOrder::PreOrder>([this](ModuleOp nestedMod) {
260 getEmitOpErrFn(nestedMod)().append(
261 "cannot be nested within '", getOperation()->getName(), "' operations"
262 );
263 return WalkResult::interrupt();
264 });
265 if (res.wasInterrupted()) {
266 return failure();
267 }
268
269 return success();
270}
271
272namespace {
273
274LogicalResult
275verifyFuncTypeCompute(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
276 FunctionType funcType = origin.getFunctionType();
277 llvm::ArrayRef<Type> resTypes = funcType.getResults();
278 // Must return type of parent struct
279 if (resTypes.size() != 1) {
280 return origin.emitOpError().append(
281 "\"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
282 );
283 }
284 if (failed(checkSelfType(tables, parent, resTypes.front(), origin, "return"))) {
285 return failure();
286 }
287
288 // After the more specific checks (to ensure more specific error messages would be produced if
289 // necessary), do the general check that all symbol references in the types are valid. The return
290 // types were already checked so just check the input types.
291 return llzk::verifyTypeResolution(tables, origin, funcType.getInputs());
292}
293
294LogicalResult
295verifyFuncTypeProduct(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
296 // The signature for @product is the same as the signature for @compute
297 return verifyFuncTypeCompute(origin, tables, parent);
298}
299
300LogicalResult
301verifyFuncTypeConstrain(FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
302 FunctionType funcType = origin.getFunctionType();
303 // Must return '()' type, i.e., have no return types
304 if (funcType.getResults().size() != 0) {
305 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
306 }
307
308 // Type of the first parameter must match the parent StructDefOp of the current operation.
309 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
310 if (inputTypes.size() < 1) {
311 return origin.emitOpError() << "\"@" << FUNC_NAME_CONSTRAIN
312 << "\" must have at least one input type";
313 }
314 if (failed(checkSelfType(tables, parent, inputTypes.front(), origin, "first input"))) {
315 return failure();
316 }
317
318 // After the more specific checks (to ensure more specific error messages would be produced if
319 // necessary), do the general check that all symbol references in the types are valid. There are
320 // no return types, just check the remaining input types (the first was already checked via
321 // the checkSelfType() call above).
322 return llzk::verifyTypeResolution(tables, origin, inputTypes.drop_front());
323}
324
325} // namespace
326
327LogicalResult FuncDefOp::verifySymbolUses(SymbolTableCollection &tables) {
328 // Additional checks for the compute/constrain/product functions within a struct
329 FailureOr<StructDefOp> parentStructOpt = getParentOfType<StructDefOp>(*this);
330 if (succeeded(parentStructOpt)) {
331 // Verify return type restrictions for functions within a StructDefOp
332 if (nameIsCompute()) {
333 return verifyFuncTypeCompute(*this, tables, parentStructOpt.value());
334 } else if (nameIsConstrain()) {
335 return verifyFuncTypeConstrain(*this, tables, parentStructOpt.value());
336 } else if (nameIsProduct()) {
337 return verifyFuncTypeProduct(*this, tables, parentStructOpt.value());
338 }
339 }
340 // In the general case, verify symbol resolution in all input and output types.
341 return verifyTypeResolution(tables, *this, getFunctionType());
342}
343
344SymbolRefAttr FuncDefOp::getFullyQualifiedName(bool requireParent) {
345 // If the parent is not present and not required, just return the symbol name
346 if (!requireParent && getOperation()->getParentOp() == nullptr) {
347 return SymbolRefAttr::get(getOperation());
348 }
349 auto res = getPathFromRoot(*this);
350 assert(succeeded(res));
351 return res.value();
352}
353
355 assert(nameIsCompute()); // skip inStruct check to allow dangling functions
356 // Get the single block of the function body
357 Region &body = getBody();
358 assert(!body.empty() && "compute() function body is empty");
359 Block &block = body.back();
360
361 // The terminator should be the return op
362 Operation *terminator = block.getTerminator();
363 assert(terminator && "compute() function has no terminator");
364 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
365 if (!retOp) {
366 llvm::errs() << "Expected '" << ReturnOp::getOperationName() << "' but found '"
367 << terminator->getName() << "'\n";
368 llvm_unreachable("compute() function must end with ReturnOp");
369 }
370 return retOp.getOperands().front();
371}
372
374 assert(nameIsConstrain()); // skip inStruct check to allow dangling functions
375 return getArguments().front();
376}
377
379 assert(isStructCompute() && "violated implementation pre-condition");
381}
382
383//===----------------------------------------------------------------------===//
384// ReturnOp
385//===----------------------------------------------------------------------===//
386
387LogicalResult ReturnOp::verify() {
388 auto function = getParentOp<FuncDefOp>(); // parent is FuncDefOp per ODS
389
390 // The operand number and types must match the function signature.
391 const auto results = function.getFunctionType().getResults();
392 if (getNumOperands() != results.size()) {
393 return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@"
394 << function.getName() << ") returns " << results.size();
395 }
396
397 for (unsigned i = 0, e = results.size(); i != e; ++i) {
398 if (!typesUnify(getOperand(i).getType(), results[i])) {
399 return emitError() << "type of return operand " << i << " (" << getOperand(i).getType()
400 << ") doesn't match function result type (" << results[i] << ")"
401 << " in function @" << function.getName();
402 }
403 }
404
405 return success();
406}
407
408//===----------------------------------------------------------------------===//
409// CallOp
410//===----------------------------------------------------------------------===//
411
412void CallOp::build(
413 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
414 ValueRange argOperands
415) {
416 odsState.addTypes(resultTypes);
417 odsState.addOperands(argOperands);
419 odsBuilder, odsState, static_cast<int32_t>(argOperands.size())
420 );
421 props.setCallee(callee);
422}
423
424void CallOp::build(
425 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
426 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
427) {
428 odsState.addTypes(resultTypes);
429 odsState.addOperands(argOperands);
431 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
432 );
433 props.setCallee(callee);
434}
435
436namespace {
437
438struct CallOpVerifier {
439 CallOpVerifier(CallOp *c, StringRef tgtName) : callOp(c), tgtKind(fnNameToKind(tgtName)) {}
440 virtual ~CallOpVerifier() = default;
441
442 LogicalResult verify() {
443 // Rather than immediately returning on failure, we check all verifier steps and aggregate to
444 // provide as many errors are possible in a single verifier run.
445 LogicalResult aggregateResult = success();
446 if (failed(verifyTargetAttributes())) {
447 aggregateResult = failure();
448 }
449 if (failed(verifyInputs())) {
450 aggregateResult = failure();
451 }
452 if (failed(verifyOutputs())) {
453 aggregateResult = failure();
454 }
455 if (failed(verifyAffineMapParams())) {
456 aggregateResult = failure();
457 }
458 return aggregateResult;
459 }
460
461protected:
462 CallOp *callOp;
463 FunctionKind tgtKind;
464
465 virtual LogicalResult verifyTargetAttributes() = 0;
466 virtual LogicalResult verifyInputs() = 0;
467 virtual LogicalResult verifyOutputs() = 0;
468 virtual LogicalResult verifyAffineMapParams() = 0;
469
471 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
472 LogicalResult aggregateRes = success();
473 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
474 auto emitAttrErr = [&](StringLiteral attrName) {
475 aggregateRes = callOp->emitOpError()
476 << "target '@" << target.getName() << "' has '" << attrName
477 << "' attribute, which is not specified by the caller '@" << caller.getName()
478 << '\'';
479 };
480
481 if (target.hasAllowConstraintAttr() && !caller.hasAllowConstraintAttr()) {
482 emitAttrErr(AllowConstraintAttr::name);
483 }
484 if (target.hasAllowWitnessAttr() && !caller.hasAllowWitnessAttr()) {
485 emitAttrErr(AllowWitnessAttr::name);
486 }
487 if (target.hasAllowNonNativeFieldOpsAttr() && !caller.hasAllowNonNativeFieldOpsAttr()) {
488 emitAttrErr(AllowNonNativeFieldOpsAttr::name);
489 }
490 }
491 return aggregateRes;
492 }
493
494 LogicalResult verifyNoAffineMapInstantiations() {
495 if (!isNullOrEmpty(callOp->getMapOpGroupSizesAttr())) {
496 // Tested in call_with_affinemap_fail.llzk
497 return callOp->emitOpError().append(
498 "can only have affine map instantiations when targeting a \"@", FUNC_NAME_COMPUTE,
499 "\" function"
500 );
501 }
502 // ASSERT: the check above is sufficient due to VerifySizesForMultiAffineOps trait.
503 assert(isNullOrEmpty(callOp->getNumDimsPerMapAttr()));
504 assert(callOp->getMapOperands().empty());
505 return success();
506 }
507};
508
509struct KnownTargetVerifier : public CallOpVerifier {
510 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
511 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
512 includeSymNames(tgtRes.getNamespace()) {}
513
514 LogicalResult verifyTargetAttributes() override {
515 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
516 }
517
518 LogicalResult verifyInputs() override {
519 return verifyTypesMatch(callOp->getArgOperands().getTypes(), tgtType.getInputs(), "operand");
520 }
521
522 LogicalResult verifyOutputs() override {
523 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(), "result");
524 }
525
526 LogicalResult verifyAffineMapParams() override {
527 if ((FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) &&
528 isInStruct(tgt.getOperation())) {
529 // Return type should be a single StructType. If that is not the case here, just bail without
530 // producing an error. The combination of this KnownTargetVerifier resolving the callee to a
531 // specific FuncDefOp and verifyFuncTypeCompute() ensuring all FUNC_NAME_COMPUTE FuncOps have
532 // a single StructType return value will produce a more relevant error message in that case.
533 if (StructType retTy = callOp->getSingleResultTypeOfWitnessGen()) {
534 if (ArrayAttr params = retTy.getParams()) {
535 // Collect the struct parameters that are defined via AffineMapAttr
536 SmallVector<AffineMapAttr> mapAttrs;
537 for (Attribute a : params) {
538 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
539 mapAttrs.push_back(m);
540 }
541 }
543 callOp->getMapOperands(), callOp->getNumDimsPerMap(), mapAttrs, *callOp
544 );
545 }
546 }
547 return success();
548 } else {
549 // Global functions and constrain functions cannot have affine map instantiations.
550 return verifyNoAffineMapInstantiations();
551 }
552 }
553
554private:
555 template <typename T>
556 LogicalResult
557 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes, const char *aspect) {
558 if (tgtTypes.size() != callOpTypes.size()) {
559 return callOp->emitOpError()
560 .append("incorrect number of ", aspect, "s for callee, expected ", tgtTypes.size())
561 .attachNote(tgt.getLoc())
562 .append("callee defined here");
563 }
564 for (unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
565 if (!typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
566 return callOp->emitOpError().append(
567 aspect, " type mismatch: expected type ", tgtTypes[i], ", but found ", callOpTypes[i],
568 " for ", aspect, " number ", i
569 );
570 }
571 }
572 return success();
573 }
574
575 FuncDefOp tgt;
576 FunctionType tgtType;
577 std::vector<llvm::StringRef> includeSymNames;
578};
579
582LogicalResult checkSelfTypeUnknownTarget(
583 StringAttr expectedParamName, Type actualType, CallOp *origin, const char *aspect
584) {
585 if (!llvm::isa<TypeVarType>(actualType) ||
586 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
587 // Tested in function_restrictions_fail.llzk:
588 // Non-tvar for constrain input via "call_target_constrain_without_self_non_struct"
589 // Non-tvar for compute output via "call_target_compute_wrong_type_ret"
590 // Wrong tvar for constrain input via "call_target_constrain_without_self_wrong_tvar_param"
591 // Wrong tvar for compute output via "call_target_compute_wrong_tvar_param_ret"
592 return origin->emitOpError().append(
593 "target \"@", origin->getCallee().getLeafReference().getValue(), "\" expected ", aspect,
594 " type '!", TypeVarType::name, "<@", expectedParamName.getValue(), ">' but found ",
595 actualType
596 );
597 }
598 return success();
599}
600
610struct UnknownTargetVerifier : public CallOpVerifier {
611 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
612 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
613
614 LogicalResult verifyTargetAttributes() override {
615 // Based on the precondition of this verifier, the target must be either a
616 // struct compute, constrain, or product function.
617 LogicalResult aggregateRes = success();
618 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
619 auto emitAttrErr = [&](StringLiteral attrName) {
620 aggregateRes = callOp->emitOpError()
621 << "target '" << calleeAttr << "' has '" << attrName
622 << "' attribute, which is not specified by the caller '@" << caller.getName()
623 << '\'';
624 };
625
626 switch (tgtKind) {
628 if (!caller.hasAllowConstraintAttr()) {
629 emitAttrErr(AllowConstraintAttr::name);
630 }
631 break;
633 if (!caller.hasAllowWitnessAttr()) {
634 emitAttrErr(AllowWitnessAttr::name);
635 }
636 break;
638 if (!caller.hasAllowWitnessAttr()) {
639 emitAttrErr(AllowWitnessAttr::name);
640 }
641 if (!caller.hasAllowConstraintAttr()) {
642 emitAttrErr(AllowConstraintAttr::name);
643 }
644 break;
645 default:
646 break;
647 }
648 }
649 return aggregateRes;
650 }
651
652 LogicalResult verifyInputs() override {
653 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
654 // Without known target, no additional checks can be done.
655 } else if (FunctionKind::StructConstrain == tgtKind) {
656 // Without known target, this can only check that the first input is VarType using the same
657 // struct parameter as the base of the callee (later replaced with the target struct's type).
658 Operation::operand_type_range inputTypes = callOp->getArgOperands().getTypes();
659 if (inputTypes.size() < 1) {
660 // Tested in function_restrictions_fail.llzk
661 return callOp->emitOpError()
662 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have at least one input type";
663 }
664 return checkSelfTypeUnknownTarget(
665 calleeAttr.getRootReference(), inputTypes.front(), callOp, "first input"
666 );
667 }
668 return success();
669 }
670
671 LogicalResult verifyOutputs() override {
672 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
673 // Without known target, this can only check that the function returns VarType using the same
674 // struct parameter as the base of the callee (later replaced with the target struct's type).
675 Operation::result_type_range resTypes = callOp->getResultTypes();
676 if (resTypes.size() != 1) {
677 // Tested in function_restrictions_fail.llzk
678 return callOp->emitOpError().append(
679 "target \"@", FUNC_NAME_COMPUTE, "\" must have exactly one return type"
680 );
681 }
682 return checkSelfTypeUnknownTarget(
683 calleeAttr.getRootReference(), resTypes.front(), callOp, "return"
684 );
685 } else if (FunctionKind::StructConstrain == tgtKind) {
686 // Without known target, this can only check that the function has no return
687 if (callOp->getNumResults() != 0) {
688 // Tested in function_restrictions_fail.llzk
689 return callOp->emitOpError()
690 << "target \"@" << FUNC_NAME_CONSTRAIN << "\" must have no return type";
691 }
692 }
693 return success();
694 }
695
696 LogicalResult verifyAffineMapParams() override {
697 if (FunctionKind::StructCompute == tgtKind || FunctionKind::StructProduct == tgtKind) {
698 // Without known target, no additional checks can be done.
699 } else if (FunctionKind::StructConstrain == tgtKind) {
700 // Without known target, this can only check that there are no affine map instantiations.
701 return verifyNoAffineMapInstantiations();
702 }
703 return success();
704 }
705
706private:
707 SymbolRefAttr calleeAttr;
708};
709
710} // namespace
711
712LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &tables) {
713 // First, verify symbol resolution in all input and output types.
714 if (failed(verifyTypeResolution(tables, *this, getCalleeType()))) {
715 return failure(); // verifyTypeResolution() already emits a sufficient error message
716 }
717
718 // Check that the callee attribute was specified.
719 SymbolRefAttr calleeAttr = getCalleeAttr();
720 if (!calleeAttr) {
721 return emitOpError("requires a 'callee' symbol reference attribute");
722 }
723
724 // If the callee references a parameter of the struct where this call appears, perform the subset
725 // of checks that can be done even though the target is unknown.
726 if (calleeAttr.getNestedReferences().size() == 1) {
727 FailureOr<StructDefOp> parent = getParentOfType<StructDefOp>(*this);
728 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
729 return UnknownTargetVerifier(this, calleeAttr).verify();
730 }
731 }
732
733 // Otherwise, callee must be specified via full path from the root module. Perform the full set of
734 // checks against the known target function.
735 auto tgtOpt = lookupTopLevelSymbol<FuncDefOp>(tables, calleeAttr, *this);
736 if (failed(tgtOpt)) {
737 return this->emitError() << "expected '" << FuncDefOp::getOperationName() << "' named \""
738 << calleeAttr << '"';
739 }
740 return KnownTargetVerifier(this, std::move(*tgtOpt)).verify();
741}
742
743FunctionType CallOp::getCalleeType() {
744 return FunctionType::get(getContext(), getArgOperands().getTypes(), getResultTypes());
745}
746
747namespace {
748
749bool calleeIsStructFunctionImpl(
750 const char *funcName, SymbolRefAttr callee, llvm::function_ref<StructType()> getType
751) {
752 if (callee.getLeafReference() == funcName) {
753 if (StructType t = getType()) {
754 // If the name ref within the StructType matches the `callee` prefix (i.e., sans the function
755 // name itself), then the `callee` target must be within a StructDefOp because validation
756 // checks elsewhere ensure that every StructType references a StructDefOp (i.e., the `callee`
757 // function is not simply a global function nested within a ModuleOp)
758 return t.getNameRef() == getPrefixAsSymbolRefAttr(callee);
759 }
760 }
761 return false;
762}
763
764} // namespace
765
767 return calleeIsStructFunctionImpl(FUNC_NAME_COMPUTE, getCallee(), [this]() {
768 return this->getSingleResultTypeOfCompute();
769 });
770}
771
773 return calleeIsStructFunctionImpl(FUNC_NAME_CONSTRAIN, getCallee(), [this]() {
774 return getAtIndex<StructType>(this->getArgOperands().getTypes(), 0);
775 });
776}
777
779 assert(calleeIsStructCompute());
780 return getResults().front();
781}
782
784 assert(calleeIsStructConstrain());
785 return getArgOperands().front();
786}
787
788FailureOr<SymbolLookupResult<FuncDefOp>> CallOp::getCalleeTarget(SymbolTableCollection &tables) {
789 Operation *thisOp = this->getOperation();
790 auto root = getRootModule(thisOp);
791 assert(succeeded(root));
792 return llzk::lookupSymbolIn<FuncDefOp>(tables, getCallee(), root->getOperation(), thisOp);
793}
794
796 assert(calleeIsCompute() && "violated implementation pre-condition");
797 return getIfSingleton<StructType>(getResultTypes());
798}
799
801 assert(calleeContainsWitnessGen() && "violated implementation pre-condition");
802 return getIfSingleton<StructType>(getResultTypes());
803}
804
806CallInterfaceCallable CallOp::getCallableForCallee() { return getCalleeAttr(); }
807
809void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
810 setCalleeAttr(llvm::cast<SymbolRefAttr>(callee));
811}
812
813SmallVector<ValueRange> CallOp::toVectorOfValueRange(OperandRangeRange input) {
814 llvm::SmallVector<ValueRange, 4> output;
815 output.reserve(input.size());
816 for (OperandRange r : input) {
817 output.push_back(r);
818 }
819 return output;
820}
821
822} // namespace llzk::function
This file defines methods symbol lookup across LLZK operations and included files.
bool calleeContainsWitnessGen()
Return true iff the callee function can contain witness generation code (this does not check if the c...
Definition Ops.h.inc:335
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:772
::mlir::CallInterfaceCallable getCallableForCallee()
Return the callee of this operation.
Definition Ops.cpp:806
::mlir::FunctionType getCalleeType()
Definition Ops.cpp:743
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:795
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:267
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:712
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:766
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
bool calleeIsCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE (this does not check if the callee func...
Definition Ops.h.inc:329
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::SymbolRefAttr callee, ::mlir::ValueRange argOperands={})
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:472
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:783
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:245
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:813
::llzk::component::StructType getSingleResultTypeOfWitnessGen()
Assuming the callee contains witness generation code, return the single StructType result.
Definition Ops.cpp:800
FoldAdaptor::Properties Properties
Definition Ops.h.inc:192
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.h.inc:282
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:778
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee)
Set the callee for this operation.
Definition Ops.cpp:809
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:788
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:208
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:327
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:354
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={}, ::llvm::ArrayRef<::mlir::DictionaryAttr > argAttrs={})
void setAllowNonNativeFieldOpsAttr(bool newValue=true)
Add (resp. remove) the allow_non_native_field_ops attribute to (resp. from) the function def.
Definition Ops.cpp:216
::mlir::StringAttr getFunctionTypeAttrName()
Definition Ops.h.inc:559
bool hasAllowNonNativeFieldOpsAttr()
Return true iff the function def has the allow_non_native_field_ops attribute.
Definition Ops.h.inc:737
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:373
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:119
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:778
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:729
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:378
::mlir::StringAttr getResAttrsAttrName()
Definition Ops.h.inc:567
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:108
void cloneInto(FuncDefOp dest, ::mlir::IRMapping &mapper)
Clone the internal blocks and attributes from this function into dest.
Definition Ops.cpp:128
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:786
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:782
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:792
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:789
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Returns the result types of this function.
Definition Ops.h.inc:760
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:583
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:200
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:344
bool hasArgPublicAttr(unsigned index)
Return true iff the argument at the given index has pub attribute.
Definition Ops.cpp:224
::llvm::LogicalResult verify()
Definition Ops.cpp:234
::mlir::Region & getBody()
Definition Ops.h.inc:607
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:721
::mlir::StringAttr getArgAttrsAttrName()
Definition Ops.h.inc:551
::llvm::LogicalResult verify()
Definition Ops.cpp:387
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:893
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:27
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
bool isInStruct(Operation *op)
Definition Ops.cpp:45
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:112
FunctionKind fnNameToKind(mlir::StringRef name)
Given a function name, return the corresponding FunctionKind.
Definition Ops.cpp:41
FunctionKind
Kinds of functions in LLZK.
Definition Ops.h:32
@ StructConstrain
Function within a struct named FUNC_NAME_CONSTRAIN.
Definition Ops.h:36
@ StructProduct
Function within a struct named FUNC_NAME_PRODUCT.
Definition Ops.h:38
@ StructCompute
Function within a struct named FUNC_NAME_COMPUTE.
Definition Ops.h:34
@ Free
Function that is not within a struct.
Definition Ops.h:40
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:253
FailureOr< ModuleOp > getRootModule(Operation *from)
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:257
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
OpClass delegate_to_build(mlir::Location location, Args &&...args)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:111
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)