Skip to content

Commit 5375cea

Browse files
Krovatkinfacebook-github-bot
authored andcommitted
run optimizations on pre-profiled graph (pytorch#31392)
Summary: This is the first stab at running profile-insensitive optimizations on pre-profiled graphs. Running those optimizations has a potential to simplify graphs greatly before GuardElimination and GuardElimination should be able to remove more guards. Pull Request resolved: pytorch#31392 Differential Revision: D19173639 Pulled By: Krovatkin fbshipit-source-id: 2485a2a598c10f9b5445efb30b16439ad4551b3f
1 parent 256db1e commit 5375cea

File tree

3 files changed

+92
-86
lines changed

3 files changed

+92
-86
lines changed

test/test_jit.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7421,31 +7421,24 @@ def func():
74217421
self.assertEqual(t1.device, t2.device)
74227422

74237423

7424-
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't have any shapes to propagate")
7424+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple Executor doesn't have any shapes to propagate")
74257425
def test_tensor_as_tensor_shape_prop(self):
74267426
tensor_template = dedent('''
74277427
def func():
74287428
return torch.{tensor_op}({input})
74297429
''')
74307430
ops = ['tensor', 'as_tensor']
74317431
inputs = ['[1]', '[False]', '[2.5]', '0.5', '1', 'False', '[[1]]']
7432-
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
7433-
expected_shape = ["Long(1)", "Bool(1)", "Double(1)", "Double()", "Long()", "Bool()", "Long(1, 1)"]
7434-
else:
7435-
expected_shape = ["Long(*)", ("Bool(*)"), "Double(*)", "Double()", "Long()", "Bool()", "Long(*, *)"]
7432+
expected_shape = ["Long(*)", ("Bool(*)"), "Double(*)", "Double()", "Long()", "Bool()", "Long(*, *)"]
74367433

74377434
for op in ops:
74387435
for inp, expect in zip(inputs, expected_shape):
74397436
code = tensor_template.format(tensor_op=op, input=inp)
74407437
scope = {}
74417438
exec(code, globals(), scope)
7442-
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
7443-
fn = self.checkScript(code, ())
7444-
FileCheck().check(expect).check("aten::{tensor_op}".format(tensor_op=op)).run(fn.graph_for())
7445-
else:
7446-
cu = torch.jit.CompilationUnit(code)
7447-
torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False)
7448-
FileCheck().check(expect).check("aten::{tensor_op}".format(tensor_op=op)).run(cu.func.graph)
7439+
cu = torch.jit.CompilationUnit(code)
7440+
torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False)
7441+
FileCheck().check(expect).check("aten::{tensor_op}".format(tensor_op=op)).run(cu.func.graph)
74497442

74507443
@torch.jit.script
74517444
def test_dtype(inp_dtype):

torch/csrc/jit/profiling_graph_executor_impl.cpp

Lines changed: 85 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
#include <torch/csrc/jit/passes/bailout_graph.h>
33
#include <torch/csrc/jit/passes/canonicalize_ops.h>
44
#include <torch/csrc/jit/passes/clear_undefinedness.h>
5+
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
6+
#include <torch/csrc/jit/passes/constant_pooling.h>
57
#include <torch/csrc/jit/passes/constant_propagation.h>
68
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
79
#include <torch/csrc/jit/passes/dead_code_elimination.h>
810
#include <torch/csrc/jit/passes/graph_fuser.h>
911
#include <torch/csrc/jit/passes/guard_elimination.h>
1012
#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
13+
#include <torch/csrc/jit/passes/inplace_check.h>
1114
#include <torch/csrc/jit/passes/insert_guards.h>
1215
#include <torch/csrc/jit/passes/lower_grad_of.h>
16+
#include <torch/csrc/jit/passes/peephole.h>
1317
#include <torch/csrc/jit/passes/remove_expands.h>
1418
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
1519
#include <torch/csrc/jit/passes/shape_analysis.h>
@@ -53,11 +57,67 @@ static bool needsGradientInProfilingMode(Block* b) {
5357
return false;
5458
}
5559

