Skip to content

Commit 20c4a4c

Browse files
authored
Impl scalar switch case op with condition op (#8184)
Impl scalar switch case op with condition op
1 parent e583201 commit 20c4a4c

File tree

5 files changed

+171
-10
lines changed

5 files changed

+171
-10
lines changed

doc/design/switch.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ The following example shows the usage of `fluid.switch`.
1010
a = fluid.Var(10)
1111
b = fluid.Var(0)
1212

13-
switch = fluid.switch()
14-
with switch.block():
13+
with switch() as switch:
1514
with switch.case(fluid.less_equal(a, 10)):
1615
fluid.print("Case 1")
1716
with switch.case(fluid.larger(a, 0)):

paddle/operators/conditional_block_op.cc

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ class ConditionalOp : public framework::OperatorBase {
4141
});
4242
return retv;
4343
}
44+
45+
bool ScalarCondition(
46+
const std::vector<const framework::LoDTensor *> &ips) const {
47+
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
48+
PADDLE_THROW("should have one initialized input as condition");
49+
}
50+
if (!(ips[0]->type().hash_code() == typeid(bool).hash_code() &&
51+
ips[0]->numel() == 1)) {
52+
PADDLE_THROW(
53+
"condition input's data type should be bool, "
54+
"numel should be 1, actual numel is %d",
55+
ips[0]->numel());
56+
}
57+
return ips[0]->data<bool>()[0];
58+
}
4459
};
4560

