Skip to content

Commit 10b74e8

Browse files
【开源实习】 Albert 模型微调 (#2008)
1 parent 59c6eda commit 10b74e8

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

llm/finetune/albert/Albert_mind.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import random
2+
import mindspore as ms
3+
from mindspore import nn, ops, Tensor
4+
from mindspore.dataset import GeneratorDataset
5+
from mindnlp.transformers import AlbertTokenizer, AlbertForSequenceClassification
6+
from mindnlp.engine import Trainer, TrainingArguments
7+
from datasets import load_dataset
8+
import numpy as np
9+
import os
10+
import evaluate
11+
12+
# 1. 加载预训练模型和分词器
13+
model_name = "albert-base-v1"
14+
tokenizer = AlbertTokenizer.from_pretrained(model_name)
15+
model = AlbertForSequenceClassification.from_pretrained(
16+
model_name, num_labels=2)
17+
18+
# 2. 加载IMDb数据集
19+
dataset = load_dataset("stanfordnlp/imdb", trust_remote_code=True)
20+
print("dataset:", dataset)
21+
# 3. 数据预处理函数
22+
23+
24+
def tokenize_function(examples):
25+
tokenized = tokenizer(
26+
examples["text"],
27+
padding="max_length",
28+
truncation=True,
29+
max_length=512
30+
)
31+
# 添加标签到返回字典
32+
tokenized["labels"] = examples["label"]
33+
return tokenized
34+
35+
36+
# 应用预处理
37+
tokenized_datasets = dataset.map(tokenize_function, batched=True)
38+
39+
# 检查标签分布(修正后的代码)
40+
print("\n==== 数据分布验证 ====")
41+
42+
# 检查训练集
43+
train_labels = np.array(tokenized_datasets["train"]["labels"])
44+
print("训练集标签统计:")
45+
print("- 唯一值:", np.unique(train_labels))
46+
print("- 分布:", np.bincount(train_labels))
47+
48+
# 检查测试集
49+
test_labels = np.array(tokenized_datasets["test"]["labels"])
50+
print("\n测试集标签统计:")
51+
print("- 唯一值:", np.unique(test_labels))
52+
print("- 分布:", np.bincount(test_labels))
53+
# 4. 转换数据集格式
54+
55+
def create_dataset(data, batch_size=8):
56+
# 将数据转换为列表以便打乱
57+
data_list = list(data)
58+
random.shuffle(data_list) # 打乱数据顺序
59+
60+
def generator():
61+
for item in data_list: # 遍历打乱后的数据
62+
yield item["input_ids"], item["attention_mask"], Tensor(item["labels"], dtype=ms.int32)
63+
64+
return GeneratorDataset(generator(), ["input_ids", "attention_mask", "labels"]).batch(batch_size)
65+
66+
67+
train_dataset = create_dataset(tokenized_datasets["train"])
68+
eval_dataset = create_dataset(tokenized_datasets["test"])
69+
70+
# 5. 加载评估指标
71+
accuracy = evaluate.load("accuracy")
72+
f1 = evaluate.load("f1")
73+
precision = evaluate.load("precision")
74+
recall = evaluate.load("recall")
75+
76+
sample = next(iter(train_dataset))
77+
print("Input IDs:", sample[0])
78+
print("Attention Mask:", sample[1])
79+
print("Labels:", sample[2])
80+
81+
# 自定义指标计算函数
82+
def compute_metrics(eval_pred):
83+
logits, labels = eval_pred # 直接解包为logits和labels
84+
predictions = np.argmax(logits, axis=-1)
85+
86+
return {
87+
"accuracy": accuracy.compute(predictions=predictions, references=labels)["accuracy"],
88+
"f1": f1.compute(predictions=predictions, references=labels, average="binary")["f1"],
89+
"precision": precision.compute(predictions=predictions, references=labels, average="binary")["precision"],
90+
"recall": recall.compute(predictions=predictions, references=labels, average="binary")["recall"]
91+
}
92+
93+
94+
# 6. 配置训练参数
95+
training_args = TrainingArguments(
96+
num_train_epochs=3,
97+
per_device_train_batch_size=8,
98+
per_device_eval_batch_size=8,
99+
learning_rate=1e-5,
100+
weight_decay=0.01,
101+
output_dir="./results",
102+
logging_dir="./logs",
103+
logging_steps=10,
104+
evaluation_strategy="epoch",
105+
save_strategy="epoch",
106+
load_best_model_at_end=True,
107+
metric_for_best_model="accuracy", # 根据准确率选择最佳模型
108+
greater_is_better=True, # 准确率越高越好
109+
)
110+
111+
# 7. 初始化并运行训练
112+
trainer = Trainer(
113+
model=model,
114+
args=training_args,
115+
train_dataset=train_dataset,
116+
eval_dataset=eval_dataset,
117+
compute_metrics=compute_metrics, # 添加指标计算函数
118+
)
119+
120+
trainer.train()
121+
122+
# 8. 评估模型
123+
eval_results = trainer.evaluate(eval_dataset)
124+
print(f"Evaluation results: {eval_results}")
125+
print("\nFinal evaluation results:")
126+
print(f"Accuracy: {eval_results['eval_accuracy']:.4f}")
127+
print(f"F1 Score: {eval_results['eval_f1']:.4f}")
128+
print(f"Precision: {eval_results['eval_precision']:.4f}")
129+
print(f"Recall: {eval_results['eval_recall']:.4f}")
130+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Albert mindnlp StanfordIMDB reviewer Finetune
2+
3+
- Albert模型微调任务链接:[【开源实习】albert模型微调 · Issue #IAUONP · MindSpore/community - Gitee.com](https://gitee.com/mindspore/community/issues/IAUONP)
4+
- 实现了Albert-base-v1 基准权重 在 [Sentiment analysis of IMDb reviews - Stanford University] 数据集上的微调
5+
6+
- base model: [albert/albert-base-v1 · Hugging Face](https://huggingface.co/albert/albert-base-v1)
7+
- dataset: [stanfordnlp/imdb · Datasets at Hugging Face](https://huggingface.co/datasets/stanfordnlp/imdb)
8+
9+
# Requirments
10+
## Pytorch
11+
12+
- GPU: RTX 4070ti 12G
13+
- cuda: 11.8
14+
- Python version: 3.10
15+
- torch version: 2.5.0
16+
- transformers version : 4.47.0
17+
18+
## Mindspore 启智社区 Ascend910B算力资源
19+
- Ascend: 910B
20+
- python: 3.11
21+
- mindspore: 2.5.0
22+
- mindnlp: 0.4.1
23+
24+
# Result for finetune
25+
26+
training for 3 epochs
27+
28+
## torch
29+
30+
| Epoch | eval_loss |
31+
| ------------------ | --------- |
32+
| 1 | 0.3868 |
33+
| 2 | 0.2978 |
34+
| 3 | 0.3293 |
35+
| Evaluation results | 0.2978 |
36+
37+
**评估结果**
38+
39+
| Accuracy | Precision | Recall | F1_score |
40+
| -------- | --------- | ------ | -------- |
41+
| 0.9212 | 0.9218 | 0.9284 | 0.9218 |
42+
43+
44+
45+
## mindspore
46+
47+
| Epoch | eval_loss |
48+
| ------------------ | --------- |
49+
| 1 | 0.2677 |
50+
| 2 | 0.2314 |
51+
| 3 | 0.2332 |
52+
| Evaluation results | 0.2314 |
53+
54+
**评估结果**
55+
56+
| Accuracy | Precision | Recall | F1_score |
57+
| -------- | --------- | ------ | -------- |
58+
| 0.9219 | 0.9238 | 0.9218 | 0.9228 |

0 commit comments

Comments
 (0)