Skip to content
This repository was archived by the owner on Oct 19, 2023. It is now read-only.

Commit 4bc72e7

Browse files
committed
update gan example code
1 parent 4eec99c commit 4bc72e7

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

09-경쟁하며_학습하는_GAN/gan.ipynb

+10-4
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,20 @@
7272
"outputs": [],
7373
"source": [
7474
"# Fashion MNIST 데이터셋\n",
75-
"trainset = datasets.FashionMNIST('./.data',\n",
75+
"trainset = datasets.FashionMNIST(\n",
76+
" './.data',\n",
7677
" train=True,\n",
7778
" download=True,\n",
7879
" transform=transforms.Compose([\n",
7980
" transforms.ToTensor(),\n",
8081
" transforms.Normalize((0.5,), (0.5,))\n",
81-
" ]))\n",
82+
" ])\n",
83+
")\n",
8284
"train_loader = torch.utils.data.DataLoader(\n",
8385
" dataset = trainset,\n",
8486
" batch_size = BATCH_SIZE,\n",
85-
" shuffle = True)"
87+
" shuffle = True\n",
88+
")"
8689
]
8790
},
8891
{
@@ -722,9 +725,12 @@
722725
" d_loss_fake = criterion(outputs, fake_labels)\n",
723726
" fake_score = outputs\n",
724727
" \n",
725-
" # 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차를 계산 후 학습\n",
728+
" # 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차 계산\n",
726729
" d_loss = d_loss_real + d_loss_fake\n",
730+
"\n",
731+
" # 역전파 알고리즘으로 판별자 모델의 학습을 진행\n",
727732
" d_optimizer.zero_grad()\n",
733+
" g_optimizer.zero_grad()\n",
728734
" d_loss.backward()\n",
729735
" d_optimizer.step()\n",
730736
" \n",

09-경쟁하며_학습하는_GAN/gan.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,20 @@
3030
# 학습에 필요한 데이터셋을 로딩합니다.
3131

3232
# Fashion MNIST 데이터셋
33-
trainset = datasets.FashionMNIST('./.data',
33+
trainset = datasets.FashionMNIST(
34+
'./.data',
3435
train=True,
3536
download=True,
3637
transform=transforms.Compose([
3738
transforms.ToTensor(),
3839
transforms.Normalize((0.5,), (0.5,))
39-
]))
40+
])
41+
)
4042
train_loader = torch.utils.data.DataLoader(
4143
dataset = trainset,
4244
batch_size = BATCH_SIZE,
43-
shuffle = True)
45+
shuffle = True
46+
)
4447

4548

4649
# 생성자는 64차원의 랜덤한 텐서를 입력받아 이에 행렬곱(Linear)과 활성화 함수(ReLU, Tanh) 연산을 실행합니다. 생성자의 결과값은 784차원, 즉 Fashion MNIST 속의 이미지와 같은 차원의 텐서입니다.
@@ -105,9 +108,12 @@
105108
d_loss_fake = criterion(outputs, fake_labels)
106109
fake_score = outputs
107110

108-
# 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차를 계산 후 학습
111+
# 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차 계산
109112
d_loss = d_loss_real + d_loss_fake
113+
114+
# 역전파 알고리즘으로 판별자 모델의 학습을 진행
110115
d_optimizer.zero_grad()
116+
g_optimizer.zero_grad()
111117
d_loss.backward()
112118
d_optimizer.step()
113119

0 commit comments

Comments
 (0)