diff options
author | Richard Marmorstein <richardm@stripe.com> | 2024-02-20 13:48:57 -0800 |
---|---|---|
committer | Richard Marmorstein <richardm@stripe.com> | 2024-02-20 13:48:57 -0800 |
commit | 7a08bdd0df5509bdc2119a1a48f15b09bdb32c1f (patch) | |
tree | c0ab818dea8a21e87f51395997e9caf6ab710895 | |
parent | be5b48a977ff309d44a63d2a38fad3eddca37d74 (diff) |
wip
-rw-r--r-- | stripe/_list_object.py | 79 | ||||
-rw-r--r-- | stripe/_list_object_async.py | 84 | ||||
-rw-r--r-- | stripe/_list_object_base.py | 173 |
3 files changed, 265 insertions, 71 deletions
diff --git a/stripe/_list_object.py b/stripe/_list_object.py index e3de85e7..69b4095a 100644 --- a/stripe/_list_object.py +++ b/stripe/_list_object.py @@ -20,23 +20,18 @@ from stripe._stripe_object import StripeObject from stripe._request_options import RequestOptions, extract_options_from_dict from urllib.parse import quote_plus +from stripe._list_object_base import ListObjectBase T = TypeVar("T", bound=StripeObject) -class ListObject(StripeObject, Generic[T]): - OBJECT_NAME = "list" - data: List[T] - has_more: bool - url: str - - def _get_url_for_list(self) -> str: - url = self.get("url") - if not isinstance(url, str): - raise ValueError( - 'Cannot call .list on a list object without a string "url" property' - ) - return url +class ListObject(ListObjectBase, Generic[T]): + """ + Represents a list response from the Stripe API. Unlike ListObjectAsync, also contains sync versions of request-making methods like `.auto_paging_iter` and `.next_page`. + """ + # Even though ListObjectAsync is the "async version" we cannot omit async methods from + # ListObject and must include both, because this is the class that gets deserialized by default + # when object: 'list_object' def list(self, **params: Mapping[str, Any]) -> Self: return cast( @@ -98,31 +93,6 @@ class ListObject(StripeObject, Generic[T]): ), ) - def __getitem__(self, k: str) -> T: - if isinstance(k, str): # pyright: ignore - return super(ListObject, self).__getitem__(k) - else: - raise KeyError( - "You tried to access the %s index, but ListObject types only " - "support string keys. (HINT: List calls return an object with " - "a 'data' (which is the data array). You likely want to call " - ".data[%s])" % (repr(k), repr(k)) - ) - - # Pyright doesn't like this because ListObject inherits from StripeObject inherits from Dict[str, Any] - # and so it wants the type of __iter__ to agree with __iter__ from Dict[str, Any] - # But we are iterating through "data", which is a List[T]. - def __iter__( # pyright: ignore - self, - ) -> Iterator[T]: - return getattr(self, "data", []).__iter__() - - def __len__(self) -> int: - return getattr(self, "data", []).__len__() - - def __reversed__(self) -> Iterator[T]: # pyright: ignore (see above) - return getattr(self, "data", []).__reversed__() - def auto_paging_iter(self) -> Iterator[T]: page = self @@ -161,39 +131,6 @@ class ListObject(StripeObject, Generic[T]): if page.is_empty: break - @classmethod - def _empty_list( - cls, - **params: Unpack[RequestOptions], - ) -> Self: - return cls._construct_from( - values={"data": []}, - last_response=None, - requestor=_APIRequestor._global_with_options( # pyright: ignore[reportPrivateUsage] - **params, - ), - api_mode="V1", - ) - - @property - def is_empty(self) -> bool: - return not self.data - - def _get_filters_for_next_page( - self, params: RequestOptions - ) -> Mapping[str, Any]: - - last_id = getattr(self.data[-1], "id") - if not last_id: - raise ValueError( - "Unexpected: element in .data of list object had no id" - ) - - params_with_filters = dict(self._retrieve_params) - params_with_filters.update({"starting_after": last_id}) - params_with_filters.update(params) - return params_with_filters - def next_page(self, **params: Unpack[RequestOptions]) -> Self: if not self.has_more: request_options, _ = extract_options_from_dict(params) diff --git a/stripe/_list_object_async.py b/stripe/_list_object_async.py new file mode 100644 index 00000000..1d33b50c --- /dev/null +++ b/stripe/_list_object_async.py @@ -0,0 +1,84 @@ +# pyright: strict, reportUnnecessaryTypeIgnoreComment=false +# reportUnnecessaryTypeIgnoreComment is set to false because some type ignores are required in some +# python versions but not the others +from typing_extensions import Self, Unpack + +from typing import ( + Any, + AsyncIterator, + Iterator, + List, + Generic, + TypeVar, + cast, + Mapping, +) +from stripe._api_requestor import ( + _APIRequestor, # pyright: ignore[reportPrivateUsage] +) +from stripe._stripe_object import StripeObject +from stripe._request_options import RequestOptions, extract_options_from_dict + +from urllib.parse import quote_plus +from stripe._list_object_base import ListObjectBase + +T = TypeVar("T", bound=StripeObject) + + +class ListObjectAsync(ListObjectBase, Generic[T]): + """ + Variant of ListObject that contains *only* async versions of request-making methods. + """ + async def list_async(self, **params: Mapping[str, Any]) -> Self: + return cast( + Self, + await self._request_async( + "get", + self._get_url_for_list(), + params=params, + base_address="api", + api_mode="V1", + ), + ) + + async def auto_paging_iter_async(self) -> AsyncIterator[T]: + page = self + + while True: + if ( + "ending_before" in self._retrieve_params + and "starting_after" not in self._retrieve_params + ): + for item in reversed(page): + yield item + page = await page.previous_page_async() + else: + for item in page: + yield item + page = await page.next_page_async() + + if page.is_empty: + break + + async def next_page_async(self, **params: Unpack[RequestOptions]) -> Self: + if not self.has_more: + request_options, _ = extract_options_from_dict(params) + return self._empty_list( + **request_options, + ) + + return await self.list_async(**self._get_filters_for_next_page(params)) + + async def previous_page_async( + self, **params: Unpack[RequestOptions] + ) -> Self: + if not self.has_more: + request_options, _ = extract_options_from_dict(params) + return self._empty_list( + **request_options, + ) + + result = await self.list_async( + **self._get_filters_for_previous_page(params) + ) + return result diff --git a/stripe/_list_object_base.py b/stripe/_list_object_base.py new file mode 100644 index 00000000..91eb5b4f --- /dev/null +++ b/stripe/_list_object_base.py @@ -0,0 +1,173 @@ +# pyright: strict, reportUnnecessaryTypeIgnoreComment=false +# reportUnnecessaryTypeIgnoreComment is set to false because some type ignores are required in some +# python versions but not the others +from typing_extensions import Self, Unpack + +from typing import ( + Any, + AsyncIterator, + Iterator, + List, + Generic, + TypeVar, + cast, + Mapping, +) +from stripe._api_requestor import ( + _APIRequestor, # pyright: ignore[reportPrivateUsage] +) +from stripe._stripe_object import StripeObject +from stripe._request_options import RequestOptions, extract_options_from_dict + +from urllib.parse import quote_plus + +T = TypeVar("T", bound=StripeObject) + + +class ListObjectBase(StripeObject, Generic[T]): + OBJECT_NAME = "list" + data: List[T] + has_more: bool + url: str + + def _get_url_for_list(self) -> str: + url = self.get("url") + if not isinstance(url, str): + raise ValueError( + 'Cannot call .list on a list object without a string "url" property' + ) + return url + + def list(self, **params: Mapping[str, Any]) -> Self: + return cast( + Self, + self._request( + "get", + self._get_url_for_list(), + params=params, + base_address="api", + api_mode="V1", + ), + ) + + async def list_async(self, **params: Mapping[str, Any]) -> Self: + return cast( + Self, + await self._request_async( + "get", + self._get_url_for_list(), + params=params, + base_address="api", + api_mode="V1", + ), + ) + + def create(self, **params: Mapping[str, Any]) -> T: + url = self.get("url") + if not isinstance(url, str): + raise ValueError( + 'Cannot call .create on a list object for the collection of an object without a string "url" property' + ) + return cast( + T, + self._request( + "post", + url, + params=params, + base_address="api", + api_mode="V1", + ), + ) + + def retrieve(self, id: str, **params: Mapping[str, Any]): + url = self.get("url") + if not isinstance(url, str): + raise ValueError( + 'Cannot call .retrieve on a list object for the collection of an object without a string "url" property' + ) + + url = "%s/%s" % (self.get("url"), quote_plus(id)) + return cast( + T, + self._request( + "get", + url, + params=params, + base_address="api", + api_mode="V1", + ), + ) + + def __getitem__(self, k: str) -> T: + if isinstance(k, str): # pyright: ignore + return super(ListObject, self).__getitem__(k) + else: + raise KeyError( + "You tried to access the %s index, but ListObject types only " + "support string keys. (HINT: List calls return an object with " + "a 'data' (which is the data array). You likely want to call " + ".data[%s])" % (repr(k), repr(k)) + ) + + # Pyright doesn't like this because ListObject inherits from StripeObject inherits from Dict[str, Any] + # and so it wants the type of __iter__ to agree with __iter__ from Dict[str, Any] + # But we are iterating through "data", which is a List[T]. + def __iter__( # pyright: ignore + self, + ) -> Iterator[T]: + return getattr(self, "data", []).__iter__() + + def __len__(self) -> int: + return getattr(self, "data", []).__len__() + + def __reversed__(self) -> Iterator[T]: # pyright: ignore (see above) + return getattr(self, "data", []).__reversed__() + + @classmethod + def _empty_list( + cls, + **params: Unpack[RequestOptions], + ) -> Self: + return cls._construct_from( + values={"data": []}, + last_response=None, + requestor=_APIRequestor._global_with_options( # pyright: ignore[reportPrivateUsage] + **params, + ), + api_mode="V1", + ) + + @property + def is_empty(self) -> bool: + return not self.data + + # Used by child classes for next_page + def _get_filters_for_next_page( + self, params: RequestOptions + ) -> Mapping[str, Any]: + + last_id = getattr(self.data[-1], "id") + if not last_id: + raise ValueError( + "Unexpected: element in .data of list object had no id" + ) + + params_with_filters = dict(self._retrieve_params) + params_with_filters.update({"starting_after": last_id}) + params_with_filters.update(params) + return params_with_filters + + # Used by child classes for previous_page + def _get_filters_for_previous_page( + self, params: RequestOptions + ) -> Mapping[str, Any]: + first_id = getattr(self.data[0], "id") + if not first_id: + raise ValueError( + "Unexpected: element in .data of list object had no id" + ) + + params_with_filters = dict(self._retrieve_params) + params_with_filters.update({"ending_before": first_id}) + params_with_filters.update(params) + return params_with_filters |