|
| 1 | +import mindspore as ms |
| 2 | +import numpy as np |
| 3 | +from datasets import Audio, ClassLabel, load_dataset |
| 4 | +from mindspore.dataset import GeneratorDataset |
| 5 | +from sklearn.metrics import accuracy_score |
| 6 | +from mindnlp.engine import Trainer, TrainingArguments |
| 7 | +from mindnlp.transformers import (ASTConfig, ASTFeatureExtractor, |
| 8 | + ASTForAudioClassification) |
| 9 | + |
| 10 | +ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend") |
| 11 | + |
| 12 | +# 加载esc50数据集 |
| 13 | +esc50 = load_dataset("ashraq/esc50", split="train") |
| 14 | + |
| 15 | +df = esc50.select_columns(["target", "category"]).to_pandas() |
| 16 | +class_names = df.iloc[np.unique(df["target"], return_index=True)[ |
| 17 | + 1]]["category"].to_list() |
| 18 | + |
| 19 | +esc50 = esc50.cast_column("target", ClassLabel(names=class_names)) |
| 20 | +esc50 = esc50.cast_column("audio", Audio(sampling_rate=16000)) |
| 21 | +esc50 = esc50.rename_column("target", "labels") |
| 22 | +num_labels = len(np.unique(esc50["labels"])) |
| 23 | + |
| 24 | +# 初始化AST |
| 25 | +pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593" |
| 26 | +feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model) |
| 27 | +model_input_name = feature_extractor.model_input_names[0] |
| 28 | +SAMPLING_RATE = feature_extractor.sampling_rate |
| 29 | + |
| 30 | + |
| 31 | +# 预处理音频 |
| 32 | +def preprocess_audio(batch): |
| 33 | + wavs = [audio["array"] for audio in batch["input_values"]] |
| 34 | + inputs = feature_extractor( |
| 35 | + wavs, sampling_rate=SAMPLING_RATE, return_tensors="ms") |
| 36 | + return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])} |
| 37 | + |
| 38 | + |
| 39 | +dataset = esc50 |
| 40 | +label2id = dataset.features["labels"]._str2int |
| 41 | + |
| 42 | +# 构造训练集和测试集 |
| 43 | +if "test" not in dataset: |
| 44 | + dataset = dataset.train_test_split( |
| 45 | + test_size=0.2, shuffle=True, seed=0, stratify_by_column="labels") |
| 46 | + |
| 47 | + |
| 48 | +dataset = dataset.cast_column("audio", Audio( |
| 49 | + sampling_rate=feature_extractor.sampling_rate)) |
| 50 | +dataset = dataset.rename_column("audio", "input_values") |
| 51 | + |
| 52 | +dataset["train"].set_transform( |
| 53 | + preprocess_audio, output_all_columns=False) |
| 54 | +dataset["test"].set_transform(preprocess_audio, output_all_columns=False) |
| 55 | + |
| 56 | +# 加载config |
| 57 | +config = ASTConfig.from_pretrained(pretrained_model) |
| 58 | +config.num_labels = num_labels |
| 59 | +config.label2id = label2id |
| 60 | +config.id2label = {v: k for k, v in label2id.items()} |
| 61 | + |
| 62 | +model = ASTForAudioClassification.from_pretrained( |
| 63 | + pretrained_model, config=config, ignore_mismatched_sizes=True) |
| 64 | + |
| 65 | + |
| 66 | +def convert_mindspore_datatset(hf_dataset, batch_size): |
| 67 | + data_list = list(hf_dataset) |
| 68 | + |
| 69 | + def generator(): |
| 70 | + for item in data_list: |
| 71 | + yield item[model_input_name], item["labels"] |
| 72 | + # 构造MindSpore的GeneratorDataset |
| 73 | + ds = GeneratorDataset( |
| 74 | + source=generator, |
| 75 | + column_names=[model_input_name, "labels"], |
| 76 | + shuffle=False |
| 77 | + ) |
| 78 | + ds = ds.batch(batch_size, drop_remainder=True) |
| 79 | + return ds |
| 80 | + |
| 81 | + |
| 82 | +# 初始化训练参数 |
| 83 | +training_args = TrainingArguments( |
| 84 | + output_dir="./checkpoint", |
| 85 | + logging_dir="./logs", |
| 86 | + learning_rate=5e-5, |
| 87 | + num_train_epochs=10, |
| 88 | + per_device_train_batch_size=8, |
| 89 | + evaluation_strategy="epoch", |
| 90 | + save_strategy="epoch", |
| 91 | + eval_steps=1, |
| 92 | + save_steps=1, |
| 93 | + load_best_model_at_end=True, |
| 94 | + metric_for_best_model="accuracy", |
| 95 | + logging_strategy="epoch", |
| 96 | + logging_steps=20, |
| 97 | +) |
| 98 | + |
| 99 | +train_ms_dataset = convert_mindspore_datatset( |
| 100 | + dataset["train"], training_args.per_device_train_batch_size) |
| 101 | +eval_ms_dataset = convert_mindspore_datatset( |
| 102 | + dataset["test"], training_args.per_device_train_batch_size) |
| 103 | + |
| 104 | + |
| 105 | +def compute_metrics(eval_pred): |
| 106 | + logits = eval_pred.predictions |
| 107 | + labels = eval_pred.label_ids |
| 108 | + predictions = np.argmax(logits, axis=1) |
| 109 | + return {"accuracy": accuracy_score(predictions, labels)} |
| 110 | + |
| 111 | + |
| 112 | +# 初始化trainer |
| 113 | +trainer = Trainer( |
| 114 | + model=model, |
| 115 | + args=training_args, |
| 116 | + train_dataset=train_ms_dataset, |
| 117 | + eval_dataset=eval_ms_dataset, |
| 118 | + compute_metrics=compute_metrics, |
| 119 | +) |
| 120 | + |
| 121 | +trainer.train() |
0 commit comments