Performed some general linting. More is needed along with cleanup
Some checks failed
Lint / Lint (push) Failing after 22s

This commit is contained in:
2024-05-13 12:53:57 +03:00
parent 18812423c5
commit 4e2b0ec0c6
5 changed files with 394 additions and 262 deletions

View File

@@ -1,36 +1,45 @@
from starfields_drf_generics.utils import sorted_params_string
# TODO classes below that involve create, update, destroy don't delete the caches properly, they need a regex cache delete
# TODO classes below that involve create, update, destroy don't delete the
# caches properly, they need a regex cache delete
# TODO put more reasonable asserts and feedback
# Mixin classes that provide cache functionalities
class CacheUniqueUrl:
def get_cache_unique_url(self, request):
""" Create the query to be cached in a unique way to avoid duplicates. """
"""
Create the query to be cached in a unique way to avoid duplicates.
"""
if not hasattr(self, 'filters_string'):
# Only assign the attribute if it's not already assigned
filters = {}
if self.extra_filters_dict:
filters.update(self.extra_filters_dict)
# Check if the url parameters have any of the keys of the extra filters and if so assign them
# Check if the url parameters have any of the keys of the extra
# filters and if so assign them
for key in self.extra_filters_dict:
if key in self.request.query_params.keys():
filters[key] = self.request.query_params[key].replace(' ','').split(',')
filters[key] = self.request.query_params[key].replace(
' ', '').split(',')
# Check if they're resolved in the urlconf as well
if key in self.kwargs.keys():
filters[key] = [self.kwargs[key]]
if hasattr(self, 'paged'):
if self.paged:
filters.update({'limit': [self.default_page_size], 'offset': [0]})
filters.update({'limit': [self.default_page_size],
'offset': [0]})
if 'limit' in self.request.query_params.keys():
filters.update({'limit': [self.request.query_params['limit']]})
filters.update({
'limit': [self.request.query_params['limit']]})
if 'offset' in self.request.query_params.keys():
filters.update({'offset': [self.request.query_params['offset']]})
filters.update({
'offset': [self.request.query_params['offset']]})
for backend in list(self.filter_backends):
filters.update(backend().get_filters_dict(request, self))
self.filters_string = sorted_params_string(filters)
class CacheGetMixin(CacheUniqueUrl):
cache_prefix = None
@@ -38,30 +47,34 @@ class CacheGetMixin(CacheUniqueUrl):
cache_timeout_mins = None
default_page_size = 20
extra_filters_dict = None
def get_cache(self, request):
assert self.cache_prefix is not None, (
"'%s' should include a `cache_prefix` attribute"
% self.__class__.__name__
)
self.get_cache_unique_url(request)
# Attempt to get the response from the cache for the whole request
try:
if self.cache_vary_on_user:
cache_attempt = self.cache.get(f"{self.cache_prefix}.{request.user}.{self.filters_string}")
cache_attempt = self.cache.get(
f"{self.cache_prefix}.{request.user}.{self.filters_string}"
)
else:
cache_attempt = self.cache.get(f"{self.cache_prefix}.{self.filters_string}")
except:
self.logger.info(f"Cache get attempt for {self.__class__.__name__} failed.")
cache_attempt = self.cache.get(
f"{self.cache_prefix}.{self.filters_string}")
except Exception:
self.logger.info(f"Cache get attempt for {self.__class__.__name__}"
" failed.")
cache_attempt = None
if cache_attempt:
return cache_attempt
else:
return None
class CacheSetMixin(CacheUniqueUrl):
cache_prefix = None
@@ -69,31 +82,34 @@ class CacheSetMixin(CacheUniqueUrl):
cache_timeout_mins = None
default_page_size = 20
extra_filters_dict = None
def set_cache(self, request, response):
self.get_cache_unique_url(request)
# Create a function that programmatically defines the caching function
def make_caching_function(cls, request, cache):
def caching_function(response):
# Writes the response to the cache
try:
if self.cache_vary_on_user:
self.cache.set(key=f"{self.cache_prefix}.{request.user}.{self.filters_string}",
value = response.data,
timeout=60*self.cache_timeout_mins)
self.cache.set(key=f"{self.cache_prefix}."
f"{request.user}.{self.filters_string}",
value=response.data,
timeout=60*self.cache_timeout_mins)
else:
self.cache.set(key=f"{self.cache_prefix}.{self.filters_string}",
value = response.data,
timeout=60*self.cache_timeout_mins)
except:
self.logger.exception(f"Cache set attempt for {self.__class__.__name__} failed.")
self.cache.set(key=f"{self.cache_prefix}"
f".{self.filters_string}",
value=response.data,
timeout=60*self.cache_timeout_mins)
except Exception:
self.logger.exception("Cache set attempt for "
f"{self.__class__.__name__} failed.")
return caching_function
# Register the post rendering hook to the response
caching_function = make_caching_function(self, request, self.cache)
response.add_post_render_callback(caching_function)
class CacheDeleteMixin(CacheUniqueUrl):
cache_delete = True
@@ -101,22 +117,26 @@ class CacheDeleteMixin(CacheUniqueUrl):
cache_vary_on_user = False
cache_timeout_mins = None
extra_filters_dict = None
def delete_cache(self, request):
# Handle the caching
if self.cache_delete:
# Create the query to be cached in a unique way to avoid duplicates
self.get_cache_unique_url(request)
assert self.cache_prefix is not None, (
f"{self.__class__.__name__} should include a `cache_prefix` attribute"
f"{self.__class__.__name__} should include a `cache_prefix`"
"attribute"
)
# Delete the cache since a new entry has been created
try:
if self.cache_vary_on_user:
self.cache.delete(f"{self.cache_prefix}.{request.user}.{self.filters_string}")
self.cache.delete(f"{self.cache_prefix}.{request.user}"
f".{self.filters_string}")
else:
self.cache.delete(f"{self.cache_prefix}.{self.filters_string}")
except:
self.logger.exception(f"Cache delete attempt for {self.__class__.__name__} failed.")
self.cache.delete(f"{self.cache_prefix}"
f".{self.filters_string}")
except Exception:
self.logger.exception("Cache delete attempt for "
f"{self.__class__.__name__} failed.")

