Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions jsonrpcserver/async_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import asyncio
import collections.abc
from json import JSONDecodeError
from json import dumps as serialize, loads as deserialize
from typing import Any, Iterable, Optional, Union
from json import dumps as default_serialize, loads as default_deserialize
from typing import Any, Iterable, Optional, Union, Callable

from apply_defaults import apply_config # type: ignore
from jsonschema import ValidationError # type: ignore
Expand Down Expand Up @@ -34,25 +34,34 @@ async def call(method: Method, *args: Any, **kwargs: Any) -> Any:
return await validate_args(method, *args, **kwargs)(*args, **kwargs)


async def safe_call(request: Request, methods: Methods, *, debug: bool) -> Response:
async def safe_call(
request: Request, methods: Methods, *, debug: bool, serialize: Callable
) -> Response:
with handle_exceptions(request, debug) as handler:
result = await call(
lookup(methods, request.method), *request.args, **request.kwargs
)
# Ensure value returned from the method is JSON-serializable. If not,
# handle_exception will set handler.response to an ExceptionResponse
serialize(result)
handler.response = SuccessResponse(result=result, id=request.id)
handler.response = SuccessResponse(
result=result, id=request.id, serialize_func=serialize
)
return handler.response


async def call_requests(
requests: Union[Request, Iterable[Request]], methods: Methods, debug: bool
requests: Union[Request, Iterable[Request]],
methods: Methods,
debug: bool,
serialize: Callable,
) -> Response:
if isinstance(requests, collections.abc.Iterable):
responses = (safe_call(r, methods, debug=debug) for r in requests)
return BatchResponse(await asyncio.gather(*responses))
return await safe_call(requests, methods, debug=debug)
responses = (
safe_call(r, methods, debug=debug, serialize=serialize) for r in requests
)
return BatchResponse(await asyncio.gather(*responses), serialize_func=serialize)
return await safe_call(requests, methods, debug=debug, serialize=serialize)


async def dispatch_pure(
Expand All @@ -61,7 +70,9 @@ async def dispatch_pure(
*,
context: Any,
convert_camel_case: bool,
debug: bool
debug: bool,
serialize: Callable,
deserialize: Callable,
) -> Response:
try:
deserialized = validate(deserialize(request), schema)
Expand All @@ -75,6 +86,7 @@ async def dispatch_pure(
),
methods,
debug=debug,
serialize=serialize,
)


Expand All @@ -88,7 +100,9 @@ async def dispatch(
context: Any = NOCONTEXT,
debug: bool = False,
trim_log_values: bool = False,
**kwargs: Any
serialize: Callable = default_serialize,
deserialize: Callable = default_deserialize,
**kwargs: Any,
) -> Response:
# Use the global methods object if no methods object was passed.
methods = global_methods if methods is None else methods
Expand All @@ -102,6 +116,8 @@ async def dispatch(
debug=debug,
context=context,
convert_camel_case=convert_camel_case,
serialize=serialize,
deserialize=deserialize,
)
log_response(str(response), trim_log_values=trim_log_values)
# Remove the temporary stream handlers
Expand Down
50 changes: 42 additions & 8 deletions jsonrpcserver/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,20 @@
from configparser import ConfigParser
from contextlib import contextmanager
from json import JSONDecodeError
from json import dumps as serialize, loads as deserialize
from json import dumps as default_serialize, loads as default_deserialize
from types import SimpleNamespace
from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
Callable,
)

from apply_defaults import apply_config # type: ignore
from jsonschema import ValidationError # type: ignore
Expand Down Expand Up @@ -40,7 +51,7 @@
response_logger = logging.getLogger(__name__ + ".response")

# Prepare the jsonschema validator
schema = deserialize(resource_string(__name__, "request-schema.json"))
schema = default_deserialize(resource_string(__name__, "request-schema.json"))
klass = validator_for(schema)
klass.check_schema(schema)
validator = klass(schema)
Expand Down Expand Up @@ -144,14 +155,17 @@ def handle_exceptions(request: Request, debug: bool) -> Generator:
handler.response = NotificationResponse()


