Skip to content

Commit fc30bb1

Browse files
First commit
1 parent 9a5d100 commit fc30bb1

File tree

1 file changed

+12
-39
lines changed

1 file changed

+12
-39
lines changed

docs/source/examples/plot_image_restoration.py

+12-39
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
# Load ProxTorch Logo as jpg then convert to grayscale numpy array
3131
proxtorch_logo = plt.imread("../proxtorch-logo.jpg")
32+
# Downsample to 64x64
33+
proxtorch_logo = proxtorch_logo[::4, ::4]
3234
proxtorch_logo = 1 - np.mean(proxtorch_logo, axis=2)
3335
# Normalize to [0, 1]
3436
proxtorch_logo = (proxtorch_logo - np.min(proxtorch_logo)) / (
@@ -41,76 +43,47 @@ def __init__(self, alpha, l1_ratio):
4143
super().__init__()
4244
self.restored = torch.nn.Parameter(torch.zeros(proxtorch_logo.shape))
4345
self.tvl1_prox = TVL1_2DProx(alpha=alpha, l1_ratio=l1_ratio)
46+
self.automatic_optimization = False
4447

4548
def forward(self, x):
4649
return self.restored
4750

4851
def training_step(self, batch, _):
52+
opt = self.optimizers()
53+
opt.zero_grad()
4954
noisy, original = batch
5055
y_hat = self.restored
5156
loss = torch.sum((y_hat - noisy) ** 2)
52-
self.log("fidelity_loss", loss)
5357
tv_loss = self.tvl1_prox(self.restored)
54-
self.log("tvl1_loss", tv_loss)
55-
return loss
56-
57-
def configure_optimizers(self):
58-
return optim.SGD(self.parameters(), lr=0.01)
59-
60-
def on_train_batch_end(self, _, __, batch_idx: int):
58+
self.manual_backward(loss)
59+
opt.step()
6160
with torch.no_grad():
6261
optimizer = self.trainer.optimizers[0]
6362
self.restored.data = self.tvl1_prox.prox(
6463
self.restored.data, optimizer.param_groups[0]["lr"]
6564
)
6665

67-
68-
class TVRestoration(pl.LightningModule):
69-
def __init__(self, alpha):
70-
super().__init__()
71-
self.restored = torch.nn.Parameter(torch.zeros(proxtorch_logo.shape))
72-
self.tv_prox = TV_2DProx(alpha=alpha)
73-
74-
def forward(self, x):
75-
return self.restored
76-
77-
def training_step(self, batch, _):
78-
noisy, original = batch
79-
y_hat = self.restored
80-
loss = torch.sum((y_hat - noisy) ** 2)
81-
self.log("fidelity_loss", loss)
82-
tv_loss = self.tv_prox(self.restored)
83-
self.log("tv_loss", tv_loss)
84-
return loss
85-
8666
def configure_optimizers(self):
8767
return optim.SGD(self.parameters(), lr=0.01)
8868

89-
def on_train_batch_end(self, _, __, batch_idx: int):
90-
with torch.no_grad():
91-
optimizer = self.trainer.optimizers[0]
92-
self.restored.data = self.tv_prox.prox(
93-
self.restored.data, optimizer.param_groups[0]["lr"]
94-
)
95-
9669

9770
# Data Preparation
9871
noisy_logo = proxtorch_logo + np.random.normal(
99-
loc=0, scale=0.2, size=proxtorch_logo.shape
72+
loc=0, scale=0.1, size=proxtorch_logo.shape
10073
)
10174
dataset = TensorDataset(
10275
torch.tensor(noisy_logo).unsqueeze(0), torch.tensor(proxtorch_logo).unsqueeze(0)
10376
)
10477
loader = DataLoader(dataset, batch_size=1)
10578

10679
# Model Initialization
107-
tv_l1_model = TVL1Restoration(alpha=0.5, l1_ratio=0.5)
108-
tv_model = TVRestoration(alpha=0.5)
80+
tv_l1_model = TVL1Restoration(alpha=0.2, l1_ratio=0.05)
81+
tv_model = TVL1Restoration(alpha=0.2, l1_ratio=0.0)
10982

11083
# Training
111-
trainer = pl.Trainer(max_epochs=200)
84+
trainer = pl.Trainer(max_epochs=50)
11285
trainer.fit(tv_model, loader)
113-
trainer = pl.Trainer(max_epochs=200)
86+
trainer = pl.Trainer(max_epochs=50)
11487
trainer.fit(tv_l1_model, loader)
11588

11689

0 commit comments

Comments
 (0)