Skip to content

Commit b8d6bfd

Browse files
authored
PYTHON-4144 Optimize json_util encoding performance using single dispatch table (mongodb#1475)
1 parent b9e1bf7 commit b8d6bfd

File tree

4 files changed

+248
-106
lines changed

4 files changed

+248
-106
lines changed

bson/__init__.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -886,18 +886,11 @@ def _encode_maxkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes:
886886
_abc.Mapping: _encode_mapping,
887887
}
888888

889-
890-
_MARKERS = {
891-
5: _encode_binary,
892-
7: _encode_objectid,
893-
11: _encode_regex,
894-
13: _encode_code,
895-
17: _encode_timestamp,
896-
18: _encode_long,
897-
100: _encode_dbref,
898-
127: _encode_maxkey,
899-
255: _encode_minkey,
900-
}
889+
# Map each _type_marker to its encoder for faster lookup.
890+
_MARKERS = {}
891+
for _typ in _ENCODERS:
892+
if hasattr(_typ, "_type_marker"):
893+
_MARKERS[_typ._type_marker] = _ENCODERS[_typ]
901894

902895

903896
_BUILT_IN_TYPES = tuple(t for t in _ENCODERS)

bson/json_util.py

Lines changed: 173 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
from typing import (
111111
TYPE_CHECKING,
112112
Any,
113+
Callable,
113114
Mapping,
114115
MutableMapping,
115116
Optional,
@@ -835,7 +836,7 @@ def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
835836
json_options.datetime_representation == DatetimeRepresentation.ISO8601
836837
and 0 <= int(obj) <= _max_datetime_ms()
837838
):
838-
return default(obj.as_datetime(), json_options)
839+
return _encode_datetime(obj.as_datetime(), json_options)
839840
elif json_options.datetime_representation == DatetimeRepresentation.LEGACY:
840841
return {"$date": str(int(obj))}
841842
return {"$date": {"$numberLong": str(int(obj))}}
@@ -855,100 +856,180 @@ def _encode_int64(obj: Int64, json_options: JSONOptions) -> Any:
855856
return int(obj)
856857

857858

