Skip to content

Commit 0aa416d

Browse files
authored
[PIR]Open uts for sequence_mask (#60704)
1 parent 098cb1f commit 0aa416d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

test/sequence/test_sequence_mask.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020

2121
import paddle
2222
from paddle.base.framework import (
23-
Program,
2423
convert_np_dtype_to_dtype_,
25-
program_guard,
2624
)
25+
from paddle.pir_utils import test_with_pir_api
2726

2827

2928
def sequence_mask_wraper(x, maxlen_tensor=None, maxlen=-1, mask_dtype='int64'):
@@ -168,15 +167,16 @@ def initParameters(self):
168167

169168

170169
class TestSequenceMaskOpError(unittest.TestCase):
170+
@test_with_pir_api
171171
def test_errors(self):
172-
with program_guard(Program(), Program()):
172+
with paddle.static.program_guard(
173+
paddle.static.Program(), paddle.static.Program()
174+
):
173175
input_data = np.random.uniform(1, 5, [4]).astype("float32")
174176

175177
def test_Variable():
176178
# the input must be Variable
177-
paddle.static.nn.sequence_lod.sequence_mask(
178-
input_data, maxlen=4
179-
)
179+
paddle.nn.functional.sequence_mask(input_data, maxlen=4)
180180

181181
self.assertRaises(TypeError, test_Variable)
182182

0 commit comments

Comments
 (0)