Skip to content

Commit 0dd75b0

Browse files
Run gemma3 on gpus
1 parent 5120ce1 commit 0dd75b0

File tree

6 files changed

+21
-24
lines changed

6 files changed

+21
-24
lines changed

gemma/gemma3_model.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def forward(self,
9797
self.global_freqs_cis.index_select(0, input_positions)
9898
)
9999
hidden_states = self.text_token_embedder(input_token_ids)
100-
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
100+
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device)
101101
hidden_states = hidden_states * normalizer
102102
if image_patches is not None and self.config.vision_config is not None:
103103
# the input has images
@@ -127,7 +127,7 @@ def forward(self,
127127
embedder_weight = self.text_token_embedder.weight
128128
if self.config.quant:
129129
embedder_weight = (
130-
embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
130+
embedder_weight * self.text_token_embedder.weight_scaler.unsqueeze(-1))
131131

132132
next_tokens, logits = self.sampler(
133133
embedding=embedder_weight,
@@ -162,7 +162,7 @@ def populate_image_embeddings(self,
162162

163163
def create_attention_mask(self, input_ids: torch.Tensor, sequence_length: int):
164164
batch_size = input_ids.shape[0]
165-
causal_mask = torch.tril(torch.ones((batch_size, 1, sequence_length, sequence_length), dtype=torch.bool))
165+
causal_mask = torch.tril(torch.ones((batch_size, 1, sequence_length, sequence_length), dtype=torch.bool, device=input_ids.device))
166166
image_token_mask = input_ids == self.tokenizer.image_token_placeholder_id
167167
# Pad the mask to the left with 0. This is to make sure the boundary
168168
# detection works correctly. Boundary (starting index of image patch) is
@@ -202,7 +202,7 @@ def create_attention_mask(self, input_ids: torch.Tensor, sequence_length: int):
202202
# local attention is within the sliding window.
203203
local_mask = torch.logical_and(
204204
attention_mask,
205-
torch.triu(torch.ones((1, 1, sequence_length, sequence_length), dtype=torch.bool), diagonal=-(self.config.sliding_window_size-1))
205+
torch.triu(torch.ones((1, 1, sequence_length, sequence_length), dtype=torch.bool, device=input_ids.device), diagonal=-(self.config.sliding_window_size-1))
206206
)
207207
return attention_mask, local_mask
208208

@@ -233,8 +233,8 @@ def generate(
233233
if self.config.sliding_window_size is None:
234234
raise ValueError('gemma 3 model requires sliding_window size')
235235
boolean_mask, local_boolean_mask = self.create_attention_mask(user_input_token_ids, total_seq_len)
236-
mask_tensor = torch.where(boolean_mask, 0, min_dtype).contiguous()
237-
local_mask_tensor = torch.where(local_boolean_mask, 0, min_dtype).contiguous()
236+
mask_tensor = torch.where(boolean_mask, 0, torch.tensor(min_dtype, dtype=torch.float32, device=device)).contiguous()
237+
local_mask_tensor = torch.where(local_boolean_mask, 0, torch.tensor(min_dtype, dtype=torch.float32, device=device)).contiguous()
238238

239239
kv_caches = []
240240
for _ in range(self.config.num_hidden_layers):
@@ -247,25 +247,22 @@ def generate(
247247

248248
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
249249
self.tokenizer.pad_id,
250-
dtype=torch.int64)
250+
dtype=torch.int64, device=device)
251251
token_ids_tensor = user_input_token_ids.to(device)
252252
for i in range(batch_size):
253253
p = user_input_token_ids[i]
254254
input_token_ids_tensor[i, :min_prompt_len] = p[:min_prompt_len]
255255

256-
token_ids_tensor = token_ids_tensor.to(device)
257-
input_token_ids_tensor = input_token_ids_tensor.to(device)
258-
input_positions_tensor = torch.arange(0, min_prompt_len, dtype=torch.int64).to(device)
256+
input_positions_tensor = torch.arange(0, min_prompt_len, dtype=torch.int64, device=device)
259257
prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
260258
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
261259
curr_local_mask_tensor = local_mask_tensor.index_select(2, input_positions_tensor)
262-
output_positions_tensor = torch.LongTensor([min_prompt_len - 1])
260+
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
263261
temperatures_tensor = None if not temperature else torch.FloatTensor(
264262
[temperature] * batch_size).to(device)
265263
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
266264
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
267-
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
268-
device)
265+
output_index = torch.tensor(min_prompt_len, dtype=torch.int64, device=device)
269266

270267
# Prefill up to min_prompt_len tokens, then treat other prefill as
271268
# decode and ignore output.
@@ -298,8 +295,7 @@ def generate(
298295
curr_local_mask_tensor = local_mask_tensor.index_select(
299296
2, input_positions_tensor
300297
) if local_mask_tensor is not None else None
301-
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
302-
device)
298+
output_positions_tensor = torch.tensor(0, dtype=torch.int64, device=device)
303299
output_index = output_index + 1
304300
image_batch = None
305301
image_presence_mask = None

gemma/gemma3_preprocessor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def tokenize_raw_input(
172172
config.vision_config.input_channels,
173173
config.vision_config.image_size,
174174
config.vision_config.image_size,
175-
)
175+
), device=device
176176
)
177177
for _ in range(pad_length)
178178
]
@@ -182,12 +182,12 @@ def tokenize_raw_input(
182182
image_presence_mask.append(presence_mask)
183183

184184
# Convert lists to tensors
185-
user_input_token_ids = torch.tensor(user_input_token_ids, dtype=torch.long).to(device)
185+
user_input_token_ids = torch.tensor(user_input_token_ids, dtype=torch.long, device=device)
186186
if max_num_images > 0:
187187
image_batch = torch.stack([torch.stack(images) for images in image_batch]).to(
188188
device
189189
)
190-
image_presence_mask = torch.tensor(image_presence_mask, dtype=torch.bool).to(device)
190+
image_presence_mask = torch.tensor(image_presence_mask, dtype=torch.bool, device=device)
191191
else:
192192
image_batch = None
193193
image_presence_mask = None

gemma/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def forward(
594594
# Gemma normalizes the embedding by sqrt(hidden_size).
595595
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
596596
# See https://github.com/huggingface/transformers/pull/29402
597-
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
597+
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device)
598598
hidden_states = hidden_states * normalizer
599599

600600
hidden_states = self.model(
@@ -677,7 +677,7 @@ def generate(
677677
curr_local_mask_tensor = local_mask_tensor.index_select(
678678
2, input_positions_tensor
679679
) if local_mask_tensor is not None else None
680-
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
680+
output_positions_tensor = torch.LongTensor([min_prompt_len - 1], device=device)
681681
temperatures_tensor = None if not temperature else torch.FloatTensor(
682682
[temperature] * batch_size).to(device)
683683
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)

gemma/siglip_vision/siglip_vision_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def gelu_tanh(self, x):
104104
* (
105105
1
106106
+ torch.tanh(
107-
torch.sqrt(torch.tensor(2.0 / torch.pi))
107+
torch.sqrt(torch.tensor(2.0 / torch.pi, device=x.device))
108108
* (x + 0.044715 * torch.pow(x, 3))
109109
)
110110
)
@@ -192,7 +192,8 @@ def forward(
192192
# (batch_size,channels,height,width)->(batch_size, height*width, channels)
193193
x = x.flatten(2).transpose(1, 2)
194194

195-
x = x + self.position_embedding(self.position_ids)
195+
position_ids = self.position_ids.to(pixel_values.device)
196+
x = x + self.position_embedding(position_ids)
196197

197198
for block in self.encoder_blocks:
198199
x = block(x) # batch_size, height*width, embedding_dim (1152)

scripts/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _set_default_tensor_type(dtype: torch.dtype):
6868
def main(_):
6969
# Construct the model config.
7070
model_config = config.get_model_config(FLAGS.variant)
71-
model_config.dtype = "float32" if FLAGS.device == "cpu" else "float16"
71+
model_config.dtype = "float32"
7272
model_config.quant = FLAGS.quant
7373

7474
# Seed random.

scripts/run_multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _set_default_tensor_type(dtype: torch.dtype):
8585
def main(_):
8686
# Construct the model config.
8787
model_config = config.get_model_config(_VARIANT.value)
88-
model_config.dtype = 'float32' if _DEVICE.value == 'cpu' else 'float16'
88+
model_config.dtype = 'float32'
8989
model_config.quant = _QUANT.value
9090
image_paths = {"cow_in_beach": "scripts/images/cow_in_beach.jpg",
9191
"lilly": "scripts/images/lilly.jpg",

0 commit comments

Comments
 (0)