Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions litellm/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
contains_tag,
)
import httpx
from urllib.parse import urlparse, urlunparse


class BedrockError(Exception):
Expand Down Expand Up @@ -541,6 +542,19 @@ def callback(request, **kwargs):
return callback


def change_url_for_cf(api_base: str):
"""Closure to change the host and path after signing."""
def callback(request, **kwargs):
"""Actual callback function that Boto3 will call."""
# botocore.awsrequest.AWSRequest
api_base_url = urlparse(api_base)
old_url = urlparse(request.url)
new_path = api_base_url.path.lstrip('/') + old_url.path
request.url = urlunparse((api_base_url.scheme, api_base_url.netloc, new_path, old_url.params, old_url.query, old_url.fragment))

return callback


def init_bedrock_client(
region_name=None,
aws_access_key_id: Optional[str] = None,
Expand Down Expand Up @@ -607,6 +621,11 @@ def init_bedrock_client(
else:
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"

real_endpoint_url = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please can we not have 2 floating variables both with 'endpoint_url' :D

is it possible for us to have cloudflare logic in 'cloudflare.py' and just have that function wrap this?

having the cloudflare logic in here, looks like it complicates this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could enforce types on functions in here, so any wrapper function can always know what it's going to get

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, isn't having a floating variable the easiest way of doing this? :)

We could separate the logic out, but it's very Bedrock + Cloudflare AI Gateway specific. e.g. the code for Azure OpenAI + Cloudflare AI Gateway is totally different.

if "gateway.ai.cloudflare.com" in endpoint_url:
real_endpoint_url = endpoint_url
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"

import boto3

if isinstance(timeout, float):
Expand Down Expand Up @@ -674,6 +693,11 @@ def init_bedrock_client(
if extra_headers:
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))

if real_endpoint_url:
client.meta.events.register(
"before-send.bedrock-runtime.*", change_url_for_cf(real_endpoint_url)
)

return client


Expand Down
35 changes: 35 additions & 0 deletions litellm/tests/test_bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,41 @@ def test_completion_bedrock_claude_completion_auth():

# test_completion_bedrock_claude_completion_auth()

def test_completion_bedrock_cloudflare_ai_gateway():
print("calling bedrock with cloudflare ai gateway")
import os

aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_region_name = os.environ["AWS_REGION_NAME"]
aws_bedrock_runtime_endpoint = f"https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/aws-bedrock/bedrock-runtime/{aws_region_name}"

os.environ.pop("AWS_ACCESS_KEY_ID", None)
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
os.environ.pop("AWS_REGION_NAME", None)

try:
response = completion(
model="bedrock/amazon.titan-text-express-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint
)
# Add any assertions here to check the response
print(response)

os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
os.environ["AWS_REGION_NAME"] = aws_region_name
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")


def test_completion_bedrock_claude_2_1_completion_auth():
print("calling bedrock claude 2.1 completion params auth")
Expand Down