View File

@@ -1,34 +1,36 @@
from django_filters import rest_framework as filters
from django.contrib.postgres.search import TrigramSimilarity
from django.db.models.functions import Concat
from django.db.models import CharField
from rest_framework.filters import BaseFilterBackend
import operator
from functools import reduce
from rest_framework.settings import api_settings
from rest_framework.filters import BaseFilterBackend
from django.template import loader
from django.utils.translation import gettext_lazy as _
from django.db.models import Max, Min, Count, Q
from rest_framework.settings import api_settings
from django.db.models.constants import LOOKUP_SEP
from django.db import models
from functools import reduce
from django.db.models import Q, CharField
from django.db.models.constants import LOOKUP_SEP
from django.db.models.functions import Concat
from django.contrib.postgres.search import TrigramSimilarity
# TODO the dev pages are not done
# Filters
class LessThanOrEqualFilter(BaseFilterBackend):
def get_less_than_field(self, view, request):
return getattr(view, 'less_than_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
less_than_field = self.get_less_than_field(view, request)
assert less_than_field is not None, (
f"{view.__class__.__name__} should include a `less_than_field` attribute"
f"{view.__class__.__name__} should include a `less_than_field`"
"attribute"
)
filters_dict = {}
if less_than_field+'_max' in request.query_params.keys():
field_value = request.query_params.get(less_than_field+'_max')
@@ -36,9 +38,10 @@ class LessThanOrEqualFilter(BaseFilterBackend):
else:
filters_dict[less_than_field+'_max'] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
# Return the correctly filtered queryset but also assign the filter dict to create the unique url for the cache
# Return the correctly filtered queryset but also assign the filter
# dict to create the unique url for the cache
less_than_field = self.get_less_than_field(view, request)
if less_than_field+'_max' in request.query_params.keys():
kwquery = {}
@@ -52,17 +55,19 @@ class LessThanOrEqualFilter(BaseFilterBackend):
class MoreThanOrEqualFilter(BaseFilterBackend):
def get_more_than_field(self, view, request):
return getattr(view, 'more_than_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
more_than_field = self.get_more_than_field(view, request)
assert more_than_field is not None, (
f"{view.__class__.__name__} should include a `more_than_field` attribute"
f"{view.__class__.__name__} should include a `more_than_field`"
"attribute"
)
filters_dict = {}
if more_than_field+'_min' in request.query_params.keys():
field_value = request.query_params.get(more_than_field+'_min')
@@ -70,7 +75,7 @@ class MoreThanOrEqualFilter(BaseFilterBackend):
else:
filters_dict[more_than_field+'_min'] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
"""
Return the correctly filtered queryset
@@ -78,7 +83,8 @@ class MoreThanOrEqualFilter(BaseFilterBackend):
more_than_field = self.get_more_than_field(view, request)
if more_than_field+'_min' in request.query_params.keys():
kwquery = {}
kwquery[more_than_field+'__gte'] = request.query_params.get(more_than_field+'_min')
kwquery[more_than_field+'__gte'] = request.query_params.get(
more_than_field+'_min')
return queryset.filter(**kwquery)
else:
return queryset
@@ -86,82 +92,90 @@ class MoreThanOrEqualFilter(BaseFilterBackend):
class CategoryFilter(BaseFilterBackend):
"""
This filter assigns the view.category object for later use, in particular for filters that depend on this one.
This filter assigns the view.category object for later use, in particular
for filters that depend on this one.
"""
template = 'starfields_drf_generics/templates/filters/categories.html'
category_field = 'category'
def get_category_class(self, view, request):
return getattr(view, 'category_class', None)
def assign_view_category(self, request, view):
if not hasattr(view, self.category_field):
if self.category_field in request.query_params.keys():
try:
category_slug = request.query_params.get(self.category_field).strip()
category = view.category_class.objects.get(slug=category_slug)
# Append the category object in the view for later reference
category_slug = request.query_params.get(
self.category_field).strip()
category = view.category_class.objects.get(
slug=category_slug)
# Append the category object in the view for later use
view.category = category
except:
except Exception:
view.category = None
else:
view.category = None
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. Queries the database for the current category and saves it in view.category for internal use and later filters.
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes. Queries the database for the current
category and saves it in view.category for internal use and later
filters.
"""
if hasattr(view, 'category_field'):
self.category_field = self.get_category_field(view, request)
category_class = self.get_category_class(view, request)
assert category_class is not None, (
f"{view.__class__.__name__} should include a `category_class` attribute"
f"{view.__class__.__name__} should include a `category_class`"
"attribute"
)
# Create the filters dictionary and find the present category instance
self.assign_view_category(request, view)
filters_dict = {}
if view.category:
filters_dict[self.category_field] = [view.category.slug]
else:
filters_dict[self.category_field] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
self.assign_view_category(request, view)
# Create the queryset
if view.category:
kwquery_1 = {}
kwquery_1[self.category_field+'__id'] = view.category.id
if view.category.tn_descendants_pks:
kwquery_2 = {}
kwquery_2[self.category_field+'__id__in'] = view.category.tn_descendants_pks.split(',')
key = self.category_field+'__id__in'
kwquery_2[key] = view.category.tn_descendants_pks.split(',')
queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2))
else:
queryset = queryset.filter(**kwquery_1)
return queryset
# Developer Interface methods
def get_valid_fields(self, queryset, view, context, request):
# A query is executed here to get the possible categories
category_class = self.get_category_class(view, request)
if hasattr(view, 'category_field'):
self.category_field = self.get_category_field(view, request)
assert category_class is not None, (
f"{view.__class__.__name__} should include a `category_class` attribute"
f"{view.__class__.__name__} should include a `category_class`"
"attribute"
)
valid_fields = category_class.objects.all()
if len(valid_fields):
valid_fields = [
(item.slug, item.__str__()) for item in valid_fields
@@ -170,7 +184,7 @@ class CategoryFilter(BaseFilterBackend):
return valid_fields
else:
return []
def get_current(self, request, queryset, view):
params = request.query_params.get(self.category_field)
if params:
@@ -178,7 +192,7 @@ class CategoryFilter(BaseFilterBackend):
return fields[0]
else:
return None
def get_template_context(self, request, queryset, view):
current = self.get_current(request, queryset, view)
options = []
@@ -187,11 +201,12 @@ class CategoryFilter(BaseFilterBackend):
'current': current,
'param': self.category_field,
}
for key, label in self.get_valid_fields(queryset, view, context, request):
valid_fields = self.get_valid_fields(queryset, view, context, request)
for key, label in valid_fields:
options.append((key, '%s' % (label)))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
@@ -200,112 +215,135 @@ class CategoryFilter(BaseFilterBackend):
class FacetFilter(BaseFilterBackend):
"""
This filter requires CategoryFilter to be ran before it. It assigns the view.facets which includes all the facets applicable to the current category.
This filter requires CategoryFilter to be ran before it. It assigns the
view.facets which includes all the facets applicable to the current
category.
"""
template = 'starfields_drf_generics/templates/filters/facets.html'
def get_facet_class(self, view, request):
return getattr(view, 'facet_class', None)
def get_facet_tag_class(self, view, request):
return getattr(view, 'facet_tag_class', None)
def get_facet_tag_field(self, view, request):
return getattr(view, 'facet_tag_field', None)
def assign_view_facets(self, request, view):
if not hasattr(view, 'facets'):
if hasattr(view, 'facet_class'):
self.facet_class = self.get_facet_class(view, request)
assert self.facet_class is not None, (
f"{view.__class__.__name__} should include a `facet_class` attribute"
f"{view.__class__.__name__} should include a `facet_class`"
"attribute"
)
if view.category:
if view.category.tn_ancestors_pks:
view.facets = self.facet_class.objects.filter(Q(category__id=view.category.id) | Q(category__id__in=view.category.tn_ancestors_pks.split(','))).prefetch_related('facet_tags')
ancestor_ids = view.category.tn_ancestors_pks.split(',')
view.facets = self.facet_class.objects.filter(
Q(category__id=view.category.id) | Q(
category__id__in=ancestor_ids)
).prefetch_related('facet_tags')
else:
view.facets = self.facet_class.objects.filter(category__id=view.category.id).prefetch_related('facet_tags')
view.facets = self.facet_class.objects.filter(
category__id=view.category.id).prefetch_related(
'facet_tags')
else:
view.facets = self.facet_class.objects.filter(category__tn_level=1).prefetch_related('facet_tags')
view.facets = self.facet_class.objects.filter(
category__tn_level=1).prefetch_related(
'facet_tags')
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
if hasattr(view, 'facet_class'):
self.facet_class = self.get_facet_class(view, request)
assert self.facet_class is not None, (
f"{view.__class__.__name__} should include a `facet_class` attribute"
f"{view.__class__.__name__} should include a `facet_class`"
"attribute"
)
self.assign_view_facets(request, view)
filters_dict = {}
if view.facets:
for facet in view.facets:
if facet.slug in request.query_params.keys():
filters_dict[facet.slug] = set(request.query_params[facet.slug].split(','))
filters_dict[facet.slug] = set(
request.query_params[facet.slug].split(','))
else:
filters_dict[facet.slug] = set({})
# Append the facets object and the tags dict in the view for later reference
# Append the facets object and the tags dict in the view for later
# reference
view.tags = filters_dict
return filters_dict
def filter_queryset(self, request, queryset, view):
if hasattr(view, 'facet_tag_class'):
self.facet_tag_class = self.get_facet_tag_class(view, request)
assert self.facet_tag_class is not None, (
f"{view.__class__.__name__} should include a `facet_tag_class` attribute"
f"{view.__class__.__name__} should include a `facet_tag_class`"
"attribute"
)
if hasattr(view, 'facet_tag_field'):
self.facet_tag_field = self.get_facet_tag_field(view, request)
assert self.facet_tag_field is not None, (
f"{view.__class__.__name__} should include a `facet_tag_field` attribute"
f"{view.__class__.__name__} should include a `facet_tag_field`"
"attribute"
)
self.assign_view_facets(request, view)
if view.facets:
for facet in view.facets:
if facet.slug in request.query_params.keys() and request.query_params[facet.slug]:
request_slug = request.query_params[facet.slug]
if facet.slug in request.query_params.keys() and request_slug:
tag_filterlist = request.query_params.get(facet.slug)
if tag_filterlist == '':
# If the tag filterlist is empty then we're not filtering against it, it's like having all the tags of the facet selected
# If the tag filterlist is empty then we're not
# filtering against it, it's like having all the tags
# of the facet selected
pass
else:
kwquery = {}
kwquery[self.facet_tag_field+'__slug__in'] = tag_filterlist.replace(' ','').split(',')
key = self.facet_tag_field+'__slug__in'
kwquery[key] = tag_filterlist.replace(' ', '').split(
',')
queryset = queryset.filter(**kwquery)
return queryset
# Developer Interface methods
def get_template_context(self, request, queryset, view):
# Does aggressive database querying to get the necessary facets and facettags, but this is only for the developer interface so its fine
# Does aggressive database querying to get the necessary facets and
# facettags, but this is only for the developer interface so its fine
if hasattr(view, 'facet_class'):
self.facet_class = self.get_facet_class(view, request)
assert self.facet_class is not None, (
f"{view.__class__.__name__} should include a `facet_class` attribute"
f"{view.__class__.__name__} should include a `facet_class`"
"attribute"
)
if hasattr(view, 'facet_tag_class'):
self.facet_tag_class = self.get_facet_tag_class(view, request)
assert self.facet_tag_class is not None, (
f"{view.__class__.__name__} should include a `facet_tag_class` attribute"
f"{view.__class__.__name__} should include a `facet_tag_class`"
"attribute"
)
# Find the current choices
current = []
facet_slugs = []
@@ -314,7 +352,7 @@ class FacetFilter(BaseFilterBackend):
facet_slugs.append(facet.slug)
if facet.slug in request.query_params.keys():
current.append(request.query_params.get(facet.slug))
facet_slug_names = {}
options = {}
context = {
@@ -324,17 +362,19 @@ class FacetFilter(BaseFilterBackend):
}
if view.facets:
for facet in view.facets:
facet_tag_instances = self.facet_tag_class.objects.filter(facet__slug=facet.slug)
options[facet.slug] = [(facet_tag.slug, facet_tag.name) for facet_tag in facet_tag_instances]
facet_tag_instances = self.facet_tag_class.objects.filter(
facet__slug=facet.slug)
options[facet.slug] = [(facet_tag.slug, facet_tag.name) for
facet_tag in facet_tag_instances]
facet_slug_names[facet.slug] = facet.name
context['facet_slug_names'] = facet_slug_names
context['options'] = options
else:
context['facet_slug_names'] = {}
context['options'] = {}
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
@@ -346,10 +386,13 @@ class TrigramSearchFilter(BaseFilterBackend):
search_param = 'search'
template = 'rest_framework/filters/search.html'
search_title = _('Search')
search_description = _('A search string to perform trigram similarity based searching with.')
search_description = _('A search string to perform trigram similarity'
'based searching with.')
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
self.filters_dict = {}
if 'search' in request.query_params.keys():
@@ -357,18 +400,19 @@ class TrigramSearchFilter(BaseFilterBackend):
self.filters_dict['search'] = [slug_term]
else:
self.filters_dict['search'] = []
return self.filters_dict
def filter_queryset(self, request, queryset, view):
search_fields = getattr(view, 'search_fields', None)
assert search_fields is not None, (
f"{view.__class__.__name__} should include a `search_fields` attribute"
f"{view.__class__.__name__} should include a `search_fields`"
"attribute"
)
query = request.query_params.get(self.search_param, '')
if query:
queryset = queryset.annotate(
search_field=Concat(
@@ -377,7 +421,7 @@ class TrigramSearchFilter(BaseFilterBackend):
)).annotate(
similarity=TrigramSimilarity('search_field', query)
).filter(similarity__gt=0.05).distinct()
return queryset
def to_html(self, request, queryset, view):
@@ -436,14 +480,14 @@ class SlugSearchFilter(BaseFilterBackend):
def get_slug_field(self, view, request):
return getattr(view, 'slug_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
if hasattr(view, 'slug_field'):
self.slug_field = self.get_slug_field(view, request)
assert self.slug_field is not None, (
f"{view.__class__.__name__} should include a `slug_field` attribute"
)
@@ -453,9 +497,9 @@ class SlugSearchFilter(BaseFilterBackend):
self.filters_dict[self.slug_field] = [slug_term]
else:
self.filters_dict[self.slug_field] = []
return self.filters_dict
def filter_queryset(self, request, queryset, view):
# Ensure that the slug field was searched against
try:
@@ -463,18 +507,18 @@ class SlugSearchFilter(BaseFilterBackend):
slug_term = request.query_params.get(self.slug_field)
query = {}
query[self.slug_field] = slug_term
queryset = queryset.get(**query)
except Exception as e:
print(e)
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'slug_field', None):
return ''
slug_term = self.get_slug_term(request)
context = {
'param': self.slug_field,
@@ -540,7 +584,7 @@ class SearchFilter(BaseFilterBackend):
# This field has a custom __ query transform but is not a relational field.
break
return False
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
@@ -551,7 +595,7 @@ class SearchFilter(BaseFilterBackend):
self.filters_dict['search'] = [slug_term]
else:
self.filters_dict['search'] = []
return self.filters_dict
def filter_queryset(self, request, queryset, view):
@@ -728,7 +772,7 @@ class OrderingFilter(BaseFilterBackend):
return term in valid_fields
return [term for term in fields if term_valid(term)]
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
@@ -739,9 +783,9 @@ class OrderingFilter(BaseFilterBackend):
self.filters_dict['ordering'] = [slug_term]
else:
self.filters_dict['ordering'] = [view.ordering_fields[0]]
return self.filters_dict
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)

