Skip to content

Commit f290333

Browse files
adamlerersoumith
authored andcommitted
Make coalesce() out of place
1 parent 9643be7 commit f290333

File tree

10 files changed

+141
-149
lines changed

10 files changed

+141
-149
lines changed

test/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def assertTensorsEqual(a, b):
150150
self.assertLessEqual(max_err, prec, message)
151151
self.assertEqual(x.is_sparse, y.is_sparse, message)
152152
if x.is_sparse:
153-
x = x.clone().coalesce_()
154-
y = y.clone().coalesce_()
153+
x = x.coalesce()
154+
y = y.coalesce()
155155
assertTensorsEqual(x.indices(), y.indices())
156156
assertTensorsEqual(x.values(), y.values())
157157
else:

test/test_sparse.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _test_contig(self, is_cuda):
156156
[31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
157157
])
158158
exp_v = ValueTensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7])
159-
x.coalesce_()
159+
x = x.coalesce()
160160
self.assertEqual(exp_i, x.indices())
161161
self.assertEqual(exp_v, x.values())
162162

@@ -174,7 +174,7 @@ def _test_contig(self, is_cuda):
174174
])
175175
exp_v = ValueTensor([2, 1, 3, 4])
176176

177-
x.coalesce_()
177+
x = x.coalesce()
178178
self.assertEqual(exp_i, x.indices())
179179
self.assertEqual(exp_v, x.values())
180180

@@ -193,7 +193,7 @@ def _test_contig(self, is_cuda):
193193
])
194194
exp_v = ValueTensor([6, 4])
195195

196-
x.coalesce_()
196+
x = x.coalesce()
197197
self.assertEqual(exp_i, x.indices())
198198
self.assertEqual(exp_v, x.values())
199199

@@ -224,7 +224,7 @@ def _test_contig_hybrid(self, is_cuda):
224224
[2, 3], [1, 2], [6, 7], [4, 5], [10, 11],
225225
[3, 4], [5, 6], [9, 10], [8, 9], [7, 8],
226226
])
227-
x.coalesce_()
227+
x = x.coalesce()
228228
self.assertEqual(exp_i, x.indices())
229229
self.assertEqual(exp_v, x.values())
230230

@@ -242,7 +242,7 @@ def _test_contig_hybrid(self, is_cuda):
242242
])
243243
exp_v = ValueTensor([[2, 2, 2], [1, 1, 1], [3, 3, 3], [4, 4, 4]])
244244

245-
x.coalesce_()
245+
x = x.coalesce()
246246
self.assertEqual(exp_i, x.indices())
247247
self.assertEqual(exp_v, x.values())
248248

@@ -261,7 +261,7 @@ def _test_contig_hybrid(self, is_cuda):
261261
])
262262
exp_v = ValueTensor([[6, 4, 5], [4, 3, 4]])
263263

264-
x.coalesce_()
264+
x = x.coalesce()
265265
self.assertEqual(exp_i, x.indices())
266266
self.assertEqual(exp_v, x.values())
267267

@@ -490,6 +490,16 @@ def _test_basic_ops_shape(self, is_cuda, shape_i, shape_v=None):
490490
expected = torch.zeros(x1.size())
491491
self.assertEqual(y.to_dense(), expected)
492492

493+
self.assertFalse(x1.is_coalesced())
494+
y = x1.coalesce()
495+
z = x1.coalesce()
496+
self.assertFalse(x1.is_coalesced())
497+
self.assertTrue(y.is_coalesced())
498+
self.assertEqual(x1, y)
499+
# check that coalesce is out of place
500+
y.values().add_(1)
501+
self.assertEqual(z.values() + 1, y.values())
502+
493503
def _test_basic_ops(self, is_cuda):
494504
self._test_basic_ops_shape(is_cuda, [5, 6])
495505
self._test_basic_ops_shape(is_cuda, [10, 10, 10])

torch/csrc/generic/methods/SparseTensor.cwrap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)
9090

9191
[[
9292
name: coalesce
93-
python_name: coalesce_
93+
cname: newCoalesce
9494
sparse: yes
95-
return: argument 0
95+
return: THSTensor*
9696
arguments:
9797
- THSTensor* self
9898
]]

