|
| 1 | +# Autoformer Mindnlp 微调 |
| 2 | + |
| 3 | +- Autoformer模型微调任务链接:[【开源实习】autoformer模型微调 · Issue #IAUOTL · MindSpore/community - Gitee.com](https://gitee.com/mindspore/community/issues/IAUOTL) |
| 4 | +- 实现了huggingface/autoformer-tourism-monthly 基准权重 在 [monash_tsf/tourism_monthly] 数据集上的微调 |
| 5 | +- base model: [huggingface/autoformer-tourism-monthly · Hugging Face](https://huggingface.co/huggingface/autoformer-tourism-monthly) |
| 6 | +- dataset: [Monash-University/monash_tsf · Datasets at Hugging Face](https://huggingface.co/datasets/Monash-University/monash_tsf) |
| 7 | + |
| 8 | +------ |
| 9 | + |
| 10 | +# Requirments |
| 11 | +## Pytorch |
| 12 | + |
| 13 | +- GPU: RTX 4070ti 12G |
| 14 | +- cuda: 11.8 |
| 15 | +- Python version: 3.10 |
| 16 | +- torch version: 2.5.0 |
| 17 | +- transformers version : 4.47.0 |
| 18 | +- accelerate: 0.27.0 |
| 19 | +- gluonts: 0.14.0 |
| 20 | +- datasets: 2.16.0 |
| 21 | +- evaluate: 0.4.0 |
| 22 | +- numpy: 1.26.4 |
| 23 | +- pandas: 2.1.0 |
| 24 | +- scipy: 1.11.0 |
| 25 | + |
| 26 | +## Mindspore 启智社区 Ascend910B算力资源 |
| 27 | +- Ascend: 910B |
| 28 | +- python: 3.9 |
| 29 | +- mindspore: 2.5.0 |
| 30 | +- mindnlp: 0.4.1 |
| 31 | +- gluonts: 0.16.0 |
| 32 | +- datasets: 3.5.0 |
| 33 | +- evaluate: 0.4.3 |
| 34 | +- numpy: 1.26.4 |
| 35 | +- pandas: 2.2.3 |
| 36 | +- scipy: 1.13.1 |
| 37 | + |
| 38 | + |
| 39 | + |
| 40 | +--- |
| 41 | + |
| 42 | +## 修改内容 |
| 43 | + |
| 44 | +### Ascend |
| 45 | + |
| 46 | +源码中**modeling_autoformer.py**文件中 **padding_mode** = **circular** 改成 **padding_mode** = **replicate** |
| 47 | + |
| 48 | +### CPU/GPU |
| 49 | + |
| 50 | +源码中**modeling_autoformer.py** 922行 roll操作 在gpu和cpu上没有实现 |
| 51 | + |
| 52 | +修改源码中922行语句: |
| 53 | + |
| 54 | +```python |
| 55 | +value_states_roll_delay = value_states.roll(shifts=-int(top_k_delays_index[i]), dims=1) |
| 56 | +``` |
| 57 | + |
| 58 | +改成 |
| 59 | + |
| 60 | +```python |
| 61 | +value_states_roll_delay = custom_roll( |
| 62 | + value_states, |
| 63 | + shifts=-int(top_k_delays_index[i].asnumpy()),# 转换为Python整数 |
| 64 | + dim=1 |
| 65 | +) |
| 66 | +``` |
| 67 | + |
| 68 | +并且在前面添加一个用于替代的roll函数的自定义方法 |
| 69 | + |
| 70 | +```python |
| 71 | +def custom_roll(tensor, shifts, dim): |
| 72 | + """ |
| 73 | + Custom implementation of cyclic shift along specified dimension |
| 74 | + |
| 75 | + Args: |
| 76 | + tensor: Input tensor to be shifted |
| 77 | + shifts: Number of positions to shift |
| 78 | + (positive = right shift, negative = left shift) |
| 79 | + dim: Dimension index along which to perform shift |
| 80 | + |
| 81 | + Returns: |
| 82 | + Tensor with elements cyclically shifted along specified dimension |
| 83 | + """ |
| 84 | + # Handle cases where shifts exceed dimension length |
| 85 | + dim_size = tensor.shape[dim] |
| 86 | + shifts = shifts % dim_size # Ensure shifts are within valid range |
| 87 | + |
| 88 | + if shifts == 0: |
| 89 | + return tensor |
| 90 | + |
| 91 | + # Split tensor into two parts and swap their order |
| 92 | + if shifts > 0: |
| 93 | + # Right shift: keep last 'shifts' elements and move to front |
| 94 | + part1 = ops.narrow(tensor, dim, 0, dim_size - shifts) |
| 95 | + part2 = ops.narrow(tensor, dim, dim_size - shifts, shifts) |
| 96 | + else: |
| 97 | + # Left shift: keep first '|shifts|' elements and move to end |
| 98 | + shifts = abs(shifts) |
| 99 | + part1 = ops.narrow(tensor, dim, shifts, dim_size - shifts) |
| 100 | + part2 = ops.narrow(tensor, dim, 0, shifts) |
| 101 | + |
| 102 | + # Concatenate the reversed parts |
| 103 | + return ops.cat((part2, part1), dim) |
| 104 | +``` |
| 105 | + |
| 106 | + |
| 107 | + |
| 108 | +---- |
| 109 | + |
| 110 | +## 微调结果 |
| 111 | + |
| 112 | +### Mindspore |
| 113 | + |
| 114 | +| Epoch | Loss | |
| 115 | +| ----- | ----------------- | |
| 116 | +| 0 | 7.546689510345459 | |
| 117 | +| 1 | 7.772482395172119 | |
| 118 | +| 2 | 7.14789342880249 | |
| 119 | +| 3 | 7.49253511428833 | |
| 120 | +| 4 | 7.337801456451416 | |
| 121 | +| 5 | 6.960692882537842 | |
| 122 | +| 6 | 8.312647819519043 | |
| 123 | +| 7 | 6.90599250793457 | |
| 124 | +| 8 | 7.212374210357666 | |
| 125 | +| 9 | 7.506921291351318 | |
| 126 | + |
| 127 | +------ |
| 128 | + |
| 129 | +### Pytorch |
| 130 | + |
| 131 | +| Epoch | Loss | |
| 132 | +| ----- | ------------------ | |
| 133 | +| 0 | 7.412668228149414 | |
| 134 | +| 1 | 7.8263068199157715 | |
| 135 | +| 2 | 7.839258670806885 | |
| 136 | +| 3 | 8.043777465820312 | |
| 137 | +| 4 | 8.08508586883545 | |
| 138 | +| 5 | 7.503101825714111 | |
| 139 | +| 6 | 7.824302673339844 | |
| 140 | +| 7 | 7.399034023284912 | |
| 141 | +| 8 | 7.122222900390625 | |
| 142 | +| 9 | 7.612663269042969 | |
0 commit comments