Skip to content

Commit

Permalink
Merge pull request #1496 from xlsynth:cdleary/fix-1473-parametric-fn-…
Browse files Browse the repository at this point in the history
…slot

PiperOrigin-RevId: 645440237
  • Loading branch information
copybara-github committed Jun 21, 2024
2 parents 1289e7c + 5415cf6 commit b1f1125
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 30 deletions.
2 changes: 1 addition & 1 deletion xls/dslx/constexpr_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ absl::StatusOr<absl::flat_hash_map<std::string, InterpValue>> MakeConstexprEnv(
}

// Collect all the freevars that are constexpr.
FreeVariables freevars = GetFreeVariables(node);
FreeVariables freevars = GetFreeVariablesByPos(node);
VLOG(5) << "freevar count for `" << node->ToString()
<< "`: " << freevars.GetFreeVariableCount();
freevars = freevars.DropBuiltinDefs();
Expand Down
32 changes: 22 additions & 10 deletions xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@ absl::StatusOr<TypeDefinition> ToTypeDefinition(AstNode* node) {
node->GetNodeTypeName(), node->ToString()));
}

const Span& FreeVariables::GetFirstNameRefSpan(
std::string_view identifier) const {
std::vector<const NameRef*> name_refs = values_.at(identifier);
CHECK(!name_refs.empty());
return name_refs.at(0)->span();
}