torch/lib/THCS/generic/THCSTensor.cu

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,16 @@ THCTensor *THCSTensor_(toDense)(THCState *state, THCSTensor *self) {
3535
return dst;
3636
}
3737

38-
void THCSTensor_(coalesce)(THCState *state, THCSTensor *self) {
39-
if (self->coalesced) return;
40-
int nnz = self->nnz;
41-
if (nnz < 2) return;
38+
THCSTensor *THCSTensor_(newCoalesce)(THCState *state, THCSTensor *self) {
39+
ptrdiff_t nnz = self->nnz;
40+
if (nnz < 2) {
41+
self->coalesced = 1;
42+
}
43+
if (self->coalesced) {
44+
THCSTensor_(retain)(state, self);
45+
return self;
46+
}
47+
4248
#if CUDA_VERSION >= 7000
4349
THCThrustAllocator thrustAlloc(state);
4450
#define THRUST_EXEC(fn, ...) fn(thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), ##__VA_ARGS__)
@@ -93,10 +99,10 @@ void THCSTensor_(coalesce)(THCState *state, THCSTensor *self) {
9399
long newNnz = newEnd.first - indicesIter;
94100

95101
THCIndexTensor_(resize2d)(state, indices1D, 1, newNnz);
96-
THLongStorage *newValuesSize = THCTensor_(newSizeOf)(state, values);
97-
newValuesSize->data[0] = newNnz;
98-
THCTensor *newValues = THCTensor_(newWithSize)(state, newValuesSize, NULL);
99-
THLongStorage_free(newValuesSize);
102+
THCTensor *newValues = THCTensor_(new)(state);
103+
THCTensor_(resizeNd)(state, newValues, values->nDimension, values->size, NULL);
104+
newValues->size[0] = newNnz;
105+
100106

101107
dim3 grid(THCCeilDiv(newNnz, (long) 4), THCCeilDiv(stride, (long) 128));
102108
dim3 block(32, 4);
@@ -152,16 +158,16 @@ void THCSTensor_(coalesce)(THCState *state, THCSTensor *self) {
152158
THCIndexTensor_(free)(state, indicesSlice);
153159
}
154160
////////////////////////////////////////////////////////////
155-
self->nnz = newNnz;
156-
THCIndexTensor_(free)(state, self->indices);
157-
self->indices = newIndices;
161+
THLongStorage *size = THCSTensor_(newSizeOf)(state, self);
162+
THCSTensor *dst = THCSTensor_(newWithTensorAndSize)(state, newIndices, newValues, size);
163+
THLongStorage_free(size);
158164

165+
THCIndexTensor_(free)(state, indices);
159166
THCTensor_(free)(state, values);
160-
THCTensor_(free)(state, self->values);
161-
self->values = newValues;
162167

163-
self->coalesced = 1;
168+
dst->coalesced = 1;
164169
THCudaCheck(cudaGetLastError());
170+
return dst;
165171
#undef THRUST_EXEC
166172
}
167173

torch/lib/THCS/generic/THCSTensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ TH_API void THCSTensor_(copy)(THCState *state, THCSTensor *self, THCSTensor *src
5858

5959
TH_API void THCSTensor_(transpose)(THCState *state, THCSTensor *self, int dimension1_, int dimension2_);
6060
TH_API int THCSTensor_(isCoalesced)(THCState *state, const THCSTensor *self);
61-
TH_API void THCSTensor_(coalesce)(THCState *state, THCSTensor *self);
61+
TH_API THCSTensor *THCSTensor_(newCoalesce)(THCState *state, THCSTensor *self);
6262

6363
TH_API void THCTensor_(sparseMask)(THCState *state, THCSTensor *r_, THCTensor *t, THCSTensor *mask);
6464

torch/lib/THCS/generic/THCSTensorMath.cu

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,22 @@ void THCTensor_(spaddcdiv)(THCState *state, THCTensor *r_, THCTensor *t, real va
4040
THError("WARNING: Sparse Cuda Tensor op spaddcdiv is not implemented");
4141
}
4242

43-
void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCSTensor *sparse, THCTensor *dense) {
43+
void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCSTensor *sparse_, THCTensor *dense) {
4444
#if defined(THCS_REAL_IS_FLOAT) || defined(THCS_REAL_IS_DOUBLE)
45-
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 1, 4, sparse, r_, t, dense));
45+
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 1, 4, sparse_, r_, t, dense));
4646
THCudaIntTensor *csr;
4747
THCIndexTensor *indices;
4848
THCTensor *values, *r__, *dense_;
4949

