33import tarfile
44import traceback
55import uuid
6- from typing import Dict , List , Optional , Tuple
6+ from types import TracebackType
7+ from typing import Dict , List , Optional , Tuple , Any , Union
78
89import pandas as pd
910
@@ -23,24 +24,24 @@ def __init__(
2324 self ,
2425 filesystem : FileSystem ,
2526 destination_dir : str ,
26- keys_mapping : Optional [Dict [str , str ]] = None ,
27- max_files_in_shard : Optional [ int ] = 1000 ,
28- datafiles_ext : Optional [ str ] = "csv" ,
29- archives_ext : Optional [ str ] = "tar" ,
27+ keys_to_rename : Optional [Dict [str , str ]] = None ,
28+ max_files_in_shard : int = 1000 ,
29+ datafiles_ext : str = "csv" ,
30+ archives_ext : str = "tar" ,
3031 filenaming : str = "counter"
3132 ) -> None :
3233 self .filesystem = filesystem
3334 self .destination_dir = destination_dir
34- self .keys_mapping = keys_mapping
35+ self .keys_to_rename = keys_to_rename
3536 self .max_files_in_shard = max_files_in_shard
3637 self .datafiles_ext = "." + datafiles_ext .lstrip ("." )
3738 self .archives_ext = "." + archives_ext .lstrip ("." )
3839 self .filenaming = filenaming
3940 assert self .filenaming in ["counter" , "uuid" ], "Invalid files naming"
4041
41- self .df_raw = []
42+ self .df_raw : List [ Dict [ str , Any ]] = []
4243 self .tar_bytes = io .BytesIO ()
43- self .tar = None
44+ self .tar : tarfile . TarFile = None # type: ignore
4445 self .shard_index , self .last_file_index = self ._init_writer_from_last_uploaded_file ()
4546
4647 def save_sample (
@@ -61,8 +62,8 @@ def save_sample(
6162 img_tar_info , fp = self ._prepare_image_for_tar_format (file_bytes , filename )
6263 self .tar .addfile (img_tar_info , fp )
6364
64- if self .keys_mapping :
65- table_data = rename_dict_keys (table_data , self .keys_mapping )
65+ if self .keys_to_rename :
66+ table_data = rename_dict_keys (table_data , self .keys_to_rename )
6667
6768 self .df_raw .append (table_data )
6869 self ._try_close_batch ()
@@ -76,20 +77,20 @@ def _prepare_image_for_tar_format(
7677 img_tar_info .size = len (fp .getvalue ())
7778 return img_tar_info , fp
7879
79- def __enter__ (self ) -> "FileWriter " : # noqa: F821
80+ def __enter__ (self ) -> "ShardsWriter " : # noqa: F821
8081 return self
8182
8283 def __exit__ (
8384 self ,
84- exception_type ,
85- exception_value : Optional [ Exception ],
86- exception_traceback : traceback ,
85+ exception_type : Union [ type [ BaseException ], None ] ,
86+ exception_value : Union [ BaseException , None ],
87+ exception_traceback : Union [ TracebackType , None ] ,
8788 ) -> None :
8889 if len (self .df_raw ) != 0 :
8990 self ._flush (self ._calculate_current_tarname ())
9091 self .last_file_index = 0
9192
92- def _init_writer_from_last_uploaded_file (self ) -> ( int , int ) :
93+ def _init_writer_from_last_uploaded_file (self ) -> Tuple [ int , int ] :
9394 self .filesystem .mkdir (self .destination_dir )
9495 list_csv = [
9596 int (os .path .basename (filename [: - len (self .datafiles_ext )]))
@@ -124,9 +125,12 @@ def _calculate_current_tarname(self) -> str:
124125 def get_current_filename (self , extension : str ) -> str :
125126 extension = extension .lstrip ('.' )
126127 if self .filenaming == "counter" :
127- return f"{ self .last_file_index } .{ extension } "
128+ filename = f"{ self .last_file_index } .{ extension } "
128129 elif self .filenaming == "uuid" :
129- return f"{ uuid .uuid4 ().hex } .{ extension } "
130+ filename = f"{ uuid .uuid4 ().hex } .{ extension } "
131+ else :
132+ raise ValueError (f"Invalid filenaming type: { self .filenaming } " )
133+ return filename
130134
131135 def _try_close_batch (self ) -> None :
132136 old_tarname = self ._calculate_current_tarname ()
@@ -154,7 +158,7 @@ def _flush_and_upload_tar(self, filename: str) -> None:
154158 self .filesystem .save_file (
155159 self .tar_bytes , self .filesystem .join (self .destination_dir , filename ), binary = True
156160 )
157- self .tar = None
161+ self .tar = None # type: ignore
158162 self .tar_bytes = io .BytesIO ()
159163
160164 def _flush (self , tarname : str ) -> None :
0 commit comments