diff --git a/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py b/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py index c3d649a11db..c449a6a1fe9 100644 --- a/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py +++ b/examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py @@ -20,11 +20,12 @@ class Architect(): """" Architect controls architecture of cell by computing gradients of alphas """ - def __init__(self, model, w_momentum, w_weight_decay): + def __init__(self, model, w_momentum, w_weight_decay, device): self.model = model self.v_model = copy.deepcopy(model) self.w_momentum = w_momentum self.w_weight_decay = w_weight_decay + self.device = device def virtual_step(self, train_x, train_y, xi, w_optim): """ @@ -43,9 +44,10 @@ def virtual_step(self, train_x, train_y, xi, w_optim): # Forward and calculate loss # Loss for train with w. L_train(w) loss = self.model.loss(train_x, train_y) + # Compute gradient gradients = torch.autograd.grad(loss, self.model.getWeights()) - + # Do virtual step (Update gradient) # Below operations do not need gradient tracking with torch.no_grad(): @@ -53,7 +55,10 @@ def virtual_step(self, train_x, train_y, xi, w_optim): # be iterated also. for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients): m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum - vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w)) + if(self.device == 'cuda'): + vw.copy_(w - torch.cuda.FloatTensor(xi) * (m + g + self.w_weight_decay * w)) + elif(self.device == 'cpu'): + vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w)) # Sync alphas for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()): @@ -71,7 +76,7 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim): # Calculate unrolled loss # Loss for validation with w'. L_valid(w') loss = self.v_model.loss(valid_x, valid_y) - + # Calculate gradient v_alphas = tuple(self.v_model.getAlphas()) v_weights = tuple(self.v_model.getWeights()) @@ -85,7 +90,10 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim): # Update final gradient = dalpha - xi * hessian with torch.no_grad(): for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian): - alpha.grad = da - torch.FloatTensor(xi) * h + if(self.device == 'cuda'): + alpha.grad = da - torch.cuda.FloatTensor(xi) * h + elif(self.device == 'cpu'): + alpha.grad = da - torch.cpu.FloatTensor(xi) * h def compute_hessian(self, dws, train_x, train_y): """ diff --git a/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py b/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py index ceb74dfc5e3..a9836d240cd 100644 --- a/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py +++ b/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py @@ -140,7 +140,7 @@ def main(): num_epochs, eta_min=w_lr_min) - architect = Architect(model, w_momentum, w_weight_decay) + architect = Architect(model, w_momentum, w_weight_decay, device) # Start training best_top1 = 0.