Skip to content

Commit 4c1cdb6

Browse files
colesburyapaszke
authored andcommitted
Refactor Python string utility function
1 parent 775481e commit 4c1cdb6

File tree

9 files changed

+103
-117
lines changed

9 files changed

+103
-117
lines changed

torch/csrc/Module.cpp

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <libshm.h>
88
#include <TH/TH.h>
99

10+
#include "torch/csrc/utils/python_strings.h"
11+
1012
#ifdef WITH_CUDNN
1113
#include "cudnn/Module.h"
1214
#endif
@@ -73,13 +75,9 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg)
7375

7476
THPObjectPtr module_name = PyObject_GetAttrString(obj, "__module__");
7577
if (!module_name) return NULL;
76-
#if PY_MAJOR_VERSION == 2
77-
THPUtils_assert(PyString_Check(module_name.get()), "expected __module__ to be a string");
78-
std::string name = PyString_AS_STRING(module_name.get());
79-
#else
80-
THPUtils_assert(PyUnicode_Check(module_name.get()), "expected __module__ to be a string");
81-
std::string name = PyUnicode_AsUTF8(module_name.get());
82-
#endif
78+
THPUtils_assert(THPUtils_checkString(module_name.get()),
79+
"expected __module__ to be a string");
80+
std::string name = THPUtils_unpackString(module_name.get());
8381
names.push_back(name + "." + type->tp_name);
8482
type->tp_name = names.back().c_str();
8583
}
@@ -89,15 +87,13 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg)
8987
static bool THPModule_assignStateless(PyObject *self)
9088
{
9189
#define INIT_STATELESS(type) \
92-
stateless = PyObject_Call((PyObject*)&TH_CONCAT_2(type, TensorStatelessType), arg, NULL); \
90+
stateless = PyObject_CallFunctionObjArgs((PyObject*)&TH_CONCAT_2(type, TensorStatelessType), NULL); \
9391
if (!stateless) { \
94-
THPUtils_setError("stateless method initialization error"); \
9592
return false; \
9693
} \
9794
if (PyObject_SetAttrString(TH_CONCAT_3(THP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \
98-
THPUtils_setError("stateless method initialization error (on assignment)");\
95+
return false; \
9996
}
100-
PyObject *arg = PyTuple_New(0);
10197
PyObject *stateless;
10298
INIT_STATELESS(Double);
10399
INIT_STATELESS(Float);
@@ -107,23 +103,25 @@ static bool THPModule_assignStateless(PyObject *self)
107103
INIT_STATELESS(Short);
108104
INIT_STATELESS(Char);
109105
INIT_STATELESS(Byte);
110-
Py_DECREF(arg);
111106
return true;
112107
#undef INIT_STATELESS
113108
}
114109
//
115110
// Callback for python part. Used for additional initialization of python classes
116111
static PyObject * THPModule_initExtension(PyObject *self, PyObject *shm_manager_path)
117112
{
118-
if (!THPUtils_checkBytes(shm_manager_path)) {
113+
HANDLE_TH_ERRORS
114+
if (!THPUtils_checkString(shm_manager_path)) {
119115
THPUtils_setError("initialization error - expected bytes/string object as shm_manager_path!");
120116
return NULL;
121117
}
122-
libshm_init(THPUtils_bytesAsString(shm_manager_path));
118+
std::string path = THPUtils_unpackString(shm_manager_path);
119+
libshm_init(path.c_str());
123120
if (!THPModule_loadClasses(self)) return NULL;
124121
if (!THPModule_assignStateless(self)) return NULL;
125122
if (!THPAutograd_initFunctions(self)) return NULL;
126-
return PyBool_FromLong(true);
123+
Py_RETURN_NONE;
124+
END_HANDLE_TH_ERRORS
127125
}
128126

129127
static PyObject * THPModule_getNumThreads(PyObject *module)
@@ -429,22 +427,6 @@ PyObject *THPModule_safeCall(PyObject *_unused, PyObject *args, PyObject *kwargs
429427
return result;
430428
}
431429

432-
static std::string parseString(PyObject *obj)
433-
{
434-
if (PyBytes_Check(obj)) {
435-
return std::string(PyBytes_AS_STRING(obj));
436-
#if PY_MAJOR_VERSION == 3
437-
} else if (PyUnicode_Check(obj)) {
438-
return std::string(PyUnicode_AsUTF8(obj));
439-
#else
440-
} else if (PyUnicode_Check(obj)) {
441-
THPObjectPtr utf8 = PyUnicode_AsUTF8String(obj);
442-
return std::string(PyBytes_AS_STRING(utf8.get()));
443-
#endif
444-
}
445-
return "<invalid string>";
446-
}
447-
448430
PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args)
449431
{
450432
// adds a __doc__ string to a function, similar to numpy's arr_add_docstring
@@ -455,8 +437,11 @@ PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args)
455437
return NULL;
456438
}
457439

458-
all_docs.push_back(parseString(doc_obj));
459-
const char* doc_str = all_docs.back().c_str();
440+
const char* doc_str = "<invalid string>";
441+
if (THPUtils_checkString(doc_obj)) {
442+
all_docs.push_back(THPUtils_unpackString(doc_obj));
443+
doc_str = all_docs.back().c_str();
444+
}
460445

461446
if (Py_TYPE(obj) == &PyCFunction_Type) {
462447
PyCFunctionObject* f = (PyCFunctionObject *)obj;
@@ -499,7 +484,6 @@ extern PyObject * THCPModule_seedAll(PyObject *_unused);
499484
extern PyObject * THCPModule_initialSeed(PyObject *_unused);
500485
extern PyObject * THCPModule_cudaHostAllocator(PyObject *_unused);
501486
extern PyObject * THCPModule_cudaSynchronize(PyObject *_unused);
502-
extern PyObject * THCPModule_getLibPath(PyObject *_unused);
503487
extern PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles);
504488
extern PyObject * THCPModule_cudaLockMutex(PyObject *module);
505489
extern PyObject * THCPModule_cudaUnlockMutex(PyObject *module);
@@ -532,7 +516,6 @@ static PyMethodDef TorchMethods[] = {
532516
{"_cuda_initialSeed", (PyCFunction)THCPModule_initialSeed, METH_NOARGS, NULL},
533517
{"_cuda_cudaHostAllocator", (PyCFunction)THCPModule_cudaHostAllocator, METH_NOARGS, NULL},
534518
{"_cuda_synchronize", (PyCFunction)THCPModule_cudaSynchronize, METH_NOARGS, NULL},
535-
{"_cuda_getLibPath", (PyCFunction)THCPModule_getLibPath, METH_NOARGS, NULL},
536519
{"_cuda_sleep", (PyCFunction)THCPModule_cudaSleep, METH_O, NULL},
537520
{"_cuda_sparse_init", (PyCFunction)THCSPModule_initExtension, METH_NOARGS, NULL},
538521
{"_cuda_lock_mutex", (PyCFunction)THCPModule_cudaLockMutex, METH_NOARGS, NULL},

torch/csrc/Size.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "Size.h"
22

33
#include <string>
4+
#include "torch/csrc/utils/python_strings.h"
45
#include "THP.h"
56

67
PyObject* THPSizeClass = NULL;
@@ -47,11 +48,7 @@ static PyObject * THPSize_repr(THPSize *self)
4748
repr += std::to_string(PyLong_AsLong(PyTuple_GET_ITEM(self, i)));
4849
}
4950
repr += "])";
50-
#if PY_MAJOR_VERSION == 2
51-
return PyString_FromString(repr.c_str());
52-
#else
53-
return PyUnicode_FromString(repr.c_str());
54-
#endif
51+
return THPUtils_packString(repr);
5552
}
5653