def safe_call(request: Request, methods: Methods, *, debug: bool) -> Response:
def safe_call(
request: Request, methods: Methods, *, debug: bool, serialize: Callable
) -> Response:
"""
Call a Request, catching exceptions to ensure we always return a Response.

Args:
request: The Request object.
methods: The list of methods that can be called.
debug: Include more information in error responses.
serialize: Function that is used to serialize data.

Returns:
A Response object.
Expand All @@ -161,12 +175,17 @@ def safe_call(request: Request, methods: Methods, *, debug: bool) -> Response:
# Ensure value returned from the method is JSON-serializable. If not,
# handle_exception will set handler.response to an ExceptionResponse
serialize(result)
handler.response = SuccessResponse(result=result, id=request.id)
handler.response = SuccessResponse(
result=result, id=request.id, serialize_func=serialize
)
return handler.response


def call_requests(
requests: Union[Request, Iterable[Request]], methods: Methods, debug: bool
requests: Union[Request, Iterable[Request]],
methods: Methods,
debug: bool,
serialize: Callable,
) -> Response:
"""
Takes a request or list of Requests and calls them.
Expand All @@ -175,10 +194,14 @@ def call_requests(
requests: Request object, or a collection of them.
methods: The list of methods that can be called.
debug: Include more information in error responses.
serialize: Function that is used to serialize data.
"""
if isinstance(requests, Iterable):
return BatchResponse(safe_call(r, methods, debug=debug) for r in requests)
return safe_call(requests, methods, debug=debug)
return BatchResponse(
[safe_call(r, methods, debug=debug, serialize=serialize) for r in requests],
serialize_func=serialize,
)
return safe_call(requests, methods, debug=debug, serialize=serialize)


def create_requests(
Expand Down Expand Up @@ -211,6 +234,8 @@ def dispatch_pure(
context: Any,
convert_camel_case: bool,
debug: bool,
serialize: Callable,
deserialize: Callable,
) -> Response:
"""
Pure version of dispatch - no logging, no optional parameters.
Expand All @@ -225,6 +250,8 @@ def dispatch_pure(
context: If specified, will be the first positional argument in all requests.
convert_camel_case: Will convert the method name/any named params to snake case.
debug: Include more information in error responses.
serialize: Function that is used to serialize data.
deserialize: Function that is used to deserialize data.
Returns:
A Response.
"""
Expand All @@ -240,6 +267,7 @@ def dispatch_pure(
),
methods,
debug=debug,
serialize=serialize,
)


Expand All @@ -253,6 +281,8 @@ def dispatch(
context: Any = NOCONTEXT,
debug: bool = False,
trim_log_values: bool = False,
serialize: Callable = default_serialize,
deserialize: Callable = default_deserialize,
**kwargs: Any,
) -> Response:
"""
Expand All @@ -270,6 +300,8 @@ def dispatch(
case.
debug: Include more information in error responses.
trim_log_values: Show abbreviated requests and responses in log.
serialize: Function that is used to serialize data.
deserialize: Function that is used to deserialize data.

Returns:
A Response.
Expand All @@ -289,6 +321,8 @@ def dispatch(
debug=debug,
context=context,
convert_camel_case=convert_camel_case,
serialize=serialize,
deserialize=deserialize,
)
log_response(str(response), trim_log_values=trim_log_values)
# Remove the temporary stream handlers
Expand Down
22 changes: 14 additions & 8 deletions jsonrpcserver/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
ExceptionResponse
BatchResponse - a list of DictResponses
"""
import json
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, Iterable, cast
from typing import Any, Dict, Iterable, cast, Callable
from json import dumps as default_serialize

from . import status

Expand All @@ -43,8 +43,11 @@
class Response(ABC):
"""Base class of all responses."""

def __init__(self, http_status: int) -> None:
def __init__(
self, http_status: int, serialize_func: Callable = default_serialize
) -> None:
self.http_status = http_status
self._serialize = serialize_func

@property
@abstractmethod
Expand Down Expand Up @@ -130,7 +133,7 @@ def deserialized(self) -> dict:

def __str__(self) -> str:
"""Use str() to get the JSON-RPC response string."""
return json.dumps(sort_dict_response(self.deserialized()))
return self._serialize(sort_dict_response(self.deserialized()))


class SuccessResponse(DictResponse):
Expand All @@ -150,7 +153,7 @@ def __init__(
The payload from processing the request. If the request was a JSON-RPC
notification (i.e. the request id is `None`), the result must also be
`None` because notifications don't require any data returned.
http_status:
http_status:
"""
super().__init__(http_status=http_status, **kwargs)
self.result = result
Expand Down Expand Up @@ -297,9 +300,12 @@ class BatchResponse(Response):
"""

def __init__(
self, responses: Iterable[Response], http_status: int = status.HTTP_OK
self,
responses: Iterable[Response],
http_status: int = status.HTTP_OK,
**kwargs: Any,
) -> None:
super().__init__(http_status=http_status)
super().__init__(http_status=http_status, **kwargs)
# Remove notifications; these are not allowed in batch responses
self.responses = cast(
Iterable[DictResponse], {r for r in responses if r.wanted}
Expand All @@ -317,4 +323,4 @@ def __str__(self) -> str:
dicts = self.deserialized()
# For an all-notifications response, an empty string should be returned, as per
# spec
return json.dumps(dicts) if len(dicts) else ""
return self._serialize(dicts) if len(dicts) else ""
Loading