@@ -29,37 +29,29 @@ def initialize_model_and_trainer(model_properties, training_properties, datasetl
2929 logger .info ("Model type is %s" , training_properties ["learner" ])
3030 if  training_properties ["learner" ] ==  "text_cnn" :
3131 model  =  TextCnn (model_properties ).to (device )
32-  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader .train_iter ,
33-  datasetloader .val_iter , datasetloader .test_iter , device )
32+  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader , device )
3433 elif  training_properties ["learner" ] ==  "gru" :
3534 model  =  GRU (model_properties ).to (device )
36-  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader .train_iter ,
37-  datasetloader .val_iter , datasetloader .test_iter , device )
35+  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader , device )
3836 elif  training_properties ["learner" ] ==  "lstm" :
3937 model  =  LSTM (model_properties ).to (device )
40-  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader .train_iter ,
41-  datasetloader .val_iter , datasetloader .test_iter , device )
38+  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader , device )
4239 elif  training_properties ["learner" ] ==  "char_cnn" :
4340 model  =  CharCNN (model_properties ).to (device )
44-  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader .train_iter ,
45-  datasetloader .val_iter , datasetloader .test_iter , device )
41+  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader , device )
4642 elif  training_properties ["learner" ] ==  "vdcnn" :
4743 model  =  VDCNN (model_properties ).to (device )
48-  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader .train_iter ,
49-  datasetloader .val_iter , datasetloader .test_iter , device )
44+  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader , device )
5045 elif  training_properties ["learner" ] ==  "conv_deconv_cnn" :
5146 model  =  ConvDeconvCNN (model_properties )
52-  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader .train_iter ,
53-  datasetloader .val_iter , datasetloader .test_iter , device )
47+  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader , device )
5448 elif  training_properties ["learner" ] ==  "transformer_google" :
5549 model  =  TransformerGoogle (model_properties ).model .to (device )
56-  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader .train_iter ,
57-  datasetloader .val_iter , datasetloader .test_iter , device )
50+  trainer  =  Trainer .trainer_factory ("single_model_trainer" , training_properties , datasetloader , device )
5851 elif  training_properties ["learner" ] ==  "lstmcrf" :
5952 assert  training_properties ["task" ] ==  "ner" 
6053 model  =  LSTMCRF (model_properties ).to (device )
61-  trainer  =  Trainer .trainer_factory ("single_model_ner_trainer" , training_properties , datasetloader .train_iter ,
62-  datasetloader .val_iter , datasetloader .test_iter , device )
54+  trainer  =  Trainer .trainer_factory ("single_model_ner_trainer" , training_properties , datasetloader , device )
6355 else :
6456 raise  ValueError ("Model is not defined! Available learner values are : 'text_cnn', 'char_cnn', 'vdcnn', 'gru', " 
6557 "'lstm', 'conv_deconv_cnn' and 'transformer_google'" )
@@ -176,6 +168,7 @@ def initialize_model_and_trainer(model_properties, training_properties, datasetl
176168 if  category_vocab  is  not   None :
177169 model_properties ["common_model_properties" ]["num_class" ] =  len (category_vocab )
178170 if  ner_vocab  is  not   None :
171+  model_properties ["common_model_properties" ]["ner_vocab" ] =  ner_vocab 
179172 model_properties ["common_model_properties" ]["num_tags" ] =  len (ner_vocab )
180173 model_properties ["common_model_properties" ]["start_id" ] =  ner_vocab .stoi ["<start>" ]
181174 model_properties ["common_model_properties" ]["end_id" ] =  ner_vocab .stoi ["<end>" ]
@@ -217,4 +210,4 @@ def initialize_model_and_trainer(model_properties, training_properties, datasetl
217210 category_vocab_path = category_vocab_path ,
218211 preprocessor = preprocessor .preprocess ,
219212 topk = training_properties ["topk" ])
220-  logger .info ("" )
213+  logger .info ("Done! " )
0 commit comments