Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions image_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TBD
48 changes: 48 additions & 0 deletions image_classification/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import random
from paddle.v2.image import load_and_transform


def train_reader(train_list):
def reader():
with open(train_list, 'r') as f:
lines = [line.strip() for line in f]
random.shuffle(lines)
for line in lines:
img_path, lab = line.strip().split('\t')
im = load_and_transform(img_path, 256, 224, True)
yield im.flatten().astype('float32'), int(lab)

return reader


def test_reader(test_list):
def reader():
with open(test_list, 'r') as f:
lines = [line.strip() for line in f]
for line in lines:
img_path, lab = line.strip().split('\t')
im = load_and_transform(img_path, 256, 224, False)
yield im.flatten().astype('float32'), int(lab)

return reader


if __name__ == '__main__':
for im in train_reader('train.list'):
print len(im[0])
for im in train_reader('test.list'):
print len(im[0])
81 changes: 81 additions & 0 deletions image_classification/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import gzip

import paddle.v2 as paddle
import reader
import vgg

DATA_DIM = 3 * 224 * 224
CLASS_DIM = 1000
BATCH_SIZE = 128


def main():

# PaddlePaddle init
paddle.init(use_gpu=True, trainer_count=4)

image = paddle.layer.data(
name="image", type=paddle.data_type.dense_vector(DATA_DIM))
lbl = paddle.layer.data(
name="label", type=paddle.data_type.integer_value(CLASS_DIM))
net = vgg.vgg13(image)
out = paddle.layer.fc(
input=net, size=CLASS_DIM, act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=out, label=lbl)

# Create parameters
parameters = paddle.parameters.create(cost)

# Create optimizer
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0005 *
BATCH_SIZE),
learning_rate=0.01 / BATCH_SIZE,
learning_rate_decay_a=0.1,
learning_rate_decay_b=128000 * 35,
learning_rate_schedule="discexp", )

train_reader = paddle.batch(
paddle.reader.shuffle(reader.test_reader("train.list"), buf_size=1000),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
reader.train_reader("test.list"), batch_size=BATCH_SIZE)

# End batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f:
parameters.to_tar(f)

result = trainer.test(reader=test_reader)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)

# Create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer)

trainer.train(
reader=train_reader, num_passes=200, event_handler=event_handler)


if __name__ == '__main__':
main()
66 changes: 66 additions & 0 deletions image_classification/vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.v2 as paddle

__all__ = ['vgg13', 'vgg16', 'vgg19']


def vgg(input, nums):
def conv_block(input, num_filter, groups, num_channels=None):
return paddle.networks.img_conv_group(
input=input,
num_channels=num_channels,
pool_size=2,
pool_stride=2,
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act=paddle.activation.Relu(),
pool_type=paddle.pooling.Max())

assert len(nums) == 5
# the channel of input feature is 3
conv1 = conv_block(input, 64, nums[0], 3)
conv2 = conv_block(conv1, 128, nums[1])
conv3 = conv_block(conv2, 256, nums[2])
conv4 = conv_block(conv3, 512, nums[3])
conv5 = conv_block(conv4, 512, nums[4])

fc_dim = 4096
fc1 = paddle.layer.fc(
input=conv5,
size=fc_dim,
act=paddle.activation.Relu(),
layer_attr=paddle.attr.Extra(drop_rate=0.5))
fc2 = paddle.layer.fc(
input=fc1,
size=fc_dim,
act=paddle.activation.Relu(),
layer_attr=paddle.attr.Extra(drop_rate=0.5))
return fc2


def vgg13(input):
nums = [2, 2, 2, 2, 2]
return vgg(input, nums)


def vgg16(input):
nums = [2, 2, 3, 3, 3]
return vgg(input, nums)


def vgg19(input):
nums = [2, 2, 4, 4, 4]
return vgg(input, nums)