Skip to content

Commit 9a8bc74

Browse files
colizzhqucms
authored andcommitted
Cooperate paper revision: include the fine-tuning of ParticleNet
1 parent 83ec97a commit 9a8bc74

7 files changed

+76
-6
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,22 @@ Additional arguments will be passed directly to the `weaver` command, such as `-
7070
- using PyTorch's DistributedDataParallel:
7171

7272
```
73-
NGPUS=8 ./train_JetClass.sh ParT full --batch-size [batch_size_per_gpu] ...
73+
DDP_NGPUS=4 ./train_JetClass.sh ParT full --batch-size [batch_size_per_gpu] ...
7474
```
7575

7676
**To run the training on the QuarkGluon dataset:**
7777

7878
```
79-
./train_QuarkGluon.sh [ParT|ParT-FineTune|PN|PFN|PCNN] [kin|kinpid|kinpidplus] ...
79+
./train_QuarkGluon.sh [ParT|ParT-FineTune|PN|PN-FineTune|PFN|PCNN] [kin|kinpid|kinpidplus] ...
8080
```
8181

8282
**To run the training on the TopLandscape dataset:**
8383

8484
```
85-
./train_TopLandscape.sh [ParT|ParT-FineTune|PN|PFN|PCNN] [kin] ...
85+
./train_TopLandscape.sh [ParT|ParT-FineTune|PN|PN-FineTune|PFN|PCNN] [kin] ...
8686
```
8787

88-
The argument `ParT-FineTune` will run the fine-tuning using [models pre-trained on the JetClass dataset](models/).
88+
The argument `ParT-FineTune` or `PN-FineTune` will run the fine-tuning using [models pre-trained on the JetClass dataset](models/).
8989

9090
## Citations
9191

models/ParticleNet_kin.pt

1.44 MB
Binary file not shown.

models/ParticleNet_kinpid.pt

1.45 MB
Binary file not shown.
+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import torch.nn as nn
3+
from weaver.nn.model.ParticleNet import ParticleNet
4+
5+
6+
class ParticleNetWrapper(nn.Module):
7+
def __init__(self, **kwargs) -> None:
8+
super().__init__()
9+
10+
in_dim = kwargs['fc_params'][-1][0]
11+
num_classes = kwargs['num_classes']
12+
self.for_inference = kwargs['for_inference']
13+
14+
# finetune the last FC layer
15+
self.fc_out = nn.Linear(in_dim, num_classes)
16+
17+
kwargs['for_inference'] = False
18+
self.mod = ParticleNet(**kwargs)
19+
self.mod.fc = self.mod.fc[:-1]
20+
21+
def forward(self, points, features, lorentz_vectors, mask):
22+
x_cls = self.mod(points, features, mask)
23+
output = self.fc_out(x_cls)
24+
if self.for_inference:
25+
output = torch.softmax(output, dim=1)
26+
return output
27+
28+
29+
def get_model(data_config, **kwargs):
30+
conv_params = [
31+
(16, (64, 64, 64)),
32+
(16, (128, 128, 128)),
33+
(16, (256, 256, 256)),
34+
]
35+
fc_params = [(256, 0.1)]
36+
37+
pf_features_dims = len(data_config.input_dicts['pf_features'])
38+
num_classes = len(data_config.label_value)
39+
model = ParticleNetWrapper(
40+
input_dims=pf_features_dims,
41+
num_classes=num_classes,
42+
conv_params=kwargs.get('conv_params', conv_params),
43+
fc_params=kwargs.get('fc_params', fc_params),
44+
use_fusion=kwargs.get('use_fusion', False),
45+
use_fts_bn=kwargs.get('use_fts_bn', True),
46+
use_counts=kwargs.get('use_counts', True),
47+
for_inference=kwargs.get('for_inference', False)
48+
)
49+
50+
model_info = {
51+
'input_names': list(data_config.input_names),
52+
'input_shapes': {k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items()},
53+
'output_names': ['softmax'],
54+
'dynamic_axes': {**{k: {0: 'N', 2: 'n_' + k.split('_')[0]} for k in data_config.input_names}, **{'softmax': {0: 'N'}}},
55+
}
56+
57+
return model, model_info
58+
59+
60+
def get_loss(data_config, **kwargs):
61+
return torch.nn.CrossEntropyLoss()

networks/example_ParticleTransformer_finetune.py

-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def forward(self, points, features, lorentz_vectors, mask):
3939
output = torch.softmax(output, dim=1)
4040
return output
4141

42-
# return self.mod(features, v=lorentz_vectors, mask=mask)[:, [-2, 0]]
43-
4442

4543
def get_model(data_config, **kwargs):
4644

train_QuarkGluon.sh

+7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ elif [[ "$model" == "ParT-FineTune" ]]; then
2626
elif [[ "$model" == "PN" ]]; then
2727
modelopts="networks/example_ParticleNet.py"
2828
lr="1e-2"
29+
elif [[ "$model" == "PN-FineTune" ]]; then
30+
modelopts="networks/example_ParticleNet_finetune.py"
31+
lr="1e-3"
32+
extraopts="--optimizer-option lr_mult (\"fc_out.*\",50) --lr-scheduler none"
2933
elif [[ "$model" == "PFN" ]]; then
3034
modelopts="networks/example_PFN.py"
3135
lr="2e-2"
@@ -55,6 +59,9 @@ fi
5559
if [[ "$model" == "ParT-FineTune" ]]; then
5660
modelopts+=" --load-model-weights models/ParT_${pretrain_type}.pt"
5761
fi
62+
if [[ "$model" == "PN-FineTune" ]]; then
63+
modelopts+=" --load-model-weights models/ParticleNet_${pretrain_type}.pt"
64+
fi
5865

5966
weaver \
6067
--data-train "${DATADIR}/train_file_*.parquet" \

train_TopLandscape.sh

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ elif [[ "$model" == "ParT-FineTune" ]]; then
2525
elif [[ "$model" == "PN" ]]; then
2626
modelopts="networks/example_ParticleNet.py"
2727
lr="1e-2"
28+
elif [[ "$model" == "PN-FineTune" ]]; then
29+
modelopts="networks/example_ParticleNet_finetune.py"
30+
lr="1e-3"
31+
extraopts="--optimizer-option lr_mult (\"fc_out.*\",50) --lr-scheduler none --load-model-weights models/ParticleNet_kin.pt"
2832
elif [[ "$model" == "PFN" ]]; then
2933
modelopts="networks/example_PFN.py"
3034
lr="2e-2"

0 commit comments

Comments
 (0)