View File

@@ -1,15 +1,7 @@
"""
Generic views that provide commonly needed behaviour.
"""
from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
from rest_framework import views
from rest_framework.generics import GenericAPIView
from rest_framework.settings import api_settings
from starfields_drf_generics import mixins
@@ -18,26 +10,32 @@ from starfields_drf_generics import mixins
# Single item CRUD
class CachedCreateAPIView(mixins.CachedCreateModelMixin,GenericAPIView):
class CachedCreateAPIView(mixins.CachedCreateModelMixin,
GenericAPIView):
"""
Concrete view for creating a model instance.
"""
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
class CachedRetrieveAPIView(mixins.CachedRetrieveModelMixin,GenericAPIView):
class CachedRetrieveAPIView(mixins.CachedRetrieveModelMixin,
GenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,GenericAPIView):
class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,
GenericAPIView):
"""
Concrete view for updating a model instance.
"""
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
@@ -45,18 +43,23 @@ class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,GenericAPIView):
return self.partial_update(request, *args, **kwargs)
class CachedDestroyAPIView(mixins.CachedDestroyModelMixin,GenericAPIView):
class CachedDestroyAPIView(mixins.CachedDestroyModelMixin,
GenericAPIView):
"""
Concrete view for deleting a model instance.
"""
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,GenericAPIView):
class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,
mixins.CachedUpdateModelMixin,
GenericAPIView):
"""
Concrete view for retrieving, updating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
@@ -67,10 +70,13 @@ class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedU
return self.partial_update(request, *args, **kwargs)
class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,
mixins.CachedDestroyModelMixin,
GenericAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
@@ -78,10 +84,14 @@ class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.Cached
return self.destroy(request, *args, **kwargs)
class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,
mixins.CachedUpdateModelMixin,
mixins.CachedDestroyModelMixin,
GenericAPIView):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
@@ -95,10 +105,16 @@ class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.
return self.destroy(request, *args, **kwargs)
class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,
mixins.CachedRetrieveModelMixin,
mixins.CachedUpdateModelMixin,
mixins.CachedDestroyModelMixin,
GenericAPIView):
"""
Concrete view for creating, retrieving, updating or deleting a model instance.
Concrete view for creating, retrieving, updating or deleting a model
instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
@@ -117,26 +133,32 @@ class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,mix
# List based CRUD
class CachedListRetrieveAPIView(mixins.CachedListRetrieveModelMixin,GenericAPIView):
class CachedListRetrieveAPIView(mixins.CachedListRetrieveModelMixin,
GenericAPIView):
"""
Concrete view for listing a queryset.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
class CachedListCreateAPIView(mixins.CachedListCreateModelMixin,GenericAPIView):
class CachedListCreateAPIView(mixins.CachedListCreateModelMixin,
GenericAPIView):
"""
Concrete view for creating multiple instances.
"""
def post(self, request, *args, **kwargs):
return self.list_create(request, *args, **kwargs)
class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,GenericAPIView):
class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,
GenericAPIView):
"""
Concrete view for updating multiple instances.
"""
def put(self, request, *args, **kwargs):
return self.list_update(request, *args, **kwargs)
@@ -144,18 +166,23 @@ class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,GenericAPIView):
return self.list_partial_update(request, *args, **kwargs)
class CachedListDestroyAPIView(mixins.CachedListDestroyModelMixin,GenericAPIView):
class CachedListDestroyAPIView(mixins.CachedListDestroyModelMixin,
GenericAPIView):
"""
Concrete view for deleting multiple instances.
"""
def delete(self, request, *args, **kwargs):
return self.list_destroy(request, *args, **kwargs)
class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,mixins.CachedListCreateModelMixin,GenericAPIView):
class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,
mixins.CachedListCreateModelMixin,
GenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
@@ -163,10 +190,15 @@ class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,mixins
return self.create(request, *args, **kwargs)
class CachedListCreateRetrieveDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView):
class CachedListCreateRetrieveDestroyAPIView(
mixins.CachedListCreateModelMixin,
mixins.CachedListRetrieveModelMixin,
mixins.CachedListDestroyModelMixin,
GenericAPIView):
"""
Concrete view for creating, retrieving or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
@@ -177,10 +209,16 @@ class CachedListCreateRetrieveDestroyAPIView(mixins.CachedListCreateModelMixin,m
return self.list_destroy(request, *args, **kwargs)
class CachedListCreateRetrieveUpdateAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,GenericAPIView):
class CachedListCreateRetrieveUpdateAPIView(
mixins.CachedListCreateModelMixin,
mixins.CachedListRetrieveModelMixin,
mixins.CachedListUpdateModelMixin,
GenericAPIView):
"""
Concrete view for creating, retrieving, updating or deleting a model instance.
Concrete view for creating, retrieving, updating or deleting a model
instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
@@ -194,10 +232,17 @@ class CachedListCreateRetrieveUpdateAPIView(mixins.CachedListCreateModelMixin,mi
return self.list_partial_update(request, *args, **kwargs)
class CachedListCreateRetrieveUpdateDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView):
class CachedListCreateRetrieveUpdateDestroyAPIView(
mixins.CachedListCreateModelMixin,
mixins.CachedListRetrieveModelMixin,
mixins.CachedListUpdateModelMixin,
mixins.CachedListDestroyModelMixin,
GenericAPIView):
"""
Concrete view for creating, retrieving, updating or deleting a model instance.
Concrete view for creating, retrieving, updating or deleting a model
instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)