5754
extern PyTypeObject THPSizeType;

torch/csrc/autograd/python_hook.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "torch/csrc/autograd/python_variable.h"
77
#include "torch/csrc/utils/auto_gil.h"
88
#include "torch/csrc/utils/object_ptr.h"
9+
#include "torch/csrc/utils/python_strings.h"
910
#include "torch/csrc/Exceptions.h"
1011
#include <THPP/THPP.h>
1112

@@ -184,15 +185,8 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject
184185

185186
static std::string hook_name(PyObject* hook) {
186187
THPObjectPtr name = PyObject_GetAttrString(hook, "__name__");
187-
#if PY_MAJOR_VERSION == 2
188-
if (name && PyString_Check(name.get())) {
189-
return std::string(PyString_AS_STRING(name.get()));
188+
if (name && THPUtils_checkString(name.get())) {
189+
return THPUtils_unpackString(name.get());
190190
}
191-
#else
192-
if (name && PyUnicode_Check(name.get())) {
193-
THPObjectPtr tmp = PyUnicode_AsASCIIString(name.get());
194-
return std::string(PyBytes_AS_STRING(tmp.get()));
195-
}
196-
#endif
197191
return "<unknown>";
198192
}

torch/csrc/cuda/Module.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -259,19 +259,6 @@ PyObject * THCPModule_cudaUnlockMutex(PyObject *module)
259259
Py_RETURN_NONE;
260260
}
261261

