Skip to content

Commit 1f6f82d

Browse files
committed
Fall back to indexing compatible with numpy
1 parent 1f89399 commit 1f6f82d

File tree

2 files changed

+87
-77
lines changed

2 files changed

+87
-77
lines changed

test/test_torch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,8 +1892,9 @@ def test_index(self):
18921892
reference = self._consecutive((5, 5, 5))
18931893
idx = torch.LongTensor([2, 4])
18941894
self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
1895-
self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
1896-
self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
1895+
# TODO: enable one indexing is implemented like in numpy
1896+
# self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
1897+
# self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
18971898

18981899
# None indexing
18991900
self.assertEqual(reference[2, None], reference[2].unsqueeze(0))
@@ -1944,6 +1945,7 @@ def checkPartialAssign(index):
19441945
checkPartialAssign((0, 1))
19451946
checkPartialAssign((1, 2))
19461947
checkPartialAssign((0, 2))
1948+
checkPartialAssign(torch.LongTensor((0, 2)))
19471949

19481950
with self.assertRaises(IndexError):
19491951
reference[1, 1, 1, 1] = 1
@@ -1964,10 +1966,8 @@ def checkPartialAssign(index):
19641966
with self.assertRaises(TypeError):
19651967
reference[0.0, :, 0.0] = 1
19661968

1967-
# LongTensor assignments are not supported yet
1968-
with self.assertRaises(RuntimeError):
1969-
reference[torch.LongTensor([2, 4])] = 1
1970-
with self.assertRaises(RuntimeError):
1969+
# LongTensor assignments are not fully supported yet
1970+
with self.assertRaises(TypeError):
19711971
reference[0, torch.LongTensor([2, 4])] = 1
19721972

19731973
def test_index_copy(self):

torch/csrc/generic/Tensor.cpp

Lines changed: 81 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#endif
2424

2525
PyObject *THPTensorClass = NULL;
26+
THPCopyList THTensor_(copy_functions);
2627

2728
PyObject * THPTensor_(NewEmpty)()
2829
{
@@ -425,7 +426,6 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
425426
#endif
426427

427428

428-
template<bool allow_index>
429429
static bool THPTensor_(_indexOnce)(PyObject *index, int &indexed_dim,
430430
THTensorPtr &tresult, THStorage* &sresult, long &storage_offset)
431431
{
@@ -478,22 +478,13 @@ static bool THPTensor_(_indexOnce)(PyObject *index, int &indexed_dim,
478478
tresult->stride[indexed_dim] *= step;
479479
tresult->size[indexed_dim] /= step;
480480
indexed_dim++;
481-
// Indexing with a LongTensor
482-
} else if (THPIndexTensor_Check(index)) {
483-
if (!allow_index)
484-
throw std::runtime_error("assignments using LongTensors as index aren't supported yet");
485-
THIndexTensor *index_t = ((THPIndexTensor*)index)->cdata;
486-
THTensorPtr index_result = THTensor_(new)(LIBRARY_STATE_NOARGS);
487-
THTensor_(indexSelect)(LIBRARY_STATE index_result.get(), tresult.get(), indexed_dim++, index_t);
488-
tresult = index_result.release();
489481
} else {
490482
return false;
491483
}
492484
return true;
493485
}
494486

495487

496-
template<bool allow_index>
497488
static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
498489
THTensorPtr &tresult, THStorage * &sresult, long &storage_offset)
499490
{
@@ -531,7 +522,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
531522
continue;
532523
}
533524
PyObject *dimidx = PyTuple_GET_ITEM(index, dim);
534-
valid = THPTensor_(_indexOnce)<allow_index>(dimidx, indexed_dim, tresult, sresult, storage_offset);
525+
valid = THPTensor_(_indexOnce)(dimidx, indexed_dim, tresult, sresult, storage_offset);
535526
if (!valid) {
536527
tresult = NULL;
537528
// overwrite this, so the message mentions the incorrect object
@@ -540,78 +531,28 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
540531
}
541532
}
542533
if (valid) return true;
543-
} else if (index == Py_Ellipsis) return true;
544-
else {
545-
if (THPTensor_(_indexOnce)<allow_index>(index, indexed_dim, tresult, sresult, storage_offset))
534+
} else if (index == Py_Ellipsis) {
535+
return true;
536+
} else {
537+
if (THPTensor_(_indexOnce)(index, indexed_dim, tresult, sresult, storage_offset))
546538
return true;
547539
}
548540

