Skip to content

Commit 475beac

Browse files
committed
refactor: refactor writers & processor utils
1 parent 881d6a7 commit 475beac

File tree

6 files changed

+58
-48
lines changed

6 files changed

+58
-48
lines changed

DPF/processors/helpers/dataframe_helper.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional
1+
from typing import Dict, List, Optional, Any
22

33
import pandas as pd
44
from tqdm.contrib.concurrent import thread_map
@@ -19,15 +19,15 @@ def __init__(
1919
self.filesystem = filesystem
2020
self.config = config
2121

22-
def _save_dataframe(self, df, path, **kwargs) -> Optional[str]:
22+
def _save_dataframe(self, df: pd.DataFrame, path: str, **kwargs) -> Optional[str]: # type: ignore
2323
errname = None
2424
try:
2525
self.filesystem.save_dataframe(df, path, **kwargs)
2626
except Exception as err:
2727
errname = f"Error during saving file {path}: {err}"
2828
return errname
2929

30-
def validate_path_for_delete(self, columns_to_delete: List[str], path: str):
30+
def validate_path_for_delete(self, columns_to_delete: List[str], path: str) -> None:
3131
df = self.filesystem.read_dataframe(path)
3232
for col in columns_to_delete:
3333
assert col in df.columns, f'Dataframe {path} dont have "{col}" column'
@@ -59,7 +59,7 @@ def delete_columns(
5959
)
6060
return [err for err in errors if err is not None]
6161

62-
def validate_path_for_rename(self, column_map: Dict[str, str], path: str):
62+
def validate_path_for_rename(self, column_map: Dict[str, str], path: str) -> None:
6363
df = self.filesystem.read_dataframe(path)
6464
for col_old, col_new in column_map.items():
6565
assert col_old in df.columns, f'Dataframe {path} dont have "{col_old}" column'
@@ -95,33 +95,32 @@ def rename_columns(
9595
def validate_path_for_update(
9696
self,
9797
key_column: str,
98-
df_new: List[dict],
98+
df_new: List[Dict[str, Any]],
9999
path: str
100-
):
100+
) -> None:
101101
df_new = pd.DataFrame(df_new)
102102
df_old = self.filesystem.read_dataframe(path)
103103
assert key_column in df_old.columns, f'Dataframe {path} dont have "{key_column}" column'
104-
assert set(df_old[key_column]) == set(df_new[key_column]), \
105-
f'Dataframe {path} has different values in "{key_column}"'
104+
assert set(df_old[key_column]) == set(df_new[key_column]), f'Dataframe {path} has different values in "{key_column}"' # type: ignore
106105

107106
duplicates = df_old[df_old[key_column].duplicated()][key_column].tolist()
108107
assert len(duplicates) == 0, f'Dataframe {path} has duplicates in "{key_column}" column: {duplicates}'
109108

110-
duplicates = df_new[df_new[key_column].duplicated()][key_column].tolist()
109+
duplicates = df_new[df_new[key_column].duplicated()][key_column].tolist() # type: ignore
111110
assert len(duplicates) == 0, f'New dataframe for {path} has duplicates in "{key_column}" column: {duplicates}'
112111

113112
assert len(df_old) == len(df_new), f'Length of {path} dataframe is changed'
114113

115114
def update_columns_for_path(
116115
self,
117116
key_column: str,
118-
df_new: List[dict],
117+
df_new: List[Dict[str, Any]],
119118
path: str
120119
) -> Optional[str]:
121120
df_new = pd.DataFrame(df_new)
122121
df_old = self.filesystem.read_dataframe(path)
123122

124-
columns_to_add = [i for i in df_new.columns if i != key_column]
123+
columns_to_add = [i for i in df_new.columns if i != key_column] # type: ignore [attr-defined]
125124
columns_intersection = set(df_old.columns).intersection(set(columns_to_add))
126125

127126
if len(columns_intersection) > 0:
@@ -133,7 +132,7 @@ def update_columns_for_path(
133132
def update_columns(
134133
self,
135134
key_column: str,
136-
path2df: Dict[str, List[dict]],
135+
path2df: Dict[str, List[Dict[str, Any]]],
137136
max_threads: int = 16,
138137
pbar: bool = True
139138
) -> List[str]:

DPF/processors/processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def save_to_shards(
393393
filesystem: Optional[FileSystem] = None,
394394
max_files_in_shard: int = 1000,
395395
datafiles_ext: str = "csv",
396-
archives_ext: Optional[str] = "tar",
396+
archives_ext: str = "tar",
397397
filenaming: str = "counter",
398398
columns_to_save: Optional[List[str]] = None,
399399
rename_columns: Optional[Dict[str, str]] = None,
@@ -431,7 +431,7 @@ def save_to_shards(
431431
writer = ShardsWriter(
432432
filesystem,
433433
destination_dir,
434-
keys_mapping=rename_columns,
434+
keys_to_rename=rename_columns,
435435
max_files_in_shard=max_files_in_shard,
436436
datafiles_ext=datafiles_ext,
437437
archives_ext=archives_ext,

DPF/processors/writers/filewriter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import traceback
1+
from types import TracebackType
22
from abc import abstractmethod
3-
from typing import Dict, Optional, Tuple
3+
from typing import Dict, Optional, Tuple, Union
44

55

66
class ABSWriter:
@@ -19,8 +19,8 @@ def __enter__(self) -> "ABSWriter":
1919
@abstractmethod
2020
def __exit__(
2121
self,
22-
exception_type,
23-
exception_value: Optional[Exception],
24-
exception_traceback: traceback,
22+
exception_type: Union[type[BaseException], None],
23+
exception_value: Union[BaseException, None],
24+
exception_traceback: Union[TracebackType, None],
2525
) -> None:
2626
pass

DPF/processors/writers/sharded_files_writer.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import traceback
33
import uuid
4-
from typing import Dict, List, Optional, Tuple
4+
from types import TracebackType
5+
from typing import Dict, List, Optional, Tuple, Any, Union
56

67
import pandas as pd
78

@@ -34,9 +35,9 @@ def __init__(
3435
self.filenaming = filenaming
3536
assert self.filenaming in ["counter", "uuid"], "Invalid files naming"
3637

37-
self.df_raw = []
38+
self.df_raw: List[Dict[str, Any]] = []
3839
self.shard_index, self.last_file_index = self._init_writer_from_last_uploaded_file()
39-
self.last_path_to_dir = None
40+
self.last_path_to_dir: str = None # type: ignore
4041

4142
def save_sample(
4243
self,
@@ -66,20 +67,20 @@ def save_sample(
6667
self.df_raw.append(table_data)
6768
self._try_close_batch()
6869

69-
def __enter__(self) -> "FileWriter": # noqa: F821
70+
def __enter__(self) -> "ShardedFilesWriter": # noqa: F821
7071
return self
7172

7273
def __exit__(
7374
self,
74-
exception_type: Optional[type],
75-
exception_value: Optional[Exception],
76-
exception_traceback: traceback,
75+
exception_type: Union[type[BaseException], None],
76+
exception_value: Union[BaseException, None],
77+
exception_traceback: Union[TracebackType, None],
7778
) -> None:
7879
if len(self.df_raw) != 0:
7980
self._flush(self._calculate_current_dirname())
8081
self.last_file_index = 0
8182

82-
def _init_writer_from_last_uploaded_file(self) -> (int, int):
83+
def _init_writer_from_last_uploaded_file(self) -> Tuple[int, int]:
8384
self.filesystem.mkdir(self.destination_dir)
8485
list_dirs = [
8586
int(os.path.basename(filename[: -len(self.datafiles_ext)]))
@@ -109,9 +110,12 @@ def _init_writer_from_last_uploaded_file(self) -> (int, int):
109110
def get_current_filename(self, extension: str) -> str:
110111
extension = extension.lstrip('.')
111112
if self.filenaming == "counter":
112-
return f"{self.last_file_index}.{extension}"
113+
filename = f"{self.last_file_index}.{extension}"
113114
elif self.filenaming == "uuid":
114-
return f"{uuid.uuid4().hex}.{extension}"
115+
filename = f"{uuid.uuid4().hex}.{extension}"
116+
else:
117+
raise ValueError(f"Invalid filenaming type: {self.filenaming}")
118+
return filename
115119

116120
def _calculate_current_dirname(self) -> str:
117121
return str(self.shard_index)

DPF/processors/writers/shards_writer.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import tarfile
44
import traceback
55
import uuid
6-
from typing import Dict, List, Optional, Tuple
6+
from types import TracebackType
7+
from typing import Dict, List, Optional, Tuple, Any, Union
78

89
import pandas as pd
910

@@ -23,24 +24,24 @@ def __init__(
2324
self,
2425
filesystem: FileSystem,
2526
destination_dir: str,
26-
keys_mapping: Optional[Dict[str, str]] = None,
27-
max_files_in_shard: Optional[int] = 1000,
28-
datafiles_ext: Optional[str] = "csv",
29-
archives_ext: Optional[str] = "tar",
27+
keys_to_rename: Optional[Dict[str, str]] = None,
28+
max_files_in_shard: int = 1000,
29+
datafiles_ext: str = "csv",
30+
archives_ext: str = "tar",
3031
filenaming: str = "counter"
3132
) -> None:
3233
self.filesystem = filesystem
3334
self.destination_dir = destination_dir
34-
self.keys_mapping = keys_mapping
35+
self.keys_to_rename = keys_to_rename
3536
self.max_files_in_shard = max_files_in_shard
3637
self.datafiles_ext = "." + datafiles_ext.lstrip(".")
3738
self.archives_ext = "." + archives_ext.lstrip(".")
3839
self.filenaming = filenaming
3940
assert self.filenaming in ["counter", "uuid"], "Invalid files naming"
4041

41-
self.df_raw = []
42+
self.df_raw: List[Dict[str, Any]] = []
4243
self.tar_bytes = io.BytesIO()
43-
self.tar = None
44+
self.tar: tarfile.TarFile = None # type: ignore
4445
self.shard_index, self.last_file_index = self._init_writer_from_last_uploaded_file()
4546

4647
def save_sample(
@@ -61,8 +62,8 @@ def save_sample(
6162
img_tar_info, fp = self._prepare_image_for_tar_format(file_bytes, filename)
6263
self.tar.addfile(img_tar_info, fp)
6364

64-
if self.keys_mapping:
65-
table_data = rename_dict_keys(table_data, self.keys_mapping)
65+
if self.keys_to_rename:
66+
table_data = rename_dict_keys(table_data, self.keys_to_rename)
6667

6768
self.df_raw.append(table_data)
6869
self._try_close_batch()
@@ -76,20 +77,20 @@ def _prepare_image_for_tar_format(
7677
img_tar_info.size = len(fp.getvalue())
7778
return img_tar_info, fp
7879

79-
def __enter__(self) -> "FileWriter": # noqa: F821
80+
def __enter__(self) -> "ShardsWriter": # noqa: F821
8081
return self
8182

8283
def __exit__(
8384
self,
84-
exception_type,
85-
exception_value: Optional[Exception],
86-
exception_traceback: traceback,
85+
exception_type: Union[type[BaseException], None],
86+
exception_value: Union[BaseException, None],
87+
exception_traceback: Union[TracebackType, None],
8788
) -> None:
8889
if len(self.df_raw) != 0:
8990
self._flush(self._calculate_current_tarname())
9091
self.last_file_index = 0
9192

92-
def _init_writer_from_last_uploaded_file(self) -> (int, int):
93+
def _init_writer_from_last_uploaded_file(self) -> Tuple[int, int]:
9394
self.filesystem.mkdir(self.destination_dir)
9495
list_csv = [
9596
int(os.path.basename(filename[: -len(self.datafiles_ext)]))
@@ -124,9 +125,12 @@ def _calculate_current_tarname(self) -> str:
124125
def get_current_filename(self, extension: str) -> str:
125126
extension = extension.lstrip('.')
126127
if self.filenaming == "counter":
127-
return f"{self.last_file_index}.{extension}"
128+
filename = f"{self.last_file_index}.{extension}"
128129
elif self.filenaming == "uuid":
129-
return f"{uuid.uuid4().hex}.{extension}"
130+
filename = f"{uuid.uuid4().hex}.{extension}"
131+
else:
132+
raise ValueError(f"Invalid filenaming type: {self.filenaming}")
133+
return filename
130134

131135
def _try_close_batch(self) -> None:
132136
old_tarname = self._calculate_current_tarname()
@@ -154,7 +158,7 @@ def _flush_and_upload_tar(self, filename: str) -> None:
154158
self.filesystem.save_file(
155159
self.tar_bytes, self.filesystem.join(self.destination_dir, filename), binary=True
156160
)
157-
self.tar = None
161+
self.tar = None # type: ignore
158162
self.tar_bytes = io.BytesIO()
159163

160164
def _flush(self, tarname: str) -> None:

DPF/processors/writers/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
def rename_dict_keys(d: dict, keys_mapping: dict) -> dict:
1+
from typing import Dict, Any
2+
3+
4+
def rename_dict_keys(d: Dict[Any, Any], keys_mapping: Dict[Any, Any]) -> Dict[Any, Any]:
25
for k, v in keys_mapping.items():
36
d[v] = d.pop(k)
47
return d

0 commit comments

Comments
 (0)