262-
PyObject * THCPModule_getLibPath(PyObject *_unused)
263-
{
264-
#define _STR(x) #x
265-
#define STR(x) _STR(x)
266-
#if PY_MAJOR_VERSION == 2
267-
return PyString_FromString(STR(CUDA_LIB_PATH));
268-
#else
269-
return PyUnicode_FromString(STR(CUDA_LIB_PATH));
270-
#endif
271-
#undef STR
272-
#undef _STR
273-
}
274-
275262
////////////////////////////////////////////////////////////////////////////////
276263
// Cuda module initialization
277264
////////////////////////////////////////////////////////////////////////////////

torch/csrc/distributed/Module.cpp

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,43 +42,27 @@ static bool THDPModule_loadClasses(PyObject *module_dict)
4242
static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
4343
static std::unordered_map<PyObject*, THDGroup> obj2group;
4444

45-
static THPObjectPtr _ensureBytes(PyObject *obj)
46-
{
47-
#if PY_MAJOR_VERSION == 2
48-
if (PyString_Check(obj)) {
49-
#elif PY_MAJOR_VERSION == 3
50-
if (PyBytes_Check(obj)) {
51-
#endif
52-
Py_INCREF(obj);
53-
return obj;
54-
}
55-
if (PyUnicode_Check(obj)) {
56-
return PyUnicode_AsASCIIString(obj);
57-
}
58-
return NULL;
59-
}
60-
61-
PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *_backend)
45+
PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *backend)
6246
{
6347
HANDLE_TH_ERRORS
64-
THPObjectPtr backend_bytes = _ensureBytes(_backend);
65-
THPUtils_assert(backend_bytes, "backend argument has to be a string/bytes "
66-
"object, but got %s", THPUtils_typename(_backend));
67-
char *backend_name = THPUtils_bytesAsString(backend_bytes.get());
48+
THPUtils_assert(THPUtils_checkString(backend),
49+
"backend argument has to be a string/bytes object, but got %s",
50+
THPUtils_typename(backend));
51+
std::string backend_name = THPUtils_unpackString(backend);
6852
THDChannelType channel_type = name2channel_type.at(backend_name);
6953
THPUtils_assert(THDProcessGroupInit(channel_type), "failed to initialize "
7054
"distributed library (THD)");
7155
Py_RETURN_NONE;
7256
END_HANDLE_TH_ERRORS
7357
}
7458

75-
PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *_backend)
59+
PyObject* THDPModule_initMasterWorker(PyObject *_unused, PyObject *backend)
7660
{
7761
HANDLE_TH_ERRORS
78-
THPObjectPtr backend_bytes = _ensureBytes(_backend);
79-
THPUtils_assert(backend_bytes, "backend argument has to be a string/bytes "
80-
"object, but got %s", THPUtils_typename(_backend));
81-
char *backend_name = THPUtils_bytesAsString(backend_bytes.get());
62+
THPUtils_assert(THPUtils_checkString(backend),
63+
"backend argument has to be a string/bytes object, but got %s",
64+
THPUtils_typename(backend));
65+
std::string backend_name = THPUtils_unpackString(backend);
8266
THDChannelType channel_type = name2channel_type.at(backend_name);
8367
THPUtils_assert(THDMasterWorkerInit(channel_type), "failed to initialize "
8468
"distributed library (THD)");

torch/csrc/generic/StorageSharing.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ static PyObject * THPStorage_(shareFilename)(THPStorage *self)
104104
ctx = (libshm_context*)storage->allocatorContext;
105105
}
106106

