Skip to content

Commit eeabb44

Browse files
hongtedcopybara-github
authored andcommitted
Added minor additions to DSLX types, funcs, and ast nodes.
- Additional casts for Types. - GetParamByName for Functions. - Overrideable default for AstNodeVisitorWithDefault. PiperOrigin-RevId: 681662121
1 parent c86dfd7 commit eeabb44

File tree

6 files changed

+93
-2
lines changed

6 files changed

+93
-2
lines changed

xls/dslx/frontend/ast.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,20 @@ Function::Function(Module* owner, Span span, NameDef* name_def,
16321632

16331633
Function::~Function() = default;
16341634

1635+
absl::StatusOr<Param*> Function::GetParamByName(
1636+
std::string_view param_name) const {
1637+
auto i = std::find_if(params_.begin(), params_.end(), [=](Param* p) -> bool {
1638+
return (p != nullptr) && (p->name_def()->identifier() == param_name);
1639+
});
1640+
1641+
if (i == params_.end()) {
1642+
return absl::NotFoundError(absl::StrFormat(
1643+
"Param '%s' not a parameter of function %s", param_name, ToString()));
1644+
}
1645+
1646+
return *i;
1647+
}
1648+
16351649
std::vector<AstNode*> Function::GetChildren(bool want_types) const {
16361650
std::vector<AstNode*> results;
16371651
results.push_back(name_def());

xls/dslx/frontend/ast.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,18 @@ class AstNodeVisitor {
136136

137137
// Subtype of abstract AstNodeVisitor that returns ok status (does nothing) for
138138
// every node type.
139+
//
140+
// Users can override the default behavior by overriding the DefaultHandler()
141+
// method.
139142
class AstNodeVisitorWithDefault : public AstNodeVisitor {
140143
public:
141144
~AstNodeVisitorWithDefault() override = default;
142145

146+
virtual absl::Status DefaultHandler() { return absl::OkStatus(); }
147+
143148
#define DECLARE_HANDLER(__type) \
144149
absl::Status Handle##__type(const __type* n) override { \
145-
return absl::OkStatus(); \
150+
return DefaultHandler(); \
146151
}
147152
XLS_DSLX_AST_NODE_EACH(DECLARE_HANDLER)
148153
#undef DECLARE_HANDLER
@@ -1656,6 +1661,7 @@ class Function : public AstNode {
16561661
return parametric_bindings_;
16571662
}
16581663
const std::vector<Param*>& params() const { return params_; }
1664+
absl::StatusOr<Param*> GetParamByName(std::string_view param_name) const;
16591665

16601666
// The body of the function is a block (sequence of statements that yields a
16611667
// final expression).

xls/dslx/frontend/ast_test.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,5 +205,42 @@ TEST(AstTest, ToStringCastWithinLtComparison) {
205205
EXPECT_EQ(lt->ToString(), "(x as t) < x");
206206
}
207207

208+
TEST(AstTest, GetFuncParam) {
209+
// Create an empty function
210+
// fn f(p: u32) -> u32 {}
211+
212+
FileTable file_table;
213+
const Span fake_span;
214+
Module m("test", /*fs_path=*/std::nullopt, file_table);
215+
216+
BuiltinNameDef* builtin_name_def = m.GetOrCreateBuiltinNameDef("u32");
217+
BuiltinTypeAnnotation* u32_type_annotation = m.Make<BuiltinTypeAnnotation>(
218+
fake_span, BuiltinType::kU32, builtin_name_def);
219+
220+
NameDef* func_name_def =
221+
m.Make<NameDef>(fake_span, std::string("f"), nullptr);
222+
NameDef* param_name_def =
223+
m.Make<NameDef>(fake_span, std::string("p"), nullptr);
224+
225+
std::vector<Param*> params;
226+
params.push_back(m.Make<Param>(param_name_def, u32_type_annotation));
227+
228+
std::vector<ParametricBinding*> parametric_bindings;
229+
230+
StatementBlock* block =
231+
m.Make<StatementBlock>(fake_span, std::vector<Statement*>{}, true);
232+
233+
Function* f =
234+
m.Make<Function>(fake_span, func_name_def, parametric_bindings, params,
235+
u32_type_annotation, block, FunctionTag::kNormal, false);
236+
237+
XLS_ASSERT_OK_AND_ASSIGN(Param * found_param, f->GetParamByName("p"));
238+
EXPECT_EQ(found_param, params[0]);
239+
240+
EXPECT_THAT(f->GetParamByName("not_a_param"),
241+
StatusIs(absl::StatusCode::kNotFound,
242+
HasSubstr("Param 'not_a_param' not a parameter")));
243+
}
244+
208245
} // namespace
209246
} // namespace xls::dslx

xls/dslx/type_system/type.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,13 @@ bool Type::IsToken() const {
374374
return dynamic_cast<const TokenType*>(this) != nullptr;
375375
}
376376

377+
bool Type::IsTuple() const {
378+
return dynamic_cast<const TupleType*>(this) != nullptr;
379+
}
380+
377381
const EnumType& Type::AsEnum() const {
378382
auto* s = dynamic_cast<const EnumType*>(this);
379-
CHECK(s != nullptr) << "Type is not a enum: " << *this;
383+
CHECK(s != nullptr) << "Type is not an enum: " << *this;
380384
return *s;
381385
}
382386

@@ -386,12 +390,24 @@ const StructType& Type::AsStruct() const {
386390
return *s;
387391
}
388392

393+
const MetaType& Type::AsMeta() const {
394+
auto* s = dynamic_cast<const MetaType*>(this);
395+
CHECK(s != nullptr) << "Type is not a MetaType: " << *this;
396+
return *s;
397+
}
398+
389399
const ArrayType& Type::AsArray() const {
390400
auto* s = dynamic_cast<const ArrayType*>(this);
391401
CHECK(s != nullptr) << "Type is not an array: " << *this;
392402
return *s;
393403
}
394404

405+
const TupleType& Type::AsTuple() const {
406+
auto* s = dynamic_cast<const TupleType*>(this);
407+
CHECK(s != nullptr) << "Type is not a tuple: " << *this;
408+
return *s;
409+
}
410+
395411
// -- TokenType
396412

397413
TokenType::~TokenType() = default;

xls/dslx/type_system/type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,13 @@ class Type {
388388
bool IsEnum() const;
389389
bool IsArray() const;
390390
bool IsMeta() const;
391+
bool IsTuple() const;
391392

392393
const StructType& AsStruct() const;
393394
const EnumType& AsEnum() const;
394395
const ArrayType& AsArray() const;
396+
const MetaType& AsMeta() const;
397+
const TupleType& AsTuple() const;
395398

396399
protected:
397400
static std::vector<std::unique_ptr<Type>> CloneSpan(

xls/dslx/type_system/type_test.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ TEST(TypeTest, TestU32) {
184184
EXPECT_EQ(std::vector<TypeDim>{TypeDim::CreateU32(32)}, t.GetAllDims());
185185
EXPECT_EQ(t, *t.ToUBits());
186186
EXPECT_TRUE(IsUBits(t));
187+
EXPECT_FALSE(t.IsTuple());
187188
}
188189

189190
TEST(TypeTest, TestUnit) {
@@ -193,7 +194,21 @@ TEST(TypeTest, TestUnit) {
193194
EXPECT_EQ("tuple", t.GetDebugTypeName());
194195
EXPECT_EQ(false, t.HasEnum());
195196
EXPECT_TRUE(t.GetAllDims().empty());
197+
EXPECT_TRUE(t.IsTuple());
196198
EXPECT_FALSE(IsUBits(t));
199+
200+
Type* generic_type = &t;
201+
EXPECT_TRUE(generic_type->IsTuple());
202+
EXPECT_EQ(&generic_type->AsTuple(), &t);
203+
}
204+
205+
TEST(TypeTest, TestMetaUnit) {
206+
MetaType meta_t(std::make_unique<BitsType>(false, 32));
207+
EXPECT_TRUE(meta_t.IsMeta());
208+
209+
Type* generic_type = &meta_t;
210+
EXPECT_FALSE(generic_type->IsTuple());
211+
EXPECT_EQ(&generic_type->AsMeta(), &meta_t);
197212
}
198213

199214
TEST(TypeTest, TestTwoTupleOfStruct) {

0 commit comments

Comments
 (0)