|
| 1 | +// Copyright 2024 The XLS Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "xls/codegen/fold_vast_constants.h" |
| 16 | + |
| 17 | +#include <cstdint> |
| 18 | +#include <memory> |
| 19 | +#include <optional> |
| 20 | +#include <utility> |
| 21 | +#include <vector> |
| 22 | + |
| 23 | +#include "absl/container/flat_hash_map.h" |
| 24 | +#include "absl/numeric/bits.h" |
| 25 | +#include "absl/status/status.h" |
| 26 | +#include "absl/status/statusor.h" |
| 27 | +#include "absl/strings/str_cat.h" |
| 28 | +#include "absl/types/span.h" |
| 29 | +#include "xls/codegen/vast.h" |
| 30 | +#include "xls/common/status/status_macros.h" |
| 31 | +#include "xls/ir/bits.h" |
| 32 | +#include "xls/ir/format_preference.h" |
| 33 | + |
| 34 | +namespace xls { |
| 35 | +namespace verilog { |
| 36 | +namespace { |
| 37 | + |
| 38 | +// Based upon the util version. |
| 39 | +uint64_t Log2Ceiling(uint64_t n) { |
| 40 | + int floor = (absl::bit_width(n) - 1); |
| 41 | + return floor + (((n & (n - 1)) == 0) ? 0 : 1); |
| 42 | +} |
| 43 | + |
| 44 | +// Helper class used internally by the exposed API to bind a type map and avoid |
| 45 | +// passing it around. |
| 46 | +class ConstantFoldingContext { |
| 47 | + public: |
| 48 | + explicit ConstantFoldingContext( |
| 49 | + const absl::flat_hash_map<Expression*, DataType*>& type_map) |
| 50 | + : type_map_(type_map) {} |
| 51 | + |
| 52 | + absl::StatusOr<Expression*> FoldConstants(Expression* expr) { |
| 53 | + if (expr->IsLiteral()) { |
| 54 | + return expr; |
| 55 | + } |
| 56 | + if (auto* ref = dynamic_cast<ParameterRef*>(expr); |
| 57 | + ref && ref->parameter()->rhs()) { |
| 58 | + return FoldConstants(ref->parameter()->rhs()); |
| 59 | + } |
| 60 | + if (auto* ref = dynamic_cast<EnumMemberRef*>(expr); |
| 61 | + ref && ref->member()->rhs()) { |
| 62 | + return FoldConstants(ref->member()->rhs()); |
| 63 | + } |
| 64 | + if (auto* op = dynamic_cast<BinaryInfix*>(expr); op) { |
| 65 | + return FoldBinaryOp(op); |
| 66 | + } |
| 67 | + if (auto* ternary = dynamic_cast<Ternary*>(expr); ternary) { |
| 68 | + absl::StatusOr<int64_t> test_value = FoldEntireExpr(ternary->test()); |
| 69 | + absl::StatusOr<int64_t> consequent_value = |
| 70 | + FoldEntireExpr(ternary->consequent()); |
| 71 | + absl::StatusOr<int64_t> alternate_value = |
| 72 | + FoldEntireExpr(ternary->alternate()); |
| 73 | + if (test_value.ok() && consequent_value.ok() && alternate_value.ok()) { |
| 74 | + return *test_value ? MakeFoldedConstant(ternary, *consequent_value) |
| 75 | + : MakeFoldedConstant(ternary, *alternate_value); |
| 76 | + } |
| 77 | + } |
| 78 | + if (auto* call = dynamic_cast<SystemFunctionCall*>(expr); call) { |
| 79 | + if (call->name() == "clog2" && call->args().has_value() && |
| 80 | + (*call->args()).size() == 1) { |
| 81 | + absl::StatusOr<int64_t> arg_value = FoldEntireExpr((*call->args())[0]); |
| 82 | + if (arg_value.ok()) { |
| 83 | + return MakeFoldedConstant(call, Log2Ceiling(*arg_value)); |
| 84 | + } |
| 85 | + } |
| 86 | + } |
| 87 | + return expr; |
| 88 | + } |
| 89 | + |
| 90 | + absl::StatusOr<int64_t> FoldEntireExpr(Expression* expr) { |
| 91 | + XLS_ASSIGN_OR_RETURN(Expression * folded, FoldConstants(expr)); |
| 92 | + if (folded->IsLiteral()) { |
| 93 | + return folded->AsLiteralOrDie()->ToInt64(); |
| 94 | + } |
| 95 | + return absl::InvalidArgumentError( |
| 96 | + absl::StrCat("Expression does not entirely fold to a constant: ", |
| 97 | + expr->Emit(nullptr))); |
| 98 | + } |
| 99 | + |
| 100 | + absl::StatusOr<Literal*> MakeFoldedConstant(Expression* original, |
| 101 | + int64_t value) { |
| 102 | + int64_t effective_size = 32; |
| 103 | + bool is_signed = true; |
| 104 | + const auto it = type_map_.find(original); |
| 105 | + if (it != type_map_.end()) { |
| 106 | + XLS_ASSIGN_OR_RETURN(effective_size, it->second->FlatBitCountAsInt64()); |
| 107 | + is_signed = it->second->is_signed(); |
| 108 | + } |
| 109 | + return original->file()->Make<Literal>( |
| 110 | + original->loc(), |
| 111 | + is_signed ? SBits(value, effective_size) |
| 112 | + : UBits(static_cast<uint64_t>(value), effective_size), |
| 113 | + value < 0 ? FormatPreference::kHex : FormatPreference::kUnsignedDecimal, |
| 114 | + /*effective_bit_count=*/effective_size, |
| 115 | + /*emit_bit_count=*/effective_size != 32, |
| 116 | + /*declared_as_signed=*/is_signed); |
| 117 | + } |
| 118 | + |
| 119 | + absl::StatusOr<DataType*> FoldConstants(DataType* data_type) { |
| 120 | + if (data_type->FlatBitCountAsInt64().ok()) { |
| 121 | + return data_type; |
| 122 | + } |
| 123 | + if (auto* bit_vector_type = dynamic_cast<BitVectorType*>(data_type); |
| 124 | + bit_vector_type && !bit_vector_type->size_expr()->IsLiteral()) { |
| 125 | + XLS_ASSIGN_OR_RETURN(int64_t folded_size, |
| 126 | + FoldEntireExpr(bit_vector_type->size_expr())); |
| 127 | + return data_type->file()->Make<BitVectorType>( |
| 128 | + data_type->loc(), |
| 129 | + data_type->file()->PlainLiteral(static_cast<int32_t>(folded_size), |
| 130 | + data_type->loc()), |
| 131 | + data_type->is_signed(), |
| 132 | + /*size_expr_is_max=*/bit_vector_type->max().has_value()); |
| 133 | + } |
| 134 | + if (auto* array_type = dynamic_cast<UnpackedArrayType*>(data_type); |
| 135 | + array_type) { |
| 136 | + XLS_ASSIGN_OR_RETURN(std::vector<int64_t> folded_dims, |
| 137 | + FoldDims(array_type->dims())); |
| 138 | + XLS_ASSIGN_OR_RETURN(DataType * folded_element_type, |
| 139 | + FoldConstants(array_type->element_type())); |
| 140 | + return array_type->file()->template Make<UnpackedArrayType>( |
| 141 | + data_type->loc(), folded_element_type, folded_dims); |
| 142 | + } |
| 143 | + if (auto* array_type = dynamic_cast<PackedArrayType*>(data_type); |
| 144 | + array_type) { |
| 145 | + XLS_ASSIGN_OR_RETURN(std::vector<int64_t> folded_dims, |
| 146 | + FoldDims(array_type->dims())); |
| 147 | + XLS_ASSIGN_OR_RETURN(DataType * folded_element_type, |
| 148 | + FoldConstants(array_type->element_type())); |
| 149 | + return array_type->file()->template Make<PackedArrayType>( |
| 150 | + data_type->loc(), folded_element_type, folded_dims, |
| 151 | + /*dims_are_max=*/array_type->dims_are_max()); |
| 152 | + } |
| 153 | + if (auto* enum_def = dynamic_cast<Enum*>(data_type); enum_def) { |
| 154 | + XLS_ASSIGN_OR_RETURN(DataType * folded_base_type, |
| 155 | + FoldConstants(enum_def->BaseType())); |
| 156 | + return data_type->file()->Make<Enum>(enum_def->loc(), enum_def->kind(), |
| 157 | + folded_base_type, |
| 158 | + enum_def->members()); |
| 159 | + } |
| 160 | + if (auto* struct_def = dynamic_cast<Struct*>(data_type); struct_def) { |
| 161 | + XLS_ASSIGN_OR_RETURN(std::vector<Def*> folded_member_defs, |
| 162 | + FoldTypesOfDefs(struct_def->members())); |
| 163 | + return struct_def->file()->Make<Struct>(struct_def->loc(), |
| 164 | + folded_member_defs); |
| 165 | + } |
| 166 | + if (auto* type_def = dynamic_cast<TypedefType*>(data_type); type_def) { |
| 167 | + return FoldConstants(type_def->BaseType()); |
| 168 | + } |
| 169 | + return absl::InternalError(absl::StrCat("Could not constant-fold type: ", |
| 170 | + data_type->Emit(nullptr))); |
| 171 | + } |
| 172 | + |
| 173 | + private: |
| 174 | + absl::StatusOr<std::vector<int64_t>> FoldDims( |
| 175 | + absl::Span<Expression* const> dims) { |
| 176 | + std::vector<int64_t> result(dims.size()); |
| 177 | + result.reserve(dims.size()); |
| 178 | + int i = 0; |
| 179 | + for (Expression* dim : dims) { |
| 180 | + XLS_ASSIGN_OR_RETURN(result[i++], FoldEntireExpr(dim)); |
| 181 | + } |
| 182 | + return result; |
| 183 | + } |
| 184 | + |
| 185 | + absl::StatusOr<std::vector<Def*>> FoldTypesOfDefs( |
| 186 | + absl::Span<Def* const> defs) { |
| 187 | + std::vector<Def*> result(defs.size()); |
| 188 | + int i = 0; |
| 189 | + for (Def* def : defs) { |
| 190 | + XLS_ASSIGN_OR_RETURN(DataType * folded_type, |
| 191 | + FoldConstants(def->data_type())); |
| 192 | + result[i++] = def->file()->Make<Def>( |
| 193 | + def->loc(), def->GetName(), def->data_kind(), folded_type, |
| 194 | + def->init().has_value() ? *def->init() : nullptr); |
| 195 | + } |
| 196 | + return result; |
| 197 | + } |
| 198 | + |
| 199 | + absl::StatusOr<Expression*> FoldBinaryOp(Operator* op) { |
| 200 | + auto* binop = dynamic_cast<BinaryInfix*>(op); |
| 201 | + XLS_ASSIGN_OR_RETURN(Expression * folded_lhs, FoldConstants(binop->lhs())); |
| 202 | + XLS_ASSIGN_OR_RETURN(Expression * folded_rhs, FoldConstants(binop->rhs())); |
| 203 | + if (!folded_lhs->IsLiteral() || !folded_rhs->IsLiteral()) { |
| 204 | + return op->file()->Make<BinaryInfix>(op->loc(), folded_lhs, folded_rhs, |
| 205 | + op->kind()); |
| 206 | + } |
| 207 | + Literal* lhs_literal = folded_lhs->AsLiteralOrDie(); |
| 208 | + Literal* rhs_literal = folded_rhs->AsLiteralOrDie(); |
| 209 | + XLS_ASSIGN_OR_RETURN(int64_t lhs_value, |
| 210 | + folded_lhs->AsLiteralOrDie()->ToInt64()); |
| 211 | + XLS_ASSIGN_OR_RETURN(int64_t rhs_value, |
| 212 | + folded_rhs->AsLiteralOrDie()->ToInt64()); |
| 213 | + bool signed_input = lhs_literal->is_declared_as_signed() && |
| 214 | + rhs_literal->is_declared_as_signed(); |
| 215 | + std::optional<bool> bool_result; |
| 216 | + std::optional<int64_t> int_result; |
| 217 | + switch (op->kind()) { |
| 218 | + case OperatorKind::kAdd: |
| 219 | + int_result = lhs_value + rhs_value; |
| 220 | + break; |
| 221 | + case OperatorKind::kSub: |
| 222 | + int_result = lhs_value - rhs_value; |
| 223 | + break; |
| 224 | + case OperatorKind::kMul: |
| 225 | + int_result = lhs_value * rhs_value; |
| 226 | + break; |
| 227 | + case OperatorKind::kDiv: |
| 228 | + int_result = lhs_value / rhs_value; |
| 229 | + break; |
| 230 | + case OperatorKind::kMod: |
| 231 | + int_result = lhs_value % rhs_value; |
| 232 | + break; |
| 233 | + case OperatorKind::kEq: |
| 234 | + bool_result = lhs_value == rhs_value; |
| 235 | + break; |
| 236 | + case OperatorKind::kNe: |
| 237 | + bool_result = lhs_value != rhs_value; |
| 238 | + break; |
| 239 | + case OperatorKind::kGe: |
| 240 | + bool_result = signed_input ? lhs_value >= rhs_value |
| 241 | + : static_cast<uint64_t>(lhs_value) >= |
| 242 | + static_cast<uint64_t>(rhs_value); |
| 243 | + break; |
| 244 | + case OperatorKind::kGt: |
| 245 | + bool_result = signed_input ? lhs_value > rhs_value |
| 246 | + : static_cast<uint64_t>(lhs_value) > |
| 247 | + static_cast<uint64_t>(rhs_value); |
| 248 | + break; |
| 249 | + case OperatorKind::kLe: |
| 250 | + bool_result = signed_input ? lhs_value <= rhs_value |
| 251 | + : static_cast<uint64_t>(lhs_value) <= |
| 252 | + static_cast<uint64_t>(rhs_value); |
| 253 | + break; |
| 254 | + case OperatorKind::kLt: |
| 255 | + bool_result = signed_input ? lhs_value < rhs_value |
| 256 | + : static_cast<uint64_t>(lhs_value) < |
| 257 | + static_cast<uint64_t>(rhs_value); |
| 258 | + break; |
| 259 | + default: |
| 260 | + break; |
| 261 | + } |
| 262 | + if (int_result.has_value()) { |
| 263 | + return MakeFoldedConstant(op, *int_result); |
| 264 | + } |
| 265 | + if (bool_result.has_value()) { |
| 266 | + return MakeFoldedConstant(op, static_cast<int64_t>(*bool_result)); |
| 267 | + } |
| 268 | + return op->file()->Make<BinaryInfix>(op->loc(), folded_lhs, folded_rhs, |
| 269 | + op->kind()); |
| 270 | + } |
| 271 | + |
| 272 | + const absl::flat_hash_map<Expression*, DataType*>& type_map_; |
| 273 | +}; |
| 274 | + |
| 275 | +} // namespace |
| 276 | + |
| 277 | +absl::StatusOr<DataType*> FoldVastConstants( |
| 278 | + DataType* data_type, |
| 279 | + const absl::flat_hash_map<Expression*, DataType*>& type_map) { |
| 280 | + auto context = std::make_unique<ConstantFoldingContext>(type_map); |
| 281 | + return context->FoldConstants(data_type); |
| 282 | +} |
| 283 | + |
| 284 | +absl::StatusOr<Expression*> FoldVastConstants( |
| 285 | + Expression* expr, |
| 286 | + const absl::flat_hash_map<Expression*, DataType*>& type_map) { |
| 287 | + auto context = std::make_unique<ConstantFoldingContext>(type_map); |
| 288 | + return context->FoldConstants(expr); |
| 289 | +} |
| 290 | + |
| 291 | +} // namespace verilog |
| 292 | +} // namespace xls |
0 commit comments