Skip to content

Commit a5310a3

Browse files
committed
Merge remote-tracking branch 'origin/benchmark-fixes-vit_hybrids' into pit_and_vit_update
2 parents 7953e5d + e2e3290 commit a5310a3

File tree

15 files changed

+1109
-149
lines changed

15 files changed

+1109
-149
lines changed

benchmark.py

Lines changed: 470 additions & 0 deletions
Large diffs are not rendered by default.

timm/data/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
_logger = logging.getLogger(__name__)
66

77

8-
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True):
8+
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
99
new_config = {}
1010
default_cfg = default_cfg
1111
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):

timm/data/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,13 @@ def __init__(
7373
batch_size=None,
7474
class_map='',
7575
load_bytes=False,
76+
repeats=0,
7677
transform=None,
7778
):
7879
assert parser is not None
7980
if isinstance(parser, str):
8081
self.parser = create_parser(
81-
parser, root=root, split=split, is_training=is_training, batch_size=batch_size)
82+
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
8283
else:
8384
self.parser = parser
8485
self.transform = transform

timm/data/dataset_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
2323
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
2424
else:
2525
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
26+
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
2627
if search_split and os.path.isdir(root):
2728
root = _search_split(root, split)
2829
ds = ImageDataset(root, parser=name, **kwargs)

timm/data/parsers/parser_tfds.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
PREFETCH_SIZE = 4096 # samples to prefetch
3030

3131

32+
def even_split_indices(split, n, num_samples):
33+
partitions = [round(i * num_samples / n) for i in range(n + 1)]
34+
return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)]
35+
36+
3237
class ParserTfds(Parser):
3338
""" Wrap Tensorflow Datasets for use in PyTorch
3439
@@ -52,7 +57,7 @@ class ParserTfds(Parser):
5257
components.
5358
5459
"""
55-
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
60+
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
5661
super().__init__()
5762
self.root = root
5863
self.split = split
@@ -62,6 +67,8 @@ def __init__(self, root, name, split='train', shuffle=False, is_training=False,
6267
assert batch_size is not None,\
6368
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
6469
self.batch_size = batch_size
70+
self.repeats = repeats
71+
self.subsplit = None
6572

6673
self.builder = tfds.builder(name, data_dir=root)
6774
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
@@ -95,6 +102,7 @@ def _lazy_init(self):
95102
if worker_info is not None:
96103
self.worker_info = worker_info
97104
num_workers = worker_info.num_workers
105+
global_num_workers = self.dist_num_replicas * num_workers
98106
worker_id = worker_info.id
99107

100108
# FIXME I need to spend more time figuring out the best way to distribute/split data across
@@ -114,19 +122,31 @@ def _lazy_init(self):
114122
# split = split + '[{}:]'.format(start)
115123
# else:
116124
# split = split + '[{}:{}]'.format(start, start + split_size)
117-
118-
input_context = tf.distribute.InputContext(
119-
num_input_pipelines=self.dist_num_replicas * num_workers,
120-
input_pipeline_id=self.dist_rank * num_workers + worker_id,
121-
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
122-
)
123-
124-
read_config = tfds.ReadConfig(input_context=input_context)
125-
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
125+
if not self.is_training and '[' not in self.split:
126+
# If not training, and split doesn't define a subsplit, manually split the dataset
127+
# for more even samples / worker
128+
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[
129+
self.dist_rank * num_workers + worker_id]
130+
131+
if self.subsplit is None:
132+
input_context = tf.distribute.InputContext(
133+
num_input_pipelines=self.dist_num_replicas * num_workers,
134+
input_pipeline_id=self.dist_rank * num_workers + worker_id,
135+
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
136+
)
137+
else:
138+
input_context = None
139+
140+
read_config = tfds.ReadConfig(
141+
shuffle_seed=42,
142+
shuffle_reshuffle_each_iteration=True,
143+
input_context=input_context)
144+
ds = self.builder.as_dataset(
145+
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
126146
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
127147
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
128148
ds.options().experimental_threading.max_intra_op_parallelism = 1
129-
if self.is_training:
149+
if self.is_training or self.repeats > 1:
130150
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
131151
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
132152
ds = ds.repeat() # allow wrap around and break iteration manually
@@ -143,7 +163,7 @@ def __iter__(self):
143163
# This adds extra samples and will slightly alter validation results.
144164
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
145165
# batches are produced (underlying tfds iter wraps around)
146-
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
166+
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
147167
if self.is_training:
148168
# round up to nearest batch_size per worker-replica
149169
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
@@ -160,8 +180,8 @@ def __iter__(self):
160180
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
161181
# Validation batch padding only done for distributed training where results are reduced across nodes.
162182
# For single process case, it won't matter if workers return different batch sizes.
163-
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
164-
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
183+
# FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
184+
# approach is not optimal
165185
yield img, sample['label'] # yield prev sample again
166186
sample_count += 1
167187

@@ -176,7 +196,7 @@ def _num_pipelines(self):
176196
def __len__(self):
177197
# this is just an estimate and does not factor in extra samples added to pad batches based on
178198
# complete worker & replica info (not available until init in dataloader).
179-
return math.ceil(self.num_samples / self.dist_num_replicas)
199+
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
180200

181201
def _filename(self, index, basename=False, absolute=False):
182202
assert False, "Not supported" # no random access to samples

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .tresnet import *
3030
from .vgg import *
3131
from .vision_transformer import *
32+
from .vision_transformer_hybrid import *
3233
from .vovnet import *
3334
from .xception import *
3435
from .xception_aligned import *

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
3232
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
3333
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
34-
from .weight_init import trunc_normal_
34+
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_

timm/models/layers/weight_init.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import math
33
import warnings
44

5+
from torch.nn.init import _calculate_fan_in_and_fan_out
6+
57

68
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
79
# Cut & paste from PyTorch official master until it's in a few official releases - RW
@@ -58,3 +60,30 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
5860
>>> nn.init.trunc_normal_(w)
5961
"""
6062
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63+
64+
65+
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
66+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
67+
if mode == 'fan_in':
68+
denom = fan_in
69+
elif mode == 'fan_out':
70+
denom = fan_out
71+
elif mode == 'fan_avg':
72+
denom = (fan_in + fan_out) / 2
73+
74+
variance = scale / denom
75+
76+
if distribution == "truncated_normal":
77+
# constant is stddev of standard normal truncated to (-2, 2)
78+
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
79+
elif distribution == "normal":
80+
tensor.normal_(std=math.sqrt(variance))
81+
elif distribution == "uniform":
82+
bound = math.sqrt(3 * variance)
83+
tensor.uniform_(-bound, bound)
84+
else:
85+
raise ValueError(f"invalid distribution {distribution}")
86+
87+
88+
def lecun_normal_(tensor):
89+
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')

timm/models/resnetv2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ def forward(self, x):
274274
return x
275275

276276

277-
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None):
277+
def create_resnetv2_stem(
278+
in_chs, out_chs=64, stem_type='', preact=True,
279+
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
278280
stem = OrderedDict()
279281
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
280282

@@ -322,7 +324,8 @@ def __init__(self, layers, channels=(256, 512, 1024, 2048),
322324

323325
self.feature_info = []
324326
stem_chs = make_div(stem_chs * wf)
325-
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
327+
self.stem = create_resnetv2_stem(
328+
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
326329
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
327330
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
328331

0 commit comments

Comments
 (0)