29
29
30
30
# Load ProxTorch Logo as jpg then convert to grayscale numpy array
31
31
proxtorch_logo = plt .imread ("../proxtorch-logo.jpg" )
32
+ # Downsample to 64x64
33
+ proxtorch_logo = proxtorch_logo [::4 , ::4 ]
32
34
proxtorch_logo = 1 - np .mean (proxtorch_logo , axis = 2 )
33
35
# Normalize to [0, 1]
34
36
proxtorch_logo = (proxtorch_logo - np .min (proxtorch_logo )) / (
@@ -41,76 +43,47 @@ def __init__(self, alpha, l1_ratio):
41
43
super ().__init__ ()
42
44
self .restored = torch .nn .Parameter (torch .zeros (proxtorch_logo .shape ))
43
45
self .tvl1_prox = TVL1_2DProx (alpha = alpha , l1_ratio = l1_ratio )
46
+ self .automatic_optimization = False
44
47
45
48
def forward (self , x ):
46
49
return self .restored
47
50
48
51
def training_step (self , batch , _ ):
52
+ opt = self .optimizers ()
53
+ opt .zero_grad ()
49
54
noisy , original = batch
50
55
y_hat = self .restored
51
56
loss = torch .sum ((y_hat - noisy ) ** 2 )
52
- self .log ("fidelity_loss" , loss )
53
57
tv_loss = self .tvl1_prox (self .restored )
54
- self .log ("tvl1_loss" , tv_loss )
55
- return loss
56
-
57
- def configure_optimizers (self ):
58
- return optim .SGD (self .parameters (), lr = 0.01 )
59
-
60
- def on_train_batch_end (self , _ , __ , batch_idx : int ):
58
+ self .manual_backward (loss )
59
+ opt .step ()
61
60
with torch .no_grad ():
62
61
optimizer = self .trainer .optimizers [0 ]
63
62
self .restored .data = self .tvl1_prox .prox (
64
63
self .restored .data , optimizer .param_groups [0 ]["lr" ]
65
64
)
66
65
67
-
68
- class TVRestoration (pl .LightningModule ):
69
- def __init__ (self , alpha ):
70
- super ().__init__ ()
71
- self .restored = torch .nn .Parameter (torch .zeros (proxtorch_logo .shape ))
72
- self .tv_prox = TV_2DProx (alpha = alpha )
73
-
74
- def forward (self , x ):
75
- return self .restored
76
-
77
- def training_step (self , batch , _ ):
78
- noisy , original = batch
79
- y_hat = self .restored
80
- loss = torch .sum ((y_hat - noisy ) ** 2 )
81
- self .log ("fidelity_loss" , loss )
82
- tv_loss = self .tv_prox (self .restored )
83
- self .log ("tv_loss" , tv_loss )
84
- return loss
85
-
86
66
def configure_optimizers (self ):
87
67
return optim .SGD (self .parameters (), lr = 0.01 )
88
68
89
- def on_train_batch_end (self , _ , __ , batch_idx : int ):
90
- with torch .no_grad ():
91
- optimizer = self .trainer .optimizers [0 ]
92
- self .restored .data = self .tv_prox .prox (
93
- self .restored .data , optimizer .param_groups [0 ]["lr" ]
94
- )
95
-
96
69
97
70
# Data Preparation
98
71
noisy_logo = proxtorch_logo + np .random .normal (
99
- loc = 0 , scale = 0.2 , size = proxtorch_logo .shape
72
+ loc = 0 , scale = 0.1 , size = proxtorch_logo .shape
100
73
)
101
74
dataset = TensorDataset (
102
75
torch .tensor (noisy_logo ).unsqueeze (0 ), torch .tensor (proxtorch_logo ).unsqueeze (0 )
103
76
)
104
77
loader = DataLoader (dataset , batch_size = 1 )
105
78
106
79
# Model Initialization
107
- tv_l1_model = TVL1Restoration (alpha = 0.5 , l1_ratio = 0.5 )
108
- tv_model = TVRestoration (alpha = 0.5 )
80
+ tv_l1_model = TVL1Restoration (alpha = 0.2 , l1_ratio = 0.05 )
81
+ tv_model = TVL1Restoration (alpha = 0.2 , l1_ratio = 0.0 )
109
82
110
83
# Training
111
- trainer = pl .Trainer (max_epochs = 200 )
84
+ trainer = pl .Trainer (max_epochs = 50 )
112
85
trainer .fit (tv_model , loader )
113
- trainer = pl .Trainer (max_epochs = 200 )
86
+ trainer = pl .Trainer (max_epochs = 50 )
114
87
trainer .fit (tv_l1_model , loader )
115
88
116
89
0 commit comments