|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +import os |
| 4 | +import cPickle as pickle |
| 5 | +from os.path import expanduser |
| 6 | +import sys |
| 7 | + |
| 8 | +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..","util"))) |
| 9 | +from tf_utils import fcLayer, createLSTMCell, applyActivation, predictionLayer, compute_cost |
| 10 | +#from predContext import predContext, createHtDict |
| 11 | + |
| 12 | +class model(object): |
| 13 | + |
| 14 | + # Task params |
| 15 | + is_multi_task = True |
| 16 | + secondary_task = "word generation" |
| 17 | + primary_task = "classification" |
| 18 | + |
| 19 | + # Model params |
| 20 | + # 0 -- shared; 1 -- context; 2 -- task |
| 21 | +fc_activation = "tanh" |
| 22 | +context_output_activation = "tanh" |
| 23 | +task_output_activation = "softmax" |
| 24 | +dropout = 0.0 |
| 25 | +body_lstm_size = 128 |
| 26 | +context_lstm_size = 128 |
| 27 | +task_lstm_size = 128 |
| 28 | +body_n_layer = 1 |
| 29 | +context_n_layer = 1 |
| 30 | +task_n_layer = 1 |
| 31 | +context_branch_fc = 512 |
| 32 | +task_branch_fc = 512 |
| 33 | + |
| 34 | +# Data params |
| 35 | +n_classes = 2 |
| 36 | +batch_size = 128 |
| 37 | +max_length = 52 |
| 38 | +feature_length = 300 |
| 39 | + context_dim = 300 |
| 40 | +task_dim = n_classes |
| 41 | + |
| 42 | +# Hyper- params |
| 43 | +lr = 0.0001 |
| 44 | +context_lr = 0.5 * lr |
| 45 | +n_epoch = 500 |
| 46 | +topN = 4 |
| 47 | +keep_prob_val = 1.0 |
| 48 | + |
| 49 | + |
| 50 | +def buildModel(self, x, y_context, y_task, is_train, dropout, scope="multiTask"): |
| 51 | + |
| 52 | + # Assume the input shape is (batch_size, max_length, feature_length) |
| 53 | + |
| 54 | + #TASK = primary task, CONTEXT = secondary task |
| 55 | + |
| 56 | + # Create lstm cell for the shared layer |
| 57 | + body_lstm_cell, _ = createLSTMCell(self.batch_size, self.body_lstm_size, self.body_n_layer, forget_bias=0.0) |
| 58 | + # Create lstm cell for branch 1 |
| 59 | + context_lstm_cell, _ = createLSTMCell(self.batch_size, self.context_lstm_size, self.context_n_layer, forget_bias=0.0) |
| 60 | + # Create lstm cells for branch 2 |
| 61 | + task_lstm_cell, _ = createLSTMCell(self.batch_size, self.task_lstm_size, self.task_n_layer, forget_bias=0.0) |
| 62 | + |
| 63 | + context_cost = tf.constant(0) |
| 64 | + task_cost = tf.constant(0.0, dtype=tf.float32) |
| 65 | + |
| 66 | + if not self.is_multi_task: context_output = tf.constant(0) |
| 67 | + |
| 68 | + with tf.variable_scope("shared_lstm"): |
| 69 | + body_cell_output, last_body_state = tf.nn.dynamic_rnn(cell = body_lstm_cell, dtype=tf.float32, sequence_length=self.length(x), inputs=x) |
| 70 | + |
| 71 | + if self.is_multi_task: |
| 72 | + with tf.variable_scope("context_branch"): |
| 73 | + context_cell_output, last_context_state = tf.nn.dynamic_rnn(cell = context_lstm_cell, dtype=tf.float32, sequence_length=self.length(body_cell_output), inputs=body_cell_output) |
| 74 | + |
| 75 | + # The output from LSTMs will be (batch_size, max_length, out_size) |
| 76 | + |
| 77 | + # Select the last output that is not generated by zero vectors |
| 78 | + if self.secondary_task == "missing word": |
| 79 | + last_context_output = self.last_relevant(context_cell_output, self.length(context_cell_output)) |
| 80 | + # feed the last output to the fc layer and make prediction |
| 81 | + with tf.variable_scope("context_fc"): |
| 82 | + context_fc_out = fcLayer(x=last_context_output, in_shape=self.context_lstm_size, out_shape=self.context_branch_fc, activation=self.fc_activation, dropout=self.dropout, is_train=is_train, scope="fc1") |
| 83 | + with tf.variable_scope("context_pred"): |
| 84 | + context_output, context_logits = predictionLayer(x=context_fc_out, y=y_context, in_shape=self.context_branch_fc, out_shape=y_context.get_shape()[-1].value, activation=self.context_output_activation) |
| 85 | + context_cost = compute_cost(logit=context_logits, y=y_context, out_type="last_only", max_length=self.max_length, batch_size=self.batch_size, embed_dim=self.feature_length, activation=self.context_output_activation) |
| 86 | + |
| 87 | + if self.secondary_task == "word generation": |
| 88 | +context_cell_output = tf.transpose(context_cell_output, [1, 0, 2]) |
| 89 | + context_cell_output = tf.reshape(context_cell_output, [-1, self.context_lstm_size]) |
| 90 | + context_output_list = tf.split(context_cell_output, self.max_length, 0) |
| 91 | + fc_output_list = [] |
| 92 | +with tf.variable_scope("context_fc"): |
| 93 | + for step in range(self.max_length): |
| 94 | + if step > 0: tf.get_variable_scope().reuse_variables() |
| 95 | + fc_out = fcLayer(x=context_output_list[step], in_shape=self.context_lstm_size, out_shape=self.context_branch_fc, activation=self.fc_activation, dropout=self.dropout, is_train=is_train, scope="fc1") |
| 96 | + fc_output_list.append(tf.expand_dims(fc_out, axis=1)) |
| 97 | + print len(fc_output_list) |
| 98 | + print fc_output_list[0].get_shape() |
| 99 | + context_fc_out = tf.concat(fc_output_list, axis=1) |
| 100 | + print "context fc output shape before transpose: ", context_fc_out.get_shape() |
| 101 | + #context_fc_out = tf.transpose(context_fc_out, [1, 0, 2]) |
| 102 | + #print "Context fc output shape: ", context_fc_out.get_shape() |
| 103 | +with tf.variable_scope("context_pred"): |
| 104 | + context_output, context_logits = predictionLayer(x=context_fc_out, y=y_context, in_shape=self.context_branch_fc, out_shape=y_context.get_shape()[-1].value, activation=self.context_output_activation) |
| 105 | + print "Context prediction output shape: ", context_output.get_shape() |
| 106 | + context_cost = compute_cost(logit=context_logits, y=y_context, out_type="sequential", max_length=self.max_length, batch_size=self.batch_size, embed_dim=self.feature_length,activation=self.context_output_activation) |
| 107 | + |
| 108 | + print "Context cost shape: ", context_cost.get_shape() |
| 109 | + |
| 110 | + with tf.variable_scope("task_branch"): |
| 111 | + task_cell_output, last_task_state = tf.nn.dynamic_rnn(cell = task_lstm_cell, dtype=tf.float32, sequence_length=self.length(body_cell_output), inputs=body_cell_output) |
| 112 | + |
| 113 | + with tf.variable_scope("task_fc"): |
| 114 | + # Select the last output that is not generated by zero vectors |
| 115 | + last_task_output = self.last_relevant(task_cell_output, self.length(task_cell_output)) |
| 116 | + # feed the last output to the fc layer and make prediction |
| 117 | + task_fc_out = fcLayer(x=last_task_output, in_shape=self.task_lstm_size, out_shape=self.task_branch_fc, activation=self.fc_activation, dropout=self.dropout, is_train=is_train, scope="fc2") |
| 118 | + task_output, task_logits = predictionLayer(x=task_fc_out, y=y_task, in_shape=self.context_branch_fc, out_shape=y_task.get_shape()[-1].value, activation=self.task_output_activation) |
| 119 | +task_cost = compute_cost(logit=task_logits, y=y_task, out_type="last_only", max_length=self.max_length, batch_size=self.batch_size, embed_dim=self.feature_length, activation=self.task_output_activation) |
| 120 | + |
| 121 | + print "Task cost shape: ", task_cost.get_shape() |
| 122 | + return context_cost, task_cost, task_output, context_output |
| 123 | + |
| 124 | +# Flatten the output tensor to shape features in all examples x output size |
| 125 | +# construct an index into that by creating a tensor with the start indices for each example tf.range(0, batch_size) x max_length |
| 126 | +# and add the individual sequence lengths to it |
| 127 | +# tf.gather() then performs the acutal indexing. |
| 128 | +def last_relevant(self, output, length): |
| 129 | + index = tf.range(0, self.batch_size) * self.max_length + (length - 1) |
| 130 | + out_size = int(output.get_shape()[2]) |
| 131 | + flat = tf.reshape(output, [-1, out_size]) |
| 132 | + relevant = tf.gather(flat, index) |
| 133 | + return relevant |
| 134 | + |
| 135 | +# Assume that the sequences are padded with 0 vectors to have shape (batch_size, max_length, feature_length) |
| 136 | + |
| 137 | + def length(self, sequence): |
| 138 | + used = tf.sign(tf.reduce_max(tf.abs(sequence), reduction_indices=2)) |
| 139 | + length = tf.reduce_sum(used, reduction_indices=1) |
| 140 | + length = tf.cast(length, tf.int32) |
| 141 | + print length.get_shape() |
| 142 | + return length |
| 143 | + |
| 144 | + |
| 145 | + |
| 146 | + |
| 147 | + |
0 commit comments