Skip to content

Commit b0648fc

Browse files
committed
Merge commit 'be9ef9283f297997afd3bf8e21147ec6bf09ebbf'
2 parents 9ec7051 + be9ef92 commit b0648fc

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

torch/lib/ATen/function_wrapper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,13 @@ def get_arguments(option):
391391
for argument in option['arguments'] if not drop_argument(argument, option)]
392392

393393
def is_actual_return_long(ret):
394-
return ret['type'] == 'long' or (backend_type_env['ScalarName'] == 'Long' and
395-
ret['type'] == 'real' or ret['type'] == 'accreal')
394+
if ret['type'] == 'long':
395+
return True
396+
if ret['type'] == 'real':
397+
return backend_type_env['ScalarName'] == 'Long'
398+
if ret['type'] == 'accreal':
399+
return backend_type_env['AccScalarName'] == 'Long'
400+
return False
396401

397402
def handle_zero_dim(env, option):
398403
if 'zero_dim_dispatch_when_scalar' not in option:

torch/lib/ATen/templates/TensorDerived.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ IntList ${Tensor}::sizes() {
2626
int64_t ${Tensor}::dim() {
2727
if(isScalar())
2828
return 0;
29-
return ${THTensor}_nDimension(${state,}tensor);
29+
int64_t d = ${THTensor}_nDimension(${state,}tensor);
30+
if(d != 0)
31+
return d;
32+
// See Note [Undefined-dim versus 0-dim]
33+
return kUndefinedDimensions;
3034
}
3135

3236
const char * ${Tensor}::typeString() {

torch/lib/ATen/templates/Type.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <memory>
4+
#include <limits>
45

56
#include "ATen/ArrayRef.h"
67
#include "ATen/Half.h"
@@ -45,6 +46,18 @@ constexpr Backend kCUDA = Backend::CUDA;
4546
constexpr Backend kSparseCPU = Backend::SparseCPU;
4647
constexpr Backend kSparseCUDA = Backend::SparseCUDA;
4748

49+
// Note [Undefined-dim versus 0-dim]
50+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51+
// Unlike Torch, ATen treats zero-dimension tensors as having ONE
52+
// element (that is to say, a zero-dimensional tensor is a scalar!)
53+
// This is in contrast to Torch, where a zero-dimension tensor has
54+
// zero elements.
55+
//
56+
// Because we are backed by Torch tensors, we need to be able to
57+
// represent this state (of numel==0). kUndefinedDimensions represents this
58+
// situation.
59+
constexpr int64_t kUndefinedDimensions = std::numeric_limits<int64_t>::min();
60+
4861
static inline const char * toString(Backend b) {
4962
switch(b) {
5063
case Backend::CPU: return "CPU";

torch/lib/ATen/test/basic.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ static void test(Type & type) {
208208
Tensor b = CPU(kFloat).ones({3,7});
209209
Tensor c = cat({a,b},1);
210210
std::cout << c << std::endl;
211+
212+
Tensor e = CPU(kFloat).rand({});
213+
check(*e.data<float>()== e.sum().toFloat());
211214
}
212215

213216
}

0 commit comments

Comments
 (0)