|
18 | 18 | from .. import core |
19 | 19 | from ..framework import Program, Variable, Operator |
20 | 20 | from ..layer_helper import LayerHelper, unique_name |
| 21 | +from ops import logical_and, logical_not, logical_or |
21 | 22 |
|
22 | 23 | __all__ = [ |
23 | 24 | 'split_lod_tensor', |
|
27 | 28 | 'StaticRNNMemoryLink', |
28 | 29 | 'WhileGuard', |
29 | 30 | 'While', |
| 31 | + 'Switch', |
30 | 32 | 'lod_rank_table', |
31 | 33 | 'max_sequence_len', |
32 | 34 | 'topk', |
@@ -1063,11 +1065,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): |
1063 | 1065 |
|
1064 | 1066 |
|
1065 | 1067 | class ConditionalBlock(object): |
1066 | | - def __init__(self, inputs, name=None): |
| 1068 | + def __init__(self, inputs, is_scalar_condition=False, name=None): |
1067 | 1069 | for each_input in inputs: |
1068 | 1070 | if not isinstance(each_input, Variable): |
1069 | 1071 | raise TypeError("Each input should be variable") |
1070 | 1072 | self.inputs = inputs |
| 1073 | + self.is_scalar_condition = is_scalar_condition |
1071 | 1074 | self.helper = LayerHelper('conditional_block', name=name) |
1072 | 1075 |
|
1073 | 1076 | def block(self): |
@@ -1112,7 +1115,66 @@ def complete(self): |
1112 | 1115 | }, |
1113 | 1116 | outputs={'Out': out_list, |
1114 | 1117 | 'Scope': [step_scope]}, |
1115 | | - attrs={'sub_block': inside_block}) |
| 1118 | + attrs={ |
| 1119 | + 'sub_block': inside_block, |
| 1120 | + 'is_scalar_condition': self.is_scalar_condition |
| 1121 | + }) |
| 1122 | + |
| 1123 | + |
| 1124 | +class Switch(object): |
| 1125 | + def __init__(self, name=None): |
| 1126 | + self.helper = LayerHelper('switch', name=name) |
| 1127 | + self.inside_scope = False |
| 1128 | + self.pre_not_conditions = [] |
| 1129 | + |
| 1130 | + def case(self, condition): |
| 1131 | + """create a new block for this condition |
| 1132 | + """ |
| 1133 | + if not self.inside_scope: |
| 1134 | + raise ValueError("case should be called inside with") |
| 1135 | + |
| 1136 | + if len(self.pre_not_conditions) == 0: |
| 1137 | + cond_block = ConditionalBlock([condition], is_scalar_condition=True) |
| 1138 | + not_cond = logical_not(x=condition) |
| 1139 | + self.pre_not_conditions.append(not_cond) |
| 1140 | + else: |
| 1141 | + pre_cond_num = len(self.pre_not_conditions) |
| 1142 | + pre_not_cond = self.pre_not_conditions[pre_cond_num - 1] |
| 1143 | + new_not_cond = logical_and( |
| 1144 | + x=pre_not_cond, y=logical_not(x=condition)) |
| 1145 | + self.pre_not_conditions.append(new_not_cond) |
| 1146 | + cond_block = ConditionalBlock( |
| 1147 | + [logical_and( |
| 1148 | + x=pre_not_cond, y=condition)], |
| 1149 | + is_scalar_condition=True) |
| 1150 | + |
| 1151 | + return ConditionalBlockGuard(cond_block) |
| 1152 | + |
| 1153 | + def default(self): |
| 1154 | + """create a default case for this switch |
| 1155 | + """ |
| 1156 | + pre_cond_num = len(self.pre_not_conditions) |
| 1157 | + if pre_cond_num == 0: |
| 1158 | + raise ValueError("there should be at least one condition") |
| 1159 | + cond_block = ConditionalBlock( |
| 1160 | + [self.pre_not_conditions[pre_cond_num - 1]], |
| 1161 | + is_scalar_condition=True) |
| 1162 | + return ConditionalBlockGuard(cond_block) |
| 1163 | + |
| 1164 | + def __enter__(self): |
| 1165 | + """ |
| 1166 | + set flag that now is inside switch.block {} |
| 1167 | + :return: |
| 1168 | + """ |
| 1169 | + self.inside_scope = True |
| 1170 | + return self |
| 1171 | + |
| 1172 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 1173 | + self.inside_scope = False |
| 1174 | + if exc_type is not None: |
| 1175 | + return False # re-raise exception |
| 1176 | + |
| 1177 | + return True |
1116 | 1178 |
|
1117 | 1179 |
|
1118 | 1180 | class IfElseBlockGuard(object): |
|
0 commit comments