Skip to content

Commit fa0451a

Browse files
committed
Error message improvements.
* Add C file/line to THError and THArgCheck. Add tensor sizeStr() method. * Add THAssert THAssertMsg functions. * Add sizeDesc() and desc() functions, which return tensor descriptor strings for use in error messages. To avoid memory leaks, we have to pass these strings on the stack within structs, so it's a bit unwieldly. * Improved a bunch of error messages. e.g. ``` th> torch.mm(torch.Tensor(4,4), torch.Tensor(5,5)) size mismatch, m1: [4 x 4], m2: [5 x 5] at /home/alerer/git/torch7-2/lib/TH/generic/THTensorMath.c:511 stack traceback: ... th> torch.Tensor(1):size(4) [...]:1: bad argument pytorch#1 to 'size' (dimension 4 out of range of 1D tensor at /home/alerer/git/torch7-2/generic/Tensor.c:16) stack traceback: ... ```
1 parent 186df5a commit fa0451a

File tree

6 files changed

+208
-61
lines changed

6 files changed

+208
-61
lines changed

THGeneral.c

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,33 @@ static void defaultTorchErrorHandlerFunction(const char *msg, void *data)
1313
static __thread void (*torchErrorHandlerFunction)(const char *msg, void *data) = defaultTorchErrorHandlerFunction;
1414
static __thread void *torchErrorHandlerData;
1515

16-
void THError(const char *fmt, ...)
16+
void _THError(const char *file, const int line, const char *fmt, ...)
1717
{
18-
char msg[1024];
18+
char msg[2048];
1919
va_list args;
2020

2121
/* vasprintf not standard */
2222
/* vsnprintf: how to handle if does not exists? */
2323
va_start(args, fmt);
24-
vsnprintf(msg, 1024, fmt, args);
24+
int n = vsnprintf(msg, 2048, fmt, args);
2525
va_end(args);
2626

27+
if(n < 2048) {
28+
snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
29+
}
30+
2731
(*torchErrorHandlerFunction)(msg, torchErrorHandlerData);
2832
}
2933

34+
void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) {
35+
char msg[1024];
36+
va_list args;
37+
va_start(args, fmt);
38+
vsnprintf(msg, 1024, fmt, args);
39+
va_end(args);
40+
_THError(file, line, "Assertion `%s' failed. %s", exp, msg);
41+
}
42+
3043
void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg, void *data), void *data )
3144
{
3245
if(torchErrorHandlerFunction_)
@@ -49,10 +62,24 @@ static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg,
4962
static __thread void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data) = defaultTorchArgErrorHandlerFunction;
5063
static __thread void *torchArgErrorHandlerData;
5164

52-
void THArgCheck(int condition, int argNumber, const char *msg)
65+
void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...)
5366
{
54-
if(!condition)
67+
if(!condition) {
68+
char msg[2048];
69+
va_list args;
70+
71+
/* vasprintf not standard */
72+
/* vsnprintf: how to handle if does not exists? */
73+
va_start(args, fmt);
74+
int n = vsnprintf(msg, 2048, fmt, args);
75+
va_end(args);
76+
77+
if(n < 2048) {
78+
snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
79+
}
80+
5581
(*torchArgErrorHandlerFunction)(argNumber, msg, torchArgErrorHandlerData);
82+
}
5683
}
5784

5885
void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber, const char *msg, void *data), void *data )

THGeneral.h.in

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,30 @@
4444
#endif
4545

4646
TH_API double THLog1p(const double x);
47-
TH_API void THError(const char *fmt, ...);
47+
TH_API void _THError(const char *file, const int line, const char *fmt, ...);
48+
TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...);
4849
TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg, void *data), void *data );
49-
TH_API void THArgCheck(int condition, int argNumber, const char *msg);
50+
TH_API void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...);
5051
TH_API void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data), void *data );
5152
TH_API void* THAlloc(long size);
5253
TH_API void* THRealloc(void *ptr, long size);
5354
TH_API void THFree(void *ptr);
5455

