|
| 1 | +import errno |
| 2 | +import math |
| 3 | +import os |
| 4 | + |
| 5 | +import matplotlib |
| 6 | +import numpy |
| 7 | + |
| 8 | +import paddle.v2 as paddle |
| 9 | +import paddle.v2.fluid as fluid |
| 10 | + |
| 11 | +matplotlib.use('Agg') |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +import matplotlib.gridspec as gridspec |
| 14 | + |
| 15 | +NOISE_SIZE = 100 |
| 16 | +NUM_PASS = 1000 |
| 17 | +NUM_REAL_IMGS_IN_BATCH = 121 |
| 18 | +NUM_TRAIN_TIMES_OF_DG = 3 |
| 19 | +LEARNING_RATE = 2e-5 |
| 20 | + |
| 21 | + |
| 22 | +def D(x): |
| 23 | + hidden = fluid.layers.fc(input=x, |
| 24 | + size=200, |
| 25 | + act='relu', |
| 26 | + param_attr='D.w1', |
| 27 | + bias_attr='D.b1') |
| 28 | + logits = fluid.layers.fc(input=hidden, |
| 29 | + size=1, |
| 30 | + act=None, |
| 31 | + param_attr='D.w2', |
| 32 | + bias_attr='D.b2') |
| 33 | + return logits |
| 34 | + |
| 35 | + |
| 36 | +def G(x): |
| 37 | + hidden = fluid.layers.fc(input=x, |
| 38 | + size=200, |
| 39 | + act='relu', |
| 40 | + param_attr='G.w1', |
| 41 | + bias_attr='G.b1') |
| 42 | + img = fluid.layers.fc(input=hidden, |
| 43 | + size=28 * 28, |
| 44 | + act='tanh', |
| 45 | + param_attr='G.w2', |
| 46 | + bias_attr='G.b2') |
| 47 | + return img |
| 48 | + |
| 49 | + |
| 50 | +def plot(gen_data): |
| 51 | + gen_data.resize(gen_data.shape[0], 28, 28) |
| 52 | + n = int(math.ceil(math.sqrt(gen_data.shape[0]))) |
| 53 | + fig = plt.figure(figsize=(n, n)) |
| 54 | + gs = gridspec.GridSpec(n, n) |
| 55 | + gs.update(wspace=0.05, hspace=0.05) |
| 56 | + |
| 57 | + for i, sample in enumerate(gen_data): |
| 58 | + ax = plt.subplot(gs[i]) |
| 59 | + plt.axis('off') |
| 60 | + ax.set_xticklabels([]) |
| 61 | + ax.set_yticklabels([]) |
| 62 | + ax.set_aspect('equal') |
| 63 | + plt.imshow(sample.reshape(28, 28), cmap='Greys_r') |
| 64 | + |
| 65 | + return fig |
| 66 | + |
| 67 | + |
| 68 | +def main(): |
| 69 | + try: |
| 70 | + os.makedirs("./out") |
| 71 | + except OSError as e: |
| 72 | + if e.errno != errno.EEXIST: |
| 73 | + raise |
| 74 | + |
| 75 | + startup_program = fluid.Program() |
| 76 | + d_program = fluid.Program() |
| 77 | + dg_program = fluid.Program() |
| 78 | + |
| 79 | + with fluid.program_guard(d_program, startup_program): |
| 80 | + img = fluid.layers.data(name='img', shape=[784], dtype='float32') |
| 81 | + d_loss = fluid.layers.sigmoid_cross_entropy_with_logits( |
| 82 | + x=D(img), |
| 83 | + label=fluid.layers.data( |
| 84 | + name='label', shape=[1], dtype='float32')) |
| 85 | + d_loss = fluid.layers.mean(x=d_loss) |
| 86 | + |
| 87 | + with fluid.program_guard(dg_program, startup_program): |
| 88 | + noise = fluid.layers.data( |
| 89 | + name='noise', shape=[NOISE_SIZE], dtype='float32') |
| 90 | + g_img = G(x=noise) |
| 91 | + g_program = dg_program.clone() |
| 92 | + dg_loss = fluid.layers.sigmoid_cross_entropy_with_logits( |
| 93 | + x=D(g_img), |
| 94 | + label=fluid.layers.fill_constant_batch_size_like( |
| 95 | + input=noise, dtype='float32', shape=[-1, 1], value=1.0)) |
| 96 | + dg_loss = fluid.layers.mean(x=dg_loss) |
| 97 | + |
| 98 | + opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE) |
| 99 | + |
| 100 | + opt.minimize(loss=d_loss, startup_program=startup_program) |
| 101 | + opt.minimize( |
| 102 | + loss=dg_loss, |
| 103 | + startup_program=startup_program, |
| 104 | + parameter_list=[ |
| 105 | + p.name for p in g_program.global_block().all_parameters() |
| 106 | + ]) |
| 107 | + exe = fluid.Executor(fluid.CPUPlace()) |
| 108 | + exe.run(startup_program) |
| 109 | + |
| 110 | + num_true = NUM_REAL_IMGS_IN_BATCH |
| 111 | + train_reader = paddle.batch( |
| 112 | + paddle.reader.shuffle( |
| 113 | + paddle.dataset.mnist.train(), buf_size=60000), |
| 114 | + batch_size=num_true) |
| 115 | + |
| 116 | + for pass_id in range(NUM_PASS): |
| 117 | + for batch_id, data in enumerate(train_reader()): |
| 118 | + num_true = len(data) |
| 119 | + n = numpy.random.uniform( |
| 120 | + low=-1.0, high=1.0, |
| 121 | + size=[num_true * NOISE_SIZE]).astype('float32').reshape( |
| 122 | + [num_true, NOISE_SIZE]) |
| 123 | + generated_img = exe.run(g_program, |
| 124 | + feed={'noise': n}, |
| 125 | + fetch_list={g_img})[0] |
| 126 | + real_data = numpy.array(map(lambda x: x[0], data)).astype('float32') |
| 127 | + real_data = real_data.reshape(num_true, 784) |
| 128 | + total_data = numpy.concatenate([real_data, generated_img]) |
| 129 | + total_label = numpy.concatenate([ |
| 130 | + numpy.ones( |
| 131 | + shape=[real_data.shape[0], 1], dtype='float32'), |
| 132 | + numpy.zeros( |
| 133 | + shape=[real_data.shape[0], 1], dtype='float32') |
| 134 | + ]) |
| 135 | + d_loss_np = exe.run(d_program, |
| 136 | + feed={'img': total_data, |
| 137 | + 'label': total_label}, |
| 138 | + fetch_list={d_loss})[0] |
| 139 | + for _ in xrange(NUM_TRAIN_TIMES_OF_DG): |
| 140 | + n = numpy.random.uniform( |
| 141 | + low=-1.0, high=1.0, |
| 142 | + size=[2 * num_true * NOISE_SIZE]).astype('float32').reshape( |
| 143 | + [2 * num_true, NOISE_SIZE, 1, 1]) |
| 144 | + dg_loss_np = exe.run(dg_program, |
| 145 | + feed={'noise': n}, |
| 146 | + fetch_list={dg_loss})[0] |
| 147 | + print("Pass ID={0}, Batch ID={1}, D-Loss={2}, DG-Loss={3}".format( |
| 148 | + pass_id, batch_id, d_loss_np, dg_loss_np)) |
| 149 | + # generate image each batch |
| 150 | + fig = plot(generated_img) |
| 151 | + plt.savefig( |
| 152 | + 'out/{0}.png'.format(str(pass_id).zfill(3)), bbox_inches='tight') |
| 153 | + plt.close(fig) |
| 154 | + |
| 155 | + |
| 156 | +if __name__ == '__main__': |
| 157 | + main() |
0 commit comments