Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/paddle/image/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ def initHook(settings, height, width, color, num_class, **kwargs):
def process(settings, file_list):
for i in xrange(1024):
img = np.random.rand(1, settings.data_size).reshape(-1, 1).flatten()
lab = random.randint(0, settings.num_class)
lab = random.randint(0, settings.num_class - 1)
yield img.astype('float32'), int(lab)
51 changes: 51 additions & 0 deletions benchmark/paddle/image/run_mkldnn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
set -e

unset OMP_NUM_THREADS MKL_NUM_THREADS
export OMP_DYNAMIC="FALSE"
export KMP_AFFINITY="granularity=fine,compact,0,0"

function train() {
topology=$1
bs=$2
use_mkldnn=$3
if [ $3 == "True" ]; then
use_mkldnn=$3
thread=1
log="logs/${topology}-mkldnn-${bs}.log"
elif [ $3 == "False" ]; then
use_mkldnn=$3
thread=`nproc`
log="logs/${topology}-${thread}mklml-${bs}.log"
else
echo "Wrong input $3, use True or False."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12行和16行是多余的。

fi
args="batch_size=${bs}"
config="${topology}.py"
paddle train --job=time \
--config=$config \
--use_mkldnn=$use_mkldnn \
--use_gpu=False \
--trainer_count=$thread \
--log_period=10 \
--test_period=100 \
--config_args=$args \
2>&1 | tee ${log}
}

if [ ! -d "train.list" ]; then
echo " " > train.list
fi
if [ ! -d "logs" ]; then
mkdir logs
fi

#========= mkldnn =========#
# vgg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

43行也是多余的。

train vgg 64 True
train vgg 128 True
train vgg 256 True

#========== mklml ===========#
train vgg 64 False
train vgg 128 False
train vgg 256 False
103 changes: 103 additions & 0 deletions benchmark/paddle/image/vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python
from paddle.trainer_config_helpers import *

height = 224
width = 224
num_class = 1000
batch_size = get_config_arg('batch_size', int, 64)
layer_num = get_config_arg('layer_num', int, 19)

args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
define_py_data_sources2(
"train.list", None, module="provider", obj="process", args=args)

settings(
batch_size=batch_size,
learning_rate=0.01 / batch_size,
learning_method=MomentumOptimizer(0.9),
regularization=L2Regularization(0.0005 * batch_size))

img = data_layer(name='image', size=height * width * 3)


def vgg_network(vgg_num=3):
tmp = img_conv_group(
input=img,
num_channels=3,
conv_padding=1,
conv_num_filter=[64, 64],
conv_filter_size=3,
conv_act=ReluActivation(),
pool_size=2,
pool_stride=2,
pool_type=MaxPooling())

tmp = img_conv_group(
input=tmp,
conv_num_filter=[128, 128],
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)

channels = []
for i in range(vgg_num):
channels.append(256)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)
channels = []
for i in range(vgg_num):
channels.append(512)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)

tmp = fc_layer(
input=tmp,
size=4096,
act=ReluActivation(),
layer_attr=ExtraAttr(drop_rate=0.5))

tmp = fc_layer(
input=tmp,
size=4096,
act=ReluActivation(),
layer_attr=ExtraAttr(drop_rate=0.5))

return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())


if layer_num == 16:
vgg = vgg_network(3)
elif layer_num == 19:
vgg = vgg_network(4)
else:
print("Wrong layer number.")

lab = data_layer('label', num_class)
loss = cross_entropy(input=vgg, label=lab)
outputs(loss)
4 changes: 4 additions & 0 deletions cmake/util.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ function(link_paddle_exe TARGET_NAME)
target_link_libraries(${TARGET_NAME} log)
endif(ANDROID)

if(WITH_MKLDNN AND WITH_MKLML AND MKLDNN_IOMP_DIR)
target_link_libraries(${TARGET_NAME} "-L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
endif()

add_dependencies(${TARGET_NAME} ${external_project_dependencies})
endfunction()

Expand Down
3 changes: 2 additions & 1 deletion paddle/gserver/activations/MKLDNNActivation.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class MKLDNNEltwiseActivation : public MKLDNNActivation {
if (cnt_ == act.value->getElementCnt()) {
return;
}
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
cnt_ = act.value->getElementCnt();
stream_.reset(new MKLDNNStream());
auto eng = CPUEngine::Instance().getEngine();
Expand All @@ -110,7 +111,6 @@ class MKLDNNEltwiseActivation : public MKLDNNActivation {
float alpha = getAlpha();
float beta = getBeta();

/// forward
pipelineFwd_.clear();
val_ = std::dynamic_pointer_cast<MKLDNNMatrix>(act.value);
if (val_ == nullptr) {
Expand Down Expand Up @@ -152,6 +152,7 @@ class MKLDNNEltwiseActivation : public MKLDNNActivation {
if (!needResetBwd_) {
return;
}
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
needResetBwd_ = false;
mkldnn::algorithm algo = getAlgo(this->getName());
float alpha = getBwdAlpha();
Expand Down
31 changes: 20 additions & 11 deletions paddle/gserver/layers/MKLDNNConvLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool MKLDNNConvLayer::init(const LayerMap& layerMap,

// create biases
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_, 0));
}
return true;
}
Expand Down Expand Up @@ -251,22 +251,31 @@ void MKLDNNConvLayer::resetInValue(
// create buffer and reorder if input value do not match
cpuInVal_ = nullptr;
cvtInVal_ = nullptr;
if (inputIsOnlyMKLDNN()) {
MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK(dnnIn) << "Input should be MKLDNNMatrix";
if (dnnIn->getPrimitiveDesc() != in->getPrimitiveDesc()) {
CHECK_EQ(dnnIn->getFormat(), format::nc);

MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK_EQ(inputIsOnlyMKLDNN(), dnnIn != nullptr);
if (dnnIn != nullptr && dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) {
in = dnnIn;
return;
}
if (dnnIn) {
if (dnnIn->getFormat() == format::nc) {
CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format";
// create a new one with nchw format and same data
memory::dims inDims = memory::dims{bs_, ic_, 1, 1};
dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_);
CHECK(dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc());
}
in = dnnIn;
if (dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) {
in = dnnIn;
return;
}
cpuInVal_ = dnnIn;
in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc());
cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in);
CHECK(cvtInVal_) << "should not be emptry";
} else {
const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE);
memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_};
cpuInVal_ = MKLDNNMatrix::create(cpuIn, inDims, format::nchw, engine_);
cpuInVal_ = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_);
if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) {
// create new mkldnn matrix
in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc());
Expand Down Expand Up @@ -535,7 +544,7 @@ void MKLDNNConvLayer::resetWgtValBwdData(
} else {
wgtValBwdData_ = wgtVal_;
}
VLOG(MKLDNN_FMTS) << "weight value format for backward data"
VLOG(MKLDNN_FMTS) << "weight value format for backward data: "
<< wgtValBwdData_->getFormat();
}

