1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""
15- ACL2016 Multimodal Machine Translation. Please see this websit for more details:
16- http://www.statmt.org/wmt16/multimodal-task.html#task1
15+ ACL2016 Multimodal Machine Translation. Please see this website for more
16+ details: http://www.statmt.org/wmt16/multimodal-task.html#task1
1717
1818If you use the dataset created for your task, please cite the following paper:
1919Multi30K: Multilingual English-German Image Descriptions.
5656UNK_MARK = "<unk>"
5757
5858
59- def __build_dict__ (tar_file , dict_size , save_path , lang ):
59+ def __build_dict (tar_file , dict_size , save_path , lang ):
6060 word_dict = defaultdict (int )
6161 with tarfile .open (tar_file , mode = "r" ) as f :
6262 for line in f .extractfile ("wmt16/train" ):
@@ -75,12 +75,12 @@ def __build_dict__(tar_file, dict_size, save_path, lang):
7575 fout .write ("%s\n " % (word [0 ]))
7676
7777
78- def __load_dict__ (tar_file , dict_size , lang , reverse = False ):
78+ def __load_dict (tar_file , dict_size , lang , reverse = False ):
7979 dict_path = os .path .join (paddle .v2 .dataset .common .DATA_HOME ,
8080 "wmt16/%s_%d.dict" % (lang , dict_size ))
8181 if not os .path .exists (dict_path ) or (
8282 len (open (dict_path , "r" ).readlines ()) != dict_size ):
83- __build_dict__ (tar_file , dict_size , dict_path , lang )
83+ __build_dict (tar_file , dict_size , dict_path , lang )
8484
8585 word_dict = {}
8686 with open (dict_path , "r" ) as fdict :
@@ -92,7 +92,7 @@ def __load_dict__(tar_file, dict_size, lang, reverse=False):
9292 return word_dict
9393
9494
95- def __get_dict_size__ (src_dict_size , trg_dict_size , src_lang ):
95+ def __get_dict_size (src_dict_size , trg_dict_size , src_lang ):
9696 src_dict_size = min (src_dict_size , (TOTAL_EN_WORDS if src_lang == "en" else
9797 TOTAL_DE_WORDS ))
9898 trg_dict_size = min (trg_dict_size , (TOTAL_DE_WORDS if src_lang == "en" else
@@ -102,9 +102,9 @@ def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):
102102
103103def reader_creator (tar_file , file_name , src_dict_size , trg_dict_size , src_lang ):
104104 def reader ():
105- src_dict = __load_dict__ (tar_file , src_dict_size , src_lang )
106- trg_dict = __load_dict__ (tar_file , trg_dict_size ,
107- ("de" if src_lang == "en" else "en" ))
105+ src_dict = __load_dict (tar_file , src_dict_size , src_lang )
106+ trg_dict = __load_dict (tar_file , trg_dict_size ,
107+ ("de" if src_lang == "en" else "en" ))
108108
109109 # the indice for start mark, end mark, and unk are the same in source
110110 # language and target language. Here uses the source language
@@ -173,8 +173,8 @@ def train(src_dict_size, trg_dict_size, src_lang="en"):
173173
174174 assert (src_lang in ["en" , "de" ], ("An error language type. Only support: "
175175 "en (for English); de(for Germany)" ))
176- src_dict_size , trg_dict_size = __get_dict_size__ (src_dict_size ,
177- trg_dict_size , src_lang )
176+ src_dict_size , trg_dict_size = __get_dict_size (src_dict_size , trg_dict_size ,
177+ src_lang )
178178
179179 return reader_creator (
180180 tar_file = paddle .v2 .dataset .common .download (DATA_URL , "wmt16" , DATA_MD5 ,
@@ -222,8 +222,8 @@ def test(src_dict_size, trg_dict_size, src_lang="en"):
222222 ("An error language type. "
223223 "Only support: en (for English); de(for Germany)" ))
224224
225- src_dict_size , trg_dict_size = __get_dict_size__ (src_dict_size ,
226- trg_dict_size , src_lang )
225+ src_dict_size , trg_dict_size = __get_dict_size (src_dict_size , trg_dict_size ,
226+ src_lang )
227227
228228 return reader_creator (
229229 tar_file = paddle .v2 .dataset .common .download (DATA_URL , "wmt16" , DATA_MD5 ,
@@ -269,8 +269,8 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"):
269269 assert (src_lang in ["en" , "de" ],
270270 ("An error language type. "
271271 "Only support: en (for English); de(for Germany)" ))
272- src_dict_size , trg_dict_size = __get_dict_size__ (src_dict_size ,
273- trg_dict_size , src_lang )
272+ src_dict_size , trg_dict_size = __get_dict_size (src_dict_size , trg_dict_size ,
273+ src_lang )
274274
275275 return reader_creator (
276276 tar_file = paddle .v2 .dataset .common .download (DATA_URL , "wmt16" , DATA_MD5 ,
@@ -308,7 +308,7 @@ def get_dict(lang, dict_size, reverse=False):
308308 "Please invoke paddle.dataset.wmt16.train/test/validation "
309309 "first to build the dictionary." )
310310 tar_file = os .path .join (paddle .v2 .dataset .common .DATA_HOME , "wmt16.tar.gz" )
311- return __load_dict__ (tar_file , dict_size , lang , reverse )
311+ return __load_dict (tar_file , dict_size , lang , reverse )
312312
313313
314314def fetch ():
0 commit comments