859+
def _encode_noop(obj: Any, dummy0: Any) -> Any:
860+
return obj
861+
862+
863+
def _encode_regex(obj: Any, json_options: JSONOptions) -> dict:
864+
flags = ""
865+
if obj.flags & re.IGNORECASE:
866+
flags += "i"
867+
if obj.flags & re.LOCALE:
868+
flags += "l"
869+
if obj.flags & re.MULTILINE:
870+
flags += "m"
871+
if obj.flags & re.DOTALL:
872+
flags += "s"
873+
if obj.flags & re.UNICODE:
874+
flags += "u"
875+
if obj.flags & re.VERBOSE:
876+
flags += "x"
877+
if isinstance(obj.pattern, str):
878+
pattern = obj.pattern
879+
else:
880+
pattern = obj.pattern.decode("utf-8")
881+
if json_options.json_mode == JSONMode.LEGACY:
882+
return {"$regex": pattern, "$options": flags}
883+
return {"$regularExpression": {"pattern": pattern, "options": flags}}
884+
885+
886+
def _encode_int(obj: int, json_options: JSONOptions) -> Any:
887+
if json_options.json_mode == JSONMode.CANONICAL:
888+
if -(2**31) <= obj < 2**31:
889+
return {"$numberInt": str(obj)}
890+
return {"$numberLong": str(obj)}
891+
return obj
892+
893+
894+
def _encode_float(obj: float, json_options: JSONOptions) -> Any:
895+
if json_options.json_mode != JSONMode.LEGACY:
896+
if math.isnan(obj):
897+
return {"$numberDouble": "NaN"}
898+
elif math.isinf(obj):
899+
representation = "Infinity" if obj > 0 else "-Infinity"
900+
return {"$numberDouble": representation}
901+
elif json_options.json_mode == JSONMode.CANONICAL:
902+
# repr() will return the shortest string guaranteed to produce the
903+
# original value, when float() is called on it.
904+
return {"$numberDouble": str(repr(obj))}
905+
return obj
906+
907+
908+
def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
909+
if json_options.datetime_representation == DatetimeRepresentation.ISO8601:
910+
if not obj.tzinfo:
911+
obj = obj.replace(tzinfo=utc)
912+
assert obj.tzinfo is not None
913+
if obj >= EPOCH_AWARE:
914+
off = obj.tzinfo.utcoffset(obj)
915+
if (off.days, off.seconds, off.microseconds) == (0, 0, 0): # type: ignore
916+
tz_string = "Z"
917+
else:
918+
tz_string = obj.strftime("%z")
919+
millis = int(obj.microsecond / 1000)
920+
fracsecs = ".%03d" % (millis,) if millis else ""
921+
return {
922+
"$date": "{}{}{}".format(obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string)
923+
}
924+
925+
millis = _datetime_to_millis(obj)
926+
if json_options.datetime_representation == DatetimeRepresentation.LEGACY:
927+
return {"$date": millis}
928+
return {"$date": {"$numberLong": str(millis)}}
929+
930+
931+
def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict:
932+
return _encode_binary(obj, 0, json_options)
933+
934+
935+
def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict:
936+
return _encode_binary(obj, obj.subtype, json_options)
937+
938+
939+
def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict:
940+
if json_options.strict_uuid:
941+
binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation)
942+
return _encode_binary(binval, binval.subtype, json_options)
943+
else:
944+
return {"$uuid": obj.hex}
945+
946+
947+
def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict:
948+
return {"$oid": str(obj)}
949+
950+
951+
def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict:
952+
return {"$timestamp": {"t": obj.time, "i": obj.inc}}
953+
954+
955+
def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict:
956+
return {"$numberDecimal": str(obj)}
957+
958+
959+
def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict:
960+
return _json_convert(obj.as_doc(), json_options=json_options)
961+
962+
963+
def _encode_minkey(dummy0: Any, dummy1: Any) -> dict:
964+
return {"$minKey": 1}
965+
966+
967+
def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict:
968+
return {"$maxKey": 1}
969+
970+
858971
# Encoders for BSON types
859-
_encoders = {
860-
5: lambda obj, json_options: _encode_binary(obj, obj.subtype, json_options), # Binary
861-
7: lambda obj, json_options: {"$oid": str(obj)}, # noqa: ARG005 ObjectId
862-
9: _encode_datetimems, # DatetimeMS
863-
13: _encode_code, # Code
864-
17: lambda obj, json_options: {"$timestamp": {"t": obj.time, "i": obj.inc}}, # noqa: ARG005 Timestamp
865-
18: _encode_int64, # Int64
866-
19: lambda obj, json_options: {"$numberDecimal": str(obj)}, # noqa: ARG005 Decimal128
867-
100: lambda obj, json_options: _json_convert(obj.as_doc(), json_options=json_options), # DBRef
868-
127: lambda obj, json_options: {"$maxKey": 1}, # noqa: ARG005 MaxKey
869-
255: lambda obj, json_options: {"$minKey": 1}, # noqa: ARG005 MinKey
972+
# Each encoder function's signature is:
973+
# - obj: a Python data type, e.g. a Python int for _encode_int
974+
# - json_options: a JSONOptions
975+
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = {
976+
bool: _encode_noop,
977+
bytes: _encode_bytes,
978+
datetime.datetime: _encode_datetime,
979+
DatetimeMS: _encode_datetimems,
980+
float: _encode_float,
981+
int: _encode_int,
982+
str: _encode_noop,
983+
type(None): _encode_noop,
984+
uuid.UUID: _encode_uuid,
985+
Binary: _encode_binary_obj,
986+
Int64: _encode_int64,
987+
Code: _encode_code,
988+
DBRef: _encode_dbref,
989+
MaxKey: _encode_maxkey,
990+
MinKey: _encode_minkey,
991+
ObjectId: _encode_objectid,
992+
Regex: _encode_regex,
993+
RE_TYPE: _encode_regex,
994+
Timestamp: _encode_timestamp,
995+
Decimal128: _encode_decimal128,
870996
}
871997

998+
# Map each _type_marker to its encoder for faster lookup.
999+
_MARKERS: dict[int, Callable[[Any, JSONOptions], Any]] = {}
1000+
for _typ in _ENCODERS:
1001+
if hasattr(_typ, "_type_marker"):
1002+
_MARKERS[_typ._type_marker] = _ENCODERS[_typ]
1003+
1004+
_BUILT_IN_TYPES = tuple(t for t in _ENCODERS)
1005+
8721006

8731007
def default(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any:
874-
# We preserve key order when rendering SON, DBRef, etc. as JSON by
875-
# returning a SON for those types instead of a dict.
876-
if isinstance(obj, bool):
877-
return obj
878-
elif isinstance(obj, (RE_TYPE, Regex)):
879-
flags = ""
880-
if obj.flags & re.IGNORECASE:
881-
flags += "i"
882-
if obj.flags & re.LOCALE:
883-
flags += "l"
884-
if obj.flags & re.MULTILINE:
885-
flags += "m"
886-
if obj.flags & re.DOTALL:
887-
flags += "s"
888-
if obj.flags & re.UNICODE:
889-
flags += "u"
890-
if obj.flags & re.VERBOSE:
891-
flags += "x"
892-
if isinstance(obj.pattern, str):
893-
pattern = obj.pattern
894-
else:
895-
pattern = obj.pattern.decode("utf-8")
896-
if json_options.json_mode == JSONMode.LEGACY:
897-
return {"$regex": pattern, "$options": flags}
898-
return {"$regularExpression": {"pattern": pattern, "options": flags}}
899-
elif hasattr(obj, "_type_marker"):
900-
type_marker = obj._type_marker
901-
try:
902-
return _encoders[type_marker](obj, json_options) # type: ignore[no-untyped-call]
903-
except KeyError:
904-
raise TypeError("%r is not JSON serializable" % obj) from None
905-
elif isinstance(obj, int):
906-
if json_options.json_mode == JSONMode.CANONICAL:
907-
if -(2**31) <= obj < 2**31:
908-
return {"$numberInt": str(obj)}
909-
return {"$numberLong": str(obj)}
910-
return obj
911-
elif isinstance(obj, float):
912-
if json_options.json_mode != JSONMode.LEGACY:
913-
if math.isnan(obj):
914-
return {"$numberDouble": "NaN"}
915-
elif math.isinf(obj):
916-
representation = "Infinity" if obj > 0 else "-Infinity"
917-
return {"$numberDouble": representation}
918-
elif json_options.json_mode == JSONMode.CANONICAL:
919-
# repr() will return the shortest string guaranteed to produce the
920-
# original value, when float() is called on it.
921-
return {"$numberDouble": str(repr(obj))}
922-
return obj
923-
elif isinstance(obj, str):
924-
return obj
925-
elif isinstance(obj, datetime.datetime):
926-
if json_options.datetime_representation == DatetimeRepresentation.ISO8601:
927-
if not obj.tzinfo:
928-
obj = obj.replace(tzinfo=utc)
929-
assert obj.tzinfo is not None
930-
if obj >= EPOCH_AWARE:
931-
off = obj.tzinfo.utcoffset(obj)
932-
if (off.days, off.seconds, off.microseconds) == (0, 0, 0): # type: ignore
933-
tz_string = "Z"
934-
else:
935-
tz_string = obj.strftime("%z")
936-
millis = int(obj.microsecond / 1000)
937-
fracsecs = ".%03d" % (millis,) if millis else ""
938-
return {
939-
"$date": "{}{}{}".format(obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string)
940-
}
941-
942-
millis = _datetime_to_millis(obj)
943-
if json_options.datetime_representation == DatetimeRepresentation.LEGACY:
944-
return {"$date": millis}
945-
return {"$date": {"$numberLong": str(millis)}}
946-
elif isinstance(obj, bytes):
947-
return _encode_binary(obj, 0, json_options)
948-
elif isinstance(obj, uuid.UUID):
949-
if json_options.strict_uuid:
950-
binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation)
951-
return _encode_binary(binval, binval.subtype, json_options)
952-
else:
953-
return {"$uuid": obj.hex}
1008+
# First see if the type is already cached. KeyError will only ever
1009+
# happen once per subtype.
1010+
try:
1011+
return _ENCODERS[type(obj)](obj, json_options)
1012+
except KeyError:
1013+
pass
1014+
1015+
# Second, fall back to trying _type_marker. This has to be done
1016+
# before the loop below since users could subclass one of our
1017+
# custom types that subclasses a python built-in (e.g. Binary)
1018+
if hasattr(obj, "_type_marker"):
1019+
marker = obj._type_marker
1020+
if marker in _MARKERS:
1021+
func = _MARKERS[marker]
1022+
# Cache this type for faster subsequent lookup.
1023+
_ENCODERS[type(obj)] = func
1024+
return func(obj, json_options)
1025+
1026+
# Third, test each base type. This will only happen once for
1027+
# a subtype of a supported base type.
1028+
for base in _BUILT_IN_TYPES:
1029+
if isinstance(obj, base):
1030+
func = _ENCODERS[base]
1031+
# Cache this type for faster subsequent lookup.
1032+
_ENCODERS[type(obj)] = func
1033+
return func(obj, json_options)
1034+
9541035
raise TypeError("%r is not JSON serializable" % obj)

test/test_bson.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
encode,
4949
is_valid,
5050
)
51-
from bson.binary import Binary, UuidRepresentation
51+
from bson.binary import USER_DEFINED_SUBTYPE, Binary, UuidRepresentation
5252
from bson.code import Code
5353
from bson.codec_options import CodecOptions, DatetimeConversion
5454
from bson.datetime_ms import _DATETIME_ERROR_SUGGESTION
@@ -772,6 +772,21 @@ class _myunicode(str):
772772
self.assertEqual(type(value), orig_type)
773773
self.assertEqual(value, orig_type(value))
774774

775+
def test_encode_type_marker(self):
776+
# Assert that a custom subclass can be BSON encoded based on the _type_marker attribute.
777+
class MyMaxKey:
778+
_type_marker = 127
779+
780+
expected_bson = encode({"a": MaxKey()})
781+
self.assertEqual(encode({"a": MyMaxKey()}), expected_bson)
782+
783+
# Test a class that inherits from two built in types
784+
class MyBinary(Binary):
785+
pass
786+
787+
expected_bson = encode({"a": Binary(b"bin", USER_DEFINED_SUBTYPE)})
788+
self.assertEqual(encode({"a": MyBinary(b"bin", USER_DEFINED_SUBTYPE)}), expected_bson)
789+
775790
def test_ordered_dict(self):
776791
d = OrderedDict([("one", 1), ("two", 2), ("three", 3), ("four", 4)])
777792
self.assertEqual(d, decode(encode(d), CodecOptions(document_class=OrderedDict))) # type: ignore[type-var]

0 commit comments

Comments
 (0)