@@ -56,17 +56,24 @@ def test_map(self):
5656 def f (x ):
5757 return x ** 2
5858
59- xs = KerasTensor ((None ,))
60- self .assertEqual (core .map (f , xs ).shape , (None ,))
59+ xs = KerasTensor ((None , 5 ))
60+ self .assertEqual (core .map (f , xs ).shape , (None , 5 ))
6161
6262 # Test nested output
6363 def f2 (x ):
6464 return {"a" : x ** 2 , "b" : x * 10 }
6565
66- xs = KerasTensor ((None ,))
66+ xs = KerasTensor ((None , 5 ))
6767 ys = core .map (f2 , xs )
68- self .assertEqual (ys ["a" ].shape , (None ,))
69- self .assertEqual (ys ["b" ].shape , (None ,))
68+ self .assertEqual (ys ["a" ].shape , (None , 5 ))
69+ self .assertEqual (ys ["b" ].shape , (None , 5 ))
70+
71+ # Test nested input
72+ def f3 (x ):
73+ return x [0 ] + x [1 ]
74+
75+ xs = (KerasTensor ((None , 5 )), KerasTensor ((None , 5 )))
76+ self .assertEqual (core .map (f3 , xs ).shape , (None , 5 ))
7077
7178 def test_saturate_cast (self ):
7279 x = KerasTensor ((3 , 5 , None ), dtype = "float32" )
@@ -125,6 +132,29 @@ def fn(x, y):
125132 self .assertEqual (result [0 ].shape , (None ,))
126133 self .assertEqual (result [1 ].shape , (None ,))
127134
135+ def test_vectorized_map (self ):
136+ def f (x ):
137+ return x ** 2
138+
139+ xs = KerasTensor ((None , 5 ))
140+ self .assertEqual (core .vectorized_map (f , xs ).shape , (None , 5 ))
141+
142+ # Test nested output
143+ def f2 (x ):
144+ return {"a" : x ** 2 , "b" : x * 10 }
145+
146+ xs = KerasTensor ((None , 5 ))
147+ ys = core .vectorized_map (f2 , xs )
148+ self .assertEqual (ys ["a" ].shape , (None , 5 ))
149+ self .assertEqual (ys ["b" ].shape , (None , 5 ))
150+
151+ # Test nested input
152+ def f3 (x ):
153+ return x [0 ] + x [1 ]
154+
155+ xs = (KerasTensor ((None , 5 )), KerasTensor ((None , 5 )))
156+ self .assertEqual (core .vectorized_map (f3 , xs ).shape , (None , 5 ))
157+
128158 def test_while_loop (self ):
129159 def cond (args ):
130160 return tree .flatten (args )[0 ] < 10
@@ -203,18 +233,25 @@ def test_map(self):
203233 def f (x ):
204234 return x ** 2
205235
206- xs = KerasTensor ((6 ,))
236+ xs = KerasTensor ((6 , 5 ))
207237 ys = core .map (f , xs )
208- self .assertEqual (ys .shape , (6 ,))
238+ self .assertEqual (ys .shape , (6 , 5 ))
209239
210240 # Test nested output
211241 def f2 (x ):
212242 return {"a" : x ** 2 , "b" : x * 10 }
213243
214- xs = KerasTensor ((6 ,))
244+ xs = KerasTensor ((6 , 5 ))
215245 ys = core .map (f2 , xs )
216- self .assertEqual (ys ["a" ].shape , (6 ,))
217- self .assertEqual (ys ["b" ].shape , (6 ,))
246+ self .assertEqual (ys ["a" ].shape , (6 , 5 ))
247+ self .assertEqual (ys ["b" ].shape , (6 , 5 ))
248+
249+ # Test nested input
250+ def f3 (x ):
251+ return x [0 ] + x [1 ]
252+
253+ xs = (KerasTensor ((6 , 5 )), KerasTensor ((6 , 5 )))
254+ self .assertEqual (core .map (f3 , xs ).shape , (6 , 5 ))
218255
219256 def test_saturate_cast (self ):
220257 x = KerasTensor ((3 , 5 , 7 ), dtype = "float32" )
@@ -307,6 +344,30 @@ def fn(x, y):
307344 self .assertEqual (core .switch (index , [fn ], x , y )[0 ].shape , (5 ,))
308345 self .assertEqual (core .switch (index , [fn ], x , y )[1 ].shape , (2 ,))
309346
347+ def test_vectorized_map (self ):
348+ def f (x ):
349+ return x ** 2
350+
351+ xs = KerasTensor ((6 , 5 ))
352+ ys = core .vectorized_map (f , xs )
353+ self .assertEqual (ys .shape , (6 , 5 ))
354+
355+ # Test nested output
356+ def f2 (x ):
357+ return {"a" : x ** 2 , "b" : x * 10 }
358+
359+ xs = KerasTensor ((6 , 5 ))
360+ ys = core .vectorized_map (f2 , xs )
361+ self .assertEqual (ys ["a" ].shape , (6 , 5 ))
362+ self .assertEqual (ys ["b" ].shape , (6 , 5 ))
363+
364+ # Test nested input
365+ def f3 (x ):
366+ return x [0 ] + x [1 ]
367+
368+ xs = (KerasTensor ((6 , 5 )), KerasTensor ((6 , 5 )))
369+ self .assertEqual (core .vectorized_map (f3 , xs ).shape , (6 , 5 ))
370+
310371 def test_while_loop (self ):
311372 def cond (args ):
312373 return tree .flatten (args )[0 ] < 10
0 commit comments