Skip to content

Commit 3fe300c

Browse files
authored
Remove CoalescePartitions insertion from Joins (apache#15570)
* refactor: Catch PartitionMode:Auto in execute explicitly * refactor(hash_join): Move coalesce logic into function * refactor(hash_join): Move coalesce logic out of collect_left_input * refactor(hash_join): Execute build side earlier * chore(hash_join): Drop unnecessary clippy hint \o/ * Back out "refactor: Catch PartitionMode:Auto in execute explicitly" This backs out commit 4346cb8. * Remove CoalescePartitions from CollectLeft-HashJoins * Fix imports * Fix tests * Check CollectLeft-HJ distribution in execute * Fix repeated execute; propagate execute error * Remove CoalescePartitions from nested_loop_join * Execute left side earlier in nested_loop_join * Remove CoalescePartitions from cross_join * Execute left side earlier in cross_join * Remove unused import * Fix docs * Remove dead comment * Throw error on unexpected distribution
1 parent 8a6e98d commit 3fe300c

File tree

4 files changed

+66
-68
lines changed

4 files changed

+66
-68
lines changed

datafusion/physical-plan/src/joins/cross_join.rs

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use super::utils::{
2525
BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
2626
StatefulStreamResult,
2727
};
28-
use crate::coalesce_partitions::CoalescePartitionsExec;
2928
use crate::execution_plan::{boundedness_from_children, EmissionType};
3029
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
3130
use crate::projection::{
@@ -189,19 +188,11 @@ impl CrossJoinExec {
189188

190189
/// Asynchronously collect the result of the left child
191190
async fn load_left_input(
192-
left: Arc<dyn ExecutionPlan>,
193-
context: Arc<TaskContext>,
191+
stream: SendableRecordBatchStream,
194192
metrics: BuildProbeJoinMetrics,
195193
reservation: MemoryReservation,
196194
) -> Result<JoinLeftData> {
197-
// merge all left parts into a single stream
198-
let left_schema = left.schema();
199-
let merge = if left.output_partitioning().partition_count() != 1 {
200-
Arc::new(CoalescePartitionsExec::new(left))
201-
} else {
202-
left
203-
};
204-
let stream = merge.execute(0, context)?;
195+
let left_schema = stream.schema();
205196

206197
// Load all batches and count the rows
207198
let (batches, _metrics, reservation) = stream
@@ -291,6 +282,13 @@ impl ExecutionPlan for CrossJoinExec {
291282
partition: usize,
292283
context: Arc<TaskContext>,
293284
) -> Result<SendableRecordBatchStream> {
285+
if self.left.output_partitioning().partition_count() != 1 {
286+
return internal_err!(
287+
"Invalid CrossJoinExec, the output partition count of the left child must be 1,\
288+
consider using CoalescePartitionsExec or the EnforceDistribution rule"
289+
);
290+
}
291+
294292
let stream = self.right.execute(partition, Arc::clone(&context))?;
295293

296294
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
@@ -303,14 +301,15 @@ impl ExecutionPlan for CrossJoinExec {
303301
let enforce_batch_size_in_joins =
304302
context.session_config().enforce_batch_size_in_joins();
305303

306-
let left_fut = self.left_fut.once(|| {
307-
load_left_input(
308-
Arc::clone(&self.left),
309-
context,
304+
let left_fut = self.left_fut.try_once(|| {
305+
let left_stream = self.left.execute(0, context)?;
306+
307+
Ok(load_left_input(
308+
left_stream,
310309
join_metrics.clone(),
311310
reservation,
312-
)
313-
});
311+
))
312+
})?;
314313

315314
if enforce_batch_size_in_joins {
316315
Ok(Box::pin(CrossJoinStream {

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ use crate::projection::{
4141
use crate::spill::get_record_batch_memory_size;
4242
use crate::ExecutionPlanProperties;
4343
use crate::{
44-
coalesce_partitions::CoalescePartitionsExec,
4544
common::can_project,
4645
handle_state,
4746
hash_utils::create_hashes,
@@ -792,34 +791,42 @@ impl ExecutionPlan for HashJoinExec {
792791
);
793792
}
794793

794+
if self.mode == PartitionMode::CollectLeft && left_partitions != 1 {
795+
return internal_err!(
796+
"Invalid HashJoinExec, the output partition count of the left child must be 1 in CollectLeft mode,\
797+
consider using CoalescePartitionsExec or the EnforceDistribution rule"
798+
);
799+
}
800+
795801
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
796802
let left_fut = match self.mode {
797-
PartitionMode::CollectLeft => self.left_fut.once(|| {
803+
PartitionMode::CollectLeft => self.left_fut.try_once(|| {
804+
let left_stream = self.left.execute(0, Arc::clone(&context))?;
805+
798806
let reservation =
799807
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());
800-
collect_left_input(
801-
None,
808+
809+
Ok(collect_left_input(
802810
self.random_state.clone(),
803-
Arc::clone(&self.left),
811+
left_stream,
804812
on_left.clone(),
805-
Arc::clone(&context),
806813
join_metrics.clone(),
807814
reservation,
808815
need_produce_result_in_final(self.join_type),
809816
self.right().output_partitioning().partition_count(),
810-
)
811-
}),
817+
))
818+
})?,
812819
PartitionMode::Partitioned => {
820+
let left_stream = self.left.execute(partition, Arc::clone(&context))?;
821+
813822
let reservation =
814823
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
815824
.register(context.memory_pool());
816825

817826
OnceFut::new(collect_left_input(
818-
Some(partition),
819827
self.random_state.clone(),
820-
Arc::clone(&self.left),
828+
left_stream,
821829
on_left.clone(),
822-
Arc::clone(&context),
823830
join_metrics.clone(),
824831
reservation,
825832
need_produce_result_in_final(self.join_type),
@@ -930,36 +937,22 @@ impl ExecutionPlan for HashJoinExec {
930937

931938
/// Reads the left (build) side of the input, buffering it in memory, to build a
932939
/// hash table (`LeftJoinData`)
933-
#[allow(clippy::too_many_arguments)]
934940
async fn collect_left_input(
935-
partition: Option<usize>,
936941
random_state: RandomState,
937-
left: Arc<dyn ExecutionPlan>,
942+
left_stream: SendableRecordBatchStream,
938943
on_left: Vec<PhysicalExprRef>,
939-
context: Arc<TaskContext>,
940944
metrics: BuildProbeJoinMetrics,
941945
reservation: MemoryReservation,
942946
with_visited_indices_bitmap: bool,
943947
probe_threads_count: usize,
944948
) -> Result<JoinLeftData> {
945-
let schema = left.schema();
946-
947-
let (left_input, left_input_partition) = if let Some(partition) = partition {
948-
(left, partition)
949-
} else if left.output_partitioning().partition_count() != 1 {
950-
(Arc::new(CoalescePartitionsExec::new(left)) as _, 0)
951-
} else {
952-
(left, 0)
953-
};
954-
955-
// Depending on partition argument load single partition or whole left side in memory
956-
let stream = left_input.execute(left_input_partition, Arc::clone(&context))?;
949+
let schema = left_stream.schema();
957950

958951
// This operation performs 2 steps at once:
959952
// 1. creates a [JoinHashMap] of all batches from the stream
960953
// 2. stores the batches in a vector.
961954
let initial = (Vec::new(), 0, metrics, reservation);
962-
let (batches, num_rows, metrics, mut reservation) = stream
955+
let (batches, num_rows, metrics, mut reservation) = left_stream
963956
.try_fold(initial, |mut acc, batch| async {
964957
let batch_size = get_record_batch_memory_size(&batch);
965958
// Reserve memory for incoming batch
@@ -1655,6 +1648,7 @@ impl EmbeddedProjection for HashJoinExec {
16551648
#[cfg(test)]
16561649
mod tests {
16571650
use super::*;
1651+
use crate::coalesce_partitions::CoalescePartitionsExec;
16581652
use crate::test::TestMemoryExec;
16591653
use crate::{
16601654
common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
@@ -2105,6 +2099,7 @@ mod tests {
21052099
let left =
21062100
TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
21072101
.unwrap();
2102+
let left = Arc::new(CoalescePartitionsExec::new(left));
21082103

21092104
let right = build_table(
21102105
("a1", &vec![1, 2, 3]),
@@ -2177,6 +2172,7 @@ mod tests {
21772172
let left =
21782173
TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
21792174
.unwrap();
2175+
let left = Arc::new(CoalescePartitionsExec::new(left));
21802176
let right = build_table(
21812177
("a2", &vec![20, 30, 10]),
21822178
("b2", &vec![5, 6, 4]),

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ use super::utils::{
2828
need_produce_result_in_final, reorder_output_after_swap, swap_join_projection,
2929
BatchSplitter, BatchTransformer, NoopBatchTransformer, StatefulStreamResult,
3030
};
31-
use crate::coalesce_partitions::CoalescePartitionsExec;
3231
use crate::common::can_project;
3332
use crate::execution_plan::{boundedness_from_children, EmissionType};
3433
use crate::joins::utils::{
@@ -483,23 +482,31 @@ impl ExecutionPlan for NestedLoopJoinExec {
483482
partition: usize,
484483
context: Arc<TaskContext>,
485484
) -> Result<SendableRecordBatchStream> {
485+
if self.left.output_partitioning().partition_count() != 1 {
486+
return internal_err!(
487+
"Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\
488+
consider using CoalescePartitionsExec or the EnforceDistribution rule"
489+
);
490+
}
491+
486492
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
487493

488494
// Initialization reservation for load of inner table
489495
let load_reservation =
490496
MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
491497
.register(context.memory_pool());
492498

493-
let inner_table = self.inner_table.once(|| {
494-
collect_left_input(
495-
Arc::clone(&self.left),
496-
Arc::clone(&context),
499+
let inner_table = self.inner_table.try_once(|| {
500+
let stream = self.left.execute(0, Arc::clone(&context))?;
501+
502+
Ok(collect_left_input(
503+
stream,
497504
join_metrics.clone(),
498505
load_reservation,
499506
need_produce_result_in_final(self.join_type),
500507
self.right().output_partitioning().partition_count(),
501-
)
502-
});
508+
))
509+
})?;
503510

504511
let batch_size = context.session_config().batch_size();
505512
let enforce_batch_size_in_joins =
@@ -610,20 +617,13 @@ impl ExecutionPlan for NestedLoopJoinExec {
610617

611618
/// Asynchronously collect input into a single batch, and creates `JoinLeftData` from it
612619
async fn collect_left_input(
613-
input: Arc<dyn ExecutionPlan>,
614-
context: Arc<TaskContext>,
620+
stream: SendableRecordBatchStream,
615621
join_metrics: BuildProbeJoinMetrics,
616622
reservation: MemoryReservation,
617623
with_visited_left_side: bool,
618624
probe_threads_count: usize,
619625
) -> Result<JoinLeftData> {
620-
let schema = input.schema();
621-
let merge = if input.output_partitioning().partition_count() != 1 {
622-
Arc::new(CoalescePartitionsExec::new(input))
623-
} else {
624-
input
625-
};
626-
let stream = merge.execute(0, context)?;
626+
let schema = stream.schema();
627627

628628
// Load all batches and count the rows
629629
let (batches, metrics, mut reservation) = stream

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ pub fn build_join_schema(
324324
}
325325

326326
/// A [`OnceAsync`] runs an `async` closure once, where multiple calls to
327-
/// [`OnceAsync::once`] return a [`OnceFut`] that resolves to the result of the
327+
/// [`OnceAsync::try_once`] return a [`OnceFut`] that resolves to the result of the
328328
/// same computation.
329329
///
330330
/// This is useful for joins where the results of one child are needed to proceed
@@ -337,7 +337,7 @@ pub fn build_join_schema(
337337
///
338338
/// Each output partition waits on the same `OnceAsync` before proceeding.
339339
pub(crate) struct OnceAsync<T> {
340-
fut: Mutex<Option<OnceFut<T>>>,
340+
fut: Mutex<Option<SharedResult<OnceFut<T>>>>,
341341
}
342342

343343
impl<T> Default for OnceAsync<T> {
@@ -356,19 +356,22 @@ impl<T> Debug for OnceAsync<T> {
356356

357357
impl<T: 'static> OnceAsync<T> {
358358
/// If this is the first call to this function on this object, will invoke
359-
/// `f` to obtain a future and return a [`OnceFut`] referring to this
359+
/// `f` to obtain a future and return a [`OnceFut`] referring to this. `f`
360+
/// may fail, in which case its error is returned.
360361
///
361362
/// If this is not the first call, will return a [`OnceFut`] referring
362-
/// to the same future as was returned by the first call
363-
pub(crate) fn once<F, Fut>(&self, f: F) -> OnceFut<T>
363+
/// to the same future as was returned by the first call - or the same
364+
/// error if the initial call to `f` failed.
365+
pub(crate) fn try_once<F, Fut>(&self, f: F) -> Result<OnceFut<T>>
364366
where
365-
F: FnOnce() -> Fut,
367+
F: FnOnce() -> Result<Fut>,
366368
Fut: Future<Output = Result<T>> + Send + 'static,
367369
{
368370
self.fut
369371
.lock()
370-
.get_or_insert_with(|| OnceFut::new(f()))
372+
.get_or_insert_with(|| f().map(OnceFut::new).map_err(Arc::new))
371373
.clone()
374+
.map_err(DataFusionError::Shared)
372375
}
373376
}
374377

0 commit comments

Comments
 (0)