Skip to content

Commit cf18cf7

Browse files
committed
initial commit
0 parents  commit cf18cf7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+704532
-0
lines changed

.gitignore

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
pretrained/
9+
tb_history/
10+
logs/
11+
cls_runs/
12+
slurm_logs/
13+
14+
#experiments/
15+
experiments/
16+
experiments_da/
17+
config[0-9]*
18+
*.png
19+
*.pth
20+
21+
# Distribution / packaging
22+
.Python
23+
env/
24+
build/
25+
develop-eggs/
26+
dist/
27+
downloads/
28+
eggs/
29+
.eggs/
30+
lib/
31+
lib64/
32+
parts/
33+
sdist/
34+
var/
35+
wheels/
36+
*.egg-info/
37+
.installed.cfg
38+
*.egg
39+
40+
# PyInstaller
41+
# Usually these files are written by a python script from a template
42+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
43+
*.manifest
44+
*.spec
45+
46+
# Installer logs
47+
pip-log.txt
48+
pip-delete-this-directory.txt
49+
50+
# Unit test / coverage reports
51+
htmlcov/
52+
.tox/
53+
.coverage
54+
.coverage.*
55+
.cache
56+
nosetests.xml
57+
coverage.xml
58+
*.cover
59+
.hypothesis/
60+
61+
# Translations
62+
*.mo
63+
*.pot
64+
65+
# Django stuff:
66+
*.log
67+
local_settings.py
68+
69+
# Flask stuff:
70+
instance/
71+
.webassets-cache
72+
73+
# Scrapy stuff:
74+
.scrapy
75+
76+
# Sphinx documentation
77+
docs/_build/
78+
79+
# PyBuilder
80+
target/
81+
82+
# Jupyter Notebook
83+
.ipynb_checkpoints
84+
85+
# pyenv
86+
.python-version
87+
88+
# celery beat schedule file
89+
celerybeat-schedule
90+
91+
# SageMath parsed files
92+
*.sage.py
93+
94+
# dotenv
95+
.env
96+
97+
# virtualenv
98+
.venv
99+
venv/
100+
ENV/
101+
102+
# Spyder project settings
103+
.spyderproject
104+
.spyproject
105+
106+
# Rope project settings
107+
.ropeproject
108+
109+
# mkdocs documentation
110+
/site
111+
112+
# mypy
113+
.mypy_cache/
114+
115+
# input data, saved log, checkpoints
116+
data/
117+
input/
118+
saved/
119+
outputs/
120+
datasets/
121+
122+
# editor, os cache directory
123+
.vscode/
124+
.idea/
125+
__MACOSX/
126+

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2020 Yassine Ouali
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
2+
3+
## Semi-Supervised Semantic Segmentation with Cross-Consistecy Training (CCT)
4+
5+
#### [Paper](https://arxiv.org/abs/2003.09005), [Project Page](https://yassouali.github.io/cct_page/)
6+
7+
This repo contains the official implementation of CVPR 2020 paper: Semi-Supervised Semantic Segmentation with Cross-Consistecy Training, which
8+
adapts the traditional consistency training framework of semi-supervised learning for semantic segmentation, with an extension to weak-supervised
9+
learning and learning on multiple domains.
10+
11+
<p align="center"><img src="https://yassouali.github.io/cct_page/files/overview.png" width="450"></p>
12+
13+
### Highlights
14+
15+
**(1) Consistency Training for semantic segmentation.** \
16+
We observe that for semantic segmentation, due to the dense nature of the task,
17+
the cluster assumption is more easily enforced over the hidden representations rather than the inputs.
18+
19+
**(2) Cross-Consistecy Training.** \
20+
We propose CCT (Cross-Consistecy Training) for semi-supervised semantic segmentation, where we define
21+
a number of novel perturbations, and show the effectiveness of enforcing consistency over the encoder's outputs
22+
rather than the inputs.
23+
24+
**(3) Using weak-labels and pixel-level labels from multiple domains.** \
25+
The proposed method is quite simple and flexible, and can easily be extended to use image-level labels and
26+
pixel-level labels from multiple-domains.
27+
28+
29+
30+
### Requirements
31+
32+
This repo was tested with Ubuntu 18.04.3 LTS, Python 3.7, PyTorch 1.1.0, and CUDA 10.0. But it should be runnable with recent PyTorch versions >=1.1.0.
33+
34+
The required packages are `pytorch` and `torchvision`, together with `PIL` and `opencv` for data-preprocessing and `tqdm` for showing the training progress.
35+
With some additional modules like `dominate` to save the results in the form of HTML files. To setup the necessary modules, simply run:
36+
37+
```bash
38+
pip install -r requirements.txt
39+
```
40+
41+
### Dataset
42+
43+
In this repo, we use **Pascal VOC**, to obtain it, first download the [original dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar), after extracting the files we'll end up with `VOCtrainval_11-May-2012/VOCdevkit/VOC2012` containing the image sets, the XML annotation for both object detection and segmentation, and JPEG images.\
44+
The second step is to augment the dataset using the additionnal annotations provided by [Semantic Contours from Inverse Detectors](http://home.bharathh.info/pubs/pdfs/BharathICCV2011.pdf). Download the rest of the annotations [SegmentationClassAug](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0) and add them to the path `VOCtrainval_11-May-2012/VOCdevkit/VOC2012`, now we're set, for training use the path to `VOCtrainval_11-May-2012`.
45+
46+
47+
### Training
48+
49+
To train a model, first download PASCAL VOC as detailed above, then set `data_dir` to the dataset path in the config file in `configs/config.json` and set the rest of the parameters, like the number of GPUs, cope size, data augmentation ... etc ,you can also change CCT hyperparameters if you wish, more details below. Then simply run:
50+
51+
```bash
52+
python train.py --config configs/config.json
53+
```
54+
55+
The log files and the `.pth` checkpoints will be saved in `saved\EXP_NAME`, to monitor the training using tensorboard, please run:
56+
57+
```bash
58+
tensorboard --logdir saved
59+
```
60+
61+
To resume training using a saved `.pth` model:
62+
63+
```bash
64+
python train.py --config configs/config.json --resume saved/CCT/checkpoint.pth
65+
```
66+
67+
**Results**: The results will be saved in `saved` as an html file, containing the validation results,
68+
and the name it will take is `experim_name` specified in `configs/config.json`.
69+
70+
### Pseudo-labels
71+
72+
If you want to use image level labels to train the auxiliary labels as explained in section 3.3 of the paper. First generate the pseudo-labels
73+
using the code in `pseudo_labels`:
74+
75+
76+
```bash
77+
cd pseudo_labels
78+
python run.py --voc12_root DATA_PATH
79+
```
80+
81+
`DATA_PATH` must point to the folder containing `JPEGImages` in Pascal Voc dataset. The results will be
82+
saved in `pseudo_labels/result/pseudo_labels` as PNG files, the flag `use_weak_labels` needs to be set to True in the config file, and
83+
then we can train the model as detailed above.
84+
85+
86+
### Inference
87+
88+
For inference, we need a pretrained model, the jpg images we'd like to segment and the config used in training (to load the correct model and other parameters),
89+
90+
```bash
91+
python inference.py --config config.json --model best_model.pth --images images_folder
92+
```
93+
94+
The predictions will be saved as `.png` images in `outputs\` is used, for Pacal VOC the default palette is:
95+
96+
<p align="center"><img src="https://raw.githubusercontent.com/yassouali/pytorch_segmentation/master/images/colour_scheme.png" width="550"></p>
97+
98+
Here are the flags available for inference:
99+
100+
```
101+
--images Folder containing the jpg images to segment.
102+
--model Path to the trained pth model.
103+
--config The config file used for training the model.
104+
```
105+
106+
### Citation ✏️ 📄
107+
108+
If you find this repo useful for your research, please consider citing the paper as follows:
109+
110+
```
111+
@inproceedings{ouali2020semi,
112+
title={Semi-Supervised Semantic Segmentation with Cross-Consistency Training},
113+
author={Ouali, Yassine and Hudelot, C{\'e}line and Tami, Myriam},
114+
year = {2020},
115+
booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR)},
116+
note = {to appear},
117+
pubstate = {published},
118+
tppubtype = {inproceedings}
119+
}
120+
```
121+
122+
For any questions, please contact Yassine Ouali ([email protected]).
123+
124+
#### Config file details ⚙️
125+
126+
Bellow we detail the CCT parameters that can be controlled in the config file `configs/config.json`, the rest of the parameters
127+
are self-explanatory.
128+
129+
```javascript
130+
{
131+
"name": "CCT",
132+
"experim_name": "CCT", // The name the results will take (html and the folder in /saved)
133+
"n_gpu": 1, // Number of GPUs
134+
"n_labeled_examples": 1000, // Number of labeled examples (choices are 60, 100, 200,
135+
// 300, 500, 800, 1000, 1464, and the splits are in dataloaders/voc_splits)
136+
"diff_lrs": true,
137+
"ramp_up": 0.1, // The unsupervised loss will be slowly scaled up in the first 10% of Training time
138+
"unsupervised_w": 30, // Weighting of the unsupervised loss
139+
"ignore_index": 255,
140+
"lr_scheduler": "Poly",
141+
"use_weak_labels": false, // If the pseudo-labels were generated, we can use them to train the aux. decoders
142+
"weakly_loss_w": 0.4, // Weighting of the weakly-supervised loss
143+
"pretrained": true,
144+
145+
"model":{
146+
"supervised": true, // Supervised setting (training only on the labeled examples)
147+
"semi": false, // Semi-supervised setting
148+
"supervised_w": 1, // Weighting of the supervised loss
149+
150+
"sup_loss": "CE", // supervised loss, choices are CE and ab-CE = ["CE", "ABCE"]
151+
"un_loss": "MSE", // unsupervised loss, choices are CE and KL-divergence = ["MSE", "KL"]
152+
153+
"softmax_temp": 1,
154+
"aux_constraint": false, // Pair-wise loss (sup. mat.)
155+
"aux_constraint_w": 1,
156+
"confidence_masking": false, // Confidence masking (sup. mat.)
157+
"confidence_th": 0.5,
158+
159+
"drop": 6, // Number of DropOut decoders
160+
"drop_rate": 0.5, // Dropout probability
161+
"spatial": true,
162+
163+
"cutout": 6, // Number of G-Cutout decoders
164+
"erase": 0.4, // We drop 40% of the area
165+
166+
"vat": 2, // Number of I-VAT decoders
167+
"xi": 1e-6, // VAT parameters
168+
"eps": 2.0,
169+
170+
"context_masking": 2, // Number of Con-Msk decoders
171+
"object_masking": 2, // Number of Obj-Msk decoders
172+
"feature_drop": 6, // Number of F-Drop decoders
173+
174+
"feature_noise": 6, // Number of F-Noise decoders
175+
"uniform_range": 0.3 // The range of the noise
176+
},
177+
```
178+
179+
#### Acknowledgements
180+
181+
- Pseudo-labels generation is based on Jiwoon Ahn's implementation [irn](https://github.com/jiwoon-ahn/irn).
182+
- Code structure was based on [Pytorch-Template](https://github.com/victoresque/pytorch-template/blob/master/README.m)
183+
- ResNet backbone was downloaded from [torchcv](https://github.com/donnyyou/torchcv)

base/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .base_dataloader import *
2+
from .base_dataset import *
3+
from .base_model import *
4+
from .base_trainer import *
5+
6+

base/base_dataloader.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
from copy import deepcopy
3+
import torch
4+
from torch.utils.data import DataLoader
5+
from torch.utils.data.sampler import SubsetRandomSampler
6+
7+
class BaseDataLoader(DataLoader):
8+
def __init__(self, dataset, batch_size, shuffle, num_workers, val_split = 0.0):
9+
self.shuffle = shuffle
10+
self.dataset = dataset
11+
self.nbr_examples = len(dataset)
12+
if val_split:
13+
self.train_sampler, self.val_sampler = self._split_sampler(val_split)
14+
else:
15+
self.train_sampler, self.val_sampler = None, None
16+
17+
self.init_kwargs = {
18+
'dataset': self.dataset,
19+
'batch_size': batch_size,
20+
'shuffle': self.shuffle,
21+
'num_workers': num_workers,
22+
'pin_memory': True
23+
}
24+
super(BaseDataLoader, self).__init__(sampler=self.train_sampler, **self.init_kwargs)
25+
26+
def _split_sampler(self, split):
27+
if split == 0.0:
28+
return None, None
29+
30+
self.shuffle = False
31+
32+
split_indx = int(self.nbr_examples * split)
33+
np.random.seed(0)
34+
35+
indxs = np.arange(self.nbr_examples)
36+
np.random.shuffle(indxs)
37+
train_indxs = indxs[split_indx:]
38+
val_indxs = indxs[:split_indx]
39+
self.nbr_examples = len(train_indxs)
40+
41+
train_sampler = SubsetRandomSampler(train_indxs)
42+
val_sampler = SubsetRandomSampler(val_indxs)
43+
return train_sampler, val_sampler
44+
45+
def get_val_loader(self):
46+
if self.val_sampler is None:
47+
return None
48+
return DataLoader(sampler=self.val_sampler, **self.init_kwargs)

0 commit comments

Comments
 (0)