Skip to content

Commit e7ef57f

Browse files
[data] Update streaming_repartition and map_batches_fusion (#59476)
Analysis of the two operator patterns: ## Streaming_repartition → map_batches | | Number of `map_batches` tasks | |----------------------|---------------------------------------------------------------------------| | **Fused** | `num_input_blocks` (which is ≤ number of output blocks of StreamingRepartition) | | **Not fused** | number of output blocks of StreamingRepartition | When fused, the number of tasks equals the number of input blocks, which is ≤ the number of output blocks of StreamingRepartition. If StreamingRepartition is supposed to break down blocks to increase parallelism, that won't happen when fused. So we don't fuse. --- ## Map_batches → streaming_repartition `batch_size % target_num_rows == 0` | | Number of `map_batches` tasks | |----------------------|-------------------------------| | **Fused** | == total_rows / batch_size | | **Not fused** | == total_rows / batch_size | So, the fusion doesn’t affect the parallelism. --- Thus, we currently disable the `Streaming_repartition → map_batches` fusion and enable the fusion when `batch_size % target_num_rows == 0` for `Map_batches → streaming_repartition`. --------- Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com> Signed-off-by: xgui <xgui@anyscale.com> Signed-off-by: Alexey Kudinkin <alexey.kudinkin@gmail.com> Co-authored-by: Alexey Kudinkin <alexey.kudinkin@gmail.com>
1 parent 196c678 commit e7ef57f

File tree

3 files changed

+49
-9
lines changed

3 files changed

+49
-9
lines changed

python/ray/data/_internal/logical/rules/operator_fusion.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,30 @@ def _fuse_streaming_repartition_operators_in_dag(
9090
9191
This will ensure the map_batch's function receive the correct number of rows.
9292
We also ensure the output rows is `batch_size`.
93+
94+
Why don't we fuse `StreamingRepartition -> MapBatches`?
95+
96+
----------------------------------------------------------------------------------------------------
97+
| | Number of `map_batches` tasks |
98+
|----------------------|---------------------------------------------------------------------------|
99+
| Fused | num_input_blocks (which is <= num output blocks of StreamingRepartition) |
100+
| Not fused | num output blocks of StreamingRepartition |
101+
----------------------------------------------------------------------------------------------------
102+
103+
When fused, the number of tasks equals the number of input blocks, which is
104+
<= the number of output blocks of StreamingRepartition. If StreamingRepartition
105+
is supposed to break down blocks to increase parallelism, that won't happen
106+
when fused. So we don't fuse.
107+
108+
Why do we fuse `MapBatches -> StreamingRepartition` (when `batch_size % target_num_rows == 0`)?
109+
----------------------------------------------------------
110+
| | Number of `map_batches` tasks |
111+
|----------------------|--------------------------------|
112+
| Fused | total_rows / batch_size |
113+
| Not fused | total_rows / batch_size |
114+
----------------------------------------------------------
115+
116+
Parallelism is unchanged, so we fuse to avoid intermediate materialization.
93117
"""
94118
upstream_ops = dag.input_dependencies
95119
while (
@@ -252,8 +276,17 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
252276
if isinstance(down_logical_op, StreamingRepartition):
253277
return (
254278
isinstance(up_logical_op, MapBatches)
279+
and up_logical_op._batch_size is not None
280+
and down_logical_op.target_num_rows_per_block is not None
281+
and down_logical_op.target_num_rows_per_block > 0
282+
# When the batch_size is a multiple of target_num_rows_per_block, fusing would still produce exactly identical sequence of blocks.
283+
# See `_fuse_streaming_repartition_operators_in_dag` docstring for details.
284+
# TODO: when the StreamingRepartition supports none_strict_mode, we can fuse
285+
# `MapBatches -> StreamingRepartition` no matter what the `batch_size` and `target_num_rows` are.
286+
# https://anyscale1.atlassian.net/browse/DATA-1731
255287
and up_logical_op._batch_size
256-
== down_logical_op.target_num_rows_per_block
288+
% down_logical_op.target_num_rows_per_block
289+
== 0
257290
)
258291
# Other operators cannot fuse with StreamingRepartition.
259292
if isinstance(up_logical_op, StreamingRepartition):
@@ -276,7 +309,9 @@ def _get_fused_streaming_repartition_operator(
276309
up_logical_op = self._op_map.pop(up_op)
277310
assert isinstance(up_logical_op, MapBatches)
278311
assert isinstance(down_logical_op, StreamingRepartition)
279-
assert up_logical_op._batch_size == down_logical_op.target_num_rows_per_block
312+
assert (
313+
up_logical_op._batch_size % down_logical_op.target_num_rows_per_block == 0
314+
)
280315
batch_size = up_logical_op._batch_size
281316

282317
compute = self._fuse_compute_strategy(

python/ray/data/tests/test_operator_fusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,10 @@ def test_zero_copy_fusion_eliminate_build_output_blocks(
745745
@pytest.mark.parametrize(
746746
"order,target_num_rows,batch_size,should_fuse",
747747
[
748-
# map_batches -> streaming_repartition: fuse when batch_size == target_num_rows
748+
# map_batches -> streaming_repartition: fuse when batch_size is a multiple of target_num_rows
749749
("map_then_sr", 20, 20, True),
750750
("map_then_sr", 20, 10, False),
751+
("map_then_sr", 20, 40, True),
751752
("map_then_sr", 20, None, False),
752753
# streaming_repartition -> map_batches: not fused
753754
("sr_then_map", 20, 20, False),

python/ray/data/tests/test_repartition_e2e.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,22 +233,25 @@ def test_repartition_empty_datasets(ray_start_regular_shared_2_cpus, shuffle):
233233

234234

235235
@pytest.mark.parametrize("streaming_repartition_first", [True, False])
236+
@pytest.mark.parametrize("n_target_num_rows", [1, 5])
236237
def test_streaming_repartition_write_with_operator_fusion(
237238
ray_start_regular_shared_2_cpus,
238239
tmp_path,
239240
disable_fallback_to_object_extension,
240241
streaming_repartition_first,
242+
n_target_num_rows,
241243
):
242244
"""Test that write with streaming repartition produces exact partitions
243245
with operator fusion.
244246
This test verifies:
245247
* StreamingRepartition and MapBatches operators are fused, with both orders
246248
"""
249+
target_num_rows = 20
247250

248251
def fn(batch):
249252
# Get number of rows from the first column (batch is a dict of column_name -> array)
250253
num_rows = len(batch["id"])
251-
assert num_rows == 20, f"Expected batch size 20, got {num_rows}"
254+
assert num_rows == b_s, f"Expected batch size {b_s}, got {num_rows}"
252255
return batch
253256

254257
# Configure shuffle strategy
@@ -270,12 +273,13 @@ def fn(batch):
270273
ds = ds.repartition(target_num_rows_per_block=30)
271274

272275
# Verify fusion of StreamingRepartition and MapBatches operators
276+
b_s = target_num_rows * n_target_num_rows
273277
if streaming_repartition_first:
274-
ds = ds.repartition(target_num_rows_per_block=20)
275-
ds = ds.map_batches(fn, batch_size=20)
278+
ds = ds.repartition(target_num_rows_per_block=target_num_rows)
279+
ds = ds.map_batches(fn, batch_size=b_s)
276280
else:
277-
ds = ds.map_batches(fn, batch_size=20)
278-
ds = ds.repartition(target_num_rows_per_block=20)
281+
ds = ds.map_batches(fn, batch_size=b_s)
282+
ds = ds.repartition(target_num_rows_per_block=target_num_rows)
279283
planner = create_planner()
280284
physical_plan = planner.plan(ds._logical_plan)
281285
physical_plan = PhysicalOptimizer().optimize(physical_plan)
@@ -286,7 +290,7 @@ def fn(batch):
286290
else:
287291
assert (
288292
physical_op.name
289-
== "MapBatches(fn)->StreamingRepartition[num_rows_per_block=20]"
293+
== f"MapBatches(fn)->StreamingRepartition[num_rows_per_block={target_num_rows}]"
290294
)
291295

292296
# Write output to local Parquet files partitioned by key

0 commit comments

Comments
 (0)