View File

@@ -8,56 +8,70 @@ from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework import mixins
from starfields_drf_generics.cache_mixins import CacheGetMixin, CacheSetMixin, CacheDeleteMixin
from starfields_drf_generics.cache_mixins import (
CacheGetMixin, CacheSetMixin, CacheDeleteMixin)
# Mixin classes to be included in the generic classes
class CachedCreateModelMixin(CacheDeleteMixin, mixins.CreateModelMixin):
"""
A slightly modified version of rest_framework.mixins.CreateModelMixin that handles cache deletions.
A slightly modified version of rest_framework.mixins.CreateModelMixin
that handles cache deletions.
"""
def create(self, request, *args, **kwargs):
" Creates the entry in the request "
# Go on with the creation as normal
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
# Delete the cache
self.delete_cache(request)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers)
class CachedRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
"""
A slightly modified version of rest_framework.mixins.RetrieveModelMixin that handles cache attempts.
mixins.RetrieveModelMixin only has the retrieve method so it doesn't stand to inherit anything from it.
A slightly modified version of rest_framework.mixins.RetrieveModelMixin
that handles cache attempts.
mixins.RetrieveModelMixin only has the retrieve method so it doesn't stand
to inherit anything from it.
"""
def retrieve(self, request, *args, **kwargs):
def retrieve(self, request):
" Retrieves the entry in the request "
# Attempt to get the request from the cache
cache_attempt = self.get_cache(request)
if cache_attempt:
return Response(cache_attempt)
else:
# The cache get attempt failed so we have to get the results from the database
instance = self.get_object()
serializer = self.get_serializer(instance)
response = Response(serializer.data)
self.set_cache(request, response)
return response
# The cache get attempt failed so we have to get the results from
# the database
instance = self.get_object()
serializer = self.get_serializer(instance)
response = Response(serializer.data)
self.set_cache(request, response)
return response
class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin):
"""
A slightly modified version of rest_framework.mixins.UpdateModelMixin that handles cache deletes.
A slightly modified version of rest_framework.mixins.UpdateModelMixin that
handles cache deletes.
"""
def update(self, request, *args, **kwargs):
" Updates the entry in the request "
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer = self.get_serializer(instance, data=request.data,
partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
@@ -65,7 +79,7 @@ class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
# Delete the related caches
self.delete_cache(request)
@@ -74,15 +88,18 @@ class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin):
class CachedDestroyModelMixin(CacheDeleteMixin, mixins.DestroyModelMixin):
"""
A slightly modified version of rest_framework.mixins.DestroyModelMixin that handles cache deletes.
A slightly modified version of rest_framework.mixins.DestroyModelMixin
that handles cache deletes.
"""
def destroy(self, request, *args, **kwargs):
" Deletes the entry in the request "
instance = self.get_object()
self.perform_destroy(instance)
# Delete the related caches
self.delete_cache(request)
return Response(status=status.HTTP_204_NO_CONTENT)
@@ -91,21 +108,26 @@ class CachedListCreateModelMixin(CacheDeleteMixin):
"""
A fully custom mixin that handles mutiple instance cration.
"""
def list_create(self, request, *args, **kwargs):
def list_create(self, request):
" Creates the list of entries in the request "
# Go on with the creation as normal
serializer = self.get_serializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
# Delete the cache
self.delete_cache(request)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers)
def perform_create(self, serializer):
" Generic save hook "
serializer.save()
def get_success_headers(self, data):
" Returns extra success headers "
try:
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
except (TypeError, KeyError):
@@ -114,31 +136,36 @@ class CachedListCreateModelMixin(CacheDeleteMixin):
class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
"""
A slightly modified version of rest_framework.mixins.ListModelMixin that handles cache saves.
mixins.ListModelMixin only has the list method so it doesn't stand to inherit anything from it.
A slightly modified version of rest_framework.mixins.ListModelMixin that
handles cache saves.
mixins.ListModelMixin only has the list method so it doesn't stand to
inherit anything from it.
"""
def list(self, request, *args, **kwargs):
def list(self, request):
" Retrieves the listing of entries "
# Attempt to get the request from the cache
cache_attempt = self.get_cache(request)
if cache_attempt:
return Response(cache_attempt)
else:
# The cache get attempt failed so we have to get the results from the database
queryset = self.filter_queryset(self.get_queryset())
if self.paged:
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
response = self.get_paginated_response(serializer.data)
else:
serializer = self.get_serializer(queryset, many=True)
response = Response(serializer.data)
# The cache get attempt failed so we have to get the results from
# the database
queryset = self.filter_queryset(self.get_queryset())
if self.paged:
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
response = self.get_paginated_response(serializer.data)
else:
serializer = self.get_serializer(queryset, many=True)
response = Response(serializer.data)
else:
serializer = self.get_serializer(queryset, many=True)
response = Response(serializer.data)
self.set_cache(request, response)
return response
@@ -147,12 +174,15 @@ class CachedListUpdateModelMixin(CacheDeleteMixin):
"""
A fully custom mixin that handles mutiple instance updates.
"""
def list_update(self, request, *args, **kwargs):
def list_update(self, request, **kwargs):
" Updates the list of entries in the request "
partial = kwargs.pop('partial', False)
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, data=request.data, partial=partial, many=True)
serializer = self.get_serializer(queryset, data=request.data,
partial=partial, many=True)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
@@ -162,9 +192,11 @@ class CachedListUpdateModelMixin(CacheDeleteMixin):
return Response(serializer.data)
def perform_update(self, serializer):
" Generic save hook "
serializer.save()
def list_partial_update(self, request, *args, **kwargs):
" Needs to be called on partial updates "
kwargs['partial'] = True
return self.list_update(request, *args, **kwargs)
@@ -173,39 +205,22 @@ class CachedListDestroyModelMixin(CacheDeleteMixin):
"""
A fully custom mixin that handles mutiple instance deletions.
"""
def list_destroy(self, request, *args, **kwargs):
def list_destroy(self, request):
" Deletes the list of entries in the request "
# Go on with the validation as normal
serializer = self.get_serializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
validated_data = serializer.validated_data
# TODO does this new stuff work even? need to check on the frontend
serializer.delete(validated_data)
# for instance in self.get_objects():
# if instance is not None:
# self.perform_destroy(instance)
# Delete the related caches
self.delete_cache(request)
return Response(status=status.HTTP_204_NO_CONTENT)
#def perform_destroy(self, instance):
# instance.delete()
#def get_objects(self):
# """
# The custom list version of get_object that retrieves one instance from the #database. It yields model instances with each call.
# """
# queryset = self.filter_queryset(self.get_queryset())
#
# if len(queryset):
# for obj in queryset.all():
# # May raise a permission denied
# self.check_object_permissions(self.request, obj)
# yield obj
#yield None

View File

@@ -1,22 +1,30 @@
"""
Utility functions for the library
"""
def sorted_params_string(filters_dict):
"""
This function takes a dict and returns it in a sorted form for the url filter, it's primarily used for cache purposes.
This function takes a dict and returns it in a sorted form for the url
filter, it's primarily used for cache purposes.
"""
filters_string = ''
for key in sorted(filters_dict.keys()):
if filters_string == '':
filters_string = f"{key}={','.join(str(val) for val in sorted(filters_dict[key]))}"
key_str = ','.join(str(val) for val in sorted(filters_dict[key]))
filters_string = f"{key}={key_str}"
else:
filters_string = f"{filters_string}&{key}={','.join(str(val) for val in sorted(filters_dict[key]))}"
key_str = ','.join(str(val) for val in sorted(filters_dict[key]))
filters_string = f"{filters_string}&{key}={key_str}"
filters_string = filters_string.strip()
return filters_string
def parse_tags_to_dict(tags):
" This function parses a tag string into a dictionary "
tagdict = {}
if ':' not in tags:
tagdict = {}
else:
else:
for subtag in tags.split('&'):
tagkey, taglist = subtag.split(':')
taglist = taglist.split(',')