Skip to content

Commit b5e6f15

Browse files
authored
[CodeStyle][SIM117] Combine multiple with statements (part2) (#73659)
1 parent 1893444 commit b5e6f15

File tree

97 files changed

+3993
-3684
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+3993
-3684
lines changed

test/deprecated/legacy_test/test_program_deprecated.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,12 @@ def test_copy_info_from_error(self):
140140
def build_program():
141141
main_program = paddle.static.Program()
142142
startup_program = paddle.static.Program()
143-
with paddle.utils.unique_name.guard():
144-
with paddle.static.program_guard(main_program, startup_program):
145-
x = paddle.static.data(name='x', shape=[3, 2, 1])
146-
out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2)
143+
with (
144+
paddle.utils.unique_name.guard(),
145+
paddle.static.program_guard(main_program, startup_program),
146+
):
147+
x = paddle.static.data(name='x', shape=[3, 2, 1])
148+
out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2)
147149
return main_program
148150

149151

@@ -177,10 +179,12 @@ class TestProgramHash(unittest.TestCase):
177179
def build_program(self):
178180
main_program = paddle.static.Program()
179181
startup_program = paddle.static.Program()
180-
with paddle.utils.unique_name.guard():
181-
with paddle.static.program_guard(main_program, startup_program):
182-
x = paddle.static.data(name='x', shape=[3, 2, 1])
183-
out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2)
182+
with (
183+
paddle.utils.unique_name.guard(),
184+
paddle.static.program_guard(main_program, startup_program),
185+
):
186+
x = paddle.static.data(name='x', shape=[3, 2, 1])
187+
out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2)
184188
return main_program
185189

186190
def test_program_need_update(self):

test/deprecated/legacy_test/test_program_prune_backward_deprecated.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,12 @@ def program_scope_guard(self):
579579
prog = base.Program()
580580
startup_prog = base.Program()
581581
scope = base.core.Scope()
582-
with base.scope_guard(scope):
583-
with base.program_guard(prog, startup_prog):
584-
with base.unique_name.guard():
585-
yield
582+
with (
583+
base.scope_guard(scope),
584+
base.program_guard(prog, startup_prog),
585+
base.unique_name.guard(),
586+
):
587+
yield
586588

587589

588590
if __name__ == '__main__':

test/deprecated/legacy_test/test_prune_deprecated.py

Lines changed: 387 additions & 379 deletions
Large diffs are not rendered by default.

test/deprecated/legacy_test/test_py_func_op_deprecated.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -176,38 +176,36 @@ def test_main(use_cuda, use_py_func_op):
176176
if use_cuda and not base.core.is_compiled_with_cuda():
177177
return None
178178

179-
with base.program_guard(base.Program(), base.Program()):
180-
with base.scope_guard(base.core.Scope()):
181-
gen = paddle.seed(1)
182-
np.random.seed(1)
183-
img = paddle.static.data(
184-
name='image', shape=[-1, 784], dtype='float32'
185-
)
186-
label = paddle.static.data(
187-
name='label', shape=[-1, 1], dtype='int64'
188-
)
189-
loss = simple_fc_net(img, label, use_py_func_op)
190-
optimizer = paddle.optimizer.SGD(learning_rate=1e-3)
191-
optimizer.minimize(loss)
192-
193-
place = base.CUDAPlace(0) if use_cuda else base.CPUPlace()
194-
feeder = base.DataFeeder(feed_list=[img, label], place=place)
195-
r = paddle.batch(reader, batch_size=10)
196-
197-
exe = base.Executor(place)
198-
exe.run(base.default_startup_program())
199-
200-
train_cp = base.default_main_program()
201-
fetch_list = [loss]
202-
203-
ret = []
204-
for epoch_id in range(2):
205-
for d in r():
206-
(L,) = exe.run(
207-
train_cp, feed=feeder.feed(d), fetch_list=fetch_list
208-
)
209-
ret.append(L)
210-
return np.array(ret)
179+
with (
180+
base.program_guard(base.Program(), base.Program()),
181+
base.scope_guard(base.core.Scope()),
182+
):
183+
gen = paddle.seed(1)
184+
np.random.seed(1)
185+
img = paddle.static.data(name='image', shape=[-1, 784], dtype='float32')
186+
label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
187+
loss = simple_fc_net(img, label, use_py_func_op)
188+
optimizer = paddle.optimizer.SGD(learning_rate=1e-3)
189+
optimizer.minimize(loss)
190+
191+
place = base.CUDAPlace(0) if use_cuda else base.CPUPlace()
192+
feeder = base.DataFeeder(feed_list=[img, label], place=place)
193+
r = paddle.batch(reader, batch_size=10)
194+
195+
exe = base.Executor(place)
196+
exe.run(base.default_startup_program())
197+
198+
train_cp = base.default_main_program()
199+
fetch_list = [loss]
200+
201+
ret = []
202+
for epoch_id in range(2):
203+
for d in r():
204+
(L,) = exe.run(
205+
train_cp, feed=feeder.feed(d), fetch_list=fetch_list
206+
)
207+
ret.append(L)
208+
return np.array(ret)
211209

