Skip to content

Commit 6ad4bbb

Browse files
Refactoring changes for flexibility (#1)
* unhashed download * added notebook version * added reference to original repo and updated ROG * removed extra install * added fix to installs * add pretrained model downloads and inference --------- Co-authored-by: jameshskelton <jameshskelton@gmail.com>
1 parent b0257e3 commit 6ad4bbb

File tree

4 files changed

+128
-7
lines changed

4 files changed

+128
-7
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Data2Vec 2.0
22

3-
Data2Vec is self-supervised highly-efficient general framework to generate representations for vision, speech and text. This repository contains ready-to train [data2vec](https://github.com/facebookresearch/fairseq/tree/main/examples/data2vec) ([arXiv](https://arxiv.org/abs/2202.03555)) implementation containing helper scripts to load, process & train the data.
3+
[Check out the original repo!](https://github.com/ashutosh1919/data2vec-pytorch)
44

5+
Data2Vec is self-supervised highly-efficient general framework to generate representations for vision, speech and text. This repository contains ready-to train [data2vec](https://github.com/facebookresearch/fairseq/tree/main/examples/data2vec) ([arXiv](https://arxiv.org/abs/2202.03555)) implementation containing helper scripts to load, process & train the data.
56

67
## Run in a Free GPU powered Gradient Notebook
7-
[![Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://console.paperspace.com/github/ashutosh1919/data2vec-pytorch?machine=Free-GPU)
88

9+
[![Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://console.paperspace.com/github/gradient-ai/data2vec-pytorch?machine=Free-GPU)
910

1011
## Setup
1112

@@ -40,14 +41,14 @@ bash scripts/train_data2vec_multi_speech.sh
4041

4142
Note that you may want to change some of the arguments in these task scripts based on your system. Since we have single GPU, the arg `distributed_training.distributed_world_size=1` for us which you can change based on your requirement.
4243

43-
4444
## Original Code
4545

4646
`data2vec` directory contains the original code taken from [fairseq](https://github.com/facebookresearch/fairseq/tree/main/examples/data2vec) repository. The code present in this directory is exactly same as the original code. We have only made changes in some of the config files corresponding to the tasks.
4747

4848
## Reference
4949

5050
data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language -- https://arxiv.org/abs/2202.03555
51+
5152
```
5253
@article{DBLP:journals/corr/abs-2202-03555,
5354
author = {Alexei Baevski and
@@ -65,9 +66,10 @@ data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
6566
```
6667

6768
Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language -- https://arxiv.org/abs/2212.07525
69+
6870
```
6971
@misc{baevski2022efficient,
70-
title={Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language},
72+
title={Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language},
7173
author={Alexei Baevski and Arun Babu and Wei-Ning Hsu and Michael Auli},
7274
year={2022},
7375
eprint={2212.07525},

data2vec.ipynb

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"source": [
7+
"!bash installations.sh\n"
8+
],
9+
"outputs": [],
10+
"metadata": {
11+
"execution": {
12+
"iopub.execute_input": "2023-03-28T21:17:09.717578Z",
13+
"iopub.status.busy": "2023-03-28T21:17:09.716884Z"
14+
}
15+
}
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"source": [
20+
"# Get model checkpoints"
21+
],
22+
"metadata": {}
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"source": [
28+
"!mkdir models\n",
29+
"%cd models\n",
30+
"wget https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet_ft.pt ### ViT-B Imagenet-1k finetuned\n",
31+
"wget https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri_960h.pt ### Librispeech finetuned 960 hour split\n",
32+
"wget https://dl.fbaipublicfiles.com/fairseq/data2vec2/nlp_base.pt #### Base NLP\n",
33+
"%cd ../"
34+
],
35+
"outputs": [],
36+
"metadata": {}
37+
},
38+
{
39+
"cell_type": "markdown",
40+
"source": [
41+
"## Get data & train\n",
42+
"\n",
43+
"If you want to train the models yourself, you can run the cell below. \n",
44+
"\n",
45+
"This will take a long time to run, and requires downloading large datasets. "
46+
],
47+
"metadata": {}
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"source": [
53+
"# Downloads ImageNet and starts training data2vec_multi with it.\n",
54+
"!bash scripts/train_data2vec_multi_image.sh\n",
55+
"\n",
56+
"# Downloads OpenWebText and starts training data2vec_multi with it.\n",
57+
"!bash scripts/train_data2vec_multi_text.sh\n",
58+
"\n",
59+
"# Downloads LibriSpeech and starts training data2vec_multi with it.\n",
60+
"!bash scripts/train_data2vec_multi_speech.sh"
61+
],
62+
"outputs": [],
63+
"metadata": {}
64+
},
65+
{
66+
"cell_type": "markdown",
67+
"source": [
68+
"# Checkpoints & Future usage"
69+
],
70+
"metadata": {}
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"source": [
76+
"import torch\n",
77+
"from data2vec.models.data2vec2 import D2vModalitiesConfig\n",
78+
"from data2vec.models.data2vec2 import Data2VecMultiConfig\n",
79+
"from data2vec.models.data2vec2 import Data2VecMultiModel\n",
80+
"from PIL import Image\n",
81+
"CHECKPOINT_PATH = 'models/base_imagenet_ft.pt'\n",
82+
"# Load checkpoint\n",
83+
"ckpt = torch.load(CHECKPOINT_PATH)\n",
84+
"\n",
85+
"# Create config and load model\n",
86+
"cfg = Data2VecMultiConfig()\n",
87+
"model = Data2VecMultiModel(cfg, modalities=D2vModalitiesConfig.image)\n",
88+
"model.load_state_dict(ckpt)\n",
89+
"model.eval()\n",
90+
"BATCHED_DATA_OBJECT = Image.open('assets/n01440764_tench.JPEG')\n",
91+
"# Generating prediction from data\n",
92+
"pred = model(BATCHED_DATA_OBJECT)"
93+
],
94+
"outputs": [],
95+
"metadata": {}
96+
}
97+
],
98+
"metadata": {
99+
"kernelspec": {
100+
"display_name": "Python 3 (ipykernel)",
101+
"language": "python",
102+
"name": "python3"
103+
},
104+
"language_info": {
105+
"codemirror_mode": {
106+
"name": "ipython",
107+
"version": 3
108+
},
109+
"file_extension": ".py",
110+
"mimetype": "text/x-python",
111+
"name": "python",
112+
"nbconvert_exporter": "python",
113+
"pygments_lexer": "ipython3",
114+
"version": "3.9.16"
115+
}
116+
},
117+
"nbformat": 4,
118+
"nbformat_minor": 5
119+
}

datasets/imagenet/fetch_imagenet.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
# Run this script to fetch all the dataset
44

5-
# wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar
6-
# wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
5+
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar
6+
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
77

88
train_tar="${1:-ILSVRC2012_img_train.tar}"
99
val_tar="${2:-ILSVRC2012_img_val.tar}"

installations.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1+cu116 torchtext==0.14.1 -f https://download.pytorch.org/whl/torch_stable.html
33

44
# Installing lxml
5-
sudo apt-get install python-lxml
5+
sudo apt-get install python-lxml -y
66

77
# Installing requirements.txt
88
pip install -r requirements.txt

0 commit comments

Comments
 (0)