107-
THPObjectPtr manager_handle = THPUtils_bytesFromString(ctx->manager_handle);
107+
THPObjectPtr manager_handle = PyBytes_FromString(ctx->manager_handle);
108108
if (!manager_handle) return NULL;
109109
THPObjectPtr storage_handle =
110-
THPUtils_bytesFromString(THMapAllocatorContext_filename(ctx->th_context));
110+
PyBytes_FromString(THMapAllocatorContext_filename(ctx->th_context));
111111
if (!storage_handle) return NULL;
112112
THPObjectPtr size = PyLong_FromLong(storage->size);
113113
if (!size) return NULL;
@@ -124,20 +124,21 @@ static PyObject * THPStorage_(shareFilename)(THPStorage *self)
124124
static PyObject * THPStorage_(newSharedFilename)(PyObject *_unused, PyObject *args)
125125
{
126126
HANDLE_TH_ERRORS
127+
THPUtils_assert(PyTuple_GET_SIZE(args) == 3, "tuple of 3 items expected");
127128
PyObject *_manager_handle = PyTuple_GET_ITEM(args, 0);
128129
PyObject *_object_handle = PyTuple_GET_ITEM(args, 1);
129130
PyObject *_size = PyTuple_GET_ITEM(args, 2);
130-
if (!THPUtils_checkBytes(_manager_handle) || !THPUtils_checkBytes(_object_handle) || !THPUtils_checkLong(_size)) {
131+
if (!PyBytes_Check(_manager_handle) || !PyBytes_Check(_object_handle) || !THPUtils_checkLong(_size)) {
131132
THPUtils_invalidArguments(args, NULL, "_new_shared in file system mode", 1,
132133
"a handle (string/bytes) and storage size (int)");
133134
return NULL;
134135
}
135-
const char *manager_handle = THPUtils_bytesAsString(_manager_handle);
136-
const char *object_handle = THPUtils_bytesAsString(_object_handle);
136+
const char *manager_handle = PyBytes_AS_STRING(_manager_handle);
137+
const char *object_handle = PyBytes_AS_STRING(_object_handle);
137138
long size = THPUtils_unpackLong(_size);
138-
139-
libshm_context *ctx = libshm_context_new(manager_handle, object_handle,
140-
TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_NOCREATE);
139+
int flags = TH_ALLOCATOR_MAPPED_SHAREDMEM |
140+
TH_ALLOCATOR_MAPPED_NOCREATE;
141+
libshm_context *ctx = libshm_context_new(manager_handle, object_handle, flags);
141142
return THPStorage_(New)(THStorage_(newWithAllocator)(size,
142143
&THManagedSharedAllocator, (void*)ctx));
143144
END_HANDLE_TH_ERRORS
@@ -199,6 +200,7 @@ static PyObject * THPStorage_(shareFd)(THPStorage *self)
199200
static PyObject * THPStorage_(newSharedFd)(PyObject *_unused, PyObject *args)
200201
{
201202
HANDLE_TH_ERRORS
203+
THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected");
202204
PyObject *_tmp_fd = PyTuple_GET_ITEM(args, 0);
203205
PyObject *_size = PyTuple_GET_ITEM(args, 1);
204206
if (!THPUtils_checkLong(_tmp_fd) || !THPUtils_checkLong(_size)) {

torch/csrc/utils.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <algorithm>
77
#include <unordered_map>
88
#include "THP.h"
9+
#include "torch/csrc/utils/python_strings.h"
910

1011
#include "generic/utils.cpp"
1112
#include <TH/THGenerateAllTypes.h>
@@ -457,15 +458,6 @@ std::vector<std::string> _tryMatchKwargs(const Option& option,
457458
return unmatched;
458459
}
459460

460-
std::string _parseDictKey(PyObject *key_str) {
461-
#if PY_MAJOR_VERSION == 3
462-
THPObjectPtr ascii = PyUnicode_AsASCIIString(key_str);
463-
return std::string(PyBytes_AS_STRING(ascii.get()));
464-
#else
465-
return std::string(PyString_AS_STRING(key_str));
466-
#endif
467-
}
468-
469461
void THPUtils_invalidArguments(PyObject *given_args, PyObject *given_kwargs,
470462
const char *function_name, size_t num_options, ...) {
471463
std::vector<std::string> option_strings;
@@ -493,7 +485,7 @@ void THPUtils_invalidArguments(PyObject *given_args, PyObject *given_kwargs,
493485
Py_ssize_t pos = 0;
494486

495487
while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
496-
kwargs.emplace(_parseDictKey(key), value);
488+
kwargs.emplace(THPUtils_unpackString(key), value);
497489
}
498490
}
499491

torch/csrc/utils.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,6 @@
2525
(throw std::runtime_error("Could not unpack long"), 0))
2626
#endif
2727

28-
29-
#if PY_MAJOR_VERSION == 2
30-
#define THPUtils_bytesFromString(c_string) PyString_FromString(c_string)
31-
#define THPUtils_checkBytes(obj) PyString_Check(obj)
32-
#define THPUtils_bytesAsString(obj) PyString_AS_STRING(obj)
33-
#else
34-
#define THPUtils_bytesFromString(c_string) PyBytes_FromString(c_string)
35-
#define THPUtils_checkBytes(obj) PyBytes_Check(obj)
36-
#define THPUtils_bytesAsString(obj) PyBytes_AS_STRING(obj)
37-
#endif
38-
3928
#if PY_MAJOR_VERSION == 2
4029
#define THPUtils_checkReal_FLOAT(object) \
4130
(PyFloat_Check(object) || PyLong_Check(object) || PyInt_Check(object))

0 commit comments

Comments
 (0)