Skip to content

Commit 4f09461

Browse files
adamlerersoumith
authored andcommitted
Rename sparse tensor contiguous() to coalesce()
1 parent bafb2e5 commit 4f09461

File tree

13 files changed

+74
-96
lines changed

13 files changed

+74
-96
lines changed

test/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def assertTensorsEqual(a, b):
150150
self.assertLessEqual(max_err, prec, message)
151151
self.assertEqual(x.is_sparse, y.is_sparse, message)
152152
if x.is_sparse:
153-
x = x.clone().contiguous()
154-
y = y.clone().contiguous()
153+
x = x.clone().coalesce_()
154+
y = y.clone().coalesce_()
155155
assertTensorsEqual(x.indices(), y.indices())
156156
assertTensorsEqual(x.values(), y.values())
157157
else:

test/test_sparse.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _test_contig(self, is_cuda):
156156
[31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
157157
])
158158
exp_v = ValueTensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7])
159-
x.contiguous()
159+
x.coalesce_()
160160
self.assertEqual(exp_i, x.indices())
161161
self.assertEqual(exp_v, x.values())
162162

@@ -174,7 +174,7 @@ def _test_contig(self, is_cuda):
174174
])
175175
exp_v = ValueTensor([2, 1, 3, 4])
176176

177-
x.contiguous()
177+
x.coalesce_()
178178
self.assertEqual(exp_i, x.indices())
179179
self.assertEqual(exp_v, x.values())
180180

@@ -193,7 +193,7 @@ def _test_contig(self, is_cuda):
193193
])
194194
exp_v = ValueTensor([6, 4])
195195

196-
x.contiguous()
196+
x.coalesce_()
197197
self.assertEqual(exp_i, x.indices())
198198
self.assertEqual(exp_v, x.values())
199199

@@ -224,7 +224,7 @@ def _test_contig_hybrid(self, is_cuda):
224224
[2, 3], [1, 2], [6, 7], [4, 5], [10, 11],
225225
[3, 4], [5, 6], [9, 10], [8, 9], [7, 8],
226226
])
227-
x.contiguous()
227+
x.coalesce_()
228228
self.assertEqual(exp_i, x.indices())
229229
self.assertEqual(exp_v, x.values())
230230

@@ -242,7 +242,7 @@ def _test_contig_hybrid(self, is_cuda):
242242
])
243243
exp_v = ValueTensor([[2, 2, 2], [1, 1, 1], [3, 3, 3], [4, 4, 4]])
244244

245-
x.contiguous()
245+
x.coalesce_()
246246
self.assertEqual(exp_i, x.indices())
247247
self.assertEqual(exp_v, x.values())
248248

@@ -261,7 +261,7 @@ def _test_contig_hybrid(self, is_cuda):
261261
])
262262
exp_v = ValueTensor([[6, 4, 5], [4, 3, 4]])
263263

264-
x.contiguous()
264+
x.coalesce_()
265265
self.assertEqual(exp_i, x.indices())
266266
self.assertEqual(exp_v, x.values())
267267

torch/csrc/generic/methods/SparseTensor.cwrap

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)
6262
]]
6363

6464
[[
65-
name: isContiguous
65+
name: isCoalesced
6666
sparse: yes
67-
python_name: is_contiguous
67+
python_name: is_coalesced
6868
return: bool
6969
arguments:
7070
- THSTensor* self
@@ -89,7 +89,8 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)
8989
]]
9090

