Skip to content

Commit 7f1e331

Browse files
suopytorchmergebot
authored andcommitted
Make SymInt constructor explicit
Since we plan to have a bunch of code that is sensitive to whether or not a SymInt contains a symbolic shape or not, it seems like a bad idea to have an implicit constructor. For example, code like: ``` sizes_and_strides_.stride_at_unchecked(dim) = 0; ``` would sail through, and the `0` would get implicitly promoted to a SymInt. This is a tradeoff though: it makes code that handles `SymInt`s more clunky as `int64_t`s and integer literals need to be explicitly wrapped in `SymInt` before being used. Pull Request resolved: pytorch#77666 Approved by: https://github.com/ezyang
1 parent c673696 commit 7f1e331

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

c10/core/SymInt.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class SymbolicIntNode;
3131
// a traced operation to represent it in LTC or Fx graphs.
3232
class C10_API SymInt {
3333
public:
34-
SymInt(int64_t d) : data_(d){};
34+
explicit SymInt(int64_t d) : data_(d){};
3535

3636
int64_t expect_int() const {
3737
TORCH_CHECK(!is_symbolic());
@@ -51,7 +51,7 @@ class C10_API SymInt {
5151
TORCH_CHECK(
5252
!this->is_symbolic() && !sci.is_symbolic(),
5353
"Symbolic Add isn't supported yet");
54-
return data_ + sci.data_;
54+
return SymInt(data_ + sci.data_);
5555
}
5656

5757
std::shared_ptr<SymbolicIntNode> toSymbolicIntNode();

test/cpp/jit/test_misc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
13861386
TEST(TestSymIntArrayRef, BasicConversion) {
13871387
const size_t X = 2, Y = 4, Z = 5;
13881388
std::vector<int64_t> tgt_size_v{2, 4, 5};
1389-
std::vector<c10::SymInt> tgt_size({X, Y, Z});
1389+
std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
13901390
auto a = at::randn({1, 4, 1}, at::kCPU);
13911391
auto b = a.expand(tgt_size);
13921392
auto c = a.expand(tgt_size_v);

torch/csrc/utils/python_arg_parser.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,9 @@ inline int64_t PythonArgs::toInt64(int i) {
637637
}
638638

639639
inline c10::SymInt PythonArgs::toSymInt(int i) {
640-
if (!args[i]) return signature.params[i].default_int;
640+
if (!args[i]) {
641+
return c10::SymInt(signature.params[i].default_int);
642+
}
641643
if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
642644
auto & var = THPVariable_Unpack(args[i]);
643645
jit::tracer::ArgumentStash::stashValue(

0 commit comments

Comments
 (0)