11import itertools
22import os
33import tarfile
4- from typing import Any , Callable , Dict , List , Optional , Union
4+ from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , Union
55
66import pandas as pd
77import torch
1212from DPF .filesystems .filesystem import FileSystem
1313
1414
15- class ShardsDataset (IterableDataset ):
15+ class ShardsDataset (IterableDataset [ Tuple [ bool , Any ]] ):
1616 """
1717 Dataset class for shards format (files in tar archives)
1818 """
@@ -24,7 +24,7 @@ def __init__(
2424 split2archive_path : Dict [str , str ],
2525 datatypes : List [Union [ShardedDataType , ColumnDataType ]],
2626 meta_columns : Optional [List [str ]] = None ,
27- preprocess_function : Callable [[Dict [str , bytes ], Dict [str , str ]], Any ] = identical_preprocess_function ,
27+ preprocess_function : Callable [[Dict [str , Union [ bytes , Any ] ], Dict [str , str ]], Any ] = identical_preprocess_function ,
2828 return_none_on_error : bool = False
2929 ):
3030 """
@@ -44,9 +44,9 @@ def __init__(
4444 Preprocessing function for data. First argument of the preprocess_f is mapping from modality name to bytes
4545 and the second argument is mapping from meta_column name to its value.
4646 return_none_on_error: bool = False
47- Whether to return None if error during reading file occures
47+ Whether to return None if error during reading file occurs
4848 """
49- super (ShardsDataset ).__init__ ()
49+ super ().__init__ ()
5050 self .filesystem = filesystem
5151
5252 self .datatypes = datatypes
@@ -57,7 +57,7 @@ def __init__(
5757 self .column2modality = {}
5858 for d in self .datatypes :
5959 if isinstance (d , ColumnDataType ):
60- self .column2modality [d .modality . column ] = d .modality .key
60+ self .column2modality [d .column_name ] = d .modality .key
6161 elif isinstance (d , ShardedDataType ):
6262 self .path_column2modality [d .modality .path_column ] = d .modality .key
6363 else :
@@ -76,10 +76,10 @@ def __init__(
7676 self .preprocess_f = preprocess_function
7777 self .return_none_on_error = return_none_on_error
7878
79- def __len__ (self ):
79+ def __len__ (self ) -> int :
8080 return self .total_samples
8181
82- def __iter__ (self ):
82+ def __iter__ (self ) -> Iterator [ Tuple [ bool , Any ]] :
8383 worker_info = torch .utils .data .get_worker_info ()
8484 worker_total_num = worker_info .num_workers if worker_info is not None else None
8585 worker_id = worker_info .id if worker_info is not None else None
@@ -100,12 +100,12 @@ def __iter__(self):
100100 filename = os .path .basename (data [col ])
101101 if self .return_none_on_error :
102102 try :
103- file_bytes = tar .extractfile (filename ).read ()
103+ file_bytes = tar .extractfile (filename ).read () # type: ignore
104104 except Exception :
105105 file_bytes = None
106106 is_ok = False
107107 else :
108- file_bytes = tar .extractfile (filename ).read ()
108+ file_bytes = tar .extractfile (filename ).read () # type: ignore
109109 modality2data [modality ] = file_bytes
110110 # read data from columns
111111 for col in self .column2modality .keys ():
0 commit comments