LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
WitnessSelection.cpp
Go to the documentation of this file.
1//===-- WitnessSelection.cpp ---------------------------- -------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "WitnessSelection.h"
11
16
17#include <mlir/IR/Operation.h>
18
19using namespace mlir;
20
21namespace llzk::witgen {
22namespace {
23
25static FailureOr<bool>
26typeContainsSignals(Type type, SymbolTableCollection &tables, Operation *origin);
27
29static FailureOr<bool> structContainsSignals(
30 component::StructDefOp def, SymbolTableCollection &tables, Operation *origin
31) {
32 for (component::MemberDefOp member : def.getMemberDefs()) {
33 if (memberIsSignal(def, member)) {
34 return true;
35 }
36 auto nested = typeContainsSignals(member.getType(), tables, origin);
37 if (failed(nested)) {
38 return failure();
39 }
40 if (*nested) {
41 return true;
42 }
43 }
44 return false;
45}
46
48static FailureOr<bool>
49typeContainsSignals(Type type, SymbolTableCollection &tables, Operation *origin) {
50 if (auto structType = dyn_cast<component::StructType>(type)) {
51 auto defLookup = structType.getDefinition(tables, origin);
52 if (failed(defLookup)) {
53 return failure();
54 }
55 return structContainsSignals(defLookup->get(), tables, origin);
56 }
57 return false;
58}
59
61static LogicalResult appendSignalLeafBindings(
62 Type type, ArrayRef<std::string> prefix, SmallVectorImpl<OutputBinding> &out, Operation *origin
63) {
64 if (isa<felt::FeltType, array::ArrayType>(type)) {
65 out.push_back(
66 OutputBinding {llvm::SmallVector<std::string>(prefix.begin(), prefix.end()), type}
67 );
68 return success();
69 }
70
71 if (auto podType = dyn_cast<pod::PodType>(type)) {
72 for (pod::RecordAttr record : podType.getRecords()) {
73 llvm::SmallVector<std::string> path(prefix.begin(), prefix.end());
74 path.push_back(record.getName().getValue().str());
75 if (failed(appendSignalLeafBindings(record.getType(), path, out, origin))) {
76 return failure();
77 }
78 }
79 return success();
80 }
81
82 origin->emitError("signal members in llzk-witgen must be felts, felt arrays, or PODs of felts");
83 return failure();
84}
85
87static LogicalResult appendStructSignalBindings(
88 component::StructDefOp def, SymbolTableCollection &tables, Operation *origin,
89 SmallVectorImpl<OutputBinding> &out, ArrayRef<std::string> prefix = {}
90) {
91 for (component::MemberDefOp member : def.getMemberDefs()) {
92 llvm::SmallVector<std::string> path(prefix.begin(), prefix.end());
93 path.push_back(member.getSymName().str());
94
95 if (memberIsSignal(def, member)) {
96 if (failed(appendSignalLeafBindings(member.getType(), path, out, origin))) {
97 return failure();
98 }
99 continue;
100 }
101
102 auto nested = typeContainsSignals(member.getType(), tables, origin);
103 if (failed(nested)) {
104 return failure();
105 }
106 if (!*nested) {
107 continue;
108 }
109
110 auto structType = dyn_cast<component::StructType>(member.getType());
111 if (!structType) {
112 member.emitError("non-struct signal container is unsupported in llzk-witgen");
113 return failure();
114 }
115 auto defLookup = structType.getDefinition(tables, origin);
116 if (failed(defLookup)) {
117 return failure();
118 }
119 if (failed(appendStructSignalBindings(defLookup->get(), tables, origin, out, path))) {
120 return failure();
121 }
122 }
123 return success();
124}
125
127static void
128insertLeafJSON(llvm::json::Object &root, ArrayRef<std::string> path, llvm::json::Value value) {
129 if (path.empty()) {
130 return;
131 }
132 if (path.size() == 1) {
133 root[path.front()] = std::move(value);
134 return;
135 }
136
137 llvm::json::Value *slot = &root[path.front()];
138 if (!slot->getAsObject()) {
139 *slot = llvm::json::Object();
140 }
141 insertLeafJSON(*slot->getAsObject(), path.drop_front(), std::move(value));
142}
143
144} // namespace
145
148 return member.getSignal() || (owner.isMainComponent() && member.hasPublicAttr());
149}
150
152llvm::SmallVector<InputBinding> collectInputBindings(function::FuncDefOp computeFunc) {
153 llvm::SmallVector<InputBinding> bindings;
154 bindings.reserve(computeFunc.getNumArguments());
155 for (unsigned i = 0; i < computeFunc.getNumArguments(); ++i) {
156 std::string name;
157 if (std::optional<StringAttr> argName = computeFunc.getArgNameAttr(i)) {
158 name = argName->getValue().str();
159 } else {
160 name = "arg" + std::to_string(i);
161 }
162 bindings.push_back(InputBinding {std::move(name), computeFunc.getArgumentTypes()[i], i});
163 }
164 return bindings;
165}
166
168FailureOr<llvm::SmallVector<OutputBinding>> collectOutputBindings(
169 component::StructDefOp mainDef, SymbolTableCollection &tables, Operation *origin,
170 OutputScope scope
171) {
172 llvm::SmallVector<OutputBinding> bindings;
173 if (scope == OutputScope::Public) {
174 for (component::MemberDefOp member : mainDef.getMemberDefs()) {
175 if (!member.hasPublicAttr()) {
176 continue;
177 }
178 bindings.push_back(OutputBinding {{member.getSymName().str()}, member.getType()});
179 }
180 return bindings;
181 }
182
183 if (failed(appendStructSignalBindings(mainDef, tables, origin, bindings))) {
184 return failure();
185 }
186 return bindings;
187}
188
190llvm::json::Value buildSignalsJSONObject(
191 ArrayRef<OutputBinding> bindings, ArrayRef<llvm::json::Value> serializedLeaves
192) {
193 llvm::json::Object result;
194 for (auto [binding, leaf] : llvm::zip(bindings, serializedLeaves)) {
195 insertLeafJSON(result, binding.path, llvm::json::Value(leaf));
196 }
197 return llvm::json::Value(std::move(result));
198}
199
200} // namespace llzk::witgen
::std::vector< MemberDefOp > getMemberDefs()
Get all MemberDefOp in this structure.
Definition Ops.cpp:458
bool isMainComponent()
Return true iff this struct.def is the main struct. See llzk::MAIN_ATTR_NAME.
Definition Ops.cpp:480
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:850
::std::optional<::mlir::StringAttr > getArgNameAttr(unsigned index)
Return the function.arg_name attribute for the argument at the given index.
Definition Ops.cpp:297
llvm::SmallVector< InputBinding > collectInputBindings(function::FuncDefOp computeFunc)
Collect stable JSON bindings for the main compute inputs.
OutputScope
Select the JSON scope emitted by llzk-witgen.
FailureOr< llvm::SmallVector< OutputBinding > > collectOutputBindings(component::StructDefOp mainDef, SymbolTableCollection &tables, Operation *origin, OutputScope scope)
Collect the selected output bindings for the requested scope.
llvm::json::Value buildSignalsJSONObject(ArrayRef< OutputBinding > bindings, ArrayRef< llvm::json::Value > serializedLeaves)
Assemble a nested JSON object from selected witness leaves.
bool memberIsSignal(component::StructDefOp owner, component::MemberDefOp member)
Return true iff the member is considered a witness signal.
Describe one JSON-visible main input binding.
Describe one selected witness output leaf.