212210

213211
class TestPyFuncOpUseExecutor(unittest.TestCase):

test/deprecated/legacy_test/utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,16 @@ def impl(*args, **kwargs):
184184
original_flag_value = get_flags(pt_flag)[pt_flag]
185185
if os.environ.get('FLAGS_use_stride_kernel', False):
186186
return
187-
with static.scope_guard(static.Scope()):
188-
with static.program_guard(static.Program()):
189-
with EnvironmentVariableGuard(ENV_ENABLE_PIR_WITH_PT, True):
190-
try:
191-
set_flags({pt_flag: True})
192-
ir_outs = fn(*args, **kwargs)
193-
finally:
194-
set_flags({pt_flag: original_flag_value})
187+
with (
188+
static.scope_guard(static.Scope()),
189+
static.program_guard(static.Program()),
190+
EnvironmentVariableGuard(ENV_ENABLE_PIR_WITH_PT, True),
191+
):
192+
try:
193+
set_flags({pt_flag: True})
194+
ir_outs = fn(*args, **kwargs)
195+
finally:
196+
set_flags({pt_flag: original_flag_value})
195197
return ir_outs
196198

197199
return impl

test/deprecated/mkldnn/test_mkldnn_cpu_bfloat16_pass_deprecated.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,25 @@
2828
class TestMKLDNNCpuBfloat16Pass(InferencePassTest):
2929
def setUp(self):
3030
self.init_data()
31-
with paddle.pir_utils.OldIrGuard():
32-
with base.program_guard(self.main_program, self.startup_program):
33-
x = paddle.static.data(
34-
name='x', shape=[-1, *self.shape_x], dtype=self.d_type
35-
)
31+
with (
32+
paddle.pir_utils.OldIrGuard(),
33+
base.program_guard(self.main_program, self.startup_program),
34+
):
35+
x = paddle.static.data(
36+
name='x', shape=[-1, *self.shape_x], dtype=self.d_type
37+
)
3638

37-
out = paddle.transpose(x, perm=[0, 1, 2, 3])
38-
out = paddle.reshape(out, [0, 0, 0, 0])
39+
out = paddle.transpose(x, perm=[0, 1, 2, 3])
40+
out = paddle.reshape(out, [0, 0, 0, 0])
3941

40-
out = paddle.static.nn.fc(out, size=1)
42+
out = paddle.static.nn.fc(out, size=1)
4143

42-
self.feeds = {
43-
"x": np.random.random([self.bs, *self.shape_x]).astype(
44-
self.d_type
45-
)
46-
}
47-
self.fetch_list = [out]
44+
self.feeds = {
45+
"x": np.random.random([self.bs, *self.shape_x]).astype(
46+
self.d_type
47+
)
48+
}
49+
self.fetch_list = [out]
4850

4951
def init_data(self):
5052
self.bs = 8

test/deprecated/mkldnn/test_mkldnn_elt_act_fuse_pass_deprecated.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,24 @@ class ElementwiseActivationOneDNNFusePassTest(InferencePassTest):
3333

