Skip to content

Commit 1e6c336

Browse files
committed
Fixed django#18177 -- Cached known related instances.
This was recently fixed for one-to-one relations; this patch adds support for foreign keys. Thanks kaiser.yann for the report and the initial version of the patch.
1 parent 3b2993e commit 1e6c336

File tree

7 files changed

+243
-31
lines changed

7 files changed

+243
-31
lines changed

django/db/models/fields/related.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,18 @@ def get_query_set(self, **db_hints):
237237
return self.related.model._base_manager.using(db)
238238

239239
def get_prefetch_query_set(self, instances):
240-
vals = set(instance._get_pk_val() for instance in instances)
241-
params = {'%s__pk__in' % self.related.field.name: vals}
242-
return (self.get_query_set(instance=instances[0]).filter(**params),
243-
attrgetter(self.related.field.attname),
244-
lambda obj: obj._get_pk_val(),
245-
True,
246-
self.cache_name)
240+
rel_obj_attr = attrgetter(self.related.field.attname)
241+
instance_attr = lambda obj: obj._get_pk_val()
242+
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
243+
params = {'%s__pk__in' % self.related.field.name: instances_dict.keys()}
244+
qs = self.get_query_set(instance=instances[0]).filter(**params)
245+
# Since we're going to assign directly in the cache,
246+
# we must manage the reverse relation cache manually.
247+
rel_obj_cache_name = self.related.field.get_cache_name()
248+
for rel_obj in qs:
249+
instance = instances_dict[rel_obj_attr(rel_obj)]
250+
setattr(rel_obj, rel_obj_cache_name, instance)
251+
return qs, rel_obj_attr, instance_attr, True, self.cache_name
247252

248253
def __get__(self, instance, instance_type=None):
249254
if instance is None:
@@ -324,17 +329,23 @@ def get_query_set(self, **db_hints):
324329
return QuerySet(self.field.rel.to).using(db)
325330

326331
def get_prefetch_query_set(self, instances):
327-
vals = set(getattr(instance, self.field.attname) for instance in instances)
332+
rel_obj_attr = attrgetter(self.field.rel.field_name)
333+
instance_attr = attrgetter(self.field.attname)
334+
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
328335
other_field = self.field.rel.get_related_field()
329336
if other_field.rel:
330-
params = {'%s__pk__in' % self.field.rel.field_name: vals}
337+
params = {'%s__pk__in' % self.field.rel.field_name: instances_dict.keys()}
331338
else:
332-
params = {'%s__in' % self.field.rel.field_name: vals}
333-
return (self.get_query_set(instance=instances[0]).filter(**params),
334-
attrgetter(self.field.rel.field_name),
335-
attrgetter(self.field.attname),
336-
True,
337-
self.cache_name)
339+
params = {'%s__in' % self.field.rel.field_name: instances_dict.keys()}
340+
qs = self.get_query_set(instance=instances[0]).filter(**params)
341+
# Since we're going to assign directly in the cache,
342+
# we must manage the reverse relation cache manually.
343+
if not self.field.rel.multiple:
344+
rel_obj_cache_name = self.field.related.get_cache_name()
345+
for rel_obj in qs:
346+
instance = instances_dict[rel_obj_attr(rel_obj)]
347+
setattr(rel_obj, rel_obj_cache_name, instance)
348+
return qs, rel_obj_attr, instance_attr, True, self.cache_name
338349

339350
def __get__(self, instance, instance_type=None):
340351
if instance is None:
@@ -467,18 +478,24 @@ def get_query_set(self):
467478
return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
468479
except (AttributeError, KeyError):
469480
db = self._db or router.db_for_read(self.model, instance=self.instance)
470-
return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
481+
qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
482+
qs._known_related_object = (rel_field.name, self.instance)
483+
return qs
471484

472485
def get_prefetch_query_set(self, instances):
486+
rel_obj_attr = attrgetter(rel_field.get_attname())
487+
instance_attr = attrgetter(attname)
488+
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
473489
db = self._db or router.db_for_read(self.model, instance=instances[0])
474-
query = {'%s__%s__in' % (rel_field.name, attname):
475-
set(getattr(obj, attname) for obj in instances)}
490+
query = {'%s__%s__in' % (rel_field.name, attname): instances_dict.keys()}
476491
qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
477-
return (qs,
478-
attrgetter(rel_field.get_attname()),
479-
attrgetter(attname),
480-
False,
481-
rel_field.related_query_name())
492+
# Since we just bypassed this class' get_query_set(), we must manage
493+
# the reverse relation manually.
494+
for rel_obj in qs:
495+
instance = instances_dict[rel_obj_attr(rel_obj)]
496+
setattr(rel_obj, rel_field.name, instance)
497+
cache_name = rel_field.related_query_name()
498+
return qs, rel_obj_attr, instance_attr, False, cache_name
482499

483500
def add(self, *objs):
484501
for obj in objs:

django/db/models/query.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(self, model=None, query=None, using=None):
4141
self._for_write = False
4242
self._prefetch_related_lookups = []
4343
self._prefetch_done = False
44+
self._known_related_object = None # (attname, rel_obj)
4445

4546
########################
4647
# PYTHON MAGIC METHODS #
@@ -282,9 +283,10 @@ def iterator(self):
282283
init_list.append(field.attname)
283284
model_cls = deferred_class_factory(self.model, skip)
284285

285-
# Cache db and model outside the loop
286+
# Cache db, model and known_related_object outside the loop
286287
db = self.db
287288
model = self.model
289+
kro_attname, kro_instance = self._known_related_object or (None, None)
288290
compiler = self.query.get_compiler(using=db)
289291
if fill_cache:
290292
klass_info = get_klass_info(model, max_depth=max_depth,
@@ -294,12 +296,12 @@ def iterator(self):
294296
obj, _ = get_cached_row(row, index_start, db, klass_info,
295297
offset=len(aggregate_select))
296298
else:
299+
# Omit aggregates in object creation.
300+
row_data = row[index_start:aggregate_start]
297301
if skip:
298-
row_data = row[index_start:aggregate_start]
299302
obj = model_cls(**dict(zip(init_list, row_data)))
300303
else:
301-
# Omit aggregates in object creation.
302-
obj = model(*row[index_start:aggregate_start])
304+
obj = model(*row_data)
303305

304306
# Store the source database of the object
305307
obj._state.db = db
@@ -313,7 +315,11 @@ def iterator(self):
313315
# Add the aggregates to the model
314316
if aggregate_select:
315317
for i, aggregate in enumerate(aggregate_select):
316-
setattr(obj, aggregate, row[i+aggregate_start])
318+
setattr(obj, aggregate, row[i + aggregate_start])
319+
320+
# Add the known related object to the model, if there is one
321+
if kro_instance:
322+
setattr(obj, kro_attname, kro_instance)
317323

318324
yield obj
319325

@@ -864,6 +870,7 @@ def _clone(self, klass=None, setup=False, **kwargs):
864870
c = klass(model=self.model, query=query, using=self._db)
865871
c._for_write = self._for_write
866872
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
873+
c._known_related_object = self._known_related_object
867874
c.__dict__.update(kwargs)
868875
if setup and hasattr(c, '_setup_query'):
869876
c._setup_query()
@@ -1781,9 +1788,7 @@ def prefetch_one_level(instances, prefetcher, attname):
17811788
rel_obj_cache = {}
17821789
for rel_obj in all_related_objects:
17831790
rel_attr_val = rel_obj_attr(rel_obj)
1784-
if rel_attr_val not in rel_obj_cache:
1785-
rel_obj_cache[rel_attr_val] = []
1786-
rel_obj_cache[rel_attr_val].append(rel_obj)
1791+
rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
17871792

17881793
for obj in instances:
17891794
instance_attr_val = instance_attr(obj)

docs/releases/1.5.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@ reasons or when trying to avoid overwriting concurrent changes.
4444
See the :meth:`Model.save() <django.db.models.Model.save()>` documentation for
4545
more details.
4646

47+
Caching of related model instances
48+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
49+
50+
When traversing relations, the ORM will avoid re-fetching objects that were
51+
previously loaded. For example, with the tutorial's models::
52+
53+
>>> first_poll = Poll.objects.all()[0]
54+
>>> first_choice = first_poll.choice_set.all()[0]
55+
>>> first_choice.poll is first_poll
56+
True
57+
58+
In Django 1.5, the third line no longer triggers a new SQL query to fetch
59+
``first_choice.poll``; it was set when by the second line.
60+
61+
For one-to-one relationships, both sides can be cached. For many-to-one
62+
relationships, only the single side of the relationship can be cached. This
63+
is particularly helpful in combination with ``prefetch_related``.
64+
4765
Minor features
4866
~~~~~~~~~~~~~~
4967

tests/modeltests/known_related_objects/__init__.py

Whitespace-only changes.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
[
2+
{
3+
"pk": 1,
4+
"model": "known_related_objects.tournament",
5+
"fields": {
6+
"name": "Tourney 1"
7+
}
8+
},
9+
{
10+
"pk": 2,
11+
"model": "known_related_objects.tournament",
12+
"fields": {
13+
"name": "Tourney 2"
14+
}
15+
},
16+
{
17+
"pk": 1,
18+
"model": "known_related_objects.pool",
19+
"fields": {
20+
"tournament": 1,
21+
"name": "T1 Pool 1"
22+
}
23+
},
24+
{
25+
"pk": 2,
26+
"model": "known_related_objects.pool",
27+
"fields": {
28+
"tournament": 1,
29+
"name": "T1 Pool 2"
30+
}
31+
},
32+
{
33+
"pk": 3,
34+
"model": "known_related_objects.pool",
35+
"fields": {
36+
"tournament": 2,
37+
"name": "T2 Pool 1"
38+
}
39+
},
40+
{
41+
"pk": 4,
42+
"model": "known_related_objects.pool",
43+
"fields": {
44+
"tournament": 2,
45+
"name": "T2 Pool 2"
46+
}
47+
},
48+
{
49+
"pk": 1,
50+
"model": "known_related_objects.poolstyle",
51+
"fields": {
52+
"name": "T1 Pool 2 Style",
53+
"pool": 2
54+
}
55+
},
56+
{
57+
"pk": 2,
58+
"model": "known_related_objects.poolstyle",
59+
"fields": {
60+
"name": "T2 Pool 1 Style",
61+
"pool": 3
62+
}
63+
}
64+
]
65+
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
Existing related object instance caching.
3+
4+
Test that queries are not redone when going back through known relations.
5+
"""
6+
7+
from django.db import models
8+
9+
class Tournament(models.Model):
10+
name = models.CharField(max_length=30)
11+
12+
class Pool(models.Model):
13+
name = models.CharField(max_length=30)
14+
tournament = models.ForeignKey(Tournament)
15+
16+
class PoolStyle(models.Model):
17+
name = models.CharField(max_length=30)
18+
pool = models.OneToOneField(Pool)
19+
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import absolute_import
2+
3+
from django.test import TestCase
4+
5+
from .models import Tournament, Pool, PoolStyle
6+
7+
class ExistingRelatedInstancesTests(TestCase):
8+
fixtures = ['tournament.json']
9+
10+
def test_foreign_key(self):
11+
with self.assertNumQueries(2):
12+
tournament = Tournament.objects.get(pk=1)
13+
pool = tournament.pool_set.all()[0]
14+
self.assertIs(tournament, pool.tournament)
15+
16+
def test_foreign_key_prefetch_related(self):
17+
with self.assertNumQueries(2):
18+
tournament = (Tournament.objects.prefetch_related('pool_set').get(pk=1))
19+
pool = tournament.pool_set.all()[0]
20+
self.assertIs(tournament, pool.tournament)
21+
22+
def test_foreign_key_multiple_prefetch(self):
23+
with self.assertNumQueries(2):
24+
tournaments = list(Tournament.objects.prefetch_related('pool_set'))
25+
pool1 = tournaments[0].pool_set.all()[0]
26+
self.assertIs(tournaments[0], pool1.tournament)
27+
pool2 = tournaments[1].pool_set.all()[0]
28+
self.assertIs(tournaments[1], pool2.tournament)
29+
30+
def test_one_to_one(self):
31+
with self.assertNumQueries(2):
32+
style = PoolStyle.objects.get(pk=1)
33+
pool = style.pool
34+
self.assertIs(style, pool.poolstyle)
35+
36+
def test_one_to_one_select_related(self):
37+
with self.assertNumQueries(1):
38+
style = PoolStyle.objects.select_related('pool').get(pk=1)
39+
pool = style.pool
40+
self.assertIs(style, pool.poolstyle)
41+
42+
def test_one_to_one_multi_select_related(self):
43+
with self.assertNumQueries(1):
44+
poolstyles = list(PoolStyle.objects.select_related('pool'))
45+
self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle)
46+
self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle)
47+
48+
def test_one_to_one_prefetch_related(self):
49+
with self.assertNumQueries(2):
50+
style = PoolStyle.objects.prefetch_related('pool').get(pk=1)
51+
pool = style.pool
52+
self.assertIs(style, pool.poolstyle)
53+
54+
def test_one_to_one_multi_prefetch_related(self):
55+
with self.assertNumQueries(2):
56+
poolstyles = list(PoolStyle.objects.prefetch_related('pool'))
57+
self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle)
58+
self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle)
59+
60+
def test_reverse_one_to_one(self):
61+
with self.assertNumQueries(2):
62+
pool = Pool.objects.get(pk=2)
63+
style = pool.poolstyle
64+
self.assertIs(pool, style.pool)
65+
66+
def test_reverse_one_to_one_select_related(self):
67+
with self.assertNumQueries(1):
68+
pool = Pool.objects.select_related('poolstyle').get(pk=2)
69+
style = pool.poolstyle
70+
self.assertIs(pool, style.pool)
71+
72+
def test_reverse_one_to_one_prefetch_related(self):
73+
with self.assertNumQueries(2):
74+
pool = Pool.objects.prefetch_related('poolstyle').get(pk=2)
75+
style = pool.poolstyle
76+
self.assertIs(pool, style.pool)
77+
78+
def test_reverse_one_to_one_multi_select_related(self):
79+
with self.assertNumQueries(1):
80+
pools = list(Pool.objects.select_related('poolstyle'))
81+
self.assertIs(pools[1], pools[1].poolstyle.pool)
82+
self.assertIs(pools[2], pools[2].poolstyle.pool)
83+
84+
def test_reverse_one_to_one_multi_prefetch_related(self):
85+
with self.assertNumQueries(2):
86+
pools = list(Pool.objects.prefetch_related('poolstyle'))
87+
self.assertIs(pools[1], pools[1].poolstyle.pool)
88+
self.assertIs(pools[2], pools[2].poolstyle.pool)

0 commit comments

Comments
 (0)