549541
PyErr_Format(PyExc_TypeError, "indexing a tensor with an object of type %s. "
550542
"The only supported types are integers, slices"
551543
#ifdef WITH_NUMPY
552-
", numpy scalars"
544+
", numpy scalars and "
553545
#endif
554546
#ifndef THC_GENERIC_FILE
555-
", torch.LongTensor and torch.ByteTensor.",
547+
"torch.LongTensor or torch.ByteTensor as the only argument.",
556548
#else
557-
", torch.cuda.LongTensor and torch.cuda.ByteTensor.",
549+
"torch.cuda.LongTensor or torch.cuda.ByteTensor as the only argument.",
558550
#endif
559551
THPUtils_typename(index));
560552
return false;
561553
}
562554
#undef IS_SCALAR
563-
#undef THIndexTensor
564-
#undef THIndexTensor_
565-
#undef THPIndexTensor
566-
#undef THPIndexTensor_Check
567-
568-
extern THPCopyList THTensor_(copy_functions);
569-
THPCopyList THTensor_(copy_functions);
570-
571-
void THPTensor_(initCopyMethods)()
572-
{
573-
auto& h = THTensor_(copy_functions);
574-
// copy from CPU types
575-
THPInsertCopyFunction(h, &THTensor_(copyByte));
576-
THPInsertCopyFunction(h, &THTensor_(copyChar));
577-
THPInsertCopyFunction(h, &THTensor_(copyShort));
578-
THPInsertCopyFunction(h, &THTensor_(copyInt));
579-
THPInsertCopyFunction(h, &THTensor_(copyLong));
580-
THPInsertCopyFunction(h, &THTensor_(copyFloat));
581-
THPInsertCopyFunction(h, &THTensor_(copyDouble));
582-
#ifdef THC_GENERIC_FILE
583-
// copy from GPU types
584-
THPInsertCopyFunction(h, &THTensor_(copyCudaByte));
585-
THPInsertCopyFunction(h, &THTensor_(copyCudaChar));
586-
THPInsertCopyFunction(h, &THTensor_(copyCudaShort));
587-
THPInsertCopyFunction(h, &THTensor_(copyCudaInt));
588-
THPInsertCopyFunction(h, &THTensor_(copyCudaLong));
589-
THPInsertCopyFunction(h, &THTensor_(copyCudaFloat));
590-
THPInsertCopyFunction(h, &THTensor_(copyCudaDouble));
591-
#ifdef CUDA_HALF_TENSOR
592-
THPInsertCopyFunction(h, &THTensor_(copyCudaHalf));
593-
#endif
594-
#ifndef THC_REAL_IS_HALF
595-
THPInsertCopyFunction(h, &THCTensor_(copyAsyncCPU), true);
596-
// add CPU <- GPU copies to base type
597-
#define THCpuTensor_(name) TH_CONCAT_4(TH, Real, Tensor_, name)
598-
extern THPCopyList THCpuTensor_(copy_functions);
599-
auto& b = THCpuTensor_(copy_functions);
600-
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaByte));
601-
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaChar));
602-
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaShort));
603-
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaInt));
604-
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaLong));
605-
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaFloat));
606-
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaDouble));
607-
#ifdef CUDA_HALF_TENSOR
608-
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaHalf));
609-
#endif
610-
THPInsertCopyFunction(b, &THCpuTensor_(copyAsyncCuda), true);
611-
#undef THCpuTensor_
612-
#endif
613-
#endif
614-
}
555+
#undef UNPACK_SCALAR
615556

616557
template<bool force_tensor>
617558
static PyObject * THPTensor_(getValue)(THPTensor *self, PyObject *index)
@@ -629,11 +570,17 @@ static PyObject * THPTensor_(getValue)(THPTensor *self, PyObject *index)
629570
THTensor_(maskedSelect)(LIBRARY_STATE t.get(), self->cdata, mask->cdata);
630571
return THPTensor_(New)(t.release());
631572
}
573+
if (THPIndexTensor_Check(index)) {
574+
THIndexTensor *index_t = ((THPIndexTensor*)index)->cdata;
575+
THTensorPtr index_result = THTensor_(new)(LIBRARY_STATE_NOARGS);
576+
THTensor_(indexSelect)(LIBRARY_STATE index_result.get(), self->cdata, 0, index_t);
577+
return THPTensor_(New)(index_result.release());
578+
}
632579

