Skip to content

Commit 24ab586

Browse files
committed
refactor: refactor dataloaders
1 parent 4beb9c0 commit 24ab586

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Union
22

33

44
# TODO(review) - логика работы непонятна совсем, для чего метод нужен, нужны пояснения + рефактор (выглядит как что-то ненужное)
55
# default identical preprocessing function for FilesDataset and ShardsDataset
6-
def identical_preprocess_function(column2bytes: Dict[str, bytes], data: Dict[str, str]) -> Any:
7-
return column2bytes, data
6+
def identical_preprocess_function(modality2data: Dict[str, Union[bytes, Any]], data: Dict[str, str]) -> Any:
7+
return modality2data, data
88

99

1010
# identical collate function for pytorch dataloader
11-
def identical_collate_fn(x):
11+
def identical_collate_fn(x: Any) -> Any:
1212
return x

DPF/dataloaders/files_dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, List, Optional, Union
1+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22

33
import pandas as pd
44
from torch.utils.data import Dataset
@@ -8,7 +8,7 @@
88
from DPF.filesystems.filesystem import FileSystem
99

1010

11-
class FilesDataset(Dataset):
11+
class FilesDataset(Dataset[Tuple[bool, Any]]):
1212
"""
1313
Dataset class to read "raw" files
1414
"""
@@ -19,7 +19,7 @@ def __init__(
1919
df: pd.DataFrame,
2020
datatypes: List[Union[ShardedDataType, FileDataType, ColumnDataType]],
2121
meta_columns: Optional[List[str]] = None,
22-
preprocess_function: Callable[[Dict[str, bytes], Dict[str, str]], Any] = identical_preprocess_function,
22+
preprocess_function: Callable[[Dict[str, Union[bytes, Any]], Dict[str, str]], Any] = identical_preprocess_function,
2323
# TODO(review) - на ошибке надо выбрасывать ошибку, а не возвращать None, и в дальнейшем эту ошибку обрабатывать прикладом, использующим этот класс
2424
return_none_on_error: bool = False
2525
):
@@ -50,7 +50,7 @@ def __init__(
5050
self.column2modality = {}
5151
for d in self.datatypes:
5252
if isinstance(d, ColumnDataType):
53-
self.column2modality[d.modality.column] = d.modality.key
53+
self.column2modality[d.column_name] = d.modality.key
5454
elif isinstance(d, (ShardedDataType, FileDataType)):
5555
self.path_column2modality[d.modality.path_column] = d.modality.key
5656
else:
@@ -64,10 +64,10 @@ def __init__(
6464
self.preprocess_f = preprocess_function
6565
self.return_none_on_error = return_none_on_error
6666

67-
def __len__(self):
67+
def __len__(self) -> int:
6868
return len(self.data_to_iterate)
6969

70-
def __getitem__(self, idx):
70+
def __getitem__(self, idx: int) -> Tuple[bool, Any]:
7171
data = {
7272
self.columns[c]: item for c, item in enumerate(self.data_to_iterate[idx])
7373
}

DPF/dataloaders/shards_dataset.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import itertools
22
import os
33
import tarfile
4-
from typing import Any, Callable, Dict, List, Optional, Union
4+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
55

66
import pandas as pd
77
import torch
@@ -12,7 +12,7 @@
1212
from DPF.filesystems.filesystem import FileSystem
1313

1414

15-
class ShardsDataset(IterableDataset):
15+
class ShardsDataset(IterableDataset[Tuple[bool, Any]]):
1616
"""
1717
Dataset class for shards format (files in tar archives)
1818
"""
@@ -24,7 +24,7 @@ def __init__(
2424
split2archive_path: Dict[str, str],
2525
datatypes: List[Union[ShardedDataType, ColumnDataType]],
2626
meta_columns: Optional[List[str]] = None,
27-
preprocess_function: Callable[[Dict[str, bytes], Dict[str, str]], Any] = identical_preprocess_function,
27+
preprocess_function: Callable[[Dict[str, Union[bytes, Any]], Dict[str, str]], Any] = identical_preprocess_function,
2828
return_none_on_error: bool = False
2929
):
3030
"""
@@ -44,9 +44,9 @@ def __init__(
4444
Preprocessing function for data. First argument of the preprocess_f is mapping from modality name to bytes
4545
and the second argument is mapping from meta_column name to its value.
4646
return_none_on_error: bool = False
47-
Whether to return None if error during reading file occures
47+
Whether to return None if error during reading file occurs
4848
"""
49-
super(ShardsDataset).__init__()
49+
super().__init__()
5050
self.filesystem = filesystem
5151

5252
self.datatypes = datatypes
@@ -57,7 +57,7 @@ def __init__(
5757
self.column2modality = {}
5858
for d in self.datatypes:
5959
if isinstance(d, ColumnDataType):
60-
self.column2modality[d.modality.column] = d.modality.key
60+
self.column2modality[d.column_name] = d.modality.key
6161
elif isinstance(d, ShardedDataType):
6262
self.path_column2modality[d.modality.path_column] = d.modality.key
6363
else:
@@ -76,10 +76,10 @@ def __init__(
7676
self.preprocess_f = preprocess_function
7777
self.return_none_on_error = return_none_on_error
7878

79-
def __len__(self):
79+
def __len__(self) -> int:
8080
return self.total_samples
8181

82-
def __iter__(self):
82+
def __iter__(self) -> Iterator[Tuple[bool, Any]]:
8383
worker_info = torch.utils.data.get_worker_info()
8484
worker_total_num = worker_info.num_workers if worker_info is not None else None
8585
worker_id = worker_info.id if worker_info is not None else None
@@ -100,12 +100,12 @@ def __iter__(self):
100100
filename = os.path.basename(data[col])
101101
if self.return_none_on_error:
102102
try:
103-
file_bytes = tar.extractfile(filename).read()
103+
file_bytes = tar.extractfile(filename).read() # type: ignore
104104
except Exception:
105105
file_bytes = None
106106
is_ok = False
107107
else:
108-
file_bytes = tar.extractfile(filename).read()
108+
file_bytes = tar.extractfile(filename).read() # type: ignore
109109
modality2data[modality] = file_bytes
110110
# read data from columns
111111
for col in self.column2modality.keys():

0 commit comments

Comments
 (0)