Skip to content
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,5 @@ that much better:
* Timothé Perez (https://github.com/AchilleAsh)
* oleksandr-l5 (https://github.com/oleksandr-l5)
* Ido Shraga (https://github.com/idoshr)
* Nick Freville (https://github.com/nickfrev)
* Terence Honles (https://github.com/terencehonles)
1 change: 1 addition & 0 deletions mongoengine/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"ComplexBaseField",
"ObjectIdField",
"GeoJsonBaseField",
"SaveableBaseField",
# metaclasses
"DocumentMetaclass",
"TopLevelDocumentMetaclass",
Expand Down
21 changes: 19 additions & 2 deletions mongoengine/base/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mongoengine.common import _import_class
from mongoengine.errors import DeprecatedError, ValidationError

__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
__all__ = ("BaseField", "SaveableBaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")


class BaseField:
Expand Down Expand Up @@ -259,7 +259,14 @@ def owner_document(self, owner_document):
self._set_owner_document(owner_document)


class ComplexBaseField(BaseField):
class SaveableBaseField(BaseField):
"""A base class that dictates a field has the ability to save.
"""
def save():
pass


class ComplexBaseField(SaveableBaseField):
"""Handles complex fields, such as lists / dictionaries.

Allows for nesting of embedded documents inside complex types.
Expand Down Expand Up @@ -483,6 +490,16 @@ def validate(self, value):
if self.required and not value:
self.error("Field is required and cannot be empty")

def save(self, instance, **kwargs):
Document = _import_class("Document")
value = instance._data.get(self.name)

for ref in value:
if isinstance(ref, SaveableBaseField):
ref.save(self, **kwargs)
elif isinstance(ref, Document):
ref.save(**kwargs)

def prepare_query_value(self, op, value):
return self.to_mongo(value)

Expand Down
154 changes: 86 additions & 68 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import re

import pymongo
from bson import SON
from bson.dbref import DBRef
from pymongo.read_preferences import ReadPreference

from mongoengine import signals
from mongoengine.base import (
BaseDict,
BaseDocument,
SaveableBaseField,
BaseList,
DocumentMetaclass,
EmbeddedDocumentList,
Expand Down Expand Up @@ -385,44 +387,34 @@ def save(
the cascade save using cascade_kwargs which overwrites the
existing kwargs with custom values.
"""
signal_kwargs = signal_kwargs or {}

if self._meta.get("abstract"):
raise InvalidDocumentError("Cannot save an abstract document.")

signals.pre_save.send(self.__class__, document=self, **signal_kwargs)

if validate:
self.validate(clean=clean)

if write_concern is None:
write_concern = {}
# Used to avoid saving a document that is already saving (infinite loops)
# this can be caused by the cascade save and circular references
if getattr(self, "_is_saving", False):
return
self._is_saving = True

doc_id = self.to_mongo(fields=[self._meta["id_field"]])
created = "_id" not in doc_id or self._created or force_insert
try:
signal_kwargs = signal_kwargs or {}

signals.pre_save_post_validation.send(
self.__class__, document=self, created=created, **signal_kwargs
)
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
doc = self.to_mongo()
if write_concern is None:
write_concern = {}

if self._meta.get("auto_create_index", True):
self.ensure_indexes()

try:
# Save a new document or update an existing one
if created:
object_id = self._save_create(doc, force_insert, write_concern)
else:
object_id, created = self._save_update(
doc, save_condition, write_concern
)
if self._meta.get("abstract"):
raise InvalidDocumentError("Cannot save an abstract document.")

# Cascade save before validation to avoid child not existing errors
if cascade is None:
cascade = self._meta.get("cascade", False) or cascade_kwargs is not None

has_placeholder_saved = False

if cascade:
# If a cascade will occur save a placeholder version of this document to
# avoid issues with cyclic saves if this doc has not been created yet
if self.id is None:
self._save_place_holder(force_insert, write_concern)
has_placeholder_saved = True

kwargs = {
"force_insert": force_insert,
"validate": validate,
Expand All @@ -434,31 +426,74 @@ def save(
kwargs["_refs"] = _refs
self.cascade_save(**kwargs)

except pymongo.errors.DuplicateKeyError as err:
message = "Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % err)
except pymongo.errors.OperationFailure as err:
message = "Could not save document (%s)"
if re.match("^E1100[01] duplicate key", str(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
# update force_insert to reflect that we might have already run the insert for
# the placeholder
force_insert = force_insert and not has_placeholder_saved

signals.pre_save.send(self.__class__, document=self, **signal_kwargs)

if validate:
self.validate(clean=clean)

doc_id = self.to_mongo(fields=[self._meta["id_field"]])
created = "_id" not in doc_id or self._created or force_insert

signals.pre_save_post_validation.send(
self.__class__, document=self, created=created, **signal_kwargs
)
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
doc = self.to_mongo()

if self._meta.get("auto_create_index", True):
self.ensure_indexes()

try:
# Save a new document or update an existing one
if created:
object_id = self._save_create(doc, force_insert, write_concern)
else:
object_id, created = self._save_update(
doc, save_condition, write_concern
)
except pymongo.errors.DuplicateKeyError as err:
message = "Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % err)
raise OperationError(message % err)
except pymongo.errors.OperationFailure as err:
message = "Could not save document (%s)"
if re.match("^E1100[01] duplicate key", str(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
message = "Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % err)
raise OperationError(message % err)

# Make sure we store the PK on this document now that it's saved
id_field = self._meta["id_field"]
if created or id_field not in self._meta.get("shard_key", []):
self[id_field] = self._fields[id_field].to_python(object_id)

signals.post_save.send(
self.__class__, document=self, created=created, **signal_kwargs
)

# Make sure we store the PK on this document now that it's saved
id_field = self._meta["id_field"]
if created or id_field not in self._meta.get("shard_key", []):
self[id_field] = self._fields[id_field].to_python(object_id)
self._clear_changed_fields()
self._created = False
except Exception as e:
raise e
finally:
self._is_saving = False

signals.post_save.send(
self.__class__, document=self, created=created, **signal_kwargs
)
return self

self._clear_changed_fields()
self._created = False
def _save_place_holder(self, force_insert, write_concern):
"""Save a temp placeholder to the db with nothing but the ID.
"""
data = SON()

return self
object_id = self._save_create(data, force_insert, write_concern)

id_field = self._meta["id_field"]
self[id_field] = self._fields[id_field].to_python(object_id)

def _save_create(self, doc, force_insert, write_concern):
"""Save a new document.
Expand Down Expand Up @@ -556,28 +591,11 @@ def cascade_save(self, **kwargs):
"""Recursively save any references and generic references on the
document.
"""
_refs = kwargs.get("_refs") or []

ReferenceField = _import_class("ReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")

for name, cls in self._fields.items():
if not isinstance(cls, (ReferenceField, GenericReferenceField)):
continue

ref = self._data.get(name)
if not ref or isinstance(ref, DBRef):
if not isinstance(cls, SaveableBaseField):
continue

if not getattr(ref, "_changed_fields", True):
continue

ref_id = f"{ref.__class__.__name__},{str(ref._data)}"
if ref and ref_id not in _refs:
_refs.append(ref_id)
kwargs["_refs"] = _refs
ref.save(**kwargs)
ref._changed_fields = []
cls.save(self, **kwargs)

@property
def _qs(self):
Expand Down
25 changes: 23 additions & 2 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mongoengine.base import (
BaseDocument,
BaseField,
SaveableBaseField,
ComplexBaseField,
GeoJsonBaseField,
LazyReference,
Expand Down Expand Up @@ -1123,7 +1124,7 @@ def __init__(self, field=None, *args, **kwargs):
super().__init__(field=field, *args, **kwargs)


class ReferenceField(BaseField):
class ReferenceField(SaveableBaseField):
"""A reference to a document that will be automatically dereferenced on
access (lazily).

Expand Down Expand Up @@ -1295,6 +1296,16 @@ def validate(self, value):
"saved to the database"
)

def save(self, instance, **kwargs):
ref = instance._data.get(self.name)
if not ref or isinstance(ref, DBRef):
return

if not getattr(self, "_changed_fields", True):
return

ref.save(**kwargs)

def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)

Expand Down Expand Up @@ -1464,7 +1475,7 @@ def sync_all(self):
self.owner_document.objects(**filter_kwargs).update(**update_kwargs)


class GenericReferenceField(BaseField):
class GenericReferenceField(SaveableBaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily).

Expand Down Expand Up @@ -1546,6 +1557,16 @@ def validate(self, value):
" saved to the database"
)

def save(self, instance, **kwargs):
ref = instance._data.get(self.name)
if not ref or isinstance(ref, DBRef):
return

if not getattr(ref, "_changed_fields", True):
return

ref.save(**kwargs)

def to_mongo(self, document):
if document is None:
return None
Expand Down
Loading