Expand Down
11 changes: 9 additions & 2 deletions paddle/gserver/layers/MKLDNNFcLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,

// create biases
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_, 0));
}
return true;
}
Expand Down Expand Up @@ -161,9 +161,16 @@ void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) {

void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias) {
format wgtFmt = format::oihw;
if (inVal_->getFormat() == format::nChw8c) {
wgtFmt = format::oIhw8i;
} else if (inVal_->getFormat() == format::nChw16c) {
wgtFmt = format::oIhw16i;
}
wgt = MKLDNNMatrix::create(
weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_);
weight_->getW(), {oc_, ic_, ih_, iw_}, wgtFmt, engine_);
wgt->downSpatial();
VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat();

bias = (biases_ && biases_->getW())
? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_)
Expand Down
2 changes: 2 additions & 0 deletions paddle/gserver/layers/MKLDNNLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class MKLDNNLayer : public Layer {
copySeqInfoToOutputs();
size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt();
if (inputElemenCnt_ != elemenCnt) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
// reset when input total sizes changed, not only the batchsize
inputElemenCnt_ = elemenCnt;
reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_);
Expand Down Expand Up @@ -142,6 +143,7 @@ class MKLDNNLayer : public Layer {

void backward(const UpdateCallback& callback) override {
if (needResetBwd_) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_);
needResetBwd_ = false;
}
Expand Down
13 changes: 13 additions & 0 deletions paddle/trainer/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ add_test(NAME test_CompareTwoNets
--config_file_a=trainer/tests/sample_trainer_config_qb_rnn.conf --config_file_b=trainer/tests/sample_trainer_config_rnn.conf
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)

################ test_CompareMKLDNNandCPU ######################
if(WITH_MKLDNN)
add_unittest_without_exec(test_CompareMKLDNNandCPU
test_CompareTwoNets.cpp)
add_test(NAME test_CompareMKLDNNandCPU
COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python/
${CMAKE_CURRENT_BINARY_DIR}/test_CompareMKLDNNandCPU
--config_file_a=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_a=True
--config_file_b=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_b=False
--use_gpu=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

49行的use_gpu可以去掉吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是因为目前使用对比的时候是不希望使用gpu的。

WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)
endif()

############### test_CompareTwoOpts ###################
add_unittest_without_exec(test_CompareTwoOpts
test_CompareTwoOpts.cpp)
Expand Down
63 changes: 63 additions & 0 deletions paddle/trainer/tests/sample_trainer_config_simple_net.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer_config_helpers import *

################################### Data Configuration ###################################
TrainData(ProtoData(files = "trainer/tests/mnist.list"))
################################### Algorithm Configuration ###################################
settings(batch_size = 1000,
learning_method = MomentumOptimizer(momentum=0.5, sparse=False))
################################### Network Configuration ###################################
data = data_layer(name ="input", size=784)

tmp = img_conv_layer(input=data,
num_channels=1,
filter_size=3,
num_filters=32,
padding=1,
shared_biases=True,
act=ReluActivation())

tmp = img_pool_layer(input=tmp,
pool_size=3,
stride=2,
padding=1,
pool_type=AvgPooling())

tmp = img_conv_layer(input=tmp,
filter_size=3,
num_filters=64,
padding=1,
shared_biases=True,
act=ReluActivation())

tmp = img_pool_layer(input=tmp,
pool_size=3,
stride=2,
padding=1,
pool_type=MaxPooling())

tmp = fc_layer(input=tmp, size=64,
bias_attr=True,
act=ReluActivation())

output = fc_layer(input=tmp, size=10,
bias_attr=True,
act=SoftmaxActivation())

lbl = data_layer(name ="label", size=10)

cost = classification_cost(input=output, label=lbl)
outputs(cost)
Loading