Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9fbc518
feat(event_handler): Support multiple origins for cors
michaelbrewer Feb 7, 2022
7d8aa07
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Feb 9, 2022
a1d73df
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Feb 9, 2022
e23462d
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Feb 14, 2022
45a4376
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Feb 15, 2022
672f84b
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Feb 21, 2022
180111d
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Feb 25, 2022
db14220
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Mar 6, 2022
e2c9daa
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Mar 8, 2022
f7bd782
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Mar 9, 2022
ef41dbc
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Mar 10, 2022
587f447
Merge branch 'awslabs:develop' into feat-1006
michaelbrewer Apr 6, 2022
e653753
Merge branch 'develop' into feat-1006
michaelbrewer May 16, 2022
87c1da0
Merge branch 'develop' into feat-1006
michaelbrewer May 18, 2022
026fdd5
Merge branch 'develop' into feat-1006
michaelbrewer May 19, 2022
3bdb928
Merge branch 'develop' into feat-1006
michaelbrewer May 20, 2022
11f7289
Merge branch 'develop' into feat-1006
michaelbrewer May 21, 2022
1adce36
Merge branch 'develop' into feat-1006
michaelbrewer Jun 2, 2022
cb4870a
Merge branch 'develop' into feat-1006
michaelbrewer Jun 7, 2022
8ff8ac3
Merge branch 'develop' into feat-1006
michaelbrewer Jun 13, 2022
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
# API GW/ALB decode non-safe URI chars; we must support them too
_UNSAFE_URI = "%<> \[\]{}|^" # noqa: W605
_NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
_NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"


class ProxyEventType(Enum):
Expand Down Expand Up @@ -92,6 +92,7 @@ def __init__(
expose_headers: Optional[List[str]] = None,
max_age: Optional[int] = None,
allow_credentials: bool = False,
allow_origins: Optional[List[str]] = None,
):
"""
Parameters
Expand All @@ -111,15 +112,16 @@ def __init__(
A boolean value that sets the value of `Access-Control-Allow-Credentials`
"""
self.allow_origin = allow_origin
self.allow_origins = allow_origins
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
self.expose_headers = expose_headers or []
self.max_age = max_age
self.allow_credentials = allow_credentials

def to_dict(self) -> Dict[str, str]:
def to_dict(self, current_event: Optional[BaseProxyEvent]) -> Dict[str, str]:
"""Builds the configured Access-Control http headers"""
headers = {
"Access-Control-Allow-Origin": self.allow_origin,
"Access-Control-Allow-Origin": self._allow_origin(current_event),
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
}
if self.expose_headers:
Expand All @@ -130,6 +132,16 @@ def to_dict(self) -> Dict[str, str]:
headers["Access-Control-Allow-Credentials"] = "true"
return headers

def _allow_origin(self, current_event: Optional[BaseProxyEvent]) -> str:
if self.allow_origins is None or current_event is None:
return self.allow_origin

origin = current_event.get_header_value("origin", "")
if origin in self.allow_origins:
return origin

return self.allow_origin


class Response:
"""Response data class that provides greater control over what is returned from the proxy event"""
Expand Down Expand Up @@ -180,13 +192,19 @@ def __init__(
class ResponseBuilder:
"""Internally used Response builder"""

def __init__(self, response: Response, route: Optional[Route] = None):
def __init__(
self,
response: Response,
route: Optional[Route] = None,
current_event: Optional[BaseProxyEvent] = None,
):
self.response = response
self.route = route
self.current_event = current_event

def _add_cors(self, cors: CORSConfig):
"""Update headers to include the configured Access-Control headers"""
self.response.headers.update(cors.to_dict())
self.response.headers.update(cors.to_dict(current_event=self.current_event))

def _add_cache_control(self, cache_control: str):
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""
Expand Down Expand Up @@ -590,30 +608,34 @@ def _not_found(self, method: str) -> ResponseBuilder:
headers = {}
if self._cors:
logger.debug("CORS is enabled, updating headers.")
headers.update(self._cors.to_dict())
headers.update(self._cors.to_dict(current_event=self.current_event))

if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))
return ResponseBuilder(
Response(status_code=204, content_type=None, headers=headers, body=None),
current_event=self.current_event,
)

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return ResponseBuilder(handler(NotFoundError()))
return ResponseBuilder(handler(NotFoundError()), current_event=self.current_event)

return ResponseBuilder(
Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=headers,
body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}),
)
),
current_event=self.current_event,
)

def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
try:
return ResponseBuilder(self._to_response(route.func(**args)), route)
return ResponseBuilder(self._to_response(route.func(**args)), route, self.current_event)
except Exception as exc:
response_builder = self._call_exception_handler(exc, route)
if response_builder:
Expand All @@ -630,6 +652,7 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
body="".join(traceback.format_exc()),
),
route,
self.current_event,
)

raise
Expand Down Expand Up @@ -657,18 +680,23 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
handler = self._lookup_exception_handler(type(exp))
if handler:
try:
return ResponseBuilder(handler(exp), route)
return ResponseBuilder(
handler(exp),
route=route,
current_event=self.current_event,
)
except ServiceError as service_error:
exp = service_error

if isinstance(exp, ServiceError):
return ResponseBuilder(
Response(
response=Response(
status_code=exp.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
),
route,
route=route,
current_event=self.current_event,
)

return None
Expand Down
24 changes: 24 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,30 @@ def handle_not_found(_) -> Response:
assert result["statusCode"] == 404


def test_allow_origins_no_matching_origin():
# GIVEN
allow_origin = "https://www.foo.com/"
app = APIGatewayRestResolver(cors=CORSConfig(allow_origin=allow_origin, allow_origins=["https://staging.foo.com/"]))

# WHEN
result = app({"path": "/another-one", "httpMethod": "GET", "headers": {}}, None)

# THEN
assert result["headers"]["Access-Control-Allow-Origin"] == allow_origin


def test_allow_origins_match_origin():
# GIVEN
allow_origin = "https://staging.example.com/"
app = APIGatewayRestResolver(cors=CORSConfig(allow_origin="https://www.example.com/", allow_origins=[allow_origin]))

# WHEN
result = app({"path": "/another-one", "httpMethod": "GET", "headers": {"Origin": allow_origin}}, None)

# THEN
assert result["headers"]["Access-Control-Allow-Origin"] == allow_origin


def test_exception_handler_raises_service_error(json_dump):
# GIVEN an exception handler raises a ServiceError (BadRequestError)
app = ApiGatewayResolver()
Expand Down