Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ To release a new version, please update the changelog as followed:

- `SpatialTransform2dAffine` auto `in_channels`
- support TensorFlow 2.0.0-beta1
- Update model weights property, now returns its copy (#PR 1010)

### Dependencies Update

### Deprecated

### Fixed
- Fix `tf.models.Model._construct_graph` for list of outputs, e.g. STN case (PR #1010)

### Removed

Expand All @@ -89,6 +91,7 @@ To release a new version, please update the changelog as followed:
### Contributors

- @zsdonghao
- @ChrisWu1997: #1010

## [2.1.0]

Expand Down
3 changes: 1 addition & 2 deletions examples/database/dispatch_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,5 @@

# get the best model
print("all tasks finished")
sess = tf.InteractiveSession()
net = db.find_top_model(sess=sess, model_name='mlp', sort=[("test_accuracy", -1)])
net = db.find_top_model(model_name='mlp', sort=[("test_accuracy", -1)])
print("the best accuracy {} is from model {}".format(net._test_accuracy, net._name))
76 changes: 32 additions & 44 deletions examples/database/task_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,57 @@
import tensorflow as tf
import tensorlayer as tl

tf.logging.set_verbosity(tf.logging.DEBUG)
# tf.logging.set_verbosity(tf.logging.DEBUG)
tl.logging.set_verbosity(tl.logging.DEBUG)

sess = tf.InteractiveSession()

# connect to database
db = tl.db.TensorHub(ip='localhost', port=27017, dbname='temp', project_name='tutorial')

# load dataset from database
X_train, y_train, X_val, y_val, X_test, y_test = db.find_top_dataset('mnist')

# define placeholder
x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
y_ = tf.placeholder(tf.int64, shape=[None], name='y_')


# define the network
def mlp(x, is_train=True, reuse=False):
with tf.variable_scope("MLP", reuse=reuse):
net = tl.layers.InputLayer(x, name='input')
net = tl.layers.DropoutLayer(net, keep=0.8, is_fix=True, is_train=is_train, name='drop1')
net = tl.layers.DenseLayer(net, n_units=n_units1, act=tf.nn.relu, name='relu1')
net = tl.layers.DropoutLayer(net, keep=0.5, is_fix=True, is_train=is_train, name='drop2')
net = tl.layers.DenseLayer(net, n_units=n_units2, act=tf.nn.relu, name='relu2')
net = tl.layers.DropoutLayer(net, keep=0.5, is_fix=True, is_train=is_train, name='drop3')
net = tl.layers.DenseLayer(net, n_units=10, act=None, name='output')
return net


# define inferences
net_train = mlp(x, is_train=True, reuse=False)
net_test = mlp(x, is_train=False, reuse=True)

# cost for training
y = net_train.outputs
cost = tl.cost.cross_entropy(y, y_, name='xentropy')
correct_prediction = tf.equal(tf.argmax(y, 1), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# cost and accuracy for evalution
y2 = net_test.outputs
cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2')
correct_prediction = tf.equal(tf.argmax(y2, 1), y_)
acc_test = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def mlp():
ni = tl.layers.Input([None, 784], name='input')
net = tl.layers.Dropout(keep=0.8, name='drop1')(ni)
net = tl.layers.Dense(n_units=n_units1, act=tf.nn.relu, name='relu1')(net)
net = tl.layers.Dropout(keep=0.5, name='drop2')(net)
net = tl.layers.Dense(n_units=n_units2, act=tf.nn.relu, name='relu2')(net)
net = tl.layers.Dropout(keep=0.5, name='drop3')(net)
net = tl.layers.Dense(n_units=10, act=None, name='output')(net)
M = tl.models.Model(inputs=ni, outputs=net)
return M

network = mlp()

# cost and accuracy
cost = tl.cost.cross_entropy

def acc(y, y_):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.convert_to_tensor(y_, tf.int64))
return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# define the optimizer
train_params = tl.layers.get_variables_with_name('MLP', True, False)
train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params)

# initialize all variables in the session
sess.run(tf.global_variables_initializer())
train_op = tf.optimizers.Adam(learning_rate=0.0001)

# train the network
# tl.utils.fit(
# network, train_op, cost, X_train, y_train, acc=acc, batch_size=500, n_epoch=20, print_freq=5,
# X_val=X_val, y_val=y_val, eval_train=False
# )

tl.utils.fit(
sess, net_train, train_op, cost, X_train, y_train, x, y_, acc=acc, batch_size=500, n_epoch=1, print_freq=5,
X_val=X_val, y_val=y_val, eval_train=False
network, train_op=tf.optimizers.Adam(learning_rate=0.0001), cost=tl.cost.cross_entropy, X_train=X_train,
y_train=y_train, acc=acc, batch_size=256, n_epoch=20, X_val=X_val, y_val=y_val, eval_train=False,
)

# evaluation and save result that match the result_key
test_accuracy = tl.utils.test(sess, net_test, acc_test, X_test, y_test, x, y_, batch_size=None, cost=cost_test)
test_accuracy = tl.utils.test(network, acc, X_test, y_test, batch_size=None, cost=cost)
test_accuracy = float(test_accuracy)

# save model into database
db.save_model(net_train, model_name='mlp', name=str(n_units1) + '-' + str(n_units2), test_accuracy=test_accuracy)
db.save_model(network, model_name='mlp', name=str(n_units1) + '-' + str(n_units2), test_accuracy=test_accuracy)
# in other script, you can load the model as follow
# net = db.find_model(sess=sess, model_name=str(n_units1)+'-'+str(n_units2)

tf.python.keras.layers.BatchNormalization
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def get_model(inputs_shape):

## 2. Spatial transformer module (sampler)
stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20)
s = stn((nn, ni))
nn = stn((nn, ni))
s = nn

