Skip to content

Commit b1f1125

Browse files
Merge pull request #1496 from xlsynth:cdleary/fix-1473-parametric-fn-slot
PiperOrigin-RevId: 645440237
2 parents 1289e7c + 5415cf6 commit b1f1125

File tree

10 files changed

+252
-30
lines changed

10 files changed

+252
-30
lines changed

xls/dslx/constexpr_evaluator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ absl::StatusOr<absl::flat_hash_map<std::string, InterpValue>> MakeConstexprEnv(
727727
}
728728

729729
// Collect all the freevars that are constexpr.
730-
FreeVariables freevars = GetFreeVariables(node);
730+
FreeVariables freevars = GetFreeVariablesByPos(node);
731731
VLOG(5) << "freevar count for `" << node->ToString()
732732
<< "`: " << freevars.GetFreeVariableCount();
733733
freevars = freevars.DropBuiltinDefs();

xls/dslx/frontend/ast.cc

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,13 @@ absl::StatusOr<TypeDefinition> ToTypeDefinition(AstNode* node) {
339339
node->GetNodeTypeName(), node->ToString()));
340340
}
341341

342+
const Span& FreeVariables::GetFirstNameRefSpan(
343+
std::string_view identifier) const {
344+
std::vector<const NameRef*> name_refs = values_.at(identifier);
345+
CHECK(!name_refs.empty());
346+
return name_refs.at(0)->span();
347+
}
348+
342349
FreeVariables FreeVariables::DropBuiltinDefs() const {
343350
FreeVariables result;
344351
for (const auto& [identifier, name_refs] : values_) {
@@ -401,28 +408,33 @@ absl::flat_hash_set<std::string> FreeVariables::Keys() const {
401408
return result;
402409
}
403410

404-
FreeVariables GetFreeVariables(const AstNode* node, const Pos* start_pos) {
411+
FreeVariables GetFreeVariablesByLambda(
412+
const AstNode* node,
413+
const std::function<bool(const NameRef&)>& consider_free) {
405414
DfsIteratorNoTypes it(node);
406415
FreeVariables freevars;
407416
while (it.HasNext()) {
408417
const AstNode* n = it.Next();
409418
if (const auto* name_ref = dynamic_cast<const NameRef*>(n)) {
410-
// If a start position was given we test whether the name definition
411-
// occurs before that start position. (If none was given we accept all
412-
// name refs.)
413-
if (start_pos == nullptr) {
419+
if (consider_free == nullptr || consider_free(*name_ref)) {
414420
freevars.Add(name_ref->identifier(), name_ref);
415-
} else {
416-
std::optional<Pos> name_def_start = name_ref->GetNameDefStart();
417-
if (!name_def_start.has_value() || *name_def_start < *start_pos) {
418-
freevars.Add(name_ref->identifier(), name_ref);
419-
}
420421
}
421422
}
422423
}
423424
return freevars;
424425
}
425426

427+
FreeVariables GetFreeVariablesByPos(const AstNode* node, const Pos* start_pos) {
428+
std::function<bool(const NameRef&)> consider_free = nullptr;
429+
if (start_pos != nullptr) {
430+
consider_free = [start_pos](const NameRef& name_ref) {
431+
std::optional<Pos> name_def_start = name_ref.GetNameDefStart();
432+
return !name_def_start.has_value() || name_def_start.value() < *start_pos;
433+
};
434+
}
435+
return GetFreeVariablesByLambda(node, consider_free);
436+
}
437+
426438
std::string BuiltinTypeToString(BuiltinType t) {
427439
switch (t) {
428440
#define CASE(__enum, B, __str, ...) \

xls/dslx/frontend/ast.h

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,21 @@ class FreeVariables {
213213
// references, but the number of free variables).
214214
int64_t GetFreeVariableCount() const { return values_.size(); }
215215

216+
// Returns the span of the first `NameRef` that is referring to `identifier`
217+
// in this free variables set.
218+
const Span& GetFirstNameRefSpan(std::string_view identifier) const;
219+
216220
private:
217221
absl::flat_hash_map<std::string, std::vector<const NameRef*>> values_;
218222
};
219223

224+
// Generalized form of `GetFreeVariablesByPos()` below -- takes a lambda that
225+
// helps us determine if a `NameRef` present in the `node` should be considered
226+
// free.
227+
FreeVariables GetFreeVariablesByLambda(
228+
const AstNode* node,
229+
const std::function<bool(const NameRef&)>& consider_free = nullptr);
230+
220231
// Retrieves all the free variables (references to names that are defined
221232
// prior to start_pos) that are transitively in this AST subtree.
222233
//
@@ -225,10 +236,18 @@ class FreeVariables {
225236
// const FOO = u32:42;
226237
// fn main(x: u32) { FOO+x }
227238
//
228-
// And using the starting point of the function as the start_pos, the FOO will
229-
// be flagged as a free variable and returned.
230-
FreeVariables GetFreeVariables(const AstNode* node,
231-
const Pos* start_pos = nullptr);
239+
// And *using the starting point of the function* as the `start_pos`, the `FOO`
240+
// will be flagged as a free variable and returned.
241+
//
242+
// Note: the start_pos given is a way to approximate "free variable with
243+
// respect to this AST construct". i.e. all the references with defs that are
244+
// defined before this start_pos point are considered free. This gives an easy
245+
// way to say "everything defined inside the body we don't need to worry about
246+
// -- only tell me about references to things before this lexical position in
247+
// the file" -- "lexical position in the file" is an approximation for
248+
// "everything defined outside of (before) this AST construct".
249+
FreeVariables GetFreeVariablesByPos(const AstNode* node,
250+
const Pos* start_pos = nullptr);
232251

233252
// Analogous to ToAstNode(), but for Expr base.
234253
template <typename... Types>
@@ -1708,6 +1727,14 @@ class Function : public AstNode {
17081727
void set_proc(Proc* proc) { proc_ = proc; }
17091728
bool IsInProc() const { return proc_.has_value(); }
17101729

1730+
std::optional<Span> GetParametricBindingsSpan() const {
1731+
if (parametric_bindings_.empty()) {
1732+
return std::nullopt;
1733+
}
1734+
return Span(parametric_bindings_.front()->span().start(),
1735+
parametric_bindings_.back()->span().limit());
1736+
}
1737+
17111738
private:
17121739
Span span_;
17131740
NameDef* name_def_;
@@ -2279,6 +2306,14 @@ class StructDef : public AstNode {
22792306
return extern_type_name_;
22802307
}
22812308

2309+
std::optional<Span> GetParametricBindingsSpan() const {
2310+
if (parametric_bindings_.empty()) {
2311+
return std::nullopt;
2312+
}
2313+
return Span(parametric_bindings_.front()->span().start(),
2314+
parametric_bindings_.back()->span().limit());
2315+
}
2316+
22822317
private:
22832318
Span span_;
22842319
NameDef* name_def_;

xls/dslx/frontend/parser_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,7 +1527,7 @@ TEST_F(ParserTest, MatchFreevars) {
15271527
y => z,
15281528
})",
15291529
{"x", "y", "z"});
1530-
FreeVariables fv = GetFreeVariables(e, &e->span().start());
1530+
FreeVariables fv = GetFreeVariablesByPos(e, &e->span().start());
15311531
EXPECT_THAT(fv.Keys(), testing::ContainerEq(
15321532
absl::flat_hash_set<std::string>{"x", "y", "z"}));
15331533
}
@@ -1538,7 +1538,7 @@ TEST_F(ParserTest, ForFreevars) {
15381538
new_accum
15391539
}(u32:0))",
15401540
{"range", "j"});
1541-
FreeVariables fv = GetFreeVariables(e, &e->span().start());
1541+
FreeVariables fv = GetFreeVariablesByPos(e, &e->span().start());
15421542
EXPECT_THAT(fv.Keys(), testing::ContainerEq(
15431543
absl::flat_hash_set<std::string>{"j", "range"}));
15441544
}

xls/dslx/ir_convert/function_converter.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,7 @@ absl::Status FunctionConverter::HandleFor(const For* node) {
12821282
// So we suffix free variables for the function body onto the function
12831283
// parameters.
12841284
FreeVariables freevars =
1285-
GetFreeVariables(node->body(), &node->span().start());
1285+
GetFreeVariablesByPos(node->body(), &node->span().start());
12861286
freevars = freevars.DropBuiltinDefs();
12871287
std::vector<const NameDef*> relevant_name_defs;
12881288
for (const auto& any_name_def : freevars.GetNameDefs()) {
@@ -3505,7 +3505,7 @@ absl::StatusOr<Value> InterpValueToValue(const InterpValue& iv) {
35053505
absl::StatusOr<std::vector<ConstantDef*>> GetConstantDepFreevars(
35063506
AstNode* node) {
35073507
Span span = node->GetSpan().value();
3508-
FreeVariables free_variables = GetFreeVariables(node, &span.start());
3508+
FreeVariables free_variables = GetFreeVariablesByPos(node, &span.start());
35093509
std::vector<std::pair<std::string, AnyNameDef>> freevars =
35103510
free_variables.GetNameDefTuples();
35113511
std::vector<ConstantDef*> constant_deps;

xls/dslx/tests/errors/error_modules_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,13 @@ def test_width_slice_of_non_type_size(self):
11811181
'Expected type-reference to refer to a type definition', stderr
11821182
)
11831183

1184+
def test_gh_1473(self):
1185+
stderr = self._run('xls/dslx/tests/errors/gh_1473.x')
1186+
self.assertIn(
1187+
'Parametric expression `umax(MAX_N_M, V)` refered to `V`'
1188+
' which is not present in the parametric environment', stderr
1189+
)
1190+
11841191

11851192
if __name__ == '__main__':
11861193
test_base.main()

xls/dslx/tests/errors/gh_1473.x

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright 2024 The XLS Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import float32;
16+
17+
pub fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] { if x > y { x } else { y } }
18+
19+
pub fn uadd_with_overflow
20+
<V: u32,
21+
N: u32,
22+
M: u32,
23+
MAX_N_M: u32 = {umax(N, M)},
24+
MAX_N_M_V: u32 = {umax(MAX_N_M, V)}>
25+
(x: uN[N], y: uN[M]) -> (bool, uN[V]) {
26+
27+
let x_extended = widening_cast<uN[MAX_N_M_V + u32:1]>(x);
28+
let y_extended = widening_cast<uN[MAX_N_M_V + u32:1]>(y);
29+
30+
let full_result: uN[MAX_N_M_V + u32:1] = x_extended + y_extended;
31+
let narrowed_result = full_result as uN[V];
32+
let overflow_detected = or_reduce(full_result[V as s32:]);
33+
34+
(overflow_detected, narrowed_result)
35+
}
36+
37+
pub fn double_fraction_carry(f: float32::F32) -> (uN[float32::F32_FRACTION_SZ], u1) {
38+
let f = f.fraction as uN[float32::F32_FRACTION_SZ + u32:1];
39+
let (overflow, f_x2) = uadd_with_overflow(f, f);
40+
(f_x2, overflow)
41+
}

xls/dslx/type_system/parametric_instantiator_internal.cc

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,15 @@ absl::Status EagerlyPopulateParametricEnvMap(
144144
absl::Span<const ParametricWithType> typed_parametrics,
145145
const absl::flat_hash_map<std::string, Expr*>& parametric_default_exprs,
146146
absl::flat_hash_map<std::string, InterpValue>& parametric_env_map,
147-
const Span& span, std::string_view kind_name, DeduceCtx* ctx) {
147+
const std::optional<Span>& parametrics_span, const Span& span,
148+
std::string_view kind_name, DeduceCtx* ctx) {
149+
// If there are no parametric bindings being instantiated (in the callee) we
150+
// should have already typechecked that there are no parametrics being
151+
// applied by the caller.
152+
if (!parametrics_span.has_value()) {
153+
XLS_RET_CHECK(typed_parametrics.empty());
154+
}
155+
148156
// Attempt to interpret the parametric "default expressions" in order.
149157
for (const ParametricWithType& typed_parametric : typed_parametrics) {
150158
std::string_view name = typed_parametric.identifier();
@@ -180,6 +188,38 @@ absl::Status EagerlyPopulateParametricEnvMap(
180188
VLOG(5) << absl::StreamFormat("Evaluating expr: `%s` in env: %s",
181189
expr->ToString(), env.ToString());
182190

191+
// If the expression requires variables that were not bound in the
192+
// environment yet, we give an error message.
193+
FreeVariables freevars =
194+
GetFreeVariablesByLambda(expr, [&](const NameRef& name_ref) {
195+
// If the name def is in the parametrics span, we consider it a
196+
// freevar.
197+
AnyNameDef any_name_def = name_ref.name_def();
198+
if (!std::holds_alternative<const NameDef*>(any_name_def)) {
199+
return false;
200+
}
201+
const NameDef* name_def = std::get<const NameDef*>(any_name_def);
202+
Span definition_span = name_def->GetSpan().value();
203+
if (definition_span.filename() != parametrics_span->filename()) {
204+
return false;
205+
}
206+
return parametrics_span->Contains(definition_span);
207+
});
208+
209+
auto keys_unsorted = freevars.Keys();
210+
absl::btree_set<std::string> keys_sorted(keys_unsorted.begin(),
211+
keys_unsorted.end());
212+
for (const std::string& key : keys_sorted) {
213+
if (!parametric_env_map.contains(key)) {
214+
return TypeInferenceErrorStatus(
215+
freevars.GetFirstNameRefSpan(key), nullptr,
216+
absl::StrFormat(
217+
"Parametric expression `%s` refered to `%s` which is not "
218+
"present in the parametric environment; instantiated from %s",
219+
expr->ToString(), key, span.ToString()));
220+
}
221+
}
222+
183223
absl::StatusOr<InterpValue> result = InterpretExpr(ctx, expr, env);
184224

185225
VLOG(5) << "Interpreted expr: " << expr->ToString() << " @ " << expr->span()
@@ -209,18 +249,36 @@ absl::Status EagerlyPopulateParametricEnvMap(
209249
parametric_env_map.insert({std::string{name}, result.value()});
210250
}
211251
}
252+
253+
// TODO(https://github.com/google/xls/issues/1495): 2024-06-18 We would like
254+
// to enable this invariant to tighten up what is accepted by the type
255+
// system, but that requires some investigation into failing samples.
256+
if (false) {
257+
// Check that all parametric bindings are present in the env.
258+
for (const auto& [parametric_binding_name, _] : parametric_default_exprs) {
259+
if (!parametric_env_map.contains(parametric_binding_name)) {
260+
return TypeInferenceErrorStatus(
261+
span, nullptr,
262+
absl::StrFormat("Caller did not supply parametric value for `%s`",
263+
parametric_binding_name));
264+
}
265+
}
266+
}
267+
212268
return absl::OkStatus();
213269
}
214270

215271
} // namespace
216272

217273
ParametricInstantiator::ParametricInstantiator(
218-
Span span, absl::Span<const InstantiateArg> args, DeduceCtx* ctx,
274+
Span span, std::optional<Span> parametrics_span,
275+
absl::Span<const InstantiateArg> args, DeduceCtx* ctx,
219276
absl::Span<const ParametricWithType> typed_parametrics,
220277
const absl::flat_hash_map<std::string, InterpValue>& explicit_parametrics,
221278
absl::Span<absl::Nonnull<const ParametricBinding*> const>
222279
parametric_bindings)
223280
: span_(std::move(span)),
281+
parametrics_span_(std::move(parametrics_span)),
224282
args_(args),
225283
ctx_(ABSL_DIE_IF_NULL(ctx)),
226284
typed_parametrics_(typed_parametrics),
@@ -339,7 +397,7 @@ absl::StatusOr<TypeAndParametricEnv> FunctionInstantiator::Instantiate() {
339397

340398
XLS_RETURN_IF_ERROR(EagerlyPopulateParametricEnvMap(
341399
typed_parametrics(), parametric_default_exprs(), parametric_env_map(),
342-
span(), GetKindName(), &ctx()));
400+
parametrics_span(), span(), GetKindName(), &ctx()));
343401

344402
// Phase 2: resolve and check.
345403
VLOG(10) << "Phase 2: resolve-and-check";
@@ -404,7 +462,7 @@ absl::StatusOr<TypeAndParametricEnv> StructInstantiator::Instantiate() {
404462

405463
XLS_RETURN_IF_ERROR(EagerlyPopulateParametricEnvMap(
406464
typed_parametrics(), parametric_default_exprs(), parametric_env_map(),
407-
span(), GetKindName(), &ctx()));
465+
parametrics_span(), span(), GetKindName(), &ctx()));
408466

409467
// Phase 2: resolve and check.
410468
for (int64_t i = 0; i < member_types_.size(); ++i) {

0 commit comments

Comments
 (0)