Skip to content

Commit f7de7ba

Browse files
committed
Merge commit 'fd97d92479e32e550866adfd1f0465e4cfa5e581'
2 parents f3aa97f + fd97d92 commit f7de7ba

File tree

8 files changed

+16
-12
lines changed

8 files changed

+16
-12
lines changed

torch/lib/ATen/TensorImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct TensorImpl {
2020
virtual int64_t dim() = 0;
2121
virtual Scalar localScalar() = 0;
2222
virtual void assign_(Scalar s) = 0;
23-
virtual void * unsafeGetTH() = 0;
23+
virtual void * unsafeGetTH(bool retain) = 0;
2424
void retain() {
2525
++refcount;
2626
}

torch/lib/ATen/templates/Tensor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct Tensor {
2929

3030
Tensor()
3131
: pImpl(nullptr){}
32-
explicit Tensor(TensorImpl * self, bool retain = true)
32+
explicit Tensor(TensorImpl * self, bool retain)
3333
: pImpl(self) {
3434
if(pImpl != nullptr && retain)
3535
pImpl->retain();
@@ -67,7 +67,7 @@ struct Tensor {
6767
Tensor().swap(*this);
6868
}
6969
void reset(TensorImpl * rhs) {
70-
Tensor(rhs).swap(*this);
70+
Tensor(rhs,true).swap(*this);
7171
}
7272
void reset(TensorImpl * rhs, bool retain) {
7373
Tensor(rhs, retain).swap(*this );
@@ -126,8 +126,8 @@ struct Tensor {
126126
template<typename T>
127127
T * data() const;
128128

129-
void * unsafeGetTH() {
130-
return pImpl->unsafeGetTH();
129+
void * unsafeGetTH(bool retain) {
130+
return pImpl->unsafeGetTH(retain);
131131
}
132132

133133
//toLongData(), toFloatData() etc.

torch/lib/ATen/templates/TensorDerived.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ int64_t ${Tensor}::dim() {
3636
const char * ${Tensor}::typeString() {
3737
return "${Type}";
3838
}
39-
void * ${Tensor}::unsafeGetTH() {
39+
void * ${Tensor}::unsafeGetTH(bool retain) {
40+
if (retain)
41+
${THTensor}_retain(${state,} tensor);
4042
return tensor;
4143
}
4244

torch/lib/ATen/templates/TensorDerived.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct ${Tensor} : public TensorImpl {
2020
virtual int64_t dim() override;
2121
virtual Scalar localScalar() override;
2222
virtual void assign_(Scalar s) override;
23-
virtual void * unsafeGetTH() override;
23+
virtual void * unsafeGetTH(bool retain) override;
2424
static const char * typeString();
2525

2626
//TODO(zach): sort of friend permissions later so this

torch/lib/ATen/templates/Type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ struct Type {
109109
virtual std::unique_ptr<Storage> storage(size_t size) = 0;
110110
virtual std::unique_ptr<Storage> storageFromBlob(void * data, int64_t size) = 0;
111111
virtual std::unique_ptr<Generator> generator() = 0;
112-
virtual Tensor unsafeTensorFromTH(void * th_pointer) = 0;
112+
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) = 0;
113113
virtual const char * toString() const = 0;
114114
Type & toBackend(Backend b);
115115
Type & toScalarType(ScalarType s);

torch/lib/ATen/templates/TypeDerived.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ std::unique_ptr<Storage> ${Type}::storageFromBlob(void * data, int64_t size) {
3333
return std::unique_ptr<Storage>(
3434
new ${Storage}(context,data,size));
3535
}
36-
Tensor ${Type}::unsafeTensorFromTH(void * th_pointer) {
37-
return Tensor(new ${Tensor}(context,(${THTensor}*)(th_pointer)));
36+
Tensor ${Type}::unsafeTensorFromTH(void * th_pointer,bool retain) {
37+
if (retain)
38+
${THTensor}_retain(${state,} (${THTensor}*) th_pointer);
39+
return Tensor(new ${Tensor}(context,(${THTensor}*)(th_pointer)),false);
3840
}
3941
std::unique_ptr<Generator> ${Type}::generator() {
4042
return std::unique_ptr<Generator>(new ${Generator}(context));

torch/lib/ATen/templates/TypeDerived.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct ${Type} : public Type {
2020
virtual const char * toString() const override;
2121
virtual TypeID ID() const override;
2222
static const char * typeString();
23-
Tensor unsafeTensorFromTH(void * th_pointer) override;
23+
Tensor unsafeTensorFromTH(void * th_pointer, bool retain) override;
2424

2525
// example
2626
// virtual Tensor * add(Tensor & a, Tensor & b) override;

torch/lib/ATen/test/basic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ static void test(Type & type) {
200200
int a = 4;
201201
THFloatTensor *t = THFloatTensor_newWithSize2d(a, a);
202202
THFloatTensor_fill(t, a);
203-
Tensor tt = CPU(kFloat).unsafeTensorFromTH(t);
203+
Tensor tt = CPU(kFloat).unsafeTensorFromTH(t,false);
204204
std::cout << tt << std::endl;
205205
}
206206
{

0 commit comments

Comments
 (0)