Skip to content

Commit 5c37bbd

Browse files
committed
audio_spectrogram_transformer finetune
1 parent 10b74e8 commit 5c37bbd

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# audio_spectrogram_transformer 模型微调报告
2+
3+
## 任务
4+
- 模型:MIT/ast-finetuned-audioset-10-10-0.4593
5+
- 数据集: ashraq/esc50
6+
7+
## 结果对比:
8+
9+
**Mindnlp+D910B**
10+
11+
| Epoch | Training Loss | Eval Loss | Accuracy |
12+
|------:|-------------:|----------------:|----------:|
13+
| 1 | 3.0928 | 2.2305 | 0.8150 |
14+
| 2 | 1.4845 | 0.9815 | 0.8950 |
15+
| 3 | 0.5733 | 0.4876 | 0.9250 |
16+
| 4 | 0.2061 | 0.3770 | 0.9125 |
17+
| 5 | 0.0742 | 0.3207 | 0.9300 |
18+
| 6 | 0.0443 | 0.1821 | **0.9600**|
19+
| 7 | 0.0178 | 0.2144 | 0.9575 |
20+
| 8 | 0.0111 | 0.2155 | 0.9575 |
21+
| 9 | 0.0094 | 0.2167 | 0.9575 |
22+
| 10 | 0.0087 | 0.2174 | 0.9575 |
23+
24+
---
25+
26+
**Pytorch+3090**
27+
28+
| Epoch | Training Loss | Validation Loss | Accuracy |
29+
|------:|-------------:|----------------:|---------:|
30+
| 1 | 1.2550 | 0.3934 | 0.8825 |
31+
| 2 | 0.4477 | 0.3656 | 0.8925 |
32+
| 3 | 0.3289 | 0.2777 | 0.9200 |
33+
| 4 | 0.2200 | 0.3645 | 0.9175 |
34+
| 5 | 0.1679 | 0.2345 | 0.9350 |
35+
| 6 | 0.1140 | 0.1877 | 0.9575 |
36+
| 7 | 0.0925 | 0.1641 | 0.9575 |
37+
| 8 | 0.0648 | 0.1810 | 0.9475 |
38+
| 9 | 0.0593 | 0.1285 | 0.9550 |
39+
| 10 | 0.0269 | 0.1222 | **0.9575** |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import mindspore as ms
2+
import numpy as np
3+
from datasets import Audio, ClassLabel, load_dataset
4+
from mindspore.dataset import GeneratorDataset
5+
from sklearn.metrics import accuracy_score
6+
from mindnlp.engine import Trainer, TrainingArguments
7+
from mindnlp.transformers import (ASTConfig, ASTFeatureExtractor,
8+
ASTForAudioClassification)
9+
10+
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend")
11+
12+
# 加载esc50数据集
13+
esc50 = load_dataset("ashraq/esc50", split="train")
14+
15+
df = esc50.select_columns(["target", "category"]).to_pandas()
16+
class_names = df.iloc[np.unique(df["target"], return_index=True)[
17+
1]]["category"].to_list()
18+
19+
esc50 = esc50.cast_column("target", ClassLabel(names=class_names))
20+
esc50 = esc50.cast_column("audio", Audio(sampling_rate=16000))
21+
esc50 = esc50.rename_column("target", "labels")
22+
num_labels = len(np.unique(esc50["labels"]))
23+
24+
# 初始化AST
25+
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
26+
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)
27+
model_input_name = feature_extractor.model_input_names[0]
28+
SAMPLING_RATE = feature_extractor.sampling_rate
29+
30+
31+
# 预处理音频
32+
def preprocess_audio(batch):
33+
wavs = [audio["array"] for audio in batch["input_values"]]
34+
inputs = feature_extractor(
35+
wavs, sampling_rate=SAMPLING_RATE, return_tensors="ms")
36+
return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])}
37+
38+
39+
dataset = esc50
40+
label2id = dataset.features["labels"]._str2int
41+
42+
# 构造训练集和测试集
43+
if "test" not in dataset:
44+
dataset = dataset.train_test_split(
45+
test_size=0.2, shuffle=True, seed=0, stratify_by_column="labels")
46+
47+
48+
dataset = dataset.cast_column("audio", Audio(
49+
sampling_rate=feature_extractor.sampling_rate))
50+
dataset = dataset.rename_column("audio", "input_values")
51+
52+
dataset["train"].set_transform(
53+
preprocess_audio, output_all_columns=False)
54+
dataset["test"].set_transform(preprocess_audio, output_all_columns=False)
55+
56+
# 加载config
57+
config = ASTConfig.from_pretrained(pretrained_model)
58+
config.num_labels = num_labels
59+
config.label2id = label2id
60+
config.id2label = {v: k for k, v in label2id.items()}
61+
62+
model = ASTForAudioClassification.from_pretrained(
63+
pretrained_model, config=config, ignore_mismatched_sizes=True)
64+
65+
66+
def convert_mindspore_datatset(hf_dataset, batch_size):
67+
data_list = list(hf_dataset)
68+
69+
def generator():
70+
for item in data_list:
71+
yield item[model_input_name], item["labels"]
72+
# 构造MindSpore的GeneratorDataset
73+
ds = GeneratorDataset(
74+
source=generator,
75+
column_names=[model_input_name, "labels"],
76+
shuffle=False
77+
)
78+
ds = ds.batch(batch_size, drop_remainder=True)
79+
return ds
80+
81+
82+
# 初始化训练参数
83+
training_args = TrainingArguments(
84+
output_dir="./checkpoint",
85+
logging_dir="./logs",
86+
learning_rate=5e-5,
87+
num_train_epochs=10,
88+
per_device_train_batch_size=8,
89+
evaluation_strategy="epoch",
90+
save_strategy="epoch",
91+
eval_steps=1,
92+
save_steps=1,
93+
load_best_model_at_end=True,
94+
metric_for_best_model="accuracy",
95+
logging_strategy="epoch",
96+
logging_steps=20,
97+
)
98+
99+
train_ms_dataset = convert_mindspore_datatset(
100+
dataset["train"], training_args.per_device_train_batch_size)
101+
eval_ms_dataset = convert_mindspore_datatset(
102+
dataset["test"], training_args.per_device_train_batch_size)
103+
104+
105+
def compute_metrics(eval_pred):
106+
logits = eval_pred.predictions
107+
labels = eval_pred.label_ids
108+
predictions = np.argmax(logits, axis=1)
109+
return {"accuracy": accuracy_score(predictions, labels)}
110+
111+
112+
# 初始化trainer
113+
trainer = Trainer(
114+
model=model,
115+
args=training_args,
116+
train_dataset=train_ms_dataset,
117+
eval_dataset=eval_ms_dataset,
118+
compute_metrics=compute_metrics,
119+
)
120+
121+
trainer.train()

0 commit comments

Comments
 (0)