3434
def setUp(self):
3535
self.set_params()
36-
with paddle.pir_utils.OldIrGuard():
37-
with base.program_guard(self.main_program, self.startup_program):
38-
data_A = paddle.static.data(
39-
name="data_A", shape=[-1, 3, 100, 100], dtype="float32"
40-
)
41-
data_B = paddle.static.data(
42-
name="data_B", shape=[-1, 3, 100, 100], dtype="float32"
43-
)
44-
elt_out = self.operand(data_A, data_B)
45-
if self.act is not None:
46-
if self.act_beta is not None:
47-
elt_out = self.act(
48-
elt_out, self.act_alpha, self.act_beta
49-
)
50-
elif self.act_alpha is not None:
51-
elt_out = self.act(elt_out, self.act_alpha)
52-
else:
53-
elt_out = self.act(elt_out)
36+
with (
37+
paddle.pir_utils.OldIrGuard(),
38+
base.program_guard(self.main_program, self.startup_program),
39+
):
40+
data_A = paddle.static.data(
41+
name="data_A", shape=[-1, 3, 100, 100], dtype="float32"
42+
)
43+
data_B = paddle.static.data(
44+
name="data_B", shape=[-1, 3, 100, 100], dtype="float32"
45+
)
46+
elt_out = self.operand(data_A, data_B)
47+
if self.act is not None:
48+
if self.act_beta is not None:
49+
elt_out = self.act(elt_out, self.act_alpha, self.act_beta)
50+
elif self.act_alpha is not None:
51+
elt_out = self.act(elt_out, self.act_alpha)
52+
else:
53+
elt_out = self.act(elt_out)
5454