633580
THTensorPtr tresult;
634581
THStorage *sresult;
635582
long storage_offset;
636-
if (!THPTensor_(_index)<true>(self, index, tresult, sresult, storage_offset))
583+
if (!THPTensor_(_index)(self, index, tresult, sresult, storage_offset))
637584
return NULL;
638585
if (tresult)
639586
return THPTensor_(New)(tresult.release());
@@ -675,11 +622,25 @@ static int THPTensor_(setValue)(THPTensor *self, PyObject *index, PyObject *valu
675622
}
676623
return 0;
677624
}
625+
if (THPIndexTensor_Check(index)) {
626+
THIndexTensor *index_t = ((THPIndexTensor*)index)->cdata;
627+
if (THPUtils_(checkReal)(value)) {
628+
real v = THPUtils_(unpackReal)(value);
629+
THTensor_(indexFill)(LIBRARY_STATE self->cdata, 0, index_t, v);
630+
} else if (THPTensor_(Check)(value)) {
631+
THTensor_(indexCopy)(LIBRARY_STATE self->cdata, 0, index_t, ((THPTensor*)value)->cdata);
632+
} else {
633+
THPUtils_setError("can't assign %s to a " THPTensorStr " using a LongTensor "
634+
"(only " THPTensorStr " or %s are supported)",
635+
THPUtils_typename(value), THPUtils_typeTraits<real>::python_type_str);
636+
}
637+
return 0;
638+
}
678639

679640
THTensorPtr tresult;
680641
THStorage *sresult;
681642
long storage_offset;
682-
if (!THPTensor_(_index)<false>(self, index, tresult, sresult, storage_offset))
643+
if (!THPTensor_(_index)(self, index, tresult, sresult, storage_offset))
683644
return -1;
684645
if (sresult) {
685646
if (!force_tensor) {
@@ -714,6 +675,10 @@ static int THPTensor_(setValue)(THPTensor *self, PyObject *index, PyObject *valu
714675
return -1;
715676
END_HANDLE_TH_ERRORS_RET(-1)
716677
}
678+
#undef THIndexTensor
679+
#undef THIndexTensor_
680+
#undef THPIndexTensor
681+
#undef THPIndexTensor_Check
717682

718683
Py_ssize_t THPTensor_(length)(THPTensor *self)
719684
{
@@ -831,6 +796,51 @@ PyTypeObject THPTensorStatelessType = {
831796

832797
#include "SparseTensor.cpp"
833798

799+
void THPTensor_(initCopyMethods)()
800+
{
801+
auto& h = THTensor_(copy_functions);
802+
// copy from CPU types
803+
THPInsertCopyFunction(h, &THTensor_(copyByte));
804+
THPInsertCopyFunction(h, &THTensor_(copyChar));
805+
THPInsertCopyFunction(h, &THTensor_(copyShort));
806+
THPInsertCopyFunction(h, &THTensor_(copyInt));
807+
THPInsertCopyFunction(h, &THTensor_(copyLong));
808+
THPInsertCopyFunction(h, &THTensor_(copyFloat));
809+
THPInsertCopyFunction(h, &THTensor_(copyDouble));
810+
#ifdef THC_GENERIC_FILE
811+
// copy from GPU types
812+
THPInsertCopyFunction(h, &THTensor_(copyCudaByte));
813+
THPInsertCopyFunction(h, &THTensor_(copyCudaChar));
814+
THPInsertCopyFunction(h, &THTensor_(copyCudaShort));
815+
THPInsertCopyFunction(h, &THTensor_(copyCudaInt));
816+
THPInsertCopyFunction(h, &THTensor_(copyCudaLong));
817+
THPInsertCopyFunction(h, &THTensor_(copyCudaFloat));
818+
THPInsertCopyFunction(h, &THTensor_(copyCudaDouble));
819+
#ifdef CUDA_HALF_TENSOR
820+
THPInsertCopyFunction(h, &THTensor_(copyCudaHalf));
821+
#endif
822+
#ifndef THC_REAL_IS_HALF
823+
THPInsertCopyFunction(h, &THCTensor_(copyAsyncCPU), true);
824+
// add CPU <- GPU copies to base type
825+
#define THCpuTensor_(name) TH_CONCAT_4(TH, Real, Tensor_, name)
826+
extern THPCopyList THCpuTensor_(copy_functions);
827+
auto& b = THCpuTensor_(copy_functions);
828+
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaByte));
829+
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaChar));
830+
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaShort));
831+
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaInt));
832+
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaLong));
833+
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaFloat));
834+
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaDouble));
835+
#ifdef CUDA_HALF_TENSOR
836+
THPInsertCopyFunction(b, &THCpuTensor_(copyCudaHalf));
837+
#endif
838+
THPInsertCopyFunction(b, &THCpuTensor_(copyAsyncCuda), true);
839+
#undef THCpuTensor_
840+
#endif
841+
#endif
842+
}
843+
834844
bool THPTensor_(init)(PyObject *module)
835845
{
836846
#ifndef THC_GENERIC_FILE

0 commit comments

Comments
 (0)