Skip to content

Commit 33ac9cd

Browse files
killeentsoumith
authored andcommitted
add ATen tensor support to pytorch tuple_parser (pytorch#2102)
1 parent 52a9367 commit 33ac9cd

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

torch/csrc/utils/tuple_parser.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ auto TupleParser::parse(std::shared_ptr<thpp::Tensor>& x, const std::string& par
5252
x.reset(torch::createTensor(obj)->clone_shallow());
5353
}
5454

55+
auto TupleParser::parse(at::Tensor& x, const std::string& param_name) -> void {
56+
PyObject* obj = next_arg();
57+
x = torch::createTensorAT(obj);
58+
}
59+
5560
auto TupleParser::parse(std::vector<int>& x, const std::string& param_name) -> void {
5661
PyObject* obj = next_arg();
5762
if (!PyTuple_Check(obj)) {

torch/csrc/utils/tuple_parser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <memory>
55
#include <vector>
66
#include <THPP/THPP.h>
7+
#include <ATen/ATen.h>
78

89
namespace torch {
910

@@ -15,6 +16,7 @@ struct TupleParser {
1516
void parse(double& x, const std::string& param_name);
1617
void parse(std::unique_ptr<thpp::Tensor>& x, const std::string& param_name);
1718
void parse(std::shared_ptr<thpp::Tensor>& x, const std::string& param_name);
19+
void parse(at::Tensor& x, const std::string& param_name);
1820
void parse(std::vector<int>& x, const std::string& param_name);
1921
void parse(std::string& x, const std::string& param_name);
2022

0 commit comments

Comments
 (0)