Skip to content

Commit b4ef42f

Browse files
committed
Completes the classification implementation
1 parent 91de686 commit b4ef42f

15 files changed

+113
-278
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,13 @@
22
Light-weight and Efficient Networks for Mobile Vision Applications
33

44
## :rocket: News
5-
* Training and evaluation code along with pre-trained models will be released soon. Stay tuned!
5+
* Training and evaluation code along with pre-trained models will be released soon. Stay tuned!
6+
7+
<hr />
8+
9+
![main figure](images/EdgeNext.jpeg)
10+
> **Abstract:** *Designing lightweight general purpose networks for edge devices is a challenging task due to the compute constraints. In this domain, CNN-based light-weight architectures are considered the de-facto choice due to their efficiency in terms of parameters and complexity. However, they are based on spatially local operations and exhibit a limited receptive field. While vision transformers alleviate these issues and can learn global representations, they are typically compute intensive and difficult to optimize. Here, we investigate how to effectively encode both local and global information, while being efficient in terms of both parameters and MAdds on vision tasks. To this end, we propose EdgeNeXt, a hybrid CNN-Transformer architecture that strives to jointly optimize parameters and MAdds for efficient inference on edge devices. Within our EdgeNeXt, we introduce split depthwise transpose attention (SDTA) encoder that splits input tensors into multiple channel groups and utilizes depthwise convolution along with self-attention across channel dimensions to implicitly increase the receptive field and encode multi-scale features. Our extensive experiments on classification, detection and segmentation settings, reveal the merits of the proposed approach, outperforming state-of-the-art methods with comparatively lower compute requirements. Our EdgeNeXt model with 1.3M parameters achieves 71.2\% top-1 accuracy on ImageNet-1K, outperforming MobileViT with an absolute gain of 2.2\% with similar parameters and 28\% reduction in MAdds. Further, our EdgeNeXt model with 5.6M parameters achieves 79.4\% top-1 accuracy on ImageNet-1K.*
11+
<hr />
12+
13+
## Comparison with SOTA ViTs and Hybrid Designs
14+
![main figure](images/Figure_1.png)

datasets.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
3-
# All rights reserved.
4-
5-
# This source code is licensed under the license found in the
6-
# LICENSE file in the root directory of this source tree.
7-
8-
91
import os
102
from torchvision import datasets, transforms
113

@@ -59,7 +51,7 @@ def build_transform(is_train, args):
5951
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
6052

6153
if is_train:
62-
# this should always dispatch to transforms_imagenet_train
54+
# This should always dispatch to transforms_imagenet_train
6355
transform = create_transform(
6456
input_size=args.input_size,
6557
is_training=True,
@@ -87,7 +79,7 @@ def build_transform(is_train, args):
8779

8880
t = []
8981
if resize_im:
90-
# warping (no cropping) when evaluated at 384 or larger
82+
# Warping (no cropping) when evaluated at 384 or larger
9183
if args.input_size >= 384:
9284
t.append(
9385
transforms.Resize((args.input_size, args.input_size),
@@ -99,7 +91,7 @@ def build_transform(is_train, args):
9991
args.crop_pct = 224 / 256
10092
size = int(args.input_size / args.crop_pct)
10193
t.append(
102-
# to maintain same ratio w.r.t. 224 images
94+
# To maintain same ratio w.r.t. 224 images
10395
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
10496
)
10597
t.append(transforms.CenterCrop(args.input_size))

engine.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
3-
# All rights reserved.
4-
5-
# This source code is licensed under the license found in the
6-
# LICENSE file in the root directory of this source tree.
7-
8-
91
import math
102
from typing import Iterable, Optional
113
import torch
@@ -34,7 +26,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
3426
step = data_iter_step // update_freq
3527
if step >= num_training_steps_per_epoch:
3628
continue
37-
it = start_steps + step # global training iteration
29+
it = start_steps + step # Global training iteration
3830
# Update LR & WD for the first acc
3931
if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
4032
for i, param_group in enumerate(optimizer.param_groups):
@@ -53,18 +45,18 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
5345
with torch.cuda.amp.autocast():
5446
output = model(samples)
5547
loss = criterion(output, targets)
56-
else: # full precision
48+
else: # Full precision
5749
output = model(samples)
5850
loss = criterion(output, targets)
5951

6052
loss_value = loss.item()
6153

62-
if not math.isfinite(loss_value): # this could trigger if using AMP
54+
if not math.isfinite(loss_value): # This could trigger if using AMP
6355
print("Loss is {}, stopping training".format(loss_value))
6456
assert math.isfinite(loss_value)
6557

6658
if use_amp:
67-
# this attribute is added by timm on one optimizer (adahessian)
59+
# This attribute is added by timm on one optimizer (adahessian)
6860
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
6961
loss /= update_freq
7062
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
@@ -74,7 +66,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
7466
optimizer.zero_grad()
7567
if model_ema is not None:
7668
model_ema.update(model)
77-
else: # full precision
69+
else: # Full precision
7870
loss /= update_freq
7971
loss.backward()
8072
if (data_iter_step + 1) % update_freq == 0:
@@ -129,7 +121,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
129121
wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
130122
wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
131123

132-
# gather the stats from all processes
124+
# Gather the stats from all processes
133125
metric_logger.synchronize_between_processes()
134126
print("Averaged stats:", metric_logger)
135127
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@@ -142,7 +134,7 @@ def evaluate(data_loader, model, device, use_amp=False):
142134
metric_logger = utils.MetricLogger(delimiter=" ")
143135
header = 'Test:'
144136

145-
# switch to evaluation mode
137+
# Switch to evaluation mode
146138
model.eval()
147139
for batch in metric_logger.log_every(data_loader, 10, header):
148140
images = batch[0]
@@ -151,7 +143,7 @@ def evaluate(data_loader, model, device, use_amp=False):
151143
images = images.to(device, non_blocking=True)
152144
target = target.to(device, non_blocking=True)
153145

154-
# compute output
146+
# Compute output
155147
if use_amp:
156148
with torch.cuda.amp.autocast():
157149
output = model(images)
@@ -166,7 +158,7 @@ def evaluate(data_loader, model, device, use_amp=False):
166158
metric_logger.update(loss=loss.item())
167159
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
168160
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
169-
# gather the stats from all processes
161+
# Gather the stats from all processes
170162
metric_logger.synchronize_between_processes()
171163
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
172164
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

images/EdgeNext.jpeg

587 KB
Loading

images/Figure_1.png

701 KB
Loading

images/table_2.png

77.7 KB
Loading

0 commit comments

Comments
 (0)