Skip to content

Commit 4f12333

Browse files
author
wbw520
committed
update read me
1 parent 76584e5 commit 4f12333

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,22 @@ Then see the generated concepts by:
4141
python vis_retri.py --num_classes 50 --num_cpt 20 --base_model resnet18 --index 300 --top_sample 20 --dataset CUB200
4242
```
4343

44+
#### Demo weight
45+
We provide some trained weight for demo. Download and set them to folder "saved_model/"
46+
![MNIST](https://drive.google.com/file/d/1wQtsi2jTEoG1k877XNG9cB4njWztlkzn/view?usp=sharing)
47+
![CUB 50 classes](https://drive.google.com/file/d/1XIcTPCCb3uXFOOb_PrNXjJHfMrh--wsy/view?usp=sharing)
48+
![ImageNet 50 classes](https://drive.google.com/file/d/1VSAlC6QftQDUzIAE8WJ6D0tNEaa1niLR/view?usp=sharing)
49+
50+
Run them as following:
51+
```
52+
MNIST
53+
python vis_recon.py --num_classes 10 --num_cpt 20 --index 0 --top_sample 20 --top_sample 20 --deactivate -1
54+
55+
CUB
56+
python process.py
57+
python vis_retri.py --num_classes 50 --num_cpt 20 --base_model resnet18 --index 300 --top_sample 20 --dataset CUB200
58+
59+
ImageNet
60+
python process.py
61+
python vis_retri.py --num_classes 50 --num_cpt 20 --base_model resnet18 --index 300 --top_sample 20 --dataset CUB200
62+
```

main_recon.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import os
1010

1111

12+
os.makedirs('saved_model/', exist_ok=True)
13+
14+
1215
def main():
1316
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
1417
trainset = datasets.MNIST('../data', train=True, download=True, transform=transform)

main_retri.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from utils.tools import fix_parameter, print_param
99

1010

11+
os.makedirs('saved_model/', exist_ok=True)
12+
13+
1114
def main():
1215
model = MainModel(args)
1316
device = torch.device(args.device)

0 commit comments

Comments
 (0)