from logging import getLogger
import time
import collections
import django
from django.db import models
from django.db.models import query
try:
from django.db.models.fields.related import ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor
except ImportError:
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor
__version__ = '1.1.0'
logger = getLogger(__name__)
class PrefetchManagerMixin(models.Manager):
use_for_related_fields = True
prefetch_definitions = {}
@classmethod
def get_queryset_class(cls):
return PrefetchQuerySet
def __init__(self):
super(PrefetchManagerMixin, self).__init__()
for name, prefetcher in self.prefetch_definitions.items():
if prefetcher.__class__ is not Prefetcher and not callable(prefetcher):
raise InvalidPrefetch("Invalid prefetch definition %s. This prefetcher needs to be a class not an instance." % name)
def get_queryset(self):
qs = self.get_queryset_class()(
self.model, prefetch_definitions=self.prefetch_definitions
)
if getattr(self, '_db', None) is not None:
qs = qs.using(self._db)
return qs
def get_query_set(self):
"""
Django <1.6 compatibility method.
"""
return self.get_queryset()
def prefetch(self, *args):
return self.get_queryset().prefetch(*args)
class PrefetchManager(PrefetchManagerMixin):
def __init__(self, **kwargs):
self.prefetch_definitions = kwargs
super(PrefetchManager, self).__init__()
class InvalidPrefetch(Exception):
pass
class PrefetchOption(object):
def __init__(self, name, *args, **kwargs):
self.name = name
self.args = args
self.kwargs = kwargs
P = PrefetchOption
class PrefetchQuerySet(query.QuerySet):
def __init__(self, model=None, query=None, using=None,
prefetch_definitions=None, **kwargs):
if using is None: # this is to support Django 1.1
super(PrefetchQuerySet, self).__init__(model, query, **kwargs)
else:
super(PrefetchQuerySet, self).__init__(model, query, using, **kwargs)
self._prefetch = {}
self.prefetch_definitions = prefetch_definitions
def _clone(self, **kwargs):
return super(PrefetchQuerySet, self). \
_clone(_prefetch=self._prefetch,
prefetch_definitions=self.prefetch_definitions, **kwargs)
def prefetch(self, *names):
obj = self._clone()
for opt in names:
if isinstance(opt, PrefetchOption):
name = opt.name
else:
name = opt
opt = None
parts = name.split('__')
forwarders = []
prefetcher = None
model = self.model
prefetch_definitions = self.prefetch_definitions
for what in parts:
if not prefetcher:
if what in prefetch_definitions:
prefetcher = prefetch_definitions[what]
continue
descriptor = getattr(model, what, None)
if isinstance(descriptor, ForwardManyToOneDescriptor):
field = descriptor.field
forwarders.append(field.name)
if hasattr(field, 'remote_field'):
model = field.remote_field.model
else:
model = field.rel.to
manager = model.objects
if not isinstance(manager, PrefetchManagerMixin):
raise InvalidPrefetch('Manager for %s is not a PrefetchManagerMixin instance.' % model)
prefetch_definitions = manager.prefetch_definitions
else:
raise InvalidPrefetch("Invalid part %s in prefetch call for %s on model %s. "
"The name is not a prefetcher nor a forward relation (fk)." % (
what, name, self.model))
else:
raise InvalidPrefetch("Invalid part %s in prefetch call for %s on model %s. "
"You cannot have any more relations after the prefetcher." % (
what, name, self.model))
if not prefetcher:
raise InvalidPrefetch("Invalid prefetch call with %s for on model %s. "
"The last part isn't a prefetch definition." % (name, self.model))
if opt:
if prefetcher.__class__ is Prefetcher:
raise InvalidPrefetch("Invalid prefetch call with %s for on model %s. "
"This prefetcher (%s) needs to be a subclass of Prefetcher." % (
name, self.model, prefetcher))
obj._prefetch[name] = forwarders, prefetcher(*opt.args, **opt.kwargs)
else:
obj._prefetch[name] = forwarders, prefetcher if prefetcher.__class__ is Prefetcher else prefetcher()
for forwarders, prefetcher in obj._prefetch.values():
if forwarders:
if django.VERSION < (1, 7) and obj.query.select_related:
if not obj.query.max_depth:
obj.query.add_select_related('__'.join(forwarders))
else:
obj = obj.select_related('__'.join(forwarders))
return obj
def iterator(self):
data = list(super(PrefetchQuerySet, self).iterator())
for name, (forwarders, prefetcher) in self._prefetch.items():
prefetcher.fetch(data, name, self.model, forwarders,
getattr(self, '_db', None))
return iter(data)
[docs]class Prefetcher(object):
"""
Prefetch definitition. For convenience you can either subclass this and
define the methods on the subclass or just pass the functions to the
contructor.
Eg, subclassing::
class GroupPrefetcher(Prefetcher):
@staticmethod
def filter(ids):
return User.groups.through.objects.filter(user__in=ids).select_related('group')
@staticmethod
def reverse_mapper(user_group_association):
return [user_group_association.user_id]
@staticmethod
def decorator(user, user_group_associations=()):
setattr(user, 'prefetched_groups', [i.group for i in user_group_associations])
Or with contructor::
Prefetcher(
filter = lambda ids: User.groups.through.objects.filter(user__in=ids).select_related('group'),
reverse_mapper = lambda user_group_association: [user_group_association.user_id],
decorator = lambda user, user_group_associations=(): setattr(user, 'prefetched_groups', [
i.group for i in user_group_associations
])
)
Glossary:
* filter(list_of_ids):
A function that returns a queryset containing all the related data for a given list of keys.
Takes a list of ids as argument.
* reverse_mapper(related_object):
A function that takes the related object as argument and returns a list
of keys that maps that related object to the objects in the queryset.
* mapper(object):
Optional (defaults to ``lambda obj: obj.id``).
A function that returns the key for a given object in your query set.
* decorator(object, list_of_related_objects):
A function that will save the related data on each of your objects in
your queryset. Takes the object and a list of related objects as
arguments. Note that you should not override existing attributes on the
model instance here.
"""
collect = False
def __init__(self, filter=None, reverse_mapper=None, decorator=None, mapper=None, collect=None):
if filter:
self.filter = filter
elif not hasattr(self, 'filter'):
raise RuntimeError("You must define a filter function")
if reverse_mapper:
self.reverse_mapper = reverse_mapper
elif not hasattr(self, 'reverse_mapper'):
raise RuntimeError("You must define a reverse_mapper function")
if decorator:
self.decorator = decorator
elif not hasattr(self, 'decorator'):
raise RuntimeError("You must define a decorator function")
if mapper:
self.mapper = mapper
if collect is not None:
self.collect = collect
@staticmethod
def mapper(obj):
return obj.id
def fetch(self, dataset, name, model, forwarders, db):
collect = self.collect or forwarders
try:
data_mapping = collections.defaultdict(list)
t1 = time.time()
for obj in dataset:
for field in forwarders:
obj = getattr(obj, field, None)
if not obj:
continue
if collect:
data_mapping[self.mapper(obj)].append(obj)
else:
data_mapping[self.mapper(obj)] = obj
self.decorator(obj)
t2 = time.time()
logger.debug("Creating data_mapping for %s query took %.3f secs for the %s prefetcher.",
model.__name__, t2-t1, name)
t1 = time.time()
related_data = self.filter(data_mapping.keys())
if db is not None:
related_data = related_data.using(db)
related_data_len = len(related_data)
t2 = time.time()
logger.debug("Filtering for %s related objects for %s query took %.3f secs for the %s prefetcher.",
related_data_len, model.__name__, t2-t1, name)
relation_mapping = collections.defaultdict(list)
t1 = time.time()
for obj in related_data:
for id_ in self.reverse_mapper(obj):
if id_:
relation_mapping[id_].append(obj)
for id_, related_items in relation_mapping.items():
if id_ in data_mapping:
if collect:
for item in data_mapping[id_]:
self.decorator(item, related_items)
else:
self.decorator(data_mapping[id_], related_items)
t2 = time.time()
logger.debug("Adding the related objects on the %s query took %.3f secs for the %s prefetcher.",
model.__name__, t2-t1, name)
return dataset
except Exception:
logger.exception("Prefetch failed for %s prefetch on the %s model:", name, model.__name__)
raise