This repository has been archived by the owner on Mar 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 333
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Extreme low shot benchmark as done in MSN (https://arxiv.org/pdf/2204.07141.pdf): - limited number of sample per classes (1 image in this case) - logistic regression with strong regularisation after feature extraction - tested on MSN B16 where it reaches 51.3% X-link: fairinternal/ssl_scaling#279 Reviewed By: mannatsingh Differential Revision: D38009053 Pulled By: QuentinDuval fbshipit-source-id: 1556ee78529b8879f56ed12cd1288310fd319e3b
- Loading branch information
1 parent
0d465ce
commit 4ca0ae3
Showing
9 changed files
with
373 additions
and
1 deletion.
There are no files selected for viewing
68 changes: 68 additions & 0 deletions
68
configs/config/benchmark/low_shot_transfer/imagenet/logistic_regression_in1k.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# @package _global_ | ||
config: | ||
DATA: | ||
NUM_DATALOADER_WORKERS: 5 | ||
TRAIN: | ||
DATA_SOURCES: [disk_folder] | ||
LABEL_SOURCES: [disk_folder] | ||
DATASET_NAMES: [imagenet1k_folder] | ||
BATCHSIZE_PER_REPLICA: 32 | ||
DATA_LIMIT: 1000 | ||
DATA_LIMIT_SAMPLING: | ||
IS_BALANCED: True | ||
TRANSFORMS: | ||
- name: Resize | ||
size: 256 | ||
- name: CenterCrop | ||
size: 224 | ||
- name: ToTensor | ||
- name: Normalize | ||
mean: [0.485, 0.456, 0.406] | ||
std: [0.229, 0.224, 0.225] | ||
MMAP_MODE: False | ||
COPY_TO_LOCAL_DISK: False | ||
COPY_DESTINATION_DIR: /tmp/imagenet1k/ | ||
TEST: | ||
DATA_SOURCES: [disk_folder] | ||
LABEL_SOURCES: [disk_folder] | ||
DATASET_NAMES: [imagenet1k_folder] | ||
BATCHSIZE_PER_REPLICA: 32 | ||
TRANSFORMS: | ||
- name: Resize | ||
size: 256 | ||
- name: CenterCrop | ||
size: 224 | ||
- name: ToTensor | ||
- name: Normalize | ||
mean: [0.485, 0.456, 0.406] | ||
std: [0.229, 0.224, 0.225] | ||
MMAP_MODE: False | ||
COPY_TO_LOCAL_DISK: False | ||
COPY_DESTINATION_DIR: /tmp/imagenet1k/ | ||
MODEL: | ||
FEATURE_EVAL_SETTINGS: | ||
EVAL_MODE_ON: True | ||
FREEZE_TRUNK_AND_HEAD: True | ||
EVAL_TRUNK_AND_HEAD: True | ||
TRUNK: | ||
NAME: resnet | ||
RESNETS: | ||
DEPTH: 50 | ||
HEAD: | ||
PARAMS: [ | ||
["identity", {}], | ||
] | ||
WEIGHTS_INIT: | ||
PARAMS_FILE: "specify the model weights" | ||
STATE_DICT_KEY_NAME: classy_state_dict | ||
LOGISTIC_REGRESSION: | ||
LAMBDA: 0.1 | ||
FEATURES: | ||
PATH: '' | ||
DISTRIBUTED: | ||
NUM_NODES: 1 | ||
NUM_PROC_PER_NODE: 8 | ||
MACHINE: | ||
DEVICE: gpu | ||
CHECKPOINT: | ||
DIR: . |
17 changes: 17 additions & 0 deletions
17
configs/config/benchmark/low_shot_transfer/imagenet/models/deit_s16.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# @package _global_ | ||
config: | ||
MODEL: | ||
TRUNK: # S-16 | ||
NAME: vision_transformer | ||
VISION_TRANSFORMERS: | ||
IMAGE_SIZE: 224 | ||
PATCH_SIZE: 16 | ||
HIDDEN_DIM: 384 | ||
NUM_LAYERS: 12 | ||
NUM_HEADS: 6 | ||
MLP_DIM: 1536 | ||
CLASSIFIER: token | ||
DROPOUT_RATE: 0 | ||
ATTENTION_DROPOUT_RATE: 0 | ||
QKV_BIAS: True | ||
DROP_PATH_RATE: 0.0 |
33 changes: 33 additions & 0 deletions
33
configs/config/benchmark/low_shot_transfer/imagenet/models/deit_tiny.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# @package _global_ | ||
config: | ||
MODEL: | ||
FEATURE_EVAL_SETTINGS: | ||
LINEAR_EVAL_FEAT_POOL_OPS_MAP: [ | ||
["concatCLS4", ["Identity", []] ], | ||
["lastCLS", ["Identity", []] ], | ||
["concatCLS4", ["Identity", []] ], | ||
["lastCLS", ["Identity", []] ], | ||
] | ||
TRUNK: # Tiny | ||
NAME: vision_transformer | ||
VISION_TRANSFORMERS: | ||
IMAGE_SIZE: 224 | ||
PATCH_SIZE: 16 | ||
NUM_LAYERS: 12 | ||
NUM_HEADS: 3 | ||
HIDDEN_DIM: 192 | ||
MLP_DIM: 768 | ||
CLASSIFIER: token | ||
DROPOUT_RATE: 0 | ||
ATTENTION_DROPOUT_RATE: 0 | ||
QKV_BIAS: True | ||
DROP_PATH_RATE: 0.0 | ||
HEAD: | ||
PARAMS: [ | ||
["eval_mlp", {"in_channels": 768, "dims": [768, 1000]}], | ||
["eval_mlp", {"in_channels": 192, "dims": [192, 1000]}], | ||
["mlp", {"dims": [768, 1000]}], | ||
["mlp", {"dims": [192, 1000]}], | ||
] | ||
OPTIMIZER: | ||
regularize_bn: True |
17 changes: 17 additions & 0 deletions
17
configs/config/benchmark/low_shot_transfer/imagenet/models/vit_b16.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# @package _global_ | ||
config: | ||
MODEL: | ||
TRUNK: # L-16 | ||
NAME: vision_transformer | ||
VISION_TRANSFORMERS: | ||
IMAGE_SIZE: 224 | ||
PATCH_SIZE: 16 | ||
NUM_LAYERS: 12 | ||
NUM_HEADS: 12 | ||
HIDDEN_DIM: 768 | ||
MLP_DIM: 3072 | ||
DROPOUT_RATE: 0.0 | ||
ATTENTION_DROPOUT_RATE: 0 | ||
CLASSIFIER: token | ||
QKV_BIAS: True | ||
DROP_PATH_RATE: 0.0 |
17 changes: 17 additions & 0 deletions
17
configs/config/benchmark/low_shot_transfer/imagenet/models/vit_l16.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# @package _global_ | ||
config: | ||
MODEL: | ||
TRUNK: # L-16 | ||
NAME: vision_transformer | ||
VISION_TRANSFORMERS: | ||
IMAGE_SIZE: 224 | ||
PATCH_SIZE: 16 | ||
NUM_LAYERS: 24 | ||
NUM_HEADS: 16 | ||
HIDDEN_DIM: 1024 | ||
MLP_DIM: 4096 | ||
DROPOUT_RATE: 0 | ||
ATTENTION_DROPOUT_RATE: 0 | ||
CLASSIFIER: token | ||
QKV_BIAS: True | ||
DROP_PATH_RATE: 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import sys | ||
from argparse import Namespace | ||
from typing import Any, List | ||
|
||
from vissl.config import AttrDict | ||
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict | ||
from vissl.utils.low_shot_utils import ( | ||
extract_features_and_low_shot, | ||
extract_features_and_low_shot_on_slurm, | ||
) | ||
from vissl.utils.slurm import is_submitit_available | ||
|
||
|
||
def main(args: Namespace, config: AttrDict): | ||
if config.SLURM.USE_SLURM: | ||
assert ( | ||
is_submitit_available() | ||
), "Please 'pip install submitit' to schedule jobs on SLURM" | ||
extract_features_and_low_shot_on_slurm(config) | ||
else: | ||
extract_features_and_low_shot(args.node_id, config) | ||
|
||
|
||
def hydra_main(overrides: List[Any]): | ||
cfg = compose_hydra_configuration(overrides) | ||
args, config = convert_to_attrdict(cfg) | ||
main(args, config) | ||
|
||
|
||
if __name__ == "__main__": | ||
overrides = sys.argv[1:] | ||
hydra_main(overrides=overrides) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
|
||
from vissl.config import AttrDict | ||
from vissl.hooks import default_hook_generator | ||
from vissl.models.model_helpers import get_trunk_output_feature_names | ||
from vissl.utils.checkpoint import get_checkpoint_folder | ||
from vissl.utils.distributed_launcher import ( | ||
create_submitit_executor, | ||
launch_distributed, | ||
) | ||
from vissl.utils.env import set_env_vars | ||
from vissl.utils.extract_features_utils import ExtractedFeaturesLoader | ||
from vissl.utils.hydra_config import print_cfg | ||
from vissl.utils.logger import setup_logging, shutdown_logging | ||
from vissl.utils.svm_utils.svm_trainer import SVMTrainer | ||
|
||
|
||
def load_features(feature_dir: str, layer_name: str): | ||
train_out = ExtractedFeaturesLoader.load_features( | ||
feature_dir, "train", layer_name, flatten_features=True | ||
) | ||
train_features, train_labels = train_out["features"], train_out["targets"] | ||
test_out = ExtractedFeaturesLoader.load_features( | ||
feature_dir, "test", layer_name, flatten_features=True | ||
) | ||
test_features, test_labels = test_out["features"], test_out["targets"] | ||
return train_features, train_labels, test_features, test_labels | ||
|
||
|
||
def run_low_shot_logistic_regression(config: AttrDict, layer_name: str = "heads"): | ||
""" | ||
Run the Nearest Neighbour benchmark at the layer "layer_name" | ||
""" | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
# -- get train and test features | ||
feature_dir = config.LOW_SHOT_BENCHMARK.FEATURES.PATH | ||
train_features, train_labels, test_features, test_labels = load_features( | ||
feature_dir, layer_name | ||
) | ||
|
||
# -- Scale the features based on training statistics | ||
scaler = StandardScaler() | ||
scaler.fit(train_features) | ||
train_features = scaler.transform(train_features) | ||
test_features = scaler.transform(test_features) | ||
|
||
# -- Fit Logistic Regression Classifier | ||
method_config = config.LOW_SHOT_BENCHMARK.LOGISTIC_REGRESSION | ||
lambd = method_config.LAMBDA | ||
lambd /= len(train_labels) | ||
classifier = LogisticRegression( | ||
penalty="l2", | ||
C=1 / lambd, | ||
multi_class="multinomial", | ||
) | ||
classifier.fit( | ||
train_features, | ||
train_labels, | ||
) | ||
|
||
# -- Evaluate on train and test set | ||
train_score = classifier.score(train_features, train_labels) | ||
print("Train score: ", train_score) | ||
test_score = classifier.score(test_features, test_labels) | ||
print("Test score: ", test_score) | ||
return test_score | ||
|
||
|
||
def run_low_shot_svm(config: AttrDict, layer_name: str = "heads"): | ||
# setup the environment variables | ||
set_env_vars(local_rank=0, node_id=0, cfg=config) | ||
|
||
features_dir = config.LOW_SHOT_BENCHMARK.FEATURES.PATH | ||
output_dir = get_checkpoint_folder(config) | ||
|
||
# Train the svm | ||
logging.info(f"Training SVM for layer: {layer_name}") | ||
trainer = SVMTrainer(config.SVM, layer=layer_name, output_dir=output_dir) | ||
train_data = ExtractedFeaturesLoader.load_features( | ||
features_dir, "train", layer_name, flatten_features=True | ||
) | ||
trainer.train(train_data["features"], train_data["targets"]) | ||
|
||
# Test the svm | ||
test_data = ExtractedFeaturesLoader.load_features( | ||
features_dir, "test", layer_name, flatten_features=True | ||
) | ||
trainer.test(test_data["features"], test_data["targets"]) | ||
logging.info("All Done!") | ||
|
||
|
||
def run_low_shot_all_layers(config: AttrDict): | ||
""" | ||
Get the names of the features that we are extracting. If user doesn't | ||
specify the features to evaluate, we get the full model output and freeze | ||
head/trunk both as caution. | ||
""" | ||
feat_names = get_trunk_output_feature_names(config.MODEL) | ||
if len(feat_names) == 0: | ||
feat_names = ["heads"] | ||
|
||
for layer in feat_names: | ||
if config.LOW_SHOT_BENCHMARK.METHOD == "logistic_regression": | ||
top_1 = run_low_shot_logistic_regression(config, layer_name=layer) | ||
else: | ||
top_1 = run_low_shot_svm(config, layer_name=layer) | ||
logging.info(f"layer: {layer}, Top1: {top_1}") | ||
|
||
|
||
def extract_features_and_low_shot(node_id: int, config: AttrDict): | ||
setup_logging(__name__) | ||
print_cfg(config) | ||
set_env_vars(local_rank=0, node_id=0, cfg=config) | ||
|
||
# Extract the features if no path to the extract features is provided | ||
if not config.LOW_SHOT_BENCHMARK.FEATURES.PATH: | ||
launch_distributed( | ||
config, | ||
node_id, | ||
engine_name="extract_features", | ||
hook_generator=default_hook_generator, | ||
) | ||
config.LOW_SHOT_BENCHMARK.FEATURES.PATH = get_checkpoint_folder(config) | ||
|
||
# Run KNN on all the extract features | ||
run_low_shot_all_layers(config) | ||
|
||
# close the logging streams including the file handlers | ||
shutdown_logging() | ||
|
||
|
||
class _ResumableLowShotSlurmJob: | ||
def __init__(self, config: AttrDict): | ||
self.config = config | ||
|
||
def __call__(self): | ||
import submitit | ||
|
||
environment = submitit.JobEnvironment() | ||
node_id = environment.global_rank | ||
master_ip = environment.hostnames[0] | ||
master_port = self.config.SLURM.PORT_ID | ||
self.config.DISTRIBUTED.INIT_METHOD = "tcp" | ||
self.config.DISTRIBUTED.RUN_ID = f"{master_ip}:{master_port}" | ||
extract_features_and_low_shot(node_id=node_id, config=self.config) | ||
|
||
def checkpoint(self): | ||
import submitit | ||
|
||
trainer = _ResumableLowShotSlurmJob(config=self.config) | ||
return submitit.helpers.DelayedSubmission(trainer) | ||
|
||
|
||
def extract_features_and_low_shot_on_slurm(cfg): | ||
executor = create_submitit_executor(cfg) | ||
trainer = _ResumableLowShotSlurmJob(config=cfg) | ||
job = executor.submit(trainer) | ||
print(f"SUBMITTED: {job.job_id}") | ||
return job |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters