diff --git a/configs/config/benchmark/low_shot_transfer/imagenet/logistic_regression_in1k.yaml b/configs/config/benchmark/low_shot_transfer/imagenet/logistic_regression_in1k.yaml new file mode 100644 index 000000000..3cc0af819 --- /dev/null +++ b/configs/config/benchmark/low_shot_transfer/imagenet/logistic_regression_in1k.yaml @@ -0,0 +1,68 @@ +# @package _global_ +config: + DATA: + NUM_DATALOADER_WORKERS: 5 + TRAIN: + DATA_SOURCES: [disk_folder] + LABEL_SOURCES: [disk_folder] + DATASET_NAMES: [imagenet1k_folder] + BATCHSIZE_PER_REPLICA: 32 + DATA_LIMIT: 1000 + DATA_LIMIT_SAMPLING: + IS_BALANCED: True + TRANSFORMS: + - name: Resize + size: 256 + - name: CenterCrop + size: 224 + - name: ToTensor + - name: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + MMAP_MODE: False + COPY_TO_LOCAL_DISK: False + COPY_DESTINATION_DIR: /tmp/imagenet1k/ + TEST: + DATA_SOURCES: [disk_folder] + LABEL_SOURCES: [disk_folder] + DATASET_NAMES: [imagenet1k_folder] + BATCHSIZE_PER_REPLICA: 32 + TRANSFORMS: + - name: Resize + size: 256 + - name: CenterCrop + size: 224 + - name: ToTensor + - name: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + MMAP_MODE: False + COPY_TO_LOCAL_DISK: False + COPY_DESTINATION_DIR: /tmp/imagenet1k/ + MODEL: + FEATURE_EVAL_SETTINGS: + EVAL_MODE_ON: True + FREEZE_TRUNK_AND_HEAD: True + EVAL_TRUNK_AND_HEAD: True + TRUNK: + NAME: resnet + RESNETS: + DEPTH: 50 + HEAD: + PARAMS: [ + ["identity", {}], + ] + WEIGHTS_INIT: + PARAMS_FILE: "specify the model weights" + STATE_DICT_KEY_NAME: classy_state_dict + LOGISTIC_REGRESSION: + LAMBDA: 0.1 + FEATURES: + PATH: '' + DISTRIBUTED: + NUM_NODES: 1 + NUM_PROC_PER_NODE: 8 + MACHINE: + DEVICE: gpu + CHECKPOINT: + DIR: . diff --git a/configs/config/benchmark/low_shot_transfer/imagenet/models/deit_s16.yaml b/configs/config/benchmark/low_shot_transfer/imagenet/models/deit_s16.yaml new file mode 100644 index 000000000..91d643d90 --- /dev/null +++ b/configs/config/benchmark/low_shot_transfer/imagenet/models/deit_s16.yaml @@ -0,0 +1,17 @@ +# @package _global_ +config: + MODEL: + TRUNK: # S-16 + NAME: vision_transformer + VISION_TRANSFORMERS: + IMAGE_SIZE: 224 + PATCH_SIZE: 16 + HIDDEN_DIM: 384 + NUM_LAYERS: 12 + NUM_HEADS: 6 + MLP_DIM: 1536 + CLASSIFIER: token + DROPOUT_RATE: 0 + ATTENTION_DROPOUT_RATE: 0 + QKV_BIAS: True + DROP_PATH_RATE: 0.0 diff --git a/configs/config/benchmark/low_shot_transfer/imagenet/models/deit_tiny.yaml b/configs/config/benchmark/low_shot_transfer/imagenet/models/deit_tiny.yaml new file mode 100644 index 000000000..dfb39b72b --- /dev/null +++ b/configs/config/benchmark/low_shot_transfer/imagenet/models/deit_tiny.yaml @@ -0,0 +1,33 @@ +# @package _global_ +config: + MODEL: + FEATURE_EVAL_SETTINGS: + LINEAR_EVAL_FEAT_POOL_OPS_MAP: [ + ["concatCLS4", ["Identity", []] ], + ["lastCLS", ["Identity", []] ], + ["concatCLS4", ["Identity", []] ], + ["lastCLS", ["Identity", []] ], + ] + TRUNK: # Tiny + NAME: vision_transformer + VISION_TRANSFORMERS: + IMAGE_SIZE: 224 + PATCH_SIZE: 16 + NUM_LAYERS: 12 + NUM_HEADS: 3 + HIDDEN_DIM: 192 + MLP_DIM: 768 + CLASSIFIER: token + DROPOUT_RATE: 0 + ATTENTION_DROPOUT_RATE: 0 + QKV_BIAS: True + DROP_PATH_RATE: 0.0 + HEAD: + PARAMS: [ + ["eval_mlp", {"in_channels": 768, "dims": [768, 1000]}], + ["eval_mlp", {"in_channels": 192, "dims": [192, 1000]}], + ["mlp", {"dims": [768, 1000]}], + ["mlp", {"dims": [192, 1000]}], + ] + OPTIMIZER: + regularize_bn: True diff --git a/configs/config/benchmark/low_shot_transfer/imagenet/models/vit_b16.yaml b/configs/config/benchmark/low_shot_transfer/imagenet/models/vit_b16.yaml new file mode 100644 index 000000000..f22ce000a --- /dev/null +++ b/configs/config/benchmark/low_shot_transfer/imagenet/models/vit_b16.yaml @@ -0,0 +1,17 @@ +# @package _global_ +config: + MODEL: + TRUNK: # L-16 + NAME: vision_transformer + VISION_TRANSFORMERS: + IMAGE_SIZE: 224 + PATCH_SIZE: 16 + NUM_LAYERS: 12 + NUM_HEADS: 12 + HIDDEN_DIM: 768 + MLP_DIM: 3072 + DROPOUT_RATE: 0.0 + ATTENTION_DROPOUT_RATE: 0 + CLASSIFIER: token + QKV_BIAS: True + DROP_PATH_RATE: 0.0 diff --git a/configs/config/benchmark/low_shot_transfer/imagenet/models/vit_l16.yaml b/configs/config/benchmark/low_shot_transfer/imagenet/models/vit_l16.yaml new file mode 100644 index 000000000..91ab9436b --- /dev/null +++ b/configs/config/benchmark/low_shot_transfer/imagenet/models/vit_l16.yaml @@ -0,0 +1,17 @@ +# @package _global_ +config: + MODEL: + TRUNK: # L-16 + NAME: vision_transformer + VISION_TRANSFORMERS: + IMAGE_SIZE: 224 + PATCH_SIZE: 16 + NUM_LAYERS: 24 + NUM_HEADS: 16 + HIDDEN_DIM: 1024 + MLP_DIM: 4096 + DROPOUT_RATE: 0 + ATTENTION_DROPOUT_RATE: 0 + CLASSIFIER: token + QKV_BIAS: True + DROP_PATH_RATE: 0.0 diff --git a/tools/low_shot_test.py b/tools/low_shot_test.py new file mode 100644 index 000000000..75d655bb5 --- /dev/null +++ b/tools/low_shot_test.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +from argparse import Namespace +from typing import Any, List + +from vissl.config import AttrDict +from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict +from vissl.utils.low_shot_utils import ( + extract_features_and_low_shot, + extract_features_and_low_shot_on_slurm, +) +from vissl.utils.slurm import is_submitit_available + + +def main(args: Namespace, config: AttrDict): + if config.SLURM.USE_SLURM: + assert ( + is_submitit_available() + ), "Please 'pip install submitit' to schedule jobs on SLURM" + extract_features_and_low_shot_on_slurm(config) + else: + extract_features_and_low_shot(args.node_id, config) + + +def hydra_main(overrides: List[Any]): + cfg = compose_hydra_configuration(overrides) + args, config = convert_to_attrdict(cfg) + main(args, config) + + +if __name__ == "__main__": + overrides = sys.argv[1:] + hydra_main(overrides=overrides) diff --git a/vissl/config/defaults.yaml b/vissl/config/defaults.yaml index 88aa718da..eeab29059 100644 --- a/vissl/config/defaults.yaml +++ b/vissl/config/defaults.yaml @@ -1552,6 +1552,23 @@ config: # If empty, will run the extract features, if not, use the path to find the features PATH: "" + # ----------------------------------------------------------------------------------- # + # LOGISTIC REGRESSION (benchmark) + # ----------------------------------------------------------------------------------- # + LOW_SHOT_BENCHMARK: + # Which evaluation method to pick to choose among: + # - logistic_regression + # - svm + METHOD: 'logistic_regression' + # Configuration if we use logistic regression for evaluation + LOGISTIC_REGRESSION: + # Regularisation term for L2 penalty + LAMBDA: 0.025 + # Where to find the extracted features and whether or not we should extract them + FEATURES: + # If empty, will run the extract features, if not, use the path to find the features + PATH: "" + # ----------------------------------------------------------------------------------- # # Geo Localization (benchmark) # ----------------------------------------------------------------------------------- # diff --git a/vissl/utils/low_shot_utils.py b/vissl/utils/low_shot_utils.py new file mode 100644 index 000000000..8bcb424f8 --- /dev/null +++ b/vissl/utils/low_shot_utils.py @@ -0,0 +1,166 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from vissl.config import AttrDict +from vissl.hooks import default_hook_generator +from vissl.models.model_helpers import get_trunk_output_feature_names +from vissl.utils.checkpoint import get_checkpoint_folder +from vissl.utils.distributed_launcher import ( + create_submitit_executor, + launch_distributed, +) +from vissl.utils.env import set_env_vars +from vissl.utils.extract_features_utils import ExtractedFeaturesLoader +from vissl.utils.hydra_config import print_cfg +from vissl.utils.logger import setup_logging, shutdown_logging +from vissl.utils.svm_utils.svm_trainer import SVMTrainer + + +def load_features(feature_dir: str, layer_name: str): + train_out = ExtractedFeaturesLoader.load_features( + feature_dir, "train", layer_name, flatten_features=True + ) + train_features, train_labels = train_out["features"], train_out["targets"] + test_out = ExtractedFeaturesLoader.load_features( + feature_dir, "test", layer_name, flatten_features=True + ) + test_features, test_labels = test_out["features"], test_out["targets"] + return train_features, train_labels, test_features, test_labels + + +def run_low_shot_logistic_regression(config: AttrDict, layer_name: str = "heads"): + """ + Run the Nearest Neighbour benchmark at the layer "layer_name" + """ + from sklearn.linear_model import LogisticRegression + from sklearn.preprocessing import StandardScaler + + # -- get train and test features + feature_dir = config.LOW_SHOT_BENCHMARK.FEATURES.PATH + train_features, train_labels, test_features, test_labels = load_features( + feature_dir, layer_name + ) + + # -- Scale the features based on training statistics + scaler = StandardScaler() + scaler.fit(train_features) + train_features = scaler.transform(train_features) + test_features = scaler.transform(test_features) + + # -- Fit Logistic Regression Classifier + method_config = config.LOW_SHOT_BENCHMARK.LOGISTIC_REGRESSION + lambd = method_config.LAMBDA + lambd /= len(train_labels) + classifier = LogisticRegression( + penalty="l2", + C=1 / lambd, + multi_class="multinomial", + ) + classifier.fit( + train_features, + train_labels, + ) + + # -- Evaluate on train and test set + train_score = classifier.score(train_features, train_labels) + print("Train score: ", train_score) + test_score = classifier.score(test_features, test_labels) + print("Test score: ", test_score) + return test_score + + +def run_low_shot_svm(config: AttrDict, layer_name: str = "heads"): + # setup the environment variables + set_env_vars(local_rank=0, node_id=0, cfg=config) + + features_dir = config.LOW_SHOT_BENCHMARK.FEATURES.PATH + output_dir = get_checkpoint_folder(config) + + # Train the svm + logging.info(f"Training SVM for layer: {layer_name}") + trainer = SVMTrainer(config.SVM, layer=layer_name, output_dir=output_dir) + train_data = ExtractedFeaturesLoader.load_features( + features_dir, "train", layer_name, flatten_features=True + ) + trainer.train(train_data["features"], train_data["targets"]) + + # Test the svm + test_data = ExtractedFeaturesLoader.load_features( + features_dir, "test", layer_name, flatten_features=True + ) + trainer.test(test_data["features"], test_data["targets"]) + logging.info("All Done!") + + +def run_low_shot_all_layers(config: AttrDict): + """ + Get the names of the features that we are extracting. If user doesn't + specify the features to evaluate, we get the full model output and freeze + head/trunk both as caution. + """ + feat_names = get_trunk_output_feature_names(config.MODEL) + if len(feat_names) == 0: + feat_names = ["heads"] + + for layer in feat_names: + if config.LOW_SHOT_BENCHMARK.METHOD == "logistic_regression": + top_1 = run_low_shot_logistic_regression(config, layer_name=layer) + else: + top_1 = run_low_shot_svm(config, layer_name=layer) + logging.info(f"layer: {layer}, Top1: {top_1}") + + +def extract_features_and_low_shot(node_id: int, config: AttrDict): + setup_logging(__name__) + print_cfg(config) + set_env_vars(local_rank=0, node_id=0, cfg=config) + + # Extract the features if no path to the extract features is provided + if not config.LOW_SHOT_BENCHMARK.FEATURES.PATH: + launch_distributed( + config, + node_id, + engine_name="extract_features", + hook_generator=default_hook_generator, + ) + config.LOW_SHOT_BENCHMARK.FEATURES.PATH = get_checkpoint_folder(config) + + # Run KNN on all the extract features + run_low_shot_all_layers(config) + + # close the logging streams including the file handlers + shutdown_logging() + + +class _ResumableLowShotSlurmJob: + def __init__(self, config: AttrDict): + self.config = config + + def __call__(self): + import submitit + + environment = submitit.JobEnvironment() + node_id = environment.global_rank + master_ip = environment.hostnames[0] + master_port = self.config.SLURM.PORT_ID + self.config.DISTRIBUTED.INIT_METHOD = "tcp" + self.config.DISTRIBUTED.RUN_ID = f"{master_ip}:{master_port}" + extract_features_and_low_shot(node_id=node_id, config=self.config) + + def checkpoint(self): + import submitit + + trainer = _ResumableLowShotSlurmJob(config=self.config) + return submitit.helpers.DelayedSubmission(trainer) + + +def extract_features_and_low_shot_on_slurm(cfg): + executor = create_submitit_executor(cfg) + trainer = _ResumableLowShotSlurmJob(config=cfg) + job = executor.submit(trainer) + print(f"SUBMITTED: {job.job_id}") + return job diff --git a/vissl/utils/svm_utils/svm_trainer.py b/vissl/utils/svm_utils/svm_trainer.py index 7ec0fd0d2..b224454a5 100644 --- a/vissl/utils/svm_utils/svm_trainer.py +++ b/vissl/utils/svm_utils/svm_trainer.py @@ -96,7 +96,7 @@ def _get_svm_model_filename(self, cls_num, cost): def get_best_cost_value(self): """ - During the SVM training, we write the cross vaildation + During the SVM training, we write the cross validation AP value for training at each class and cost value combination. We load the AP values and for each class, determine the cost value that gives the maximum