@@ -289,5 +289,55 @@ def test_reshape(rng, dtype):
289289 np .testing .assert_array_equal (actual , expected )
290290
291291 # DENSE
292- # NOTE: dense reshape is probably broken in MLIR
292+ # NOTE: dense reshape is probably broken in MLIR in 19.x branch
293293 # dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
294+
295+
296+ @parametrize_dtypes
297+ def test_broadcast_to (dtype ):
298+ # CSR, CSC, COO
299+ for shape , new_shape , dimensions , input_arr , expected_arrs in [
300+ (
301+ (3 , 4 ),
302+ (2 , 3 , 4 ),
303+ [0 ],
304+ np .array ([[0 , 1 , 0 , 3 ], [0 , 0 , 4 , 5 ], [6 , 7 , 0 , 0 ]]),
305+ [
306+ np .array ([0 , 3 , 6 ]),
307+ np .array ([0 , 1 , 2 , 0 , 1 , 2 ]),
308+ np .array ([0 , 2 , 4 , 6 , 8 , 10 , 12 ]),
309+ np .array ([1 , 3 , 2 , 3 , 0 , 1 , 1 , 3 , 2 , 3 , 0 , 1 ]),
310+ np .array ([1.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 1.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 ]),
311+ ],
312+ ),
313+ (
314+ (4 , 2 ),
315+ (4 , 2 , 2 ),
316+ [1 ],
317+ np .array ([[0 , 1 ], [0 , 0 ], [2 , 3 ], [4 , 0 ]]),
318+ [
319+ np .array ([0 , 2 , 2 , 4 , 6 ]),
320+ np .array ([0 , 1 , 0 , 1 , 0 , 1 ]),
321+ np .array ([0 , 1 , 2 , 4 , 6 , 7 , 8 ]),
322+ np .array ([1 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ]),
323+ np .array ([1.0 , 1.0 , 2.0 , 3.0 , 2.0 , 3.0 , 4.0 , 4.0 ]),
324+ ],
325+ ),
326+ ]:
327+ for fn_format in [sps .csr_array , sps .csc_array , sps .coo_array ]:
328+ arr = fn_format (input_arr , shape = shape , dtype = dtype )
329+ arr .sum_duplicates ()
330+ tensor = sparse .asarray (arr )
331+ result = sparse .broadcast_to (tensor , new_shape , dimensions = dimensions ).to_scipy_sparse ()
332+
333+ for actual , expected in zip (result , expected_arrs , strict = False ):
334+ np .testing .assert_allclose (actual , expected )
335+
336+ # DENSE
337+ np_arr = np .array ([0 , 0 , 2 , 3 , 0 , 1 ])
338+ arr = np .asarray (np_arr , dtype = dtype )
339+ tensor = sparse .asarray (arr )
340+ result = sparse .broadcast_to (tensor , (3 , 6 ), dimensions = [0 ]).to_scipy_sparse ()
341+
342+ assert result .format == "csr"
343+ np .testing .assert_allclose (result .todense (), np .repeat (np_arr [np .newaxis ], 3 , axis = 0 ))
0 commit comments