Compare commits
1 Commits
starfields
...
starfields
| Author | SHA1 | Date | |
|---|---|---|---|
| c1fd01a6d8 |
122
starfields_drf_generics/cache_mixins.py
Normal file
122
starfields_drf_generics/cache_mixins.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from libraries.utils import sorted_params_string
|
||||
|
||||
# TODO classes below that involve create, update, destroy don't delete the caches properly, they need a regex cache delete
|
||||
# TODO put more reasonable asserts and feedback
|
||||
|
||||
# Mixin classes that provide cache functionalities
|
||||
class CacheUniqueUrl:
|
||||
def get_cache_unique_url(self, request):
|
||||
""" Create the query to be cached in a unique way to avoid duplicates. """
|
||||
if not hasattr(self, 'filters_string'):
|
||||
# Only assign the attribute if it's not already assigned
|
||||
filters = {}
|
||||
if self.extra_filters_dict:
|
||||
filters.update(self.extra_filters_dict)
|
||||
# Check if the url parameters have any of the keys of the extra filters and if so assign them
|
||||
for key in self.extra_filters_dict:
|
||||
if key in self.request.query_params.keys():
|
||||
filters[key] = self.request.query_params[key].replace(' ','').split(',')
|
||||
# Check if they're resolved in the urlconf as well
|
||||
if key in self.kwargs.keys():
|
||||
filters[key] = [self.kwargs[key]]
|
||||
|
||||
if hasattr(self, 'paged'):
|
||||
if self.paged:
|
||||
filters.update({'limit': [self.default_page_size], 'offset': [0]})
|
||||
if 'limit' in self.request.query_params.keys():
|
||||
filters.update({'limit': [self.request.query_params['limit']]})
|
||||
if 'offset' in self.request.query_params.keys():
|
||||
filters.update({'offset': [self.request.query_params['offset']]})
|
||||
for backend in list(self.filter_backends):
|
||||
filters.update(backend().get_filters_dict(request, self))
|
||||
self.filters_string = sorted_params_string(filters)
|
||||
|
||||
|
||||
class CacheGetMixin(CacheUniqueUrl):
|
||||
cache_prefix = None
|
||||
cache_vary_on_user = False
|
||||
cache_timeout_mins = None
|
||||
default_page_size = 20
|
||||
extra_filters_dict = None
|
||||
|
||||
def get_cache(self, request):
|
||||
assert self.cache_prefix is not None, (
|
||||
"'%s' should include a `cache_prefix` attribute"
|
||||
% self.__class__.__name__
|
||||
)
|
||||
|
||||
self.get_cache_unique_url(request)
|
||||
|
||||
# Attempt to get the response from the cache for the whole request
|
||||
try:
|
||||
if self.cache_vary_on_user:
|
||||
cache_attempt = self.cache.get(f"{self.cache_prefix}.{request.user}.{self.filters_string}")
|
||||
else:
|
||||
cache_attempt = self.cache.get(f"{self.cache_prefix}.{self.filters_string}")
|
||||
except:
|
||||
self.logger.info(f"Cache get attempt for {self.__class__.__name__} failed.")
|
||||
cache_attempt = None
|
||||
|
||||
if cache_attempt:
|
||||
return cache_attempt
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class CacheSetMixin(CacheUniqueUrl):
|
||||
cache_prefix = None
|
||||
cache_vary_on_user = False
|
||||
cache_timeout_mins = None
|
||||
default_page_size = 20
|
||||
extra_filters_dict = None
|
||||
|
||||
def set_cache(self, request, response):
|
||||
self.get_cache_unique_url(request)
|
||||
|
||||
# Create a function that programmatically defines the caching function
|
||||
def make_caching_function(cls, request, cache):
|
||||
def caching_function(response):
|
||||
# Writes the response to the cache
|
||||
try:
|
||||
if self.cache_vary_on_user:
|
||||
self.cache.set(key=f"{self.cache_prefix}.{request.user}.{self.filters_string}",
|
||||
value = response.data,
|
||||
timeout=60*self.cache_timeout_mins)
|
||||
else:
|
||||
self.cache.set(key=f"{self.cache_prefix}.{self.filters_string}",
|
||||
value = response.data,
|
||||
timeout=60*self.cache_timeout_mins)
|
||||
except:
|
||||
self.logger.exception(f"Cache set attempt for {self.__class__.__name__} failed.")
|
||||
return caching_function
|
||||
|
||||
# Register the post rendering hook to the response
|
||||
caching_function = make_caching_function(self, request, self.cache)
|
||||
response.add_post_render_callback(caching_function)
|
||||
|
||||
|
||||
class CacheDeleteMixin(CacheUniqueUrl):
|
||||
cache_delete = True
|
||||
cache_prefix = None
|
||||
cache_vary_on_user = False
|
||||
cache_timeout_mins = None
|
||||
extra_filters_dict = None
|
||||
|
||||
def delete_cache(self, request):
|
||||
# Handle the caching
|
||||
if self.cache_delete:
|
||||
# Create the query to be cached in a unique way to avoid duplicates
|
||||
self.get_cache_unique_url(request)
|
||||
|
||||
assert self.cache_prefix is not None, (
|
||||
f"{self.__class__.__name__} should include a `cache_prefix` attribute"
|
||||
)
|
||||
|
||||
# Delete the cache since a new entry has been created
|
||||
try:
|
||||
if self.cache_vary_on_user:
|
||||
self.cache.delete(f"{self.cache_prefix}.{request.user}.{self.filters_string}")
|
||||
else:
|
||||
self.cache.delete(f"{self.cache_prefix}.{self.filters_string}")
|
||||
except:
|
||||
self.logger.exception(f"Cache delete attempt for {self.__class__.__name__} failed.")
|
||||
793
starfields_drf_generics/filters.py
Normal file
793
starfields_drf_generics/filters.py
Normal file
@@ -0,0 +1,793 @@
|
||||
from django_filters import rest_framework as filters
|
||||
from django.contrib.postgres.search import TrigramSimilarity
|
||||
from django.db.models.functions import Concat
|
||||
from django.db.models import CharField
|
||||
from shop.models.product import Product, Facet
|
||||
from rest_framework.filters import BaseFilterBackend
|
||||
import operator
|
||||
from django.template import loader
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.db.models import Max, Min, Count, Q
|
||||
from rest_framework.settings import api_settings
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db import models
|
||||
from functools import reduce
|
||||
|
||||
# TODO the dev pages are not done
|
||||
|
||||
# Filters
|
||||
class LessThanOrEqualFilter(BaseFilterBackend):
|
||||
def get_less_than_field(self, view, request):
|
||||
return getattr(view, 'less_than_field', None)
|
||||
|
||||
def get_filters_dict(self, request, view):
|
||||
"""
|
||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||
"""
|
||||
less_than_field = self.get_less_than_field(view, request)
|
||||
|
||||
assert less_than_field is not None, (
|
||||
f"{view.__class__.__name__} should include a `less_than_field` attribute"
|
||||
)
|
||||
|
||||
filters_dict = {}
|
||||
if less_than_field+'_max' in request.query_params.keys():
|
||||
field_value = request.query_params.get(less_than_field+'_max')
|
||||
filters_dict[less_than_field+'_max'] = [field_value]
|
||||
else:
|
||||
filters_dict[less_than_field+'_max'] = []
|
||||
return filters_dict
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
# Return the correctly filtered queryset but also assign the filter dict to create the unique url for the cache
|
||||
less_than_field = self.get_less_than_field(view, request)
|
||||
if less_than_field+'_max' in request.query_params.keys():
|
||||
kwquery = {}
|
||||
field_value = request.query_params.get(less_than_field+'_max')
|
||||
kwquery[less_than_field+'__lte'] = field_value
|
||||
return queryset.filter(**kwquery)
|
||||
else:
|
||||
return queryset
|
||||
|
||||
|
||||
class MoreThanOrEqualFilter(BaseFilterBackend):
|
||||
def get_more_than_field(self, view, request):
|
||||
return getattr(view, 'more_than_field', None)
|
||||
|
||||
def get_filters_dict(self, request, view):
|
||||
"""
|
||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||
"""
|
||||
more_than_field = self.get_more_than_field(view, request)
|
||||
|
||||
assert more_than_field is not None, (
|
||||
f"{view.__class__.__name__} should include a `more_than_field` attribute"
|
||||
)
|
||||
|
||||
filters_dict = {}
|
||||
if more_than_field+'_min' in request.query_params.keys():
|
||||
field_value = request.query_params.get(more_than_field+'_min')
|
||||
filters_dict[more_than_field+'_min'] = [field_value]
|
||||
else:
|
||||
filters_dict[more_than_field+'_min'] = []
|
||||
return filters_dict
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
"""
|
||||
Return the correctly filtered queryset
|
||||
"""
|
||||
more_than_field = self.get_more_than_field(view, request)
|
||||
if more_than_field+'_min' in request.query_params.keys():
|
||||
kwquery = {}
|
||||
kwquery[more_than_field+'__gte'] = request.query_params.get(more_than_field+'_min')
|
||||
return queryset.filter(**kwquery)
|
||||
else:
|
||||
return queryset
|
||||
|
||||
|
||||
class CategoryFilter(BaseFilterBackend):
|
||||
"""
|
||||
This filter assigns the view.category object for later use, in particular for filters that depend on this one.
|
||||
"""
|
||||
template = './filters/categories.html'
|
||||
category_field = 'category'
|
||||
|
||||
def get_category_class(self, view, request):
|
||||
return getattr(view, 'category_class', None)
|
||||
|
||||
def assign_view_category(self, request, view):
|
||||
if not hasattr(view, self.category_field):
|
||||
if self.category_field in request.query_params.keys():
|
||||
try:
|
||||
category_slug = request.query_params.get(self.category_field).strip()
|
||||
category = view.category_class.objects.get(slug=category_slug)
|
||||
# Append the category object in the view for later reference
|
||||
view.category = category
|
||||
except:
|
||||
view.category = None
|
||||
else:
|
||||
view.category = None
|
||||
|
||||
|
||||
def get_filters_dict(self, request, view):
|
||||
"""
|
||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. Queries the database for the current category and saves it in view.category for internal use and later filters.
|
||||
"""
|
||||
if hasattr(view, 'category_field'):
|
||||
self.category_field = self.get_category_field(view, request)
|
||||
|
||||
category_class = self.get_category_class(view, request)
|
||||
|
||||
assert category_class is not None, (
|
||||
f"{view.__class__.__name__} should include a `category_class` attribute"
|
||||
)
|
||||
|
||||
# Create the filters dictionary and find the present category instance
|
||||
self.assign_view_category(request, view)
|
||||
|
||||
filters_dict = {}
|
||||
if view.category:
|
||||
filters_dict[self.category_field] = [view.category.slug]
|
||||
else:
|
||||
filters_dict[self.category_field] = []
|
||||
|
||||
return filters_dict
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
self.assign_view_category(request, view)
|
||||
|
||||
# Create the queryset
|
||||
if view.category:
|
||||
kwquery_1 = {}
|
||||
kwquery_1[self.category_field+'__id'] = view.category.id
|
||||
if view.category.tn_descendants_pks:
|
||||
kwquery_2 = {}
|
||||
kwquery_2[self.category_field+'__id__in'] = view.category.tn_descendants_pks.split(',')
|
||||
|
||||
queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2))
|
||||
else:
|
||||
queryset = queryset.filter(**kwquery_1)
|
||||
|
||||
return queryset
|
||||
|
||||
# Developer Interface methods
|
||||
def get_valid_fields(self, queryset, view, context, request):
|
||||
# A query is executed here to get the possible categories
|
||||
category_class = self.get_category_class(view, request)
|
||||
if hasattr(view, 'category_field'):
|
||||
self.category_field = self.get_category_field(view, request)
|
||||
|
||||
assert category_class is not None, (
|
||||
f"{view.__class__.__name__} should include a `category_class` attribute"
|
||||
)
|
||||
|
||||
valid_fields = category_class.objects.all()
|
||||
|
||||
if len(valid_fields):
|
||||
valid_fields = [
|
||||
(item.slug, item.__str__()) for item in valid_fields
|
||||
]
|
||||
|
||||
return valid_fields
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_current(self, request, queryset, view):
|
||||
params = request.query_params.get(self.category_field)
|
||||
if params:
|
||||
fields = [param.strip() for param in params.split(',')]
|
||||
return fields[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_template_context(self, request, queryset, view):
|
||||
current = self.get_current(request, queryset, view)
|
||||
options = []
|
||||
context = {
|
||||
'request': request,
|
||||
'current': current,
|
||||
'param': self.category_field,
|
||||
}
|
||||
for key, label in self.get_valid_fields(queryset, view, context, request):
|
||||
options.append((key, '%s' % (label)))
|
||||
context['options'] = options
|
||||
return context
|
||||
|
||||
def to_html(self, request, queryset, view):
|
||||
template = loader.get_template(self.template)
|
||||
context = self.get_template_context(request, queryset, view)
|
||||
return template.render(context)
|
||||
|
||||
|
||||
class FacetFilter(BaseFilterBackend):
|
||||
"""
|
||||
This filter requires CategoryFilter to be ran before it. It assigns the view.facets which includes all the facets applicable to the current category.
|
||||
"""
|
||||
template = './filters/facets.html'
|
||||
|
||||
def get_facet_class(self, view, request):
|
||||
return getattr(view, 'facet_class', None)
|
||||
|
||||
def get_facet_tag_class(self, view, request):
|
||||
return getattr(view, 'facet_tag_class', None)
|
||||
|
||||
def get_facet_tag_field(self, view, request):
|
||||
return getattr(view, 'facet_tag_field', None)
|
||||
|
||||
def assign_view_facets(self, request, view):
|
||||
if not hasattr(view, 'facets'):
|
||||
if view.category:
|
||||
if view.category.tn_ancestors_pks:
|
||||
view.facets = Facet.objects.filter(Q(category__id=view.category.id) | Q(category__id__in=view.category.tn_ancestors_pks.split(','))).prefetch_related('facet_tags')
|
||||
else:
|
||||
view.facets = Facet.objects.filter(category__id=view.category.id).prefetch_related('facet_tags')
|
||||
else:
|
||||
view.facets = Facet.objects.filter(category__tn_level=1).prefetch_related('facet_tags')
|
||||
|
||||
def get_filters_dict(self, request, view):
|
||||
"""
|
||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||
"""
|
||||
if hasattr(view, 'facet_class'):
|
||||
self.facet_class = self.get_facet_class(view, request)
|
||||
|
||||
assert self.facet_class is not None, (
|
||||
f"{view.__class__.__name__} should include a `facet_class` attribute"
|
||||
)
|
||||
|
||||
self.assign_view_facets(request, view)
|
||||
|
||||
filters_dict = {}
|
||||
if view.facets:
|
||||
for facet in view.facets:
|
||||
if facet.slug in request.query_params.keys():
|
||||
filters_dict[facet.slug] = set(request.query_params[facet.slug].split(','))
|
||||
else:
|
||||
filters_dict[facet.slug] = set({})
|
||||
|
||||
# Append the facets object and the tags dict in the view for later reference
|
||||
view.tags = filters_dict
|
||||
|
||||
return filters_dict
|
||||
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
if hasattr(view, 'facet_tag_class'):
|
||||
self.facet_tag_class = self.get_facet_tag_class(view, request)
|
||||
|
||||
assert self.facet_tag_class is not None, (
|
||||
f"{view.__class__.__name__} should include a `facet_tag_class` attribute"
|
||||
)
|
||||
|
||||
if hasattr(view, 'facet_tag_field'):
|
||||
self.facet_tag_field = self.get_facet_tag_field(view, request)
|
||||
|
||||
assert self.facet_tag_field is not None, (
|
||||
f"{view.__class__.__name__} should include a `facet_tag_field` attribute"
|
||||
)
|
||||
|
||||
self.assign_view_facets(request, view)
|
||||
|
||||
if view.facets:
|
||||
for facet in view.facets:
|
||||
if facet.slug in request.query_params.keys() and request.query_params[facet.slug]:
|
||||
tag_filterlist = request.query_params.get(facet.slug)
|
||||
if tag_filterlist == '':
|
||||
# If the tag filterlist is empty then we're not filtering against it, it's like having all the tags of the facet selected
|
||||
pass
|
||||
else:
|
||||
kwquery = {}
|
||||
kwquery[self.facet_tag_field+'__slug__in'] = tag_filterlist.replace(' ','').split(',')
|
||||
queryset = queryset.filter(**kwquery)
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
# Developer Interface methods
|
||||
def get_template_context(self, request, queryset, view):
|
||||
# Does aggressive database querying to get the necessary facets and facettags, but this is only for the developer interface so its fine
|
||||
if hasattr(view, 'facet_class'):
|
||||
self.facet_class = self.get_facet_class(view, request)
|
||||
|
||||
assert self.facet_class is not None, (
|
||||
f"{view.__class__.__name__} should include a `facet_class` attribute"
|
||||
)
|
||||
|
||||
if hasattr(view, 'facet_tag_class'):
|
||||
self.facet_tag_class = self.get_facet_tag_class(view, request)
|
||||
|
||||
assert self.facet_tag_class is not None, (
|
||||
f"{view.__class__.__name__} should include a `facet_tag_class` attribute"
|
||||
)
|
||||
|
||||
# Find the current choices
|
||||
current = []
|
||||
facet_slugs = []
|
||||
if view.facets:
|
||||
for facet in view.facets:
|
||||
facet_slugs.append(facet.slug)
|
||||
if facet.slug in request.query_params.keys():
|
||||
current.append(request.query_params.get(facet.slug))
|
||||
|
||||
facet_slug_names = {}
|
||||
options = {}
|
||||
context = {
|
||||
'request': request,
|
||||
'current': current,
|
||||
'facet_slugs': facet_slugs,
|
||||
}
|
||||
if view.facets:
|
||||
for facet in view.facets:
|
||||
facet_tag_instances = self.facet_tag_class.objects.filter(facet__slug=facet.slug)
|
||||
options[facet.slug] = [(facet_tag.slug, facet_tag.name) for facet_tag in facet_tag_instances]
|
||||
facet_slug_names[facet.slug] = facet.name
|
||||
context['facet_slug_names'] = facet_slug_names
|
||||
context['options'] = options
|
||||
else:
|
||||
context['facet_slug_names'] = {}
|
||||
context['options'] = {}
|
||||
|
||||
return context
|
||||
|
||||
def to_html(self, request, queryset, view):
|
||||
template = loader.get_template(self.template)
|
||||
context = self.get_template_context(request, queryset, view)
|
||||
return template.render(context)
|
||||
|
||||
|
||||
class TrigramSearchFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the search.
|
||||
search_param = 'search'
|
||||
template = 'rest_framework/filters/search.html'
|
||||
search_title = _('Search')
|
||||
search_description = _('A search string to perform trigram similarity based searching with.')
|
||||
def get_filters_dict(self, request, view):
|
||||
"""
|
||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||
"""
|
||||
self.filters_dict = {}
|
||||
if 'search' in request.query_params.keys():
|
||||
slug_term = request.query_params.get('search')
|
||||
self.filters_dict['search'] = [slug_term]
|
||||
else:
|
||||
self.filters_dict['search'] = []
|
||||
|
||||
return self.filters_dict
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
search_fields = getattr(view, 'search_fields', None)
|
||||
|
||||
assert search_fields is not None, (
|
||||
f"{view.__class__.__name__} should include a `search_fields` attribute"
|
||||
)
|
||||
|
||||
query = request.query_params.get(self.search_param, '')
|
||||
|
||||
if query:
|
||||
queryset = queryset.annotate(
|
||||
search_field=Concat(
|
||||
*search_fields,
|
||||
output_field=CharField()
|
||||
)).annotate(
|
||||
similarity=TrigramSimilarity('search_field', query)
|
||||
).filter(similarity__gt=0.05).distinct()
|
||||
|
||||
return queryset
|
||||
|
||||
def to_html(self, request, queryset, view):
|
||||
if not getattr(view, 'search_fields', None):
|
||||
return ''
|
||||
|
||||
term = request.query_params.get(self.search_param, '')
|
||||
term = term[0] if term else ''
|
||||
context = {
|
||||
'param': self.search_param,
|
||||
'term': term
|
||||
}
|
||||
template = loader.get_template(self.template)
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.search_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title=force_str(self.search_title),
|
||||
description=force_str(self.search_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.search_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.search_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# TODO
|
||||
#class FieldFilter(BaseFilterBackend):
|
||||
|
||||
|
||||
# TODO misunderstood the urlconf stuff of the RUD methods, this is probably unnecessary
|
||||
class SlugSearchFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the search.
|
||||
template = './filters/slug.html'
|
||||
slug_title = _('Slug Search')
|
||||
slug_description = _("The instance's slug.")
|
||||
slug_field = 'slug'
|
||||
|
||||
def get_slug_field(self, view, request):
|
||||
return getattr(view, 'slug_field', None)
|
||||
|
||||
def get_filters_dict(self, request, view):
|
||||
"""
|
||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||
"""
|
||||
if hasattr(view, 'slug_field'):
|
||||
self.slug_field = self.get_slug_field(view, request)
|
||||
|
||||
assert self.slug_field is not None, (
|
||||
f"{view.__class__.__name__} should include a `slug_field` attribute"
|
||||
)
|
||||
self.filters_dict = {}
|
||||
if self.slug_field in request.query_params.keys():
|
||||
slug_term = request.query_params.get(self.slug_field)
|
||||
self.filters_dict[self.slug_field] = [slug_term]
|
||||
else:
|
||||
self.filters_dict[self.slug_field] = []
|
||||
|
||||
return self.filters_dict
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
# Ensure that the slug field was searched against
|
||||
try:
|
||||
if self.slug_field in request.query_params.keys():
|
||||
slug_term = request.query_params.get(self.slug_field)
|
||||
query = {}
|
||||
query[self.slug_field] = slug_term
|
||||
|
||||
queryset = queryset.get(**query)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
def to_html(self, request, queryset, view):
|
||||
if not getattr(view, 'slug_field', None):
|
||||
return ''
|
||||
|
||||
slug_term = self.get_slug_term(request)
|
||||
context = {
|
||||
'param': self.slug_field,
|
||||
'term': slug_term
|
||||
}
|
||||
template = loader.get_template(self.template)
|
||||
return template.render(context)
|
||||
|
||||
|
||||
class SearchFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the search.
|
||||
search_param = api_settings.SEARCH_PARAM
|
||||
template = 'rest_framework/filters/search.html'
|
||||
lookup_prefixes = {
|
||||
'^': 'istartswith',
|
||||
'=': 'iexact',
|
||||
'@': 'search',
|
||||
'$': 'iregex',
|
||||
}
|
||||
search_title = _('Search')
|
||||
search_description = _('A search term.')
|
||||
# TODO to be removed
|
||||
def get_search_terms(self, request):
|
||||
"""
|
||||
Search terms are set by a ?search=... query parameter,
|
||||
and may be comma and/or whitespace delimited.
|
||||
"""
|
||||
params = request.query_params.get(self.search_param, '')
|
||||
params = params.replace('\x00', '') # strip null characters
|
||||
params = params.replace(',', ' ')
|
||||
return params.split()
|
||||
# TODO to be removed
|
||||
def construct_search(self, field_name):
|
||||
lookup = self.lookup_prefixes.get(field_name[0])
|
||||
if lookup:
|
||||
field_name = field_name[1:]
|
||||
else:
|
||||
lookup = 'icontains'
|
||||
return LOOKUP_SEP.join([field_name, lookup])
|
||||
# TODO to be removed
|
||||
def must_call_distinct(self, queryset, search_fields):
|
||||
"""
|
||||
Return True if 'distinct()' should be used to query the given lookups.
|
||||
"""
|
||||
for search_field in search_fields:
|
||||
opts = queryset.model._meta
|
||||
if search_field[0] in self.lookup_prefixes:
|
||||
search_field = search_field[1:]
|
||||
# Annotated fields do not need to be distinct
|
||||
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
|
||||
continue
|
||||
parts = search_field.split(LOOKUP_SEP)
|
||||
for part in parts:
|
||||
field = opts.get_field(part)
|
||||
if hasattr(field, 'get_path_info'):
|
||||
# This field is a relation, update opts to follow the relation
|
||||
path_info = field.get_path_info()
|
||||
opts = path_info[-1].to_opts
|
||||
if any(path.m2m for path in path_info):
|
||||
# This field is a m2m relation so we know we need to call distinct
|
||||
return True
|
||||
else:
|
||||
# This field has a custom __ query transform but is not a relational field.
|
||||
break
|
||||
return False
|
||||
|
||||
def get_filters_dict(self, request, view):
|
||||
"""
|
||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||
"""
|
||||
self.filters_dict = {}
|
||||
if 'search' in request.query_params.keys():
|
||||
slug_term = request.query_params.get('search')
|
||||
self.filters_dict['search'] = [slug_term]
|
||||
else:
|
||||
self.filters_dict['search'] = []
|
||||
|
||||
return self.filters_dict
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
search_fields = getattr(view, 'search_fields', None)
|
||||
search_terms = self.get_search_terms(request)
|
||||
|
||||
if not search_fields or not search_terms:
|
||||
return queryset
|
||||
|
||||
orm_lookups = [
|
||||
self.construct_search(str(search_field))
|
||||
for search_field in search_fields
|
||||
]
|
||||
|
||||
base = queryset
|
||||
conditions = []
|
||||
for search_term in search_terms:
|
||||
queries = [
|
||||
models.Q(**{orm_lookup: search_term})
|
||||
for orm_lookup in orm_lookups
|
||||
]
|
||||
conditions.append(reduce(operator.or_, queries))
|
||||
queryset = queryset.filter(reduce(operator.and_, conditions))
|
||||
|
||||
if self.must_call_distinct(queryset, search_fields):
|
||||
# Filtering against a many-to-many field requires us to
|
||||
# call queryset.distinct() in order to avoid duplicate items
|
||||
# in the resulting queryset.
|
||||
# We try to avoid this if possible, for performance reasons.
|
||||
queryset = distinct(queryset, base)
|
||||
return queryset
|
||||
|
||||
def to_html(self, request, queryset, view):
|
||||
if not getattr(view, 'search_fields', None):
|
||||
return ''
|
||||
|
||||
term = self.get_search_terms(request)
|
||||
term = term[0] if term else ''
|
||||
context = {
|
||||
'param': self.search_param,
|
||||
'term': term
|
||||
}
|
||||
template = loader.get_template(self.template)
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.search_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title=force_str(self.search_title),
|
||||
description=force_str(self.search_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.search_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.search_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class OrderingFilter(BaseFilterBackend):
|
||||
# The URL query parameter used for the ordering.
|
||||
ordering_param = api_settings.ORDERING_PARAM
|
||||
ordering_fields = None
|
||||
ordering_title = _('Ordering')
|
||||
ordering_description = _('Which field to use when ordering the results.')
|
||||
template = 'rest_framework/filters/ordering.html'
|
||||
|
||||
def get_ordering(self, request, queryset, view):
|
||||
"""
|
||||
Ordering is set by a comma delimited ?ordering=... query parameter.
|
||||
|
||||
The `ordering` query parameter can be overridden by setting
|
||||
the `ordering_param` value on the OrderingFilter or by
|
||||
specifying an `ORDERING_PARAM` value in the API settings.
|
||||
"""
|
||||
params = request.query_params.get(self.ordering_param)
|
||||
if params:
|
||||
fields = [param.strip() for param in params.split(',')]
|
||||
ordering = self.remove_invalid_fields(queryset, fields, view, request)
|
||||
if ordering:
|
||||
return ordering
|
||||
|
||||
# No ordering was included, or all the ordering fields were invalid
|
||||
return self.get_default_ordering(view)
|
||||
|
||||
def get_default_ordering(self, view):
|
||||
ordering = getattr(view, 'ordering', None)
|
||||
if isinstance(ordering, str):
|
||||
return (ordering,)
|
||||
return ordering
|
||||
|
||||
def get_default_valid_fields(self, queryset, view, context={}):
|
||||
# If `ordering_fields` is not specified, then we determine a default
|
||||
# based on the serializer class, if one exists on the view.
|
||||
if hasattr(view, 'get_serializer_class'):
|
||||
try:
|
||||
serializer_class = view.get_serializer_class()
|
||||
except AssertionError:
|
||||
# Raised by the default implementation if
|
||||
# no serializer_class was found
|
||||
serializer_class = None
|
||||
else:
|
||||
serializer_class = getattr(view, 'serializer_class', None)
|
||||
|
||||
if serializer_class is None:
|
||||
msg = (
|
||||
"Cannot use %s on a view which does not have either a "
|
||||
"'serializer_class', an overriding 'get_serializer_class' "
|
||||
"or 'ordering_fields' attribute."
|
||||
)
|
||||
raise ImproperlyConfigured(msg % self.__class__.__name__)
|
||||
|
||||
model_class = queryset.model
|
||||
model_property_names = [
|
||||
# 'pk' is a property added in Django's Model class, however it is valid for ordering.
|
||||
attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk'
|
||||
]
|
||||
|
||||
return [
|
||||
(field.source.replace('.', '__') or field_name, field.label)
|
||||
for field_name, field in serializer_class(context=context).fields.items()
|
||||
if (
|
||||
not getattr(field, 'write_only', False) and
|
||||
not field.source == '*' and
|
||||
field.source not in model_property_names
|
||||
)
|
||||
]
|
||||
|
||||
def get_valid_fields(self, queryset, view, context={}):
|
||||
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
||||
|
||||
if valid_fields is None:
|
||||
# Default to allowing filtering on serializer fields
|
||||
return self.get_default_valid_fields(queryset, view, context)
|
||||
|
||||
elif valid_fields == '__all__':
|
||||
# View explicitly allows filtering on any model field
|
||||
valid_fields = [
|
||||
(field.name, field.verbose_name) for field in queryset.model._meta.fields
|
||||
]
|
||||
valid_fields += [
|
||||
(key, key.title().split('__'))
|
||||
for key in queryset.query.annotations
|
||||
]
|
||||
else:
|
||||
valid_fields = [
|
||||
(item, item) if isinstance(item, str) else item
|
||||
for item in valid_fields
|
||||
]
|
||||
|
||||
return valid_fields
|
||||
|
||||
def remove_invalid_fields(self, queryset, fields, view, request):
|
||||
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
|
||||
|
||||
def term_valid(term):
|
||||
if term.startswith("-"):
|
||||
term = term[1:]
|
||||
return term in valid_fields
|
||||
|
||||
return [term for term in fields if term_valid(term)]
|
||||
|
||||
def get_filters_dict(self, request, view):
|
||||
"""
|
||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||
"""
|
||||
self.filters_dict = {}
|
||||
if 'ordering' in request.query_params.keys():
|
||||
slug_term = request.query_params.get('ordering')
|
||||
self.filters_dict['ordering'] = [slug_term]
|
||||
else:
|
||||
self.filters_dict['ordering'] = [view.ordering_fields[0]]
|
||||
|
||||
return self.filters_dict
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
ordering = self.get_ordering(request, queryset, view)
|
||||
|
||||
if ordering:
|
||||
return queryset.order_by(*ordering)
|
||||
|
||||
return queryset
|
||||
|
||||
def get_template_context(self, request, queryset, view):
|
||||
current = self.get_ordering(request, queryset, view)
|
||||
current = None if not current else current[0]
|
||||
options = []
|
||||
context = {
|
||||
'request': request,
|
||||
'current': current,
|
||||
'param': self.ordering_param,
|
||||
}
|
||||
for key, label in self.get_valid_fields(queryset, view, context):
|
||||
options.append((key, '%s - %s' % (label, _('ascending'))))
|
||||
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
|
||||
context['options'] = options
|
||||
return context
|
||||
|
||||
def to_html(self, request, queryset, view):
|
||||
template = loader.get_template(self.template)
|
||||
context = self.get_template_context(request, queryset, view)
|
||||
return template.render(context)
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(
|
||||
name=self.ordering_param,
|
||||
required=False,
|
||||
location='query',
|
||||
schema=coreschema.String(
|
||||
title=force_str(self.ordering_title),
|
||||
description=force_str(self.ordering_description)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def get_schema_operation_parameters(self, view):
|
||||
return [
|
||||
{
|
||||
'name': self.ordering_param,
|
||||
'required': False,
|
||||
'in': 'query',
|
||||
'description': force_str(self.ordering_description),
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
},
|
||||
},
|
||||
]
|
||||
214
starfields_drf_generics/generics.py
Normal file
214
starfields_drf_generics/generics.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Generic views that provide commonly needed behaviour.
|
||||
"""
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db.models.query import QuerySet
|
||||
from django.http import Http404
|
||||
from django.shortcuts import get_object_or_404 as _get_object_or_404
|
||||
|
||||
from rest_framework import views
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from libraries import mixins
|
||||
|
||||
|
||||
# Concrete view classes that provide method handlers
|
||||
# by composing the mixin classes with the base view.
|
||||
|
||||
# Single item CRUD
|
||||
|
||||
class CachedCreateAPIView(mixins.CachedCreateModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for creating a model instance.
|
||||
"""
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.create(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedRetrieveAPIView(mixins.CachedRetrieveModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for retrieving a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for updating a model instance.
|
||||
"""
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.partial_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedDestroyAPIView(mixins.CachedDestroyModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for deleting a model instance.
|
||||
"""
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for retrieving, updating a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.partial_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for retrieving or deleting a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for retrieving, updating or deleting a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.partial_update(request, *args, **kwargs)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for creating, retrieving, updating or deleting a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.retrieve(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.create(request, *args, **kwargs)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.partial_update(request, *args, **kwargs)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
# List based CRUD
|
||||
|
||||
class CachedListRetrieveAPIView(mixins.CachedListRetrieveModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for listing a queryset.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.list(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedListCreateAPIView(mixins.CachedListCreateModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for creating multiple instances.
|
||||
"""
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.list_create(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for updating multiple instances.
|
||||
"""
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.list_update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.list_partial_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedListDestroyAPIView(mixins.CachedListDestroyModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for deleting multiple instances.
|
||||
"""
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.list_destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,mixins.CachedListCreateModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for listing a queryset or creating a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.list(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.create(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedListCreateRetrieveDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for creating, retrieving or deleting a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.list(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.list_create(request, *args, **kwargs)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.list_destroy(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedListCreateRetrieveUpdateAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for creating, retrieving, updating or deleting a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.list(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.list_create(request, *args, **kwargs)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.list_update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.list_partial_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedListCreateRetrieveUpdateDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView):
|
||||
"""
|
||||
Concrete view for creating, retrieving, updating or deleting a model instance.
|
||||
"""
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.list(request, *args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
return self.list_create(request, *args, **kwargs)
|
||||
|
||||
def put(self, request, *args, **kwargs):
|
||||
return self.list_update(request, *args, **kwargs)
|
||||
|
||||
def patch(self, request, *args, **kwargs):
|
||||
return self.list_partial_update(request, *args, **kwargs)
|
||||
|
||||
def delete(self, request, *args, **kwargs):
|
||||
return self.list_destroy(request, *args, **kwargs)
|
||||
211
starfields_drf_generics/mixins.py
Normal file
211
starfields_drf_generics/mixins.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Basic building blocks for generic class based views.
|
||||
|
||||
We don't bind behaviour to http method handlers yet,
|
||||
which allows mixin classes to be composed in interesting ways.
|
||||
"""
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework import mixins
|
||||
from libraries.cache_mixins import CacheGetMixin, CacheSetMixin, CacheDeleteMixin
|
||||
|
||||
|
||||
# Mixin classes to be included in the generic classes
|
||||
class CachedCreateModelMixin(CacheDeleteMixin, mixins.CreateModelMixin):
|
||||
"""
|
||||
A slightly modified version of rest_framework.mixins.CreateModelMixin that handles cache deletions.
|
||||
"""
|
||||
def create(self, request, *args, **kwargs):
|
||||
# Go on with the creation as normal
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
|
||||
# Delete the cache
|
||||
self.delete_cache(request)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||
|
||||
|
||||
class CachedRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
|
||||
"""
|
||||
A slightly modified version of rest_framework.mixins.RetrieveModelMixin that handles cache attempts.
|
||||
mixins.RetrieveModelMixin only has the retrieve method so it doesn't stand to inherit anything from it.
|
||||
"""
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
# Attempt to get the request from the cache
|
||||
cache_attempt = self.get_cache(request)
|
||||
|
||||
if cache_attempt:
|
||||
return Response(cache_attempt)
|
||||
else:
|
||||
# The cache get attempt failed so we have to get the results from the database
|
||||
instance = self.get_object()
|
||||
|
||||
serializer = self.get_serializer(instance)
|
||||
response = Response(serializer.data)
|
||||
|
||||
self.set_cache(request, response)
|
||||
return response
|
||||
|
||||
|
||||
class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin):
|
||||
"""
|
||||
A slightly modified version of rest_framework.mixins.UpdateModelMixin that handles cache deletes.
|
||||
"""
|
||||
def update(self, request, *args, **kwargs):
|
||||
partial = kwargs.pop('partial', False)
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance, data=request.data, partial=partial)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_update(serializer)
|
||||
|
||||
if getattr(instance, '_prefetched_objects_cache', None):
|
||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||
# forcibly invalidate the prefetch cache on the instance.
|
||||
instance._prefetched_objects_cache = {}
|
||||
|
||||
# Delete the related caches
|
||||
self.delete_cache(request)
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class CachedDestroyModelMixin(CacheDeleteMixin, mixins.DestroyModelMixin):
|
||||
"""
|
||||
A slightly modified version of rest_framework.mixins.DestroyModelMixin that handles cache deletes.
|
||||
"""
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
instance = self.get_object()
|
||||
self.perform_destroy(instance)
|
||||
|
||||
# Delete the related caches
|
||||
self.delete_cache(request)
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
# List mixin classes to be included with list generic classes
|
||||
class CachedListCreateModelMixin(CacheDeleteMixin):
|
||||
"""
|
||||
A fully custom mixin that handles mutiple instance cration.
|
||||
"""
|
||||
def list_create(self, request, *args, **kwargs):
|
||||
# Go on with the creation as normal
|
||||
serializer = self.get_serializer(data=request.data, many=True)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
|
||||
# Delete the cache
|
||||
self.delete_cache(request)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save()
|
||||
|
||||
def get_success_headers(self, data):
|
||||
try:
|
||||
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
|
||||
except (TypeError, KeyError):
|
||||
return {}
|
||||
|
||||
|
||||
class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
|
||||
"""
|
||||
A slightly modified version of rest_framework.mixins.ListModelMixin that handles cache saves.
|
||||
mixins.ListModelMixin only has the list method so it doesn't stand to inherit anything from it.
|
||||
"""
|
||||
def list(self, request, *args, **kwargs):
|
||||
# Attempt to get the request from the cache
|
||||
cache_attempt = self.get_cache(request)
|
||||
|
||||
if cache_attempt:
|
||||
return Response(cache_attempt)
|
||||
else:
|
||||
# The cache get attempt failed so we have to get the results from the database
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
if self.paged:
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
response = self.get_paginated_response(serializer.data)
|
||||
else:
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
response = Response(serializer.data)
|
||||
else:
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
response = Response(serializer.data)
|
||||
|
||||
self.set_cache(request, response)
|
||||
return response
|
||||
|
||||
|
||||
class CachedListUpdateModelMixin(CacheDeleteMixin):
|
||||
"""
|
||||
A fully custom mixin that handles mutiple instance updates.
|
||||
"""
|
||||
def list_update(self, request, *args, **kwargs):
|
||||
partial = kwargs.pop('partial', False)
|
||||
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
serializer = self.get_serializer(queryset, data=request.data, partial=partial, many=True)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_update(serializer)
|
||||
|
||||
# Delete the related caches
|
||||
self.delete_cache(request)
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
def perform_update(self, serializer):
|
||||
serializer.save()
|
||||
|
||||
def list_partial_update(self, request, *args, **kwargs):
|
||||
kwargs['partial'] = True
|
||||
return self.list_update(request, *args, **kwargs)
|
||||
|
||||
|
||||
class CachedListDestroyModelMixin(CacheDeleteMixin):
|
||||
"""
|
||||
A fully custom mixin that handles mutiple instance deletions.
|
||||
"""
|
||||
def list_destroy(self, request, *args, **kwargs):
|
||||
# Go on with the validation as normal
|
||||
serializer = self.get_serializer(data=request.data, many=True)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
validated_data = serializer.validated_data
|
||||
|
||||
# TODO does this new stuff work even? need to check on the frontend
|
||||
serializer.delete(validated_data)
|
||||
|
||||
# for instance in self.get_objects():
|
||||
# if instance is not None:
|
||||
# self.perform_destroy(instance)
|
||||
|
||||
# Delete the related caches
|
||||
self.delete_cache(request)
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
#def perform_destroy(self, instance):
|
||||
# instance.delete()
|
||||
|
||||
#def get_objects(self):
|
||||
# """
|
||||
# The custom list version of get_object that retrieves one instance from the #database. It yields model instances with each call.
|
||||
# """
|
||||
# queryset = self.filter_queryset(self.get_queryset())
|
||||
#
|
||||
# if len(queryset):
|
||||
# for obj in queryset.all():
|
||||
|
||||
# # May raise a permission denied
|
||||
# self.check_object_permissions(self.request, obj)
|
||||
|
||||
# yield obj
|
||||
|
||||
#yield None
|
||||
24
starfields_drf_generics/utils.py
Normal file
24
starfields_drf_generics/utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
def sorted_params_string(filters_dict):
|
||||
"""
|
||||
This function takes a dict and returns it in a sorted form for the url filter, it's primarily used for cache purposes.
|
||||
"""
|
||||
filters_string = ''
|
||||
for key in sorted(filters_dict.keys()):
|
||||
if filters_string == '':
|
||||
filters_string = f"{key}={','.join(str(val) for val in sorted(filters_dict[key]))}"
|
||||
else:
|
||||
filters_string = f"{filters_string}&{key}={','.join(str(val) for val in sorted(filters_dict[key]))}"
|
||||
filters_string = filters_string.strip()
|
||||
return filters_string
|
||||
|
||||
|
||||
def parse_tags_to_dict(tags):
|
||||
tagdict = {}
|
||||
if ':' not in tags:
|
||||
tagdict = {}
|
||||
else:
|
||||
for subtag in tags.split('&'):
|
||||
tagkey, taglist = subtag.split(':')
|
||||
taglist = taglist.split(',')
|
||||
tagdict[tagkey] = set(taglist)
|
||||
return tagdict
|
||||
Reference in New Issue
Block a user