Skip to content

Commit 29d2794

Browse files
committed
Add tests
1 parent 707c16a commit 29d2794

File tree

3 files changed

+109
-25
lines changed

3 files changed

+109
-25
lines changed

xls/dslx/frontend/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ cc_library(
6565
":ast",
6666
":ast_cloner",
6767
":module",
68+
"//xls/common/status:ret_check",
6869
"//xls/common/status:status_macros",
6970
"//xls/ir:bits_ops",
7071
"//xls/ir:format_preference",

xls/dslx/frontend/function_specializer.cc

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/status/statusor.h"
2828
#include "absl/strings/str_format.h"
2929
#include "absl/types/span.h"
30+
#include "xls/common/status/ret_check.h"
3031
#include "xls/common/status/status_macros.h"
3132
#include "xls/dslx/frontend/ast_cloner.h"
3233
#include "xls/dslx/frontend/module.h"
@@ -175,9 +176,11 @@ absl::StatusOr<Number*> CreateLiteralFromValue(Module* module, const Span& span,
175176
absl::StatusOr<Function*> InsertFunctionSpecialization(
176177
Function* source_function, const ParametricEnv& param_env,
177178
std::string_view specialized_name) {
178-
CHECK_NE(source_function, nullptr);
179+
XLS_RET_CHECK_NE(source_function, nullptr)
180+
<< "InsertFunctionSpecialization requires a non-null source function";
179181
Module* module = source_function->owner();
180-
CHECK_NE(module, nullptr);
182+
XLS_RET_CHECK_NE(module, nullptr) << absl::StrFormat(
183+
"Source function %s has no owning module", source_function->identifier());
181184

182185
if (!source_function->IsParametric()) {
183186
return absl::InvalidArgumentError(absl::StrFormat(
@@ -186,7 +189,8 @@ absl::StatusOr<Function*> InsertFunctionSpecialization(
186189

187190
auto binding_values =
188191
std::make_shared<absl::flat_hash_map<const NameDef*, InterpValue>>();
189-
absl::flat_hash_map<const NameDef*, TypeAnnotation*> binding_types;
192+
auto binding_types =
193+
std::make_shared<absl::flat_hash_map<const NameDef*, TypeAnnotation*>>();
190194
for (ParametricBinding* binding : source_function->parametric_bindings()) {
191195
std::optional<InterpValue> value = param_env.GetValue(binding->name_def());
192196
if (!value.has_value()) {
@@ -196,13 +200,13 @@ absl::StatusOr<Function*> InsertFunctionSpecialization(
196200
}
197201
binding_values->emplace(binding->name_def(), *value);
198202
if (binding->type_annotation() != nullptr) {
199-
binding_types.emplace(binding->name_def(), binding->type_annotation());
203+
binding_types->emplace(binding->name_def(), binding->type_annotation());
200204
}
201205
}
202206

203207
auto make_replacer = [binding_values, binding_types](
204-
const absl::flat_hash_map<const NameDef*, NameDef*>*
205-
param_name_replacements) -> CloneReplacer {
208+
const absl::flat_hash_map<const NameDef*, NameDef*>*
209+
param_name_replacements) -> CloneReplacer {
206210
return [binding_values, binding_types, param_name_replacements](
207211
const AstNode* original, Module* target_module,
208212
const absl::flat_hash_map<const AstNode*, AstNode*>& old_to_new)
@@ -217,20 +221,19 @@ absl::StatusOr<Function*> InsertFunctionSpecialization(
217221
auto param_it = param_name_replacements->find(def);
218222
if (param_it != param_name_replacements->end()) {
219223
NameDef* replacement = param_it->second;
220-
return std::optional<AstNode*>(
221-
target_module->Make<NameRef>(name_ref->span(),
222-
name_ref->identifier(),
223-
replacement, name_ref->in_parens()));
224+
return std::optional<AstNode*>(target_module->Make<NameRef>(
225+
name_ref->span(), name_ref->identifier(), replacement,
226+
name_ref->in_parens()));
224227
}
225228
}
226229
auto binding_it = binding_values->find(def);
227230
if (binding_it != binding_values->end()) {
228-
XLS_ASSIGN_OR_RETURN(Number * literal,
229-
CreateLiteralFromValue(target_module,
230-
name_ref->span(),
231-
binding_it->second));
232-
auto type_it = binding_types.find(def);
233-
if (type_it != binding_types.end() && type_it->second != nullptr) {
231+
XLS_ASSIGN_OR_RETURN(
232+
Number * literal,
233+
CreateLiteralFromValue(target_module, name_ref->span(),
234+
binding_it->second));
235+
auto type_it = binding_types->find(def);
236+
if (type_it != binding_types->end() && type_it->second != nullptr) {
234237
XLS_ASSIGN_OR_RETURN(TypeAnnotation * cloned_type,
235238
CloneNode<TypeAnnotation>(type_it->second));
236239
literal->SetTypeAnnotation(cloned_type);
@@ -246,29 +249,31 @@ absl::StatusOr<Function*> InsertFunctionSpecialization(
246249
new_params.reserve(source_function->params().size());
247250
absl::flat_hash_map<const NameDef*, NameDef*> param_name_replacements;
248251
for (Param* param : source_function->params()) {
249-
XLS_ASSIGN_OR_RETURN(Param * cloned_param,
250-
CloneNode<Param>(param,
251-
make_replacer(/*param_name_replacements=*/nullptr)));
252+
XLS_ASSIGN_OR_RETURN(
253+
Param * cloned_param,
254+
CloneNode<Param>(param,
255+
make_replacer(/*param_name_replacements=*/nullptr)));
252256
param_name_replacements.emplace(param->name_def(),
253257
cloned_param->name_def());
254258
new_params.push_back(cloned_param);
255259
}
256260

257261
TypeAnnotation* new_return_type = nullptr;
258262
if (source_function->return_type() != nullptr) {
259-
XLS_ASSIGN_OR_RETURN(new_return_type,
260-
CloneNode<TypeAnnotation>(source_function->return_type(),
261-
make_replacer(&param_name_replacements)));
263+
XLS_ASSIGN_OR_RETURN(
264+
new_return_type,
265+
CloneNode<TypeAnnotation>(source_function->return_type(),
266+
make_replacer(&param_name_replacements)));
262267
}
263268

264269
XLS_ASSIGN_OR_RETURN(
265270
StatementBlock * new_body,
266271
CloneNode<StatementBlock>(source_function->body(),
267272
make_replacer(&param_name_replacements)));
268273

269-
NameDef* new_name_def = module->Make<NameDef>(
270-
Span::Fake(), std::string(specialized_name),
271-
/*definer=*/nullptr);
274+
NameDef* new_name_def =
275+
module->Make<NameDef>(Span::Fake(), std::string(specialized_name),
276+
/*definer=*/nullptr);
272277

273278
SyntheticSpanAllocator span_allocator(module, source_function,
274279
specialized_name);

xls/dslx/frontend/function_specializer_test.cc

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,5 +370,83 @@ fn select_poly<N: u32>(polys: uN[6][N], selector: uN[N]) -> uN[6] {
370370
ASSERT_TRUE(replaced.ok()) << replaced.status();
371371
}
372372

373+
TEST(FunctionSpecializerTest, TypeAnnotationsSubstituteParametricBindings) {
374+
constexpr std::string_view kProgram =
375+
R"(fn slice<M: u32>(x: bits[M]) -> bits[M + 8] {
376+
let y: bits[M + 8] = x ++ u8:0;
377+
y
378+
}
379+
)";
380+
381+
std::unique_ptr<ImportData> import_data = CreateImportDataPtrForTest();
382+
XLS_ASSERT_OK_AND_ASSIGN(TypecheckedModule typechecked,
383+
ParseAndTypecheck(kProgram, "ta_module.x",
384+
"ta_module", import_data.get()));
385+
386+
Module* module = typechecked.module;
387+
ASSERT_NE(module, nullptr);
388+
389+
std::optional<Function*> slice_fn = module->GetFunction("slice");
390+
ASSERT_TRUE(slice_fn.has_value());
391+
392+
const ParametricBinding* binding =
393+
slice_fn.value()->parametric_bindings().front();
394+
InterpValue binding_value = InterpValue::MakeUBits(/*bit_count=*/32, 16);
395+
absl::flat_hash_map<std::string, InterpValue> env_bindings;
396+
env_bindings.emplace(binding->identifier(), binding_value);
397+
ParametricEnv env(env_bindings);
398+
399+
XLS_ASSERT_OK_AND_ASSIGN(
400+
Function * specialized,
401+
InsertFunctionSpecialization(slice_fn.value(), env, "slice_M16"));
402+
403+
ASSERT_EQ(specialized->params().size(), 1);
404+
Param* specialized_param = specialized->params()[0];
405+
auto* param_type =
406+
down_cast<ArrayTypeAnnotation*>(specialized_param->type_annotation());
407+
ASSERT_NE(param_type, nullptr);
408+
auto* param_dim = dynamic_cast<Number*>(param_type->dim());
409+
ASSERT_NE(param_dim, nullptr);
410+
EXPECT_EQ(param_dim->text(), "0x10");
411+
412+
auto* return_type =
413+
down_cast<ArrayTypeAnnotation*>(specialized->return_type());
414+
ASSERT_NE(return_type, nullptr);
415+
auto* return_dim = dynamic_cast<Binop*>(return_type->dim());
416+
ASSERT_NE(return_dim, nullptr);
417+
EXPECT_EQ(return_dim->binop_kind(), BinopKind::kAdd);
418+
auto* return_dim_lhs = dynamic_cast<Number*>(return_dim->lhs());
419+
ASSERT_NE(return_dim_lhs, nullptr);
420+
EXPECT_EQ(return_dim_lhs->text(), "0x10");
421+
auto* return_dim_rhs = dynamic_cast<Number*>(return_dim->rhs());
422+
ASSERT_NE(return_dim_rhs, nullptr);
423+
ASSERT_NE(module->file_table(), nullptr);
424+
const FileTable& file_table = *module->file_table();
425+
XLS_ASSERT_OK_AND_ASSIGN(uint64_t return_dim_rhs_value,
426+
return_dim_rhs->GetAsUint64(file_table));
427+
EXPECT_EQ(return_dim_rhs_value, 8);
428+
429+
StatementBlock* body = specialized->body();
430+
ASSERT_EQ(body->statements().size(), 2);
431+
const Statement::Wrapped& let_wrapped =
432+
body->statements().front()->wrapped();
433+
ASSERT_TRUE(std::holds_alternative<Let*>(let_wrapped));
434+
auto* let_stmt = std::get<Let*>(let_wrapped);
435+
auto* let_type =
436+
down_cast<ArrayTypeAnnotation*>(let_stmt->type_annotation());
437+
ASSERT_NE(let_type, nullptr);
438+
auto* let_dim = dynamic_cast<Binop*>(let_type->dim());
439+
ASSERT_NE(let_dim, nullptr);
440+
EXPECT_EQ(let_dim->binop_kind(), BinopKind::kAdd);
441+
auto* let_dim_lhs = dynamic_cast<Number*>(let_dim->lhs());
442+
ASSERT_NE(let_dim_lhs, nullptr);
443+
EXPECT_EQ(let_dim_lhs->text(), "0x10");
444+
auto* let_dim_rhs = dynamic_cast<Number*>(let_dim->rhs());
445+
ASSERT_NE(let_dim_rhs, nullptr);
446+
XLS_ASSERT_OK_AND_ASSIGN(uint64_t let_dim_rhs_value,
447+
let_dim_rhs->GetAsUint64(file_table));
448+
EXPECT_EQ(let_dim_rhs_value, 8);
449+
}
450+
373451
} // namespace
374452
} // namespace xls::dslx

0 commit comments

Comments
 (0)