50-
THArgCheck(sparse->nDimensionI == 2, 2,
51-
"matrices expected, got %dD tensor", sparse->nDimensionI);
52-
THArgCheck(sparse->nDimensionV == 0, 2,
53-
"scalar values expected, got %dD values", sparse->nDimensionV);
50+
THArgCheck(sparse_->nDimensionI == 2, 2,
51+
"matrices expected, got %dD tensor", sparse_->nDimensionI);
52+
THArgCheck(sparse_->nDimensionV == 0, 2,
53+
"scalar values expected, got %dD values", sparse_->nDimensionV);
5454
THArgCheck(dense->nDimension == 2, 2,
5555
"matrices expected, got %dD tensor", dense->nDimension);
5656

57-
long m = THCSTensor_(size)(state, sparse, 0);
58-
long k = THCSTensor_(size)(state, sparse, 1);
57+
long m = THCSTensor_(size)(state, sparse_, 0);
58+
long k = THCSTensor_(size)(state, sparse_, 1);
5959
long n = THCTensor_(size)(state, dense, 1);
6060

6161
THCTensor_(resize2d)(state, r_, m, n);
@@ -67,7 +67,7 @@ void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *
6767
THArgCheck(THCTensor_(size)(state, dense, 0) == k, 3,
6868
"Expected dim 0 size %d, got %d", k, THCTensor_(size)(state, dense, 0));
6969

70-
THCSTensor_(coalesce)(state, sparse);
70+
THCSTensor *sparse = THCSTensor_(newCoalesce)(state, sparse_);
7171

7272
long nnz = THCSTensor_(nnz)(state, sparse);
7373
indices = THCSTensor_(newIndices)(state, sparse);
@@ -146,6 +146,7 @@ void THCSTensor_(spaddmm)(THCState *state, THCTensor *r_, real beta, THCTensor *
146146
THCIndexTensor_(free)(state, rowIndices);
147147
THCIndexTensor_(free)(state, colIndices);
148148
THCTensor_(free)(state, values);
149+
THCSTensor_(free)(state, sparse);
149150
#else
150151
THError("unimplemented data type");
151152
#endif
@@ -156,40 +157,42 @@ void THCSTensor_(sspaddmm)(THCState *state, THCSTensor *r_, real beta, THCSTenso
156157
// TODO Write some kernels
157158
}
158159

159-
void THCSTensor_(hspmm)(THCState *state, THCSTensor *r_, real alpha, THCSTensor *sparse, THCTensor *dense) {
160+
void THCSTensor_(hspmm)(THCState *state, THCSTensor *r_, real alpha, THCSTensor *sparse_, THCTensor *dense) {
160161
#if CUDA_VERSION >= 7000
161162
THCThrustAllocator thrustAlloc(state);
162163
#define THRUST_EXEC(fn, ...) fn(thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), ##__VA_ARGS__)
163164
#else
164165
#define THRUST_EXEC(fn, ...) fn(##__VA_ARGS__)
165166
#endif
166167

167-
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 2, 3, r_, sparse, dense));
168+
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 2, 3, r_, sparse_, dense));
168169

169-
THArgCheck(sparse->nDimensionI == 2, 3,
170-
"matrices expected, got %dD tensor", sparse->nDimensionI);
171-
THArgCheck(sparse->nDimensionV == 0, 3,
172-
"scalar values expected, got %dD values", sparse->nDimensionV);
170+
THArgCheck(sparse_->nDimensionI == 2, 3,
171+
"matrices expected, got %dD tensor", sparse_->nDimensionI);
172+
THArgCheck(sparse_->nDimensionV == 0, 3,
173+
"scalar values expected, got %dD values", sparse_->nDimensionV);
173174
THArgCheck(dense->nDimension == 2, 4,
174175
"matrices expected, got %dD tensor", dense->nDimension);
175176

176-
long m = THCSTensor_(size)(state, sparse, 0);
177-
long k = THCSTensor_(size)(state, sparse, 1);
177+
long m = THCSTensor_(size)(state, sparse_, 0);
178+
long k = THCSTensor_(size)(state, sparse_, 1);
178179
long n = THCTensor_(size)(state, dense, 1);
179180

