1616from pathlib import Path
1717import struct
1818import sys
19- from typing import Any , AnyStr , BinaryIO , Dict , List , Optional , Sequence , Tuple , Union
19+ from typing import (
20+ Any ,
21+ AnyStr ,
22+ BinaryIO ,
23+ Dict ,
24+ List ,
25+ Mapping ,
26+ Optional ,
27+ Sequence ,
28+ Tuple ,
29+ Union ,
30+ )
2031import warnings
2132
2233from dateutil .relativedelta import relativedelta
4758from pandas .core .indexes .base import Index
4859from pandas .core .series import Series
4960
50- from pandas .io .common import get_filepath_or_buffer , stringify_path
61+ from pandas .io .common import (
62+ get_compression_method ,
63+ get_filepath_or_buffer ,
64+ get_handle ,
65+ infer_compression ,
66+ stringify_path ,
67+ )
5168
5269_version_error = (
5370 "Version of given Stata file is {version}. pandas supports importing "
@@ -1854,13 +1871,18 @@ def read_stata(
18541871 return data
18551872
18561873
1857- def _open_file_binary_write (fname : FilePathOrBuffer ) -> Tuple [BinaryIO , bool ]:
1874+ def _open_file_binary_write (
1875+ fname : FilePathOrBuffer , compression : Union [str , Mapping [str , str ], None ],
1876+ ) -> Tuple [BinaryIO , bool , Optional [Union [str , Mapping [str , str ]]]]:
18581877 """
18591878 Open a binary file or no-op if file-like.
18601879
18611880 Parameters
18621881 ----------
18631882 fname : string path, path object or buffer
1883+ The file name or buffer.
1884+ compression : {str, dict, None}
1885+ The compression method to use.
18641886
18651887 Returns
18661888 -------
@@ -1871,9 +1893,21 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
18711893 """
18721894 if hasattr (fname , "write" ):
18731895 # See https://github.com/python/mypy/issues/1424 for hasattr challenges
1874- return fname , False # type: ignore
1896+ return fname , False , None # type: ignore
18751897 elif isinstance (fname , (str , Path )):
1876- return open (fname , "wb" ), True
1898+ # Extract compression mode as given, if dict
1899+ compression_typ , compression_args = get_compression_method (compression )
1900+ compression_typ = infer_compression (fname , compression_typ )
1901+ path_or_buf , _ , compression_typ , _ = get_filepath_or_buffer (
1902+ fname , compression = compression_typ
1903+ )
1904+ if compression_typ is not None :
1905+ compression = compression_args
1906+ compression ["method" ] = compression_typ
1907+ else :
1908+ compression = None
1909+ f , _ = get_handle (path_or_buf , "wb" , compression = compression , is_text = False )
1910+ return f , True , compression
18771911 else :
18781912 raise TypeError ("fname must be a binary file, buffer or path-like." )
18791913
@@ -2050,6 +2084,17 @@ class StataWriter(StataParser):
20502084 variable_labels : dict
20512085 Dictionary containing columns as keys and variable labels as values.
20522086 Each label must be 80 characters or smaller.
2087+ compression : str or dict, default 'infer'
2088+ For on-the-fly compression of the output dta. If string, specifies
2089+ compression mode. If dict, value at key 'method' specifies compression
2090+ mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
2091+ 'xz', None}. If compression mode is 'infer' and `fname` is path-like,
2092+ then detect compression from the following extensions: '.gz', '.bz2',
2093+ '.zip', or '.xz' (otherwise no compression). If dict and compression
2094+ mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
2095+ other entries passed as additional compression options.
2096+
2097+ .. versionadded:: 1.1.0
20532098
20542099 Returns
20552100 -------
@@ -2074,7 +2119,12 @@ class StataWriter(StataParser):
20742119 >>> writer = StataWriter('./data_file.dta', data)
20752120 >>> writer.write_file()
20762121
2077- Or with dates
2122+ Directly write a zip file
2123+ >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
2124+ >>> writer = StataWriter('./data_file.zip', data, compression=compression)
2125+ >>> writer.write_file()
2126+
2127+ Save a DataFrame with dates
20782128 >>> from datetime import datetime
20792129 >>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
20802130 >>> writer = StataWriter('./date_data_file.dta', data, {'date' : 'tw'})
@@ -2094,6 +2144,7 @@ def __init__(
20942144 time_stamp : Optional [datetime .datetime ] = None ,
20952145 data_label : Optional [str ] = None ,
20962146 variable_labels : Optional [Dict [Label , str ]] = None ,
2147+ compression : Union [str , Mapping [str , str ], None ] = "infer" ,
20972148 ):
20982149 super ().__init__ ()
20992150 self ._convert_dates = {} if convert_dates is None else convert_dates
@@ -2102,6 +2153,8 @@ def __init__(
21022153 self ._data_label = data_label
21032154 self ._variable_labels = variable_labels
21042155 self ._own_file = True
2156+ self ._compression = compression
2157+ self ._output_file : Optional [BinaryIO ] = None
21052158 # attach nobs, nvars, data, varlist, typlist
21062159 self ._prepare_pandas (data )
21072160
@@ -2389,7 +2442,12 @@ def _encode_strings(self) -> None:
23892442 self .data [col ] = encoded
23902443
23912444 def write_file (self ) -> None :
2392- self ._file , self ._own_file = _open_file_binary_write (self ._fname )
2445+ self ._file , self ._own_file , compression = _open_file_binary_write (
2446+ self ._fname , self ._compression
2447+ )
2448+ if compression is not None :
2449+ self ._output_file = self ._file
2450+ self ._file = BytesIO ()
23932451 try :
23942452 self ._write_header (data_label = self ._data_label , time_stamp = self ._time_stamp )
23952453 self ._write_map ()
@@ -2434,6 +2492,12 @@ def _close(self) -> None:
24342492 """
24352493 # Some file-like objects might not support flush
24362494 assert self ._file is not None
2495+ if self ._output_file is not None :
2496+ assert isinstance (self ._file , BytesIO )
2497+ bio = self ._file
2498+ bio .seek (0 )
2499+ self ._file = self ._output_file
2500+ self ._file .write (bio .read ())
24372501 try :
24382502 self ._file .flush ()
24392503 except AttributeError :
@@ -2898,6 +2962,17 @@ class StataWriter117(StataWriter):
28982962 Smaller columns can be converted by including the column name. Using
28992963 StrLs can reduce output file size when strings are longer than 8
29002964 characters, and either frequently repeated or sparse.
2965+ compression : str or dict, default 'infer'
2966+ For on-the-fly compression of the output dta. If string, specifies
2967+ compression mode. If dict, value at key 'method' specifies compression
2968+ mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
2969+ 'xz', None}. If compression mode is 'infer' and `fname` is path-like,
2970+ then detect compression from the following extensions: '.gz', '.bz2',
2971+ '.zip', or '.xz' (otherwise no compression). If dict and compression
2972+ mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
2973+ other entries passed as additional compression options.
2974+
2975+ .. versionadded:: 1.1.0
29012976
29022977 Returns
29032978 -------
@@ -2923,8 +2998,12 @@ class StataWriter117(StataWriter):
29232998 >>> writer = StataWriter117('./data_file.dta', data)
29242999 >>> writer.write_file()
29253000
2926- Or with long strings stored in strl format
3001+ Directly write a zip file
3002+ >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3003+ >>> writer = StataWriter117('./data_file.zip', data, compression=compression)
3004+ >>> writer.write_file()
29273005
3006+ Or with long strings stored in strl format
29283007 >>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
29293008 ... columns=['strls'])
29303009 >>> writer = StataWriter117('./data_file_with_long_strings.dta', data,
@@ -2946,6 +3025,7 @@ def __init__(
29463025 data_label : Optional [str ] = None ,
29473026 variable_labels : Optional [Dict [Label , str ]] = None ,
29483027 convert_strl : Optional [Sequence [Label ]] = None ,
3028+ compression : Union [str , Mapping [str , str ], None ] = "infer" ,
29493029 ):
29503030 # Copy to new list since convert_strl might be modified later
29513031 self ._convert_strl : List [Label ] = []
@@ -2961,6 +3041,7 @@ def __init__(
29613041 time_stamp = time_stamp ,
29623042 data_label = data_label ,
29633043 variable_labels = variable_labels ,
3044+ compression = compression ,
29643045 )
29653046 self ._map : Dict [str , int ] = {}
29663047 self ._strl_blob = b""
@@ -3281,6 +3362,17 @@ class StataWriterUTF8(StataWriter117):
32813362 The dta version to use. By default, uses the size of data to determine
32823363 the version. 118 is used if data.shape[1] <= 32767, and 119 is used
32833364 for storing larger DataFrames.
3365+ compression : str or dict, default 'infer'
3366+ For on-the-fly compression of the output dta. If string, specifies
3367+ compression mode. If dict, value at key 'method' specifies compression
3368+ mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
3369+ 'xz', None}. If compression mode is 'infer' and `fname` is path-like,
3370+ then detect compression from the following extensions: '.gz', '.bz2',
3371+ '.zip', or '.xz' (otherwise no compression). If dict and compression
3372+ mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
3373+ other entries passed as additional compression options.
3374+
3375+ .. versionadded:: 1.1.0
32843376
32853377 Returns
32863378 -------
@@ -3308,6 +3400,11 @@ class StataWriterUTF8(StataWriter117):
33083400 >>> writer = StataWriterUTF8('./data_file.dta', data)
33093401 >>> writer.write_file()
33103402
3403+ Directly write a zip file
3404+ >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3405+ >>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
3406+ >>> writer.write_file()
3407+
33113408 Or with long strings stored in strl format
33123409
33133410 >>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
@@ -3331,6 +3428,7 @@ def __init__(
33313428 variable_labels : Optional [Dict [Label , str ]] = None ,
33323429 convert_strl : Optional [Sequence [Label ]] = None ,
33333430 version : Optional [int ] = None ,
3431+ compression : Union [str , Mapping [str , str ], None ] = "infer" ,
33343432 ):
33353433 if version is None :
33363434 version = 118 if data .shape [1 ] <= 32767 else 119
@@ -3352,6 +3450,7 @@ def __init__(
33523450 data_label = data_label ,
33533451 variable_labels = variable_labels ,
33543452 convert_strl = convert_strl ,
3453+ compression = compression ,
33553454 )
33563455 # Override version set in StataWriter117 init
33573456 self ._dta_version = version
0 commit comments