Skip to content

Commit 538b8a8

Browse files
authored
jit transition (#53)
1 parent 50622cb commit 538b8a8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1630
-483
lines changed

README.md

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ work on feature requests or accept external contributions, unless they were
2020
pre-approved (ask in an issue first). For a well-supported transfer-only
2121
codebase, see also [vision_transformer](https://github.com/google-research/vision_transformer).
2222

23+
Note that `big_vision` is quite dynamic codebase and, while we intend to keep
24+
the core code fully-functional at all times, we can not guarantee timely updates
25+
of the project-specific code that lives in the `.../proj/...` subfolders.
26+
However, we provide a [table](#project-specific-commits) with last known
27+
commits where specific projects were known to work.
28+
2329
The following research projects were originally conducted in the `big_vision`
2430
codebase:
2531

@@ -51,6 +57,13 @@ codebase:
5157
Kornblith*, Xiaohua Zhai*, Matthias Minderer*, Michael Tschannen*, Ibrahim
5258
Alabdulmohsin*, Filip Pavetic*\
5359
Resources: [readme](big_vision/configs/proj/flexivit/README.md), [configs](big_vision/configs/proj/flexivit).
60+
- [Dual PatchNorm](https://arxiv.org/abs/2302.01327), by Manoj Kumar, Mostafa Dehghani, Neil Houlsby.
61+
- [Getting ViT in Shape: Scaling Laws for Compute-Optimal Model Design](https://arxiv.org/abs/2305.13035), by
62+
Ibrahim Alabdulmohsin*, Xiaohua Zhai*, Alexander Kolesnikov, Lucas Beyer*.
63+
- (partial) [Scaling Vision Transformers to 22 Billion Parameters](https://arxiv.org/abs/2302.05442), by
64+
Mostafa Dehghani*, Josip Djolonga*, Basil Mustafa*, Piotr Padlewski*, Jonathan Heek*, *wow many middle authors*, Neil Houlsby*.
65+
- (partial) [Finite Scalar Quantization: VQ-VAE Made Simple](https://arxiv.org/abs/2309.15505), by
66+
Fabian Mentzer, David Minnen, Eirikur Agustsson, Michael Tschannen.
5467

5568
### Multimodal research
5669

@@ -64,21 +77,28 @@ codebase:
6477
- [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343), by
6578
Xiaohua Zhai*, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer*\
6679
Resources: [colab and models](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb), code TODO.
80+
- [A Study of Autoregressive Decoders for Multi-Tasking in Computer Vision](https://arxiv.org/abs/2303.17376), by
81+
Lucas Beyer*, Bo Wan*, Gagan Madan*, Filip Pavetic*, Andreas Steiner*, Alexander Kolesnikov, André Susano Pinto, Emanuele Bugliarello, Xiao Wang, Qihang Yu, Liang-Chieh Chen, Xiaohua Zhai*.
82+
- [Image Captioners Are Scalable Vision Learners Too](https://arxiv.org/abs/2306.07915), by
83+
Michael Tschannen*, Manoj Kumar*, Andreas Steiner*, Xiaohua Zhai, Neil Houlsby, Lucas Beyer*.
84+
- [Three Towers: Flexible Contrastive Learning with Pretrained Image Models](https://arxiv.org/abs/2305.16999), by Jannik Kossen, Mark Collier, Basil Mustafa, Xiao Wang, Xiaohua Zhai, Lucas Beyer, Andreas Steiner, Jesse Berent, Rodolphe Jenatton, Efi Kokiopoulou.
85+
- (partial) [PaLI: A Jointly-Scaled Multilingual Language-Image Model](https://arxiv.org/abs/2209.06794), by Xi Chen, Xiao Wang, Soravit Changpinyo, *wow so many middle authors*, Anelia Angelova, Xiaohua Zhai, Neil Houlsby, Radu Soricut.
86+
- (partial) [PaLI-3 Vision Language Models: Smaller, Faster, Stronger](https://arxiv.org/abs/2310.09199), by Xi Chen, Xiao Wang, Lucas Beyer, Alexander Kolesnikov, Jialin Wu, Paul Voigtlaender, Basil Mustafa, Sebastian Goodman, Ibrahim Alabdulmohsin, Piotr Padlewski, Daniel Salz, Xi Xiong, Daniel Vlasic, Filip Pavetic, Keran Rong, Tianli Yu, Daniel Keysers, Xiaohua Zhai, Radu Soricut.
6787

68-
### Knowledge distillation
88+
### Training
6989

7090
- [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237), by
7191
Lucas Beyer*, Xiaohua Zhai*, Amélie Royer*, Larisa Markeeva*, Rohan Anil,
7292
and Alexander Kolesnikov*\
7393
Resources: [README](big_vision/configs/proj/distill/README.md), [trainer](big_vision/trainers/proj/distill/distill.py), [colab](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing).
74-
75-
### Training
76-
7794
- [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/abs/2010.01412), by
7895
Pierre Foret, Ariel Kleiner, Hossein Mobahi, Behnam Neyshabur
79-
8096
- [Surrogate Gap Minimization Improves Sharpness-Aware Training](https://arxiv.org/abs/2203.08065), by Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha Dvornek, Sekhar Tatikonda, James Duncan and Ting Liu \
8197
Resources: [trainer](big_vision/trainers/proj/gsam/gsam.py), [config](big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py) [reproduced results](https://github.com/google-research/big_vision/pull/8#pullrequestreview-1078557411)
98+
- [Tuning computer vision models with task rewards](https://arxiv.org/abs/2302.08242), by
99+
André Susano Pinto*, Alexander Kolesnikov*, Yuge Shi, Lucas Beyer, Xiaohua Zhai.
100+
- (partial) [VeLO: Training Versatile Learned Optimizers by Scaling Up](https://arxiv.org/abs/2211.09760) by
101+
Luke Metz, James Harrison, C. Daniel Freeman, Amil Merchant, Lucas Beyer, James Bradbury, Naman Agrawal, Ben Poole, Igor Mordatch, Adam Roberts, Jascha Sohl-Dickstein.
82102

83103
### Misc
84104

@@ -118,7 +138,7 @@ details, but generally speaking, running on a GPU machine involves calling
118138
`python -m COMMAND` while running on TPUs, including multi-host, involves
119139

120140
```
121-
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all
141+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all
122142
--command "bash big_vision/run_tpu.sh COMMAND"
123143
```
124144

@@ -273,7 +293,7 @@ The following command line will create TPU VMs with 32 cores,
273293
4 hosts.
274294

275295
```
276-
gcloud alpha compute tpus tpu-vm create $NAME --zone $ZONE --accelerator-type v3-32 --version v2-tf-stable
296+
gcloud compute tpus tpu-vm create $NAME --zone $ZONE --accelerator-type v3-32 --version tpu-ubuntu2204-base
277297
```
278298

279299
## Install `big_vision` on TPU VMs
@@ -283,8 +303,8 @@ dependencies.
283303

284304
```
285305
git clone https://github.com/google-research/big_vision
286-
gcloud alpha compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --zone=$ZONE --worker=all
287-
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash big_vision/run_tpu.sh"
306+
gcloud compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --zone=$ZONE --worker=all
307+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash big_vision/run_tpu.sh"
288308
```
289309

290310
## Download and prepare TFDS datasets
@@ -298,13 +318,13 @@ Specifically, the seven TFDS datasets used during evaluations will be generated
298318
under `~/tensorflow_datasets` on TPU machine with this command:
299319

300320
```
301-
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "TFDS_DATA_DIR=~/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets cifar10 cifar100 oxford_iiit_pet oxford_flowers102 cars196 dtd uc_merced"
321+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "TFDS_DATA_DIR=~/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets cifar10 cifar100 oxford_iiit_pet oxford_flowers102 cars196 dtd uc_merced"
302322
```
303323

304324
You can then copy the datasets to GS bucket, to make them accessible to all TPU workers.
305325

306326
```
307-
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "rm -r ~/tensorflow_datasets/downloads && gsutil cp -r ~/tensorflow_datasets gs://$GS_BUCKET_NAME"
327+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "rm -r ~/tensorflow_datasets/downloads && gsutil cp -r ~/tensorflow_datasets gs://$GS_BUCKET_NAME"
308328
```
309329

310330
If you want to integrate other public or custom datasets, i.e. imagenet2012,
@@ -322,23 +342,28 @@ The following command line fine-tunes a pre-trained `vit-i21k-augreg-b/32` model
322342
on `cifar10` dataset.
323343

324344
```
325-
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"
345+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"
326346
```
327347

348+
## Checkpointing on cloud
349+
350+
In the past, we recommended writing checkpoints to a Google Cloud Bucket. With the latest update, this is very slow because of technical issues with the checkpointing format.
351+
We are working on a solution, but in the meantime, we have updated our instructions to write checkpoints to a local folder on the TPU machine. Don't forget to copy useful checkpoints elsewhere after training.
352+
328353
## Run the train script on TPU VMs
329354

330355
To train your own big_vision models on a large dataset,
331356
e.g. `imagenet2012` ([prepare the TFDS dataset](https://www.tensorflow.org/datasets/catalog/imagenet2012)),
332357
run the following command line.
333358

334359
```
335-
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/bit_i1k.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
360+
gcloud compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/bit_i1k.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
336361
```
337362

338363
## Sometimes useful gcloud commands
339364

340-
- Destroy the TPU machines: `gcloud alpha compute tpus tpu-vm delete $NAME --zone $ZONE`
341-
- Remove all big_vision-related folders on all hosts: `gcloud alpha compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'`
365+
- Destroy the TPU machines: `gcloud compute tpus tpu-vm delete $NAME --zone $ZONE`
366+
- Remove all big_vision-related folders on all hosts: `gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'`
342367

343368
# ViT baseline
344369

big_vision/configs/bit_i1k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

big_vision/configs/bit_i21k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

big_vision/configs/common.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -130,9 +130,16 @@ def autotype(x):
130130

131131
def pack_arg(**kw):
132132
"""Packs key-word args as a string to be parsed by `parse_arg()`."""
133+
for v in kw.values():
134+
assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
133135
return ','.join([f'{k}={v}' for k, v in kw.items()])
134136

135137

138+
def arg(**kw):
139+
"""Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
140+
return {'config_arg': pack_arg(**kw), **kw}
141+
142+
136143
def _get_field_ref(config_dict, field_name):
137144
path = field_name.split('.')
138145
for field in path[:-1]:

big_vision/configs/common_fewshot.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -18,12 +18,20 @@
1818

1919

2020
def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
21-
runlocal=False):
21+
runlocal=False, **kw):
2222
"""Returns a standard-ish fewshot eval configuration."""
23-
config = mlc.ConfigDict()
23+
kw.setdefault('representation_layer', 'pre_logits')
24+
kw.setdefault('shots', (1, 5, 10, 25))
25+
kw.setdefault('l2_reg', 2.0 ** 10)
26+
kw.setdefault('num_seeds', 3)
27+
kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/
28+
29+
# Backward-compatible default:
30+
if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long
31+
kw['log_steps'] = 25_000
32+
33+
config = mlc.ConfigDict(kw)
2434
config.type = 'fewshot_lsr'
25-
config.representation_layer = 'pre_logits'
26-
config.log_steps = 25_000
2735
config.datasets = {
2836
'caltech': ('caltech101', 'train', 'test'), # copybara:srtip
2937
'cars': ('cars196:2.1.0', 'train', 'test'),
@@ -37,12 +45,12 @@ def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
3745
} if not runlocal else {
3846
'pets': ('oxford_iiit_pet', 'train', 'test'),
3947
}
40-
config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)|keep("image", "label")'
41-
config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)|keep("image", "label")'
42-
config.shots = (1, 5, 10, 25)
43-
config.l2_reg = 2.0 ** 10
44-
config.num_seeds = 3
48+
config.pp_train = (f'decode|resize({resize_resolution})|'
49+
f'central_crop({target_resolution})|'
50+
f'value_range(-1,1)|keep("image", "label")')
51+
config.pp_eval = (f'decode|resize({resize_resolution})|'
52+
f'central_crop({target_resolution})|'
53+
f'value_range(-1,1)|keep("image", "label")')
4554
config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]
46-
config.prefix = '' # No prefix as we do already prefix with a/ z/ and zz/
4755

4856
return config

big_vision/configs/load_and_eval.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -36,25 +36,35 @@
3636

3737
import big_vision.configs.common as bvcc
3838
from big_vision.configs.common_fewshot import get_fewshot_lsr
39-
from big_vision.configs.proj.image_text import lit_eval
40-
import ml_collections as mlc
39+
# from big_vision.configs.proj.image_text import lit_eval
40+
41+
42+
def eval_only(config, batch_size, spec_for_init):
43+
"""Set a few configs that turn trainer into (almost) eval-only."""
44+
config.total_steps = 0
45+
config.input = {}
46+
config.input.batch_size = batch_size
47+
config.input.data = dict(name='bv:dummy', spec=spec_for_init)
48+
config.optax_name = 'identity'
49+
config.lr = 0.0
50+
return config
4151

4252

4353
def get_config(arg='name=bit_paper,batch_size=2'):
44-
arg = bvcc.parse_arg(arg, name='', batch_size=2)
45-
config = mlc.ConfigDict()
46-
config.batch_size_eval = arg.batch_size
54+
config = bvcc.parse_arg(arg, name='', batch_size=2)
55+
56+
# Make the config eval-only by setting some dummies.
57+
eval_only(config, config.batch_size, spec_for_init=dict(
58+
image=dict(shape=(224, 224, 3), dtype='float32'),
59+
))
4760

4861
# Just calls the function with the name given as `config`.
4962
# Could also be a giant if-block if you're into that kind of thing.
50-
globals()[arg.name](config)
63+
globals()[config.name](config)
5164
return config
5265

5366

5467
def bit_paper(config):
55-
# We could omit init_{shapes,types} if we wanted, as they are the default.
56-
config.init_shapes = [(1, 224, 224, 3)]
57-
config.init_types = ['float32']
5868
config.num_classes = 1000
5969

6070
config.model_name = 'bit_paper'
@@ -82,9 +92,6 @@ def get_eval(split, lbl, dataset='imagenet2012_real'):
8292

8393

8494
def vit_i1k(config):
85-
# We could omit init_{shapes,types} if we wanted, as they are the default.
86-
config.init_shapes = [(1, 224, 224, 3)]
87-
config.init_types = ['float32']
8895
config.num_classes = 1000
8996

9097
config.model_name = 'vit'
@@ -104,9 +111,6 @@ def vit_i1k(config):
104111

105112

106113
def mlp_mixer_i1k(config):
107-
# We could omit init_{shapes,types} if we wanted, as they are the default.
108-
config.init_shapes = [(1, 224, 224, 3)]
109-
config.init_types = ['float32']
110114
config.num_classes = 1000
111115

112116
config.model_name = 'mlp_mixer'
@@ -125,9 +129,6 @@ def mlp_mixer_i1k(config):
125129

126130

127131
def vit_i21k(config):
128-
# We could omit init_{shapes,types} if we wanted, as they are the default.
129-
config.init_shapes = [(1, 224, 224, 3)]
130-
config.init_types = ['float32']
131132
config.num_classes = 21843
132133

133134
config.model_name = 'vit'

big_vision/configs/mlp_mixer_i1k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

big_vision/configs/transfer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -148,11 +148,11 @@ def _set_imagenet_variants(config, h_res=448, l_res=384):
148148
def get_config(arg=None):
149149
"""Config for adaptation."""
150150
arg = bvcc.parse_arg(arg, model='vit', dataset='cifar10', crop='resmall_crop',
151-
h_res=448, l_res=384, runlocal=False)
151+
h_res=448, l_res=384, batch_size=512, runlocal=False)
152152
config = mlc.ConfigDict()
153153

154154
config.input = {}
155-
config.input.batch_size = 512 if not arg.runlocal else 8
155+
config.input.batch_size = arg.batch_size if not arg.runlocal else 8
156156
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100
157157

158158
config.log_training_steps = 10

big_vision/configs/vit_i1k.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -15,6 +15,9 @@
1515
# pylint: disable=line-too-long
1616
r"""Pre-training ViT on ILSVRC-2012 as in https://arxiv.org/abs/2106.10270
1717
18+
This config does NOT include regularization (dropout, stochastic depth), which
19+
was shown to help with B/32, B/16, L/16 models in the paper (Figure 4).
20+
1821
This configuration makes use of the "arg" to get_config to select which model
1922
to run, so a few examples are given below:
2023

big_vision/configs/vit_i21k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

big_vision/configs/vit_s16_i1k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

big_vision/datasets/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
2121
class DataSource:
2222
"""The API that any data source should implement."""
2323

24-
def get_tfdata(self, ordered):
24+
def get_tfdata(self, ordered, *, process_split=True):
2525
"""Creates this data object as a tf.data.Dataset.
2626
2727
This will be called separately in each process, and it is up to the dataset
@@ -30,6 +30,8 @@ def get_tfdata(self, ordered):
3030
Args:
3131
ordered: if True, the dataset should use deterministic ordering, if False
3232
it may have undefined ordering. Think of True == val, False == train.
33+
process_split: if False then every process receives the entire dataset
34+
(e.g. for evaluators running in a single process).
3335
3436
Returns:
3537
A tf.data.Dataset object.

big_vision/datasets/imagenet/class_names.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 Big Vision Authors.
1+
# Copyright 2023 Big Vision Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)