Skip to content

Commit 2f9747f

Browse files
committed
ManagedIdentityClient sends xms_cc and token_sha256_to_refresh to SF
1 parent 30dce4e commit 2f9747f

File tree

2 files changed

+106
-19
lines changed

2 files changed

+106
-19
lines changed

msal/managed_identity.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# All rights reserved.
33
#
44
# This code is licensed under the MIT License.
5+
import hashlib
56
import json
67
import logging
78
import os
@@ -10,7 +11,7 @@
1011
import time
1112
from urllib.parse import urlparse # Python 3+
1213
from collections import UserDict # Python 3+
13-
from typing import Optional, Union # Needed in Python 3.7 & 3.8
14+
from typing import List, Optional, Union # Needed in Python 3.7 & 3.8
1415
from .token_cache import TokenCache
1516
from .individual_cache import _IndividualCache as IndividualCache
1617
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
@@ -162,6 +163,7 @@ def __init__(
162163
http_client,
163164
token_cache=None,
164165
http_cache=None,
166+
client_capabilities: Optional[List[str]] = None,
165167
):
166168
"""Create a managed identity client.
167169
@@ -192,6 +194,17 @@ def __init__(
192194
Optional. It has the same characteristics as the
193195
:paramref:`msal.ClientApplication.http_cache`.
194196
197+
:param list[str] client_capabilities: (optional)
198+
Allows configuration of one or more client capabilities, e.g. ["CP1"].
199+
200+
Client capability is meant to inform the Microsoft identity platform
201+
(STS) what this client is capable for,
202+
so STS can decide to turn on certain features.
203+
204+
Implementation details:
205+
Client capability in Managed Identity is relayed as-is
206+
via ``xms_cc`` parameter on the wire.
207+
195208
Recipe 1: Hard code a managed identity for your app::
196209
197210
import msal, requests
@@ -238,6 +251,7 @@ def __init__(
238251
http_cache=http_cache,
239252
)
240253
self._token_cache = token_cache or TokenCache()
254+
self._client_capabilities = client_capabilities
241255

242256
def _get_instance(self):
243257
if self.__instance is None:
@@ -266,8 +280,7 @@ def acquire_token_for_client(
266280
and then a *claims challenge* will be returned by the target resource,
267281
as a `claims_challenge` directive in the `www-authenticate` header,
268282
even if the app developer did not opt in for the "CP1" client capability.
269-
Upon receiving a `claims_challenge`, MSAL will skip a token cache read,
270-
and will attempt to acquire a new token.
283+
Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.
271284
272285
.. note::
273286
@@ -278,11 +291,13 @@ def acquire_token_for_client(
278291
This is a service-side behavior that cannot be changed by this library.
279292
`Azure VM docs <https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>`_
280293
"""
294+
access_token_to_refresh = None # This could become a public parameter in the future
281295
access_token_from_cache = None
282296
client_id_in_cache = self._managed_identity.get(
283297
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
284298
now = time.time()
285-
if not claims_challenge: # Then attempt token cache search
299+
if True: # Attempt cache search even if receiving claims_challenge,
300+
# because we want to locate the existing token (if any) and refresh it
286301
matches = self._token_cache.find(
287302
self._token_cache.CredentialType.ACCESS_TOKEN,
288303
target=[resource],
@@ -297,6 +312,11 @@ def acquire_token_for_client(
297312
expires_in = int(entry["expires_on"]) - now
298313
if expires_in < 5*60: # Then consider it expired
299314
continue # Removal is not necessary, it will be overwritten
315+
if claims_challenge and not access_token_to_refresh:
316+
# Since caller did not pinpoint the token causing claims challenge,
317+
# we have to assume it is the first token we found in cache.
318+
access_token_to_refresh = entry["secret"]
319+
break
300320
logger.debug("Cache hit an AT")
301321
access_token_from_cache = { # Mimic a real response
302322
"access_token": entry["secret"],
@@ -310,7 +330,13 @@ def acquire_token_for_client(
310330
break # With a fallback in hand, we break here to go refresh
311331
return access_token_from_cache # It is still good as new
312332
try:
313-
result = _obtain_token(self._http_client, self._managed_identity, resource)
333+
result = _obtain_token(
334+
self._http_client, self._managed_identity, resource,
335+
access_token_sha256_to_refresh=hashlib.sha256(
336+
access_token_to_refresh.encode("utf-8")).hexdigest()
337+
if access_token_to_refresh else None,
338+
client_capabilities=self._client_capabilities,
339+
)
314340
if "access_token" in result:
315341
expires_in = result.get("expires_in", 3600)
316342
if "refresh_in" not in result and expires_in >= 7200:
@@ -385,8 +411,12 @@ def get_managed_identity_source():
385411
return DEFAULT_TO_VM
386412

387413

388-
def _obtain_token(http_client, managed_identity, resource):
389-
# A unified low-level API that talks to different Managed Identity
414+
def _obtain_token(
415+
http_client, managed_identity, resource,
416+
*,
417+
access_token_sha256_to_refresh: Optional[str] = None,
418+
client_capabilities: Optional[List[str]] = None,
419+
):
390420
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
391421
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
392422
):
@@ -402,6 +432,8 @@ def _obtain_token(http_client, managed_identity, resource):
402432
os.environ["IDENTITY_HEADER"],
403433
os.environ["IDENTITY_SERVER_THUMBPRINT"],
404434
resource,
435+
access_token_sha256_to_refresh=access_token_sha256_to_refresh,
436+
client_capabilities=client_capabilities,
405437
)
406438
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
407439
return _obtain_token_on_app_service(
@@ -553,6 +585,9 @@ def _obtain_token_on_machine_learning(
553585

554586
def _obtain_token_on_service_fabric(
555587
http_client, endpoint, identity_header, server_thumbprint, resource,
588+
*,
589+
access_token_sha256_to_refresh: str = None,
590+
client_capabilities: Optional[List[str]] = None,
556591
):
557592
"""Obtains token for
558593
`Service Fabric <https://learn.microsoft.com/en-us/azure/service-fabric/>`_
@@ -563,7 +598,12 @@ def _obtain_token_on_service_fabric(
563598
logger.debug("Obtaining token via managed identity on Azure Service Fabric")
564599
resp = http_client.get(
565600
endpoint,
566-
params={"api-version": "2019-07-01-preview", "resource": resource},
601+
params={k: v for k, v in {
602+
"api-version": "2019-07-01-preview",
603+
"resource": resource,
604+
"token_sha256_to_refresh": access_token_sha256_to_refresh,
605+
"xms_cc": ",".join(client_capabilities) if client_capabilities else None,
606+
}.items() if v is not None},
567607
headers={"Secret": identity_header},
568608
)
569609
try:
@@ -584,7 +624,7 @@ def _obtain_token_on_service_fabric(
584624
"ArgumentNullOrEmpty": "invalid_scope",
585625
}
586626
return {
587-
"error": error_mapping.get(payload["error"]["code"], "invalid_request"),
627+
"error": error_mapping.get(error.get("code"), "invalid_request"),
588628
"error_description": resp.text,
589629
}
590630
except json.decoder.JSONDecodeError:

tests/test_mi.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import hashlib
12
import json
23
import os
34
import sys
45
import time
6+
from typing import List, Optional
57
import unittest
68
try:
79
from unittest.mock import patch, ANY, mock_open, Mock
@@ -52,15 +54,23 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f
5254
class ClientTestCase(unittest.TestCase):
5355
maxDiff = None
5456

55-
def setUp(self):
56-
self.app = ManagedIdentityClient(
57+
def _build_app(
58+
self,
59+
*,
60+
client_capabilities: Optional[List[str]] = None,
61+
):
62+
return ManagedIdentityClient(
5763
{ # Here we test it with the raw dict form, to test that
5864
# the client has no hard dependency on ManagedIdentity object
5965
"ManagedIdentityIdType": "SystemAssigned", "Id": None,
6066
},
6167
http_client=requests.Session(),
68+
client_capabilities=client_capabilities,
6269
)
6370

71+
def setUp(self):
72+
self.app = self._build_app()
73+
6474
def test_error_out_on_invalid_input(self):
6575
with self.assertRaises(ManagedIdentityError):
6676
ManagedIdentityClient({"foo": "bar"}, http_client=requests.Session())
@@ -79,7 +89,13 @@ def assertCacheStatus(self, app):
7989
"Should have expected client_id")
8090
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")
8191

82-
def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
92+
def _test_happy_path(
93+
self, app, mocked_http, expires_in, *, resource="R", claims_challenge=None,
94+
):
95+
"""It tests a normal token request that is expected to hit IdP,
96+
a subsequent same token request that is expected to hit cache,
97+
and then a request with claims_challenge that shall hit IdP again.
98+
"""
8399
result = app.acquire_token_for_client(resource=resource)
84100
mocked_http.assert_called()
85101
call_count = mocked_http.call_count
@@ -115,7 +131,8 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
115131
expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on,
116132
"Should have a refresh_on time around the middle of the token's life")
117133

118-
result = app.acquire_token_for_client(resource=resource, claims_challenge="foo")
134+
result = app.acquire_token_for_client(
135+
resource=resource, claims_challenge=claims_challenge or "placeholder")
119136
self.assertEqual("identity_provider", result["token_source"], "Should miss cache")
120137

121138

@@ -132,6 +149,9 @@ def _test_happy_path(self) -> callable:
132149

133150
def test_happy_path_of_vm(self):
134151
self._test_happy_path().assert_called_with(
152+
# The last call contained claims_challenge
153+
# but since IMDS doesn't support token_sha256_to_refresh,
154+
# the request shall remain the same as before
135155
'http://169.254.169.254/metadata/identity/oauth2/token',
136156
params={'api-version': '2018-02-01', 'resource': 'R'},
137157
headers={'Metadata': 'true'},
@@ -244,19 +264,46 @@ def test_machine_learning_error_should_be_normalized(self):
244264
"IDENTITY_SERVER_THUMBPRINT": "bar",
245265
})
246266
class ServiceFabricTestCase(ClientTestCase):
267+
access_token = "AT"
268+
access_token_sha256 = hashlib.sha256(access_token.encode()).hexdigest()
247269

248-
def _test_happy_path(self, app):
270+
def _test_happy_path(self, app, *, claims_challenge=None) -> callable:
249271
expires_in = 1234
250272
with patch.object(app._http_client, "get", return_value=MinimalResponse(
251273
status_code=200,
252-
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
253-
int(time.time()) + expires_in),
274+
text='{"access_token": "%s", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
275+
self.access_token, int(time.time()) + expires_in),
254276
)) as mocked_method:
255277
super(ServiceFabricTestCase, self)._test_happy_path(
256-
app, mocked_method, expires_in)
278+
app, mocked_method, expires_in, claims_challenge=claims_challenge)
279+
return mocked_method
257280

258-
def test_happy_path(self):
259-
self._test_happy_path(self.app)
281+
def test_happy_path_with_client_capabilities_should_relay_capabilities(self):
282+
self._test_happy_path(self._build_app(client_capabilities=["foo", "bar"])).assert_called_with(
283+
'http://localhost',
284+
params={
285+
'api-version': '2019-07-01-preview',
286+
'resource': 'R',
287+
'token_sha256_to_refresh': self.access_token_sha256,
288+
"xms_cc": "foo,bar",
289+
},
290+
headers={'Secret': 'foo'},
291+
)
292+
293+
def test_happy_path_with_claim_challenge_should_send_sha256_to_provider(self):
294+
self._test_happy_path(
295+
self._build_app(client_capabilities=[]), # Test empty client_capabilities
296+
claims_challenge='{"access_token": {"nbf": {"essential": true, "value": "1563308371"}}}',
297+
).assert_called_with(
298+
'http://localhost',
299+
params={
300+
'api-version': '2019-07-01-preview',
301+
'resource': 'R',
302+
'token_sha256_to_refresh': self.access_token_sha256,
303+
# There is no xms_cc in this case
304+
},
305+
headers={'Secret': 'foo'},
306+
)
260307

261308
def test_unified_api_service_should_ignore_unnecessary_client_id(self):
262309
self._test_happy_path(ManagedIdentityClient(

0 commit comments

Comments
 (0)