Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Extreme low shot (#279)
Browse files Browse the repository at this point in the history
Summary:
Extreme low shot benchmark as done in MSN (https://arxiv.org/pdf/2204.07141.pdf):
- limited number of sample per classes (1 image in this case)
- logistic regression with strong regularisation after feature extraction
- tested on MSN B16 where it reaches 51.3%

X-link: fairinternal/ssl_scaling#279

Reviewed By: mannatsingh

Differential Revision: D38009053

Pulled By: QuentinDuval

fbshipit-source-id: 1556ee78529b8879f56ed12cd1288310fd319e3b
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Jul 27, 2022
1 parent 0d465ce commit 4ca0ae3
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# @package _global_
config:
DATA:
NUM_DATALOADER_WORKERS: 5
TRAIN:
DATA_SOURCES: [disk_folder]
LABEL_SOURCES: [disk_folder]
DATASET_NAMES: [imagenet1k_folder]
BATCHSIZE_PER_REPLICA: 32
DATA_LIMIT: 1000
DATA_LIMIT_SAMPLING:
IS_BALANCED: True
TRANSFORMS:
- name: Resize
size: 256
- name: CenterCrop
size: 224
- name: ToTensor
- name: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
MMAP_MODE: False
COPY_TO_LOCAL_DISK: False
COPY_DESTINATION_DIR: /tmp/imagenet1k/
TEST:
DATA_SOURCES: [disk_folder]
LABEL_SOURCES: [disk_folder]
DATASET_NAMES: [imagenet1k_folder]
BATCHSIZE_PER_REPLICA: 32
TRANSFORMS:
- name: Resize
size: 256
- name: CenterCrop
size: 224
- name: ToTensor
- name: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
MMAP_MODE: False
COPY_TO_LOCAL_DISK: False
COPY_DESTINATION_DIR: /tmp/imagenet1k/
MODEL:
FEATURE_EVAL_SETTINGS:
EVAL_MODE_ON: True
FREEZE_TRUNK_AND_HEAD: True
EVAL_TRUNK_AND_HEAD: True
TRUNK:
NAME: resnet
RESNETS:
DEPTH: 50
HEAD:
PARAMS: [
["identity", {}],
]
WEIGHTS_INIT:
PARAMS_FILE: "specify the model weights"
STATE_DICT_KEY_NAME: classy_state_dict
LOGISTIC_REGRESSION:
LAMBDA: 0.1
FEATURES:
PATH: ''
DISTRIBUTED:
NUM_NODES: 1
NUM_PROC_PER_NODE: 8
MACHINE:
DEVICE: gpu
CHECKPOINT:
DIR: .
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_
config:
MODEL:
TRUNK: # S-16
NAME: vision_transformer
VISION_TRANSFORMERS:
IMAGE_SIZE: 224
PATCH_SIZE: 16
HIDDEN_DIM: 384
NUM_LAYERS: 12
NUM_HEADS: 6
MLP_DIM: 1536
CLASSIFIER: token
DROPOUT_RATE: 0
ATTENTION_DROPOUT_RATE: 0
QKV_BIAS: True
DROP_PATH_RATE: 0.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# @package _global_
config:
MODEL:
FEATURE_EVAL_SETTINGS:
LINEAR_EVAL_FEAT_POOL_OPS_MAP: [
["concatCLS4", ["Identity", []] ],
["lastCLS", ["Identity", []] ],
["concatCLS4", ["Identity", []] ],
["lastCLS", ["Identity", []] ],
]
TRUNK: # Tiny
NAME: vision_transformer
VISION_TRANSFORMERS:
IMAGE_SIZE: 224
PATCH_SIZE: 16
NUM_LAYERS: 12
NUM_HEADS: 3
HIDDEN_DIM: 192
MLP_DIM: 768
CLASSIFIER: token
DROPOUT_RATE: 0
ATTENTION_DROPOUT_RATE: 0
QKV_BIAS: True
DROP_PATH_RATE: 0.0
HEAD:
PARAMS: [
["eval_mlp", {"in_channels": 768, "dims": [768, 1000]}],
["eval_mlp", {"in_channels": 192, "dims": [192, 1000]}],
["mlp", {"dims": [768, 1000]}],
["mlp", {"dims": [192, 1000]}],
]
OPTIMIZER:
regularize_bn: True
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_
config:
MODEL:
TRUNK: # L-16
NAME: vision_transformer
VISION_TRANSFORMERS:
IMAGE_SIZE: 224
PATCH_SIZE: 16
NUM_LAYERS: 12
NUM_HEADS: 12
HIDDEN_DIM: 768
MLP_DIM: 3072
DROPOUT_RATE: 0.0
ATTENTION_DROPOUT_RATE: 0
CLASSIFIER: token
QKV_BIAS: True
DROP_PATH_RATE: 0.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_
config:
MODEL:
TRUNK: # L-16
NAME: vision_transformer
VISION_TRANSFORMERS:
IMAGE_SIZE: 224
PATCH_SIZE: 16
NUM_LAYERS: 24
NUM_HEADS: 16
HIDDEN_DIM: 1024
MLP_DIM: 4096
DROPOUT_RATE: 0
ATTENTION_DROPOUT_RATE: 0
CLASSIFIER: token
QKV_BIAS: True
DROP_PATH_RATE: 0.0
37 changes: 37 additions & 0 deletions tools/low_shot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
from argparse import Namespace
from typing import Any, List

from vissl.config import AttrDict
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
from vissl.utils.low_shot_utils import (
extract_features_and_low_shot,
extract_features_and_low_shot_on_slurm,
)
from vissl.utils.slurm import is_submitit_available


def main(args: Namespace, config: AttrDict):
if config.SLURM.USE_SLURM:
assert (
is_submitit_available()
), "Please 'pip install submitit' to schedule jobs on SLURM"
extract_features_and_low_shot_on_slurm(config)
else:
extract_features_and_low_shot(args.node_id, config)


def hydra_main(overrides: List[Any]):
cfg = compose_hydra_configuration(overrides)
args, config = convert_to_attrdict(cfg)
main(args, config)


if __name__ == "__main__":
overrides = sys.argv[1:]
hydra_main(overrides=overrides)
17 changes: 17 additions & 0 deletions vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,23 @@ config:
# If empty, will run the extract features, if not, use the path to find the features
PATH: ""

# ----------------------------------------------------------------------------------- #
# LOGISTIC REGRESSION (benchmark)
# ----------------------------------------------------------------------------------- #
LOW_SHOT_BENCHMARK:
# Which evaluation method to pick to choose among:
# - logistic_regression
# - svm
METHOD: 'logistic_regression'
# Configuration if we use logistic regression for evaluation
LOGISTIC_REGRESSION:
# Regularisation term for L2 penalty
LAMBDA: 0.025
# Where to find the extracted features and whether or not we should extract them
FEATURES:
# If empty, will run the extract features, if not, use the path to find the features
PATH: ""

# ----------------------------------------------------------------------------------- #
# Geo Localization (benchmark)
# ----------------------------------------------------------------------------------- #
Expand Down
166 changes: 166 additions & 0 deletions vissl/utils/low_shot_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

from vissl.config import AttrDict
from vissl.hooks import default_hook_generator
from vissl.models.model_helpers import get_trunk_output_feature_names
from vissl.utils.checkpoint import get_checkpoint_folder
from vissl.utils.distributed_launcher import (
create_submitit_executor,
launch_distributed,
)
from vissl.utils.env import set_env_vars
from vissl.utils.extract_features_utils import ExtractedFeaturesLoader
from vissl.utils.hydra_config import print_cfg
from vissl.utils.logger import setup_logging, shutdown_logging
from vissl.utils.svm_utils.svm_trainer import SVMTrainer


def load_features(feature_dir: str, layer_name: str):
train_out = ExtractedFeaturesLoader.load_features(
feature_dir, "train", layer_name, flatten_features=True
)
train_features, train_labels = train_out["features"], train_out["targets"]
test_out = ExtractedFeaturesLoader.load_features(
feature_dir, "test", layer_name, flatten_features=True
)
test_features, test_labels = test_out["features"], test_out["targets"]
return train_features, train_labels, test_features, test_labels


def run_low_shot_logistic_regression(config: AttrDict, layer_name: str = "heads"):
"""
Run the Nearest Neighbour benchmark at the layer "layer_name"
"""
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

# -- get train and test features
feature_dir = config.LOW_SHOT_BENCHMARK.FEATURES.PATH
train_features, train_labels, test_features, test_labels = load_features(
feature_dir, layer_name
)

# -- Scale the features based on training statistics
scaler = StandardScaler()
scaler.fit(train_features)
train_features = scaler.transform(train_features)
test_features = scaler.transform(test_features)

# -- Fit Logistic Regression Classifier
method_config = config.LOW_SHOT_BENCHMARK.LOGISTIC_REGRESSION
lambd = method_config.LAMBDA
lambd /= len(train_labels)
classifier = LogisticRegression(
penalty="l2",
C=1 / lambd,
multi_class="multinomial",
)
classifier.fit(
train_features,
train_labels,
)

# -- Evaluate on train and test set
train_score = classifier.score(train_features, train_labels)
print("Train score: ", train_score)
test_score = classifier.score(test_features, test_labels)
print("Test score: ", test_score)
return test_score


def run_low_shot_svm(config: AttrDict, layer_name: str = "heads"):
# setup the environment variables
set_env_vars(local_rank=0, node_id=0, cfg=config)

features_dir = config.LOW_SHOT_BENCHMARK.FEATURES.PATH
output_dir = get_checkpoint_folder(config)

# Train the svm
logging.info(f"Training SVM for layer: {layer_name}")
trainer = SVMTrainer(config.SVM, layer=layer_name, output_dir=output_dir)
train_data = ExtractedFeaturesLoader.load_features(
features_dir, "train", layer_name, flatten_features=True
)
trainer.train(train_data["features"], train_data["targets"])

# Test the svm
test_data = ExtractedFeaturesLoader.load_features(
features_dir, "test", layer_name, flatten_features=True
)
trainer.test(test_data["features"], test_data["targets"])
logging.info("All Done!")


def run_low_shot_all_layers(config: AttrDict):
"""
Get the names of the features that we are extracting. If user doesn't
specify the features to evaluate, we get the full model output and freeze
head/trunk both as caution.
"""
feat_names = get_trunk_output_feature_names(config.MODEL)
if len(feat_names) == 0:
feat_names = ["heads"]

for layer in feat_names:
if config.LOW_SHOT_BENCHMARK.METHOD == "logistic_regression":
top_1 = run_low_shot_logistic_regression(config, layer_name=layer)
else:
top_1 = run_low_shot_svm(config, layer_name=layer)
logging.info(f"layer: {layer}, Top1: {top_1}")


def extract_features_and_low_shot(node_id: int, config: AttrDict):
setup_logging(__name__)
print_cfg(config)
set_env_vars(local_rank=0, node_id=0, cfg=config)

# Extract the features if no path to the extract features is provided
if not config.LOW_SHOT_BENCHMARK.FEATURES.PATH:
launch_distributed(
config,
node_id,
engine_name="extract_features",
hook_generator=default_hook_generator,
)
config.LOW_SHOT_BENCHMARK.FEATURES.PATH = get_checkpoint_folder(config)

# Run KNN on all the extract features
run_low_shot_all_layers(config)

# close the logging streams including the file handlers
shutdown_logging()


class _ResumableLowShotSlurmJob:
def __init__(self, config: AttrDict):
self.config = config

def __call__(self):
import submitit

environment = submitit.JobEnvironment()
node_id = environment.global_rank
master_ip = environment.hostnames[0]
master_port = self.config.SLURM.PORT_ID
self.config.DISTRIBUTED.INIT_METHOD = "tcp"
self.config.DISTRIBUTED.RUN_ID = f"{master_ip}:{master_port}"
extract_features_and_low_shot(node_id=node_id, config=self.config)

def checkpoint(self):
import submitit

trainer = _ResumableLowShotSlurmJob(config=self.config)
return submitit.helpers.DelayedSubmission(trainer)


def extract_features_and_low_shot_on_slurm(cfg):
executor = create_submitit_executor(cfg)
trainer = _ResumableLowShotSlurmJob(config=cfg)
job = executor.submit(trainer)
print(f"SUBMITTED: {job.job_id}")
return job
2 changes: 1 addition & 1 deletion vissl/utils/svm_utils/svm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _get_svm_model_filename(self, cls_num, cost):

def get_best_cost_value(self):
"""
During the SVM training, we write the cross vaildation
During the SVM training, we write the cross validation
AP value for training at each class and cost value
combination. We load the AP values and for each
class, determine the cost value that gives the maximum
Expand Down

0 comments on commit 4ca0ae3

Please sign in to comment.