Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import aiohttp
import requests
from aiohttp.client_exceptions import ClientConnectorError
from asgiref.sync import sync_to_async
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
from requests.exceptions import JSONDecodeError
Expand Down Expand Up @@ -122,6 +123,10 @@ def __init__(
self.token_timeout_seconds = 10
self.caller = caller

# Cache for lack of an async @cached_property
self._async_databricks_conn: Connection | None = None
self._async_host: str | None = None

def my_after_func(retry_state):
self._log_request_error(retry_state.attempt_number, retry_state.outcome)

Expand All @@ -141,6 +146,16 @@ def my_after_func(retry_state):
def databricks_conn(self) -> Connection:
return self.get_connection(self.databricks_conn_id) # type: ignore[return-value]

async def adatabricks_conn(self) -> Connection:
if self._async_databricks_conn is None:
if hasattr(self, "aget_connection"):
self._async_databricks_conn = await self.aget_connection(self.databricks_conn_id)
else:
self._async_databricks_conn = await sync_to_async(self.get_connection)(
self.databricks_conn_id
)
return self._async_databricks_conn # type: ignore[return-value]

def get_conn(self) -> Connection:
return self.databricks_conn

Expand All @@ -164,12 +179,16 @@ def user_agent_value(self) -> str:

@cached_property
def host(self) -> str:
if "host" in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson["host"])
else:
host = self._parse_host(self.databricks_conn.host)
raw_host = self.databricks_conn.extra_dejson.get("host") or self.databricks_conn.host or ""
return self._parse_host(raw_host)

return host
async def ahost(self) -> str:
"""Fetch host from connection async."""
if self._async_host is None:
conn = await self.adatabricks_conn()
raw_host = conn.extra_dejson.get("host") or conn.host or ""
self._async_host = self._parse_host(raw_host)
return self._async_host

