@@ -75,6 +75,15 @@ def sampler_real_floating(size: tuple[int, ...]):
7575 raise NotImplementedError (f"{ dtype = } not yet supported." )
7676
7777
78+ def get_exampe_csf_arrays (dtype : np .dtype ) -> tuple :
79+ pos_1 = np .array ([0 , 1 , 3 ], dtype = np .int64 )
80+ crd_1 = np .array ([1 , 0 , 1 ], dtype = np .int64 )
81+ pos_2 = np .array ([0 , 3 , 5 , 7 ], dtype = np .int64 )
82+ crd_2 = np .array ([0 , 1 , 3 , 0 , 3 , 0 , 1 ], dtype = np .int64 )
83+ data = np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = dtype )
84+ return pos_1 , crd_1 , pos_2 , crd_2 , data
85+
86+
7887@parametrize_dtypes
7988@pytest .mark .parametrize ("shape" , [(100 ,), (10 , 200 ), (5 , 10 , 20 )])
8089def test_dense_format (dtype , shape ):
@@ -176,11 +185,7 @@ def test_add(rng, dtype):
176185@parametrize_dtypes
177186def test_csf_format (dtype ):
178187 SHAPE = (2 , 2 , 4 )
179- pos_1 = np .array ([0 , 1 , 3 ], dtype = np .int64 )
180- crd_1 = np .array ([1 , 0 , 1 ], dtype = np .int64 )
181- pos_2 = np .array ([0 , 3 , 5 , 7 ], dtype = np .int64 )
182- crd_2 = np .array ([0 , 1 , 3 , 0 , 3 , 0 , 1 ], dtype = np .int64 )
183- data = np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = dtype )
188+ pos_1 , crd_1 , pos_2 , crd_2 , data = get_exampe_csf_arrays (dtype )
184189 csf = [pos_1 , crd_1 , pos_2 , crd_2 , data ]
185190
186191 csf_tensor = sparse .asarray (csf , shape = SHAPE , dtype = sparse .asdtype (dtype ), format = "csf" )
@@ -192,3 +197,70 @@ def test_csf_format(dtype):
192197 csf_2 = [pos_1 , crd_1 , pos_2 , crd_2 , data * 2 ]
193198 for actual , expected in zip (res_tensor , csf_2 , strict = False ):
194199 np .testing .assert_array_equal (actual , expected )
200+
201+
202+ @parametrize_dtypes
203+ def test_reshape (rng , dtype ):
204+ DENSITY = 0.5
205+ sampler = generate_sampler (dtype , rng )
206+
207+ # CSR, CSC, COO
208+ for shape , new_shape in [((100 , 50 ), (25 , 200 )), ((80 , 1 ), (8 , 10 ))]:
209+ for format in ["csr" , "csc" , "coo" ]:
210+ if format == "coo" :
211+ # NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
212+ continue
213+ if format == "csc" :
214+ # NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
215+ continue
216+
217+ arr = sps .random_array (
218+ shape , density = DENSITY , format = format , dtype = dtype , random_state = rng , data_sampler = sampler
219+ )
220+ if format == "coo" :
221+ arr .sum_duplicates ()
222+
223+ tensor = sparse .asarray (arr )
224+
225+ actual = sparse .reshape (tensor , shape = new_shape ).to_scipy_sparse ()
226+ expected = arr .todense ().reshape (new_shape )
227+
228+ np .testing .assert_array_equal (actual .todense (), expected )
229+
230+ # CSF
231+ csf_shape = (2 , 2 , 4 )
232+ for shape , new_shape , expected_arrs in [
233+ (
234+ csf_shape ,
235+ (4 , 4 , 1 ),
236+ [
237+ np .array ([0 , 0 , 3 , 5 , 7 ]),
238+ np .array ([0 , 1 , 3 , 0 , 3 , 0 , 1 ]),
239+ np .array ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
240+ np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 ]),
241+ np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
242+ ],
243+ ),
244+ (
245+ csf_shape ,
246+ (2 , 1 , 8 ),
247+ [
248+ np .array ([0 , 1 , 2 ]),
249+ np .array ([0 , 0 ]),
250+ np .array ([0 , 3 , 7 ]),
251+ np .array ([4 , 5 , 7 , 0 , 3 , 4 , 5 ]),
252+ np .array ([1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
253+ ],
254+ ),
255+ ]:
256+ csf = get_exampe_csf_arrays (dtype )
257+ csf_tensor = sparse .asarray (csf , shape = shape , dtype = sparse .asdtype (dtype ), format = "csf" )
258+
259+ result = sparse .reshape (csf_tensor , shape = new_shape ).to_scipy_sparse ()
260+
261+ for actual , expected in zip (result , expected_arrs , strict = False ):
262+ np .testing .assert_array_equal (actual , expected )
263+
264+ # DENSE
265+ # NOTE: dense reshape is probably broken in MLIR
266+ # dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
0 commit comments