5555
self.feeds = {
5656
"data_A": np.random.random((1, 3, 100, 100)).astype("float32"),

test/deprecated/mkldnn/test_mkldnn_matmul_op_output_fuse_pass_deprecated.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,21 @@ def init_data(self):
3333
self.enable_mkldnn = True
3434

3535
def make_network(self):
36-
with paddle.pir_utils.OldIrGuard():
37-
with base.program_guard(self.main_program, self.startup_program):
38-
x = paddle.static.data(
39-
name='x', shape=[-1, *self.shape_x], dtype=self.d_type
40-
)
41-
y = paddle.static.data(
42-
name='y', shape=[-1, *self.shape_y], dtype=self.d_type
43-
)
44-
out = paddle.matmul(x, y)
45-
out = paddle.transpose(out, perm=[0, 2, 1, 3])
46-
out = paddle.reshape(
47-
out, [0, 0, self.shape_y[0] * self.shape_y[2]]
48-
)
49-
50-
out = F.relu(out)
36+
with (
37+
paddle.pir_utils.OldIrGuard(),
38+
base.program_guard(self.main_program, self.startup_program),
39+
):
40+
x = paddle.static.data(
41+
name='x', shape=[-1, *self.shape_x], dtype=self.d_type
42+
)
43+
y = paddle.static.data(
44+
name='y', shape=[-1, *self.shape_y], dtype=self.d_type
45+
)
46+
out = paddle.matmul(x, y)
47+
out = paddle.transpose(out, perm=[0, 2, 1, 3])
48+
out = paddle.reshape(out, [0, 0, self.shape_y[0] * self.shape_y[2]])
49+
50+
out = F.relu(out)
5151
return out
5252

5353
def setUp(self):
@@ -78,18 +78,20 @@ def init_data(self):
7878

7979
class TestMKLDNNMatmulOpNotFusedWrongTransposeAxis(TestMKLDNNMatmulFuseOp):
8080
def make_network(self):
81-
with paddle.pir_utils.OldIrGuard():
82-
with base.program_guard(self.main_program, self.startup_program):
83-
x = paddle.static.data(
84-
name='x', shape=[-1, *self.shape_x], dtype=self.d_type
85-
)
86-
y = paddle.static.data(
87-
name='y', shape=[-1, *self.shape_y], dtype=self.d_type
88-
)
89-
out = paddle.matmul(x, y)
90-
out = paddle.transpose(out, perm=[0, 1, 2, 3])
91-
out = paddle.reshape(out, [0, 0, 0, 0])
92-
out = paddle.static.nn.fc(out, size=1)
81+
with (
82+
paddle.pir_utils.OldIrGuard(),
83+
base.program_guard(self.main_program, self.startup_program),
84+
):
85+
x = paddle.static.data(
86+
name='x', shape=[-1, *self.shape_x], dtype=self.d_type
87+
)
88+
y = paddle.static.data(
89+
name='y', shape=[-1, *self.shape_y], dtype=self.d_type
90+
)
91+
out = paddle.matmul(x, y)
92+
out = paddle.transpose(out, perm=[0, 1, 2, 3])
93+
out = paddle.reshape(out, [0, 0, 0, 0])
94+
out = paddle.static.nn.fc(out, size=1)
9395
return out
9496

9597

@@ -102,22 +104,22 @@ def init_data(self):
102104
self.enable_mkldnn = True
103105

104106
def make_network(self):
105-
with paddle.pir_utils.OldIrGuard():
106-
with base.program_guard(self.main_program, self.startup_program):
107-
x = paddle.static.data(
108-
name='x', shape=[-1, *self.shape_x], dtype=self.d_type
109-
)
110-
y = paddle.static.data(
111-
name='y', shape=[-1, *self.shape_y], dtype=self.d_type
112-
)
113-
out = paddle.matmul(x, y)
114-
out = paddle.transpose(out, perm=[0, 2, 1, 3])
115-
out = paddle.transpose(out, perm=[0, 1, 2, 3]) # breaks pattern
116-
out = paddle.reshape(
117-
out, [0, 0, self.shape_y[0] * self.shape_y[2]]
118-
)
119-
120-
out = F.relu(out)
107+
with (
108+
paddle.pir_utils.OldIrGuard(),
109+
base.program_guard(self.main_program, self.startup_program),
110+
):
111+
x = paddle.static.data(
112+
name='x', shape=[-1, *self.shape_x], dtype=self.d_type
113+
)
114+
y = paddle.static.data(
115+
name='y', shape=[-1, *self.shape_y], dtype=self.d_type
116+
)
117+
out = paddle.matmul(x, y)
118+
out = paddle.transpose(out, perm=[0, 2, 1, 3])
119+
out = paddle.transpose(out, perm=[0, 1, 2, 3]) # breaks pattern
120+
out = paddle.reshape(out, [0, 0, self.shape_y[0] * self.shape_y[2]])
121+
122+
out = F.relu(out)
121123
return out
122124

123125

test/deprecated/mkldnn/test_mkldnn_reshape_transpose_matmul_v2_fuse_pass_deprecated.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,26 @@ def setUp(self):
3030
self.set_params()
3131
self.transpose_perm = [0, 2, 1, 3]
3232
self.pass_name = 'reshape_transpose_matmul_onednn_fuse_pass'
33-
with paddle.pir_utils.OldIrGuard():
34-
with base.program_guard(self.main_program, self.startup_program):
35-
data = paddle.static.data(
36-
name="data", shape=self.data_shape, dtype="float32"
37-
)
38-
weight = paddle.create_parameter(
39-
shape=self.weight_shape, dtype="float32"
40-
)
41-
42-
reshape = paddle.reshape(data, shape=self.reshape_shape)
43-
transpose = paddle.transpose(reshape, self.transpose_perm)
44-
45-
matmul = paddle.matmul(
46-
transpose,
47-
weight,
48-
transpose_x=self.transpose_x,
49-
transpose_y=self.transpose_y,
50-
)
33+
with (
34+
paddle.pir_utils.OldIrGuard(),
35+
base.program_guard(self.main_program, self.startup_program),
36+
):
37+
data = paddle.static.data(
38+
name="data", shape=self.data_shape, dtype="float32"
39+
)
40+
weight = paddle.create_parameter(
41+
shape=self.weight_shape, dtype="float32"
42+
)
43+
44+
reshape = paddle.reshape(data, shape=self.reshape_shape)
45+
transpose = paddle.transpose(reshape, self.transpose_perm)
46+
47+
matmul = paddle.matmul(
48+
transpose,
49+
weight,
50+
transpose_x=self.transpose_x,
51+
transpose_y=self.transpose_y,
52+
)
5153

5254
self.fetch_list = [matmul]
5355
self.enable_mkldnn = True

0 commit comments

Comments
 (0)