Skip to content

Commit 7409e08

Browse files
lantigasoumith
authored andcommitted
Cuda fixes
1 parent f269d3f commit 7409e08

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

torch/csrc/generic/Tensor.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ static void THPTensor_(setInconsistentDepthError)(std::vector<size_t> &sizes,
147147
THPUtils_setError(error.c_str());
148148
}
149149

150-
#ifdef NUMPY_TYPE_ENUM
150+
#if defined(NUMPY_TYPE_ENUM) || defined(THC_GENERIC_FILE)
151151

152152
#ifndef THC_REAL_IS_HALF
153153
#define load_real real
@@ -190,17 +190,20 @@ THTensor* THPTensor_(fromNumpy)(PyObject *numpy_array) {
190190
}
191191

192192
THTensor *result = NULL;
193+
#ifdef NUMPY_TYPE_ENUM
193194
if (PyArray_TYPE(array) == NUMPY_TYPE_ENUM) {
194195
THStoragePtr storage(THStorage_(newWithDataAndAllocator)(
195-
(real*)PyArray_DATA(array),
196+
LIBRARY_STATE (real*)PyArray_DATA(array),
196197
storage_size,
197198
// See Note [Numpy memory management]
198199
&THNumpyArrayAllocator,
199200
new NumpyArrayAllocator(numpy_array)));
200-
result = THTensor_(newWithStorage)(storage, 0, sizes, strides);
201+
result = THTensor_(newWithStorage)(LIBRARY_STATE storage, 0, sizes, strides);
201202
}
202-
else {
203-
THStoragePtr storage(THStorage_(newWithSize)(storage_size));
203+
else
204+
#endif
205+
{
206+
THStoragePtr storage(THStorage_(newWithSize)(LIBRARY_STATE storage_size));
204207
switch (PyArray_TYPE(array)) {
205208
case NPY_DOUBLE: COPY_FROM_ARRAY(double, array, storage, storage_size); break;
206209
case NPY_FLOAT: COPY_FROM_ARRAY(float, array, storage, storage_size); break;
@@ -209,7 +212,7 @@ THTensor* THPTensor_(fromNumpy)(PyObject *numpy_array) {
209212
case NPY_INT16: COPY_FROM_ARRAY(int16_t, array, storage, storage_size); break;
210213
case NPY_UINT8: COPY_FROM_ARRAY(uint8_t, array, storage, storage_size); break;
211214
}
212-
result = THTensor_(newWithStorage)(storage, 0, sizes, strides);
215+
result = THTensor_(newWithStorage)(LIBRARY_STATE storage, 0, sizes, strides);
213216
}
214217
return result;
215218
} else {
@@ -299,7 +302,7 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
299302
return (PyObject *)self.release();
300303
}
301304

302-
#ifdef NUMPY_TYPE_ENUM
305+
#if defined(NUMPY_TYPE_ENUM) || defined(THC_GENERIC_FILE)
303306
// torch.Tensor(np.ndarray array)
304307
if (num_args == 1 && PyArray_Check(first_arg)) {
305308
THPObjectPtr numpy_array(

0 commit comments

Comments
 (0)