Skip to content

Commit

Permalink
Fix tensor devices for DARTS Trial (#2273)
Browse files Browse the repository at this point in the history
* Update architect.py

[email protected]

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>

* Update run_trial.py

[email protected]

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>

* Update architect.py

[email protected]

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>

---------

Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>
  • Loading branch information
sifa1024 authored Mar 10, 2024
1 parent a2f3fca commit 61406a5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
18 changes: 13 additions & 5 deletions examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ class Architect():
"""" Architect controls architecture of cell by computing gradients of alphas
"""

def __init__(self, model, w_momentum, w_weight_decay):
def __init__(self, model, w_momentum, w_weight_decay, device):
self.model = model
self.v_model = copy.deepcopy(model)
self.w_momentum = w_momentum
self.w_weight_decay = w_weight_decay
self.device = device

def virtual_step(self, train_x, train_y, xi, w_optim):
"""
Expand All @@ -43,17 +44,21 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
# Forward and calculate loss
# Loss for train with w. L_train(w)
loss = self.model.loss(train_x, train_y)

# Compute gradient
gradients = torch.autograd.grad(loss, self.model.getWeights())

# Do virtual step (Update gradient)
# Below operations do not need gradient tracking
with torch.no_grad():
# dict key is not the value, but the pointer. So original network weight have to
# be iterated also.
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
if(self.device == 'cuda'):
vw.copy_(w - torch.cuda.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
elif(self.device == 'cpu'):
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))

# Sync alphas
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()):
Expand All @@ -71,7 +76,7 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
# Calculate unrolled loss
# Loss for validation with w'. L_valid(w')
loss = self.v_model.loss(valid_x, valid_y)

# Calculate gradient
v_alphas = tuple(self.v_model.getAlphas())
v_weights = tuple(self.v_model.getWeights())
Expand All @@ -85,7 +90,10 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
# Update final gradient = dalpha - xi * hessian
with torch.no_grad():
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
alpha.grad = da - torch.FloatTensor(xi) * h
if(self.device == 'cuda'):
alpha.grad = da - torch.cuda.FloatTensor(xi) * h
elif(self.device == 'cpu'):
alpha.grad = da - torch.cpu.FloatTensor(xi) * h

def compute_hessian(self, dws, train_x, train_y):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def main():
num_epochs,
eta_min=w_lr_min)

architect = Architect(model, w_momentum, w_weight_decay)
architect = Architect(model, w_momentum, w_weight_decay, device)

# Start training
best_top1 = 0.
Expand Down

0 comments on commit 61406a5

Please sign in to comment.