2323#endif
2424
2525PyObject *THPTensorClass = NULL ;
26+ THPCopyList THTensor_ (copy_functions);
2627
2728PyObject * THPTensor_ (NewEmpty)()
2829{
@@ -425,7 +426,6 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
425426#endif
426427
427428
428- template <bool allow_index>
429429static bool THPTensor_ (_indexOnce)(PyObject *index, int &indexed_dim,
430430 THTensorPtr &tresult, THStorage* &sresult, long &storage_offset)
431431{
@@ -478,22 +478,13 @@ static bool THPTensor_(_indexOnce)(PyObject *index, int &indexed_dim,
478478 tresult->stride [indexed_dim] *= step;
479479 tresult->size [indexed_dim] /= step;
480480 indexed_dim++;
481- // Indexing with a LongTensor
482- } else if (THPIndexTensor_Check (index)) {
483- if (!allow_index)
484- throw std::runtime_error (" assignments using LongTensors as index aren't supported yet" );
485- THIndexTensor *index_t = ((THPIndexTensor*)index)->cdata ;
486- THTensorPtr index_result = THTensor_ (new )(LIBRARY_STATE_NOARGS);
487- THTensor_ (indexSelect)(LIBRARY_STATE index_result.get (), tresult.get (), indexed_dim++, index_t );
488- tresult = index_result.release ();
489481 } else {
490482 return false ;
491483 }
492484 return true ;
493485}
494486
495487
496- template <bool allow_index>
497488static bool THPTensor_ (_index)(THPTensor *self, PyObject *index,
498489 THTensorPtr &tresult, THStorage * &sresult, long &storage_offset)
499490{
@@ -531,7 +522,7 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
531522 continue ;
532523 }
533524 PyObject *dimidx = PyTuple_GET_ITEM (index, dim);
534- valid = THPTensor_ (_indexOnce)<allow_index> (dimidx, indexed_dim, tresult, sresult, storage_offset);
525+ valid = THPTensor_ (_indexOnce)(dimidx, indexed_dim, tresult, sresult, storage_offset);
535526 if (!valid) {
536527 tresult = NULL ;
537528 // overwrite this, so the message mentions the incorrect object
@@ -540,78 +531,28 @@ static bool THPTensor_(_index)(THPTensor *self, PyObject *index,
540531 }
541532 }
542533 if (valid) return true ;
543- } else if (index == Py_Ellipsis) return true ;
544- else {
545- if (THPTensor_ (_indexOnce)<allow_index>(index, indexed_dim, tresult, sresult, storage_offset))
534+ } else if (index == Py_Ellipsis) {
535+ return true ;
536+ } else {
537+ if (THPTensor_ (_indexOnce)(index, indexed_dim, tresult, sresult, storage_offset))
546538 return true ;
547539 }
548540
549541 PyErr_Format (PyExc_TypeError, " indexing a tensor with an object of type %s. "
550542 " The only supported types are integers, slices"
551543#ifdef WITH_NUMPY
552- " , numpy scalars"
544+ " , numpy scalars and "
553545#endif
554546#ifndef THC_GENERIC_FILE
555- " , torch.LongTensor and torch.ByteTensor." ,
547+ " torch.LongTensor or torch.ByteTensor as the only argument ." ,
556548#else
557- " , torch.cuda.LongTensor and torch.cuda.ByteTensor." ,
549+ " torch.cuda.LongTensor or torch.cuda.ByteTensor as the only argument ." ,
558550#endif
559551 THPUtils_typename (index));
560552 return false ;
561553}
562554#undef IS_SCALAR
563- #undef THIndexTensor
564- #undef THIndexTensor_
565- #undef THPIndexTensor
566- #undef THPIndexTensor_Check
567-
568- extern THPCopyList THTensor_ (copy_functions);
569- THPCopyList THTensor_ (copy_functions);
570-
571- void THPTensor_ (initCopyMethods)()
572- {
573- auto & h = THTensor_ (copy_functions);
574- // copy from CPU types
575- THPInsertCopyFunction (h, &THTensor_ (copyByte));
576- THPInsertCopyFunction (h, &THTensor_ (copyChar));
577- THPInsertCopyFunction (h, &THTensor_ (copyShort));
578- THPInsertCopyFunction (h, &THTensor_ (copyInt));
579- THPInsertCopyFunction (h, &THTensor_ (copyLong));
580- THPInsertCopyFunction (h, &THTensor_ (copyFloat));
581- THPInsertCopyFunction (h, &THTensor_ (copyDouble));
582- #ifdef THC_GENERIC_FILE
583- // copy from GPU types
584- THPInsertCopyFunction (h, &THTensor_ (copyCudaByte));
585- THPInsertCopyFunction (h, &THTensor_ (copyCudaChar));
586- THPInsertCopyFunction (h, &THTensor_ (copyCudaShort));
587- THPInsertCopyFunction (h, &THTensor_ (copyCudaInt));
588- THPInsertCopyFunction (h, &THTensor_ (copyCudaLong));
589- THPInsertCopyFunction (h, &THTensor_ (copyCudaFloat));
590- THPInsertCopyFunction (h, &THTensor_ (copyCudaDouble));
591- #ifdef CUDA_HALF_TENSOR
592- THPInsertCopyFunction (h, &THTensor_ (copyCudaHalf));
593- #endif
594- #ifndef THC_REAL_IS_HALF
595- THPInsertCopyFunction (h, &THCTensor_ (copyAsyncCPU), true );
596- // add CPU <- GPU copies to base type
597- #define THCpuTensor_ (name ) TH_CONCAT_4(TH, Real, Tensor_, name)
598- extern THPCopyList THCpuTensor_ (copy_functions);
599- auto & b = THCpuTensor_ (copy_functions);
600- THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaByte));
601- THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaChar));
602- THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaShort));
603- THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaInt));
604- THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaLong));
605- THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaFloat));
606- THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaDouble));
607- #ifdef CUDA_HALF_TENSOR
608- THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaHalf));
609- #endif
610- THPInsertCopyFunction (b, &THCpuTensor_ (copyAsyncCuda), true );
611- #undef THCpuTensor_
612- #endif
613- #endif
614- }
555+ #undef UNPACK_SCALAR
615556
616557template <bool force_tensor>
617558static PyObject * THPTensor_ (getValue)(THPTensor *self, PyObject *index)
@@ -629,11 +570,17 @@ static PyObject * THPTensor_(getValue)(THPTensor *self, PyObject *index)
629570 THTensor_ (maskedSelect)(LIBRARY_STATE t.get (), self->cdata , mask->cdata );
630571 return THPTensor_ (New)(t.release ());
631572 }
573+ if (THPIndexTensor_Check (index)) {
574+ THIndexTensor *index_t = ((THPIndexTensor*)index)->cdata ;
575+ THTensorPtr index_result = THTensor_ (new )(LIBRARY_STATE_NOARGS);
576+ THTensor_ (indexSelect)(LIBRARY_STATE index_result.get (), self->cdata , 0 , index_t );
577+ return THPTensor_ (New)(index_result.release ());
578+ }
632579
633580 THTensorPtr tresult;
634581 THStorage *sresult;
635582 long storage_offset;
636- if (!THPTensor_ (_index)< true > (self, index, tresult, sresult, storage_offset))
583+ if (!THPTensor_ (_index)(self, index, tresult, sresult, storage_offset))
637584 return NULL ;
638585 if (tresult)
639586 return THPTensor_ (New)(tresult.release ());
@@ -675,11 +622,25 @@ static int THPTensor_(setValue)(THPTensor *self, PyObject *index, PyObject *valu
675622 }
676623 return 0 ;
677624 }
625+ if (THPIndexTensor_Check (index)) {
626+ THIndexTensor *index_t = ((THPIndexTensor*)index)->cdata ;
627+ if (THPUtils_ (checkReal)(value)) {
628+ real v = THPUtils_ (unpackReal)(value);
629+ THTensor_ (indexFill)(LIBRARY_STATE self->cdata , 0 , index_t , v);
630+ } else if (THPTensor_ (Check)(value)) {
631+ THTensor_ (indexCopy)(LIBRARY_STATE self->cdata , 0 , index_t , ((THPTensor*)value)->cdata );
632+ } else {
633+ THPUtils_setError (" can't assign %s to a " THPTensorStr " using a LongTensor "
634+ " (only " THPTensorStr " or %s are supported)" ,
635+ THPUtils_typename (value), THPUtils_typeTraits<real>::python_type_str);
636+ }
637+ return 0 ;
638+ }
678639
679640 THTensorPtr tresult;
680641 THStorage *sresult;
681642 long storage_offset;
682- if (!THPTensor_ (_index)< false > (self, index, tresult, sresult, storage_offset))
643+ if (!THPTensor_ (_index)(self, index, tresult, sresult, storage_offset))
683644 return -1 ;
684645 if (sresult) {
685646 if (!force_tensor) {
@@ -714,6 +675,10 @@ static int THPTensor_(setValue)(THPTensor *self, PyObject *index, PyObject *valu
714675 return -1 ;
715676 END_HANDLE_TH_ERRORS_RET (-1 )
716677}
678+ #undef THIndexTensor
679+ #undef THIndexTensor_
680+ #undef THPIndexTensor
681+ #undef THPIndexTensor_Check
717682
718683Py_ssize_t THPTensor_ (length)(THPTensor *self)
719684{
@@ -831,6 +796,51 @@ PyTypeObject THPTensorStatelessType = {
831796
832797#include " SparseTensor.cpp"
833798
799+ void THPTensor_ (initCopyMethods)()
800+ {
801+ auto & h = THTensor_ (copy_functions);
802+ // copy from CPU types
803+ THPInsertCopyFunction (h, &THTensor_ (copyByte));
804+ THPInsertCopyFunction (h, &THTensor_ (copyChar));
805+ THPInsertCopyFunction (h, &THTensor_ (copyShort));
806+ THPInsertCopyFunction (h, &THTensor_ (copyInt));
807+ THPInsertCopyFunction (h, &THTensor_ (copyLong));
808+ THPInsertCopyFunction (h, &THTensor_ (copyFloat));
809+ THPInsertCopyFunction (h, &THTensor_ (copyDouble));
810+ #ifdef THC_GENERIC_FILE
811+ // copy from GPU types
812+ THPInsertCopyFunction (h, &THTensor_ (copyCudaByte));
813+ THPInsertCopyFunction (h, &THTensor_ (copyCudaChar));
814+ THPInsertCopyFunction (h, &THTensor_ (copyCudaShort));
815+ THPInsertCopyFunction (h, &THTensor_ (copyCudaInt));
816+ THPInsertCopyFunction (h, &THTensor_ (copyCudaLong));
817+ THPInsertCopyFunction (h, &THTensor_ (copyCudaFloat));
818+ THPInsertCopyFunction (h, &THTensor_ (copyCudaDouble));
819+ #ifdef CUDA_HALF_TENSOR
820+ THPInsertCopyFunction (h, &THTensor_ (copyCudaHalf));
821+ #endif
822+ #ifndef THC_REAL_IS_HALF
823+ THPInsertCopyFunction (h, &THCTensor_ (copyAsyncCPU), true );
824+ // add CPU <- GPU copies to base type
825+ #define THCpuTensor_ (name ) TH_CONCAT_4(TH, Real, Tensor_, name)
826+ extern THPCopyList THCpuTensor_ (copy_functions);
827+ auto & b = THCpuTensor_ (copy_functions);
828+ THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaByte));
829+ THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaChar));
830+ THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaShort));
831+ THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaInt));
832+ THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaLong));
833+ THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaFloat));
834+ THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaDouble));
835+ #ifdef CUDA_HALF_TENSOR
836+ THPInsertCopyFunction (b, &THCpuTensor_ (copyCudaHalf));
837+ #endif
838+ THPInsertCopyFunction (b, &THCpuTensor_ (copyAsyncCuda), true );
839+ #undef THCpuTensor_
840+ #endif
841+ #endif
842+ }
843+
834844bool THPTensor_ (init)(PyObject *module )
835845{
836846#ifndef THC_GENERIC_FILE
0 commit comments