diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 1c1f30a6db..0f878be8b7 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -44,7 +44,7 @@ jobs: - name: Bazel Test xls/... (opt) run: | - bazel test -c opt --noshow_progress --test_output=errors -- //xls/... -//xls/contrib/... -//xls/dev_tools/... + bazel test -c opt --noshow_progress --test_output=errors -- //xls/... -//xls/contrib/... -//xls/dev_tools/... //docs_src/... //xls/tests:documentation_test - name: Bazel Test Including Contrib & DevTools (opt) run: | diff --git a/docs_src/design_docs/dslx_pattern_exhaustiveness.md b/docs_src/design_docs/dslx_pattern_exhaustiveness.md new file mode 100644 index 0000000000..ed17ff6d6b --- /dev/null +++ b/docs_src/design_docs/dslx_pattern_exhaustiveness.md @@ -0,0 +1,169 @@ +# DSLX Pattern Exhaustiveness + +Document contact: [cdleary](https://github.com/cdleary), first written +2025-02-03 + +## Overview + +DSLX supports pattern matching for a limited pattern grammar/syntax. This +document describes the approach we take to checking the exhaustiveness of +pattern matches. + +## Pattern Matching + +The pattern matching syntax supports the following: + +* `_` -- a wildcard pattern. +* `name` -- a name pattern -- this can either bind a name, if it is currently + unbound, or can be used to check for equality if it is bound. + + (Aside: This is admittedly a little scary, but it is identical to Rust, and + the defined-but-did-not-use warning usually flags any misuse.) + +* `range` -- a range pattern. + +* `colon::reference` -- a colon-ref pattern -- equality to some + externally-defined entity. + +* `..` -- usable within a tuple, discards an arbitrary number of elements + +* `()` -- tuples and nesting of tuple patterns within tuple patterns. + +Notably there is not support for structure pattern matching or arrays; these are +desirable to match expressive ability of Rust but are not yet implemented. + +## Exhaustiveness Checking + +Exhaustiveness checking works by projecting types into a sparse N-dimensional +array. + +At the start of exhaustiveness checking we use the type being matched on to make +an N-dimensional array with the extent in each dimension being determined by the +type; e.g. if we're matching on a `u32` we make an interval that represents the +range `[0, 2^32-1]`. + +As we add patterns, we subtract out the space that the pattern covers from the +overall space. + +If we manage to subtract the entire space, then we know that the pattern is +exhaustive. + +If we don't manage to subtract the entire space, then we know that the pattern +is not exhaustive. + +### Monolithic N-dimensional Space + +The exhaustiveness checker works by representing the entire space of values that +can be matched on as a single N-dimensional space. This is conceptually simple +but it flattens all values being matched on into a single N-dimensional sparse +array, which means we discard all hierarchy present in the value being matched +upon. It seems worth noting that this approach is monolithic and flattens the +entire value space down to leaves, razing the tuple structure of the value being +matched upon. + +### Subtracting from N-dimensional Space + +Note that there are two ways you could conceptually approach the idea of +exhaustiveness -- additive and subtractive. In an additive approach you might +unify spaces until you arrived at a space that was equal to the full space. +Instead, our data structure subtracts, as it's easier to distribute subtraction +over disjoint pieces. For example, consider a 2D rectangle where we subtract out +the center of the rectangle, leaving just the left hand side and left hand side +disjointly. We store that in our `NdRegion` as two disjoint contiguous +`NdIntervals`, and then when we go to subtract out a new pattern (which can be +represented as a `NdInterval` as a single pattern always occupies contiguous +space) we can just subtract the new "user added interval" from both pieces of +the space in "brute force" fashion without worrying about merging or which +sub-space might be affected. This definitely seems to simplify things. The +canonical paper on different approaches for this is Luc Maranget's "Warnings for +pattern matching" -- admittedly I didn't do a lot of research for this +implementation, though, the N-dimensional interval space was intuitive enough, +so there may be better approaches we want to try over time. + +### Semantically Load-Bearing + +To be clear, exhaustiveness checking is semantically load-bearing; i.e. it is +not just a lint, but a semantic check. Once exhaustive pattern match is proven +it means we can turn the final arm of a match statement into an "else" +condition, e.g. for a +[priority select](https://google.github.io/xls/ir_semantics/#priority_sel) IR +lowering in the `default` case. + +It was discussed among XLS steering folks whether to make this an opt-in +language feature for purposes of landing, but consensus was that it was useful +enough to enable by default -- this is one of the top reported pain points for +DSLX writing -- just with warnings for redundant patterns default-off for now so +that real-world code bases have time to transition. + +## Structure + +The code is initially structured as follows: + +* `DeduceMatch` -- the main entry point for pattern matching and + exhaustiveness checking from the type system, we call into the + `MatchExhaustivenessChecker` for each pattern encountered in the match + statement. +* `match_exhaustiveness_checker.h` -- the main class for checking + exhaustiveness. + + The deduce rule feeds this object pattern by pattern to check whether a + pattern has led us to the point of exhaustion. This streaming + pattern-at-a-time interface allows us also give a helpful warning when a + pattern is fully redundant with previous patterns, or if we have similarly + added patterns even though we've passed the point of exhaustion. + + DSLX types and values of particular types are translated into intervals and + points at this level to subtract from the `NdRegion` that we maintain to + determine exhaustiveness. + +* `nd_region.h` -- an N-dimensional region type `NdRegion` for representing + the space we're whittling down, with patterns, towards exhaustiveness. Note + that an `NdRegion` is a collection of disjoint `NdInterval`s that we + subtract from in "brute force" fashion as user-written patterns are + introduced to the space, as described above. + + Each pattern can conceptually be translated into an `NdInterval` as it + represents some contiguous space of values in the overall N-dimensional + space. + +* `interp_value_interval.h` -- an interval type `InterpValueInterval` for + representing the range of DSLX values for a given type; provides some basic + facilities for 1D interval arithmetic and queries like `Contains`, + `Intersects`, `Covers`, etc. + +## Notes on Wrinkles + +**Enums** in DSLX are conceptually a set of names that live in some underlying +bit space representation, and may be sparse within that space. That is, there's +nothing wrong with making `enum E : u8 { A = u8:5, B = u8:10 }`. The language +contract is that there can never be an enum value that takes on an +out-of-defined-namespace value. As a result, we project the enum namespace into +a dense unsigned bit space and make intervals over that dense space. i.e. for +the enum `E` above we would require two values to cover the entire space. +Effectively, we don't care what the underlying bit representation is for the +purpose of pattern matching exhaustion. + +Empty enums (i.e. enums with no defined values in its namespace) are similar to +**zero-bit values** in that they have no real representable values. These +bit-space could be defined to be trivially exhaustive, or impossible to match +on -- we choose the latter for now because it's more convenient for +implementation, we can call it a one-bit space and any value with this type +trivially that one value (so we need to have one pattern covering it, but the +pattern matches by definition). + +**Tokens**: Note that there is also a question of tokens which are +zero-bit-like, but I imagine we don't want to re-bind tokens through a pattern +match so we can observe its linear dataflow in a given function. + +**Arrays**: We don't currently support any interesting pattern syntax for +arrays, and they can conceptually create large spaces of values in the +flattening process, so this initial change for exhaustiveness makes them +disallowed in matched expressions until support can be added more +comprehensively. It's not for very serious reasons, however, they could be +flattened in a similar fashion to tuples. + +**Zero-Element Ranges** are possible to write in DSLX, and so in building up +intervals we have to keep a maybe-interval concept until we have resolved fully +that there are nonzero values in the interval. These wind up being N-dimensional +intervals with zero volume so they never subtract out any space in the +exhaustion process. diff --git a/docs_src/dslx_std.md b/docs_src/dslx_std.md index 4547b3e7c8..d3af6fcc53 100644 --- a/docs_src/dslx_std.md +++ b/docs_src/dslx_std.md @@ -636,13 +636,15 @@ enum EnumType : u2 { } fn main(x: EnumType) -> u32 { - match x { - EnumType::FIRST => u32:0, - EnumType::SECOND => u32:1, + if x == EnumType::FIRST { + u32:0 + } else if x == EnumType::SECOND { + u32:1 + } else { // This should not be reachable. // But, if we synthesize hardware, under this condition the function is // well-defined to give back zero. - _ => fail!("unknown_EnumType", u32:0), + fail!("unknown_EnumType", u32:0) } } ``` diff --git a/mkdocs.yml b/mkdocs.yml index f1c234d4c5..44a8234e27 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -68,6 +68,7 @@ nav: - Design Docs: - Legalize Multiple Channel Ops Per Channel: 'design_docs/legalize_multiple_channel_ops_per_channel.md' - Proc-scoped channels: 'design_docs/proc_scoped_channels.md' + - DSLX Pattern Exhaustiveness: 'design_docs/dslx_pattern_exhaustiveness.md' - Releasing: 'releasing.md' - NoC: - Overview: 'noc/xls_noc_readme.md' diff --git a/xls/dslx/BUILD b/xls/dslx/BUILD index cd31adf31f..783dd976cb 100644 --- a/xls/dslx/BUILD +++ b/xls/dslx/BUILD @@ -703,6 +703,7 @@ cc_test( ":warning_kind", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", + "@com_google_absl//absl/status:status_matchers", "@com_google_googletest//:gtest", ], ) diff --git a/xls/dslx/bytecode/bytecode_interpreter_test.cc b/xls/dslx/bytecode/bytecode_interpreter_test.cc index 0b48f5ee27..dd6d672867 100644 --- a/xls/dslx/bytecode/bytecode_interpreter_test.cc +++ b/xls/dslx/bytecode/bytecode_interpreter_test.cc @@ -921,8 +921,9 @@ fn main(x: u32) -> u32 { absl::StatusOr value = Interpret(kProgram, "main", {InterpValue::MakeU32(2)}); - EXPECT_THAT(value.status(), StatusIs(absl::StatusCode::kInternal, - HasSubstr("The value was not matched"))); + EXPECT_THAT(value.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Match pattern is not exhaustive"))); } TEST_F(BytecodeInterpreterTest, RunMatchWithNameRefs) { diff --git a/xls/dslx/create_import_data.cc b/xls/dslx/create_import_data.cc index 9c0d2073bc..14afd1a5c8 100644 --- a/xls/dslx/create_import_data.cc +++ b/xls/dslx/create_import_data.cc @@ -38,14 +38,14 @@ ImportData CreateImportData( return import_data; } -ImportData CreateImportDataForTest( - std::unique_ptr vfs) { +ImportData CreateImportDataForTest(std::unique_ptr vfs, + WarningKindSet warnings) { if (vfs == nullptr) { vfs = std::make_unique(); } absl::Span additional_search_paths = {}; ImportData import_data(xls::kDefaultDslxStdlibPath, additional_search_paths, - kDefaultWarningsSet, std::move(vfs)); + warnings, std::move(vfs)); import_data.SetBytecodeCache(std::make_unique()); return import_data; } diff --git a/xls/dslx/create_import_data.h b/xls/dslx/create_import_data.h index 5ff8dbd7f8..0684d7d0a8 100644 --- a/xls/dslx/create_import_data.h +++ b/xls/dslx/create_import_data.h @@ -25,6 +25,7 @@ #include "xls/dslx/import_data.h" #include "xls/dslx/virtualizable_file_system.h" #include "xls/dslx/warning_kind.h" + namespace xls::dslx { // Creates an ImportData with the given stdlib and search paths and assigns a @@ -37,7 +38,8 @@ ImportData CreateImportData( // Creates an ImportData with reasonable defaults (standard path to the stdlib // and no additional search paths). ImportData CreateImportDataForTest( - std::unique_ptr vfs = nullptr); + std::unique_ptr vfs = nullptr, + WarningKindSet warnings = kAllWarningsSet); std::unique_ptr CreateImportDataPtrForTest(); diff --git a/xls/dslx/exhaustiveness/BUILD b/xls/dslx/exhaustiveness/BUILD new file mode 100644 index 0000000000..d57f4dc7ec --- /dev/null +++ b/xls/dslx/exhaustiveness/BUILD @@ -0,0 +1,107 @@ +# Copyright 2025 The XLS Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//xls:xls_internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "interp_value_interval", + srcs = ["interp_value_interval.cc"], + hdrs = ["interp_value_interval.h"], + deps = [ + "//xls/dslx:interp_value", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "interp_value_interval_test", + srcs = ["interp_value_interval_test.cc"], + deps = [ + ":interp_value_interval", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/common/status:status_macros", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "nd_region", + srcs = ["nd_region.cc"], + hdrs = ["nd_region.h"], + deps = [ + ":interp_value_interval", + "//xls/dslx:interp_value", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "nd_region_test", + srcs = ["nd_region_test.cc"], + deps = [ + ":nd_region", + "//xls/common:xls_gunit_main", + "//xls/common/fuzzing:fuzztest", + "//xls/common/status:matchers", + "//xls/common/status:status_macros", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "match_exhaustiveness_checker", + srcs = ["match_exhaustiveness_checker.cc"], + hdrs = ["match_exhaustiveness_checker.h"], + deps = [ + ":interp_value_interval", + ":nd_region", + "//xls/common:visitor", + "//xls/dslx:import_data", + "//xls/dslx:interp_value", + "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:pos", + "//xls/dslx/type_system:type", + "//xls/dslx/type_system:type_info", + "@com_google_absl//absl/log", + ], +) + +cc_test( + name = "exhaustiveness_match_test", + srcs = ["exhaustiveness_match_test.cc"], + deps = [ + ":match_exhaustiveness_checker", + "//xls/common:visitor", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/dslx:create_import_data", + "//xls/dslx:parse_and_typecheck", + "//xls/dslx/frontend:ast", + "//xls/dslx/type_system:type", + "//xls/dslx/type_system:type_info", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:status_matchers", + "@com_google_googletest//:gtest", + ], +) diff --git a/xls/dslx/exhaustiveness/exhaustiveness_match_test.cc b/xls/dslx/exhaustiveness/exhaustiveness_match_test.cc new file mode 100644 index 0000000000..e99e844788 --- /dev/null +++ b/xls/dslx/exhaustiveness/exhaustiveness_match_test.cc @@ -0,0 +1,542 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "absl/log/log.h" +#include "absl/status/status_matchers.h" +#include "xls/common/status/matchers.h" +#include "xls/common/visitor.h" +#include "xls/dslx/create_import_data.h" +#include "xls/dslx/exhaustiveness/match_exhaustiveness_checker.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/parse_and_typecheck.h" +#include "xls/dslx/type_system/type.h" +#include "xls/dslx/type_system/type_info.h" + +namespace xls::dslx { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::HasSubstr; + +std::vector GetPatterns(const Match& match) { + std::vector patterns; + for (const MatchArm* arm : match.arms()) { + for (const NameDefTree* pattern : arm->patterns()) { + patterns.push_back(pattern); + } + } + return patterns; +} + +void CheckExhaustiveOnlyAfterLastPattern(std::string_view program) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(program, "test.x", "test", &import_data)); + std::optional func = tm.module->GetFunction("main"); + ASSERT_TRUE(func.has_value()); + StatementBlock* body = func.value()->body(); + const Statement& statement = *body->statements().back(); + Match* match = dynamic_cast(std::get(statement.wrapped())); + ASSERT_TRUE(match != nullptr); + + std::optional matched_type = tm.type_info->GetItem(match->matched()); + ASSERT_TRUE(matched_type.has_value()); + ASSERT_NE(matched_type.value(), nullptr); + + MatchExhaustivenessChecker checker(match->matched()->span(), import_data, + *tm.type_info, *matched_type.value()); + + std::vector patterns = GetPatterns(*match); + for (int64_t i = 0; i < patterns.size(); ++i) { + bool now_exhaustive = checker.AddPattern(*patterns[i]); + // We expect it to become exhaustive with the last match arm. + bool expect_now_exhaustive = i + 1 == patterns.size(); + EXPECT_EQ(now_exhaustive, expect_now_exhaustive) + << "Expected match to be " + << (expect_now_exhaustive ? "exhaustive" : "non-exhaustive") + << " after adding pattern `" << patterns[i]->ToString() << "`"; + } +} + +void CheckNonExhaustive(std::string_view program) { + ImportData import_data = CreateImportDataForTest(); + absl::StatusOr tm = + ParseAndTypecheck(program, "test.x", "test", &import_data); + EXPECT_THAT(tm.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Match patterns are not exhaustive"))); +} + +void CheckExhaustiveWithRedundantPattern(std::string_view program) { + WarningKindSet warnings = + DisableWarning(kAllWarningsSet, WarningKind::kUnusedDefinition); + ImportData import_data = CreateImportDataForTest(nullptr, warnings); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(program, "test.x", "test", &import_data)); + absl::Span collected_warnings = + tm.warnings.warnings(); + ASSERT_FALSE(collected_warnings.empty()); + for (const WarningCollector::Entry& warning : collected_warnings) { + EXPECT_THAT(warning.message, + HasSubstr("Match is already exhaustive before this pattern")); + } +} + +TEST(ExhaustivenessMatchTest, MatchBoolTrueFalse) { + constexpr std::string_view kMatch = R"(fn main(x: bool) -> u32 { + match x { + false => u32:42, + true => u32:64, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MatchBoolJustTrue) { + constexpr std::string_view kMatch = R"(fn main(x: bool) -> u32 { + match x { + true => u32:42, + _ => u32:64, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MatchOneHotsInTuple) { + constexpr std::string_view kMatch = R"(fn main(t: (bool, bool, bool)) -> u32 { + match t { + (true, _, _) => u32:42, + (_, true, _) => u32:64, + (_, _, true) => u32:86, + _ => u32:0, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, SammpleFromStdlib) { + constexpr std::string_view kMatch = R"(fn main(x: u32, y: u32) -> u32 { + match (x[-1:], y[-1:]) { + (u1:1, u1:1) => u32:42, + (u1:1, u1:0) => u32:64, + _ => u32:0, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MatchRestOfTuple) { + constexpr std::string_view kMatch = R"(fn main(t: (bool, bool, bool)) -> u32 { + match t { + (true, ..) => u32:42, + (_, true, ..) => u32:64, + (.., true) => u32:128, + (false, false, false) => u32:0, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MatchNestedTuple) { + // Test a nested tuple match pattern. + constexpr std::string_view kMatch = + R"(fn main(x: (bool, (bool, bool))) -> u32 { + match x { + (true, (true, _)) => u32:1, + (false, (_, true)) => u32:2, + (_, (_, _)) => u32:0, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MatchRedundantPattern) { + // Even if one of the arms is redundant, the overall match is only complete + // once a catch-all branch is reached. + constexpr std::string_view kMatch = R"(fn main(t: (bool, bool)) -> u32 { + match t { + (true, _) => u32:1, + (true, false) => u32:2, + _ => u32:0, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +// Dense enum values but missing the top value in the underlying type. +TEST(ExhaustivenessMatchTest, MatchOnDenseZeroAlignedEnum) { + constexpr std::string_view kMatch = R"(enum E : u2 { + A = 0, + B = 1, + C = 2, + } + fn main(e: E) -> u32 { + match e { + E::A => u32:42, + E::B => u32:64, + E::C => u32:86, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +// Note that the value 1 is a gap in the enum name space. +TEST(ExhaustivenessMatchTest, MatchOnSparseEnum) { + constexpr std::string_view kMatch = R"(enum E : u2 { + A = 0, + B = 2, + C = 3, + } + fn main(e: E) -> u32 { + match e { + E::A => u32:42, + E::B => u32:64, + E::C => u32:86, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MatchWithNestedTuplesAndRestOfTupleSprinkled) { + constexpr std::string_view kMatch = + R"(fn main(t: (u32, (u32, u32, u32), u32, u32)) -> u32 { + match t { + (.., u32:1, a) => a, + (u32:2, (u32:0, _, b), ..) => b, + (u32:3, (.., c), ..) => c, + (u32:4, d, ..) => d.0, + _ => u32:0xdeadbeef, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, EmptyRestOfTupleInCenter) { + constexpr std::string_view kMatch = R"(fn main(x: (u32, u33)) -> u32 { + match x { + (y, .., z) => y, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, RestOfTupleAsWildcard) { + constexpr std::string_view kMatch = R"(fn main(t: (u32, u32)) -> u32 { + match t { + (u32:1, .., a) => a, + (..) => u32:1 + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, SignedNegativeAndPositiveValueRanges) { + constexpr std::string_view kMatch = R"(fn main(x: s8) -> u32 { + match x { + s8:0..s8:127 => u32:42, + s8:-128..s8:0 => u32:64, + s8:127 => u32:128, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MatchWithRangeHavingStartGtLimit) { + constexpr std::string_view kMatch = R"(fn main(x: u32) -> u32 { + match x { + u32:7..u32:0 => u32:42, + _ => u32:0, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +// Range with a "zero-volume interval" / empty range. +TEST(ExhaustivenessMatchTest, MatchWithRangeHavingStartEqLimit) { + constexpr std::string_view kMatch = R"(fn main(x: u32) -> u32 { + match x { + u32:7..u32:7 => u32:42, + _ => u32:0, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, ExhaustiveU2) { + constexpr std::string_view kMatch = R"(fn main(x: u2) -> u32 { + match x { + u2:0 => u32:42, + u2:1 => u32:64, + u2:2 => u32:128, + u2:3 => u32:256, + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, NonExhaustiveU2) { + constexpr std::string_view kMatch = R"(fn main(x: u2) -> u32 { + match x { + u2:0 => u32:42, + u2:1 => u32:64, + u2:2 => u32:128, + } +})"; + CheckNonExhaustive(kMatch); +} + +TEST(ExhaustivenessMatchTest, NonExhaustiveTuple) { + constexpr std::string_view kMatch = R"(fn main(x: (u2, u2)) -> u32 { + match x { + (u2:0, _) => u32:42, + (u2:1, u2:0) => u32:64, + } +})"; + CheckNonExhaustive(kMatch); +} + +TEST(ExhaustivenessMatchTest, RedundantPattern) { + constexpr std::string_view kMatch = R"(fn main(x: u4) -> u32 { + match x { + _ => u32:0, + u4:5 => u32:1, + } +})"; + CheckExhaustiveWithRedundantPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, ComplexOverlapAndFragmentation) { + constexpr std::string_view kMatch = R"(fn main(x: u4) -> u32 { + match x { + u4:0..u4:2 => u32:0, + u4:2..u4:4 => u32:1, + u4:4..u4:6 => u32:2, + // Intentionally leaving a gap for values 6 and 7. + u4:8..u4:10 => u32:3, + _ => fail!("nonexhaustive_match", u32:0), + } +})"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, NonExhaustiveTupleMixedElementTypes) { + constexpr std::string_view kMatch = R"( +fn main(x: (u2, bool)) -> u32 { + match x { + (u2:0, true) => u32:10, + (u2:1, false) => u32:20, + } +} +)"; + CheckNonExhaustive(kMatch); +} + +TEST(ExhaustivenessMatchTest, MultiDimensionalTuplePattern) { + constexpr std::string_view kMatch = R"( +enum E : u2 { + A = 0, + B = 1, + C = 2, +} +fn main(x: (E, bool)) -> u32 { + match x { + (E::A, false) => u32:10, + (E::A, true) => u32:20, + (E::B, _) => u32:30, + (E::C, _) => u32:40, + } +} +)"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MixLiteralAndRangePatterns) { + constexpr std::string_view kMatch = R"( +fn main(x: u4) -> u32 { + match x { + u4:0..u4:2 => u32:0, // Matches 0 and 1. + u4:2 => u32:10, // Matches 2. + _ => u32:20, // Covers the rest. + } +} +)"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MultipleRedundantPatterns) { + constexpr std::string_view kMatch = R"( +fn main(x: u4) -> u32 { + match x { + _ => u32:0, + u4:1 => u32:10, + u4:2 => u32:20, + } +} +)"; + CheckExhaustiveWithRedundantPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, NestedMatchExpression) { + constexpr std::string_view kProgram = R"( +fn main(x: bool, y: bool) -> u32 { + // Outer match only has two arms; one of which returns a nested match. + match x { + true => match y { + true => u32:1, + false => u32:2, + }, + false => u32:3, + } +} +)"; + // Because the top-level statement is a match, we can check its + // exhaustiveness. + CheckExhaustiveOnlyAfterLastPattern(kProgram); +} + +TEST(ExhaustivenessMatchTest, DeeplyNestedTuplePatterns) { + // This test ensures that the exhaustiveness checker recurses into nested + // tuples. + constexpr std::string_view kMatch = + R"(fn main(x: ((bool, bool), bool)) -> u32 { + match x { + ((true, true), true) => u32:1, + ((true, true), false) => u32:2, + ((true, false), true) => u32:3, + ((true, false), false) => u32:4, + ((false, true), true) => u32:5, + ((false, true), false) => u32:6, + ((false, false), true) => u32:7, + ((false, false), false) => u32:8, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, RedundantNamedWildcardPattern) { + // Instead of a "_" wildcard, we bind a variable which covers all cases, + // then add a redundant arm. + constexpr std::string_view kMatch = R"(fn main(x: bool) -> u32 { + match x { + val => u32:42, // 'val' binds both true and false + false => u32:64, // redundant arm + } + })"; + CheckExhaustiveWithRedundantPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, EmptyTupleMatch) { + // Matching on the unit type is a good edge case. + constexpr std::string_view kMatch = R"(fn main(x: ()) -> u32 { + match x { + () => u32:99, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, OverlappingRangesForSignedIntegers) { + // Here we deliberately use overlapping ranges for a signed type. + // Although the ranges overlap, the catch-all arm ensures exhaustiveness. + constexpr std::string_view kMatch = R"(fn main(x: s4) -> u32 { + match x { + s4:-8..s4:-2 => u32:10, + s4:-4..s4:1 => u32:20, + s4:0 => u32:30, + _ => u32:40, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MatchOnLetBoundExpression) { + // Instead of matching directly on a parameter, match on its transformed + // value. + constexpr std::string_view kMatch = R"(fn main(x: bool) -> u32 { + let y = !x; + match y { + true => u32:5, + false => u32:10, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, MiddleCatchAllRedundantSuffix) { + constexpr std::string_view kMatch = R"(fn main(x: u4) -> u32 { + match x { + u4:0 => u32:10, + _ => u32:20, // catch-all in the middle: makes the match exhaustive here + u4:1 => u32:30, // redundant + u4:2 => u32:40, // redundant + u4:3 => u32:50, // redundant + } + })"; + CheckExhaustiveWithRedundantPattern(kMatch); +} + +// Note: the second range has a partial overlap with the first range, but we +// don't flag this as a redundant pattern because we only flag redundant +// patterns once we've reached exhaustiveness. +TEST(ExhaustivenessMatchTest, RedundantRangeOverlap) { + // The second range overlaps with the first one. + constexpr std::string_view kMatch = R"(fn main(x: u4) -> u32 { + match x { + u4:0..u4:3 => u32:10, // covers 0, 1, 2 + u4:2..u4:4 => u32:20, // overlaps: covers 2, 3 + _ => u32:30, // covers the rest (values 4-15) + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, RestOfTupleAtStart) { + constexpr std::string_view kMatch = R"(fn main(t: (u32, u32, u32)) -> u32 { + match t { + (.., u32:3) => u32:100, + _ => u32:200, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, RestOfTupleWithPrefixAndSuffix) { + constexpr std::string_view kMatch = + R"(fn main(t: (u32, u32, u32, u32)) -> u32 { + match t { + (u32:10, .., u32:20) => u32:111, + (u32:30, .., u32:40) => u32:222, + _ => u32:333, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +TEST(ExhaustivenessMatchTest, NestedTupleWithRestOfTuple) { + constexpr std::string_view kMatch = + R"(fn main(t: ((u32, u32, u32), u32)) -> u32 { + match t { + ((u32:5, ..), u32:7) => u32:111, + ((u32:8, ..), u32:9) => u32:222, + _ => u32:333, + } + })"; + CheckExhaustiveOnlyAfterLastPattern(kMatch); +} + +} // namespace +} // namespace xls::dslx diff --git a/xls/dslx/exhaustiveness/interp_value_interval.cc b/xls/dslx/exhaustiveness/interp_value_interval.cc new file mode 100644 index 0000000000..2ff258a36b --- /dev/null +++ b/xls/dslx/exhaustiveness/interp_value_interval.cc @@ -0,0 +1,62 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/exhaustiveness/interp_value_interval.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_format.h" +#include "xls/dslx/interp_value.h" + +namespace xls::dslx { + +/* static */ InterpValueInterval InterpValueInterval::MakeFull( + bool is_signed, int64_t bit_count) { + return InterpValueInterval(InterpValue::MakeMinValue(is_signed, bit_count), + InterpValue::MakeMaxValue(is_signed, bit_count)); +} + +InterpValueInterval::InterpValueInterval(InterpValue min, InterpValue max) + : min_(std::move(min)), max_(std::move(max)) { + absl::StatusOr le = min_.Le(max_); + CHECK_OK(le.status()); + CHECK(le->IsTrue()) << absl::StreamFormat( + "InterpValueInterval; min: %s max: %s need min <= max", min_.ToString(), + max_.ToString()); +} + +bool InterpValueInterval::Contains(InterpValue value) const { + return min_.Le(value)->IsTrue() && max_.Ge(value)->IsTrue(); +} + +std::string InterpValueInterval::ToString(bool show_types) const { + return absl::StrFormat("[%s, %s]", min_.ToString(/*humanize=*/!show_types), + max_.ToString(/*humanize=*/!show_types)); +} + +bool InterpValueInterval::IsSigned() const { + CHECK_EQ(min_.IsSigned(), max_.IsSigned()); + return min_.IsSigned(); +} + +int64_t InterpValueInterval::GetBitCount() const { + CHECK_EQ(min_.GetBitCount(), max_.GetBitCount()); + return min_.GetBitCount().value(); +} + +} // namespace xls::dslx diff --git a/xls/dslx/exhaustiveness/interp_value_interval.h b/xls/dslx/exhaustiveness/interp_value_interval.h new file mode 100644 index 0000000000..3f0bf45884 --- /dev/null +++ b/xls/dslx/exhaustiveness/interp_value_interval.h @@ -0,0 +1,74 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef XLS_DSLX_EXHAUSTIVENESS_INTERP_VALUE_INTERVAL_H_ +#define XLS_DSLX_EXHAUSTIVENESS_INTERP_VALUE_INTERVAL_H_ + +#include +#include + +#include "xls/dslx/interp_value.h" + +namespace xls::dslx { + +// Represents a single contiguous interval of values; e.g. for a run of bits +// values with the same bits type. +// +// This is an inclusive interval, i.e. `[min, max]`, because there's not a lot +// of value in representing a "volumeless" interval via an exclusive limit, we +// can discard those early. +// +// Generally we can think of this as "the concrete representation of a range +// expression" a la `0..3` or similar. +class InterpValueInterval { + public: + static InterpValueInterval MakeFull(bool is_signed, int64_t bit_count); + + InterpValueInterval(InterpValue min, InterpValue max); + + // Returns true if this interval contains the given value. + bool Contains(InterpValue value) const; + + // Returns true if this interval intersects with the given interval. + bool Intersects(const InterpValueInterval& other) const { + return Contains(other.min_) || Contains(other.max_) || + other.Contains(min_) || other.Contains(max_); + } + + bool Covers(const InterpValueInterval& other) const { + return Contains(other.min_) && Contains(other.max_); + } + + bool operator==(const InterpValueInterval& other) const { + return min_ == other.min_ && max_ == other.max_; + } + + bool operator<(const InterpValueInterval& other) const { + return min_ < other.min_ || (min_ == other.min_ && max_ < other.max_); + } + + const InterpValue& min() const { return min_; } + const InterpValue& max() const { return max_; } + + std::string ToString(bool show_types) const; + + private: + bool IsSigned() const; + int64_t GetBitCount() const; + InterpValue min_; + InterpValue max_; +}; + +} // namespace xls::dslx + +#endif // XLS_DSLX_EXHAUSTIVENESS_INTERP_VALUE_INTERVAL_H_ diff --git a/xls/dslx/exhaustiveness/interp_value_interval_test.cc b/xls/dslx/exhaustiveness/interp_value_interval_test.cc new file mode 100644 index 0000000000..0b0be91866 --- /dev/null +++ b/xls/dslx/exhaustiveness/interp_value_interval_test.cc @@ -0,0 +1,33 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/exhaustiveness/interp_value_interval.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace xls::dslx { + +TEST(InterpValueIntervalTest, IntervalContainsPoints) { + InterpValueInterval iv(InterpValue::MakeUBits(8, 1), + InterpValue::MakeUBits(8, 10)); + EXPECT_FALSE(iv.Contains(InterpValue::MakeUBits(8, 0))); + EXPECT_TRUE(iv.Contains(InterpValue::MakeUBits(8, 1))); + EXPECT_TRUE(iv.Contains(InterpValue::MakeUBits(8, 5))); + EXPECT_TRUE(iv.Contains(InterpValue::MakeUBits(8, 10))); + EXPECT_FALSE(iv.Contains(InterpValue::MakeUBits(8, 0))); + EXPECT_FALSE(iv.Contains(InterpValue::MakeUBits(8, 11))); +} + +} // namespace xls::dslx diff --git a/xls/dslx/exhaustiveness/match_exhaustiveness_checker.cc b/xls/dslx/exhaustiveness/match_exhaustiveness_checker.cc new file mode 100644 index 0000000000..2e7055d331 --- /dev/null +++ b/xls/dslx/exhaustiveness/match_exhaustiveness_checker.cc @@ -0,0 +1,494 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/exhaustiveness/match_exhaustiveness_checker.h" + +#include + +#include "absl/log/log.h" +#include "xls/common/visitor.h" +#include "xls/dslx/import_data.h" + +namespace xls::dslx { +namespace { + +std::vector GetLeafTypesInternal(const Type& type) { + if (type.IsTuple()) { + std::vector result; + for (const std::unique_ptr& member : type.AsTuple().members()) { + std::vector member_leaf_types = + GetLeafTypesInternal(*member); + result.insert(result.end(), member_leaf_types.begin(), + member_leaf_types.end()); + } + return result; + } + return {&type}; +} + +std::vector GetLeafTypes(const Type& type, const Span& span, + const FileTable& file_table) { + std::vector result = GetLeafTypesInternal(type); + // Validate that all the matched-upon types are either bits or enums. + for (const Type* leaf_type : result) { + CHECK(GetBitsLike(*leaf_type).has_value() || leaf_type->IsEnum()) + << "Non-bits or non-enum type in matched-upon tuple: " + << leaf_type->ToString() << " @ " << span.ToString(file_table); + } + return result; +} + +// Sentinel type to indicate that some wildcard is present for a value. This +// lets us collapse out varieties of wildcards e.g. RestOfTuple and +// WildcardPattern and NameDef. +struct SomeWildcard {}; + +// NameDefTree::Leaf but where RestOfTuple has been resolved. +using PatternLeaf = + std::variant; + +InterpValueInterval MakeFullIntervalForType(const Type& type) { + if (type.IsEnum()) { + return MakeFullIntervalForEnumType(type.AsEnum()); + } + std::optional bits_like = GetBitsLike(type); + CHECK(bits_like.has_value()) + << "MakeFullIntervalForType; got non-bits type: " << type.ToString(); + int64_t bit_count = bits_like->size.GetAsInt64().value(); + bool is_signed = bits_like->is_signed.GetAsBool().value(); + InterpValue min = InterpValue::MakeMinValue(is_signed, bit_count); + InterpValue max = InterpValue::MakeMaxValue(is_signed, bit_count); + InterpValueInterval result(min, max); + VLOG(5) << "MakeFullIntervalForType; type: `" << type.ToString() + << "` result: " << result.ToString(/*show_types=*/false); + return result; +} + +// Returns the "full" intervals that can be used to represent the "no values +// have been exhausted" initial state. +std::vector GetFullIntervals( + absl::Span leaf_types) { + std::vector result; + for (const Type* leaf_type : leaf_types) { + result.push_back(MakeFullIntervalForType(*leaf_type)); + } + return result; +} + +InterpValueInterval MakePointIntervalForType(const Type& type, + const InterpValue& value, + const ImportData& import_data) { + VLOG(5) << "MakePointIntervalForType; type: `" << type.ToString() + << "` value: `" << value.ToString() << "`"; + if (type.IsEnum()) { + return MakePointIntervalForEnumType(type.AsEnum(), value, import_data); + } + std::optional bits_like = GetBitsLike(type); + CHECK(bits_like.has_value()) + << "MakePointIntervalForType; got non-bits type: " << type.ToString(); + return InterpValueInterval(value, value); +} + +InterpValueInterval MakeIntervalForType(const Type& type, + const InterpValue& min, + const InterpValue& max) { + std::optional bits_like = GetBitsLike(type); + CHECK(bits_like.has_value()) + << "MakeIntervalForType; got non-bits type: " << type.ToString(); + return InterpValueInterval(min, max); +} + +std::optional PatternToIntervalInternal( + const PatternLeaf& leaf, const Type& leaf_type, const TypeInfo& type_info, + const ImportData& import_data) { + std::optional result = absl::visit( + Visitor{ + [&](SomeWildcard /*unused*/) -> std::optional { + return MakeFullIntervalForType(leaf_type); + }, + [&](NameRef* name_ref) -> std::optional { + std::optional value = + type_info.GetConstExprOption(name_ref); + if (value.has_value()) { + return MakePointIntervalForType(leaf_type, value.value(), + import_data); + } + return MakeFullIntervalForType(leaf_type); + }, + [&](Range* range) -> std::optional { + std::optional start = + type_info.GetConstExprOption(range->start()); + std::optional limit = + type_info.GetConstExprOption(range->end()); + CHECK(start.has_value()); + CHECK(limit.has_value()); + if (start->Eq(limit.value())) { + return std::nullopt; + } + std::optional max = limit.value().Decrement(); + if (!max.has_value()) { + // Underflow -- that means the range must be empty because the + // limit is exclusive and is known to be representable in the + // type. + return std::nullopt; + } + if (max->Lt(start.value()).value().IsTrue()) { + // max < start, so the range is empty. + return std::nullopt; + } + return MakeIntervalForType(leaf_type, *start, max.value()); + }, + [&](ColonRef* colon_ref) -> std::optional { + std::optional value = + type_info.GetConstExprOption(colon_ref); + CHECK(value.has_value()); + VLOG(5) << "PatternToIntervalInternal; colon_ref: `" + << colon_ref->ToString() << "` value: `" + << value.value().ToString() << "`" << " leaf_type: `" + << leaf_type.ToString() << "`"; + return MakePointIntervalForType(leaf_type, value.value(), + import_data); + }, + [&](Number* number) -> std::optional { + std::optional value = + type_info.GetConstExprOption(number); + CHECK(value.has_value()); + return MakePointIntervalForType(leaf_type, value.value(), + import_data); + }}, + leaf); + VLOG(5) << "PatternToIntervalInternal; leaf_type: `" << leaf_type.ToString() + << "` result: " + << (result.has_value() ? result->ToString(/*show_types=*/false) + : "nullopt"); + return result; +} + +PatternLeaf ToPatternLeaf(const NameDefTree::Leaf& leaf) { + return absl::visit( + Visitor{ + [&](NameDef* name_def) -> PatternLeaf { return SomeWildcard(); }, + [&](NameRef* name_ref) -> PatternLeaf { return name_ref; }, + [&](Range* range) -> PatternLeaf { return range; }, + [&](ColonRef* colon_ref) -> PatternLeaf { return colon_ref; }, + [&](WildcardPattern* wildcard_pattern) -> PatternLeaf { + return SomeWildcard(); + }, + [&](Number* number) -> PatternLeaf { return number; }, + [&](RestOfTuple* rest_of_tuple) -> PatternLeaf { + LOG(FATAL) << "RestOfTuple not valid for conversion to PatternLeaf"; + }}, + leaf); +} + +std::vector ExpandPatternLeaves(const NameDefTree& pattern, + const Type& type, + const FileTable& file_table) { + VLOG(5) << "ExpandPatternLeaves; pattern: `" << pattern.ToString() + << "` type: `" << type.ToString() << "`"; + // For an irrefutable pattern, simply return wildcards for every leaf. + if (pattern.IsIrrefutable()) { + std::vector leaf_types = + GetLeafTypes(type, pattern.span(), file_table); + return std::vector(leaf_types.size(), SomeWildcard()); + } + // If the type is not a tuple then we expect the pattern to be a single leaf. + if (!type.IsTuple()) { + std::vector leaves = pattern.Flatten(); + CHECK_EQ(leaves.size(), 1) + << "Expected a single leaf for non-tuple type, got " << leaves.size(); + return {ToPatternLeaf(leaves.front())}; + } + // Walk through the pattern and expand any RestOfTuple markers into the + // appropriate number of wildcards. + // + // In order to do this we have to recursively call to ExpandPatternLeaves for + // any sub-tuples encountered. + absl::Span> tuple_members = + type.AsTuple().members(); + std::vector> flattened = + pattern.Flatten1(); + + // Note: there can be fewer flatten1'd nodes than tuple elements because of + // RestOfTuple markers. + // + // We need the `+1` here because we can have RestOfTuple markers that map to + // zero elements in the tuple (i.e. useless/redundant ones). + CHECK_LE(flattened.size(), tuple_members.size() + 1); + + // The results correspond to leaf types. + std::vector result; + + // The tuple type index at *this level* of the tuple. + // We bump this as we progress through -- note a single "flattened_index" + // below can advance zero or more type indices. + int64_t types_index = 0; + + for (int64_t flattened_index = 0; flattened_index < flattened.size(); + ++flattened_index) { + VLOG(5) << "ExpandPatternLeaves; flattened_index: " << flattened_index + << " flattened.size(): " << flattened.size() + << " types_index: " << types_index + << " tuple_members.size(): " << tuple_members.size(); + CHECK_LT(flattened_index, flattened.size()) + << "Flattened index out of bounds."; + const auto& node = flattened[flattened_index]; + + if (std::holds_alternative(node)) { + const NameDefTree* sub_pattern = std::get(node); + CHECK_LT(types_index, tuple_members.size()); + const Type& type_at_index = *tuple_members[types_index]; + + std::vector sub_pattern_leaves = + ExpandPatternLeaves(*sub_pattern, type_at_index, file_table); + + result.insert(result.end(), sub_pattern_leaves.begin(), + sub_pattern_leaves.end()); + types_index += 1; + continue; + } + const NameDefTree::Leaf& leaf = std::get(node); + absl::visit( + Visitor{ + [&](const NameRef* n) { + result.push_back(ToPatternLeaf(leaf)); + types_index += 1; + }, + [&](const Range* r) { + result.push_back(ToPatternLeaf(leaf)); + types_index += 1; + }, + [&](const ColonRef* c) { + result.push_back(ToPatternLeaf(leaf)); + types_index += 1; + }, + [&](const Number* n) { + result.push_back(ToPatternLeaf(leaf)); + types_index += 1; + }, + [&](const RestOfTuple* /*unused*/) { + // Instead of using flattened_index here, use types_index (the + // number of tuple elements already matched) to figure out how + // many items we need "in the rest". + int64_t explicit_before = types_index; + int64_t explicit_after = flattened.size() - flattened_index - 1; + int64_t to_push = + tuple_members.size() - (explicit_before + explicit_after); + VLOG(5) << "ExpandPatternLeaves; RestOfTuple at flattened_index: " + << flattened_index << " types_index: " << types_index + << " explicit_after: " << explicit_after + << " to_push: " << to_push; + for (int64_t i = 0; i < to_push; ++i) { + // We have to push wildcard data corresponding to the type. + CHECK_LT(types_index, tuple_members.size()); + const Type& type_at_index = *tuple_members[types_index]; + for (int64_t i = 0; + i < GetLeafTypes(type_at_index, pattern.span(), file_table) + .size(); + ++i) { + result.push_back(SomeWildcard()); + } + types_index += 1; + } + VLOG(5) << "ExpandPatternLeaves; after RestOfTuple at " + "flattened_index: " + << flattened_index << " types_index: " << types_index + << " result.size(): " << result.size(); + }, + [&](const auto* irrefutable_leaf) { + // Push back wildcards of the right size for the type. + CHECK_LT(types_index, tuple_members.size()); + const Type& type_at_index = *tuple_members[types_index]; + for (int64_t i = 0; + i < GetLeafTypes(type_at_index, pattern.span(), file_table) + .size(); + ++i) { + result.push_back(SomeWildcard()); + } + types_index += 1; + }}, + leaf); + } + + // Check that we got a consistent count between the razed tuple types and the + // PatternLeaf vector. + CHECK_EQ(result.size(), GetLeafTypes(type, pattern.span(), file_table).size()) + << "Sub-pattern leaves and tuple type must be the same size."; + return result; +} + +NdIntervalWithEmpty PatternToInterval(const NameDefTree& pattern, + const Type& matched_type, + absl::Span leaf_types, + const TypeInfo& type_info, + const ImportData& import_data) { + std::vector pattern_leaves = + ExpandPatternLeaves(pattern, matched_type, type_info.file_table()); + CHECK_EQ(pattern_leaves.size(), leaf_types.size()) + << "Pattern leaves and leaf types must be the same size."; + + // Each leaf describes some range in its dimension that it matches on -- + // together, they describe an n-dimensional interval. + std::vector> intervals; + for (int64_t i = 0; i < pattern_leaves.size(); ++i) { + intervals.push_back(PatternToIntervalInternal( + pattern_leaves[i], *leaf_types[i], type_info, import_data)); + } + NdIntervalWithEmpty result(intervals); + VLOG(5) << "PatternToInterval; pattern: `" << pattern.ToString() + << "` type: `" << matched_type.ToString() + << "` result: " << result.ToString(/*show_types=*/false); + return result; +} + +NdRegion MakeFullNdRegion(absl::Span leaf_types) { + std::vector intervals = GetFullIntervals(leaf_types); + std::vector dim_extents; + dim_extents.reserve(intervals.size()); + for (const InterpValueInterval& interval : intervals) { + dim_extents.push_back(interval.max()); + } + return NdRegion::MakeFromNdInterval(NdInterval(std::move(intervals)), + std::move(dim_extents)); +} + +} // namespace + +// -- class MatchExhaustivenessChecker + +MatchExhaustivenessChecker::MatchExhaustivenessChecker( + const Span& matched_expr_span, const ImportData& import_data, + const TypeInfo& type_info, const Type& matched_type) + : matched_expr_span_(matched_expr_span), + import_data_(import_data), + type_info_(type_info), + matched_type_(matched_type), + leaf_types_(GetLeafTypes(matched_type, matched_expr_span, file_table())), + remaining_(MakeFullNdRegion(leaf_types_)) {} + +bool MatchExhaustivenessChecker::IsExhaustive() const { + return remaining_.IsEmpty(); +} + +bool MatchExhaustivenessChecker::AddPattern(const NameDefTree& pattern) { + VLOG(5) << "MatchExhaustivenessChecker::AddPattern: `" << pattern.ToString() + << "` matched_type: `" << matched_type_.ToString() << "` @ " + << pattern.span().ToString(file_table()); + + NdIntervalWithEmpty this_pattern_interval = PatternToInterval( + pattern, matched_type_, leaf_types_, type_info_, import_data_); + remaining_ = remaining_.SubtractInterval(this_pattern_interval); + return IsExhaustive(); +} + +std::optional +MatchExhaustivenessChecker::SampleSimplestUncoveredValue() const { + // If there are no uncovered regions, we are fully exhaustive. + if (remaining_.IsEmpty()) { + return std::nullopt; + } + + // For now, just choose the first uncovered region. + const NdInterval& nd_interval = remaining_.disjoint().front(); + std::vector components; + + // For each dimension of the region, grab the lower bound (i.e. the simplest + // value in that interval). + for (int64_t i = 0; i < nd_interval.dims().size(); ++i) { + const Type& type = *leaf_types_[i]; + const InterpValueInterval& interval = nd_interval.dims()[i]; + const InterpValue& min = interval.min(); + if (type.IsEnum()) { + // We have to project back from dense space to enum name space. + const EnumType& enum_type = type.AsEnum(); + const EnumDef& enum_def = enum_type.nominal_type(); + int64_t member_index = min.GetBitValueUnsigned().value(); + CHECK_LT(member_index, enum_def.values().size()) + << "Member index out of bounds: " << member_index + << " for enum: " << enum_type.ToString(); + const EnumMember& member = enum_def.values()[member_index]; + InterpValue member_value = + type_info_.GetConstExpr(member.name_def).value(); + VLOG(5) << "SampleSimplestUncoveredValue; enum_type: " + << enum_type.ToString() << " member_index: " << member_index + << " member: " << member.name_def->ToString() + << " member_value: " << member_value.ToString(); + components.push_back(std::move(member_value)); + } else { + components.push_back(min); + } + } + + // If we have a single component, return it directly; otherwise, return a + // tuple. + if (components.size() == 1) { + return components[0]; + } + return InterpValue::MakeTuple(components); +} + +InterpValueInterval MakeFullIntervalForEnumType(const EnumType& enum_type) { + int64_t bit_count = enum_type.size().GetAsInt64().value(); + const EnumDef& enum_def = enum_type.nominal_type(); + int64_t enum_value_count = enum_def.values().size(); + VLOG(5) << "MakeFullIntervalForEnumType; enum_type: " << enum_type.ToString() + << " enum_value_count: " << enum_value_count; + CHECK_GT(enum_value_count, 0) + << "Cannot make full interval for enum type with no values: " + << enum_type.ToString(); + // Note: regardless of the requested underlying type of the enum we use a + // dense unsigned space to represent the values present in the enum namespace. + InterpValue min = InterpValue::MakeUBits(bit_count, 0); + InterpValue max = InterpValue::MakeUBits(bit_count, enum_value_count - 1); + InterpValueInterval result(min, max); + VLOG(5) << "MakeFullIntervalForEnumType; result: " + << result.ToString(/*show_types=*/false); + return result; +} + +std::optional GetEnumMemberIndex(const EnumType& enum_type, + const InterpValue& value, + const ImportData& import_data) { + const EnumDef& enum_def = enum_type.nominal_type(); + const TypeInfo& type_info = + *import_data.GetRootTypeInfoForNode(&enum_def).value(); + for (int64_t i = 0; i < enum_def.values().size(); ++i) { + const EnumMember& member = enum_def.values()[i]; + InterpValue member_val = type_info.GetConstExpr(member.name_def).value(); + if (member_val == value) { + return i; + } + } + return std::nullopt; +} + +InterpValueInterval MakePointIntervalForEnumType( + const EnumType& enum_type, const InterpValue& value, + const ImportData& import_data) { + CHECK(value.IsEnum()) + << "MakePointIntervalForEnumType; value is not an enum: " + << value.ToString(); + int64_t bit_count = enum_type.size().GetAsInt64().value(); + // The `value` provided is the `i`th value in the dense enum space -- let's + // determine that value `i`. + int64_t member_index = + GetEnumMemberIndex(enum_type, value, import_data).value(); + const InterpValue value_as_bits = + InterpValue::MakeUBits(bit_count, member_index); + VLOG(5) << "MakePointIntervalForEnumType; value_as_bits: " + << value_as_bits.ToString() << " member_index: " << member_index; + return InterpValueInterval(value_as_bits, value_as_bits); +} + +} // namespace xls::dslx diff --git a/xls/dslx/exhaustiveness/match_exhaustiveness_checker.h b/xls/dslx/exhaustiveness/match_exhaustiveness_checker.h new file mode 100644 index 0000000000..ae91d37b99 --- /dev/null +++ b/xls/dslx/exhaustiveness/match_exhaustiveness_checker.h @@ -0,0 +1,85 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef XLS_DSLX_EXHAUSTIVENESS_MATCH_EXHAUSTIVENESS_CHECKER_H_ +#define XLS_DSLX_EXHAUSTIVENESS_MATCH_EXHAUSTIVENESS_CHECKER_H_ + +#include +#include + +#include "xls/dslx/exhaustiveness/interp_value_interval.h" +#include "xls/dslx/exhaustiveness/nd_region.h" +#include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/pos.h" +#include "xls/dslx/import_data.h" +#include "xls/dslx/interp_value.h" +#include "xls/dslx/type_system/type.h" +#include "xls/dslx/type_system/type_info.h" + +namespace xls::dslx { + +// Object that we can incrementally feed match arms/patterns to and ask whether +// we've reached a point where the patterns are exhaustive. This is useful for +// flagging a warning right when we've reached the point that the arms are +// exhaustive. +class MatchExhaustivenessChecker { + public: + MatchExhaustivenessChecker(const Span& matched_expr_span, + const ImportData& import_data, + const TypeInfo& type_info, + const Type& matched_type); + + // Returns whether we've reached a point of exhaustiveness after incorporating + // the given `pattern`. + bool AddPattern(const NameDefTree& pattern); + + // Returns whether, based on already-added patterns, we're exhaustive. + bool IsExhaustive() const; + + // This method returns an optional "sample" value from the uncovered input + // space. It picks (for now) the first uncovered ND region and for each + // dimension, takes the lower bound. If there is only one dimension the value + // is returned directly; otherwise the components are aggregated into a tuple. + // + std::optional SampleSimplestUncoveredValue() const; + + private: + const FileTable& file_table() const { return type_info_.file_table(); } + + const Span matched_expr_span_; + + const ImportData& import_data_; + const TypeInfo& type_info_; + const Type& matched_type_; + + // Flattened version of the pattern tuple, each element of this vector is a + // dimension in the NdRegion below. + std::vector leaf_types_; + + // The remaining region of the value space that we need to test. + NdRegion remaining_; +}; + +// Returns the full interval range we use to represent the contents of an enum +// type -- exposed in the header for purposes of testing. +InterpValueInterval MakeFullIntervalForEnumType(const EnumType& enum_type); + +// Returns the point interval range we use to represent the contents of an enum +// value -- exposed in the header for purposes of testing. +InterpValueInterval MakePointIntervalForEnumType(const EnumType& enum_type, + const InterpValue& value, + const ImportData& import_data); + +} // namespace xls::dslx + +#endif // XLS_DSLX_EXHAUSTIVENESS_MATCH_EXHAUSTIVENESS_CHECKER_H_ diff --git a/xls/dslx/exhaustiveness/nd_region.cc b/xls/dslx/exhaustiveness/nd_region.cc new file mode 100644 index 0000000000..a02bdfc9e2 --- /dev/null +++ b/xls/dslx/exhaustiveness/nd_region.cc @@ -0,0 +1,235 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/exhaustiveness/nd_region.h" + +#include "absl/log/log.h" +#include "absl/strings/str_join.h" +namespace xls::dslx { +namespace { + +// Returns the tightest lower bound for the two given intervals. +InterpValue TightestLowerBound(const InterpValueInterval& a, + const InterpValueInterval& b) { + return a.min().Max(b.min()).value(); +} + +// Returns the tightest upper bound for the two given intervals. +InterpValue TightestUpperBound(const InterpValueInterval& a, + const InterpValueInterval& b) { + return a.max().Min(b.max()).value(); +} + +} // namespace + +/* static */ NdInterval NdInterval::MakePoint( + absl::Span point) { + std::vector dims; + for (const InterpValue& point_value : point) { + dims.push_back(InterpValueInterval(point_value, point_value)); + } + return NdInterval(dims); +} + +/* static */ NdInterval NdInterval::MakeContiguous( + absl::Span start_point, + absl::Span end_point) { + std::vector dims; + for (int64_t i = 0; i < start_point.size(); ++i) { + dims.push_back(InterpValueInterval(start_point[i], end_point[i])); + } + return NdInterval(dims); +} + +std::optional NdIntervalWithEmpty::ToNonEmpty() const { + std::vector non_empty_dims; + for (const auto& dim : dims_) { + if (dim.has_value()) { + non_empty_dims.push_back(dim.value()); + } else { + return std::nullopt; + } + } + return NdInterval(std::move(non_empty_dims)); +} + +std::string NdIntervalWithEmpty::ToString(bool show_types) const { + std::vector dims_str; + for (const std::optional& dim : dims_) { + if (dim.has_value()) { + dims_str.push_back(dim.value().ToString(show_types)); + } else { + dims_str.push_back(""); + } + } + return absl::StrCat("[", absl::StrJoin(dims_str, ", "), "]"); +} + +NdInterval::NdInterval(std::vector dims) + : dims_(std::move(dims)) {} + +bool NdInterval::Covers(const NdInterval& other) const { + for (int64_t i = 0; i < dims_.size(); ++i) { + if (!dims_[i].Covers(other.dims_[i])) { + return false; + } + } + return true; +} + +bool NdInterval::Intersects(const NdInterval& other) const { + CHECK_EQ(dims_.size(), other.dims_.size()) + << "Cannot intersect intervals with different numbers of dimensions: " + << ToString(/*show_types=*/false) << " and " + << other.ToString(/*show_types=*/false); + for (int64_t i = 0; i < dims_.size(); ++i) { + if (!dims_[i].Intersects(other.dims_[i])) { + return false; + } + } + return true; +} + +std::vector NdInterval::SubtractInterval( + const NdIntervalWithEmpty& other_with_empty) const { + std::optional non_empty = other_with_empty.ToNonEmpty(); + // If there's zero volume, subtraction is trivial. + if (!non_empty.has_value()) { + return {*this}; + } + return SubtractInterval(non_empty.value()); +} + +std::vector NdInterval::SubtractInterval( + const NdInterval& other) const { + CHECK_EQ(dims_.size(), other.dims_.size()) + << "Cannot subtract intervals with different numbers of dimensions: " + << ToString(/*show_types=*/false) << " and " + << other.ToString(/*show_types=*/false); + if (!other.Intersects(*this)) { + return {*this}; + } + if (other.Covers(*this)) { + return {}; + } + + VLOG(5) << "SubtractInterval; subtracting: " + << other.ToString(/*show_types=*/false) + << " from: " << ToString(/*show_types=*/false); + + std::vector pieces; + + // We'll work with a copy of our interval dimensions, + // "remaining" will be gradually shrunk to the (to-be-removed) intersection. + std::vector remaining = dims_; + for (int64_t dim_idx = 0; dim_idx < remaining.size(); ++dim_idx) { + const InterpValueInterval& remaining_interval = remaining[dim_idx]; + + // Compute the intersection in dimension `dim_idx`: + // `lower` is the maximum of our lower bound and the other's lower bound. + // `upper` is the minimum of our upper bound and the other's upper bound. + InterpValue lower = + TightestLowerBound(remaining[dim_idx], other.dims_[dim_idx]); + InterpValue upper = + TightestUpperBound(remaining[dim_idx], other.dims_[dim_idx]); + + // If there is a gap on the lower side, "peel off" that slice. + if (remaining_interval.min() < lower) { + VLOG(5) << absl::StreamFormat( + "remaining lower for dimension %d is %s which is < %s", dim_idx, + remaining_interval.min().ToString(/*show_types=*/true), + lower.ToString(/*show_types=*/true)); + std::vector new_dims = remaining; + new_dims[dim_idx] = InterpValueInterval(remaining_interval.min(), + lower.Decrement().value()); + pieces.push_back(NdInterval(new_dims)); + + // Update the remaining interval so that it now starts at L. + remaining[dim_idx] = InterpValueInterval(lower, remaining_interval.max()); + } + + // If there is a gap on the upper side, "peel off" that slice. + if (upper < remaining_interval.max()) { + std::vector new_dims = remaining; + new_dims[dim_idx] = InterpValueInterval(upper.Increment().value(), + remaining_interval.max()); + pieces.push_back(NdInterval(new_dims)); + + // Update the remaining interval so that it now ends at U. + remaining[dim_idx] = InterpValueInterval(remaining_interval.min(), upper); + } + } + + VLOG(5) << "Resulting pieces: " + << absl::StrJoin(pieces, ", ", + [&](std::string* out, const NdInterval& interval) { + absl::StrAppend( + out, interval.ToString(/*show_types=*/false)); + }); + return pieces; +} + +std::string NdInterval::ToString(bool show_types) const { + std::vector lower_components; + std::vector upper_components; + for (const auto& dim : dims_) { + // Assume that each 'dim' (of type InterpValueInterval) exposes: + // - a method lower() that returns the lower bound (an InterpValue) + // - a method upper() that returns the upper bound (an InterpValue) + // and that InterpValue has a ToString(bool) method. + lower_components.push_back(dim.min().ToString(/*humanize=*/!show_types)); + upper_components.push_back(dim.max().ToString(/*humanize=*/!show_types)); + } + // Build lower and upper corner strings. + std::string lower_str = + absl::StrCat("[", absl::StrJoin(lower_components, ", "), "]"); + std::string upper_str = + absl::StrCat("[", absl::StrJoin(upper_components, ", "), "]"); + // Final string: a two-element list [lower_corner, upper_corner]. + return absl::StrCat("[", lower_str, ", ", upper_str, "]"); +} + +std::string NdRegion::ToString(bool show_types) const { + std::string guts = absl::StrJoin( + disjoint_, ", ", [&](std::string* out, const NdInterval& interval) { + absl::StrAppend(out, interval.ToString(show_types)); + }); + return absl::StrCat("{", guts, "}"); +} + +ABSL_MUST_USE_RESULT NdRegion +NdRegion::SubtractInterval(const NdIntervalWithEmpty& other) const { + std::vector new_disjoint; + for (const NdInterval& old_interval : disjoint_) { + std::vector new_interval = old_interval.SubtractInterval(other); + new_disjoint.insert(new_disjoint.end(), new_interval.begin(), + new_interval.end()); + VLOG(5) << "SubtractInverval; subtracting " + << other.ToString(/*show_types=*/false) << " from old interval " + << old_interval.ToString(/*show_types=*/false) << " produced " + << absl::StrJoin(new_interval, ", ", + [&](std::string* out, const NdInterval& interval) { + absl::StrAppend(out, interval.ToString( + /*show_types=*/false)); + }); + } + NdRegion result(dim_extents_, new_disjoint); + VLOG(5) << "SubtractInterval; region-level result; subtracting " + << other.ToString(/*show_types=*/false) << " from " + << ToString(/*show_types=*/false) << " produced " + << result.ToString(/*show_types=*/false); + return result; +} + +} // namespace xls::dslx diff --git a/xls/dslx/exhaustiveness/nd_region.h b/xls/dslx/exhaustiveness/nd_region.h new file mode 100644 index 0000000000..a88873ee8a --- /dev/null +++ b/xls/dslx/exhaustiveness/nd_region.h @@ -0,0 +1,137 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef XLS_DSLX_EXHAUSTIVENESS_ND_REGION_H_ +#define XLS_DSLX_EXHAUSTIVENESS_ND_REGION_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/types/span.h" +#include "xls/dslx/exhaustiveness/interp_value_interval.h" +#include "xls/dslx/interp_value.h" + +namespace xls::dslx { + +class NdIntervalWithEmpty; + +// Represents a contiguous interval in n-dimensional space. +// +// This is also called a "hyper-rectangle" because it's a contiguous rectangular +// space in >= 1D. +// +// It is lower bounded (inclusive) by the minimum values of each dimension and +// upper bounded (inclusive) by the maximum values of each dimension. We prefer +// inclusive bounds so we can easily represent the maximum value in a given type +// without worrying about whether the exclusive limit requires an extra bit. +class NdInterval { + public: + // Note: we do not provide a `MakeFull` factory on NdInterval because + // we expect that the upper layers will determine the full range for types + // that are sparse within some bit range, like enums. We don't want to provide + // a method that could be used to accidentally create an over-wide range for a + // type such as an enum because we think we need to create a full range for + // its underlying bit space. (i.e. imagine an enum with 3 values in a u2 + // space). + + static NdInterval MakePoint(absl::Span point); + + // Note that end_point is included. + static NdInterval MakeContiguous(absl::Span start_point, + absl::Span end_point); + + explicit NdInterval(std::vector dims); + + bool Covers(const NdInterval& other) const; + + bool Intersects(const NdInterval& other) const; + + // Subtracts `other` from this contiguous n-dimensional interval space. + // This works by "peeling off" slices along each dimension where + // our interval extends beyond the subtracting interval. + // (For example, in 1D subtracting [3,6] from [0,9] produces [0,2] on the left + // and [7,9] on the right.) + std::vector SubtractInterval( + const NdIntervalWithEmpty& other) const; + std::vector SubtractInterval(const NdInterval& other) const; + + std::string ToString(bool show_types) const; + + // Note: not safe to hold across a mutating operation. + absl::Span dims() const { return dims_; } + + private: + std::vector dims_; +}; + +// Variation on NdInterval above that allows any given dimension to have an +// interval with no volume. +// +// This is useful because ranges can be zero volume or enums can have zero +// definitions. +class NdIntervalWithEmpty { + public: + explicit NdIntervalWithEmpty( + std::vector> dims) + : dims_(std::move(dims)) {} + + std::optional ToNonEmpty() const; + + std::string ToString(bool show_types) const; + + private: + std::vector> dims_; +}; + +class NdRegion { + public: + // Note: prefer this function when creating a "Full" region. For more sparse + // types such as enums the full range is not the same as the underlying bit + // space. + static NdRegion MakeFromNdInterval(const NdInterval& interval, + std::vector dim_extents) { + return NdRegion(std::move(dim_extents), {interval}); + } + + static NdRegion MakeEmpty(std::vector dim_extents) { + NdRegion region(std::move(dim_extents), {}); + return region; + } + + ABSL_MUST_USE_RESULT NdRegion + SubtractInterval(const NdIntervalWithEmpty& other) const; + + bool IsEmpty() const { return disjoint_.empty(); } + + std::string ToString(bool show_types) const; + + // Note: not safe to hold across a mutating operation. + absl::Span disjoint() const { return disjoint_; } + + private: + explicit NdRegion(std::vector dim_extents, + std::vector disjoint) + : dim_extents_(std::move(dim_extents)), disjoint_(std::move(disjoint)) {} + + // The extents (i.e. limits) in all the dimensions of the regions. + std::vector dim_extents_; + + // Disjoint intervals that describe what is filled in within the region. + std::vector disjoint_; +}; + +} // namespace xls::dslx + +#endif // XLS_DSLX_EXHAUSTIVENESS_ND_REGION_H_ diff --git a/xls/dslx/exhaustiveness/nd_region_test.cc b/xls/dslx/exhaustiveness/nd_region_test.cc new file mode 100644 index 0000000000..8e1729eff3 --- /dev/null +++ b/xls/dslx/exhaustiveness/nd_region_test.cc @@ -0,0 +1,307 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/exhaustiveness/nd_region.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace xls::dslx { +namespace { + +// Helper for making a full NdRegion according to a full bit value space, with +// each dimension going [0, dim_extent] (inclusive). +// +// We don't use this in the production code because we don't want to make +// it easy to accidentally misinterpret types like enum which may not populate +// the full underlying bit space but rather some sparse subset of it. +NdInterval MakeFullNdInterval(absl::Span dim_extents) { + std::vector intervals; + for (const InterpValue& dim_extent : dim_extents) { + bool is_signed = dim_extent.IsSigned(); + int64_t bit_count = dim_extent.GetBitCount().value(); + InterpValue min = InterpValue::MakeZeroValue(is_signed, bit_count); + InterpValue max = dim_extent; + intervals.push_back(InterpValueInterval(min, max)); + } + return NdInterval(std::move(intervals)); +} + +} // namespace + +TEST(NdRegionTest, SimpleRegionIntersectsSelf) { + std::vector dim_extents = { + InterpValue::MakeUBits(1, 1), + InterpValue::MakeUBits(1, 1), + }; + const NdInterval u1_full = MakeFullNdInterval(dim_extents); + EXPECT_TRUE(u1_full.Intersects(u1_full)); + EXPECT_EQ(u1_full.ToString(/*show_types=*/false), "[[0, 0], [1, 1]]"); +} + +TEST(NdRegionTest, Rectangle2DSubtractCorner) { + std::vector dim_extents = { + InterpValue::MakeUBits(1, 1), + InterpValue::MakeUBits(1, 1), + }; + const NdInterval u1_full = MakeFullNdInterval(dim_extents); + const NdInterval u1_corner = + NdInterval::MakePoint({dim_extents[0], dim_extents[1]}); + std::vector pieces = u1_full.SubtractInterval(u1_corner); + EXPECT_EQ(pieces.size(), 2); + // original space is: + // [0,0] [0,1] [1,0] [1,1] + // we subtracted the last one, so the resulting space is: + // [0,0] [0,1] [1,0] + // which can be represented as [[0,0]..[0,1]] union [[1,0]..[1,0]] + EXPECT_EQ(pieces[0].ToString(/*show_types=*/false), "[[0, 0], [0, 1]]"); + EXPECT_EQ(pieces[1].ToString(/*show_types=*/false), "[[1, 0], [1, 0]]"); +} + +TEST(NdRegionTest, Rectangle2DSubtractAllWhereXEquals1) { + std::vector dim_extents = { + InterpValue::MakeUBits(1, 1), + InterpValue::MakeUBits(1, 1), + }; + const NdInterval u1_full = MakeFullNdInterval(dim_extents); + EXPECT_EQ(u1_full.ToString(/*show_types=*/false), "[[0, 0], [1, 1]]"); + + // Make a slice that represents all the points where x = 1. + const NdInterval u1_x1 = NdInterval::MakeContiguous( + { + InterpValue::MakeUBits(1, 1), // x = 1 + InterpValue::MakeUBits(1, 0), // y = 0 + }, + dim_extents); + EXPECT_EQ(u1_x1.ToString(/*show_types=*/false), "[[1, 0], [1, 1]]"); + + std::vector pieces = u1_full.SubtractInterval(u1_x1); + EXPECT_EQ(pieces.size(), 1); + EXPECT_EQ(pieces[0].ToString(/*show_types=*/false), "[[0, 0], [0, 1]]"); +} + +TEST(NdRegionTest, Rectangle2DSubtractSelf) { + std::vector dim_extents = { + InterpValue::MakeUBits(1, 1), + InterpValue::MakeUBits(1, 1), + }; + const NdInterval u1_full = MakeFullNdInterval(dim_extents); + std::vector pieces = u1_full.SubtractInterval(u1_full); + EXPECT_TRUE(pieces.empty()); +} + +// From the following rectangle: +// +-------+-------+-------+-------+ +// | (0,0) | (1,0) | (2,0) | (3,0) | +// +-------+-------+-------+-------+ +// | (0,1) | (1,1) | (2,1) | (3,1) | < +// +-------+-------+-------+-------+ < slice these two rows out +// | (0,2) | (1,2) | (2,2) | (3,2) | < +// +-------+-------+-------+-------+ +// | (0,3) | (1,3) | (2,3) | (3,3) | +// +-------+-------+-------+-------+ +TEST(NdRegionTest, Rectangle2DSubtractCenterSliceInLeadingDim) { + // Make a rectangle from point (0, 0) to point (3, 3) inclusive on both ends. + std::vector dim_extents = { + InterpValue::MakeUBits(3, 3), + InterpValue::MakeUBits(3, 3), + }; + const NdInterval full = MakeFullNdInterval(dim_extents); + + const NdInterval center_slice = NdInterval::MakeContiguous( + { + InterpValue::MakeUBits(3, 0), + InterpValue::MakeUBits(3, 1), + }, + { + InterpValue::MakeUBits(3, 3), + InterpValue::MakeUBits(3, 2), + }); + + // The remainder should be two slices: from point (0, 0) to (0, 3) and from + // point (3, 0) to (3, 3). + std::vector pieces = full.SubtractInterval(center_slice); + EXPECT_EQ(pieces.size(), 2); + EXPECT_EQ(pieces[0].ToString(/*show_types=*/false), "[[0, 0], [3, 0]]"); + EXPECT_EQ(pieces[1].ToString(/*show_types=*/false), "[[0, 3], [3, 3]]"); +} + +// From the following rectangle: +// +-------+-------+-------+-------+ +// | (0,0) | (1,0) | (2,0) | (3,0) | +// +-------+-------+-------+-------+ +// | (0,1) | (1,1) | (2,1) | (3,1) | +// +-------+-------+-------+-------+ +// | (0,2) | (1,2) | (2,2) | (3,2) | +// +-------+-------+-------+-------+ +// | (0,3) | (1,3) | (2,3) | (3,3) | +// +-------+-------+-------+-------+ +// ^^^^^^^^^^^^^^^ sub these two columns out +TEST(NdRegionTest, Rectangle2DSubtractCenterSliceInTrailingDim) { + // Make a rectangle from point (0, 0) to point (3, 3) inclusive on both ends. + std::vector dim_extents = { + InterpValue::MakeUBits(3, 3), + InterpValue::MakeUBits(3, 3), + }; + const NdInterval full = MakeFullNdInterval(dim_extents); + + // Slice out the center two values of the rectangle, i.e. (0, 1) to (3, 2) -- + // not we take all the y values and two of the x values. + const NdInterval center_slice = NdInterval::MakeContiguous( + { + InterpValue::MakeUBits(3, 1), + InterpValue::MakeUBits(3, 0), + }, + { + InterpValue::MakeUBits(3, 2), + InterpValue::MakeUBits(3, 3), + }); + + // The remainder should be two slices: from point (0, 0) to (0, 3) and from + // point (3, 0) to (3, 3). + std::vector pieces = full.SubtractInterval(center_slice); + EXPECT_EQ(pieces.size(), 2); + EXPECT_EQ(pieces[0].ToString(/*show_types=*/false), "[[0, 0], [0, 3]]"); + EXPECT_EQ(pieces[1].ToString(/*show_types=*/false), "[[3, 0], [3, 3]]"); +} + +TEST(NdRegionTest, NdIntervalIntersectsProperly) { + // Construct a 2D interval A from (0,0) to (5,10). + std::vector startA = { + InterpValue::MakeU32(0), // x start + InterpValue::MakeU32(0) // y start + }; + std::vector endA = { + InterpValue::MakeU32(5), // x end + InterpValue::MakeU32(10) // y end + }; + NdInterval intervalA = NdInterval::MakeContiguous(startA, endA); + + // Construct a 2D interval B from (3,2) to (7,12) that overlaps with A. + std::vector startB = { + InterpValue::MakeU32(3), // x start (overlaps with A: 3 <= x <= 5) + InterpValue::MakeU32(2) // y start (overlaps with A: 2 <= y <= 10) + }; + std::vector endB = { + InterpValue::MakeU32(7), // x end + InterpValue::MakeU32(12) // y end + }; + NdInterval intervalB = NdInterval::MakeContiguous(startB, endB); + + // These intervals should intersect. + EXPECT_TRUE(intervalA.Intersects(intervalB)); + EXPECT_TRUE(intervalB.Intersects(intervalA)); + + // Construct a 2D interval C from (6,0) to (10,10) that is disjoint + // from A along the x-dimension (since A is [0,5]). + std::vector startC = { + InterpValue::MakeU32(6), // x start (outside A's x range) + InterpValue::MakeU32(0) // y start + }; + std::vector endC = { + InterpValue::MakeU32(10), // x end + InterpValue::MakeU32(10) // y end (overlaps with A in y) + }; + NdInterval intervalC = NdInterval::MakeContiguous(startC, endC); + + // These intervals should not intersect because intervalC's x-dimension is + // disjoint. + EXPECT_FALSE(intervalA.Intersects(intervalC)); + EXPECT_FALSE(intervalC.Intersects(intervalA)); +} + +// From the following rectangle: +// +-------+-------+-------+-------+ +// | (0,0) | (1,0) | (2,0) | (3,0) | +// +-------+-------+-------+-------+ +// | (0,1) |*(1,1)*|*(2,1)*| (3,1) | +// +-------+-------+-------+-------+ +// | (0,2) |*(1,2)*|*(2,2)*| (3,2) | +// +-------+-------+-------+-------+ +// | (0,3) | (1,3) | (2,3) | (3,3) | +// +-------+-------+-------+-------+ +TEST(NdRegionTest, Rectangle2DSubtractCentralRectangle) { + // Create a full rectangle from (0,0) to (3,3) + std::vector dim_extents = {InterpValue::MakeUBits(3, 3), + InterpValue::MakeUBits(3, 3)}; + const NdInterval full = MakeFullNdInterval(dim_extents); + + // Create a central sub-rectangle from (1,1) to (2,2) + NdInterval central = NdInterval::MakeContiguous( + {InterpValue::MakeUBits(3, 1), InterpValue::MakeUBits(3, 1)}, + {InterpValue::MakeUBits(3, 2), InterpValue::MakeUBits(3, 2)}); + + // We should end up with: + // * (0,0) to (0,3) -- left slice in above diagram + // * (1,0) to (2,0) -- upper slice in above diagram + // * (1,3) to (2,3) -- lower slice in above diagram + // * (3,0) to (3,3) -- right slice in above diagram + std::vector pieces = full.SubtractInterval(central); + + std::vector expected = {"[[0, 0], [0, 3]]", "[[1, 0], [2, 0]]", + "[[1, 3], [2, 3]]", "[[3, 0], [3, 3]]"}; + std::vector actual; + for (const NdInterval& interval : pieces) { + actual.push_back(interval.ToString(false)); + } + std::sort(expected.begin(), expected.end()); + std::sort(actual.begin(), actual.end()); + EXPECT_EQ(actual, expected); +} + +TEST(NdRegionTest, OneDimensionalSubtractMiddle) { + // Create a 1D full interval: [0, 9] + std::vector full_extent = {InterpValue::MakeU32(9)}; + NdInterval full = MakeFullNdInterval(full_extent); + + // Define a subtraction interval in the middle: [3, 6] + NdInterval sub = NdInterval::MakeContiguous({InterpValue::MakeU32(3)}, + {InterpValue::MakeU32(6)}); + std::vector pieces = full.SubtractInterval(sub); + + // Expect two pieces: + // left piece: [0] ... [2] (since 3 - 1 = 2) + // right piece: [7] ... [9] (since 6 + 1 = 7) + EXPECT_EQ(pieces.size(), 2); + std::vector expected = {"[[0], [2]]", "[[7], [9]]"}; + std::vector actual; + for (const NdInterval& piece : pieces) { + actual.push_back(piece.ToString(/*show_types=*/false)); + } + std::sort(expected.begin(), expected.end()); + std::sort(actual.begin(), actual.end()); + EXPECT_EQ(actual, expected); +} + +TEST(NdRegionTest, Rectangle2DSubtractNonIntersecting) { + // Create a 2D full region: from (0,0) to (3,3) + std::vector dim_extents = { + InterpValue::MakeUBits(3, 3), + InterpValue::MakeUBits(3, 3), + }; + const NdInterval full = MakeFullNdInterval(dim_extents); + + // Create a subtraction interval (a point) that lies completely outside + // the full region. (Here, the point (4,4) is outside since the full region + // ends at (3,3)). + NdInterval non_intersecting = NdInterval::MakePoint( + {InterpValue::MakeUBits(3, 4), InterpValue::MakeUBits(3, 4)}); + + std::vector pieces = full.SubtractInterval(non_intersecting); + // Since there is no intersection, expect the original region to be returned. + EXPECT_EQ(pieces.size(), 1); + EXPECT_EQ(pieces[0].ToString(/*show_types=*/false), "[[0, 0], [3, 3]]"); +} + +} // namespace xls::dslx diff --git a/xls/dslx/import_data.h b/xls/dslx/import_data.h index d223c96f3f..1c867c860a 100644 --- a/xls/dslx/import_data.h +++ b/xls/dslx/import_data.h @@ -157,6 +157,8 @@ class ImportData { absl::StatusOr GetRootTypeInfoForNode(const AstNode* node); absl::StatusOr GetRootTypeInfoForNode( const AstNode* node) const; + + // As above but gets the type info for a directly-provided `module`. absl::StatusOr GetRootTypeInfo(const Module* module); // The "top level bindings" for a given module are the values that get @@ -235,7 +237,7 @@ class ImportData { WarningKindSet); friend ImportData CreateImportDataForTest( - std::unique_ptr vfs); + std::unique_ptr vfs, WarningKindSet warnings); friend std::unique_ptr CreateImportDataPtrForTest(); ImportData(std::filesystem::path stdlib_path, diff --git a/xls/dslx/interp_value.cc b/xls/dslx/interp_value.cc index 7e047b6ee5..dbbe385a07 100644 --- a/xls/dslx/interp_value.cc +++ b/xls/dslx/interp_value.cc @@ -94,6 +94,14 @@ std::string TagToString(InterpValueTag tag) { Bits(bit_count)}; } +/* static */ InterpValue InterpValue::MakeOneValue(bool is_signed, + int64_t bit_count) { + CHECK_GT(bit_count, 0); + return InterpValue( + is_signed ? InterpValueTag::kSBits : InterpValueTag::kUBits, + Bits(bit_count).UpdateWithSet(0, true)); +} + /* static */ InterpValue InterpValue::MakeMaxValue(bool is_signed, int64_t bit_count) { auto bits = Bits::AllOnes(bit_count); @@ -487,6 +495,40 @@ absl::StatusOr InterpValue::Lt(const InterpValue& other) const { return Compare(*this, other, &bits_ops::ULessThan, &bits_ops::SLessThan); } +std::optional InterpValue::Increment() const { + CHECK(IsBits()); + Bits b = GetBitsOrDie(); + if (*this == MakeMaxValue(IsSigned(), GetBitCount().value())) { + return std::nullopt; // Overflow case. + } + return InterpValue(tag_, bits_ops::Increment(b)); +} + +std::optional InterpValue::Decrement() const { + CHECK(IsBits()); + if (*this == MakeMinValue(IsSigned(), GetBitCount().value())) { + return std::nullopt; // Underflow case. + } + Bits b = GetBitsOrDie(); + return InterpValue(tag_, bits_ops::Decrement(b)); +} + +absl::StatusOr InterpValue::Min(const InterpValue& other) const { + XLS_ASSIGN_OR_RETURN(InterpValue lt, Lt(other)); + if (lt.IsTrue()) { + return *this; + } + return other; +} + +absl::StatusOr InterpValue::Max(const InterpValue& other) const { + XLS_ASSIGN_OR_RETURN(InterpValue lt, Lt(other)); + if (lt.IsTrue()) { + return other; + } + return *this; +} + absl::StatusOr InterpValue::BitwiseNegate() const { XLS_ASSIGN_OR_RETURN(Bits b, GetBits()); return InterpValue(tag_, bits_ops::Not(b)); diff --git a/xls/dslx/interp_value.h b/xls/dslx/interp_value.h index bb69fb66d0..0336fe9edf 100644 --- a/xls/dslx/interp_value.h +++ b/xls/dslx/interp_value.h @@ -101,6 +101,7 @@ class InterpValue { static InterpValue MakeSBits(int64_t bit_count, int64_t value); static InterpValue MakeZeroValue(bool is_signed, int64_t bit_count); + static InterpValue MakeOneValue(bool is_signed, int64_t bit_count); static InterpValue MakeMaxValue(bool is_signed, int64_t bit_count); static InterpValue MakeMinValue(bool is_signed, int64_t bit_count); @@ -256,6 +257,9 @@ class InterpValue { absl::StatusOr Add(const InterpValue& other) const; absl::StatusOr Sub(const InterpValue& other) const; + std::optional Decrement() const; + std::optional Increment() const; + absl::StatusOr Mul(const InterpValue& other) const; absl::StatusOr Shl(const InterpValue& other) const; absl::StatusOr Shrl(const InterpValue& other) const; @@ -284,6 +288,10 @@ class InterpValue { absl::StatusOr Gt(const InterpValue& other) const; absl::StatusOr Ge(const InterpValue& other) const; + absl::StatusOr Min(const InterpValue& other) const; + + absl::StatusOr Max(const InterpValue& other) const; + // Performs the signed comparison defined by "method". // // "method" should be a value in the set {slt, sle, sgt, sge} or an diff --git a/xls/dslx/interpreter_main.cc b/xls/dslx/interpreter_main.cc index 8e4b3fdf0d..865ce597b6 100644 --- a/xls/dslx/interpreter_main.cc +++ b/xls/dslx/interpreter_main.cc @@ -145,13 +145,8 @@ absl::StatusOr RealMain( std::optional xml_output_file, EvaluatorType evaluator) { XLS_ASSIGN_OR_RETURN( WarningKindSet warnings, - WarningKindSetFromDisabledString(absl::GetFlag(FLAGS_disable_warnings))); - - XLS_ASSIGN_OR_RETURN( - const WarningKindSet warnings_to_enable, - WarningKindSetFromString(absl::GetFlag(FLAGS_enable_warnings))); - - warnings |= warnings_to_enable; + GetWarningsSetFromFlags(absl::GetFlag(FLAGS_enable_warnings), + absl::GetFlag(FLAGS_disable_warnings))); RealFilesystem vfs; diff --git a/xls/dslx/interpreter_test.py b/xls/dslx/interpreter_test.py index 6234111112..d5eb8f464e 100644 --- a/xls/dslx/interpreter_test.py +++ b/xls/dslx/interpreter_test.py @@ -141,7 +141,7 @@ def test_fail_incomplete_match(self): program, warnings_as_errors=False, want_error=True ) self.assertIn( - 'The program being interpreted failed! The value was not matched', + 'Match pattern is not exhaustive', stderr, ) diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index b330fe7e3c..cfde052e2e 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -1123,11 +1123,10 @@ absl::Status FunctionConverter::HandleBuiltinWideningCast( } absl::Status FunctionConverter::HandleMatch(const Match* node) { - if (node->arms().empty() || - !node->arms().back()->patterns()[0]->IsIrrefutable()) { + if (node->arms().empty()) { return IrConversionErrorStatus( node->span(), - "Only matches with trailing irrefutable patterns (i.e. `_ => ...`) " + "Only matches with complete patterns (i.e. a trailing `_ => ...`) " "are currently supported for IR conversion.", file_table()); } diff --git a/xls/dslx/ir_convert/ir_converter_main.cc b/xls/dslx/ir_convert/ir_converter_main.cc index 90070f787f..61eed37aaa 100644 --- a/xls/dslx/ir_convert/ir_converter_main.cc +++ b/xls/dslx/ir_convert/ir_converter_main.cc @@ -93,9 +93,13 @@ absl::Status RealMain(absl::Span paths) { bool verify_ir = ir_converter_options.verify(); bool convert_tests = ir_converter_options.convert_tests(); bool warnings_as_errors = ir_converter_options.warnings_as_errors(); - XLS_ASSIGN_OR_RETURN(WarningKindSet enabled_warnings, - WarningKindSetFromDisabledString( - ir_converter_options.disable_warnings())); + + // Start with the default set, then enable the to-enable and then disable the + // to-disable. + XLS_ASSIGN_OR_RETURN( + WarningKindSet warnings, + GetWarningsSetFromFlags(ir_converter_options.enable_warnings(), + ir_converter_options.disable_warnings())); std::optional default_fifo_config; if (ir_converter_options.has_default_fifo_config()) { XLS_ASSIGN_OR_RETURN( @@ -107,7 +111,7 @@ absl::Status RealMain(absl::Span paths) { .emit_fail_as_assert = emit_fail_as_assert, .verify_ir = verify_ir, .warnings_as_errors = warnings_as_errors, - .enabled_warnings = enabled_warnings, + .enabled_warnings = warnings, .convert_tests = convert_tests, .default_fifo_config = default_fifo_config, }; diff --git a/xls/dslx/ir_convert/ir_converter_options_flags.cc b/xls/dslx/ir_convert/ir_converter_options_flags.cc index ae8fd70a5f..18ebe39d66 100644 --- a/xls/dslx/ir_convert/ir_converter_options_flags.cc +++ b/xls/dslx/ir_convert/ir_converter_options_flags.cc @@ -55,6 +55,9 @@ ABSL_FLAG(bool, verify, true, ABSL_FLAG(std::optional, disable_warnings, std::nullopt, "Comma-delimited list of warnings to disable -- not generally " "recommended, but can be used in exceptional circumstances"); +ABSL_FLAG(std::optional, enable_warnings, std::nullopt, + "Comma-delimited list of warnings to enable -- this is only useful " + "if/when some warnings are disabled in the default warning set"); ABSL_FLAG(bool, warnings_as_errors, true, "Whether to fail early, as an error, if warnings are detected"); ABSL_FLAG(std::optional, interface_proto_file, std::nullopt, diff --git a/xls/dslx/ir_convert/ir_converter_options_flags.proto b/xls/dslx/ir_convert/ir_converter_options_flags.proto index 4d6839d51c..d9dd25140e 100644 --- a/xls/dslx/ir_convert/ir_converter_options_flags.proto +++ b/xls/dslx/ir_convert/ir_converter_options_flags.proto @@ -36,4 +36,5 @@ message IrConverterOptionsFlagsProto { optional string interface_proto_file = 11; optional string interface_textproto_file = 12; optional FifoConfigProto default_fifo_config = 13; + optional string enable_warnings = 14; } diff --git a/xls/dslx/ir_convert/ir_converter_test.cc b/xls/dslx/ir_convert/ir_converter_test.cc index cc23eaad7a..05908ca631 100644 --- a/xls/dslx/ir_convert/ir_converter_test.cc +++ b/xls/dslx/ir_convert/ir_converter_test.cc @@ -380,7 +380,7 @@ TEST(IrConverterTest, MatchTupleOfTuplesRestOfTuple) { })); } -TEST(IrConverterTest, MatchRestOfTupleAsArmFails) { +TEST(IrConverterTest, MatchRestOfTupleAsTrailingArm) { const char* program = R"(fn f() -> u32 { let t = (u32:1, u32:2); @@ -389,12 +389,11 @@ TEST(IrConverterTest, MatchRestOfTupleAsArmFails) { (..) => u32:1 } })"; - auto import_data = CreateImportDataForTest(); - EXPECT_THAT( - ConvertOneFunctionForTest(program, "f", import_data, - ConvertOptions{.emit_positions = false}), - StatusIs(absl::StatusCode::kInternal, - HasSubstr("with trailing irrefutable patterns"))); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(program, "f", + ConvertOptions{.emit_positions = false})); + ExpectIr(converted, TestName()); } TEST(IrConverterTest, Struct) { diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_MatchRestOfTupleAsTrailingArm.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_MatchRestOfTupleAsTrailingArm.ir new file mode 100644 index 0000000000..5052eae5fb --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_MatchRestOfTupleAsTrailingArm.ir @@ -0,0 +1,21 @@ +package test_module + +file_number 0 "test_module.x" + +top fn __test_module__f() -> bits[32] { + literal.1: bits[32] = literal(value=1, id=1) + literal.2: bits[32] = literal(value=2, id=2) + t: (bits[32], bits[32]) = tuple(literal.1, literal.2, id=3) + literal.6: bits[32] = literal(value=1, id=6) + tuple_index.5: bits[32] = tuple_index(t, index=0, id=5) + literal.4: bits[1] = literal(value=1, id=4) + eq.7: bits[1] = eq(literal.6, tuple_index.5, id=7) + and.8: bits[1] = and(literal.4, eq.7, id=8) + literal.10: bits[1] = literal(value=1, id=10) + and.11: bits[1] = and(and.8, literal.10, id=11) + concat.14: bits[1] = concat(and.11, id=14) + tuple_index.9: bits[32] = tuple_index(t, index=1, id=9) + literal.13: bits[32] = literal(value=1, id=13) + literal.12: bits[1] = literal(value=1, id=12) + ret priority_sel.15: bits[32] = priority_sel(concat.14, cases=[tuple_index.9], default=literal.13, id=15) +} diff --git a/xls/dslx/prove_quickcheck_main.cc b/xls/dslx/prove_quickcheck_main.cc index b40191f94b..e9c60ce4ee 100644 --- a/xls/dslx/prove_quickcheck_main.cc +++ b/xls/dslx/prove_quickcheck_main.cc @@ -47,6 +47,9 @@ ABSL_FLAG(std::string, dslx_stdlib_path, "Path to DSLX standard library directory."); ABSL_FLAG(std::string, test_filter, "", "Regexp that must be a full match of test name(s) to run."); +ABSL_FLAG(std::string, enable_warnings, "", + "Comma-delimited list of warnings to enable -- not generally " + "recommended, but can be used in exceptional circumstances"); ABSL_FLAG(std::string, disable_warnings, "", "Comma-delimited list of warnings to disable -- not generally " "recommended, but can be used in exceptional circumstances"); @@ -68,7 +71,8 @@ absl::StatusOr RealMain( std::optional xml_output_file) { XLS_ASSIGN_OR_RETURN( WarningKindSet warnings, - WarningKindSetFromDisabledString(absl::GetFlag(FLAGS_disable_warnings))); + GetWarningsSetFromFlags(absl::GetFlag(FLAGS_enable_warnings), + absl::GetFlag(FLAGS_disable_warnings))); XLS_ASSIGN_OR_RETURN(std::string program, GetFileContents(entry_module_path)); XLS_ASSIGN_OR_RETURN(std::string module_name, PathToName(entry_module_path)); diff --git a/xls/dslx/run_routines/run_routines_test.cc b/xls/dslx/run_routines/run_routines_test.cc index 59ea6ee89b..6029b36495 100644 --- a/xls/dslx/run_routines/run_routines_test.cc +++ b/xls/dslx/run_routines/run_routines_test.cc @@ -219,6 +219,8 @@ fn qc(x: MyEnum) -> bool { constexpr const char* kFilename = "test.x"; RunComparator jit_comparator(CompareMode::kJit); ParseAndTestOptions options; + options.warnings = + DisableWarning(kAllWarningsSet, WarningKind::kAlreadyExhaustiveMatch); options.run_comparator = &jit_comparator; XLS_ASSERT_OK_AND_ASSIGN( TestResultData result, @@ -351,7 +353,7 @@ fn bfloat16_bits_to_float32_bits_upcast_is_zero_pad(x: bits[BF16_TOTAL_SZ]) -> b // Look at the failure message to make sure the u16 is reported. std::vector failures = result.GetFailureMessages(); ASSERT_EQ(failures.size(), 1); - EXPECT_THAT(failures[0], HasSubstr("tests: [u16:34]")); + EXPECT_THAT(failures[0], HasSubstr("tests: [u16:")); } // An exhaustive quickcheck that fails just for one value in a decently large diff --git a/xls/dslx/stdlib/apfloat.x b/xls/dslx/stdlib/apfloat.x index acae738e94..f7e5b7aa29 100644 --- a/xls/dslx/stdlib/apfloat.x +++ b/xls/dslx/stdlib/apfloat.x @@ -374,7 +374,6 @@ pub fn cast_from_fixed_using_rne zero(false), (false, false) => APFloat { sign: is_negative, bexp, fraction }, - _ => qnan(), } } @@ -480,7 +479,6 @@ pub fn cast_from_fixed_using_rz zero(false), (false, false) => APFloat { sign: is_negative, bexp, fraction }, - _ => qnan(), } } @@ -736,7 +734,6 @@ pub fn upcast fail!("unsupported_kind", qnan()), } } } diff --git a/xls/dslx/tests/compound_eq.x b/xls/dslx/tests/compound_eq.x index 95521eb097..1ca027189c 100644 --- a/xls/dslx/tests/compound_eq.x +++ b/xls/dslx/tests/compound_eq.x @@ -21,14 +21,13 @@ fn main() -> bool { x == x } +fn slice_off_array(x: TestBlob) -> ((u2, u1), bool) { (x.1, x.2) } + // Manually expand a test blob into its leaf components to check equality. fn blob_eq(x: TestBlob, y: TestBlob) -> bool { - let zero = u32:0; - let one = u32:1; - match (x, y) { - ((x_arr, (x_tup1, x_tup2), x_bool), (y_arr, (y_tup1, y_tup2), y_bool)) => - x_arr[zero] == y_arr[zero] && x_arr[one] == y_arr[one] && x_tup1 == y_tup1 && - x_tup2 == y_tup2 && x_bool == y_bool, + match (slice_off_array(x), slice_off_array(y)) { + (((x_tup1, x_tup2), x_bool), ((y_tup1, y_tup2), y_bool)) => + x_tup1 == y_tup1 && x_tup2 == y_tup2 && x_bool == y_bool, } } @@ -44,12 +43,9 @@ fn eq_by_element(x: TestBlob[3], y: TestBlob[3]) -> bool { // Manually expand a test blob into its leaf components to check if any are // not equal. fn blob_neq(x: TestBlob, y: TestBlob) -> bool { - let zero = u32:0; - let one = u32:1; - match (x, y) { - ((x_arr, (x_tup1, x_tup2), x_bool), (y_arr, (y_tup1, y_tup2), y_bool)) => - x_arr[zero] != y_arr[zero] || x_arr[one] != y_arr[one] || x_tup1 != y_tup1 || - x_tup2 != y_tup2 || x_bool != y_bool, + match (slice_off_array(x), slice_off_array(y)) { + (((x_tup1, x_tup2), x_bool), ((y_tup1, y_tup2), y_bool)) => + x_tup1 != y_tup1 || x_tup2 != y_tup2 || x_bool != y_bool, } } diff --git a/xls/dslx/tests/errors/error_modules_test.py b/xls/dslx/tests/errors/error_modules_test.py index f3f08d321b..bdf51a798c 100644 --- a/xls/dslx/tests/errors/error_modules_test.py +++ b/xls/dslx/tests/errors/error_modules_test.py @@ -405,8 +405,8 @@ def test_match_not_exhaustive(self): stderr = self._run( 'xls/dslx/tests/errors/match_not_exhaustive.x' ) - self.assertIn('match_not_exhaustive.x:16:3-19:4', stderr) - self.assertIn('Only matches with trailing irrefutable patterns', stderr) + self.assertIn('match_not_exhaustive.x:16:5-19:6', stderr) + self.assertIn('Match patterns are not exhaustive', stderr) def test_bad_coverpoint_name(self): stderr = self._run( @@ -1286,6 +1286,21 @@ def test_unsized_array_type(self): stderr, ) + def test_already_exhaustive_match_warning(self): + # Note: this flag is disabled by default, for now. + stderr = self._run( + 'xls/dslx/tests/errors/unnecessary_trailing_match_pattern.x', + want_err_retcode=False, + ) + self.assertNotIn('Match is already exhaustive', stderr) + + stderr = self._run( + 'xls/dslx/tests/errors/unnecessary_trailing_match_pattern.x', + enable_warnings={'already_exhaustive_match'}, + want_err_retcode=True, + ) + self.assertIn('Match is already exhaustive', stderr) + if __name__ == '__main__': test_base.main() diff --git a/xls/dslx/tests/errors/match_not_exhaustive.x b/xls/dslx/tests/errors/match_not_exhaustive.x index 343febd152..1833754eac 100644 --- a/xls/dslx/tests/errors/match_not_exhaustive.x +++ b/xls/dslx/tests/errors/match_not_exhaustive.x @@ -13,8 +13,8 @@ // limitations under the License. fn f(x: u32) -> u32 { - match x { - u32:1 => u32:64, - u32:2 => u32:42 - } + match x { + u32:1 => u32:64, + u32:2 => u32:42 + } } diff --git a/xls/dslx/tests/errors/unnecessary_trailing_match_pattern.x b/xls/dslx/tests/errors/unnecessary_trailing_match_pattern.x new file mode 100644 index 0000000000..4d00a22411 --- /dev/null +++ b/xls/dslx/tests/errors/unnecessary_trailing_match_pattern.x @@ -0,0 +1,21 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +fn f(x: bool) -> u32 { + match x { + true => u32:0, + false => u32:1, + _ => u32:2, + } +} diff --git a/xls/dslx/tests/for_enum_ref.x b/xls/dslx/tests/for_enum_ref.x index 6f1a0e799d..23388b4be2 100644 --- a/xls/dslx/tests/for_enum_ref.x +++ b/xls/dslx/tests/for_enum_ref.x @@ -26,7 +26,6 @@ fn main(x: u2) -> MyEnum { u2:1 => MyEnum::B, u2:2 => MyEnum::C, u2:3 => MyEnum::D, - _ => MyEnum::A, } }(MyEnum::A) } diff --git a/xls/dslx/tests/mod_enum_fully_qualified_match_importer.x b/xls/dslx/tests/mod_enum_fully_qualified_match_importer.x index 7554ce9813..951794a735 100644 --- a/xls/dslx/tests/mod_enum_fully_qualified_match_importer.x +++ b/xls/dslx/tests/mod_enum_fully_qualified_match_importer.x @@ -24,7 +24,6 @@ fn main(x: exporter::EnumType) -> u32 { match x { exporter::EnumType::FIRST => u32:0, exporter::EnumType::SECOND => u32:1, - _ => fail!("no_match", u32:1), } } diff --git a/xls/dslx/tests/mod_enum_use_in_for_match_importer.x b/xls/dslx/tests/mod_enum_use_in_for_match_importer.x index d65028b2fc..4dc8e7c3cf 100644 --- a/xls/dslx/tests/mod_enum_use_in_for_match_importer.x +++ b/xls/dslx/tests/mod_enum_use_in_for_match_importer.x @@ -24,7 +24,6 @@ fn main(x: EnumType) -> bool { match x { EnumType::FIRST => false, EnumType::SECOND => true, - _ => false, } }(false) } diff --git a/xls/dslx/tests/mod_simple_enum_alias_importer.x b/xls/dslx/tests/mod_simple_enum_alias_importer.x index 67121cb37b..33ef9d284a 100644 --- a/xls/dslx/tests/mod_simple_enum_alias_importer.x +++ b/xls/dslx/tests/mod_simple_enum_alias_importer.x @@ -18,7 +18,6 @@ fn main(et: mod_simple_enum::EnumTypeAlias) -> u32 { match et { mod_simple_enum::EnumTypeAlias::FIRST => u32:0, mod_simple_enum::EnumTypeAlias::SECOND => u32:1, - _ => u32:2, } } diff --git a/xls/dslx/tests/mod_simple_enum_importer.x b/xls/dslx/tests/mod_simple_enum_importer.x index d911e4af5d..f077dece30 100644 --- a/xls/dslx/tests/mod_simple_enum_importer.x +++ b/xls/dslx/tests/mod_simple_enum_importer.x @@ -18,7 +18,6 @@ fn main(et: mod_simple_enum::EnumType) -> u32 { match et { mod_simple_enum::EnumType::FIRST => u32:0, mod_simple_enum::EnumType::SECOND => u32:1, - _ => u32:2, } } diff --git a/xls/dslx/tests/quickcheck_fn_with_fail.x b/xls/dslx/tests/quickcheck_fn_with_fail.x index c1039d64db..767abdaadc 100644 --- a/xls/dslx/tests/quickcheck_fn_with_fail.x +++ b/xls/dslx/tests/quickcheck_fn_with_fail.x @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -fn seemingly_can_fail(x: u32) -> u32 { - match x { - y => y, - _ => fail!("should_be_impossible", x), - } -} +fn seemingly_can_fail(x: u32) -> u32 { if x != x { fail!("should_be_impossible", x) } else { x } } #[quickcheck] fn prop_effectively_identity(x: u32) -> bool { seemingly_can_fail(x) == x } diff --git a/xls/dslx/tests/quickcheck_with_enum.x b/xls/dslx/tests/quickcheck_with_enum.x index 90ed62f4f9..00b101d132 100644 --- a/xls/dslx/tests/quickcheck_with_enum.x +++ b/xls/dslx/tests/quickcheck_with_enum.x @@ -18,12 +18,7 @@ enum MyEnum : u2 { C = 2, } -fn is_valid(x: MyEnum) -> bool { - match x { - MyEnum::A | MyEnum::B | MyEnum::C => true, - _ => false, - } -} +fn is_valid(x: MyEnum) -> bool { x == MyEnum::A || x == MyEnum::B || x == MyEnum::C } #[quickcheck(exhaustive)] fn prop_enum_is_valid(x: MyEnum) -> bool { is_valid(x) } diff --git a/xls/dslx/type_system/BUILD b/xls/dslx/type_system/BUILD index 34c8493c19..4674a55e5a 100644 --- a/xls/dslx/type_system/BUILD +++ b/xls/dslx/type_system/BUILD @@ -156,6 +156,7 @@ cc_library( "//xls/dslx/bytecode", "//xls/dslx/bytecode:bytecode_emitter", "//xls/dslx/bytecode:bytecode_interpreter", + "//xls/dslx/exhaustiveness:match_exhaustiveness_checker", "//xls/dslx/frontend:ast", "//xls/dslx/frontend:ast_cloner", "//xls/dslx/frontend:ast_node", @@ -416,6 +417,7 @@ cc_test( "//xls/dslx/frontend:ast_node_visitor_with_default", "//xls/dslx/frontend:pos", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", diff --git a/xls/dslx/type_system/deduce.cc b/xls/dslx/type_system/deduce.cc index 686331dbfa..322eda56d5 100644 --- a/xls/dslx/type_system/deduce.cc +++ b/xls/dslx/type_system/deduce.cc @@ -49,6 +49,7 @@ #include "xls/dslx/channel_direction.h" #include "xls/dslx/constexpr_evaluator.h" #include "xls/dslx/errors.h" +#include "xls/dslx/exhaustiveness/match_exhaustiveness_checker.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/ast_cloner.h" #include "xls/dslx/frontend/ast_node.h" @@ -1339,17 +1340,85 @@ static std::string PatternsToString(MatchArm* arm) { }); } +static absl::Status ValidateMatchable(const Type& type, const Span& span, + const FileTable& file_table) { + class MatchableTypeVisitor : public TypeVisitor { + public: + MatchableTypeVisitor(const Span& span, const FileTable& file_table) + : span_(span), file_table_(file_table) {} + ~MatchableTypeVisitor() override = default; + absl::Status HandleBits(const BitsType& type) override { + return absl::OkStatus(); + } + absl::Status HandleEnum(const EnumType& type) override { + return absl::OkStatus(); + } + absl::Status HandleTuple(const TupleType& type) override { + for (const auto& member : type.members()) { + XLS_RETURN_IF_ERROR(member->Accept(*this)); + } + return absl::OkStatus(); + } + absl::Status HandleArray(const ArrayType& type) override { + std::optional bits_like = GetBitsLike(type); + if (bits_like.has_value()) { + return absl::OkStatus(); + } + return Error(type); + } + // Note: this should not be directly observable outside of the array type + // element. + absl::Status HandleBitsConstructor( + const BitsConstructorType& type) override { + return Error(type); + } + // -- these types are not matchable + absl::Status HandleMeta(const MetaType& type) override { + return Error(type); + } + absl::Status HandleFunction(const FunctionType& type) override { + return Error(type); + } + absl::Status HandleChannel(const ChannelType& type) override { + return Error(type); + } + absl::Status HandleToken(const TokenType& type) override { + return Error(type); + } + absl::Status HandleStruct(const StructType& type) override { + return Error(type); + } + absl::Status HandleProc(const ProcType& type) override { + return Error(type); + } + + private: + absl::Status Error(const Type& type) { + return TypeInferenceErrorStatus( + span_, &type, "Match construct cannot match on this type.", + file_table_); + }; + + const Span& span_; + const FileTable& file_table_; + }; + MatchableTypeVisitor visitor(span, file_table); + return type.Accept(visitor); +} + absl::StatusOr> DeduceMatch(const Match* node, DeduceCtx* ctx) { VLOG(5) << "DeduceMatch: " << node->ToString(); XLS_ASSIGN_OR_RETURN(std::unique_ptr matched, ctx->Deduce(node->matched())); - if (matched->IsMeta() || matched->IsFunction()) { - return TypeInferenceErrorStatus( - node->span(), matched.get(), - "Match construct cannot match on this type.", ctx->file_table()); - } + + // Validate that we can match on the type presented. + // + // The fact the type is matchable is assumed as a precondition in + // exhaustiveness checking. + XLS_RETURN_IF_ERROR( + ValidateMatchable(*matched, node->span(), ctx->file_table())); if (node->arms().empty()) { return TypeInferenceErrorStatus( @@ -1358,11 +1427,18 @@ absl::StatusOr> DeduceMatch(const Match* node, ctx->file_table()); } + MatchExhaustivenessChecker exhaustiveness_checker( + node->matched()->span(), *ctx->import_data(), *ctx->type_info(), + *matched); + absl::flat_hash_set seen_patterns; for (MatchArm* arm : node->arms()) { // We opportunistically identify syntactically identical match arms -- this // is a user error since the first should always match, the latter is // totally redundant. + // + // TODO(cdleary): 2025-01-31 We can get precise info on overlaps beyond + // identical syntax when the exhaustiveness checker is available. std::string patterns_string = PatternsToString(arm); if (auto [it, inserted] = seen_patterns.insert(patterns_string); !inserted) { @@ -1388,9 +1464,33 @@ absl::StatusOr> DeduceMatch(const Match* node, } XLS_RETURN_IF_ERROR(Unify(pattern, *matched, ctx)); + + bool exhaustive_before = exhaustiveness_checker.IsExhaustive(); + exhaustiveness_checker.AddPattern(*pattern); + if (exhaustive_before) { + ctx->warnings()->Add(pattern->span(), + WarningKind::kAlreadyExhaustiveMatch, + "Match is already exhaustive before this pattern"); + } } } + if (!exhaustiveness_checker.IsExhaustive()) { + std::optional sample = + exhaustiveness_checker.SampleSimplestUncoveredValue(); + XLS_RET_CHECK(sample.has_value()); + return TypeInferenceErrorStatus( + node->span(), matched.get(), + absl::StrFormat( + "Match %s not exhaustive; e.g. `%s` not covered; please add " + "remaining " + "patterns to complete the match or a default case " + "via `_ => ...`", + seen_patterns.size() == 1 ? "pattern is" : "patterns are", + sample->ToString()), + ctx->file_table()); + } + std::vector> arm_types; for (MatchArm* arm : node->arms()) { XLS_ASSIGN_OR_RETURN(std::unique_ptr arm_type, diff --git a/xls/dslx/type_system/deduce_colon_ref.cc b/xls/dslx/type_system/deduce_colon_ref.cc index 1ee18e54ca..8a896e1146 100644 --- a/xls/dslx/type_system/deduce_colon_ref.cc +++ b/xls/dslx/type_system/deduce_colon_ref.cc @@ -322,9 +322,23 @@ absl::StatusOr> DeduceColonRef(const ColonRef* node, ctx->file_table()); } XLS_ASSIGN_OR_RETURN( - auto enum_type, DeduceEnumDef(enum_def, subject_ctx.get())); - return UnwrapMetaType(std::move(enum_type), node->span(), - "enum type", ctx->file_table()); + std::unique_ptr enum_type, + DeduceEnumDef(enum_def, subject_ctx.get())); + XLS_ASSIGN_OR_RETURN( + enum_type, + UnwrapMetaType(std::move(enum_type), node->span(), + "enum type", ctx->file_table())); + + // We also want to note the ColonRef's constexpr value as we + // resolved its enum definition. + XLS_ASSIGN_OR_RETURN(Expr * enum_value_expr, + enum_def->GetValue(node->attr())); + XLS_ASSIGN_OR_RETURN( + InterpValue enum_value, + subject_ctx->type_info()->GetConstExpr(enum_value_expr)); + ctx->type_info()->NoteConstExpr(node, enum_value); + + return enum_type; }, [&](BuiltinNameDef* builtin_name_def) -> ReturnT { return DeduceColonRefToBuiltinNameDef(builtin_name_def, node); diff --git a/xls/dslx/type_system/type.h b/xls/dslx/type_system/type.h index b0535afaf4..b2712f0fa4 100644 --- a/xls/dslx/type_system/type.h +++ b/xls/dslx/type_system/type.h @@ -1139,6 +1139,13 @@ inline bool operator==(const BitsLikeProperties& a, return a.is_signed == b.is_signed && a.size == b.size; } +inline std::ostream& operator<<(std::ostream& os, + const BitsLikeProperties& properties) { + return os << absl::StreamFormat("BitsLikeProperties{is_signed: %s, size: %s}", + properties.is_signed.ToString(), + properties.size.ToString()); +} + // Returns ths "bits-like properties" for a given type `t` -- in practice this // means that the type can either be a true `BitsType` or an instantiated // `BitsConstructorType` -- from both of these forms we can retrieve information diff --git a/xls/dslx/type_system/typecheck_module_test.cc b/xls/dslx/type_system/typecheck_module_test.cc index 7793272937..990f53bfce 100644 --- a/xls/dslx/type_system/typecheck_module_test.cc +++ b/xls/dslx/type_system/typecheck_module_test.cc @@ -22,6 +22,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" @@ -2172,6 +2173,31 @@ const X: u32 = p(false); HasSubstr("pattern expects uN[1]"))); } +TEST(TypecheckErrorTest, MatchNonExhaustive) { + absl::StatusOr result = Typecheck(R"( +fn f(x: u32) -> u32 { + match x { + u32:1 => u32:64, + u32:2 => u32:42, + } +} +)"); + EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Match patterns are not exhaustive"))); +} + +TEST(TypecheckErrorTest, MatchWithOneNonExhaustivePattern) { + absl::StatusOr result = Typecheck(R"( +fn f(x: u32) -> u32 { + match x { + u32:1 => u32:64, + } +} +)"); + EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Match pattern is not exhaustive"))); +} + TEST(TypecheckErrorTest, ArrayInconsistency) { EXPECT_THAT(Typecheck(R"( type Foo = (u8, u32); diff --git a/xls/dslx/warning_kind.cc b/xls/dslx/warning_kind.cc index cadd79c401..b741d05205 100644 --- a/xls/dslx/warning_kind.cc +++ b/xls/dslx/warning_kind.cc @@ -52,6 +52,8 @@ absl::StatusOr WarningKindToString(WarningKind kind) { return "member_naming"; case WarningKind::kShouldUseAssert: return "should_use_assert"; + case WarningKind::kAlreadyExhaustiveMatch: + return "already_exhaustive_match"; } return absl::InvalidArgumentError( absl::StrCat("Invalid warning kind: ", static_cast(kind))); @@ -99,4 +101,29 @@ absl::StatusOr WarningKindSetFromString( return enabled; } +std::string WarningKindSetToString(WarningKindSet set) { + std::vector enabled_warnings; + for (WarningKind kind : kAllWarningKinds) { + if (WarningIsEnabled(set, kind)) { + enabled_warnings.push_back(WarningKindToString(kind).value()); + } + } + return absl::StrJoin(enabled_warnings, ","); +} + +absl::StatusOr GetWarningsSetFromFlags( + std::string_view enable_warnings, std::string_view disable_warnings) { + XLS_ASSIGN_OR_RETURN(WarningKindSet enabled, + WarningKindSetFromString(enable_warnings)); + XLS_ASSIGN_OR_RETURN(WarningKindSet disabled, + WarningKindSetFromString(disable_warnings)); + if ((enabled & disabled) != kNoWarningsSet) { + return absl::InvalidArgumentError(absl::StrFormat( + "Cannot both enable and disable the same warning(s); enabled: %s " + "disabled: %s", + WarningKindSetToString(enabled), WarningKindSetToString(disabled))); + } + return (kDefaultWarningsSet | enabled) & Complement(disabled); +} + } // namespace xls::dslx diff --git a/xls/dslx/warning_kind.h b/xls/dslx/warning_kind.h index 917a4a6e17..4170bd7b11 100644 --- a/xls/dslx/warning_kind.h +++ b/xls/dslx/warning_kind.h @@ -42,8 +42,9 @@ enum class WarningKind : WarningKindInt { kConstantNaming = 1 << 9, kMemberNaming = 1 << 10, kShouldUseAssert = 1 << 11, + kAlreadyExhaustiveMatch = 1 << 12, }; -constexpr WarningKindInt kWarningKindCount = 12; +constexpr WarningKindInt kWarningKindCount = 13; inline constexpr std::array kAllWarningKinds = { WarningKind::kConstexprEvalRollover, @@ -58,6 +59,7 @@ inline constexpr std::array kAllWarningKinds = { WarningKind::kConstantNaming, WarningKind::kMemberNaming, WarningKind::kShouldUseAssert, + WarningKind::kAlreadyExhaustiveMatch, }; // Flag set datatype. @@ -70,6 +72,24 @@ inline constexpr WarningKindSet kNoWarningsSet = WarningKindSet{0}; inline constexpr WarningKindSet kAllWarningsSet = WarningKindSet{(WarningKindInt{1} << kWarningKindCount) - 1}; +// Set intersection. +inline WarningKindSet operator&(WarningKindSet a, WarningKindSet b) { + return WarningKindSet{a.value() & b.value()}; +} + +// Set union. +inline WarningKindSet operator|(WarningKindSet a, WarningKindSet b) { + return WarningKindSet{a.value() | b.value()}; +} + +// Returns the complement of a warning set. +// +// Note that we define this instead of operator~ because it has an existing +// overload for STRONG_INT_TYPE. +inline WarningKindSet Complement(WarningKindSet a) { + return WarningKindSet{~a.value() & kAllWarningsSet.value()}; +} + // Disables "warning" out of "set" and returns that updated result. constexpr WarningKindSet DisableWarning(WarningKindSet set, WarningKind warning) { @@ -88,8 +108,11 @@ inline bool WarningIsEnabled(WarningKindSet set, WarningKind warning) { // TODO(leary): 2024-03-15 Enable "should use fail if" by default after some // propagation time. -inline constexpr WarningKindSet kDefaultWarningsSet = - DisableWarning(kAllWarningsSet, WarningKind::kShouldUseAssert); +// TODO(cdleary): 2025-02-03 Enable "already exhaustive match" by default after +// some propagation time. +inline constexpr WarningKindSet kDefaultWarningsSet = DisableWarning( + DisableWarning(kAllWarningsSet, WarningKind::kShouldUseAssert), + WarningKind::kAlreadyExhaustiveMatch); // Converts a string representation of a warnings to its corresponding enum // value. @@ -107,6 +130,14 @@ absl::StatusOr WarningKindSetFromDisabledString( absl::StatusOr WarningKindSetFromString( std::string_view enabled_string); +std::string WarningKindSetToString(WarningKindSet set); + +// Returns the default warning set with the modifications given in flags. +// +// If flags are contradictory, returns an error. +absl::StatusOr GetWarningsSetFromFlags( + std::string_view enable_warnings, std::string_view disable_warnings); + } // namespace xls::dslx #endif // XLS_DSLX_WARNING_KIND_H_ diff --git a/xls/dslx/warning_kind_test.cc b/xls/dslx/warning_kind_test.cc index 9a47a7ef51..363f168ae4 100644 --- a/xls/dslx/warning_kind_test.cc +++ b/xls/dslx/warning_kind_test.cc @@ -17,6 +17,7 @@ #include #include "gtest/gtest.h" +#include "absl/status/status_matchers.h" #include "xls/common/status/matchers.h" namespace xls::dslx { @@ -56,11 +57,59 @@ TEST(WarningKindTest, DefaultSetAnyMissing) { WarningIsEnabled(kDefaultWarningsSet, WarningKind::kShouldUseAssert)); } +TEST(WarningKindTest, Complement) { + EXPECT_EQ(Complement(kAllWarningsSet), kNoWarningsSet); + EXPECT_EQ(Complement(kNoWarningsSet), kAllWarningsSet); +} + +TEST(WarningKindTest, SetIntersection) { + EXPECT_EQ(kAllWarningsSet & kAllWarningsSet, kAllWarningsSet); + EXPECT_EQ(kAllWarningsSet & kNoWarningsSet, kNoWarningsSet); + EXPECT_EQ(kNoWarningsSet & kAllWarningsSet, kNoWarningsSet); + EXPECT_EQ(kNoWarningsSet & kNoWarningsSet, kNoWarningsSet); + EXPECT_EQ(DisableWarning(kAllWarningsSet, WarningKind::kShouldUseAssert) & + EnableWarning(kNoWarningsSet, WarningKind::kShouldUseAssert), + kNoWarningsSet); +} + +TEST(WarningKindTest, SetUnion) { + EXPECT_EQ(kAllWarningsSet | kAllWarningsSet, kAllWarningsSet); + EXPECT_EQ(kAllWarningsSet | kNoWarningsSet, kAllWarningsSet); + EXPECT_EQ(kNoWarningsSet | kAllWarningsSet, kAllWarningsSet); + EXPECT_EQ(kNoWarningsSet | kNoWarningsSet, kNoWarningsSet); + EXPECT_EQ(DisableWarning(kAllWarningsSet, WarningKind::kShouldUseAssert) | + EnableWarning(kNoWarningsSet, WarningKind::kShouldUseAssert), + kAllWarningsSet); +} + TEST(WarningKindTest, WarningKindSetFromString) { XLS_ASSERT_OK_AND_ASSIGN(WarningKindSet set, WarningKindSetFromString("should_use_assert")); ASSERT_TRUE(WarningIsEnabled(set, WarningKind::kShouldUseAssert)); } +TEST(WarningKindTest, GetWarningsSetFromFlagsEmpty) { + XLS_ASSERT_OK_AND_ASSIGN(WarningKindSet set, GetWarningsSetFromFlags("", "")); + EXPECT_EQ(set, kDefaultWarningsSet); +} + +TEST(WarningKindTest, GetWarningsSetFromFlagsEmptyEnable) { + XLS_ASSERT_OK_AND_ASSIGN(WarningKindSet set, + GetWarningsSetFromFlags("", "constant_naming")); + EXPECT_EQ(set, + DisableWarning(kDefaultWarningsSet, WarningKind::kConstantNaming)); +} + +TEST(WarningKindTest, GetWarningsSetFromFlagsContradiction) { + absl::StatusOr set = + GetWarningsSetFromFlags("constant_naming", "constant_naming"); + EXPECT_THAT(set.status(), + absl_testing::StatusIs( + absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "Cannot both enable and disable the same warning(s); " + "enabled: constant_naming disabled: constant_naming"))); +} + } // namespace } // namespace xls::dslx diff --git a/xls/examples/dslx_module/some_caps_streaming.x b/xls/examples/dslx_module/some_caps_streaming.x index 6f1849d831..1810327d66 100644 --- a/xls/examples/dslx_module/some_caps_streaming.x +++ b/xls/examples/dslx_module/some_caps_streaming.x @@ -29,7 +29,6 @@ pub proc some_caps_streaming { some_caps::Choice::CAPITALIZE => some_caps::Choice::NOTHING, some_caps::Choice::NOTHING => some_caps::Choice::SPONGE, some_caps::Choice::SPONGE => some_caps::Choice::CAPITALIZE, - _ => state, }; let tok = send(tok, bytes_result, some_caps::maybe_capitalize(val, state)); ns diff --git a/xls/examples/gcd.x b/xls/examples/gcd.x index 3b47257d03..50d6a21eb7 100644 --- a/xls/examples/gcd.x +++ b/xls/examples/gcd.x @@ -32,7 +32,6 @@ fn gcd_binary_match(a: uN[N], b: uN[N], d: uN[N]) -> (uN[N], uN[N], uN[N (u1:1, u1:0) => (a, b >> 1, d), (u1:0, u1:0) => (a >> 1, b >> 1, d+uN[N]:1), (u1:1, u1:1) => ((a - b) >> 1, b, d), - _ => fail!("unsupported_case", (uN[N]:0, uN[N]:0, uN[N]:0)), } } diff --git a/xls/examples/hack_cpu.x b/xls/examples/hack_cpu.x index a44ffbe034..c5749935cb 100644 --- a/xls/examples/hack_cpu.x +++ b/xls/examples/hack_cpu.x @@ -258,7 +258,6 @@ fn cpu(pc: u16, rd: u16, ra: u16, ram: u16[32], rom: u16[32]) -> (u16, u16, u16, let (pc', rd', ra', rm', wm) = match ins[15+:u1] { u1:0 => run_a_instruction(pc, ins, rd, ra, rm), u1:1 => run_c_instruction(pc, ins, rd, ra, rm), - _ => fail!("exhaustive_boolean_match", (pc, rd, ra, rm, u1:0)), }; (pc', rd', ra', if wm { update(ram, ra', rm') } else { ram }) } diff --git a/xls/examples/ram.x b/xls/examples/ram.x index 6b283ff747..96609f1731 100644 --- a/xls/examples/ram.x +++ b/xls/examples/ram.x @@ -266,7 +266,6 @@ pub proc RamModel mem[read_req.addr], SimultaneousReadWriteBehavior::WRITE_BEFORE_READ => value_to_write, SimultaneousReadWriteBehavior::ASSERT_NO_CONFLICT => fail!("conflicting_read_and_write", mem[read_req.addr]), - _ => fail!("impossible_case", uN[DATA_WIDTH]:0), } } else { mem[read_req.addr] }; let read_resp_value = ReadResp { diff --git a/xls/fuzzer/ast_generator.cc b/xls/fuzzer/ast_generator.cc index dda68f00c7..9b4e79d2bd 100644 --- a/xls/fuzzer/ast_generator.cc +++ b/xls/fuzzer/ast_generator.cc @@ -649,24 +649,25 @@ absl::StatusOr AstGenerator::GenerateCompareArray(Context* ctx) { .min_stage = std::max(lhs.min_stage, rhs.min_stage)}; } -class FindTokenTypeVisitor : public AstNodeVisitorWithDefault { +class FindTypeVisitor : public AstNodeVisitorWithDefault { public: - FindTokenTypeVisitor() = default; + explicit FindTypeVisitor( + std::function is_target_type) + : is_target_type_(std::move(is_target_type)) {} - bool GetTokenFound() const { return token_found_; } + bool found() const { return found_; } absl::Status HandleBuiltinTypeAnnotation( const BuiltinTypeAnnotation* builtin_type) override { - if (!token_found_) { - token_found_ = builtin_type->builtin_type() == BuiltinType::kToken; - } + found_ = found_ || is_target_type_(builtin_type); return absl::OkStatus(); } absl::Status HandleTupleTypeAnnotation( const TupleTypeAnnotation* tuple_type) override { + found_ = found_ || is_target_type_(tuple_type); for (TypeAnnotation* member_type : tuple_type->members()) { - if (token_found_) { + if (found_) { break; } XLS_RETURN_IF_ERROR(member_type->Accept(this)); @@ -676,11 +677,13 @@ class FindTokenTypeVisitor : public AstNodeVisitorWithDefault { absl::Status HandleArrayTypeAnnotation( const ArrayTypeAnnotation* array_type) override { + found_ = found_ || is_target_type_(array_type); return array_type->element_type()->Accept(this); } absl::Status HandleTypeRefTypeAnnotation( const TypeRefTypeAnnotation* type_ref_type) override { + found_ = found_ || is_target_type_(type_ref_type); return type_ref_type->type_ref()->Accept(this); } @@ -720,6 +723,7 @@ class FindTokenTypeVisitor : public AstNodeVisitorWithDefault { absl::Status HandleTypeVariableTypeAnnotation( const TypeVariableTypeAnnotation* type_variable_type) override { + found_ = found_ || is_target_type_(type_variable_type); return type_variable_type->type_variable()->Accept(this); } @@ -764,7 +768,7 @@ class FindTokenTypeVisitor : public AstNodeVisitorWithDefault { private: absl::Status HandleStructDefBaseInternal(const StructDefBase* struct_def) { for (const StructMemberNode* member : struct_def->members()) { - if (token_found_) { + if (found_) { break; } XLS_RETURN_IF_ERROR(member->type()->Accept(this)); @@ -772,36 +776,32 @@ class FindTokenTypeVisitor : public AstNodeVisitorWithDefault { return absl::OkStatus(); } - bool token_found_ = false; + const std::function is_target_type_; + bool found_ = false; }; /* static */ absl::StatusOr AstGenerator::ContainsToken( const TypeAnnotation* type) { - FindTokenTypeVisitor token_visitor; - XLS_RETURN_IF_ERROR(type->Accept(&token_visitor)); - return token_visitor.GetTokenFound(); + FindTypeVisitor find_type_visitor( + [](const TypeAnnotation* type) { return IsToken(type); }); + XLS_RETURN_IF_ERROR(type->Accept(&find_type_visitor)); + return find_type_visitor.found(); } -/* static */ bool AstGenerator::ContainsTypeRef(const TypeAnnotation* type) { - if (IsTypeRef(type)) { - return true; - } - if (auto tuple_type = dynamic_cast(type)) { - for (TypeAnnotation* member_type : tuple_type->members()) { - if (ContainsTypeRef(member_type)) { - return true; - } - } - return false; - } - if (auto array_type = dynamic_cast(type)) { - return ContainsTypeRef(array_type->element_type()); - } - if (auto channel_type = dynamic_cast(type)) { - return ContainsTypeRef(channel_type->payload()); - } - CHECK_NE(dynamic_cast(type), nullptr); - return false; +/* static */ absl::StatusOr AstGenerator::ContainsArray( + const TypeAnnotation* type) { + FindTypeVisitor find_type_visitor( + [](const TypeAnnotation* type) { return IsArray(type); }); + XLS_RETURN_IF_ERROR(type->Accept(&find_type_visitor)); + return find_type_visitor.found(); +} + +/* static */ absl::StatusOr AstGenerator::ContainsTypeRef( + const TypeAnnotation* type) { + FindTypeVisitor find_type_visitor( + [](const TypeAnnotation* type) { return IsTypeRef(type); }); + XLS_RETURN_IF_ERROR(type->Accept(&find_type_visitor)); + return find_type_visitor.found(); } absl::StatusOr AstGenerator::ChooseEnvValueTupleWithoutToken( @@ -911,6 +911,11 @@ absl::StatusOr AstGenerator::GenerateExprOfType( absl::StatusOr AstGenerator::GenerateMatchArmPattern( Context* ctx, const TypeAnnotation* type) { + XLS_RET_CHECK(!IsTypeRef(type)) << "Matched-on TypeRefs-typed values are not " + "supported by the fuzzer; got: " + << type->ToString(); + XLS_RET_CHECK(!IsArray(type)) + << "Matched-on type cannot be an array; got: " << type->ToString(); if (IsTuple(type)) { auto tuple_type = dynamic_cast(type); // Ten percent of the time, generate a wildcard pattern. @@ -952,33 +957,9 @@ absl::StatusOr AstGenerator::GenerateMatchArmPattern( return module_->Make(fake_span_, tuple_values); } - if (IsArray(type)) { - // For the array type, only name references are supported in the match arm. - // Reference: https://github.com/google/xls/issues/810. - auto array_matches = [&type](const TypedExpr& e) -> bool { - return !ContainsTypeRef(e.type) && e.type->ToString() == type->ToString(); - }; - std::vector array_candidates = - GatherAllValues(&ctx->env, array_matches); - // Twenty percent of the time, generate a wildcard pattern. - if (array_candidates.empty() || RandomBool(0.20)) { - WildcardPattern* wc = module_->Make(fake_span_); - return module_->Make(fake_span_, wc); - } - TypedExpr array = - RandomChoice(absl::MakeConstSpan(array_candidates), bit_gen_); - NameRef* name_ref = dynamic_cast(array.expr); - return module_->Make(fake_span_, name_ref); - } - if (auto* type_ref_type = dynamic_cast(type)) { - TypeRef* type_ref = type_ref_type->type_ref(); - const TypeDefinition& type_definition = type_ref->type_definition(); - CHECK(std::holds_alternative(type_definition)); - TypeAlias* alias = std::get(type_definition); - return GenerateMatchArmPattern(ctx, &alias->type_annotation()); - } - - CHECK(IsBits(type)); + CHECK(IsBits(type)) + << "Expected match arm pattern to be a tuple or bits type; got: " + << type->ToString() << " kind: " << type->GetNodeTypeName(); // Five percent of the time, generate a wildcard pattern. if (RandomBool(0.05)) { @@ -1036,8 +1017,15 @@ absl::StatusOr AstGenerator::GenerateMatchArmPattern( } absl::StatusOr AstGenerator::GenerateMatch(Context* ctx) { - XLS_ASSIGN_OR_RETURN(TypedExpr match, - ChooseEnvValueNotContainingToken(&ctx->env)); + XLS_ASSIGN_OR_RETURN( + TypedExpr match, + ChooseEnvValue(&ctx->env, [](const TypedExpr& te) -> bool { + // TODO(cdleary): 2025-02-03 We should be able to support TypeRef if we + // were able to dereference it to an originating `TypeAnnotation`. + return !ContainsToken(te.type).value() && + !ContainsArray(te.type).value() && + !ContainsTypeRef(te.type).value(); + })); LastDelayingOp last_delaying_op = match.last_delaying_op; int64_t min_stage = match.min_stage; TypeAnnotation* match_return_type = GenerateType(); diff --git a/xls/fuzzer/ast_generator.h b/xls/fuzzer/ast_generator.h index 13268594ce..c8dfe1da66 100644 --- a/xls/fuzzer/ast_generator.h +++ b/xls/fuzzer/ast_generator.h @@ -226,8 +226,16 @@ class AstGenerator { static bool EnvContainsTuple(const Env& e); static bool EnvContainsToken(const Env& e); static bool EnvContainsChannel(const Env& e); + + // Helper that tests whether a given type annotation contains a token type + // (transitively anywhere in the structure of the type). static absl::StatusOr ContainsToken(const TypeAnnotation* type); - static bool ContainsTypeRef(const TypeAnnotation* type); + + // As above but for types containing arrays. + static absl::StatusOr ContainsArray(const TypeAnnotation* type); + + // As above but for types containing TypeRef type annotations. + static absl::StatusOr ContainsTypeRef(const TypeAnnotation* type); // Generates a function with name "name", returning the minimum number of // stages the function can be scheduled in. diff --git a/xls/modules/zstd/dec_demux.x b/xls/modules/zstd/dec_demux.x index 5bcd380f91..1ff0c5dba7 100644 --- a/xls/modules/zstd/dec_demux.x +++ b/xls/modules/zstd/dec_demux.x @@ -141,7 +141,6 @@ pub proc DecoderDemux { }; (false, false, true, new_state) }, - _ => fail!("IDLE_STATE_IMPOSSIBLE", (false, false, false, state)) }; let end_state = if (send_raw || send_rle || send_cmp) { diff --git a/xls/modules/zstd/frame_header.x b/xls/modules/zstd/frame_header.x index 858d64ac53..8159f1e47e 100644 --- a/xls/modules/zstd/frame_header.x +++ b/xls/modules/zstd/frame_header.x @@ -267,7 +267,6 @@ fn parse_dictionary_id(buffer: Buffer, desc: FrameHeade u2:1 => u32:1, u2:2 => u32:2, u2:3 => u32:4, - _ => fail!("not_possible", u32:0) }; let (result, data) = buff::buffer_pop_checked(buffer, bytes * u32:8); @@ -362,7 +361,6 @@ fn parse_frame_content_size(buffer: Buffer, desc: Frame u2:1 => u32:2, u2:2 => u32:4, u2:3 => u32:8, - _ => fail!("not_possible", u32:0) }; let (result, data) = buff::buffer_pop_checked(buffer, bytes * u32:8); diff --git a/xls/modules/zstd/memory/axi_stream_add_empty.x b/xls/modules/zstd/memory/axi_stream_add_empty.x index 7a94d9c1e1..73dabedbee 100644 --- a/xls/modules/zstd/memory/axi_stream_add_empty.x +++ b/xls/modules/zstd/memory/axi_stream_add_empty.x @@ -203,13 +203,6 @@ pub proc AxiStreamAddEmpty< Fsm::ERROR => { state }, - _ => { - assert!(false, "Invalid state"); - State { - fsm: Fsm::ERROR, - ..state - } - } }; next_state diff --git a/xls/modules/zstd/memory/axi_writer.x b/xls/modules/zstd/memory/axi_writer.x index 2f62307731..58ed5dfbf3 100644 --- a/xls/modules/zstd/memory/axi_writer.x +++ b/xls/modules/zstd/memory/axi_writer.x @@ -267,13 +267,6 @@ pub proc AxiWriter< ..state } }, - _ => { - assert!(false, "Invalid state"); - State { - fsm: Fsm::ERROR, - ..state - } - } }; let w_bundle = match(state.fsm) { diff --git a/xls/modules/zstd/memory/mem_reader.x b/xls/modules/zstd/memory/mem_reader.x index ea96264728..c0d6c5709d 100644 --- a/xls/modules/zstd/memory/mem_reader.x +++ b/xls/modules/zstd/memory/mem_reader.x @@ -201,10 +201,6 @@ proc MemReaderInternal< State { fsm: Fsm::REQUEST, ..state } } }, - _ => { - fail!("invalid_state", false); - state - }, }; next_state diff --git a/xls/modules/zstd/zstd_dec.x b/xls/modules/zstd/zstd_dec.x index 0f9fac906e..d3dea6ae59 100644 --- a/xls/modules/zstd/zstd_dec.x +++ b/xls/modules/zstd/zstd_dec.x @@ -85,7 +85,6 @@ fn decode_magic_number(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdD ..ZERO_DECODER_STATE }, magic::MagicStatus::NO_ENOUGH_DATA => state, - _ => state, }; trace_fmt!("zstd_dec: decode_magic_number: new_state: {:#x}", new_state); @@ -113,7 +112,6 @@ fn decode_frame_header(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdD status: ZstdDecoderStatus::ERROR, ..ZERO_DECODER_STATE }, - _ => state, }; trace_fmt!("zstd_dec: decode_frame_header: new_state: {:#x}", new_state); @@ -163,7 +161,6 @@ fn decode_block_header(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdD ..ZERO_DECODER_STATE }, block_header::BlockHeaderStatus::NO_ENOUGH_DATA => state, - _ => state, }; trace_fmt!("zstd_dec: decode_block_header: new_state: {:#x}", new_state);