22#include < torch/csrc/jit/passes/bailout_graph.h>
33#include < torch/csrc/jit/passes/canonicalize_ops.h>
44#include < torch/csrc/jit/passes/clear_undefinedness.h>
5+ #include < torch/csrc/jit/passes/common_subexpression_elimination.h>
6+ #include < torch/csrc/jit/passes/constant_pooling.h>
57#include < torch/csrc/jit/passes/constant_propagation.h>
68#include < torch/csrc/jit/passes/create_autodiff_subgraphs.h>
79#include < torch/csrc/jit/passes/dead_code_elimination.h>
810#include < torch/csrc/jit/passes/graph_fuser.h>
911#include < torch/csrc/jit/passes/guard_elimination.h>
1012#include < torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
13+ #include < torch/csrc/jit/passes/inplace_check.h>
1114#include < torch/csrc/jit/passes/insert_guards.h>
1215#include < torch/csrc/jit/passes/lower_grad_of.h>
16+ #include < torch/csrc/jit/passes/peephole.h>
1317#include < torch/csrc/jit/passes/remove_expands.h>
1418#include < torch/csrc/jit/passes/requires_grad_analysis.h>
1519#include < torch/csrc/jit/passes/shape_analysis.h>
@@ -53,11 +57,67 @@ static bool needsGradientInProfilingMode(Block* b) {
5357 return false ;
5458}
5559
56- std::shared_ptr<Graph> ProfilingGraphExecutorImpl::prepareGraph (
57- const std::shared_ptr<Graph>& graph,
58- Stack& stack) {
59- auto g = graph->copy ();
60- return g;
60+ void ProfilingGraphExecutorImpl::runProfilingOptimizations (
61+ std::shared_ptr<Graph>& copy) {
62+ if (!getGraphExecutorOptimize ()) {
63+ LowerGradOf (*copy);
64+ runRequiredPasses (copy);
65+ return ;
66+ }
67+
68+ InsertGuards (copy);
69+ LowerGradOf (*copy);
70+ EliminateRedundantGuards (copy);
71+ InsertBailOuts (copy);
72+ GRAPH_DUMP (" After InsertBailOuts: " , copy);
73+ specializeAutogradZero (*copy);
74+
75+ runRequiredPasses (copy);
76+ ConstantPropagation (copy);
77+ runOptimization (copy);
78+
79+ if (needsGradientInProfilingMode (copy->block ())) {
80+ auto diff_nodes = CreateAutodiffSubgraphs (
81+ copy,
82+ getAutodiffSubgraphInlining () ? autodiffSubgraphNodeThreshold : 1 );
83+ for (Node* dnode : diff_nodes) {
84+ auto diff_graph = std::move (dnode->g (attr::Subgraph));
85+ Gradient gradient = differentiate (diff_graph);
86+ runOptimization (gradient.f );
87+ // run non diff optimization on the forward graph
88+ runNondiffOptimization (gradient.f );
89+ packGradient (gradient, dnode);
90+ }
91+ InlineAutodiffSubgraphs (
92+ copy,
93+ getAutodiffSubgraphInlining () ? autodiffSubgraphInlineThreshold : 1 );
94+
95+ } else {
96+ runNondiffOptimization (copy);
97+ }
98+ EliminateDeadCode (copy);
99+ GRAPH_DUMP (" Optimized Graph : " , copy);
100+ }
101+
102+ void ProfilingGraphExecutorImpl::runProfilingInsensitiveOptimizations (
103+ std::shared_ptr<Graph>& copy) {
104+ LowerGradOf (*copy);
105+ GRAPH_DUMP (" runProfilingInsensitiveOptimizations" , copy);
106+ if (getProfilingMode ()) {
107+ ClearUndefinedness (copy);
108+ }
109+ runRequiredPasses (copy);
110+ if (!getGraphExecutorOptimize ()) {
111+ return ;
112+ }
113+
114+ ConstantPropagation (copy);
115+ EliminateDeadCode (copy);
116+ EliminateCommonSubexpression (copy);
117+ ConstantPooling (copy);
118+ PeepholeOptimize (copy);
119+ EliminateDeadCode (copy);
120+ CheckInplace (copy);
61121}
62122
63123ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl (
@@ -67,89 +127,43 @@ ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl(
67127ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor (Stack& stack) {
68128 std::lock_guard<std::mutex> lock (compile_mutex);
69129 GRAPH_DEBUG (" Running ProfilingGraphExecutorImpl " , this );
130+
70131 if (optimized_plan_) {
71132 return *optimized_plan_;
72133 }
73134
74- std::shared_ptr<Graph> copy;
75- if (getProfilingMode ()) {
76- if (!pr_) {
77- pr_ = ProfilingRecord::instrumentGraph (prepareGraph (graph, stack));
78- auto copy = pr_->graph ()->copy ();
79- LowerGradOf (*copy);
80- specializeAutogradZero (*copy);
81- runRequiredPasses (copy);
82- GRAPH_DUMP (" Profiled Graph: " , copy);
83- profiling_plan_ = ExecutionPlan (copy);
84- // fall-through
85- }
86-
87- if (!pr_->ready ()) {
88- return *profiling_plan_;
89- }
90- copy = pr_->graph ()->copy ();
91-
92- } else {
93- copy = graph->copy ();
94- }
95-
96- if (!getGraphExecutorOptimize ()) {
97- runRequiredPasses (copy);
135+ // simple executor
136+ if (!getProfilingMode ()) {
137+ auto copy = graph->copy ();
138+ runProfilingInsensitiveOptimizations (copy);
139+ GRAPH_DUMP (" Optimized SimpleExecutor Graph : " , copy);
98140 optimized_plan_ = ExecutionPlan (copy);
99141 return *optimized_plan_;
100142 }
101143
102- InsertGuards (copy);
103- LowerGradOf (*copy);
104- if (getProfilingMode ()) {
105- EliminateRedundantGuards (copy);
106- InsertBailOuts (copy);
107- GRAPH_DUMP (" After InsertBailOuts: " , copy);
144+ // if a profiling graph hasn't been created yet
145+ if (!pr_) {
146+ auto copy = graph->copy ();
147+ runProfilingInsensitiveOptimizations (copy);
148+ pr_ = ProfilingRecord::instrumentGraph (copy);
149+ auto pr_copy = pr_->graph ()->copy ();
150+ GRAPH_DUMP (" Profiled Graph: " , pr_copy);
151+ profiling_plan_ = ExecutionPlan (pr_copy);
152+ // fall-through
108153 }
109154
110- specializeAutogradZero (*copy);
111- if (!getProfilingMode ()) {
112- ClearUndefinedness (copy) ;
155+ // profile until a graph is ready
156+ if (!pr_-> ready ()) {
157+ return *profiling_plan_ ;
113158 }
114159
115- runRequiredPasses (copy);
116- ConstantPropagation (copy);
117- runOptimization (copy);
118-
119- // TODO: insert grad propagation
120- bool needs_gradient = getProfilingMode ()
121- ? needsGradientInProfilingMode (copy->block ())
122- : true ;
123- if (needs_gradient) {
124- // for Simple Executor skip creating autodiff graphs
125- // and let autograd handle backward for us
126- if (getProfilingMode ()) {
127- auto diff_nodes = CreateAutodiffSubgraphs (
128- copy,
129- getAutodiffSubgraphInlining () ? autodiffSubgraphNodeThreshold : 1 );
130- for (Node *dnode : diff_nodes) {
131- auto diff_graph = std::move (dnode->g (attr::Subgraph));
132- Gradient gradient = differentiate (diff_graph);
133- runOptimization (gradient.f );
134- // run non diff optimization on the forward graph
135- runNondiffOptimization (gradient.f );
136- packGradient (gradient, dnode);
137- }
138- InlineAutodiffSubgraphs (copy, getAutodiffSubgraphInlining ()
139- ? autodiffSubgraphInlineThreshold
140- : 1 );
141- }
142- } else {
143- runNondiffOptimization (copy);
144- }
145- EliminateDeadCode (copy);
146- GRAPH_DUMP (" Optimized Graph : " , copy);
160+ auto copy = pr_->graph ()->copy ();
161+ runProfilingOptimizations (copy);
147162 // cache
148163 optimized_plan_ = ExecutionPlan (copy);
149164 return *optimized_plan_;
150165}
151166
152-
153167GraphExecutorState ProfilingGraphExecutorImpl::getDebugState () {
154168 GraphExecutorState state;
155169 TORCH_INTERNAL_ASSERT (optimized_plan_);
0 commit comments