Skip to content
34 changes: 32 additions & 2 deletions cacheops/query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
import copy
import sys
import json
import threading
Expand All @@ -10,7 +12,7 @@
from django.utils.encoding import smart_str, force_text
from django.core.exceptions import ImproperlyConfigured, EmptyResultSet
from django.db import DEFAULT_DB_ALIAS
from django.db.models import Model
from django.db.models import Model, Prefetch
from django.db.models.manager import BaseManager
from django.db.models.query import QuerySet
from django.db.models.signals import pre_save, post_save, post_delete, m2m_changed
Expand All @@ -23,7 +25,8 @@
MAX_GET_RESULTS = None

from .conf import model_profile, settings, ALL_OPS
from .utils import monkey_mix, stamp_fields, func_cache_key, cached_view_fab, family_has_profile
from .utils import monkey_mix, stamp_fields, func_cache_key, cached_view_fab, \
family_has_profile, get_model_from_lookup
from .utils import md5
from .sharding import get_prefix
from .redis import redis_client, handle_connection_failure, load_script
Expand Down Expand Up @@ -251,6 +254,33 @@ def nocache(self):
else:
return self.cache(ops=[])

def cache_prefetch_related(self, *lookups):
"""
Same as prefetch_related but attempts to pull relations from the cache instead

lookups - same as for django's vanilla prefetch_related()
"""

# If relations are already fetched there is no point to continuing
if self._prefetch_done:
return self

prefetches = []

for pf in lookups:
if isinstance(pf, Prefetch):
item = copy.copy(pf)
item.queryset = item.queryset.cache(ops=['fetch'])
prefetches.append(item)

if isinstance(pf, str):
model_class = get_model_from_lookup(self.model, pf)
prefetches.append(
Prefetch(pf, model_class._default_manager.all().cache(ops=['fetch']))
)

return self.prefetch_related(*prefetches)

def cloning(self, cloning=1000):
self._cloning = cloning
return self
Expand Down
25 changes: 25 additions & 0 deletions cacheops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,31 @@ def wrapper(request, *args, **kwargs):
return cached_view


def get_model_from_lookup(base_model, orm_lookup):
"""
Given a base model and an ORM lookup, follow any relations and return
the final model class of the lookup.
"""

result = base_model
for field_name in orm_lookup.split('__'):

if field_name.endswith('_set'):
field_name = field_name.split('_set')[0]

try:
field = result._meta.get_field(field_name)
except models.FieldDoesNotExist:
break

if hasattr(field, 'related_model'):
result = field.related_model
else:
break

return result


### Whitespace handling for template tags

from django.utils.safestring import mark_safe
Expand Down
26 changes: 25 additions & 1 deletion tests/test_extras.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from django.db import transaction
from django.db.models import Prefetch
from django.test import TestCase, override_settings

from cacheops import cached_as, no_invalidation, invalidate_obj, invalidate_model, invalidate_all
from cacheops.conf import settings
from cacheops.signals import cache_read, cache_invalidated
from cacheops.utils import get_model_from_lookup

from .utils import BaseTestCase, make_inc
from .models import Post, Category, Local, DbAgnostic, DbBinded
from .models import Post, Category, Local, DbAgnostic, DbBinded, Brand, Label


class SettingsTests(TestCase):
Expand Down Expand Up @@ -183,3 +185,25 @@ def test_db_agnostic_disabled(self):

with self.assertNumQueries(1, using='slave'):
list(DbBinded.objects.cache().using('slave'))


class CachedPrefetchTest(BaseTestCase):

def test_get_model_from_lookup(self):
assert get_model_from_lookup(Brand, 'labels') is Label

def test_cache_prefetch_related(self):
qs = Brand.objects.all().cache_prefetch_related('labels')

pf = qs._prefetch_related_lookups[0]

assert isinstance(pf, Prefetch)
assert pf.queryset.model is Label
assert pf.queryset._cacheprofile

def test_cache_prefetch_related_with_ops(self):
qs = Brand.objects.all().cache_prefetch_related('labels')

pf = qs._prefetch_related_lookups[0]

self.assertEqual(pf.queryset._cacheprofile['ops'], {'fetch'})