Skip to content

Commit aabe1db

Browse files
authored
Feature/simple gan for api (#6149)
* Expose sigmoid_cross_entropy_with_logits Also, change the `labels` to `label` for api consistency * Very simple GAN based on pure FC layers
1 parent 813bbf4 commit aabe1db

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import errno
2+
import math
3+
import os
4+
5+
import matplotlib
6+
import numpy
7+
8+
import paddle.v2 as paddle
9+
import paddle.v2.fluid as fluid
10+
11+
matplotlib.use('Agg')
12+
import matplotlib.pyplot as plt
13+
import matplotlib.gridspec as gridspec
14+
15+
NOISE_SIZE = 100
16+
NUM_PASS = 1000
17+
NUM_REAL_IMGS_IN_BATCH = 121
18+
NUM_TRAIN_TIMES_OF_DG = 3
19+
LEARNING_RATE = 2e-5
20+
21+
22+
def D(x):
23+
hidden = fluid.layers.fc(input=x,
24+
size=200,
25+
act='relu',
26+
param_attr='D.w1',
27+
bias_attr='D.b1')
28+
logits = fluid.layers.fc(input=hidden,
29+
size=1,
30+
act=None,
31+
param_attr='D.w2',
32+
bias_attr='D.b2')
33+
return logits
34+
35+
36+
def G(x):
37+
hidden = fluid.layers.fc(input=x,
38+
size=200,
39+
act='relu',
40+
param_attr='G.w1',
41+
bias_attr='G.b1')
42+
img = fluid.layers.fc(input=hidden,
43+
size=28 * 28,
44+
act='tanh',
45+
param_attr='G.w2',
46+
bias_attr='G.b2')
47+
return img
48+
49+
50+
def plot(gen_data):
51+
gen_data.resize(gen_data.shape[0], 28, 28)
52+
n = int(math.ceil(math.sqrt(gen_data.shape[0])))
53+
fig = plt.figure(figsize=(n, n))
54+
gs = gridspec.GridSpec(n, n)
55+
gs.update(wspace=0.05, hspace=0.05)
56+
57+
for i, sample in enumerate(gen_data):
58+
ax = plt.subplot(gs[i])
59+
plt.axis('off')
60+
ax.set_xticklabels([])
61+
ax.set_yticklabels([])
62+
ax.set_aspect('equal')
63+
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
64+
65+
return fig
66+
67+
68+
def main():
69+
try:
70+
os.makedirs("./out")
71+
except OSError as e:
72+
if e.errno != errno.EEXIST:
73+
raise
74+
75+
startup_program = fluid.Program()
76+
d_program = fluid.Program()
77+
dg_program = fluid.Program()
78+
79+
with fluid.program_guard(d_program, startup_program):
80+
img = fluid.layers.data(name='img', shape=[784], dtype='float32')
81+
d_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
82+
x=D(img),
83+
label=fluid.layers.data(
84+
name='label', shape=[1], dtype='float32'))
85+
d_loss = fluid.layers.mean(x=d_loss)
86+
87+
with fluid.program_guard(dg_program, startup_program):
88+
noise = fluid.layers.data(
89+
name='noise', shape=[NOISE_SIZE], dtype='float32')
90+
g_img = G(x=noise)
91+
g_program = dg_program.clone()
92+
dg_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
93+
x=D(g_img),
94+
label=fluid.layers.fill_constant_batch_size_like(
95+
input=noise, dtype='float32', shape=[-1, 1], value=1.0))
96+
dg_loss = fluid.layers.mean(x=dg_loss)
97+
98+
opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)
99+
100+
opt.minimize(loss=d_loss, startup_program=startup_program)
101+
opt.minimize(
102+
loss=dg_loss,
103+
startup_program=startup_program,
104+
parameter_list=[
105+
p.name for p in g_program.global_block().all_parameters()
106+
])
107+
exe = fluid.Executor(fluid.CPUPlace())
108+
exe.run(startup_program)
109+
110+
num_true = NUM_REAL_IMGS_IN_BATCH
111+
train_reader = paddle.batch(
112+
paddle.reader.shuffle(
113+
paddle.dataset.mnist.train(), buf_size=60000),
114+
batch_size=num_true)
115+
116+
for pass_id in range(NUM_PASS):
117+
for batch_id, data in enumerate(train_reader()):
118+
num_true = len(data)
119+
n = numpy.random.uniform(
120+
low=-1.0, high=1.0,
121+
size=[num_true * NOISE_SIZE]).astype('float32').reshape(
122+
[num_true, NOISE_SIZE])
123+
generated_img = exe.run(g_program,
124+
feed={'noise': n},
125+
fetch_list={g_img})[0]
126+
real_data = numpy.array(map(lambda x: x[0], data)).astype('float32')
127+
real_data = real_data.reshape(num_true, 784)
128+
total_data = numpy.concatenate([real_data, generated_img])
129+
total_label = numpy.concatenate([
130+
numpy.ones(
131+
shape=[real_data.shape[0], 1], dtype='float32'),
132+
numpy.zeros(
133+
shape=[real_data.shape[0], 1], dtype='float32')
134+
])
135+
d_loss_np = exe.run(d_program,
136+
feed={'img': total_data,
137+
'label': total_label},
138+
fetch_list={d_loss})[0]
139+
for _ in xrange(NUM_TRAIN_TIMES_OF_DG):
140+
n = numpy.random.uniform(
141+
low=-1.0, high=1.0,
142+
size=[2 * num_true * NOISE_SIZE]).astype('float32').reshape(
143+
[2 * num_true, NOISE_SIZE, 1, 1])
144+
dg_loss_np = exe.run(dg_program,
145+
feed={'noise': n},
146+
fetch_list={dg_loss})[0]
147+
print("Pass ID={0}, Batch ID={1}, D-Loss={2}, DG-Loss={3}".format(
148+
pass_id, batch_id, d_loss_np, dg_loss_np))
149+
# generate image each batch
150+
fig = plot(generated_img)
151+
plt.savefig(
152+
'out/{0}.png'.format(str(pass_id).zfill(3)), bbox_inches='tight')
153+
plt.close(fig)
154+
155+
156+
if __name__ == '__main__':
157+
main()

0 commit comments

Comments
 (0)