Skip to content

Commit d8cbba3

Browse files
eellisonfacebook-github-bot
authored andcommitted
[JIT] Disable Complete Shape Inlining For Testing Purposes (pytorch#56966)
Summary: Pull Request resolved: pytorch#56966 This PR adds a toggle to shape analysis which won't inline complete tensor shapes as constants into the shape compute graph, which is a good stress test on the partial evaluation pipeline. Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D28444664 Pulled By: eellison fbshipit-source-id: a62e424515a8837a4b596546efa93af5e8e61f10
1 parent f66fbb1 commit d8cbba3

File tree

5 files changed

+42
-1
lines changed

5 files changed

+42
-1
lines changed

test/jit/test_symbolic_shape_analysis.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313

1414
# XXX: still in prototype
1515
class TestSymbolicShapeAnalysis(JitTestCase):
16+
def setUp(self):
17+
self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
18+
torch._C._jit_set_symbolic_shapes_test_mode(True)
19+
20+
def tearDown(self):
21+
torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled)
22+
1623
def test_shape_analysis(self):
1724
@torch.jit.script
1825
def foo(x, y):

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def _jit_nvfuser_enabled() -> _bool: ...
213213
def _llvm_enabled() -> _bool: ...
214214
def _jit_override_can_fuse_on_cpu(override: _bool): ...
215215
def _jit_override_can_fuse_on_gpu(override: _bool): ...
216+
def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
217+
def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
216218
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
217219
def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ...
218220
def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ...

torch/csrc/jit/passes/symbolic_shape_analysis.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,21 @@ pointwise ops)
3636
- Supporting returning partially evaluated shape compute graph
3737
*/
3838

39+
static bool symbolic_shape_analysis_test_mode = false;
40+
3941
namespace torch {
4042
namespace jit {
4143

44+
bool setSymbolicShapeAnalysisTestMode(bool value) {
45+
bool old_value = symbolic_shape_analysis_test_mode;
46+
symbolic_shape_analysis_test_mode = value;
47+
return old_value;
48+
}
49+
50+
bool symbolicShapeAnalysisTestModeEnabled() {
51+
return symbolic_shape_analysis_test_mode;
52+
}
53+
4254
// TODO: better registration mechanism
4355
std::mutex lock;
4456
std::unordered_map<std::string, std::shared_ptr<Graph>> operator_functions;
@@ -79,7 +91,14 @@ struct SymbolicShapeAnalyzer {
7991
auto type = node_->input(i)->type();
8092
if (auto tt = type->castRaw<TensorType>()) {
8193
c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
82-
if (symbolic_shapes.isComplete()) {
94+
95+
// for testing, we don't insert complete tensor shapes and rely on our
96+
// partial evaluation pipeline to propagate information.
97+
// this is a good proxy for our ability to propagate non-complete shape
98+
// information.
99+
100+
if (symbolic_shapes.isComplete() &&
101+
!symbolic_shape_analysis_test_mode) {
83102
replaceWithIValue(
84103
graph_->inputs().at(i), *tt->sizes().concrete_sizes());
85104
continue;

torch/csrc/jit/passes/symbolic_shape_analysis.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,12 @@ namespace jit {
1010

1111
TORCH_API void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph);
1212

13+
// don't insert complete tensor shapes in shape compute graphs and instead
14+
// rely on our partial evaluation pipeline to propagate information.
15+
// this is a good proxy for our ability to propagate non-complete shape
16+
// information.
17+
TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value);
18+
TORCH_API bool symbolicShapeAnalysisTestModeEnabled();
19+
1320
} // namespace jit
1421
} // namespace torch

torch/csrc/jit/python/init.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,12 @@ void initJITBindings(PyObject* module) {
179179
.def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph)
180180
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
181181
.def("_jit_pass_integer_value_refinement", RefineIntegerValues)
182+
.def(
183+
"_jit_set_symbolic_shapes_test_mode",
184+
&setSymbolicShapeAnalysisTestMode)
185+
.def(
186+
"_jit_symbolic_shapes_test_mode_enabled",
187+
&symbolicShapeAnalysisTestModeEnabled)
182188
.def(
183189
"_jit_pass_onnx_fold_if",
184190
[](std::shared_ptr<Graph>& graph) {

0 commit comments

Comments
 (0)