-
Couldn't load subscription status.
- Fork 560
Support PreemptionSyncManager in XlaCoordinator #5733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
| } | ||
| | ||
| void DistributedRuntime::ActivatePreemptionSyncManager() { | ||
| if (preemption_sync_manager_ == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any harm in initializing the PreemptionSyncManager when you initialize the xla::DistributedRuntimeService and Client? In general, I try to avoid cases where you "partially" construct an object and leave potential bugs to happen later (like calling ReachedSyncPoint before ActivatePreemptionSyncManager)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hesistant to do that, since it will register a SIGTERM handler which will cause any intentional SIGTERMs to be ignored. Open to revisiting, let me know which approach you think makes more sense!
| | ||
| // DistributedRuntime serves as the point of entry for all operations which | ||
| // required the XLA distributed runtime, such as preemption coordination. | ||
| class DistributedRuntime { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really dislike the naming choice for the upstream xla::DistributedRuntime, since it's not actually a distributed runtime. Since this class is becoming more than just a wrapper around xla::DistributedRuntimeService and xla::DistributedRuntimeClient, what do you think of changing the name to something more intuitive? e.g. XlaCoordinator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree, XlaCoordinator it is! I'll update the pybinds as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left as-is for now to keep this change minimal, we can revisit in the upcoming refactor.
| // The PreemptionSyncManager must be activated within the DistributedRuntime. | ||
| // Returns true when the input step has been identified as a sync point, and | ||
| // false otherwise. | ||
| bool ReachedSyncPoint(int step); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it makes more sense to expose the tsl::PreemptionSyncManager directly as we do with the xla::DistributedRuntimeClient? Or do we want to restrict access to the underlying object?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered that, but if the PreemptionSyncManager outlives the DistributedRuntimeClient, the program will segfault... 😢 I figured it's better to keep it hidden to avoid that edge case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @vanbasten23 and @will-cromar! I'll update to address the feedback.
| } | ||
| | ||
| void DistributedRuntime::ActivatePreemptionSyncManager() { | ||
| if (preemption_sync_manager_ == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hesistant to do that, since it will register a SIGTERM handler which will cause any intentional SIGTERMs to be ignored. Open to revisiting, let me know which approach you think makes more sense!
| | ||
| // DistributedRuntime serves as the point of entry for all operations which | ||
| // required the XLA distributed runtime, such as preemption coordination. | ||
| class DistributedRuntime { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree, XlaCoordinator it is! I'll update the pybinds as well.
| // The PreemptionSyncManager must be activated within the DistributedRuntime. | ||
| // Returns true when the input step has been identified as a sync point, and | ||
| // false otherwise. | ||
| bool ReachedSyncPoint(int step); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered that, but if the PreemptionSyncManager outlives the DistributedRuntimeClient, the program will segfault... 😢 I figured it's better to keep it hidden to avoid that edge case.
| Oh one more thing, could you help check if we have a test verifying whether the distributed runtime service is always turned down every time? I'd imagine if we comment out the line |
a5c5e74 to 10dc8db Compare | @vanbasten23 @will-cromar I've updated to have the ComputationClient own the XlaCoordinator. Please take a second look when you get a chance! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
FYI, our style guide cautions against forward declarations of entities in other projects, even if it saves compile time: https://google.github.io/styleguide/cppguide.html#Forward_Declarations
If you forward-declared the DistributedRuntime classes to unravel e.g. a macro conflict or circular dependency, please leave a comment explaining why.
| virtual void WaitDeviceOps(const std::vector<std::string>& devices) = 0; | ||
| | ||
| // Check whether the XlaCoordinator has been initialized. | ||
| virtual bool CoordinatorInitialized() const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these need to be virtual? It looks like the implementations below don't depend on the underlying runtime client.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The XlaCoordinator depends on PJRT, so I kept it separate. Though I guess that's not a strong justification...
f20d522 to 1502aa8 Compare b954fed to 9aafdf5 Compare | | ||
| // Forward declare XlaCoordinator to avoid logging macro redefinition from the | ||
| // transitively included PJRT header. | ||
| // TODO(jonbolin): We need a way to ensure the right macros are included |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logging macros are cursed
d2a36a6 to 673cad9 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I had a hard time of imagining how this is going to be incorporated to the ckpt mgr. @jonb377 Can you point me to some examples?
@alanwaketan CheckpointManager will initialize the PreemptionSyncManager on construction and call into |
| @alanwaketan See 5fdce13 for the intended usage. |
I see, that makes sense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
0d35028 to cfbb93a Compare cfbb93a to 6aab523 Compare 6aab523 to d880511 Compare d880511 to 287fa96 Compare 287fa96 to bf27ec9 Compare * Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
To support autocheckpointing upon preemption, we need to access a PreemptionSyncManager to identify sync points when a preemption has occurred.
This change additionally refactors the DistributedRuntime to be owned by the ComputationClient, since in the GPU case the ComputationClient has a direct dependency on the DistributedRuntimeClient.
This change adds the PreemptionSyncManager to the new XlaCoordinator class. The PreemptionSyncManager has the side effect of registering a SIGTERM handler, so it is not enabled by default.