Skip to content

Commit

Permalink
Added minor additions to DSLX types, funcs, and ast nodes.
Browse files Browse the repository at this point in the history
  - Additional casts for Types.
  - GetParamByName for Functions.
  - Overrideable default for AstNodeVisitorWithDefault.

PiperOrigin-RevId: 681662121
  • Loading branch information
hongted authored and copybara-github committed Oct 3, 2024
1 parent c86dfd7 commit eeabb44
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 2 deletions.
14 changes: 14 additions & 0 deletions xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,20 @@ Function::Function(Module* owner, Span span, NameDef* name_def,

Function::~Function() = default;

absl::StatusOr<Param*> Function::GetParamByName(
std::string_view param_name) const {
auto i = std::find_if(params_.begin(), params_.end(), [=](Param* p) -> bool {
return (p != nullptr) && (p->name_def()->identifier() == param_name);
});

if (i == params_.end()) {
return absl::NotFoundError(absl::StrFormat(
"Param '%s' not a parameter of function %s", param_name, ToString()));
}

return *i;
}

std::vector<AstNode*> Function::GetChildren(bool want_types) const {
std::vector<AstNode*> results;
results.push_back(name_def());
Expand Down
8 changes: 7 additions & 1 deletion xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,18 @@ class AstNodeVisitor {

// Subtype of abstract AstNodeVisitor that returns ok status (does nothing) for
// every node type.
//
// Users can override the default behavior by overriding the DefaultHandler()
// method.
class AstNodeVisitorWithDefault : public AstNodeVisitor {
public:
~AstNodeVisitorWithDefault() override = default;

virtual absl::Status DefaultHandler() { return absl::OkStatus(); }

#define DECLARE_HANDLER(__type) \
absl::Status Handle##__type(const __type* n) override { \
return absl::OkStatus(); \
return DefaultHandler(); \
}
XLS_DSLX_AST_NODE_EACH(DECLARE_HANDLER)
#undef DECLARE_HANDLER
Expand Down Expand Up @@ -1656,6 +1661,7 @@ class Function : public AstNode {
return parametric_bindings_;
}
const std::vector<Param*>& params() const { return params_; }
absl::StatusOr<Param*> GetParamByName(std::string_view param_name) const;

// The body of the function is a block (sequence of statements that yields a
// final expression).
Expand Down
37 changes: 37 additions & 0 deletions xls/dslx/frontend/ast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,5 +205,42 @@ TEST(AstTest, ToStringCastWithinLtComparison) {
EXPECT_EQ(lt->ToString(), "(x as t) < x");
}

TEST(AstTest, GetFuncParam) {
// Create an empty function
// fn f(p: u32) -> u32 {}

FileTable file_table;
const Span fake_span;
Module m("test", /*fs_path=*/std::nullopt, file_table);

BuiltinNameDef* builtin_name_def = m.GetOrCreateBuiltinNameDef("u32");
BuiltinTypeAnnotation* u32_type_annotation = m.Make<BuiltinTypeAnnotation>(
fake_span, BuiltinType::kU32, builtin_name_def);

NameDef* func_name_def =
m.Make<NameDef>(fake_span, std::string("f"), nullptr);
NameDef* param_name_def =
m.Make<NameDef>(fake_span, std::string("p"), nullptr);

std::vector<Param*> params;
params.push_back(m.Make<Param>(param_name_def, u32_type_annotation));

std::vector<ParametricBinding*> parametric_bindings;

StatementBlock* block =
m.Make<StatementBlock>(fake_span, std::vector<Statement*>{}, true);

Function* f =
m.Make<Function>(fake_span, func_name_def, parametric_bindings, params,
u32_type_annotation, block, FunctionTag::kNormal, false);

XLS_ASSERT_OK_AND_ASSIGN(Param * found_param, f->GetParamByName("p"));
EXPECT_EQ(found_param, params[0]);

EXPECT_THAT(f->GetParamByName("not_a_param"),
StatusIs(absl::StatusCode::kNotFound,
HasSubstr("Param 'not_a_param' not a parameter")));
}

} // namespace
} // namespace xls::dslx
18 changes: 17 additions & 1 deletion xls/dslx/type_system/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,13 @@ bool Type::IsToken() const {
return dynamic_cast<const TokenType*>(this) != nullptr;
}

bool Type::IsTuple() const {
return dynamic_cast<const TupleType*>(this) != nullptr;
}

const EnumType& Type::AsEnum() const {
auto* s = dynamic_cast<const EnumType*>(this);
CHECK(s != nullptr) << "Type is not a enum: " << *this;
CHECK(s != nullptr) << "Type is not an enum: " << *this;
return *s;
}

Expand All @@ -386,12 +390,24 @@ const StructType& Type::AsStruct() const {
return *s;
}

const MetaType& Type::AsMeta() const {
auto* s = dynamic_cast<const MetaType*>(this);
CHECK(s != nullptr) << "Type is not a MetaType: " << *this;
return *s;
}

const ArrayType& Type::AsArray() const {
auto* s = dynamic_cast<const ArrayType*>(this);
CHECK(s != nullptr) << "Type is not an array: " << *this;
return *s;
}

const TupleType& Type::AsTuple() const {
auto* s = dynamic_cast<const TupleType*>(this);
CHECK(s != nullptr) << "Type is not a tuple: " << *this;
return *s;
}

// -- TokenType

TokenType::~TokenType() = default;
Expand Down
3 changes: 3 additions & 0 deletions xls/dslx/type_system/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,13 @@ class Type {
bool IsEnum() const;
bool IsArray() const;
bool IsMeta() const;
bool IsTuple() const;

const StructType& AsStruct() const;
const EnumType& AsEnum() const;
const ArrayType& AsArray() const;
const MetaType& AsMeta() const;
const TupleType& AsTuple() const;

protected:
static std::vector<std::unique_ptr<Type>> CloneSpan(
Expand Down
15 changes: 15 additions & 0 deletions xls/dslx/type_system/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ TEST(TypeTest, TestU32) {
EXPECT_EQ(std::vector<TypeDim>{TypeDim::CreateU32(32)}, t.GetAllDims());
EXPECT_EQ(t, *t.ToUBits());
EXPECT_TRUE(IsUBits(t));
EXPECT_FALSE(t.IsTuple());
}

TEST(TypeTest, TestUnit) {
Expand All @@ -193,7 +194,21 @@ TEST(TypeTest, TestUnit) {
EXPECT_EQ("tuple", t.GetDebugTypeName());
EXPECT_EQ(false, t.HasEnum());
EXPECT_TRUE(t.GetAllDims().empty());
EXPECT_TRUE(t.IsTuple());
EXPECT_FALSE(IsUBits(t));

Type* generic_type = &t;
EXPECT_TRUE(generic_type->IsTuple());
EXPECT_EQ(&generic_type->AsTuple(), &t);
}

TEST(TypeTest, TestMetaUnit) {
MetaType meta_t(std::make_unique<BitsType>(false, 32));
EXPECT_TRUE(meta_t.IsMeta());

Type* generic_type = &meta_t;
EXPECT_FALSE(generic_type->IsTuple());
EXPECT_EQ(&generic_type->AsMeta(), &meta_t);
}

TEST(TypeTest, TestTwoTupleOfStruct) {
Expand Down

0 comments on commit eeabb44

Please sign in to comment.