4661
class ConditionalBlockOp : public ConditionalOp {
@@ -53,9 +68,15 @@ class ConditionalBlockOp : public ConditionalOp {
5368
void Run(const framework::Scope &scope,
5469
const platform::Place &dev_place) const override {
5570
auto xs = InputTensors(scope);
56-
bool need_run = std::all_of(
57-
xs.begin(), xs.end(),
58-
[](const framework::LoDTensor *t) { return t->numel() != 0; });
71+
72+
bool need_run;
73+
if (Attr<bool>("is_scalar_condition")) {
74+
need_run = ScalarCondition(xs);
75+
} else {
76+
need_run = std::all_of(
77+
xs.begin(), xs.end(),
78+
[](const framework::LoDTensor *t) { return t->numel() != 0; });
79+
}
5980

6081
if (need_run) {
6182
auto *scope_var = scope.FindVar(Output("Scope"));
@@ -88,6 +109,10 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
88109
"scope is std::vector<Scope*>");
89110
AddAttr<framework::BlockDesc *>(
90111
"sub_block", "The step block of conditional block operator");
112+
AddAttr<bool>("is_scalar_condition",
113+
"the input X is used as scalar "
114+
"condition")
115+
.SetDefault(false);
91116
AddComment(R"DOC(Conditional block operator
92117
93118
Run the sub-block if X is not empty. Params is the other inputs and Out is the
@@ -106,9 +131,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
106131
void Run(const framework::Scope &scope,
107132
const platform::Place &dev_place) const override {
108133
auto xs = this->InputTensors(scope);
109-
bool need_run = std::all_of(
110-
xs.begin(), xs.end(),
111-
[](const framework::LoDTensor *t) { return t->numel() != 0; });
134+
135+
bool need_run;
136+
if (Attr<bool>("is_scalar_condition")) {
137+
need_run = ScalarCondition(xs);
138+
} else {
139+
need_run = std::all_of(
140+
xs.begin(), xs.end(),
141+
[](const framework::LoDTensor *t) { return t->numel() != 0; });
142+
}
112143

113144
if (need_run) {
114145
auto *scope_var = scope.FindVar(Input("Scope"));
@@ -182,6 +213,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
182213
grad_op->SetOutput(framework::GradVarName("Params"),
183214
InputGrad("Params", false));
184215
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
216+
grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
185217
return std::unique_ptr<framework::OpDesc>(grad_op);
186218
}
187219
};

python/paddle/v2/fluid/layers/control_flow.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .. import core
1919
from ..framework import Program, Variable, Operator
2020
from ..layer_helper import LayerHelper, unique_name
21+
from ops import logical_and, logical_not, logical_or
2122

2223
__all__ = [
2324
'split_lod_tensor',
@@ -27,6 +28,7 @@
2728
'StaticRNNMemoryLink',
2829
'WhileGuard',
2930
'While',
31+
'Switch',
3032
'lod_rank_table',
3133
'max_sequence_len',
3234
'topk',
@@ -1063,11 +1065,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
10631065

10641066

10651067
class ConditionalBlock(object):
1066-
def __init__(self, inputs, name=None):
1068+
def __init__(self, inputs, is_scalar_condition=False, name=None):
10671069
for each_input in inputs:
10681070
if not isinstance(each_input, Variable):
10691071
raise TypeError("Each input should be variable")
10701072
self.inputs = inputs
1073+
self.is_scalar_condition = is_scalar_condition
10711074
self.helper = LayerHelper('conditional_block', name=name)
10721075

10731076
def block(self):
@@ -1112,7 +1115,66 @@ def complete(self):
11121115
},
11131116
outputs={'Out': out_list,
11141117
'Scope': [step_scope]},
1115-
attrs={'sub_block': inside_block})
1118+
attrs={
1119+
'sub_block': inside_block,
1120+
'is_scalar_condition': self.is_scalar_condition
1121+
})
1122+
1123+
1124+
class Switch(object):
1125+
def __init__(self, name=None):
1126+
self.helper = LayerHelper('switch', name=name)
1127+
self.inside_scope = False
1128+
self.pre_not_conditions = []
1129+
1130+
def case(self, condition):
1131+
"""create a new block for this condition
1132+
"""
1133+
if not self.inside_scope:
1134+
raise ValueError("case should be called inside with")
1135+
1136+
if len(self.pre_not_conditions) == 0:
1137+
cond_block = ConditionalBlock([condition], is_scalar_condition=True)
1138+
not_cond = logical_not(x=condition)
1139+
self.pre_not_conditions.append(not_cond)
1140+
else:
1141+
pre_cond_num = len(self.pre_not_conditions)
1142+
pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
1143+
new_not_cond = logical_and(
1144+
x=pre_not_cond, y=logical_not(x=condition))
1145+
self.pre_not_conditions.append(new_not_cond)
1146+
cond_block = ConditionalBlock(
1147+
[logical_and(
1148+
x=pre_not_cond, y=condition)],
1149+
is_scalar_condition=True)
1150+
1151+
return ConditionalBlockGuard(cond_block)
1152+
1153+
def default(self):
1154+
"""create a default case for this switch
1155+
"""
1156+
pre_cond_num = len(self.pre_not_conditions)
1157+
if pre_cond_num == 0:
1158+
raise ValueError("there should be at least one condition")
1159+
cond_block = ConditionalBlock(
1160+
[self.pre_not_conditions[pre_cond_num - 1]],
1161+
is_scalar_condition=True)
1162+
return ConditionalBlockGuard(cond_block)
1163+
1164+
def __enter__(self):
1165+
"""
1166+
set flag that now is inside switch.block {}
1167+
:return:
1168+
"""
1169+
self.inside_scope = True
1170+
return self
1171+
1172+
def __exit__(self, exc_type, exc_val, exc_tb):
1173+
self.inside_scope = False
1174+
if exc_type is not None:
1175+
return False # re-raise exception
1176+
1177+
return True
11161178

11171179

11181180
class IfElseBlockGuard(object):

python/paddle/v2/fluid/layers/ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@
6161
'clip_by_norm',
6262
'softmax',
6363
'sequence_softmax',
64+
'logical_and',
65+
'logical_or',
66+
'logical_xor',
67+
'logical_not',
6468
] + __activations__
6569

6670
for _OP in set(__all__):
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle.v2.fluid.core as core
18+
import paddle.v2.fluid.layers as layers
19+
import paddle.v2.fluid.framework as framework
20+
from paddle.v2.fluid.executor import Executor
21+
from paddle.v2.fluid.framework import default_startup_program
22+
23+
24+
class TestSwitch(unittest.TestCase):
25+
def check_switch(self, value):
26+
x = layers.fill_constant(shape=[1], dtype='float32', value=value)
27+
28+
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
29+
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
30+
two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0)
31+
three_var = layers.fill_constant(shape=[1], dtype='float32', value=3.0)
32+
33+
result = layers.create_global_var(
34+
shape=[1], value=-1.0, dtype='float32', persistable=True)
35+
36+
with layers.Switch() as switch:
37+
with switch.case(layers.less_than(x, zero_var)):
38+
layers.assign(zero_var, result)
39+
with switch.case(layers.less_than(x, one_var)):
40+
layers.assign(one_var, result)
41+
with switch.case(layers.less_than(x, two_var)):
42+
layers.assign(two_var, result)
43+
with switch.default():
44+
layers.assign(three_var, result)
45+
46+
cpu = core.CPUPlace()
47+
exe = Executor(cpu)
48+
exe.run(default_startup_program())
49+
50+
out = exe.run(feed={}, fetch_list=[result])[0][0]
51+
return out
52+
53+
def test_switch(self):
54+
test_data = {(-0.1, 0), (0.1, 1), (1.1, 2), (2.1, 3)}
55+
for x, expected_result in test_data:
56+
main_program = framework.Program()
57+
startup_program = framework.Program()
58+
with framework.program_guard(main_program, startup_program):
59+
result = self.check_switch(x)
60+
self.assertEqual(result, expected_result)
61+
62+
63+
if __name__ == '__main__':
64+
unittest.main()

0 commit comments

Comments
 (0)