-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathWolf.py
219 lines (178 loc) · 6.92 KB
/
Wolf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#MIT License Copyright 2024 joshuah rainstar
from torch.optim.optimizer import Optimizer
class Wolf(Optimizer):
"""Implements Wolf algorithm."""
def __init__(self, params, lr=0.25, betas=(0.9, 0.999), eps=1e-8):
# Define default parameters
defaults = dict(lr=lr, betas=betas, eps=eps)
self.lr = lr
# Initialize the parent Optimizer class first
super().__init__(params, defaults)
# Constants specific to Wolf
# Initialize state for each parameter
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['p'] = torch.zeros_like(p) # Second moment estimate
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
Returns:
the loss.
"""
etcerta = 0.367879441
et = 1 - etcerta
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
# Update step count
# Perform stepweight decay
grad = p.grad
state = self.state[p]
# State initialization
exp_avg = state['p']
# Weight update
update = exp_avg * et + grad * etcerta
state['p'] = exp_avg * et + update * etcerta
sign_agreement = torch.sign(update) * torch.sign(grad)
# The value to use for adaptive_alpha depends upon your model.
#in general, test and set it as high as you can without it exploding
#it may favor a backoff where it starts small and then gets larger as the model converges on the global minimum
adaptive_alpha = self.lr
# Where signs agree (positive), apply normal update
mask = (sign_agreement > 0)
p.data = torch.where(mask,
p.data - adaptive_alpha * update,
p.data)
return loss
import torch
from torch.optim.optimizer import Optimizer
class TigerOptimizer(Optimizer):
def __init__(self, model, params, lr=0.01, betas=(0.1, 0.1)): #betas just for the optimizer generator
self.model = model
defaults = dict(lr=lr, betas=betas)
super().__init__(params, defaults)
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['p'] = torch.zeros_like(p)
#somewhere in your training loop:
# def closure():
# optimizer.zero_grad()
# outputs = net(input)
# loss = loss_function(outputs, labels)
# loss.backward()
# return loss
def step(self, closure):
etcerta = 0.367879441
et = 1 - etcerta
# First compute initial loss and grads
loss = closure()
init_weights = {}
init_grads = {}
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
init_weights[p] = p.data.clone()
init_grads[p] = p.grad.clone()
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.data -= (2/3) * p.grad
# First step evaluation
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.data -= 2/3* (p.grad + init_grads[p]) # Take second step
# Second step evaluation
rko_grads = {}
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
rko_grads[p]= init_weights [p]- (0.25*init_weights [p]+ 0.75* p.data) #Ralston's
# Process all updates
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
exp_avg = state['p']
update = exp_avg * et + rko_grads[p] * etcerta
state['p'] = exp_avg * et + update * etcerta
# Reset and apply final update
#note: for some types of optimization problems, ADD, do not subtract, the update
p.data = init_weights[p] - update * 0.5
p.grad.zero_()
return loss
import torch
from torch.optim.optimizer import Optimizer
class WolfLearnRate(Optimizer):
"""Implements Wolf algorithm."""
def __init__(self, params, lr=0.25, betas=(0.9, 0.999), eps=1e-8):
# Define default parameters
defaults = dict(lr=lr, betas=betas, eps=eps)
self.lr = lr
self.t = 0
self.alpha=256 #not yet figured out
# Initialize the parent Optimizer class first
super().__init__(params, defaults)
# Constants specific to Wolf
# Initialize state for each parameter
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['p'] = torch.zeros_like(p) # Second moment estimate
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
Returns:
the loss.
"""
etcerta = 0.367879441
et = 1 - etcerta
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
# Update step count
# Perform stepweight decay
grad = p.grad
state = self.state[p]
# State initialization
exp_avg = state['p']
# Weight update
update = exp_avg * et + grad * etcerta
state['p'] = exp_avg * et + update * etcerta
sign_agreement = torch.sign(update) * torch.sign(grad)
# The value to use for adaptive_alpha depends upon your model.
#in general, test and set it as high as you can without it exploding
#it may favor a backoff where it starts small and then gets larger as the model converges on the global minimum
print(p.data)
adaptive_alpha = self.lr
# Where signs agree (positive), apply normal update
mask = (sign_agreement > 0)
lr = 2/(self.t/self.alpha + 2) - math.log(2/(self.t/self.alpha + 2) + 1)
p.data = torch.where(mask,
p.data - lr * update,
p.data)
p.grad.zero_()
self.t = self.t + 1
return loss