Skip to content

Commit 1235348

Browse files
authored
Add StarGAN implementation (pclubiitk#24)
* Added the StarGAN implementation. * Write README.md * Modified README.md * Rectified spelling mistake. * Made required changes
1 parent ddbb96f commit 1235348

File tree

15 files changed

+564
-0
lines changed

15 files changed

+564
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.DS_Store
2+
**/.DS_Store
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# PyTorch implementation of StarGAN
2+
## Usage
3+
```bash
4+
> python main.py --arguments
5+
```
6+
The arguments are as follows-
7+
```bash
8+
usage: main.py [-h] [--directory DIRECTORY] [--epochs EPOCHS]
9+
[--batch_size BATCH_SIZE] [--gen_lr GEN_LR] [--dis_lr DIS_LR]
10+
[--d_times D_TIMES] [--lam_cls LAM_CLS]
11+
[--lam_recomb LAM_RECOMB] [--image_dim IMAGE_DIM]
12+
[--download DOWNLOAD] [--eval_idx EVAL_IDX]
13+
[--attrs ATTRS [ATTRS ...]]
14+
15+
optional arguments:
16+
-h, --help show this help message and exit
17+
--directory DIRECTORY
18+
directory of dataset
19+
--epochs EPOCHS total number of epochs you want to run. Default: 20
20+
--batch_size BATCH_SIZE
21+
Batch size for dataset
22+
--gen_lr GEN_LR generator learning rate
23+
--dis_lr DIS_LR discriminator learning rate
24+
--d_times D_TIMES No of times you want D to update before updating G
25+
--lam_cls LAM_CLS Value of lambda for domain classification loss
26+
--lam_recomb LAM_RECOMB
27+
Value of lambda for image recombination loss
28+
--image_dim IMAGE_DIM
29+
Image dimension you want to resize to.
30+
--download DOWNLOAD Argument to download dataset. Set to True.
31+
--eval_idx EVAL_IDX Index of image you want to run evaluation on.
32+
--attrs ATTRS [ATTRS ...], --list ATTRS [ATTRS ...]
33+
selected attributes for the CelebA dataset
34+
```
35+
36+
## Contributed by:
37+
[Som Tambe](https://github.com/SomTambe)
38+
39+
## References
40+
**StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation** Yunjey Choi, Minje Choi, Munyoung Kim, Jung-Woo Ha, Sunghun Kim, Jaegul Choo
41+
42+
**CVPR 2018** / [ArXiv](https://arxiv.org/abs/1711.09020) /
43+
44+
## Summary
45+
## Introduction
46+
StarGAN is a very versatile example of how one can use Generative Adversarial Networks (Goodfellow. et. al) to learn cross-domain relations, and perform image-to-image translations based on a single discriminator and a generator unit.
47+
48+
## How does it do that ?
49+
Let us define the following terms before going ahead with anything new.
50+
51+
**attribute** - Particular feature inherent in an image. Example: haircolor, age, gender.
52+
53+
**attribute value** - Value of an **attribute**. Example: If chosen attribute is haircolor, its values can be blonde, black, white, grey.
54+
55+
**domain** - Set of images sharing the same attribute value. Example: images of women is one domain. Similarly, images of men is another.
56+
57+
For our experiments, we use the CelebA dataset ([Liu. et. al](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)). It contains more than 200K images with over 40 labelled attributes.
58+
59+
The existing models were quite inefficient: for learning mappings among all **K** domains, <sup>K</sup>P<sub>2</sub> generators were required to learn every single mapping among all domains. Also in these models, generator could not make full use of data and could only learn from 2 out of **K** domains at a single time.
60+
61+
StarGAN solves that problem by introducing a single generator which learns mappings between all domains. Generator inputs two things, **image**, as well as the **inference labels**.
62+
63+
<p style="text-align: center;"> <b>G(x, c) → y </b></p>
64+
65+
<i>Where, y is the generated image, x is the original image, and c is the target label. </i>
66+
67+
We here use an auxillary classifier as our discriminator, which outputs both, the real/fake **D<sub>src</sub>**, and the original labels of the input image **D<sub>cls</sub>**.
68+
69+
<p style="text-align: center;"><b>D</b> : <b>x</b> → {<b>D<sub>src</sub>(x)</b>, <b>D<sub>cls</sub>(x)</b>}</p>
70+
71+
## Loss and Objective Functions
72+
73+
There are three losses.
74+
### Adversarial Loss
75+
![adversarial](assets/adversarial.png)
76+
### Domain Classification Loss
77+
**Real Domain Classification Loss**
78+
79+
![real domain](assets/realdomain.png)
80+
81+
**Fake Domain Classification Loss**
82+
83+
![fake domain](assets/fakedomain.png)
84+
85+
### Image reconstruction loss
86+
![reconstruction](assets/reconst.png)
87+
88+
### Final Objective function
89+
![finalobj](assets/finalobj.png)
90+
91+
## Training
92+
Training has been elaborated in the following figures.
93+
94+
![training](assets/training.png)
95+
96+
# Results
97+
I selected a random image from the dataset.
98+
99+
![original](assets/original.png)
100+
101+
[Black Hair, Male]
102+
103+
Training a single epoch was taking 9 hours on the Tesla K80 GPU. I trained for about 1500 iterations from 12000 iterations from a single epoch.
104+
105+
This was the translation to [Brown_Hair, Male]-
106+
107+
![gen](assets/rendered.png)
108+
109+
The generator seems to have recognised the spatial features. Since full training has not been done, we cannot infer anything more other than the fact that the generator has been learning features.
110+
111+
## Losses
112+
113+
![losses](assets/losses.png)
114+
115+
Training was continued for 3000 iterations, but the computer crashed, erasing any progress I could have made.
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from functools import partial
2+
import torch
3+
import os
4+
import PIL
5+
from torchvision.datasets.vision import VisionDataset
6+
from torchvision.datasets.utils import check_integrity, verify_str_arg, _get_confirm_token,_save_response_content,download_file_from_google_drive
7+
8+
# Custom dataset class created to output tensors of selected attributes only.
9+
10+
class CelebA(VisionDataset):
11+
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
12+
Args:
13+
root (string): Root directory where images are downloaded to.
14+
split (string): One of {'train', 'valid', 'test', 'all'}.
15+
Accordingly dataset is selected.
16+
attributes (list): List of attributes that you want from all 40 attributes.
17+
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
18+
or ``landmarks``. Can also be a list to output a tuple with all specified target types.
19+
The targets represent:
20+
``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
21+
``identity`` (int): label for each person (data points with the same identity are the same person)
22+
``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
23+
``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
24+
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
25+
Defaults to ``attr``. If empty, ``None`` will be returned as target.
26+
transform (callable, optional): A function/transform that takes in an PIL image
27+
and returns a transformed version. E.g, ``transforms.ToTensor``
28+
target_transform (callable, optional): A function/transform that takes in the
29+
target and transforms it.
30+
download (bool, optional): If true, downloads the dataset from the internet and
31+
puts it in root directory. If dataset is already downloaded, it is not
32+
downloaded again.
33+
"""
34+
35+
base_folder = "celeba"
36+
# There currently does not appear to be a easy way to extract 7z in python (without introducing additional
37+
# dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
38+
# right now.
39+
file_list = [
40+
# File ID MD5 Hash Filename
41+
("15GLCHkvetqYVbg4d1gWZhD9Pk7RDNa7T", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
42+
# ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
43+
# ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
44+
("16ZFAm82Es_MiQ51E81r69Qbh7KEH8Dfu", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
45+
("1LuFPVoCSub0Ewyaf3QzNpmtRTDp9Tml8", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
46+
("10u_vSZfCadbWKAhQyNDuyuhF1tsCEr2B", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
47+
("1VcOp1jra9oxLDmUHdjTqkifMqMkDnQEx", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
48+
# ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
49+
("1kiE5zyobrmnw49R-ca6EfHbRNWxVq33K", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
50+
]
51+
52+
def __init__(self, root, attributes, split="train", target_type="attr", transform=None,
53+
target_transform=None, download=False):
54+
import pandas
55+
super(CelebA, self).__init__(root, transform=transform,
56+
target_transform=target_transform)
57+
self.split = split
58+
self.attributes=attributes
59+
if isinstance(target_type, list):
60+
self.target_type = target_type
61+
else:
62+
self.target_type = [target_type]
63+
64+
if not self.target_type and self.target_transform is not None:
65+
raise RuntimeError('target_transform is specified but target_type is empty')
66+
67+
if download:
68+
self.download()
69+
70+
# if not self._check_integrity():
71+
# raise RuntimeError('Dataset not found or corrupted.' +
72+
# ' You can use download=True to download it')
73+
74+
split_map = {
75+
"train": 0,
76+
"valid": 1,
77+
"test": 2,
78+
"all": None,
79+
}
80+
split = split_map[verify_str_arg(split.lower(), "split",
81+
("train", "valid", "test", "all"))]
82+
83+
fn = partial(os.path.join, self.root, self.base_folder)
84+
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
85+
identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0)
86+
bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
87+
landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
88+
attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
89+
attr = attr[self.attributes]
90+
91+
mask = slice(None) if split is None else (splits[1] == split)
92+
93+
self.filename = splits[mask].index.values
94+
self.identity = torch.as_tensor(identity[mask].values)
95+
self.bbox = torch.as_tensor(bbox[mask].values)
96+
self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
97+
self.attr = torch.as_tensor(attr[mask].values)
98+
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
99+
self.attr_names = list(attr.columns)
100+
101+
def _check_integrity(self):
102+
for (_, md5, filename) in self.file_list:
103+
fpath = os.path.join(self.root, self.base_folder, filename)
104+
_, ext = os.path.splitext(filename)
105+
# Allow original archive to be deleted (zip and 7z)
106+
# Only need the extracted images
107+
if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
108+
return False
109+
110+
# Should check a hash of the images
111+
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
112+
113+
def download(self):
114+
import zipfile
115+
116+
for (file_id, md5, filename) in self.file_list:
117+
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename)
118+
119+
with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
120+
f.extractall(os.path.join(self.root, self.base_folder))
121+
122+
def __getitem__(self, index):
123+
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
124+
125+
target = []
126+
for t in self.target_type:
127+
if t == "attr":
128+
target.append(self.attr[index, :])
129+
elif t == "identity":
130+
target.append(self.identity[index, 0])
131+
elif t == "bbox":
132+
target.append(self.bbox[index, :])
133+
elif t == "landmarks":
134+
target.append(self.landmarks_align[index, :])
135+
else:
136+
# TODO: refactor with utils.verify_str_arg
137+
raise ValueError("Target type \"{}\" is not recognized.".format(t))
138+
139+
if self.transform is not None:
140+
X = self.transform(X)
141+
142+
if target:
143+
target = tuple(target) if len(target) > 1 else target[0]
144+
145+
if self.target_transform is not None:
146+
target = self.target_transform(target)
147+
else:
148+
target = None
149+
150+
return X, target
151+
152+
def __len__(self):
153+
return len(self.attr)
154+
155+
def extra_repr(self):
156+
lines = ["Target type: {target_type}", "Split: {split}"]
157+
return '\n'.join(lines).format(**self.__dict__)

0 commit comments

Comments
 (0)