56-
std::shared_ptr<Graph> ProfilingGraphExecutorImpl::prepareGraph(
57-
const std::shared_ptr<Graph>& graph,
58-
Stack& stack) {
59-
auto g = graph->copy();
60-
return g;
60+
void ProfilingGraphExecutorImpl::runProfilingOptimizations(
61+
std::shared_ptr<Graph>& copy) {
62+
if (!getGraphExecutorOptimize()) {
63+
LowerGradOf(*copy);
64+
runRequiredPasses(copy);
65+
return;
66+
}
67+
68+
InsertGuards(copy);
69+
LowerGradOf(*copy);
70+
EliminateRedundantGuards(copy);
71+
InsertBailOuts(copy);
72+
GRAPH_DUMP("After InsertBailOuts: ", copy);
73+
specializeAutogradZero(*copy);
74+
75+
runRequiredPasses(copy);
76+
ConstantPropagation(copy);
77+
runOptimization(copy);
78+
79+
if (needsGradientInProfilingMode(copy->block())) {
80+
auto diff_nodes = CreateAutodiffSubgraphs(
81+
copy,
82+
getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
83+
for (Node* dnode : diff_nodes) {
84+
auto diff_graph = std::move(dnode->g(attr::Subgraph));
85+
Gradient gradient = differentiate(diff_graph);
86+
runOptimization(gradient.f);
87+
// run non diff optimization on the forward graph
88+
runNondiffOptimization(gradient.f);
89+
packGradient(gradient, dnode);
90+
}
91+
InlineAutodiffSubgraphs(
92+
copy,
93+
getAutodiffSubgraphInlining() ? autodiffSubgraphInlineThreshold : 1);
94+
95+
} else {
96+
runNondiffOptimization(copy);
97+
}
98+
EliminateDeadCode(copy);
99+
GRAPH_DUMP("Optimized Graph : ", copy);
100+
}
101+
102+
void ProfilingGraphExecutorImpl::runProfilingInsensitiveOptimizations(
103+
std::shared_ptr<Graph>& copy) {
104+
LowerGradOf(*copy);
105+
GRAPH_DUMP("runProfilingInsensitiveOptimizations", copy);
106+
if (getProfilingMode()) {
107+
ClearUndefinedness(copy);
108+
}
109+
runRequiredPasses(copy);
110+
if (!getGraphExecutorOptimize()) {
111+
return;
112+
}
113+
114+
ConstantPropagation(copy);
115+
EliminateDeadCode(copy);
116+
EliminateCommonSubexpression(copy);
117+
ConstantPooling(copy);
118+
PeepholeOptimize(copy);
119+
EliminateDeadCode(copy);
120+
CheckInplace(copy);
61121
}
62122

63123
ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl(
@@ -67,89 +127,43 @@ ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl(
67127
ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) {
68128
std::lock_guard<std::mutex> lock(compile_mutex);
69129
GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this);
130+
70131
if (optimized_plan_) {
71132
return *optimized_plan_;
72133
}
73134

74-
std::shared_ptr<Graph> copy;
75-
if (getProfilingMode()) {
76-
if (!pr_) {
77-
pr_ = ProfilingRecord::instrumentGraph(prepareGraph(graph, stack));
78-
auto copy = pr_->graph()->copy();
79-
LowerGradOf(*copy);
80-
specializeAutogradZero(*copy);
81-
runRequiredPasses(copy);
82-
GRAPH_DUMP("Profiled Graph: ", copy);
83-
profiling_plan_ = ExecutionPlan(copy);
84-
// fall-through
85-
}
86-
87-
if (!pr_->ready()) {
88-
return *profiling_plan_;
89-
}
90-
copy = pr_->graph()->copy();
91-
92-
} else {
93-
copy = graph->copy();
94-
}
95-
96-
if (!getGraphExecutorOptimize()) {
97-
runRequiredPasses(copy);
135+
// simple executor
136+
if (!getProfilingMode()) {
137+
auto copy = graph->copy();
138+
runProfilingInsensitiveOptimizations(copy);
139+
GRAPH_DUMP("Optimized SimpleExecutor Graph : ", copy);
98140
optimized_plan_ = ExecutionPlan(copy);
99141
return *optimized_plan_;
100142
}
101143

102-
InsertGuards(copy);
103-
LowerGradOf(*copy);
104-
if (getProfilingMode()) {
105-
EliminateRedundantGuards(copy);
106-
InsertBailOuts(copy);
107-
GRAPH_DUMP("After InsertBailOuts: ", copy);
144+
// if a profiling graph hasn't been created yet
145+
if (!pr_) {
146+
auto copy = graph->copy();
147+
runProfilingInsensitiveOptimizations(copy);
148+
pr_ = ProfilingRecord::instrumentGraph(copy);
149+
auto pr_copy = pr_->graph()->copy();
150+
GRAPH_DUMP("Profiled Graph: ", pr_copy);
151+
profiling_plan_ = ExecutionPlan(pr_copy);
152+
// fall-through
108153
}
109154

110-
specializeAutogradZero(*copy);
111-
if (!getProfilingMode()) {
112-
ClearUndefinedness(copy);
155+
// profile until a graph is ready
156+
if (!pr_->ready()) {
157+
return *profiling_plan_;
113158
}
114159

115-
runRequiredPasses(copy);
116-
ConstantPropagation(copy);
117-
runOptimization(copy);
118-
119-
// TODO: insert grad propagation
120-
bool needs_gradient = getProfilingMode()
121-
? needsGradientInProfilingMode(copy->block())
122-
: true;
123-
if (needs_gradient) {
124-
// for Simple Executor skip creating autodiff graphs
125-
// and let autograd handle backward for us
126-
if (getProfilingMode()) {
127-
auto diff_nodes = CreateAutodiffSubgraphs(
128-
copy,
129-
getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
130-
for (Node *dnode : diff_nodes) {
131-
auto diff_graph = std::move(dnode->g(attr::Subgraph));
132-
Gradient gradient = differentiate(diff_graph);
133-
runOptimization(gradient.f);
134-
// run non diff optimization on the forward graph
135-
runNondiffOptimization(gradient.f);
136-
packGradient(gradient, dnode);
137-
}
138-
InlineAutodiffSubgraphs(copy, getAutodiffSubgraphInlining()
139-
? autodiffSubgraphInlineThreshold
140-
: 1);
141-
}
142-
} else {
143-
runNondiffOptimization(copy);
144-
}
145-
EliminateDeadCode(copy);
146-
GRAPH_DUMP("Optimized Graph : ", copy);
160+
auto copy = pr_->graph()->copy();
161+
runProfilingOptimizations(copy);
147162
// cache
148163
optimized_plan_ = ExecutionPlan(copy);
149164
return *optimized_plan_;
150165
}
151166

152-
153167
GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() {
154168
GraphExecutorState state;
155169
TORCH_INTERNAL_ASSERT(optimized_plan_);

torch/csrc/jit/profiling_graph_executor_impl.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
1212
~ProfilingGraphExecutorImpl() override = default;
1313

1414
private:
15-
std::shared_ptr<Graph> prepareGraph(
16-
const std::shared_ptr<Graph>& graph,
17-
Stack& stack);
15+
void runProfilingInsensitiveOptimizations(std::shared_ptr<Graph>& graph);
16+
void runProfilingOptimizations(std::shared_ptr<Graph>& graph);
1817
std::unique_ptr<ProfilingRecord> pr_;
1918
c10::optional<ExecutionPlan>
2019
profiling_plan_; // plan to run in order to profiling the code

0 commit comments

Comments
 (0)