Skip to content

Commit e8972e6

Browse files
authored
Create hf_classification_trainer.py
1 parent 3a38c94 commit e8972e6

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

Diff for: examples/transformers/hf_classification_trainer.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#from https://huggingface.co/docs/transformers/en/training
2+
3+
from datasets import load_dataset
4+
from transformers import AutoTokenizer
5+
from transformers import AutoModelForSequenceClassification
6+
import numpy as np
7+
import evaluate
8+
from transformers import TrainingArguments, Trainer
9+
import torch
10+
import os
11+
#print("NVLINK DISABLE ENV?", os.getenv("NCCL_P2P_DISABLE"))
12+
tmpdir = os.getenv("TMPDIR")
13+
dataset = load_dataset("yelp_review_full")#, cache_dir=tmpdir)
14+
15+
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")#, cache_dir=tmpdir)
16+
17+
18+
def tokenize_function(examples):
19+
return tokenizer(examples["text"], padding="max_length", truncation=True)
20+
21+
22+
tokenized_datasets = dataset.map(tokenize_function, batched=True)
23+
24+
train_dataset = tokenized_datasets["train"].select(range(20000))
25+
eval_dataset = tokenized_datasets["test"].select(range(20000))
26+
27+
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)#, cache_dir=tmpdir)
28+
29+
def get_optimal_batch_size():
30+
if torch.cuda.is_available():
31+
total_memory = torch.cuda.get_device_properties(0).total_memory
32+
print("amout of GPU memory",total_memory)
33+
# Example logic to determine batch size, this may need to be adjusted
34+
# The base batch size and memory per batch can be tuned based on the model and sequence length
35+
base_batch_size = 8
36+
memory_per_batch = 0.70 * 1024 * 1024 * 1024 # Assume each batch takes 2GB of memory
37+
max_batch_size = int(total_memory // memory_per_batch)
38+
return min(max_batch_size, 60)
39+
else:
40+
print("cuda not found")
41+
# Default batch size if no GPU is available
42+
return 2
43+
44+
batch_size = get_optimal_batch_size()
45+
46+
print("BATCH SIZE", batch_size)
47+
48+
training_args = TrainingArguments(output_dir="output_dir",
49+
dataloader_num_workers=2,
50+
dataloader_pin_memory=True,
51+
per_device_train_batch_size=batch_size,
52+
per_device_eval_batch_size=batch_size*2,
53+
eval_strategy="epoch",
54+
save_strategy="epoch")
55+
56+
metric = evaluate.load("accuracy")
57+
58+
def compute_metrics(eval_pred):
59+
logits, labels = eval_pred
60+
predictions = np.argmax(logits, axis=-1)
61+
return metric.compute(predictions=predictions, references=labels)
62+
63+
trainer = Trainer(
64+
model=model,
65+
args=training_args,
66+
train_dataset=train_dataset,
67+
eval_dataset=eval_dataset,
68+
compute_metrics=compute_metrics,
69+
)
70+
71+
trainer.train()

0 commit comments

Comments
 (0)