@@ -30,8 +30,8 @@ def trainModel(M):
3030
3131 # Reformat the data according to the secondary task
3232 # Create class look up table
33- max_length = reformat_data (data_path , M .secondary_task == "missing word" )
34- class_look_up (data_path )
33+ # max_length = reformat_data(data_path, M.secondary_task == "missing word")
34+ # class_look_up(data_path)
3535
3636 n_classes , word2vec_dic , n_test , n_train , missing_word_dic = get_data (data_path )
3737
@@ -52,10 +52,28 @@ def trainModel(M):
5252 optimizer2 = tf .train .AdamOptimizer (learning_rate = task_lr )
5353
5454 context_cost , task_cost , task_output , context_output = M .buildModel (x , y_context , y_task , is_train , keep_prob )
55+
56+ context_vars = []
57+ task_vars = []
58+ for var in tf .trainable_variables ():
59+ if "context" not in var .name : task_vars .append (var )
60+ if "task" not in var .name : context_vars .append (var )
61+ for var in context_vars :
62+ print "Context variable: " , var .name
63+ print ("\n " )
64+ for var in task_vars :
65+ print "Task variables: " , var .name
66+
5567 if M .is_multi_task :
5668 train_step1 = optimizer1 .minimize (context_cost )
5769 train_step2 = optimizer2 .minimize (task_cost )
5870
71+ #if M.is_multi_task:
72+ # context_grads, _ = tf.clip_by_global_norm(tf.gradients(context_cost, context_vars), 10)
73+ # train_step1 = optimizer1.apply_gradients(zip(context_grads, context_vars))
74+ #task_grads, _ = tf.clip_by_global_norm(tf.gradients(task_cost, task_vars), 10)
75+ #train_step2 = optimizer2.apply_gradients(zip(task_grads, task_vars))
76+
5977 accuracy_list = np .zeros ((M .n_epoch ))
6078 # Start running operations on the graph
6179 sess = tf .Session ()
@@ -72,8 +90,11 @@ def trainModel(M):
7290 for minibatch in range (n_train_batches ):
7391 encoded_batch , batch_classes , batch_context_encoded , batch_context , batch_identifier , batch_text , batch_length = load_batch (n_classes , word2vec_dic , missing_word_dic , M .feature_length , M .max_length , data_path + "/Train/" , 1 , train_file , test_file , all_classes , start_idx , M .batch_size , M .secondary_task )
7492 start_idx += M .batch_size
75-
76- feed_dict = {x : encoded_batch , y_context : batch_context_encoded , y_task : batch_classes , is_train :1 , keep_prob :0.5 , context_lr :(1 - epoch * 1.0 / M .n_epoch )* M .lr , task_lr :epoch * 1.0 / M .n_epoch * M .lr }
93+
94+ if M .is_multi_task :
95+ feed_dict = {x : encoded_batch , y_context : batch_context_encoded , y_task : batch_classes , is_train :1 , keep_prob :0.5 , context_lr :(1 - epoch * 1.0 / M .n_epoch )* M .lr , task_lr :epoch * 1.0 / M .n_epoch * M .lr }
96+ else :
97+ feed_dict = {x : encoded_batch , y_context : batch_context_encoded , y_task : batch_classes , is_train :1 , keep_prob :0.5 , context_lr : 0.0 , task_lr :M .lr }
7798 if M .is_multi_task :
7899 train_step1 .run (feed_dict = feed_dict )
79100 context_cost_val , _ , _ = sess .run (fetches = [context_cost , task_cost , task_output ], feed_dict = feed_dict )
0 commit comments