@@ -151,12 +151,12 @@ class SparseTensorStorageBase {
151151// / each differently annotated sparse tensor, this method provides a convenient
152152// / "one-size-fits-all" solution that simply takes an input tensor and
153153// / annotations to implement all required setup in a general manner.
154- template <typename P, typename I, typename V, typename Ve >
154+ template <typename P, typename I, typename V>
155155class SparseTensorStorage : public SparseTensorStorageBase {
156156public:
157157 // / Constructs sparse tensor storage scheme following the given
158158 // / per-rank dimension dense/sparse annotations.
159- SparseTensorStorage (SparseTensor<Ve > *tensor, uint8_t *sparsity)
159+ SparseTensorStorage (SparseTensor<V > *tensor, uint8_t *sparsity)
160160 : sizes(tensor->getSizes ()), pointers(getRank()), indices(getRank()) {
161161 // Provide hints on capacity.
162162 // TODO: needs fine-tuning based on sparsity
@@ -191,14 +191,23 @@ class SparseTensorStorage : public SparseTensorStorageBase {
191191 }
192192 void getValues (std::vector<V> **out) override { *out = &values; }
193193
194+ // Factory method.
195+ static SparseTensorStorage<P, I, V> *newSparseTensor (SparseTensor<V> *t,
196+ uint8_t *s) {
197+ t->sort (); // sort lexicographically
198+ SparseTensorStorage<P, I, V> *n = new SparseTensorStorage<P, I, V>(t, s);
199+ delete t;
200+ return n;
201+ }
202+
194203private:
195204 // / Initializes sparse tensor storage scheme from a memory-resident
196205 // / representation of an external sparse tensor. This method prepares
197206 // / the pointers and indices arrays under the given per-rank dimension
198207 // / dense/sparse annotations.
199- void traverse (SparseTensor<Ve > *tensor, uint8_t *sparsity, uint64_t lo,
208+ void traverse (SparseTensor<V > *tensor, uint8_t *sparsity, uint64_t lo,
200209 uint64_t hi, uint64_t d) {
201- const std::vector<Element<Ve >> &elements = tensor->getElements ();
210+ const std::vector<Element<V >> &elements = tensor->getElements ();
202211 // Once dimensions are exhausted, insert the numerical values.
203212 if (d == getRank ()) {
204213 values.push_back (lo < hi ? elements[lo].value : 0 );
@@ -321,9 +330,9 @@ static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
321330}
322331
323332// / Reads a sparse tensor with the given filename into a memory-resident
324- // / sparse tensor in coordinate scheme. The external formats always store
325- // / the numerical values with the type double.
326- static SparseTensor<double > *openTensor (char *filename, uint64_t *perm) {
333+ // / sparse tensor in coordinate scheme.
334+ template < typename V>
335+ static SparseTensor<V > *openTensor (char *filename, uint64_t *perm) {
327336 // Open the file.
328337 FILE *file = fopen (filename, " r" );
329338 if (!file) {
@@ -347,7 +356,7 @@ static SparseTensor<double> *openTensor(char *filename, uint64_t *perm) {
347356 std::vector<uint64_t > indices (rank);
348357 for (uint64_t r = 0 ; r < rank; r++)
349358 indices[perm[r]] = idata[2 + r];
350- SparseTensor<double > *tensor = new SparseTensor<double >(indices, nnz);
359+ SparseTensor<V > *tensor = new SparseTensor<V >(indices, nnz);
351360 // Read all nonzero elements.
352361 for (uint64_t k = 0 ; k < nnz; k++) {
353362 uint64_t idx = -1 ;
@@ -359,28 +368,17 @@ static SparseTensor<double> *openTensor(char *filename, uint64_t *perm) {
359368 // Add 0-based index.
360369 indices[perm[r]] = idx - 1 ;
361370 }
371+ // The external formats always store the numerical values with the type
372+ // double, but we cast these values to the sparse tensor object type.
362373 double value;
363374 if (fscanf (file, " %lg\n " , &value) != 1 ) {
364375 fprintf (stderr, " Cannot find next value in %s\n " , filename);
365376 exit (1 );
366377 }
367378 tensor->add (indices, value);
368379 }
369- // Close the file and return sorted tensor.
380+ // Close the file and return tensor.
370381 fclose (file);
371- tensor->sort (); // sort lexicographically
372- return tensor;
373- }
374-
375- // / Templated reader.
376- template <typename P, typename I, typename V>
377- void *newSparseTensor (char *filename, uint8_t *sparsity, uint64_t *perm,
378- uint64_t size) {
379- SparseTensor<double > *t = openTensor (filename, perm);
380- assert (size == t->getRank ()); // sparsity array must match rank
381- SparseTensorStorageBase *tensor =
382- new SparseTensorStorage<P, I, V, double >(t, sparsity);
383- delete t;
384382 return tensor;
385383}
386384
@@ -419,8 +417,11 @@ char *getTensorFilename(uint64_t id) {
419417 }
420418
421419#define CASE (p, i, v, P, I, V ) \
422- if (ptrTp == (p) && indTp == (i) && valTp == (v)) \
423- return newSparseTensor<P, I, V>(filename, sparsity, perm, asize)
420+ if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
421+ SparseTensor<V> *tensor = openTensor<V>(filename, perm); \
422+ assert (asize == tensor->getRank ()); \
423+ return SparseTensorStorage<P, I, V>::newSparseTensor (tensor, sparsity); \
424+ }
424425
425426#define IMPL1 (RET, NAME, TYPE, LIB ) \
426427 RET NAME (void *tensor) { \
0 commit comments