Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowException
from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
Expand All @@ -40,7 +40,7 @@


def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
if AIRFLOW_V_3_1_PLUS:
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
else:
Expand Down
48 changes: 48 additions & 0 deletions providers/standard/tests/unit/standard/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,54 @@ def test_xcom_push_skipped_tasks(self):
"skipped": ["empty_task"]
}

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 implementation is different")
def test_short_circuit_operator_skips_sensors(self):
"""Test that ShortCircuitOperator properly skips sensors in Airflow 3.x."""
from airflow.sdk.bases.sensor import BaseSensorOperator

# Create a sensor similar to S3FileSensor to reproduce the issue
class CustomS3Sensor(BaseSensorOperator):
def __init__(self, bucket_name: str, object_key: str, **kwargs):
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.object_key = object_key
self.timeout = 0
self.poke_interval = 0

def poke(self, context):
# Simulate sensor logic
return True

with self.dag_maker(self.dag_id):
# ShortCircuit that evaluates to False (should skip all downstream)
short_circuit = ShortCircuitOperator(
task_id="check_dis_is_mon_to_fri_not_holiday",
python_callable=lambda: False, # This causes skipping
)

sensor_task = CustomS3Sensor(
task_id="wait_for_ticker_to_secid_lookup_s3_file",
bucket_name="test-bucket",
object_key="ticker_to_secid_lookup.csv",
)

short_circuit >> sensor_task

dr = self.dag_maker.create_dagrun()

self.dag_maker.run_ti("check_dis_is_mon_to_fri_not_holiday", dr)

# Verify the sensor is included in the skip list by checking XCom
# (this was the bug - sensors were not being included in skip list)
tis = dr.get_task_instances()
xcom_data = tis[0].xcom_pull(task_ids="check_dis_is_mon_to_fri_not_holiday", key="skipmixin_key")

assert xcom_data is not None, "XCom data should exist"
skipped_task_ids = set(xcom_data.get("skipped", []))
assert "wait_for_ticker_to_secid_lookup_s3_file" in skipped_task_ids, (
"Sensor should be skipped by ShortCircuitOperator"
)


virtualenv_string_args: list[str] = []

Expand Down
81 changes: 81 additions & 0 deletions providers/standard/tests/unit/standard/utils/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,84 @@ def test_raise_exception_on_not_valid_branch_task_ids(self, dag_maker, branch_ta
error_message = r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: .*"
with pytest.raises(AirflowException, match=error_message):
SkipMixin().skip_all_except(ti=ti1, branch_task_ids=branch_task_ids)

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Issue only exists in Airflow 3.x")
def test_ensure_tasks_includes_sensors_airflow_3x(self, dag_maker):
"""Test that sensors (inheriting from airflow.sdk.BaseOperator) are properly handled by _ensure_tasks."""
from airflow.providers.standard.utils.skipmixin import _ensure_tasks
from airflow.sdk import BaseOperator as SDKBaseOperator
from airflow.sdk.bases.sensor import BaseSensorOperator

class DummySensor(BaseSensorOperator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.timeout = 0
self.poke_interval = 0

def poke(self, context):
return True

with dag_maker("dag_test_sensor_skipping") as dag:
regular_task = EmptyOperator(task_id="regular_task")
sensor_task = DummySensor(task_id="sensor_task")
downstream_task = EmptyOperator(task_id="downstream_task")

regular_task >> [sensor_task, downstream_task]

dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)

downstream_nodes = dag.get_task("regular_task").downstream_list
task_list = _ensure_tasks(downstream_nodes)

# Verify both the regular operator and sensor are included
task_ids = [t.task_id for t in task_list]
assert "sensor_task" in task_ids, "Sensor should be included in task list"
assert "downstream_task" in task_ids, "Regular task should be included in task list"
assert len(task_list) == 2, "Both tasks should be included"

# Also verify that the sensor is actually an instance of the correct BaseOperator
sensor_in_list = next((t for t in task_list if t.task_id == "sensor_task"), None)
assert sensor_in_list is not None, "Sensor task should be found in list"
assert isinstance(sensor_in_list, SDKBaseOperator), "Sensor should be instance of SDK BaseOperator"

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Integration test for Airflow 3.x sensor skipping")
def test_skip_sensor_in_branching_scenario(self, dag_maker):
"""Integration test: verify sensors are properly skipped by branching operators in Airflow 3.x."""
from airflow.sdk.bases.sensor import BaseSensorOperator

# Create a dummy sensor for testing
class DummySensor(BaseSensorOperator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.timeout = 0
self.poke_interval = 0

def poke(self, context):
return True

with dag_maker("dag_test_branch_sensor_skipping"):
branch_task = EmptyOperator(task_id="branch_task")
regular_task = EmptyOperator(task_id="regular_task")
sensor_task = DummySensor(task_id="sensor_task")
branch_task >> [regular_task, sensor_task]

dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)

dag_version = DagVersion.get_latest_version(branch_task.dag_id)
ti_branch = TI(branch_task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id)

# Test skipping the sensor (follow regular_task branch)
with pytest.raises(DownstreamTasksSkipped) as exc_info:
SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="regular_task")

# Verify that the sensor task is properly marked for skipping
skipped_tasks = set(exc_info.value.tasks)
assert ("sensor_task", -1) in skipped_tasks, "Sensor task should be marked for skipping"

# Test skipping the regular task (follow sensor_task branch)
with pytest.raises(DownstreamTasksSkipped) as exc_info:
SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="sensor_task")

# Verify that the regular task is properly marked for skipping
skipped_tasks = set(exc_info.value.tasks)
assert ("regular_task", -1) in skipped_tasks, "Regular task should be marked for skipping"