180181
THArgCheck(THCTensor_(size)(state, dense, 0) == k, 4,
181182
"Expected dim 0 size %d, got %d", k, THCTensor_(size)(state, dense, 0));
182183
long size[2] = {m, n};
183184
THCSTensor_(rawResize)(state, r_, 1, 1, size);
184185

185-
THCSTensor_(coalesce)(state, sparse);
186+
THCSTensor *sparse = THCSTensor_(newCoalesce)(state, sparse_);
186187

187188
long nnz = THCSTensor_(nnz)(state, sparse);
188189
THCIndexTensor *indices = THCIndexTensor_(newWithSize2d)(state, 1, nnz);
189190
// create values in column-major format to avoid copying in spaddmm
190191
THCTensor *values = THCTensor_(newWithSize2d)(state, n, nnz);
191192
THCTensor_(transpose)(state, values, NULL, 0, 1);
192193

194+
// why does sparse need to be cloned? If this is really necessary maybe we
195+
// need to fuse this with newCoalesce
193196
THCSTensor *newSparse = THCSTensor_(newClone)(state, sparse);
194197
THCIndexTensor *spIndices = THCSTensor_(newIndices)(state, newSparse);
195198
THCIndexTensor *dstIndices = THCIndexTensor_(newSelect)(state, spIndices, 0, 0);
@@ -206,6 +209,7 @@ void THCSTensor_(hspmm)(THCState *state, THCSTensor *r_, real alpha, THCSTensor
206209
THCSTensor_(free)(state, newSparse);
207210
THCIndexTensor_(free)(state, spIndices);
208211
THCIndexTensor_(free)(state, dstIndices);
212+
THCSTensor_(free)(state, sparse);
209213

210214
#undef THRUST_EXEC
211215
}
@@ -348,8 +352,6 @@ void THCSTensor_(cadd)(THCState *state, THCSTensor *r_, THCSTensor *t, real valu
348352
if(!THCSTensor_(isSameSizeAs)(state, t, src)) {
349353
THError("cadd operands have incompatible sizes or dimension types");
350354
}
351-
THCSTensor_(coalesce)(state, t);
352-
THCSTensor_(coalesce)(state, src);
353355

354356
if (src->nnz == 0) {
355357
THCSTensor_(copy)(state, r_, t);
@@ -399,13 +401,13 @@ void THCSTensor_(csub)(THCState *state, THCSTensor *r_, THCSTensor *t, real valu
399401
THCSTensor_(cadd)(state, r_, t, ScalarNegate<real>::to(value), src);
400402
}
401403

402-
void THCSTensor_(cmul)(THCState *state, THCSTensor *r_, THCSTensor *t, THCSTensor *src) {
403-
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 3, 3, r_, t, src));
404-
if(!THCSTensor_(isSameSizeAs)(state, t, src)) {
404+
void THCSTensor_(cmul)(THCState *state, THCSTensor *r_, THCSTensor *t_, THCSTensor *src_) {
405+
THCAssertSameGPU(THCSTensor_(checkGPU)(state, 3, 3, r_, t_, src_));
406+
if(!THCSTensor_(isSameSizeAs)(state, t_, src_)) {
405407
THError("cmul operands have incompatible sizes or dimension types");
406408
}
407-
THCSTensor_(coalesce)(state, t);
408-
THCSTensor_(coalesce)(state, src);
409+
THCSTensor *t = THCSTensor_(newCoalesce)(state, t_);
410+
THCSTensor *src = THCSTensor_(newCoalesce)(state, src_);
409411

410412
if (t->nnz == 0 || src->nnz == 0) {
411413
THCSTensor_(zero)(state, r_);
@@ -453,6 +455,8 @@ void THCSTensor_(cmul)(THCState *state, THCSTensor *r_, THCSTensor *t, THCSTenso
453455
THCTensor_(free)(state, t_values_);
454456
THCIndexTensor_(free)(state, s_indices_);
455457
THCTensor_(free)(state, s_values_);
458+
THCSTensor_(free)(state, t);
459+
THCSTensor_(free)(state, src);
456460
}
457461

458462
#if defined(THCS_REAL_IS_FLOAT) || defined(THCS_REAL_IS_DOUBLE)

0 commit comments

Comments
 (0)