FreeVariables FreeVariables::DropBuiltinDefs() const {
FreeVariables result;
for (const auto& [identifier, name_refs] : values_) {
Expand Down Expand Up @@ -401,28 +408,33 @@ absl::flat_hash_set<std::string> FreeVariables::Keys() const {
return result;
}

FreeVariables GetFreeVariables(const AstNode* node, const Pos* start_pos) {
FreeVariables GetFreeVariablesByLambda(
const AstNode* node,
const std::function<bool(const NameRef&)>& consider_free) {
DfsIteratorNoTypes it(node);
FreeVariables freevars;
while (it.HasNext()) {
const AstNode* n = it.Next();
if (const auto* name_ref = dynamic_cast<const NameRef*>(n)) {
// If a start position was given we test whether the name definition
// occurs before that start position. (If none was given we accept all
// name refs.)
if (start_pos == nullptr) {
if (consider_free == nullptr || consider_free(*name_ref)) {
freevars.Add(name_ref->identifier(), name_ref);
} else {
std::optional<Pos> name_def_start = name_ref->GetNameDefStart();
if (!name_def_start.has_value() || *name_def_start < *start_pos) {
freevars.Add(name_ref->identifier(), name_ref);
}
}
}
}
return freevars;
}

FreeVariables GetFreeVariablesByPos(const AstNode* node, const Pos* start_pos) {
std::function<bool(const NameRef&)> consider_free = nullptr;
if (start_pos != nullptr) {
consider_free = [start_pos](const NameRef& name_ref) {
std::optional<Pos> name_def_start = name_ref.GetNameDefStart();
return !name_def_start.has_value() || name_def_start.value() < *start_pos;
};
}
return GetFreeVariablesByLambda(node, consider_free);
}

std::string BuiltinTypeToString(BuiltinType t) {
switch (t) {
#define CASE(__enum, B, __str, ...) \
Expand Down
43 changes: 39 additions & 4 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,21 @@ class FreeVariables {
// references, but the number of free variables).
int64_t GetFreeVariableCount() const { return values_.size(); }

// Returns the span of the first `NameRef` that is referring to `identifier`
// in this free variables set.
const Span& GetFirstNameRefSpan(std::string_view identifier) const;

private:
absl::flat_hash_map<std::string, std::vector<const NameRef*>> values_;
};

// Generalized form of `GetFreeVariablesByPos()` below -- takes a lambda that
// helps us determine if a `NameRef` present in the `node` should be considered
// free.
FreeVariables GetFreeVariablesByLambda(
const AstNode* node,
const std::function<bool(const NameRef&)>& consider_free = nullptr);

// Retrieves all the free variables (references to names that are defined
// prior to start_pos) that are transitively in this AST subtree.
//
Expand All @@ -225,10 +236,18 @@ class FreeVariables {
// const FOO = u32:42;
// fn main(x: u32) { FOO+x }
//
// And using the starting point of the function as the start_pos, the FOO will
// be flagged as a free variable and returned.
FreeVariables GetFreeVariables(const AstNode* node,
const Pos* start_pos = nullptr);
// And *using the starting point of the function* as the `start_pos`, the `FOO`
// will be flagged as a free variable and returned.
//
// Note: the start_pos given is a way to approximate "free variable with
// respect to this AST construct". i.e. all the references with defs that are
// defined before this start_pos point are considered free. This gives an easy
// way to say "everything defined inside the body we don't need to worry about
// -- only tell me about references to things before this lexical position in
// the file" -- "lexical position in the file" is an approximation for
// "everything defined outside of (before) this AST construct".
FreeVariables GetFreeVariablesByPos(const AstNode* node,
const Pos* start_pos = nullptr);

// Analogous to ToAstNode(), but for Expr base.
template <typename... Types>
Expand Down Expand Up @@ -1708,6 +1727,14 @@ class Function : public AstNode {
void set_proc(Proc* proc) { proc_ = proc; }
bool IsInProc() const { return proc_.has_value(); }

std::optional<Span> GetParametricBindingsSpan() const {
if (parametric_bindings_.empty()) {
return std::nullopt;
}
return Span(parametric_bindings_.front()->span().start(),
parametric_bindings_.back()->span().limit());
}

private:
Span span_;
NameDef* name_def_;
Expand Down Expand Up @@ -2279,6 +2306,14 @@ class StructDef : public AstNode {
return extern_type_name_;
}

std::optional<Span> GetParametricBindingsSpan() const {
if (parametric_bindings_.empty()) {
return std::nullopt;
}
return Span(parametric_bindings_.front()->span().start(),
parametric_bindings_.back()->span().limit());
}

private:
Span span_;
NameDef* name_def_;
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ TEST_F(ParserTest, MatchFreevars) {
y => z,
})",
{"x", "y", "z"});
FreeVariables fv = GetFreeVariables(e, &e->span().start());
FreeVariables fv = GetFreeVariablesByPos(e, &e->span().start());
EXPECT_THAT(fv.Keys(), testing::ContainerEq(
absl::flat_hash_set<std::string>{"x", "y", "z"}));
}
Expand All @@ -1538,7 +1538,7 @@ TEST_F(ParserTest, ForFreevars) {
new_accum
}(u32:0))",
{"range", "j"});
FreeVariables fv = GetFreeVariables(e, &e->span().start());
FreeVariables fv = GetFreeVariablesByPos(e, &e->span().start());
EXPECT_THAT(fv.Keys(), testing::ContainerEq(
absl::flat_hash_set<std::string>{"j", "range"}));
}
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ absl::Status FunctionConverter::HandleFor(const For* node) {
// So we suffix free variables for the function body onto the function
// parameters.
FreeVariables freevars =
GetFreeVariables(node->body(), &node->span().start());
GetFreeVariablesByPos(node->body(), &node->span().start());
freevars = freevars.DropBuiltinDefs();
std::vector<const NameDef*> relevant_name_defs;
for (const auto& any_name_def : freevars.GetNameDefs()) {
Expand Down Expand Up @@ -3505,7 +3505,7 @@ absl::StatusOr<Value> InterpValueToValue(const InterpValue& iv) {
absl::StatusOr<std::vector<ConstantDef*>> GetConstantDepFreevars(
AstNode* node) {
Span span = node->GetSpan().value();
FreeVariables free_variables = GetFreeVariables(node, &span.start());
FreeVariables free_variables = GetFreeVariablesByPos(node, &span.start());
std::vector<std::pair<std::string, AnyNameDef>> freevars =
free_variables.GetNameDefTuples();
std::vector<ConstantDef*> constant_deps;
Expand Down
7 changes: 7 additions & 0 deletions xls/dslx/tests/errors/error_modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,13 @@ def test_width_slice_of_non_type_size(self):
'Expected type-reference to refer to a type definition', stderr
)

def test_gh_1473(self):
stderr = self._run('xls/dslx/tests/errors/gh_1473.x')
self.assertIn(
'Parametric expression `umax(MAX_N_M, V)` refered to `V`'
' which is not present in the parametric environment', stderr
)


if __name__ == '__main__':
test_base.main()
41 changes: 41 additions & 0 deletions xls/dslx/tests/errors/gh_1473.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2024 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import float32;

pub fn umax<N: u32>(x: uN[N], y: uN[N]) -> uN[N] { if x > y { x } else { y } }

pub fn uadd_with_overflow
<V: u32,
N: u32,
M: u32,
MAX_N_M: u32 = {umax(N, M)},
MAX_N_M_V: u32 = {umax(MAX_N_M, V)}>
(x: uN[N], y: uN[M]) -> (bool, uN[V]) {

let x_extended = widening_cast<uN[MAX_N_M_V + u32:1]>(x);
let y_extended = widening_cast<uN[MAX_N_M_V + u32:1]>(y);

let full_result: uN[MAX_N_M_V + u32:1] = x_extended + y_extended;
let narrowed_result = full_result as uN[V];
let overflow_detected = or_reduce(full_result[V as s32:]);

(overflow_detected, narrowed_result)
}

pub fn double_fraction_carry(f: float32::F32) -> (uN[float32::F32_FRACTION_SZ], u1) {
let f = f.fraction as uN[float32::F32_FRACTION_SZ + u32:1];
let (overflow, f_x2) = uadd_with_overflow(f, f);
(f_x2, overflow)
}
66 changes: 62 additions & 4 deletions xls/dslx/type_system/parametric_instantiator_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,15 @@ absl::Status EagerlyPopulateParametricEnvMap(
absl::Span<const ParametricWithType> typed_parametrics,
const absl::flat_hash_map<std::string, Expr*>& parametric_default_exprs,
absl::flat_hash_map<std::string, InterpValue>& parametric_env_map,
const Span& span, std::string_view kind_name, DeduceCtx* ctx) {
const std::optional<Span>& parametrics_span, const Span& span,
std::string_view kind_name, DeduceCtx* ctx) {
// If there are no parametric bindings being instantiated (in the callee) we
// should have already typechecked that there are no parametrics being
// applied by the caller.
if (!parametrics_span.has_value()) {
XLS_RET_CHECK(typed_parametrics.empty());
}

// Attempt to interpret the parametric "default expressions" in order.
for (const ParametricWithType& typed_parametric : typed_parametrics) {
std::string_view name = typed_parametric.identifier();
Expand Down Expand Up @@ -180,6 +188,38 @@ absl::Status EagerlyPopulateParametricEnvMap(
VLOG(5) << absl::StreamFormat("Evaluating expr: `%s` in env: %s",
expr->ToString(), env.ToString());

// If the expression requires variables that were not bound in the
// environment yet, we give an error message.
FreeVariables freevars =
GetFreeVariablesByLambda(expr, [&](const NameRef& name_ref) {
// If the name def is in the parametrics span, we consider it a
// freevar.
AnyNameDef any_name_def = name_ref.name_def();
if (!std::holds_alternative<const NameDef*>(any_name_def)) {
return false;
}
const NameDef* name_def = std::get<const NameDef*>(any_name_def);
Span definition_span = name_def->GetSpan().value();
if (definition_span.filename() != parametrics_span->filename()) {
return false;
}
return parametrics_span->Contains(definition_span);
});

auto keys_unsorted = freevars.Keys();
absl::btree_set<std::string> keys_sorted(keys_unsorted.begin(),
keys_unsorted.end());
for (const std::string& key : keys_sorted) {
if (!parametric_env_map.contains(key)) {
return TypeInferenceErrorStatus(
freevars.GetFirstNameRefSpan(key), nullptr,
absl::StrFormat(
"Parametric expression `%s` refered to `%s` which is not "
"present in the parametric environment; instantiated from %s",
expr->ToString(), key, span.ToString()));
}
}

absl::StatusOr<InterpValue> result = InterpretExpr(ctx, expr, env);

VLOG(5) << "Interpreted expr: " << expr->ToString() << " @ " << expr->span()
Expand Down Expand Up @@ -209,18 +249,36 @@ absl::Status EagerlyPopulateParametricEnvMap(
parametric_env_map.insert({std::string{name}, result.value()});
}
}

// TODO(https://github.com/google/xls/issues/1495): 2024-06-18 We would like
// to enable this invariant to tighten up what is accepted by the type
// system, but that requires some investigation into failing samples.
if (false) {
// Check that all parametric bindings are present in the env.
for (const auto& [parametric_binding_name, _] : parametric_default_exprs) {
if (!parametric_env_map.contains(parametric_binding_name)) {
return TypeInferenceErrorStatus(
span, nullptr,
absl::StrFormat("Caller did not supply parametric value for `%s`",
parametric_binding_name));
}
}
}

return absl::OkStatus();
}

} // namespace

ParametricInstantiator::ParametricInstantiator(
Span span, absl::Span<const InstantiateArg> args, DeduceCtx* ctx,
Span span, std::optional<Span> parametrics_span,
absl::Span<const InstantiateArg> args, DeduceCtx* ctx,
absl::Span<const ParametricWithType> typed_parametrics,
const absl::flat_hash_map<std::string, InterpValue>& explicit_parametrics,
absl::Span<absl::Nonnull<const ParametricBinding*> const>
parametric_bindings)
: span_(std::move(span)),
parametrics_span_(std::move(parametrics_span)),
args_(args),
ctx_(ABSL_DIE_IF_NULL(ctx)),
typed_parametrics_(typed_parametrics),
Expand Down Expand Up @@ -339,7 +397,7 @@ absl::StatusOr<TypeAndParametricEnv> FunctionInstantiator::Instantiate() {

XLS_RETURN_IF_ERROR(EagerlyPopulateParametricEnvMap(
typed_parametrics(), parametric_default_exprs(), parametric_env_map(),
span(), GetKindName(), &ctx()));
parametrics_span(), span(), GetKindName(), &ctx()));

// Phase 2: resolve and check.
VLOG(10) << "Phase 2: resolve-and-check";
Expand Down Expand Up @@ -404,7 +462,7 @@ absl::StatusOr<TypeAndParametricEnv> StructInstantiator::Instantiate() {

XLS_RETURN_IF_ERROR(EagerlyPopulateParametricEnvMap(
typed_parametrics(), parametric_default_exprs(), parametric_env_map(),
span(), GetKindName(), &ctx()));
parametrics_span(), span(), GetKindName(), &ctx()));

// Phase 2: resolve and check.
for (int64_t i = 0; i < member_types_.size(); ++i) {
Expand Down
Loading

0 comments on commit b1f1125

Please sign in to comment.