|  | 
|  | 1 | +import pickle | 
|  | 2 | +from datetime import datetime, timezone | 
|  | 3 | + | 
|  | 4 | +from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache | 
|  | 5 | +from django.core.cache.backends.db import Options | 
|  | 6 | +from django.db import connections, router | 
|  | 7 | +from django.utils.functional import cached_property | 
|  | 8 | +from pymongo import ASCENDING, DESCENDING, IndexModel, ReturnDocument | 
|  | 9 | +from pymongo.errors import DuplicateKeyError, OperationFailure | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +class MongoSerializer: | 
|  | 13 | + def __init__(self, protocol=None): | 
|  | 14 | + self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol | 
|  | 15 | + | 
|  | 16 | + def dumps(self, obj): | 
|  | 17 | + # For better incr() and decr() atomicity, don't pickle integers. | 
|  | 18 | + # Using type() rather than isinstance() matches only integers and not | 
|  | 19 | + # subclasses like bool. | 
|  | 20 | + if type(obj) is int: # noqa: E721 | 
|  | 21 | + return obj | 
|  | 22 | + return pickle.dumps(obj, self.protocol) | 
|  | 23 | + | 
|  | 24 | + def loads(self, data): | 
|  | 25 | + try: | 
|  | 26 | + return int(data) | 
|  | 27 | + except (ValueError, TypeError): | 
|  | 28 | + return pickle.loads(data) # noqa: S301 | 
|  | 29 | + | 
|  | 30 | + | 
|  | 31 | +class MongoDBCache(BaseCache): | 
|  | 32 | + pickle_protocol = pickle.HIGHEST_PROTOCOL | 
|  | 33 | + | 
|  | 34 | + def __init__(self, collection_name, params): | 
|  | 35 | + super().__init__(params) | 
|  | 36 | + self._collection_name = collection_name | 
|  | 37 | + | 
|  | 38 | + class CacheEntry: | 
|  | 39 | + _meta = Options(collection_name) | 
|  | 40 | + | 
|  | 41 | + self.cache_model_class = CacheEntry | 
|  | 42 | + | 
|  | 43 | + def create_indexes(self): | 
|  | 44 | + expires_index = IndexModel("expires_at", expireAfterSeconds=0) | 
|  | 45 | + key_index = IndexModel("key", unique=True) | 
|  | 46 | + self.collection_for_write.create_indexes([expires_index, key_index]) | 
|  | 47 | + | 
|  | 48 | + @cached_property | 
|  | 49 | + def serializer(self): | 
|  | 50 | + return MongoSerializer(self.pickle_protocol) | 
|  | 51 | + | 
|  | 52 | + @property | 
|  | 53 | + def collection_for_read(self): | 
|  | 54 | + db = router.db_for_read(self.cache_model_class) | 
|  | 55 | + return connections[db].get_collection(self._collection_name) | 
|  | 56 | + | 
|  | 57 | + @property | 
|  | 58 | + def collection_for_write(self): | 
|  | 59 | + db = router.db_for_write(self.cache_model_class) | 
|  | 60 | + return connections[db].get_collection(self._collection_name) | 
|  | 61 | + | 
|  | 62 | + def _filter_expired(self, expired=False): | 
|  | 63 | + """ | 
|  | 64 | + Return MQL to exclude expired entries (needed because the MongoDB | 
|  | 65 | + daemon does not remove expired entries precisely when they expire). | 
|  | 66 | + If expired=True, return MQL to include only expired entries. | 
|  | 67 | + """ | 
|  | 68 | + op = "$lt" if expired else "$gte" | 
|  | 69 | + return {"expires_at": {op: datetime.utcnow()}} | 
|  | 70 | + | 
|  | 71 | + def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT): | 
|  | 72 | + if timeout is None: | 
|  | 73 | + return datetime.max | 
|  | 74 | + timestamp = super().get_backend_timeout(timeout) | 
|  | 75 | + return datetime.fromtimestamp(timestamp, tz=timezone.utc) | 
|  | 76 | + | 
|  | 77 | + def get(self, key, default=None, version=None): | 
|  | 78 | + return self.get_many([key], version).get(key, default) | 
|  | 79 | + | 
|  | 80 | + def get_many(self, keys, version=None): | 
|  | 81 | + if not keys: | 
|  | 82 | + return {} | 
|  | 83 | + keys_map = {self.make_and_validate_key(key, version=version): key for key in keys} | 
|  | 84 | + with self.collection_for_read.find( | 
|  | 85 | + {"key": {"$in": tuple(keys_map)}, **self._filter_expired(expired=False)} | 
|  | 86 | + ) as cursor: | 
|  | 87 | + return {keys_map[row["key"]]: self.serializer.loads(row["value"]) for row in cursor} | 
|  | 88 | + | 
|  | 89 | + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): | 
|  | 90 | + key = self.make_and_validate_key(key, version=version) | 
|  | 91 | + num = self.collection_for_write.count_documents({}, hint="_id_") | 
|  | 92 | + if num >= self._max_entries: | 
|  | 93 | + self._cull(num) | 
|  | 94 | + self.collection_for_write.update_one( | 
|  | 95 | + {"key": key}, | 
|  | 96 | + { | 
|  | 97 | + "$set": { | 
|  | 98 | + "key": key, | 
|  | 99 | + "value": self.serializer.dumps(value), | 
|  | 100 | + "expires_at": self.get_backend_timeout(timeout), | 
|  | 101 | + } | 
|  | 102 | + }, | 
|  | 103 | + upsert=True, | 
|  | 104 | + ) | 
|  | 105 | + | 
|  | 106 | + def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): | 
|  | 107 | + key = self.make_and_validate_key(key, version=version) | 
|  | 108 | + num = self.collection_for_write.count_documents({}, hint="_id_") | 
|  | 109 | + if num >= self._max_entries: | 
|  | 110 | + self._cull(num) | 
|  | 111 | + try: | 
|  | 112 | + self.collection_for_write.update_one( | 
|  | 113 | + {"key": key, **self._filter_expired(expired=True)}, | 
|  | 114 | + { | 
|  | 115 | + "$set": { | 
|  | 116 | + "key": key, | 
|  | 117 | + "value": self.serializer.dumps(value), | 
|  | 118 | + "expires_at": self.get_backend_timeout(timeout), | 
|  | 119 | + } | 
|  | 120 | + }, | 
|  | 121 | + upsert=True, | 
|  | 122 | + ) | 
|  | 123 | + except DuplicateKeyError: | 
|  | 124 | + return False | 
|  | 125 | + return True | 
|  | 126 | + | 
|  | 127 | + def _cull(self, num): | 
|  | 128 | + if self._cull_frequency == 0: | 
|  | 129 | + self.clear() | 
|  | 130 | + else: | 
|  | 131 | + # The fraction of entries that are culled when MAX_ENTRIES is | 
|  | 132 | + # reached is 1 / CULL_FREQUENCY. For example, in the default case | 
|  | 133 | + # of CULL_FREQUENCY=3, 2/3 of the entries are kept, thus `keep_num` | 
|  | 134 | + # will be 2/3 of the current number of entries. | 
|  | 135 | + keep_num = num - num // self._cull_frequency | 
|  | 136 | + try: | 
|  | 137 | + # Find the first cache entry beyond the retention limit, | 
|  | 138 | + # culling entries that expire the soonest. | 
|  | 139 | + deleted_from = next( | 
|  | 140 | + self.collection_for_write.aggregate( | 
|  | 141 | + [ | 
|  | 142 | + {"$sort": {"expires_at": DESCENDING, "key": ASCENDING}}, | 
|  | 143 | + {"$skip": keep_num}, | 
|  | 144 | + {"$limit": 1}, | 
|  | 145 | + {"$project": {"key": 1, "expires_at": 1}}, | 
|  | 146 | + ] | 
|  | 147 | + ) | 
|  | 148 | + ) | 
|  | 149 | + except StopIteration: | 
|  | 150 | + # If no entries are found, there is nothing to delete. It may | 
|  | 151 | + # happen if the database removes expired entries between the | 
|  | 152 | + # query to get `num` and the query to get `deleted_from`. | 
|  | 153 | + pass | 
|  | 154 | + else: | 
|  | 155 | + # Cull the cache. | 
|  | 156 | + self.collection_for_write.delete_many( | 
|  | 157 | + { | 
|  | 158 | + "$or": [ | 
|  | 159 | + # Delete keys that expire before `deleted_from`... | 
|  | 160 | + {"expires_at": {"$lt": deleted_from["expires_at"]}}, | 
|  | 161 | + # and the entries that share an expiration with | 
|  | 162 | + # `deleted_from` but are alphabetically after it | 
|  | 163 | + # (per the same sorting to fetch `deleted_from`). | 
|  | 164 | + { | 
|  | 165 | + "$and": [ | 
|  | 166 | + {"expires_at": deleted_from["expires_at"]}, | 
|  | 167 | + {"key": {"$gte": deleted_from["key"]}}, | 
|  | 168 | + ] | 
|  | 169 | + }, | 
|  | 170 | + ] | 
|  | 171 | + } | 
|  | 172 | + ) | 
|  | 173 | + | 
|  | 174 | + def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): | 
|  | 175 | + key = self.make_and_validate_key(key, version=version) | 
|  | 176 | + res = self.collection_for_write.update_one( | 
|  | 177 | + {"key": key}, {"$set": {"expires_at": self.get_backend_timeout(timeout)}} | 
|  | 178 | + ) | 
|  | 179 | + return res.matched_count > 0 | 
|  | 180 | + | 
|  | 181 | + def incr(self, key, delta=1, version=None): | 
|  | 182 | + serialized_key = self.make_and_validate_key(key, version=version) | 
|  | 183 | + try: | 
|  | 184 | + updated = self.collection_for_write.find_one_and_update( | 
|  | 185 | + {"key": serialized_key, **self._filter_expired(expired=False)}, | 
|  | 186 | + {"$inc": {"value": delta}}, | 
|  | 187 | + return_document=ReturnDocument.AFTER, | 
|  | 188 | + ) | 
|  | 189 | + except OperationFailure as exc: | 
|  | 190 | + method_name = "incr" if delta >= 1 else "decr" | 
|  | 191 | + raise TypeError(f"Cannot apply {method_name}() to a non-numeric value.") from exc | 
|  | 192 | + if updated is None: | 
|  | 193 | + raise ValueError(f"Key '{key}' not found.") from None | 
|  | 194 | + return updated["value"] | 
|  | 195 | + | 
|  | 196 | + def delete(self, key, version=None): | 
|  | 197 | + return self._delete_many([key], version) | 
|  | 198 | + | 
|  | 199 | + def delete_many(self, keys, version=None): | 
|  | 200 | + self._delete_many(keys, version) | 
|  | 201 | + | 
|  | 202 | + def _delete_many(self, keys, version=None): | 
|  | 203 | + if not keys: | 
|  | 204 | + return False | 
|  | 205 | + keys = tuple(self.make_and_validate_key(key, version=version) for key in keys) | 
|  | 206 | + return bool(self.collection_for_write.delete_many({"key": {"$in": keys}}).deleted_count) | 
|  | 207 | + | 
|  | 208 | + def has_key(self, key, version=None): | 
|  | 209 | + key = self.make_and_validate_key(key, version=version) | 
|  | 210 | + num = self.collection_for_read.count_documents( | 
|  | 211 | + {"key": key, **self._filter_expired(expired=False)} | 
|  | 212 | + ) | 
|  | 213 | + return num > 0 | 
|  | 214 | + | 
|  | 215 | + def clear(self): | 
|  | 216 | + self.collection_for_write.delete_many({}) | 
0 commit comments