@@ -707,8 +707,7 @@ def __init__(self, netlist: _nir.Netlist, design: Design, *, all_undef_to_ff=Fal
707
707
self .drivers = _ast .SignalDict ()
708
708
self .io_ports : dict [_ast .IOPort , int ] = {}
709
709
self .rhs_cache : dict [int , tuple [_nir .Value , bool , _ast .Value ]] = {}
710
- self .matches_cache = {}
711
- self .priority_match_cache = {}
710
+ self .match_cache = {}
712
711
self .fragment_module_idx : dict [Fragment , int ] = {}
713
712
714
713
# Collected for driver conflict diagnostics only.
@@ -787,24 +786,14 @@ def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src
787
786
op = _nir .Operator (module_idx , operator = operator , inputs = inputs , src_loc = src_loc )
788
787
return self .netlist .add_value_cell (op .width , op )
789
788
790
- def emit_matches (self , module_idx : int , value : _nir .Value , patterns , * , src_loc ):
791
- key = module_idx , value , patterns , src_loc
789
+ def emit_match (self , module_idx : int , en : _nir . Net , value : _nir .Value , patterns , * , src_loc ):
790
+ key = module_idx , en , value , patterns , src_loc
792
791
try :
793
- return self .matches_cache [key ]
792
+ return self .match_cache [key ]
794
793
except KeyError :
795
- cell = _nir .Matches (module_idx , value = value , patterns = patterns , src_loc = src_loc )
796
- net , = self .netlist .add_value_cell (1 , cell )
797
- self .matches_cache [key ] = net
798
- return net
799
-
800
- def emit_priority_match (self , module_idx : int , en : _nir .Net , inputs : _nir .Value , * , src_loc ):
801
- key = module_idx , en , inputs , src_loc
802
- try :
803
- return self .priority_match_cache [key ]
804
- except KeyError :
805
- cell = _nir .PriorityMatch (module_idx , en = en , inputs = inputs , src_loc = src_loc )
806
- res = self .netlist .add_value_cell (len (inputs ), cell )
807
- self .priority_match_cache [key ] = res
794
+ cell = _nir .Match (module_idx , en = en , value = value , patterns = patterns , src_loc = src_loc )
795
+ res = self .netlist .add_value_cell (len (patterns ), cell )
796
+ self .match_cache [key ] = res
808
797
return res
809
798
810
799
def unify_shapes_bitwise (self ,
@@ -956,17 +945,16 @@ def emit_rhs(self, module_idx: int, value: _ast.Value) -> tuple[_nir.Value, bool
956
945
result = self .emit_operator (module_idx , 'm' , test , operand_a , operand_b ,
957
946
src_loc = value .src_loc )
958
947
else :
959
- conds = []
960
948
elems = []
961
- for patterns , elem , in value . cases :
962
- if patterns is not None :
963
- net = self . emit_matches ( module_idx , test , patterns , src_loc = value . src_loc )
964
- conds .append (net )
949
+ patterns = []
950
+ for pattern_list , elem , in value . cases :
951
+ if pattern_list is not None :
952
+ patterns .append (pattern_list )
965
953
else :
966
- conds .append (_nir . Net . from_const ( 1 ))
954
+ patterns .append (( "-" * len ( test ), ))
967
955
elems .append (self .emit_rhs (module_idx , elem ))
968
- conds = self .emit_priority_match (module_idx , _nir .Net .from_const (1 ),
969
- _nir . Value ( conds ), src_loc = value .src_loc )
956
+ conds = self .emit_match (module_idx , _nir .Net .from_const (1 ), test , tuple ( patterns ),
957
+ src_loc = value .src_loc )
970
958
shape = _ast .Shape ._unify (
971
959
_ast .Shape (len (value ), signed )
972
960
for value , signed in elems
@@ -1056,14 +1044,10 @@ def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.V
1056
1044
offset , _signed = self .emit_rhs (module_idx , lhs .offset )
1057
1045
width = len (lhs .value )
1058
1046
num_cases = min ((width + lhs .stride - 1 ) // lhs .stride , 1 << len (offset ))
1059
- conds = []
1047
+ patterns = []
1060
1048
for case_index in range (num_cases ):
1061
- subcond = self .emit_matches (module_idx , offset ,
1062
- (to_binary (case_index , len (offset )),),
1063
- src_loc = lhs .src_loc )
1064
- conds .append (subcond )
1065
- conds = self .emit_priority_match (module_idx , cond , _nir .Value (conds ),
1066
- src_loc = lhs .src_loc )
1049
+ patterns .append ((to_binary (case_index , len (offset )),))
1050
+ conds = self .emit_match (module_idx , cond , offset , tuple (patterns ), src_loc = lhs .src_loc )
1067
1051
for idx , subcond in enumerate (conds ):
1068
1052
start = lhs_start + idx * lhs .stride
1069
1053
if start >= width :
@@ -1075,17 +1059,15 @@ def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.V
1075
1059
self .emit_assign (module_idx , cd , lhs .value , start , subrhs , subcond , src_loc = src_loc )
1076
1060
elif isinstance (lhs , _ast .SwitchValue ):
1077
1061
test , _signed = self .emit_rhs (module_idx , lhs .test )
1078
- conds = []
1062
+ patterns = []
1079
1063
elems = []
1080
- for patterns , elem in lhs .cases :
1081
- if patterns is not None :
1082
- net = self .emit_matches (module_idx , test , patterns , src_loc = lhs .src_loc )
1083
- conds .append (net )
1064
+ for pattern_list , elem in lhs .cases :
1065
+ if pattern_list is not None :
1066
+ patterns .append (pattern_list )
1084
1067
else :
1085
- conds .append (_nir . Net . from_const ( 1 ))
1068
+ patterns .append (( "-" * len ( test ), ))
1086
1069
elems .append (elem )
1087
- conds = self .emit_priority_match (module_idx , cond , _nir .Value (conds ),
1088
- src_loc = lhs .src_loc )
1070
+ conds = self .emit_match (module_idx , cond , test , tuple (patterns ), src_loc = lhs .src_loc )
1089
1071
for subcond , val in zip (conds , elems ):
1090
1072
self .emit_assign (module_idx , cd , val , lhs_start , rhs [:len (val )], subcond , src_loc = src_loc )
1091
1073
elif isinstance (lhs , _ast .Operator ):
@@ -1166,17 +1148,15 @@ def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str,
1166
1148
self .netlist .add_cell (cell )
1167
1149
elif isinstance (stmt , _ast .Switch ):
1168
1150
test , _signed = self .emit_rhs (module_idx , stmt .test )
1169
- conds = []
1151
+ patterns = []
1170
1152
case_stmts = []
1171
- for patterns , stmts , case_src_loc in stmt .cases :
1172
- if patterns is not None :
1173
- net = self .emit_matches (module_idx , test , patterns , src_loc = case_src_loc )
1174
- conds .append (net )
1153
+ for pattern_list , stmts , case_src_loc in stmt .cases :
1154
+ if pattern_list is not None :
1155
+ patterns .append (pattern_list )
1175
1156
else :
1176
- conds .append (_nir . Net . from_const ( 1 ))
1157
+ patterns .append (( "-" * len ( test ), ))
1177
1158
case_stmts .append (stmts )
1178
- conds = self .emit_priority_match (module_idx , cond , _nir .Value (conds ),
1179
- src_loc = stmt .src_loc )
1159
+ conds = self .emit_match (module_idx , cond , test , tuple (patterns ), src_loc = stmt .src_loc )
1180
1160
for subcond , substmts in zip (conds , case_stmts ):
1181
1161
for substmt in substmts :
1182
1162
self .emit_stmt (module_idx , fragment , domain , substmt , subcond )
@@ -1430,13 +1410,10 @@ def emit_drivers(self):
1430
1410
driver .domain .rst is not None and
1431
1411
not driver .domain .async_reset and
1432
1412
not driver .signal .reset_less ):
1433
- cond = self .emit_matches (driver .module_idx ,
1413
+ cond , = self .emit_match (driver .module_idx , _nir . Net . from_const ( 1 ) ,
1434
1414
self .emit_signal (driver .domain .rst ),
1435
- ("1" ,),
1415
+ (( "1" ,) ,),
1436
1416
src_loc = driver .domain .rst .src_loc )
1437
- cond , = self .emit_priority_match (driver .module_idx , _nir .Net .from_const (1 ),
1438
- _nir .Value (cond ),
1439
- src_loc = driver .domain .rst .src_loc )
1440
1417
init = _nir .Value .from_const (driver .signal .init , len (driver .signal ))
1441
1418
driver .assignments .append (_nir .Assignment (cond = cond , start = 0 ,
1442
1419
value = init , src_loc = driver .signal .src_loc ))
0 commit comments