-
Couldn't load subscription status.
- Fork 560
[Distributed] Make xm.all_gather a single graph in Dynamo #4922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| | ||
| | ||
| g_xrt_world_size = None | ||
| def xrt_world_size(defval=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab This is the python function that I want to use in 'allow_in_graph'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, if you are going to manually cache the value of this anyway, then i think just using allow_in_graph without the caching is the same thing.
the issue with allow_in_graph is if you expect the value to be updated on later iterations, allow_in_graph will prevent that from working. But if you expect the value to be a constant for the whole execution, then allow_in_graph will capture the value during compile and reuse it later (e.g. cache it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to use allow_in_graph. However, it looks like that the function I pass into allow_in_graph will need to return a tensor type? If the function return a bool or int, is there a workaround?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is how I use allow_in_graph:
ptxla@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla$ git diff diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6ff4a5a5..a07ff472 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -6,6 +6,7 @@ import time from typing import List, Optional import torch import torch.distributed._functional_collectives +from torch._dynamo import allow_in_graph import torch.nn.functional as F import torch_xla from torch_xla.experimental import pjrt @@ -1088,3 +1089,6 @@ def optimization_barrier_(tensors): tensors (List[torch.Tensor]): List of `torch.Tensor` to add barrier to. """ torch_xla._XLAC._xla_optimization_barrier_(tensors) + + +allow_in_graph(xrt_world_size) And here is the error:
root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU python test/test_mp_all_gather.py concurrent.futures.process._RemoteTraceback: """ Traceback (most recent call last): File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker r = call_item.fn(*call_item.args, **call_item.kwargs) File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk return [fn(*args) for args in chunk] File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp> return [fn(*args) for args in chunk] File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 92, in wrapper return fn(*args, **kwargs) File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device replica_results = list( File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator yield fs.pop().result() File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result return self.__get_result() File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result raise self._exception File "/usr/local/lib/python3.8/concurrent/futures/thread.py", line 57, in run result = self.fn(*self.args, **self.kwargs) File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 238, in _thread_fn return fn() File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 341, in __call__ self.fn(global_ordinal(), *self.args, **self.kwargs) File "/workspaces/work/pytorch/xla/test/test_mp_all_gather.py", line 32, in _mp_fn result = compiled_all_gather(ordinal_tensor, dim=0) File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn return fn(*args, **kwargs) File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 405, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn return fn(*args, **kwargs) File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 331, in _convert_frame_assert return _compile( File "/workspaces/work/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper r = func(*args, **kwargs) File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _compile out_code = transform_code_object(code, transform) File "/workspaces/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object transformations(instructions, code_options) File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 386, in transform tracer.run() File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1972, in run super().run() File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run and self.step() File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step getattr(self, inst.opname)(inst) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper return inner_fn(self, inst) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1138, in CALL_FUNCTION_KW self.call_function(fn, args, kwargs) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function self.push(fn.call_function(self, args, kwargs)) File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 269, in call_function return super().call_function(tx, args, kwargs) File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 102, in call_function return tx.inline_user_function_return( File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 557, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2077, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2155, in inline_call_ tracer.run() File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run and self.step() File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step getattr(self, inst.opname)(inst) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper return inner_fn(self, inst) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1138, in CALL_FUNCTION_KW self.call_function(fn, args, kwargs) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function self.push(fn.call_function(self, args, kwargs)) File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 269, in call_function return super().call_function(tx, args, kwargs) File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 102, in call_function return tx.inline_user_function_return( File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 557, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2077, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2155, in inline_call_ tracer.run() File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run and self.step() File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step getattr(self, inst.opname)(inst) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper return inner_fn(self, inst) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1086, in CALL_FUNCTION self.call_function(fn, args, {}) File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function self.push(fn.call_function(self, args, kwargs)) File "/workspaces/work/pytorch/torch/_dynamo/variables/torch.py", line 603, in call_function tensor_variable = wrap_fx_proxy( File "/workspaces/work/pytorch/torch/_dynamo/variables/builder.py", line 923, in wrap_fx_proxy return wrap_fx_proxy_cls( File "/workspaces/work/pytorch/torch/_dynamo/variables/builder.py", line 1098, in wrap_fx_proxy_cls unimplemented( File "/workspaces/work/pytorch/torch/_dynamo/exc.py", line 107, in unimplemented raise Unsupported(msg) torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function <function xrt_world_size at 0x7fb184e94ca0> from user code: File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 550, in all_gather return _all_gather_using_all_reduce( File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 511, in _all_gather_using_all_reduce left, right = ordinal, xrt_world_size() - 1 - ordinal Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: torch._dynamo.config.suppress_errors = True """ The above exception was the direct cause of the following exception: Traceback (most recent call last): File "test/test_mp_all_gather.py", line 66, in <module> xmp.spawn(_mp_fn, args=()) File "/workspaces/work/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 367, in spawn return pjrt.spawn(fn, nprocs, start_method, args) File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 365, in spawn _run_multiprocess(spawn_fn, start_method=start_method) File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 92, in wrapper return fn(*args, **kwargs) File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess replica_results = list( File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 323, in <genexpr> itertools.chain.from_iterable( File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists for element in iterable: File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator yield fs.pop().result() File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result return self.__get_result() File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result raise self._exception torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function <function xrt_world_size at 0x7fb184e94ca0> from user code: File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 550, in all_gather return _all_gather_using_all_reduce( File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 511, in _all_gather_using_all_reduce left, right = ordinal, xrt_world_size() - 1 - ordinal Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: torch._dynamo.config.suppress_errors = True root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# | return g_xrt_world_size | ||
| | ||
| g_ordinal = None | ||
| def get_ordinal(defval=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab This is the python function that I want to use in 'allow_in_graph'.
7673dd6 to 89f7c6c Compare 8681385 to 674e53c Compare torch_xla/core/xla_model.py Outdated
| """ | ||
| if pjrt.using_pjrt(): | ||
| return pjrt.global_ordinal() | ||
| global g_ordinal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will break PJRT + v3 cases, the implementation we had checks the devices in
m.def("_xla_get_default_device_ordinal", []() { std::string device_str = GetCurrentThreadDevice(); torch::lazy::BackendDevice device = bridge::AtenDeviceToXlaDevice(device_str); return device.ordinal(); }); There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this is confusing. That call is in the C++ layer. Then allow_in_graph won't work here.
But we can work around by caching a map...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure actually, effectively this function won't return constant in the v3 cases because there are two devices per process. This is a bit tricky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we can bypass the v3 cases for now, what's going to happen if you add a condition here to skip this cahce value of we are on v3 + PJRT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will introduce graph breaks in Dynamo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this use thread local storage instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool. Was not aware python has this feature. Let me work on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamo doesn't seem to compile in the same thread as the user code. threading.local doesn't work here.
| participating replicas. | ||
| """ | ||
| if pin_layout and xla_device_hw( | ||
| value.device) in ('TPU', 'GPU', 'XPU') and output == None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we had it because CPU was not supported at some point. Do you need to remove it because it will break dynamo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
3fb46da to 26cfb00 Compare | Thanks Jack for approving. |
Summary:
This pull request makes xm.all_gather, the _all_gather_using_all_reduce path, a single graph in Dynamo. To do that, it:
Test Plan:
PJRT_DEVICE=TPU python test/test_mp_all_gather.py