9191
[[
92-
name: contiguous
92+
name: coalesce
93+
python_name: coalesce_
9394
sparse: yes
9495
return: argument 0
9596
arguments:

torch/lib/THCS/generic/THCSTensor.c

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static void THCSTensor_(rawInit)(THCState *state, THCSTensor *self)
7171
self->values = THCTensor_(new)(state);
7272
self->nDimensionI = 0;
7373
self->nDimensionV = 0;
74-
self->contiguous = 0;
74+
self->coalesced = 0;
7575
self->nnz = 0;
7676
// self->flag = TH_TENSOR_REFCOUNTED;
7777
self->refcount = 1;
@@ -86,7 +86,7 @@ void THCSTensor_(rawResize)(THCState *state, THCSTensor *self, int nDimI, int nD
8686
}
8787
self->nDimensionI = nDimI;
8888
self->nDimensionV = nDimV;
89-
self->contiguous = 0;
89+
self->coalesced = 0;
9090
}
9191

9292
// directly assign without cloning or retaining (internal method)
@@ -110,7 +110,7 @@ THCSTensor* THCSTensor_(_move)(THCState *state, THCSTensor *self, THCIndexTensor
110110
self->indices = indices;
111111
self->values = values;
112112
self->nnz = empty ? 0 : THCTensor_(size)(state, values, 0);
113-
self->contiguous = 0;
113+
self->coalesced = 0;
114114

115115
return self;
116116
}
@@ -236,13 +236,7 @@ THCSTensor *THCSTensor_(newClone)(THCState *state, THCSTensor *self) {
236236
THCSTensor_(_set)(state, other, self->indices, self->values);
237237

238238
other->nnz = self->nnz;
239-
other->contiguous = self->contiguous;
240-
return other;
241-
}
242-
243-
THCSTensor *THCSTensor_(newContiguous)(THCState *state, THCSTensor *self) {
244-
THCSTensor *other = THCSTensor_(newClone)(state, self);
245-
THCSTensor_(contiguous)(state, other);
239+
other->coalesced = self->coalesced;
246240
return other;
247241
}
248242

@@ -340,11 +334,11 @@ void THCSTensor_(copy)(THCState *state, THCSTensor *self, THCSTensor *src) {
340334
THCSTensor_(rawResize)(state, self, src->nDimensionI, src->nDimensionV, src->size);
341335
THCSTensor_(_set)(state, self, src->indices, src->values);
342336
self->nnz = src->nnz;
343-
self->contiguous = src->contiguous;
337+
self->coalesced = src->coalesced;
344338
}
345339

346-
int THCSTensor_(isContiguous)(THCState *state, const THCSTensor *self) {
347-
return self->contiguous;
340+
int THCSTensor_(isCoalesced)(THCState *state, const THCSTensor *self) {
341+
return self->coalesced;
348342
}
349343

350344
void THCSTensor_(free)(THCState *state, THCSTensor *self)
@@ -365,12 +359,6 @@ void THCSTensor_(retain)(THCState *state, THCSTensor *self)
365359
THAtomicIncrementRef(&self->refcount);
366360
}
367361

368-
void THCSTensor_(contiguous)(THCState *state, THCSTensor *self) {
369-
if (self->contiguous) return;
370-
THCSTensor_(reorder)(state, self);
371-
self->contiguous = 1;
372-
}
373-
374362
int THCSTensor_(checkGPU)(THCState *state, unsigned int nSparseTensors, unsigned int nTensors, ...)
375363
{
376364
/* FIXME: remove this flag after any users stop using it since it is
@@ -446,7 +434,7 @@ void THCTensor_(sparseMask)(THCState *state, THCSTensor *r_, THCTensor *t, THCST
446434
THCTensor *rValues = THCTensor_(new)(state);
447435
THCTensor_(resizeAs)(state, rValues, maskValues);
448436
THCSTensor_(_move)(state, r_, THCIndexTensor_(newClone)(state, maskIndices), rValues);
449-
r_->contiguous = mask->contiguous;
437+
r_->coalesced = mask->coalesced;
450438
r_->nnz = mask->nnz;
451439

452440
THCudaLongTensor *indices = THCudaLongTensor_newWithSize1d(state, mask->nnz);

torch/lib/THCS/generic/THCSTensor.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ THCTensor *THCSTensor_(toDense)(THCState *state, THCSTensor *self) {
2323
THLongStorage *size;
2424
THCTensor *dst;
2525

26-
THCSTensor_(contiguous)(state, self);
26+
THCSTensor_(coalesce)(state, self);
2727

2828
// set up the new tensor
2929
size = THCSTensor_(newSizeOf)(state, self);
@@ -61,7 +61,8 @@ THCTensor *THCSTensor_(toDense)(THCState *state, THCSTensor *self) {
6161
return dst;
6262
}
6363

64-
void THCSTensor_(reorder)(THCState *state, THCSTensor *self) {
64+
void THCSTensor_(coalesce)(THCState *state, THCSTensor *self) {
65+
if (self->coalesced) return;
6566
if (self->nnz < 2) return;
6667
#if CUDA_VERSION >= 7000
6768
THCThrustAllocator thrustAlloc(state);
@@ -80,7 +81,7 @@ void THCSTensor_(reorder)(THCState *state, THCSTensor *self) {
8081
// Multiple values in D1 can map to the same position in D2 if there are duplicate indices
8182
// Values mapping to the same position are added together (which is what matrix multiplication does)
8283
//
83-
// When constructing S, we must make sure that it is contiguous (otherwise this function will call itself when doing the multiplication)
84+
// When constructing S, we must make sure that it is coalesced (otherwise this function will call itself when doing the multiplication)
8485
// To achieve this, we define the indices tensor of S as follows:
8586
// * the second row contains the permutation corresponding to a stable sort of the original indices
8687
// * the first row "maps" those indices to their final location after deduplication
@@ -142,7 +143,7 @@ void THCSTensor_(reorder)(THCState *state, THCSTensor *self) {
142143
// build S
143144
THCSTensor *S = THCSTensor_(newWithSize2d)(state, newNnz, self->nnz);
144145
THCSTensor_(_move)(state, S, sIndices, sValues);
145-
S->contiguous = 1;
146+
S->coalesced = 1;
146147

147148
// build output indices tensor by doing an indexSelect over the sorted list of unique indices
148149
THCIndexTensor *newIndices = THCIndexTensor_(new)(state);
@@ -201,6 +202,8 @@ void THCSTensor_(reorder)(THCState *state, THCSTensor *self) {
201202
THCTensor_(free)(state, newValuesView);
202203
}
203204

205+
self->coalesced = 1;
206+
204207
#undef THRUST_EXEC
205208
}
206209

@@ -220,7 +223,7 @@ void THCSTensor_(transpose)(THCState *state, THCSTensor *self, int d1, int d2) {
220223
long i = self->size[d1];
221224
self->size[d1] = self->size[d2];
222225
self->size[d2] = i;
223-
self->contiguous = 0;
226+
self->coalesced = 0;
224227
THCIndexTensor_(free)(state, indices);
225228
THCIndexTensor_(free)(state, buffer);
226229
THCIndexTensor_(free)(state, slice1);

torch/lib/THCS/generic/THCSTensor.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ typedef struct THCSTensor
1414
THCIndexTensor *indices;
1515
THCTensor *values;
1616
// Math operations can only be performed on ordered sparse tensors
17-
int contiguous;
17+
int coalesced;
1818
int refcount;
1919

2020
} THCSTensor;
@@ -41,7 +41,6 @@ TH_API THCSTensor *THCSTensor_(newWithSize3d)(THCState *state, long size0_, long
4141
TH_API THCSTensor *THCSTensor_(newWithSize4d)(THCState *state, long size0_, long size1_, long size2_, long size3_);
4242

4343
TH_API THCSTensor *THCSTensor_(newClone)(THCState *state, THCSTensor *self);
44-
TH_API THCSTensor *THCSTensor_(newContiguous)(THCState *state, THCSTensor *self);
4544
TH_API THCSTensor *THCSTensor_(newTranspose)(THCState *state, THCSTensor *self, int dimension1_, int dimension2_);
4645

4746
/**** reshaping methods ***/
@@ -58,8 +57,8 @@ TH_API THCTensor *THCSTensor_(toDense)(THCState *state, THCSTensor *self);
5857
TH_API void THCSTensor_(copy)(THCState *state, THCSTensor *self, THCSTensor *src);
5958

6059
TH_API void THCSTensor_(transpose)(THCState *state, THCSTensor *self, int dimension1_, int dimension2_);
61-
TH_API int THCSTensor_(isContiguous)(THCState *state, const THCSTensor *self);
62-
TH_API void THCSTensor_(contiguous)(THCState *state, THCSTensor *self);
60+
TH_API int THCSTensor_(isCoalesced)(THCState *state, const THCSTensor *self);
61+
TH_API void THCSTensor_(coalesce)(THCState *state, THCSTensor *self);
6362

6463
TH_API void THCTensor_(sparseMask)(THCState *state, THCSTensor *r_, THCTensor *t, THCSTensor *mask);
6564

@@ -72,7 +71,6 @@ TH_API int THCSTensor_(checkGPU)(THCState *state, unsigned int nSparseTensors, u
7271

7372
/* internal methods */
7473
TH_API void THCSTensor_(rawResize)(THCState *state, THCSTensor *self, int nDimI, int nDimV, long *size);
75-
TH_API void THCSTensor_(reorder)(THCState *state, THCSTensor *self);
7674
TH_API THCTensor *THCSTensor_(newValuesWithSizeOf)(THCState *state, THCTensor *values, long nnz);
7775
TH_API THCSTensor* THCSTensor_(_move)(THCState *state, THCSTensor *self, THCIndexTensor *indices, THCTensor *values);
7876
TH_API THCSTensor* THCSTensor_(_set)(THCState *state, THCSTensor *self, THCIndexTensor *indices, THCTensor *values);

torch/lib/THCS/generic/THCSTensorMath.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *
6767
THArgCheck(THCTensor_(size)(state, dense, 0) == k, 3,
6868
"Expected dim 0 size %d, got %d", k, THCTensor_(size)(state, dense, 0));
6969

70-
THCSTensor_(contiguous)(state, sparse);
70+
THCSTensor_(coalesce)(state, sparse);
7171

7272
long nnz = THCSTensor_(nnz)(state, sparse);
7373
indices = THCSTensor_(newIndices)(state, sparse);
@@ -182,7 +182,7 @@ void THCSTensor_(hspmm)(THCState *state, THCSTensor *r_, real alpha, THCSTensor
182182
long size[2] = {m, n};
183183
THCSTensor_(rawResize)(state, r_, 1, 1, size);
184184

185-
THCSTensor_(contiguous)(state, sparse);
185+
THCSTensor_(coalesce)(state, sparse);
186186

187187
long nnz = THCSTensor_(nnz)(state, sparse);
188188
THCIndexTensor *indices = THCIndexTensor_(newWithSize2d)(state, 1, nnz);
@@ -213,7 +213,7 @@ void THCSTensor_(hspmm)(THCState *state, THCSTensor *r_, real alpha, THCSTensor
213213
void THCSTensor_(spcadd)(THCState *state, THCTensor *r_, THCTensor *dense, real value, THCSTensor *sparse) {
214214
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 1, 3, sparse, r_, dense));
215215
THCTensor_(resizeAs)(state, r_, dense);
216-
THCSTensor_(contiguous)(state, sparse);
216+
THCSTensor_(coalesce)(state, sparse);
217217

218218
THCIndexTensor *indices = THCSTensor_(newIndices)(state, sparse);
219219
THCTensor *values = THCSTensor_(newValues)(state, sparse);
@@ -274,7 +274,7 @@ void THCSTensor_(mul)(THCState *state, THCSTensor *r_, THCSTensor *t, real value
274274
THCIndexTensor_(copy)(state, r_indices_, t_indices_);
275275
THCTensor_(mul)(state, r_values_, t_values_, value);
276276
r_->nnz = t->nnz;
277-
r_->contiguous = t->contiguous;
277+
r_->coalesced = t->coalesced;
278278

279279
THCIndexTensor_(free)(state, r_indices_);
280280
THCTensor_(free)(state, r_values_);
@@ -300,7 +300,7 @@ void THCSTensor_(div)(THCState *state, THCSTensor *r_, THCSTensor *t, real value
300300
THCIndexTensor_(copy)(state, r_indices_, t_indices_);
301301
THCTensor_(div)(state, r_values_, t_values_, value);
302302
r_->nnz = t->nnz;
303-
r_->contiguous = t->contiguous;
303+
r_->coalesced = t->coalesced;
304304

305305
THCIndexTensor_(free)(state, r_indices_);
306306
THCTensor_(free)(state, r_values_);
@@ -314,8 +314,8 @@ void THCSTensor_(cadd)(THCState *state, THCSTensor *r_, THCSTensor *t, real valu
314314
if(!THCSTensor_(isSameSizeAs)(state, t, src)) {
315315
THError("cadd operands have incompatible sizes or dimension types");
316316
}
317-
THCSTensor_(contiguous)(state, t);
318-
THCSTensor_(contiguous)(state, src);
317+
THCSTensor_(coalesce)(state, t);
318+
THCSTensor_(coalesce)(state, src);
319319

320320
if (src->nnz == 0) {
321321
THCSTensor_(copy)(state, r_, t);
@@ -328,7 +328,7 @@ void THCSTensor_(cadd)(THCState *state, THCSTensor *r_, THCSTensor *t, real valu
328328

329329
// We deliberately choose to simply concat the indices and values tensors
330330
// rather than merging them. This removes the need to synchronously fetch nnz
331-
// at the end of the operation, at the cost of having a non-contiguous result.
331+
// at the end of the operation, at the cost of having a non-coalesced result.
332332
// This trade-off is preferable for the common use-case of gradient accumulation.
333333
// TODO have two distinct functions? The other option is commented out below
334334
THCIndexTensor *t_indices_ = THCSTensor_(newIndices)(state, t);
@@ -392,7 +392,7 @@ void THCSTensor_(cadd)(THCState *state, THCSTensor *r_, THCSTensor *t, real valu
392392
// unsigned long nnzOut;
393393
// THCudaCheck(cudaMemcpy(&nnzOut, scratchSpace, sizeof(unsigned long), cudaMemcpyDeviceToHost));
394394
// r_->nnz = nnzOut;
395-
// r_->contiguous = 1;
395+
// r_->coalesced = 1;
396396
// if (freeScratchSpace) {
397397
// THCudaCheck(THCudaFree(state, scratchSpace));
398398
// }
@@ -412,8 +412,8 @@ void THCSTensor_(cmul)(THCState *state, THCSTensor *r_, THCSTensor *t, THCSTenso
412412
if(!THCSTensor_(isSameSizeAs)(state, t, src)) {
413413
THError("cmul operands have incompatible sizes or dimension types");
414414
}
415-
THCSTensor_(contiguous)(state, t);
416-
THCSTensor_(contiguous)(state, src);
415+
THCSTensor_(coalesce)(state, t);
416+
THCSTensor_(coalesce)(state, src);
417417

418418
if (t->nnz == 0 || src->nnz == 0) {
419419
THCSTensor_(zero)(state, r_);
@@ -455,7 +455,7 @@ void THCSTensor_(cmul)(THCState *state, THCSTensor *r_, THCSTensor *t, THCSTenso
455455
THCudaCheck(cudaGetLastError());
456456
r_->nnz = THCudaLongStorage_get(state, resultNnz, 0);
457457
THCudaLongStorage_free(state, resultNnz);
458-
r_->contiguous = 1;
458+
r_->coalesced = 1;
459459

460460
THCIndexTensor_(free)(state, t_indices_);
461461
THCTensor_(free)(state, t_values_);

torch/lib/THPP/tensors/generic/THCSTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ auto THCSTensor<real>::clone_shallow() -> THCSTensor* {
4141

4242
template<>
4343
auto THCSTensor<real>::contiguous() const -> std::unique_ptr<Tensor> {
44-
return std::unique_ptr<Tensor>(new THCSTensor(state, THCSTensor_(newContiguous)(state, tensor)));
44+
throw std::runtime_error("THCSTensor::contiguous() not supported");
4545
}
4646

4747
template<>

torch/lib/THPP/tensors/generic/THSTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ auto THSTensor<real>::clone_shallow() -> THSTensor* {
4040

4141
template<>
4242
auto THSTensor<real>::contiguous() const -> std::unique_ptr<Tensor> {
43-
return std::unique_ptr<Tensor>(new THSTensor(THSTensor_(newContiguous)(tensor)));
43+
throw std::runtime_error("THCSTensor::rawStrides() not supported");
4444
}
4545

4646
template<>

0 commit comments

Comments
 (0)