`timm` is a deep-learning library created by Ross Wightman and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.

Install

pip install timm

Or for an editable install,

git clone https://github.com/rwightman/pytorch-image-models cd pytorch-image-models && pip install -e .

How to use

Create a model

import timm import torch model = timm.create_model('resnet34') x = torch.randn(1, 3, 224, 224) model(x).shape 
torch.Size([1, 1000])

It is that simple to create a model using timm. The create_model function is a factory method that can be used to create over 300 models that are part of the timm library.

To create a pretrained model, simply pass in pretrained=True.

pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True) 
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/tmabraham/.cache/torch/hub/checkpoints/resnet34-43635321.pth 

To create a model with a custom number of classes, simply pass in num_classes=<number_of_classes>.

import timm import torch model = timm.create_model('resnet34', num_classes=10) x = torch.randn(1, 3, 224, 224) model(x).shape 
torch.Size([1, 10])

List Models with Pretrained Weights

timm.list_models() returns a complete list of available models in timm. To have a look at a complete list of pretrained models, pass in pretrained=True in list_models.

avail_pretrained_models = timm.list_models(pretrained=True) len(avail_pretrained_models), avail_pretrained_models[:5] 
(592, ['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384'])

There are a total of 271 models with pretrained weights currently available in timm!

Search for model architectures by Wildcard

It is also possible to search for model architectures using Wildcard as below:

all_densenet_models = timm.list_models('*densenet*') all_densenet_models 
['densenet121', 'densenet121d', 'densenet161', 'densenet169', 'densenet201', 'densenet264', 'densenet264d_iabn', 'densenetblur121d', 'tv_densenet121']

Fine-tune timm model in fastai

The fastai library has support for fine-tuning models from timm:

from fastai.vision.all import * path = untar_data(URLs.PETS)/'images' dls = ImageDataLoaders.from_name_func( path, get_image_files(path), valid_pct=0.2, label_func=lambda x: x[0].isupper(), item_tfms=Resize(224)) # if a string is passed into the model argument, it will now use timm (if it is installed) learn = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate) learn.fine_tune(1) 
epoch train_loss valid_loss error_rate time
0 0.201583 0.024980 0.006766 00:08
epoch train_loss valid_loss error_rate time
0 0.040622 0.024036 0.005413 00:10