Skip to content

Commit cc2d5ca

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
add enabled API to autograd profiler (pytorch#31380)
Summary: Pull Request resolved: pytorch#31380 For being able to profile async RPCs, we attach a `RecordFunction` object to the future that is created during the RPC to persist it across the lifetime of the RPC (this is implemented in the next PR: ). Since we'd only like to do this when profiling is enabled, this PR adds an enabled API to the autograd profiler. ghstack-source-id: 96053933 Test Plan: Modified unit test. Differential Revision: D19050391 fbshipit-source-id: aa382110e69d06b4a84c83b31d2bec2d8a81ba10
1 parent 7d63027 commit cc2d5ca

File tree

4 files changed

+9
-0
lines changed

4 files changed

+9
-0
lines changed

test/test_autograd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,8 +2523,11 @@ def test_profiler(self):
25232523
x = torch.randn(10, 10)
25242524

25252525
with profile() as p:
2526+
self.assertTrue(torch.autograd._profiler_enabled())
25262527
y = x * 2 + 4
25272528

2529+
self.assertFalse(torch.autograd._profiler_enabled())
2530+
25282531
last_end = 0
25292532
names = ['mul', 'add']
25302533
self.assertEqual(len(p.function_events), len(names))

torch/csrc/autograd/init.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
5050

5151
m.def("_enable_profiler", enableProfiler);
5252
m.def("_disable_profiler", disableProfiler);
53+
m.def("_profiler_enabled", profilerEnabled);
5354

5455
m.def("_push_range", [](std::string name) { pushRange(std::move(name)); });
5556
m.def("_pop_range", []() { popRange(); });

torch/csrc/autograd/profiler.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ void mark(std::string name, bool include_cuda /* = true */) {
5454
}
5555
}
5656

57+
bool profilerEnabled() {
58+
return state != ProfilerState::Disabled;
59+
}
60+
5761
void pushRangeImpl(
5862
const StringView& name,
5963
const char* msg = "",

torch/csrc/autograd/profiler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ using thread_event_lists = std::vector<std::vector<Event>>;
230230
// there no autograd functions are being executed when these function are used.
231231
TORCH_API void enableProfiler(ProfilerConfig);
232232
TORCH_API thread_event_lists disableProfiler();
233+
TORCH_API bool profilerEnabled();
233234

234235

235236
// Usage:

0 commit comments

Comments
 (0)