Skip to content

Commit 52c87e0

Browse files
committed
[mlir][sparse] use consistent type for COO object and sparse tensor storage
There was a slightly mismatch between the double COO and actual numerical type in the final sparse tensor storage (due to external formats always using double). This minor revision removes that inconsistency by using a properly typed COO and casting during the "add" method instead. This also prepares alternative ways of initializing the COO object. Reviewed By: gussmith23 Differential Revision: https://reviews.llvm.org/D107310
1 parent 65e9d7e commit 52c87e0

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

mlir/lib/ExecutionEngine/SparseUtils.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ class SparseTensorStorageBase {
151151
/// each differently annotated sparse tensor, this method provides a convenient
152152
/// "one-size-fits-all" solution that simply takes an input tensor and
153153
/// annotations to implement all required setup in a general manner.
154-
template <typename P, typename I, typename V, typename Ve>
154+
template <typename P, typename I, typename V>
155155
class SparseTensorStorage : public SparseTensorStorageBase {
156156
public:
157157
/// Constructs sparse tensor storage scheme following the given
158158
/// per-rank dimension dense/sparse annotations.
159-
SparseTensorStorage(SparseTensor<Ve> *tensor, uint8_t *sparsity)
159+
SparseTensorStorage(SparseTensor<V> *tensor, uint8_t *sparsity)
160160
: sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
161161
// Provide hints on capacity.
162162
// TODO: needs fine-tuning based on sparsity
@@ -191,14 +191,23 @@ class SparseTensorStorage : public SparseTensorStorageBase {
191191
}
192192
void getValues(std::vector<V> **out) override { *out = &values; }
193193

194+
// Factory method.
195+
static SparseTensorStorage<P, I, V> *newSparseTensor(SparseTensor<V> *t,
196+
uint8_t *s) {
197+
t->sort(); // sort lexicographically
198+
SparseTensorStorage<P, I, V> *n = new SparseTensorStorage<P, I, V>(t, s);
199+
delete t;
200+
return n;
201+
}
202+
194203
private:
195204
/// Initializes sparse tensor storage scheme from a memory-resident
196205
/// representation of an external sparse tensor. This method prepares
197206
/// the pointers and indices arrays under the given per-rank dimension
198207
/// dense/sparse annotations.
199-
void traverse(SparseTensor<Ve> *tensor, uint8_t *sparsity, uint64_t lo,
208+
void traverse(SparseTensor<V> *tensor, uint8_t *sparsity, uint64_t lo,
200209
uint64_t hi, uint64_t d) {
201-
const std::vector<Element<Ve>> &elements = tensor->getElements();
210+
const std::vector<Element<V>> &elements = tensor->getElements();
202211
// Once dimensions are exhausted, insert the numerical values.
203212
if (d == getRank()) {
204213
values.push_back(lo < hi ? elements[lo].value : 0);
@@ -321,9 +330,9 @@ static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
321330
}
322331

323332
/// Reads a sparse tensor with the given filename into a memory-resident
324-
/// sparse tensor in coordinate scheme. The external formats always store
325-
/// the numerical values with the type double.
326-
static SparseTensor<double> *openTensor(char *filename, uint64_t *perm) {
333+
/// sparse tensor in coordinate scheme.
334+
template <typename V>
335+
static SparseTensor<V> *openTensor(char *filename, uint64_t *perm) {
327336
// Open the file.
328337
FILE *file = fopen(filename, "r");
329338
if (!file) {
@@ -347,7 +356,7 @@ static SparseTensor<double> *openTensor(char *filename, uint64_t *perm) {
347356
std::vector<uint64_t> indices(rank);
348357
for (uint64_t r = 0; r < rank; r++)
349358
indices[perm[r]] = idata[2 + r];
350-
SparseTensor<double> *tensor = new SparseTensor<double>(indices, nnz);
359+
SparseTensor<V> *tensor = new SparseTensor<V>(indices, nnz);
351360
// Read all nonzero elements.
352361
for (uint64_t k = 0; k < nnz; k++) {
353362
uint64_t idx = -1;
@@ -359,28 +368,17 @@ static SparseTensor<double> *openTensor(char *filename, uint64_t *perm) {
359368
// Add 0-based index.
360369
indices[perm[r]] = idx - 1;
361370
}
371+
// The external formats always store the numerical values with the type
372+
// double, but we cast these values to the sparse tensor object type.
362373
double value;
363374
if (fscanf(file, "%lg\n", &value) != 1) {
364375
fprintf(stderr, "Cannot find next value in %s\n", filename);
365376
exit(1);
366377
}
367378
tensor->add(indices, value);
368379
}
369-
// Close the file and return sorted tensor.
380+
// Close the file and return tensor.
370381
fclose(file);
371-
tensor->sort(); // sort lexicographically
372-
return tensor;
373-
}
374-
375-
/// Templated reader.
376-
template <typename P, typename I, typename V>
377-
void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t *perm,
378-
uint64_t size) {
379-
SparseTensor<double> *t = openTensor(filename, perm);
380-
assert(size == t->getRank()); // sparsity array must match rank
381-
SparseTensorStorageBase *tensor =
382-
new SparseTensorStorage<P, I, V, double>(t, sparsity);
383-
delete t;
384382
return tensor;
385383
}
386384

@@ -419,8 +417,11 @@ char *getTensorFilename(uint64_t id) {
419417
}
420418

421419
#define CASE(p, i, v, P, I, V) \
422-
if (ptrTp == (p) && indTp == (i) && valTp == (v)) \
423-
return newSparseTensor<P, I, V>(filename, sparsity, perm, asize)
420+
if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
421+
SparseTensor<V> *tensor = openTensor<V>(filename, perm); \
422+
assert(asize == tensor->getRank()); \
423+
return SparseTensorStorage<P, I, V>::newSparseTensor(tensor, sparsity); \
424+
}
424425

425426
#define IMPL1(RET, NAME, TYPE, LIB) \
426427
RET NAME(void *tensor) { \

0 commit comments

Comments
 (0)