56+
#define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__)
57+
#define THArgCheck(...) _THArgCheck(__FILE__, __LINE__, __VA_ARGS__)
58+
#define THAssert(exp) \
59+
do { \
60+
if (!(exp)) { \
61+
_THAssertionFailed(__FILE__, __LINE__, #exp, ""); \
62+
} \
63+
} while(0)
64+
#define THAssertMsg(exp, ...) \
65+
do { \
66+
if (!(exp)) { \
67+
_THAssertionFailed(__FILE__, __LINE__, #exp, __VA_ARGS__); \
68+
} \
69+
} while(0)
70+
5571
#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y)
5672
#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y
5773

THTensor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
#define THTensor TH_CONCAT_3(TH,Real,Tensor)
88
#define THTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME)
99

10+
#define TH_DESC_BUFF_LEN 64
11+
typedef struct {
12+
char str[TH_DESC_BUFF_LEN];
13+
} THDescBuff;
14+
1015
/* basics */
1116
#include "generic/THTensor.h"
1217
#include "THGenerateAllTypes.h"

generic/THTensor.c

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ int THTensor_(nDimension)(const THTensor *self)
2020

2121
long THTensor_(size)(const THTensor *self, int dim)
2222
{
23-
THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "out of range");
23+
THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "dimension %d out of range of %dD tensor",
24+
dim+1, THTensor_(nDimension)(self));
2425
return self->size[dim];
2526
}
2627

2728
long THTensor_(stride)(const THTensor *self, int dim)
2829
{
29-
THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "out of range");
30+
THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "dimension %d out of range of %dD tensor", dim+1,
31+
THTensor_(nDimension)(self));
3032
return self->stride[dim];
3133
}
3234

@@ -751,4 +753,48 @@ real THTensor_(get4d)(const THTensor *tensor, long x0, long x1, long x2, long x3
751753
return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]+x3*tensor->stride[3]);
752754
}
753755

756+
THDescBuff THTensor_(desc)(const THTensor *tensor) {
757+
const int L = TH_DESC_BUFF_LEN;
758+
THDescBuff buf;
759+
char *str = buf.str;
760+
int n = 0;
761+
#define _stringify(x) #x
762+
n += snprintf(str, L-n, "torch." _stringify(x) "Tensor of size ");
763+
#undef _stringify
764+
int i;
765+
for(i = 0; i < tensor->nDimension; i++) {
766+
if(n >= L) break;
767+
n += snprintf(str+n, L-n, "%ld", tensor->size[i]);
768+
if(i < tensor->nDimension-1) {
769+
n += snprintf(str+n, L-n, "x");
770+
}
771+
}
772+
if(n >= L) {
773+
snprintf(str+L-4+n, 4, "...");
774+
}
775+
return buf;
776+
}
777+
778+
THDescBuff THTensor_(sizeDesc)(const THTensor *tensor) {
779+
const int L = TH_DESC_BUFF_LEN;
780+
THDescBuff buf;
781+
char *str = buf.str;
782+
int n = 0;
783+
n += snprintf(str, L-n, "[");
784+
int i;
785+
for(i = 0; i < tensor->nDimension; i++) {
786+
if(n >= L) break;
787+
n += snprintf(str+n, L-n, "%ld", tensor->size[i]);
788+
if(i < tensor->nDimension-1) {
789+
n += snprintf(str+n, L-n, " x ");
790+
}
791+
}
792+
if(n < L - 2) {
793+
snprintf(str+n, L-n, "]");
794+
} else {
795+
snprintf(str+L-5, 5, "...]");
796+
}
797+
return buf;
798+
}
799+
754800
#endif

generic/THTensor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,8 @@ TH_API real THTensor_(get2d)(const THTensor *tensor, long x0, long x1);
122122
TH_API real THTensor_(get3d)(const THTensor *tensor, long x0, long x1, long x2);
123123
TH_API real THTensor_(get4d)(const THTensor *tensor, long x0, long x1, long x2, long x3);
124124

125+
/* Debug methods */
126+
TH_API THDescBuff THTensor_(desc)(const THTensor *tensor);
127+
TH_API THDescBuff THTensor_(sizeDesc)(const THTensor *tensor);
128+
125129
#endif

0 commit comments

Comments
 (0)