async def __aenter__(self):
self._session = aiohttp.ClientSession()
Expand Down Expand Up @@ -232,7 +251,9 @@ def _get_sp_token(self, resource: str) -> str:
with attempt:
resp = requests.post(
resource,
auth=HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password),
auth=HTTPBasicAuth(
self.databricks_conn.login or "", self.databricks_conn.password or ""
),
data="grant_type=client_credentials&scope=all-apis",
headers={
**self.user_agent_header,
Expand Down Expand Up @@ -264,11 +285,12 @@ async def _a_get_sp_token(self, resource: str) -> str:

self.log.info("Existing Service Principal token is expired, or going to expire soon. Refreshing...")
try:
conn = await self.adatabricks_conn()
async for attempt in self._a_get_retry_object():
with attempt:
async with self._session.post(
resource,
auth=aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password),
auth=aiohttp.BasicAuth(conn.login or "", conn.password or ""),
data="grant_type=client_credentials&scope=all-apis",
headers={
**self.user_agent_header,
Expand Down Expand Up @@ -313,8 +335,8 @@ def _get_aad_token(self, resource: str) -> str:
token = ManagedIdentityCredential().get_token(f"{resource}/.default")
else:
credential = ClientSecretCredential(
client_id=self.databricks_conn.login,
client_secret=self.databricks_conn.password,
client_id=self.databricks_conn.login or "",
client_secret=self.databricks_conn.password or "",
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
)
token = credential.get_token(f"{resource}/.default")
Expand Down Expand Up @@ -354,16 +376,17 @@ async def _a_get_aad_token(self, resource: str) -> str:
ManagedIdentityCredential as AsyncManagedIdentityCredential,
)

conn = await self.adatabricks_conn()
async for attempt in self._a_get_retry_object():
with attempt:
if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
if conn.extra_dejson.get("use_azure_managed_identity", False):
async with AsyncManagedIdentityCredential() as credential:
token = await credential.get_token(f"{resource}/.default")
else:
async with AsyncClientSecretCredential(
client_id=self.databricks_conn.login,
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
client_id=conn.login or "",
client_secret=conn.password or "",
tenant_id=conn.extra_dejson["azure_tenant_id"],
) as credential:
token = await credential.get_token(f"{resource}/.default")
jsn = {
Expand Down Expand Up @@ -493,11 +516,10 @@ async def _a_get_aad_headers(self) -> dict:
:return: dictionary with filled AAD headers
"""
headers = {}
if "azure_resource_id" in self.databricks_conn.extra_dejson:
conn = await self.adatabricks_conn()
if "azure_resource_id" in conn.extra_dejson:
mgmt_token = await self._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
headers["X-Databricks-Azure-Workspace-Resource-Id"] = self.databricks_conn.extra_dejson[
"azure_resource_id"
]
headers["X-Databricks-Azure-Workspace-Resource-Id"] = conn.extra_dejson["azure_resource_id"]
headers["X-Databricks-Azure-SP-Management-Token"] = mgmt_token
return headers

Expand Down Expand Up @@ -585,32 +607,33 @@ def _get_token(self, raise_error: bool = False) -> str | None:
return None

async def _a_get_token(self, raise_error: bool = False) -> str | None:
if "token" in self.databricks_conn.extra_dejson:
conn = await self.adatabricks_conn()
if "token" in conn.extra_dejson:
self.log.info(
"Using token auth. For security reasons, please set token in Password field instead of extra"
)
return self.databricks_conn.extra_dejson["token"]
if not self.databricks_conn.login and self.databricks_conn.password:
return conn.extra_dejson["token"]
if not conn.login and conn.password:
self.log.debug("Using token auth.")
return self.databricks_conn.password
if "azure_tenant_id" in self.databricks_conn.extra_dejson:
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
return conn.password
if "azure_tenant_id" in conn.extra_dejson:
if conn.login == "" or conn.password == "":
raise AirflowException("Azure SPN credentials aren't provided")
self.log.debug("Using AAD Token for SPN.")
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
if conn.extra_dejson.get("use_azure_managed_identity", False):
self.log.debug("Using AAD Token for managed identity.")
await self._a_check_azure_metadata_service()
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
if self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY, False):
if conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY, False):
self.log.debug("Using AzureDefaultCredential for authentication.")

return await self._a_get_aad_token_for_default_az_credential(DEFAULT_DATABRICKS_SCOPE)
if self.databricks_conn.extra_dejson.get("service_principal_oauth", False):
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
if conn.extra_dejson.get("service_principal_oauth", False):
if conn.login == "" or conn.password == "":
raise AirflowException("Service Principal credentials aren't provided")
self.log.debug("Using Service Principal Token.")
return await self._a_get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host))
return await self._a_get_sp_token(OIDC_TOKEN_SERVICE_URL.format(conn.host))
if raise_error:
raise AirflowException("Token authentication isn't configured")

Expand All @@ -624,6 +647,12 @@ def _endpoint_url(self, endpoint):
schema = self.databricks_conn.schema or "https"
return f"{schema}://{self.host}{port}/{endpoint}"

async def _a_endpoint_url(self, endpoint):
conn = await self.adatabricks_conn()
port = f":{conn.port}" if conn.port else ""
schema = conn.schema or "https"
return f"{schema}://{await self.ahost()}{port}/{endpoint}"

def _do_api_call(
self,
endpoint_info: tuple[str, str],
Expand Down Expand Up @@ -654,7 +683,7 @@ def _do_api_call(
auth = _TokenAuth(token)
else:
self.log.info("Using basic auth.")
auth = HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password)
auth = HTTPBasicAuth(self.databricks_conn.login or "", self.databricks_conn.password or "")

request_func: Any
if method == "GET":
Expand Down Expand Up @@ -710,7 +739,7 @@ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, A
method, endpoint = endpoint_info

full_endpoint = f"api/{endpoint}"
url = self._endpoint_url(full_endpoint)
url = await self._a_endpoint_url(full_endpoint)

aad_headers = await self._a_get_aad_headers()
headers = {**self.user_agent_header, **aad_headers}
Expand All @@ -721,7 +750,8 @@ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, A
auth = BearerAuth(token)
else:
self.log.info("Using basic auth.")
auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password)
conn = await self.adatabricks_conn()
auth = aiohttp.BasicAuth(conn.login or "", conn.password or "")

request_func: Any
if method == "GET":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
)
from airflow.providers.databricks.utils import databricks as utils

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

TASK_ID = "databricks-operator"
DEFAULT_CONN_ID = "databricks_default"
NOTEBOOK_TASK = {"notebook_path": "/test"}
Expand Down Expand Up @@ -2226,7 +2228,14 @@ def setup_connections(self, create_connection_without_db):
@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post")
async def test_get_run_state(self, mock_post, mock_get):
async def test_get_run_state(self, mock_post, mock_get, mock_supervisor_comms):
if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms:
from airflow.sdk.execution_time.comms import ConnectionResult
mock_supervisor_comms.asend.return_value = ConnectionResult(
conn_id=DEFAULT_CONN_ID,
conn_type="databricks",
)

mock_post.return_value.__aenter__.return_value.json = AsyncMock(
return_value=create_sp_token_for_resource()
)
Expand Down
Loading
Loading