File tree Expand file tree Collapse file tree 5 files changed +42
-1
lines changed
Expand file tree Collapse file tree 5 files changed +42
-1
lines changed Original file line number Diff line number Diff line change 1313
1414# XXX: still in prototype
1515class TestSymbolicShapeAnalysis (JitTestCase ):
16+ def setUp (self ):
17+ self .prev_symbolic_shapes_test_enabled = torch ._C ._jit_symbolic_shapes_test_mode_enabled ()
18+ torch ._C ._jit_set_symbolic_shapes_test_mode (True )
19+
20+ def tearDown (self ):
21+ torch ._C ._jit_set_symbolic_shapes_test_mode (self .prev_symbolic_shapes_test_enabled )
22+
1623 def test_shape_analysis (self ):
1724 @torch .jit .script
1825 def foo (x , y ):
Original file line number Diff line number Diff line change @@ -213,6 +213,8 @@ def _jit_nvfuser_enabled() -> _bool: ...
213213def _llvm_enabled() -> _bool: ...
214214def _jit_override_can_fuse_on_cpu(override: _bool): ...
215215def _jit_override_can_fuse_on_gpu(override: _bool): ...
216+ def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
217+ def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
216218def _jit_set_texpr_fuser_enabled(enable: _bool): ...
217219def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ...
218220def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ...
Original file line number Diff line number Diff line change @@ -36,9 +36,21 @@ pointwise ops)
3636- Supporting returning partially evaluated shape compute graph
3737*/
3838
39+ static bool symbolic_shape_analysis_test_mode = false ;
40+
3941namespace torch {
4042namespace jit {
4143
44+ bool setSymbolicShapeAnalysisTestMode (bool value) {
45+ bool old_value = symbolic_shape_analysis_test_mode;
46+ symbolic_shape_analysis_test_mode = value;
47+ return old_value;
48+ }
49+
50+ bool symbolicShapeAnalysisTestModeEnabled () {
51+ return symbolic_shape_analysis_test_mode;
52+ }
53+
4254// TODO: better registration mechanism
4355std::mutex lock;
4456std::unordered_map<std::string, std::shared_ptr<Graph>> operator_functions;
@@ -79,7 +91,14 @@ struct SymbolicShapeAnalyzer {
7991 auto type = node_->input (i)->type ();
8092 if (auto tt = type->castRaw <TensorType>()) {
8193 c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes ();
82- if (symbolic_shapes.isComplete ()) {
94+
95+ // for testing, we don't insert complete tensor shapes and rely on our
96+ // partial evaluation pipeline to propagate information.
97+ // this is a good proxy for our ability to propagate non-complete shape
98+ // information.
99+
100+ if (symbolic_shapes.isComplete () &&
101+ !symbolic_shape_analysis_test_mode) {
83102 replaceWithIValue (
84103 graph_->inputs ().at (i), *tt->sizes ().concrete_sizes ());
85104 continue ;
Original file line number Diff line number Diff line change @@ -10,5 +10,12 @@ namespace jit {
1010
1111TORCH_API void PropagateShapesOnGraph (std::shared_ptr<Graph>& graph);
1212
13+ // don't insert complete tensor shapes in shape compute graphs and instead
14+ // rely on our partial evaluation pipeline to propagate information.
15+ // this is a good proxy for our ability to propagate non-complete shape
16+ // information.
17+ TORCH_API bool setSymbolicShapeAnalysisTestMode (bool value);
18+ TORCH_API bool symbolicShapeAnalysisTestModeEnabled ();
19+
1320} // namespace jit
1421} // namespace torch
Original file line number Diff line number Diff line change @@ -179,6 +179,12 @@ void initJITBindings(PyObject* module) {
179179 .def (" _jit_pass_propagate_shapes_on_graph" , PropagateShapesOnGraph)
180180 .def (" _jit_pass_onnx_function_substitution" , ONNXFunctionCallSubstitution)
181181 .def (" _jit_pass_integer_value_refinement" , RefineIntegerValues)
182+ .def (
183+ " _jit_set_symbolic_shapes_test_mode" ,
184+ &setSymbolicShapeAnalysisTestMode)
185+ .def (
186+ " _jit_symbolic_shapes_test_mode_enabled" ,
187+ &symbolicShapeAnalysisTestModeEnabled)
182188 .def (
183189 " _jit_pass_onnx_fold_if" ,
184190 [](std::shared_ptr<Graph>& graph) {
You can’t perform that action at this time.
0 commit comments