@@ -97,7 +97,7 @@ def forward(self,
9797 self .global_freqs_cis .index_select (0 , input_positions )
9898 )
9999 hidden_states = self .text_token_embedder (input_token_ids )
100- normalizer = torch .tensor (self .config .hidden_size ** 0.5 , dtype = hidden_states .dtype )
100+ normalizer = torch .tensor (self .config .hidden_size ** 0.5 , dtype = hidden_states .dtype , device = hidden_states . device )
101101 hidden_states = hidden_states * normalizer
102102 if image_patches is not None and self .config .vision_config is not None :
103103 # the input has images
@@ -127,7 +127,7 @@ def forward(self,
127127 embedder_weight = self .text_token_embedder .weight
128128 if self .config .quant :
129129 embedder_weight = (
130- embedder_weight * self .embedder .weight_scaler .unsqueeze (- 1 ))
130+ embedder_weight * self .text_token_embedder .weight_scaler .unsqueeze (- 1 ))
131131
132132 next_tokens , logits = self .sampler (
133133 embedding = embedder_weight ,
@@ -162,7 +162,7 @@ def populate_image_embeddings(self,
162162
163163 def create_attention_mask (self , input_ids : torch .Tensor , sequence_length : int ):
164164 batch_size = input_ids .shape [0 ]
165- causal_mask = torch .tril (torch .ones ((batch_size , 1 , sequence_length , sequence_length ), dtype = torch .bool ))
165+ causal_mask = torch .tril (torch .ones ((batch_size , 1 , sequence_length , sequence_length ), dtype = torch .bool , device = input_ids . device ))
166166 image_token_mask = input_ids == self .tokenizer .image_token_placeholder_id
167167 # Pad the mask to the left with 0. This is to make sure the boundary
168168 # detection works correctly. Boundary (starting index of image patch) is
@@ -202,7 +202,7 @@ def create_attention_mask(self, input_ids: torch.Tensor, sequence_length: int):
202202 # local attention is within the sliding window.
203203 local_mask = torch .logical_and (
204204 attention_mask ,
205- torch .triu (torch .ones ((1 , 1 , sequence_length , sequence_length ), dtype = torch .bool ), diagonal = - (self .config .sliding_window_size - 1 ))
205+ torch .triu (torch .ones ((1 , 1 , sequence_length , sequence_length ), dtype = torch .bool , device = input_ids . device ), diagonal = - (self .config .sliding_window_size - 1 ))
206206 )
207207 return attention_mask , local_mask
208208
@@ -233,8 +233,8 @@ def generate(
233233 if self .config .sliding_window_size is None :
234234 raise ValueError ('gemma 3 model requires sliding_window size' )
235235 boolean_mask , local_boolean_mask = self .create_attention_mask (user_input_token_ids , total_seq_len )
236- mask_tensor = torch .where (boolean_mask , 0 , min_dtype ).contiguous ()
237- local_mask_tensor = torch .where (local_boolean_mask , 0 , min_dtype ).contiguous ()
236+ mask_tensor = torch .where (boolean_mask , 0 , torch . tensor ( min_dtype , dtype = torch . float32 , device = device ) ).contiguous ()
237+ local_mask_tensor = torch .where (local_boolean_mask , 0 , torch . tensor ( min_dtype , dtype = torch . float32 , device = device ) ).contiguous ()
238238
239239 kv_caches = []
240240 for _ in range (self .config .num_hidden_layers ):
@@ -247,25 +247,22 @@ def generate(
247247
248248 input_token_ids_tensor = torch .full ((batch_size , min_prompt_len ),
249249 self .tokenizer .pad_id ,
250- dtype = torch .int64 )
250+ dtype = torch .int64 , device = device )
251251 token_ids_tensor = user_input_token_ids .to (device )
252252 for i in range (batch_size ):
253253 p = user_input_token_ids [i ]
254254 input_token_ids_tensor [i , :min_prompt_len ] = p [:min_prompt_len ]
255255
256- token_ids_tensor = token_ids_tensor .to (device )
257- input_token_ids_tensor = input_token_ids_tensor .to (device )
258- input_positions_tensor = torch .arange (0 , min_prompt_len , dtype = torch .int64 ).to (device )
256+ input_positions_tensor = torch .arange (0 , min_prompt_len , dtype = torch .int64 , device = device )
259257 prompt_mask_tensor = token_ids_tensor != self .tokenizer .pad_id
260258 curr_mask_tensor = mask_tensor .index_select (2 , input_positions_tensor )
261259 curr_local_mask_tensor = local_mask_tensor .index_select (2 , input_positions_tensor )
262- output_positions_tensor = torch .LongTensor ([min_prompt_len - 1 ])
260+ output_positions_tensor = torch .LongTensor ([min_prompt_len - 1 ]). to ( device )
263261 temperatures_tensor = None if not temperature else torch .FloatTensor (
264262 [temperature ] * batch_size ).to (device )
265263 top_ps_tensor = torch .FloatTensor ([top_p ] * batch_size ).to (device )
266264 top_ks_tensor = torch .LongTensor ([top_k ] * batch_size ).to (device )
267- output_index = torch .tensor (min_prompt_len , dtype = torch .int64 ).to (
268- device )
265+ output_index = torch .tensor (min_prompt_len , dtype = torch .int64 , device = device )
269266
270267 # Prefill up to min_prompt_len tokens, then treat other prefill as
271268 # decode and ignore output.
@@ -298,8 +295,7 @@ def generate(
298295 curr_local_mask_tensor = local_mask_tensor .index_select (
299296 2 , input_positions_tensor
300297 ) if local_mask_tensor is not None else None
301- output_positions_tensor = torch .tensor (0 , dtype = torch .int64 ).to (
302- device )
298+ output_positions_tensor = torch .tensor (0 , dtype = torch .int64 , device = device )
303299 output_index = output_index + 1
304300 image_batch = None
305301 image_presence_mask = None
0 commit comments