Skip to content

Commit 58479bc

Browse files
committed
resolved refactoring conflict
1 parent 5e760c5 commit 58479bc

File tree

5 files changed

+61
-13
lines changed

5 files changed

+61
-13
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,26 @@ Hyperparameter | Value
4444
It seems that it can be trained more.
4545
However, if train more, the reconstruction loss increase.
4646
Based on this weird behavior, I recommand 15 epochs to train (max 20).
47+
48+
## Dependency
49+
50+
Refer [requirements](https://github.com/ktaebum/AttentionedDeepPaint/tree/master/requirements.txt)
51+
52+
Install
53+
1. Pytorch (>= 0.4.1)
54+
2. Torchvision (>= 0.2.1)
55+
based on your python version, os, cuda version etc...
56+
57+
## Usage
58+
59+
### Download dataset
60+
61+
1. Create **data** folder
62+
2. go to [link](https://www.kaggle.com/ktaebum/animesketchcolorpair) and download
63+
3. unzip in **data** folder
64+
65+
### Train
66+
67+
`
68+
$ ./train.sh
69+
`

requirements.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
certifi==2018.10.15
2+
cloudpickle==0.6.1
3+
cycler==0.10.0
4+
dask==1.0.0
5+
decorator==4.3.0
6+
kiwisolver==1.0.1
7+
matplotlib==3.0.2
8+
networkx==2.2
9+
numpy==1.15.4
10+
Pillow==5.3.0
11+
pyparsing==2.3.0
12+
python-dateutil==2.7.5
13+
PyWavelets==1.0.1
14+
scikit-image==0.14.1
15+
scipy==1.1.0
16+
six==1.11.0
17+
toolz==0.9.0
18+
torch==0.4.1
19+
torchvision==0.2.1

train.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
python train.py --learning-rate 0.0002 --beta1 0.5 --verbose \
2-
--batch-size 4 --save-every 5 --lambd 100 --model deepunet \
3-
--sample 4 --no-mse --norm batch --num-epochs 20 --print-every 300 \
4-
--train
1+
python train.py --verbose --save-every 5 --model deepunet \
2+
--sample 4 --no-mse --num-epochs 20 --print-every 300 \

trainer/deepunet.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ def __init__(self, *args):
3030
# log file
3131
if self.args.train:
3232
ctime = time.ctime().split()
33+
3334
log_path = './log'
35+
if not os.path.exists(log_path):
36+
os.mkdir(log_path)
37+
3438
log_dir = os.path.join(
3539
log_path,
3640
'%s_%s_%s_%s' % (ctime[-1], ctime[1], ctime[2], ctime[3]))
@@ -39,8 +43,11 @@ def __init__(self, *args):
3943
f.write(str(args))
4044
self.log_file = open(os.path.join(log_dir, 'loss.txt'), 'w')
4145

46+
self.save_path = './data/result'
47+
if not os.path.exists(self.save_path):
48+
os.mkdir(self.save_path)
49+
4250
# build model
43-
self.resolution = self.args.resolution
4451
self.generator = DeepUNetPaintGenerator().to(self.device)
4552
self.discriminator = PatchGAN(sigmoid=self.args.no_mse).to(self.device)
4653

@@ -215,19 +222,19 @@ def validate(self, dataset, epoch, samples=3):
215222
(0, 0))
216223
color_result.paste(
217224
color2.crop((0, 0, self.resolution, self.resolution // 4)),
218-
(0, 512 // 4))
225+
(0, self.resolution // 4))
219226
color_result.paste(
220227
color3.crop((0, 0, self.resolution, self.resolution // 4)),
221-
(0, 512 // 4 * 2))
228+
(0, self.resolution // 4 * 2))
222229
color_result.paste(
223230
color4.crop((0, 0, self.resolution, self.resolution // 4)),
224-
(0, 512 // 4 * 3))
231+
(0, self.resolution // 4 * 3))
225232

226233
sub_result.paste(imageA, (0, 0))
227-
sub_result.paste(styleB, (512, 0))
228-
sub_result.paste(fakeB, (2 * 512, 0))
229-
sub_result.paste(imageB, (3 * 512, 0))
230-
sub_result.paste(color_result, (4 * 512, 0))
234+
sub_result.paste(styleB, (self.resolution, 0))
235+
sub_result.paste(fakeB, (2 * self.resolution, 0))
236+
sub_result.paste(imageB, (3 * self.resolution, 0))
237+
sub_result.paste(color_result, (4 * self.resolution, 0))
231238

232239
result.paste(sub_result, (0, 0 + self.resolution * i))
233240

@@ -243,7 +250,7 @@ def validate(self, dataset, epoch, samples=3):
243250
save_image(
244251
result,
245252
'deepunetpaint_%03d_%02d' % (epoch, j),
246-
'./data/pair_niko/result',
253+
self.save_path,
247254
)
248255

249256
def test(self):

trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self, args, data_loader, device):
1616
self.args = args
1717
self.data_loader = data_loader
1818
self.device = device
19+
self.resolution = 512
1920

2021
def train(self):
2122
raise NotImplementedError

0 commit comments

Comments
 (0)