## 3. Classifier
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)
Expand Down
2 changes: 1 addition & 1 deletion tensorlayer/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def run_top_task(self, task_name=None, sort=None, **kwargs):
logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime))
_script = _script.decode('utf-8')
with tf.Graph().as_default(): # # as graph: # clear all TF graphs
exec (_script, globals())
exec(_script, globals())

# set status to finished
_ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}})
Expand Down
3 changes: 0 additions & 3 deletions tensorlayer/layers/spatial_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def __repr__(self):
return s.format(classname=self.__class__.__name__, **self.__dict__)

def build(self, inputs_shape):
print("inputs_shape ", inputs_shape)
if self.in_channels is None and len(inputs_shape) != 2:
raise AssertionError("The dimension of theta layer input must be rank 2, please reshape or flatten it")
if self.in_channels:
Expand All @@ -267,7 +266,6 @@ def build(self, inputs_shape):
# shape = [inputs_shape[1], 6]
self.in_channels = inputs_shape[0][-1] # zsdonghao
shape = [self.in_channels, 6]
print("shape", shape)
self.W = self._get_weights("weights", shape=tuple(shape), init=tl.initializers.Zeros())
identity = np.reshape(np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), newshape=(6, ))
self.b = self._get_weights("biases", shape=(6, ), init=tl.initializers.Constant(identity))
Expand All @@ -282,7 +280,6 @@ def forward(self, inputs):
n_channels is identical to that of U.
"""
theta_input, U = inputs
print("inputs", inputs)
theta = tf.nn.tanh(tf.matmul(theta_input, self.W) + self.b)
outputs = transformer(U, theta, out_size=self.out_size)
# automatically set batch_size and channels
Expand Down
8 changes: 5 additions & 3 deletions tensorlayer/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def trainable_weights(self):
if layer.trainable_weights is not None:
self._trainable_weights.extend(layer.trainable_weights)

return self._trainable_weights
return self._trainable_weights.copy()

@property
def nontrainable_weights(self):
Expand All @@ -415,7 +415,7 @@ def nontrainable_weights(self):
if layer.nontrainable_weights is not None:
self._nontrainable_weights.extend(layer.nontrainable_weights)

return self._nontrainable_weights
return self._nontrainable_weights.copy()

@property
def all_weights(self):
Expand All @@ -429,7 +429,7 @@ def all_weights(self):
if layer.all_weights is not None:
self._all_weights.extend(layer.all_weights)

return self._all_weights
return self._all_weights.copy()

@property
def config(self):
Expand Down Expand Up @@ -669,6 +669,8 @@ def _construct_graph(self):

visited_node_names = set()
for out_node in output_nodes:
if out_node.visited:
continue
queue_node.put(out_node)

while not queue_node.empty():
Expand Down
39 changes: 38 additions & 1 deletion tests/layers/test_layernode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import unittest

Expand Down Expand Up @@ -193,6 +192,44 @@ def MyModel():
self.assertEqual(net.all_layers[1].model._nodes_fixed, True)
self.assertEqual(net.all_layers[1].model.all_layers[0]._nodes_fixed, True)

def test_STN(self):
print('-' * 20, 'test STN', '-' * 20)

def get_model(inputs_shape):
ni = Input(inputs_shape)

## 1. Localisation network
# use MLP as the localisation net
nn = Flatten()(ni)
nn = Dense(n_units=20, act=tf.nn.tanh)(nn)
nn = Dropout(keep=0.8)(nn)
# you can also use CNN instead for MLP as the localisation net

## 2. Spatial transformer module (sampler)
stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20)
# s = stn((nn, ni))
nn = stn((nn, ni))
s = nn

## 3. Classifier
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)
nn = Flatten()(nn)
nn = Dense(n_units=1024, act=tf.nn.relu)(nn)
nn = Dense(n_units=10, act=tf.identity)(nn)

M = Model(inputs=ni, outputs=[nn, s])
return M

net = get_model([None, 40, 40, 1])

inputs = np.random.randn(2, 40, 40, 1).astype(np.float32)
o1, o2 = net(inputs, is_train=True)
self.assertEqual(o1.shape, (2, 10))
self.assertEqual(o2.shape, (2, 40, 40, 1))

self.assertEqual(len(net._node_by_depth), 10)


if __name__ == '__main__':

Expand Down
9 changes: 9 additions & 0 deletions tests/models/test_model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@ def test_get_layer(self):
except Exception as e:
print(e)

def test_model_weights_copy(self):
print('-' * 20, 'test_model_weights_copy', '-' * 20)
model_basic = basic_static_model()
model_weights = model_basic.trainable_weights
ori_len = len(model_weights)
model_weights.append(np.arange(5))
new_len = len(model_weights)
self.assertEqual(new_len - 1, ori_len)


if __name__ == '__main__':

Expand Down