1
- from dataclasses import dataclass , field , InitVar
2
- from typing import List , Optional , Union
3
1
import warnings
2
+ from dataclasses import InitVar , dataclass , field
3
+ from typing import List , Optional , Union
4
4
5
5
from hip .models .hip_attention .gen3 .attention_metadata import ScanStage
6
6
7
-
8
7
_DEFAULT_STAGES = [
9
8
ScanStage (
10
9
stage_block_size_q = 64 ,
@@ -35,7 +34,7 @@ class HiPAttentionPerLayerConfig:
35
34
second_stage_k : int = 2048
36
35
sliding_window_size : int = 1024
37
36
sink_token_size : int = 256
38
- sa_extend_backend : str = ' streaming'
37
+ sa_extend_backend : str = " streaming"
39
38
scan_extend_backend : Optional [str ] = None
40
39
stages : list [ScanStage ] = field (default_factory = lambda : _DEFAULT_STAGES )
41
40
@@ -44,47 +43,50 @@ class HiPAttentionPerLayerConfig:
44
43
def __post_init__ (self , parsed_json : dict | None ):
45
44
super ().__init__ ()
46
45
if parsed_json is not None :
47
- if 'second_stage_k' in parsed_json :
48
- self .second_stage_k = parsed_json ['second_stage_k' ]
49
- parsed_json .pop ('second_stage_k' )
50
- if 'sliding_window_size' in parsed_json :
51
- self .sliding_window_size = parsed_json ['sliding_window_size' ]
52
- parsed_json .pop ('sliding_window_size' )
53
- if 'sink_token_size' in parsed_json :
54
- self .sink_token_size = parsed_json ['sink_token_size' ]
55
- parsed_json .pop ('sink_token_size' )
56
- if 'sa_extend_backend' in parsed_json :
57
- self .sa_extend_backend = parsed_json ['sa_extend_backend' ]
58
- parsed_json .pop ('sa_extend_backend' )
59
- if 'scan_extend_backend' in parsed_json :
60
- self .scan_extend_backend = parsed_json ['scan_extend_backend' ]
61
- parsed_json .pop ('scan_extend_backend' )
62
- if 'stages' in parsed_json :
63
- self .stages = [
64
- ScanStage (** stage )
65
- for stage in parsed_json ['stages' ]
66
- ]
67
- parsed_json .pop ('stages' )
46
+ if "second_stage_k" in parsed_json :
47
+ self .second_stage_k = parsed_json ["second_stage_k" ]
48
+ parsed_json .pop ("second_stage_k" )
49
+ if "sliding_window_size" in parsed_json :
50
+ self .sliding_window_size = parsed_json ["sliding_window_size" ]
51
+ parsed_json .pop ("sliding_window_size" )
52
+ if "sink_token_size" in parsed_json :
53
+ self .sink_token_size = parsed_json ["sink_token_size" ]
54
+ parsed_json .pop ("sink_token_size" )
55
+ if "sa_extend_backend" in parsed_json :
56
+ self .sa_extend_backend = parsed_json ["sa_extend_backend" ]
57
+ parsed_json .pop ("sa_extend_backend" )
58
+ if "scan_extend_backend" in parsed_json :
59
+ self .scan_extend_backend = parsed_json ["scan_extend_backend" ]
60
+ parsed_json .pop ("scan_extend_backend" )
61
+ if "stages" in parsed_json :
62
+ self .stages = [ScanStage (** stage ) for stage in parsed_json ["stages" ]]
63
+ parsed_json .pop ("stages" )
68
64
if parsed_json :
69
- raise ValueError (f' Unknown keys in json: { parsed_json .keys ()} ' )
65
+ raise ValueError (f" Unknown keys in json: { parsed_json .keys ()} " )
70
66
71
67
72
68
@dataclass
73
69
class HiPAttentionConfig :
74
70
dense_layers : list [int ] = field (default_factory = lambda : [0 , 1 , 2 ])
75
71
block_sparse_block_size_q : int = 64
76
72
metadata_cache_max_batch_size : int = 32
77
- mask_refresh_interval : Union [int , List [int ]] = field (default_factory = lambda : [32 , 16 , 8 ])
73
+ mask_refresh_interval : Union [int , List [int ]] = field (
74
+ default_factory = lambda : [32 , 16 , 8 ]
75
+ )
78
76
using_extend : bool = True
79
- layers : list [HiPAttentionPerLayerConfig ] = field (default_factory = lambda : [
80
- HiPAttentionPerLayerConfig (parsed_json = {"second_stage_k" : 4096 , "sliding_window_size" : 1024 , "sink_token_size" : 256 }),
81
- HiPAttentionPerLayerConfig (),
82
- ])
83
- prefill_layers : list [HiPAttentionPerLayerConfig ] = field (default_factory = lambda : [
84
- HiPAttentionPerLayerConfig (parsed_json = {"second_stage_k" : 4096 , "sliding_window_size" : 1024 , "sink_token_size" : 256 }),
85
- HiPAttentionPerLayerConfig (),
86
- ])
87
-
77
+ layers : list [HiPAttentionPerLayerConfig ] = field (
78
+ default_factory = lambda : [
79
+ HiPAttentionPerLayerConfig (
80
+ parsed_json = {
81
+ "second_stage_k" : 4096 ,
82
+ "sliding_window_size" : 1024 ,
83
+ "sink_token_size" : 256 ,
84
+ }
85
+ ),
86
+ HiPAttentionPerLayerConfig (),
87
+ ]
88
+ )
89
+
88
90
# deprecated
89
91
apply_v_dot : bool = False
90
92
prefill_always_dense : bool = False
@@ -96,58 +98,64 @@ class HiPAttentionConfig:
96
98
97
99
def __post_init__ (self , parsed_json : dict | None ):
98
100
super ().__init__ ()
99
-
101
+
100
102
if parsed_json is not None :
101
- if 'apply_v_dot' in parsed_json :
102
- self .apply_v_dot = parsed_json ['apply_v_dot' ]
103
- parsed_json .pop ('apply_v_dot' )
104
- if 'dense_layers' in parsed_json :
105
- self .dense_layers = parsed_json ['dense_layers' ]
106
- parsed_json .pop ('dense_layers' )
107
- if 'prefill_always_dense' in parsed_json :
108
- self .prefill_always_dense = parsed_json ['prefill_always_dense' ]
109
- parsed_json .pop ('prefill_always_dense' )
110
- if 'decode_always_dense' in parsed_json :
111
- self .decode_always_dense = parsed_json ['decode_always_dense' ]
112
- parsed_json .pop ('decode_always_dense' )
113
- if 'force_dense' in parsed_json :
114
- self .force_dense = parsed_json ['force_dense' ]
115
- parsed_json .pop ('force_dense' )
116
- if 'prefill_dense_threshold' in parsed_json :
117
- self .prefill_dense_threshold = parsed_json ['prefill_dense_threshold' ]
118
- parsed_json .pop ('prefill_dense_threshold' )
119
- if 'block_sparse_block_size_q' in parsed_json :
120
- self .block_sparse_block_size_q = parsed_json ['block_sparse_block_size_q' ]
121
- parsed_json .pop ('block_sparse_block_size_q' )
122
- if 'metadata_cache_max_batch_size' in parsed_json :
123
- self .metadata_cache_max_batch_size = parsed_json ['metadata_cache_max_batch_size' ]
124
- parsed_json .pop ('metadata_cache_max_batch_size' )
125
- if 'mask_refresh_interval' in parsed_json :
126
- assert isinstance (parsed_json ['mask_refresh_interval' ], (int , list ))
127
- self .mask_refresh_interval = parsed_json ['mask_refresh_interval' ]
128
- parsed_json .pop ('mask_refresh_interval' )
129
- if 'using_extend' in parsed_json :
130
- self .using_extend = parsed_json ['using_extend' ]
131
- parsed_json .pop ('using_extend' )
132
- if 'layers' in parsed_json :
103
+ if "apply_v_dot" in parsed_json :
104
+ self .apply_v_dot = parsed_json ["apply_v_dot" ]
105
+ parsed_json .pop ("apply_v_dot" )
106
+ if "dense_layers" in parsed_json :
107
+ self .dense_layers = parsed_json ["dense_layers" ]
108
+ parsed_json .pop ("dense_layers" )
109
+ if "prefill_always_dense" in parsed_json :
110
+ self .prefill_always_dense = parsed_json ["prefill_always_dense" ]
111
+ parsed_json .pop ("prefill_always_dense" )
112
+ if "decode_always_dense" in parsed_json :
113
+ self .decode_always_dense = parsed_json ["decode_always_dense" ]
114
+ parsed_json .pop ("decode_always_dense" )
115
+ if "force_dense" in parsed_json :
116
+ self .force_dense = parsed_json ["force_dense" ]
117
+ parsed_json .pop ("force_dense" )
118
+ if "prefill_dense_threshold" in parsed_json :
119
+ self .prefill_dense_threshold = parsed_json ["prefill_dense_threshold" ]
120
+ parsed_json .pop ("prefill_dense_threshold" )
121
+ if "block_sparse_block_size_q" in parsed_json :
122
+ self .block_sparse_block_size_q = parsed_json [
123
+ "block_sparse_block_size_q"
124
+ ]
125
+ parsed_json .pop ("block_sparse_block_size_q" )
126
+ if "metadata_cache_max_batch_size" in parsed_json :
127
+ self .metadata_cache_max_batch_size = parsed_json [
128
+ "metadata_cache_max_batch_size"
129
+ ]
130
+ parsed_json .pop ("metadata_cache_max_batch_size" )
131
+ if "mask_refresh_interval" in parsed_json :
132
+ assert isinstance (parsed_json ["mask_refresh_interval" ], (int , list ))
133
+ self .mask_refresh_interval = parsed_json ["mask_refresh_interval" ]
134
+ parsed_json .pop ("mask_refresh_interval" )
135
+ if "using_extend" in parsed_json :
136
+ self .using_extend = parsed_json ["using_extend" ]
137
+ parsed_json .pop ("using_extend" )
138
+ if "layers" in parsed_json :
133
139
self .layers = [
134
140
HiPAttentionPerLayerConfig (parsed_json = layer )
135
- for layer in parsed_json [' layers' ]
141
+ for layer in parsed_json [" layers" ]
136
142
]
137
143
self .prefill_layers = self .layers
138
- parsed_json .pop (' layers' )
139
- if ' prefill_layers' in parsed_json :
144
+ parsed_json .pop (" layers" )
145
+ if " prefill_layers" in parsed_json :
140
146
self .prefill_layers = [
141
147
HiPAttentionPerLayerConfig (parsed_json = layer )
142
- for layer in parsed_json [' prefill_layers' ]
148
+ for layer in parsed_json [" prefill_layers" ]
143
149
]
144
- parsed_json .pop (' prefill_layers' )
150
+ parsed_json .pop (" prefill_layers" )
145
151
if parsed_json :
146
- raise Exception ( f' Unknown keys in json: { parsed_json .keys ()} ' )
147
-
152
+ raise ValueError ( f" Unknown keys in json: { parsed_json .keys ()} " )
153
+
148
154
num_stages = len (self .layers [0 ].stages )
149
155
for layer_config in self .layers :
150
156
assert num_stages == len (layer_config .stages )
151
-
157
+
152
158
if isinstance (self .mask_refresh_interval , int ):
153
- self .mask_refresh_interval = [self .mask_refresh_interval , ] * num_stages
159
+ self .mask_refresh_interval = [
160
+ self .mask_refresh_interval ,
161
+ ] * num_stages
0 commit comments