Skip to content

Commit bfc2efb

Browse files
authored
[Serve][LLM] Add /collective_rpc endpoint for RLHF weight synchronization (#59529)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
1 parent b036ba2 commit bfc2efb

File tree

9 files changed

+607
-0
lines changed

9 files changed

+607
-0
lines changed

python/ray/llm/_internal/serve/core/engine/protocol.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,29 @@ async def is_sleeping(self) -> bool:
211211
"""
212212
return False
213213

214+
async def collective_rpc(
215+
self,
216+
method: str,
217+
timeout: Optional[float] = None,
218+
args: tuple = (),
219+
kwargs: Optional[dict] = None,
220+
) -> list:
221+
"""Execute a collective RPC call on all workers.
222+
223+
This is used for RLHF workflows where a trainer needs to execute
224+
methods on all TP/PP workers (e.g., for weight synchronization).
225+
226+
Args:
227+
method: Name of the worker method to execute.
228+
timeout: Maximum time in seconds to wait for execution.
229+
args: Positional arguments to pass to the worker method.
230+
kwargs: Keyword arguments to pass to the worker method.
231+
232+
Returns:
233+
A list containing the results from each worker.
234+
"""
235+
raise NotImplementedError("collective_rpc is not implemented for this engine")
236+
214237
async def pause(self, **kwargs: Any) -> None:
215238
"""Pause the engine.
216239

python/ray/llm/_internal/serve/core/ingress/dev_ingress.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
POST /resume: Resume generation after pause
1414
GET /is_paused: Check if engine is paused
1515
POST /reset_prefix_cache: Reset the KV prefix cache
16+
POST /collective_rpc: Execute collective RPC on all workers
1617
"""
1718

1819
import pprint
@@ -30,6 +31,7 @@
3031
)
3132
from ray.llm._internal.serve.core.ingress.mixins import (
3233
CacheManagerIngressMixin,
34+
CollectiveRpcIngressMixin,
3335
PausableIngressMixin,
3436
SleepableIngressMixin,
3537
)
@@ -43,6 +45,7 @@
4345
# Endpoint map for DevIngress - includes all default endpoints plus control plane
4446
DEV_ENDPOINTS = {
4547
**CacheManagerIngressMixin.ENDPOINTS,
48+
**CollectiveRpcIngressMixin.ENDPOINTS,
4649
**PausableIngressMixin.ENDPOINTS,
4750
**SleepableIngressMixin.ENDPOINTS,
4851
**DEFAULT_ENDPOINTS,
@@ -54,6 +57,7 @@ class DevIngress(
5457
SleepableIngressMixin,
5558
PausableIngressMixin,
5659
CacheManagerIngressMixin,
60+
CollectiveRpcIngressMixin,
5761
):
5862
"""OpenAI-compatible ingress with additional control plane endpoints.
5963
@@ -62,11 +66,13 @@ class DevIngress(
6266
- RL training: Put engines to sleep during training, wake up for rollouts
6367
- Memory management: Free GPU memory between inference workloads
6468
- Benchmarking: Reset prefix cache between benchmark rounds
69+
- RLHF: Execute collective RPC on all workers for weight updates
6570
6671
Control plane endpoints provided by mixins:
6772
- SleepableIngressMixin: /sleep, /wakeup, /is_sleeping
6873
- PausableIngressMixin: /pause, /resume, /is_paused
6974
- CacheManagerIngressMixin: /reset_prefix_cache
75+
- CollectiveRpcIngressMixin: /collective_rpc
7076
7177
WARNING: These endpoints are intended for development and trusted
7278
environments. Consider access control in production deployments.
@@ -83,6 +89,7 @@ def build_dev_openai_app(builder_config: Dict) -> Application:
8389
- /sleep, /wakeup, /is_sleeping (sleep mode - offloads weights to CPU)
8490
- /pause, /resume, /is_paused (pause mode - keeps weights in GPU)
8591
- /reset_prefix_cache (cache management)
92+
- /collective_rpc (RLHF - execute RPC on all workers)
8693
8794
Args:
8895
builder_config: Configuration conforming to LLMServingArgs.

python/ray/llm/_internal/serve/core/ingress/mixins/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
CacheManagerIngressMixin,
88
ResetPrefixCacheRequest,
99
)
10+
from ray.llm._internal.serve.core.ingress.mixins.collective_rpc import (
11+
CollectiveRpcIngressMixin,
12+
CollectiveRpcRequest,
13+
CollectiveRpcResponse,
14+
ReplicaResult,
15+
)
1016
from ray.llm._internal.serve.core.ingress.mixins.pausable import (
1117
IsPausedResponse,
1218
PausableIngressMixin,
@@ -22,8 +28,12 @@
2228

2329
__all__ = [
2430
"CacheManagerIngressMixin",
31+
"CollectiveRpcIngressMixin",
2532
"PausableIngressMixin",
2633
"SleepableIngressMixin",
34+
"CollectiveRpcRequest",
35+
"CollectiveRpcResponse",
36+
"ReplicaResult",
2737
"ResetPrefixCacheRequest",
2838
"PauseRequest",
2939
"ResumeRequest",
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Collective RPC ingress mixin.
2+
3+
Provides HTTP endpoint for collective RPC operations across all replicas
4+
and their workers, enabling RLHF workflows where a trainer forms a single
5+
NCCL process group with all TP/PP workers across all replicas.
6+
"""
7+
8+
from typing import Any, Dict, List, Optional
9+
10+
from pydantic import BaseModel, Field
11+
12+
from ray.llm._internal.serve.core.ingress.mixins.broadcastable import (
13+
ReplicaBroadcastable,
14+
)
15+
from ray.llm._internal.serve.observability.logging import get_logger
16+
17+
logger = get_logger(__name__)
18+
19+
20+
# --- Pydantic Models ---
21+
22+
23+
class CollectiveRpcRequest(BaseModel):
24+
"""Request to execute a collective RPC on all replicas."""
25+
26+
model: str
27+
method: str
28+
args: List[Any] = Field(default_factory=list)
29+
kwargs: Dict[str, Any] = Field(default_factory=dict)
30+
timeout: Optional[float] = None
31+
32+
33+
class ReplicaResult(BaseModel):
34+
"""Result from a single replica containing all worker results."""
35+
36+
replica: int
37+
worker_results: List[Any]
38+
39+
40+
class CollectiveRpcResponse(BaseModel):
41+
"""Response containing results from all replicas."""
42+
43+
results: List[ReplicaResult]
44+
45+
46+
# --- Mixin ---
47+
48+
49+
class CollectiveRpcIngressMixin(ReplicaBroadcastable):
50+
"""Ingress mixin for /collective_rpc endpoint.
51+
52+
Adds control plane endpoint for executing collective RPC calls across
53+
all replicas and their workers. This is used for RLHF workflows where
54+
a trainer needs to communicate with all TP/PP workers across all replicas.
55+
"""
56+
57+
ENDPOINTS = {
58+
"collective_rpc": lambda app: app.post("/collective_rpc"),
59+
}
60+
61+
async def collective_rpc(self, body: CollectiveRpcRequest) -> CollectiveRpcResponse:
62+
"""Execute a collective RPC on all replicas for the specified model.
63+
64+
This broadcasts the RPC call to all replicas, and each replica
65+
executes the call on all its workers (TP/PP ranks).
66+
67+
Args:
68+
body: Request containing the model ID, method name, args, kwargs,
69+
and optional timeout.
70+
71+
Returns:
72+
CollectiveRpcResponse with results from all replicas.
73+
"""
74+
logger.info(
75+
"Executing collective_rpc '%s' for model %s with args=%s, kwargs=%s",
76+
body.method,
77+
body.model,
78+
body.args,
79+
body.kwargs,
80+
)
81+
82+
# Broadcast to all replicas - each replica returns a list of worker results
83+
replica_results = await self._broadcast_to_replicas(
84+
body.model,
85+
"collective_rpc",
86+
kwargs={
87+
"method": body.method,
88+
"args": tuple(body.args),
89+
"kwargs": body.kwargs,
90+
"timeout": body.timeout,
91+
},
92+
)
93+
94+
# Format results with replica index for debugging
95+
results = [
96+
ReplicaResult(replica=i, worker_results=worker_results or [])
97+
for i, worker_results in enumerate(replica_results or [])
98+
]
99+
100+
return CollectiveRpcResponse(results=results)

python/ray/llm/_internal/serve/core/server/llm_server.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,42 @@ async def stop_profile(self) -> None:
589589
logger.error("Engine stop profile failed in LLMServer.stop_profile: %s", e)
590590
raise e
591591

592+
async def collective_rpc(
593+
self,
594+
method: str,
595+
timeout: Optional[float] = None,
596+
args: tuple = (),
597+
kwargs: Optional[dict] = None,
598+
) -> list:
599+
"""Execute a collective RPC call on all workers.
600+
601+
This is used for RLHF workflows where a trainer needs to execute
602+
methods on all TP/PP workers (e.g., for weight synchronization).
603+
604+
Args:
605+
method: Name of the worker method to execute.
606+
timeout: Maximum time in seconds to wait for execution.
607+
args: Positional arguments to pass to the worker method.
608+
kwargs: Keyword arguments to pass to the worker method.
609+
610+
Returns:
611+
A list containing the results from each worker.
612+
"""
613+
if self.engine is None:
614+
return []
615+
try:
616+
return await self.engine.collective_rpc(
617+
method=method,
618+
timeout=timeout,
619+
args=args,
620+
kwargs=kwargs,
621+
)
622+
except Exception as e:
623+
logger.error(
624+
"Engine collective_rpc failed in LLMServer.collective_rpc: %s", e
625+
)
626+
raise e
627+
592628
async def llm_config(self) -> Optional[LLMConfig]:
593629
return self._llm_config
594630

python/ray/llm/_internal/serve/engines/vllm/vllm_engine.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,3 +641,32 @@ async def start_profile(self) -> None:
641641
async def stop_profile(self) -> None:
642642
assert self._engine_client is not None, "engine_client is not initialized"
643643
await self._engine_client.stop_profile()
644+
645+
async def collective_rpc(
646+
self,
647+
method: str,
648+
timeout: Optional[float] = None,
649+
args: tuple = (),
650+
kwargs: Optional[dict] = None,
651+
) -> list:
652+
"""Execute a collective RPC call on all vLLM workers.
653+
654+
This is used for RLHF workflows where a trainer needs to execute
655+
methods on all TP/PP workers (e.g., for weight synchronization).
656+
657+
Args:
658+
method: Name of the worker method to execute.
659+
timeout: Maximum time in seconds to wait for execution.
660+
args: Positional arguments to pass to the worker method.
661+
kwargs: Keyword arguments to pass to the worker method.
662+
663+
Returns:
664+
A list containing the results from each worker.
665+
"""
666+
assert self._engine_client is not None, "engine_client is not initialized"
667+
return await self._engine_client.collective_rpc(
668+
method=method,
669+
timeout=timeout,
670+
args=args,
671+
kwargs=kwargs or {},
672+
)

0 commit comments

Comments
 (0)