diff --git a/starfields_drf_generics/cache_mixins.py b/starfields_drf_generics/cache_mixins.py index 8637d73..25c6387 100644 --- a/starfields_drf_generics/cache_mixins.py +++ b/starfields_drf_generics/cache_mixins.py @@ -1,36 +1,45 @@ from starfields_drf_generics.utils import sorted_params_string -# TODO classes below that involve create, update, destroy don't delete the caches properly, they need a regex cache delete +# TODO classes below that involve create, update, destroy don't delete the +# caches properly, they need a regex cache delete # TODO put more reasonable asserts and feedback + # Mixin classes that provide cache functionalities class CacheUniqueUrl: def get_cache_unique_url(self, request): - """ Create the query to be cached in a unique way to avoid duplicates. """ + """ + Create the query to be cached in a unique way to avoid duplicates. + """ if not hasattr(self, 'filters_string'): # Only assign the attribute if it's not already assigned filters = {} if self.extra_filters_dict: filters.update(self.extra_filters_dict) - # Check if the url parameters have any of the keys of the extra filters and if so assign them + # Check if the url parameters have any of the keys of the extra + # filters and if so assign them for key in self.extra_filters_dict: if key in self.request.query_params.keys(): - filters[key] = self.request.query_params[key].replace(' ','').split(',') + filters[key] = self.request.query_params[key].replace( + ' ', '').split(',') # Check if they're resolved in the urlconf as well if key in self.kwargs.keys(): filters[key] = [self.kwargs[key]] - + if hasattr(self, 'paged'): if self.paged: - filters.update({'limit': [self.default_page_size], 'offset': [0]}) + filters.update({'limit': [self.default_page_size], + 'offset': [0]}) if 'limit' in self.request.query_params.keys(): - filters.update({'limit': [self.request.query_params['limit']]}) + filters.update({ + 'limit': [self.request.query_params['limit']]}) if 'offset' in self.request.query_params.keys(): - filters.update({'offset': [self.request.query_params['offset']]}) + filters.update({ + 'offset': [self.request.query_params['offset']]}) for backend in list(self.filter_backends): filters.update(backend().get_filters_dict(request, self)) self.filters_string = sorted_params_string(filters) - + class CacheGetMixin(CacheUniqueUrl): cache_prefix = None @@ -38,30 +47,34 @@ class CacheGetMixin(CacheUniqueUrl): cache_timeout_mins = None default_page_size = 20 extra_filters_dict = None - + def get_cache(self, request): assert self.cache_prefix is not None, ( "'%s' should include a `cache_prefix` attribute" % self.__class__.__name__ ) - + self.get_cache_unique_url(request) - + # Attempt to get the response from the cache for the whole request try: if self.cache_vary_on_user: - cache_attempt = self.cache.get(f"{self.cache_prefix}.{request.user}.{self.filters_string}") + cache_attempt = self.cache.get( + f"{self.cache_prefix}.{request.user}.{self.filters_string}" + ) else: - cache_attempt = self.cache.get(f"{self.cache_prefix}.{self.filters_string}") - except: - self.logger.info(f"Cache get attempt for {self.__class__.__name__} failed.") + cache_attempt = self.cache.get( + f"{self.cache_prefix}.{self.filters_string}") + except Exception: + self.logger.info(f"Cache get attempt for {self.__class__.__name__}" + " failed.") cache_attempt = None - + if cache_attempt: return cache_attempt else: return None - + class CacheSetMixin(CacheUniqueUrl): cache_prefix = None @@ -69,31 +82,34 @@ class CacheSetMixin(CacheUniqueUrl): cache_timeout_mins = None default_page_size = 20 extra_filters_dict = None - + def set_cache(self, request, response): self.get_cache_unique_url(request) - + # Create a function that programmatically defines the caching function def make_caching_function(cls, request, cache): def caching_function(response): # Writes the response to the cache try: if self.cache_vary_on_user: - self.cache.set(key=f"{self.cache_prefix}.{request.user}.{self.filters_string}", - value = response.data, - timeout=60*self.cache_timeout_mins) + self.cache.set(key=f"{self.cache_prefix}." + f"{request.user}.{self.filters_string}", + value=response.data, + timeout=60*self.cache_timeout_mins) else: - self.cache.set(key=f"{self.cache_prefix}.{self.filters_string}", - value = response.data, - timeout=60*self.cache_timeout_mins) - except: - self.logger.exception(f"Cache set attempt for {self.__class__.__name__} failed.") + self.cache.set(key=f"{self.cache_prefix}" + f".{self.filters_string}", + value=response.data, + timeout=60*self.cache_timeout_mins) + except Exception: + self.logger.exception("Cache set attempt for " + f"{self.__class__.__name__} failed.") return caching_function - + # Register the post rendering hook to the response caching_function = make_caching_function(self, request, self.cache) response.add_post_render_callback(caching_function) - + class CacheDeleteMixin(CacheUniqueUrl): cache_delete = True @@ -101,22 +117,26 @@ class CacheDeleteMixin(CacheUniqueUrl): cache_vary_on_user = False cache_timeout_mins = None extra_filters_dict = None - + def delete_cache(self, request): # Handle the caching if self.cache_delete: # Create the query to be cached in a unique way to avoid duplicates self.get_cache_unique_url(request) - + assert self.cache_prefix is not None, ( - f"{self.__class__.__name__} should include a `cache_prefix` attribute" + f"{self.__class__.__name__} should include a `cache_prefix`" + "attribute" ) - + # Delete the cache since a new entry has been created try: if self.cache_vary_on_user: - self.cache.delete(f"{self.cache_prefix}.{request.user}.{self.filters_string}") + self.cache.delete(f"{self.cache_prefix}.{request.user}" + f".{self.filters_string}") else: - self.cache.delete(f"{self.cache_prefix}.{self.filters_string}") - except: - self.logger.exception(f"Cache delete attempt for {self.__class__.__name__} failed.") + self.cache.delete(f"{self.cache_prefix}" + f".{self.filters_string}") + except Exception: + self.logger.exception("Cache delete attempt for " + f"{self.__class__.__name__} failed.") diff --git a/starfields_drf_generics/filters.py b/starfields_drf_generics/filters.py index 99bbde2..e8ddf9a 100644 --- a/starfields_drf_generics/filters.py +++ b/starfields_drf_generics/filters.py @@ -1,34 +1,36 @@ -from django_filters import rest_framework as filters -from django.contrib.postgres.search import TrigramSimilarity -from django.db.models.functions import Concat -from django.db.models import CharField -from rest_framework.filters import BaseFilterBackend + import operator +from functools import reduce +from rest_framework.settings import api_settings +from rest_framework.filters import BaseFilterBackend from django.template import loader from django.utils.translation import gettext_lazy as _ -from django.db.models import Max, Min, Count, Q -from rest_framework.settings import api_settings -from django.db.models.constants import LOOKUP_SEP from django.db import models -from functools import reduce +from django.db.models import Q, CharField +from django.db.models.constants import LOOKUP_SEP +from django.db.models.functions import Concat +from django.contrib.postgres.search import TrigramSimilarity + # TODO the dev pages are not done -# Filters + class LessThanOrEqualFilter(BaseFilterBackend): def get_less_than_field(self, view, request): return getattr(view, 'less_than_field', None) - + def get_filters_dict(self, request, view): """ - Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. """ less_than_field = self.get_less_than_field(view, request) - + assert less_than_field is not None, ( - f"{view.__class__.__name__} should include a `less_than_field` attribute" + f"{view.__class__.__name__} should include a `less_than_field`" + "attribute" ) - + filters_dict = {} if less_than_field+'_max' in request.query_params.keys(): field_value = request.query_params.get(less_than_field+'_max') @@ -36,9 +38,10 @@ class LessThanOrEqualFilter(BaseFilterBackend): else: filters_dict[less_than_field+'_max'] = [] return filters_dict - + def filter_queryset(self, request, queryset, view): - # Return the correctly filtered queryset but also assign the filter dict to create the unique url for the cache + # Return the correctly filtered queryset but also assign the filter + # dict to create the unique url for the cache less_than_field = self.get_less_than_field(view, request) if less_than_field+'_max' in request.query_params.keys(): kwquery = {} @@ -52,17 +55,19 @@ class LessThanOrEqualFilter(BaseFilterBackend): class MoreThanOrEqualFilter(BaseFilterBackend): def get_more_than_field(self, view, request): return getattr(view, 'more_than_field', None) - + def get_filters_dict(self, request, view): """ - Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. """ more_than_field = self.get_more_than_field(view, request) - + assert more_than_field is not None, ( - f"{view.__class__.__name__} should include a `more_than_field` attribute" + f"{view.__class__.__name__} should include a `more_than_field`" + "attribute" ) - + filters_dict = {} if more_than_field+'_min' in request.query_params.keys(): field_value = request.query_params.get(more_than_field+'_min') @@ -70,7 +75,7 @@ class MoreThanOrEqualFilter(BaseFilterBackend): else: filters_dict[more_than_field+'_min'] = [] return filters_dict - + def filter_queryset(self, request, queryset, view): """ Return the correctly filtered queryset @@ -78,7 +83,8 @@ class MoreThanOrEqualFilter(BaseFilterBackend): more_than_field = self.get_more_than_field(view, request) if more_than_field+'_min' in request.query_params.keys(): kwquery = {} - kwquery[more_than_field+'__gte'] = request.query_params.get(more_than_field+'_min') + kwquery[more_than_field+'__gte'] = request.query_params.get( + more_than_field+'_min') return queryset.filter(**kwquery) else: return queryset @@ -86,82 +92,90 @@ class MoreThanOrEqualFilter(BaseFilterBackend): class CategoryFilter(BaseFilterBackend): """ - This filter assigns the view.category object for later use, in particular for filters that depend on this one. + This filter assigns the view.category object for later use, in particular + for filters that depend on this one. """ template = 'starfields_drf_generics/templates/filters/categories.html' category_field = 'category' - + def get_category_class(self, view, request): return getattr(view, 'category_class', None) - + def assign_view_category(self, request, view): if not hasattr(view, self.category_field): if self.category_field in request.query_params.keys(): try: - category_slug = request.query_params.get(self.category_field).strip() - category = view.category_class.objects.get(slug=category_slug) - # Append the category object in the view for later reference + category_slug = request.query_params.get( + self.category_field).strip() + category = view.category_class.objects.get( + slug=category_slug) + # Append the category object in the view for later use view.category = category - except: + except Exception: view.category = None else: view.category = None - - + def get_filters_dict(self, request, view): """ - Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. Queries the database for the current category and saves it in view.category for internal use and later filters. + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. Queries the database for the current + category and saves it in view.category for internal use and later + filters. """ if hasattr(view, 'category_field'): self.category_field = self.get_category_field(view, request) - + category_class = self.get_category_class(view, request) - + assert category_class is not None, ( - f"{view.__class__.__name__} should include a `category_class` attribute" + f"{view.__class__.__name__} should include a `category_class`" + "attribute" ) - + # Create the filters dictionary and find the present category instance self.assign_view_category(request, view) - + filters_dict = {} if view.category: filters_dict[self.category_field] = [view.category.slug] else: filters_dict[self.category_field] = [] - + return filters_dict - + def filter_queryset(self, request, queryset, view): self.assign_view_category(request, view) - + # Create the queryset if view.category: kwquery_1 = {} kwquery_1[self.category_field+'__id'] = view.category.id if view.category.tn_descendants_pks: kwquery_2 = {} - kwquery_2[self.category_field+'__id__in'] = view.category.tn_descendants_pks.split(',') - + key = self.category_field+'__id__in' + kwquery_2[key] = view.category.tn_descendants_pks.split(',') + queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2)) else: queryset = queryset.filter(**kwquery_1) - + return queryset - + # Developer Interface methods def get_valid_fields(self, queryset, view, context, request): # A query is executed here to get the possible categories category_class = self.get_category_class(view, request) if hasattr(view, 'category_field'): self.category_field = self.get_category_field(view, request) - + assert category_class is not None, ( - f"{view.__class__.__name__} should include a `category_class` attribute" + f"{view.__class__.__name__} should include a `category_class`" + "attribute" ) - + valid_fields = category_class.objects.all() - + if len(valid_fields): valid_fields = [ (item.slug, item.__str__()) for item in valid_fields @@ -170,7 +184,7 @@ class CategoryFilter(BaseFilterBackend): return valid_fields else: return [] - + def get_current(self, request, queryset, view): params = request.query_params.get(self.category_field) if params: @@ -178,7 +192,7 @@ class CategoryFilter(BaseFilterBackend): return fields[0] else: return None - + def get_template_context(self, request, queryset, view): current = self.get_current(request, queryset, view) options = [] @@ -187,11 +201,12 @@ class CategoryFilter(BaseFilterBackend): 'current': current, 'param': self.category_field, } - for key, label in self.get_valid_fields(queryset, view, context, request): + valid_fields = self.get_valid_fields(queryset, view, context, request) + for key, label in valid_fields: options.append((key, '%s' % (label))) context['options'] = options return context - + def to_html(self, request, queryset, view): template = loader.get_template(self.template) context = self.get_template_context(request, queryset, view) @@ -200,112 +215,135 @@ class CategoryFilter(BaseFilterBackend): class FacetFilter(BaseFilterBackend): """ - This filter requires CategoryFilter to be ran before it. It assigns the view.facets which includes all the facets applicable to the current category. + This filter requires CategoryFilter to be ran before it. It assigns the + view.facets which includes all the facets applicable to the current + category. """ template = 'starfields_drf_generics/templates/filters/facets.html' - + def get_facet_class(self, view, request): return getattr(view, 'facet_class', None) - + def get_facet_tag_class(self, view, request): return getattr(view, 'facet_tag_class', None) - + def get_facet_tag_field(self, view, request): return getattr(view, 'facet_tag_field', None) - + def assign_view_facets(self, request, view): if not hasattr(view, 'facets'): if hasattr(view, 'facet_class'): self.facet_class = self.get_facet_class(view, request) - + assert self.facet_class is not None, ( - f"{view.__class__.__name__} should include a `facet_class` attribute" + f"{view.__class__.__name__} should include a `facet_class`" + "attribute" ) - + if view.category: if view.category.tn_ancestors_pks: - view.facets = self.facet_class.objects.filter(Q(category__id=view.category.id) | Q(category__id__in=view.category.tn_ancestors_pks.split(','))).prefetch_related('facet_tags') + ancestor_ids = view.category.tn_ancestors_pks.split(',') + view.facets = self.facet_class.objects.filter( + Q(category__id=view.category.id) | Q( + category__id__in=ancestor_ids) + ).prefetch_related('facet_tags') else: - view.facets = self.facet_class.objects.filter(category__id=view.category.id).prefetch_related('facet_tags') + view.facets = self.facet_class.objects.filter( + category__id=view.category.id).prefetch_related( + 'facet_tags') else: - view.facets = self.facet_class.objects.filter(category__tn_level=1).prefetch_related('facet_tags') - + view.facets = self.facet_class.objects.filter( + category__tn_level=1).prefetch_related( + 'facet_tags') + def get_filters_dict(self, request, view): """ - Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. """ if hasattr(view, 'facet_class'): self.facet_class = self.get_facet_class(view, request) - + assert self.facet_class is not None, ( - f"{view.__class__.__name__} should include a `facet_class` attribute" + f"{view.__class__.__name__} should include a `facet_class`" + "attribute" ) - + self.assign_view_facets(request, view) - + filters_dict = {} if view.facets: for facet in view.facets: if facet.slug in request.query_params.keys(): - filters_dict[facet.slug] = set(request.query_params[facet.slug].split(',')) + filters_dict[facet.slug] = set( + request.query_params[facet.slug].split(',')) else: filters_dict[facet.slug] = set({}) - - # Append the facets object and the tags dict in the view for later reference + + # Append the facets object and the tags dict in the view for later + # reference view.tags = filters_dict - + return filters_dict - - + def filter_queryset(self, request, queryset, view): if hasattr(view, 'facet_tag_class'): self.facet_tag_class = self.get_facet_tag_class(view, request) - + assert self.facet_tag_class is not None, ( - f"{view.__class__.__name__} should include a `facet_tag_class` attribute" + f"{view.__class__.__name__} should include a `facet_tag_class`" + "attribute" ) - + if hasattr(view, 'facet_tag_field'): self.facet_tag_field = self.get_facet_tag_field(view, request) - + assert self.facet_tag_field is not None, ( - f"{view.__class__.__name__} should include a `facet_tag_field` attribute" + f"{view.__class__.__name__} should include a `facet_tag_field`" + "attribute" ) - + self.assign_view_facets(request, view) - + if view.facets: for facet in view.facets: - if facet.slug in request.query_params.keys() and request.query_params[facet.slug]: + request_slug = request.query_params[facet.slug] + if facet.slug in request.query_params.keys() and request_slug: tag_filterlist = request.query_params.get(facet.slug) if tag_filterlist == '': - # If the tag filterlist is empty then we're not filtering against it, it's like having all the tags of the facet selected + # If the tag filterlist is empty then we're not + # filtering against it, it's like having all the tags + # of the facet selected pass else: kwquery = {} - kwquery[self.facet_tag_field+'__slug__in'] = tag_filterlist.replace(' ','').split(',') + key = self.facet_tag_field+'__slug__in' + kwquery[key] = tag_filterlist.replace(' ', '').split( + ',') queryset = queryset.filter(**kwquery) - + return queryset - - + # Developer Interface methods def get_template_context(self, request, queryset, view): - # Does aggressive database querying to get the necessary facets and facettags, but this is only for the developer interface so its fine + # Does aggressive database querying to get the necessary facets and + # facettags, but this is only for the developer interface so its fine if hasattr(view, 'facet_class'): self.facet_class = self.get_facet_class(view, request) - + assert self.facet_class is not None, ( - f"{view.__class__.__name__} should include a `facet_class` attribute" + f"{view.__class__.__name__} should include a `facet_class`" + "attribute" ) - + if hasattr(view, 'facet_tag_class'): self.facet_tag_class = self.get_facet_tag_class(view, request) - + assert self.facet_tag_class is not None, ( - f"{view.__class__.__name__} should include a `facet_tag_class` attribute" + f"{view.__class__.__name__} should include a `facet_tag_class`" + "attribute" ) - + # Find the current choices current = [] facet_slugs = [] @@ -314,7 +352,7 @@ class FacetFilter(BaseFilterBackend): facet_slugs.append(facet.slug) if facet.slug in request.query_params.keys(): current.append(request.query_params.get(facet.slug)) - + facet_slug_names = {} options = {} context = { @@ -324,17 +362,19 @@ class FacetFilter(BaseFilterBackend): } if view.facets: for facet in view.facets: - facet_tag_instances = self.facet_tag_class.objects.filter(facet__slug=facet.slug) - options[facet.slug] = [(facet_tag.slug, facet_tag.name) for facet_tag in facet_tag_instances] + facet_tag_instances = self.facet_tag_class.objects.filter( + facet__slug=facet.slug) + options[facet.slug] = [(facet_tag.slug, facet_tag.name) for + facet_tag in facet_tag_instances] facet_slug_names[facet.slug] = facet.name context['facet_slug_names'] = facet_slug_names context['options'] = options else: context['facet_slug_names'] = {} context['options'] = {} - + return context - + def to_html(self, request, queryset, view): template = loader.get_template(self.template) context = self.get_template_context(request, queryset, view) @@ -346,10 +386,13 @@ class TrigramSearchFilter(BaseFilterBackend): search_param = 'search' template = 'rest_framework/filters/search.html' search_title = _('Search') - search_description = _('A search string to perform trigram similarity based searching with.') + search_description = _('A search string to perform trigram similarity' + 'based searching with.') + def get_filters_dict(self, request, view): """ - Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. """ self.filters_dict = {} if 'search' in request.query_params.keys(): @@ -357,18 +400,19 @@ class TrigramSearchFilter(BaseFilterBackend): self.filters_dict['search'] = [slug_term] else: self.filters_dict['search'] = [] - + return self.filters_dict def filter_queryset(self, request, queryset, view): search_fields = getattr(view, 'search_fields', None) - + assert search_fields is not None, ( - f"{view.__class__.__name__} should include a `search_fields` attribute" + f"{view.__class__.__name__} should include a `search_fields`" + "attribute" ) - + query = request.query_params.get(self.search_param, '') - + if query: queryset = queryset.annotate( search_field=Concat( @@ -377,7 +421,7 @@ class TrigramSearchFilter(BaseFilterBackend): )).annotate( similarity=TrigramSimilarity('search_field', query) ).filter(similarity__gt=0.05).distinct() - + return queryset def to_html(self, request, queryset, view): @@ -436,14 +480,14 @@ class SlugSearchFilter(BaseFilterBackend): def get_slug_field(self, view, request): return getattr(view, 'slug_field', None) - + def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. """ if hasattr(view, 'slug_field'): self.slug_field = self.get_slug_field(view, request) - + assert self.slug_field is not None, ( f"{view.__class__.__name__} should include a `slug_field` attribute" ) @@ -453,9 +497,9 @@ class SlugSearchFilter(BaseFilterBackend): self.filters_dict[self.slug_field] = [slug_term] else: self.filters_dict[self.slug_field] = [] - + return self.filters_dict - + def filter_queryset(self, request, queryset, view): # Ensure that the slug field was searched against try: @@ -463,18 +507,18 @@ class SlugSearchFilter(BaseFilterBackend): slug_term = request.query_params.get(self.slug_field) query = {} query[self.slug_field] = slug_term - + queryset = queryset.get(**query) except Exception as e: print(e) - + return queryset - + def to_html(self, request, queryset, view): if not getattr(view, 'slug_field', None): return '' - + slug_term = self.get_slug_term(request) context = { 'param': self.slug_field, @@ -540,7 +584,7 @@ class SearchFilter(BaseFilterBackend): # This field has a custom __ query transform but is not a relational field. break return False - + def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. @@ -551,7 +595,7 @@ class SearchFilter(BaseFilterBackend): self.filters_dict['search'] = [slug_term] else: self.filters_dict['search'] = [] - + return self.filters_dict def filter_queryset(self, request, queryset, view): @@ -728,7 +772,7 @@ class OrderingFilter(BaseFilterBackend): return term in valid_fields return [term for term in fields if term_valid(term)] - + def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. @@ -739,9 +783,9 @@ class OrderingFilter(BaseFilterBackend): self.filters_dict['ordering'] = [slug_term] else: self.filters_dict['ordering'] = [view.ordering_fields[0]] - + return self.filters_dict - + def filter_queryset(self, request, queryset, view): ordering = self.get_ordering(request, queryset, view) diff --git a/starfields_drf_generics/generics.py b/starfields_drf_generics/generics.py index 569edaf..eaac31f 100644 --- a/starfields_drf_generics/generics.py +++ b/starfields_drf_generics/generics.py @@ -1,15 +1,7 @@ """ Generic views that provide commonly needed behaviour. """ -from django.core.exceptions import ValidationError -from django.db.models.query import QuerySet -from django.http import Http404 -from django.shortcuts import get_object_or_404 as _get_object_or_404 - -from rest_framework import views from rest_framework.generics import GenericAPIView -from rest_framework.settings import api_settings - from starfields_drf_generics import mixins @@ -18,26 +10,32 @@ from starfields_drf_generics import mixins # Single item CRUD -class CachedCreateAPIView(mixins.CachedCreateModelMixin,GenericAPIView): +class CachedCreateAPIView(mixins.CachedCreateModelMixin, + GenericAPIView): """ Concrete view for creating a model instance. """ + def post(self, request, *args, **kwargs): return self.create(request, *args, **kwargs) -class CachedRetrieveAPIView(mixins.CachedRetrieveModelMixin,GenericAPIView): +class CachedRetrieveAPIView(mixins.CachedRetrieveModelMixin, + GenericAPIView): """ Concrete view for retrieving a model instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) -class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,GenericAPIView): +class CachedUpdateAPIView(mixins.CachedUpdateModelMixin, + GenericAPIView): """ Concrete view for updating a model instance. """ + def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) @@ -45,18 +43,23 @@ class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,GenericAPIView): return self.partial_update(request, *args, **kwargs) -class CachedDestroyAPIView(mixins.CachedDestroyModelMixin,GenericAPIView): +class CachedDestroyAPIView(mixins.CachedDestroyModelMixin, + GenericAPIView): """ Concrete view for deleting a model instance. """ + def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) -class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,GenericAPIView): +class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin, + mixins.CachedUpdateModelMixin, + GenericAPIView): """ Concrete view for retrieving, updating a model instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) @@ -67,10 +70,13 @@ class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedU return self.partial_update(request, *args, **kwargs) -class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView): +class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin, + mixins.CachedDestroyModelMixin, + GenericAPIView): """ Concrete view for retrieving or deleting a model instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) @@ -78,10 +84,14 @@ class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.Cached return self.destroy(request, *args, **kwargs) -class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView): +class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin, + mixins.CachedUpdateModelMixin, + mixins.CachedDestroyModelMixin, + GenericAPIView): """ Concrete view for retrieving, updating or deleting a model instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) @@ -95,10 +105,16 @@ class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins. return self.destroy(request, *args, **kwargs) -class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView): +class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin, + mixins.CachedRetrieveModelMixin, + mixins.CachedUpdateModelMixin, + mixins.CachedDestroyModelMixin, + GenericAPIView): """ - Concrete view for creating, retrieving, updating or deleting a model instance. + Concrete view for creating, retrieving, updating or deleting a model + instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) @@ -117,26 +133,32 @@ class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,mix # List based CRUD -class CachedListRetrieveAPIView(mixins.CachedListRetrieveModelMixin,GenericAPIView): +class CachedListRetrieveAPIView(mixins.CachedListRetrieveModelMixin, + GenericAPIView): """ Concrete view for listing a queryset. """ + def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) -class CachedListCreateAPIView(mixins.CachedListCreateModelMixin,GenericAPIView): +class CachedListCreateAPIView(mixins.CachedListCreateModelMixin, + GenericAPIView): """ Concrete view for creating multiple instances. """ + def post(self, request, *args, **kwargs): return self.list_create(request, *args, **kwargs) -class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,GenericAPIView): +class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin, + GenericAPIView): """ Concrete view for updating multiple instances. """ + def put(self, request, *args, **kwargs): return self.list_update(request, *args, **kwargs) @@ -144,18 +166,23 @@ class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,GenericAPIView): return self.list_partial_update(request, *args, **kwargs) -class CachedListDestroyAPIView(mixins.CachedListDestroyModelMixin,GenericAPIView): +class CachedListDestroyAPIView(mixins.CachedListDestroyModelMixin, + GenericAPIView): """ Concrete view for deleting multiple instances. """ + def delete(self, request, *args, **kwargs): return self.list_destroy(request, *args, **kwargs) -class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,mixins.CachedListCreateModelMixin,GenericAPIView): +class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin, + mixins.CachedListCreateModelMixin, + GenericAPIView): """ Concrete view for listing a queryset or creating a model instance. """ + def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) @@ -163,10 +190,15 @@ class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,mixins return self.create(request, *args, **kwargs) -class CachedListCreateRetrieveDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView): +class CachedListCreateRetrieveDestroyAPIView( + mixins.CachedListCreateModelMixin, + mixins.CachedListRetrieveModelMixin, + mixins.CachedListDestroyModelMixin, + GenericAPIView): """ Concrete view for creating, retrieving or deleting a model instance. """ + def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) @@ -177,10 +209,16 @@ class CachedListCreateRetrieveDestroyAPIView(mixins.CachedListCreateModelMixin,m return self.list_destroy(request, *args, **kwargs) -class CachedListCreateRetrieveUpdateAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,GenericAPIView): +class CachedListCreateRetrieveUpdateAPIView( + mixins.CachedListCreateModelMixin, + mixins.CachedListRetrieveModelMixin, + mixins.CachedListUpdateModelMixin, + GenericAPIView): """ - Concrete view for creating, retrieving, updating or deleting a model instance. + Concrete view for creating, retrieving, updating or deleting a model + instance. """ + def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) @@ -194,10 +232,17 @@ class CachedListCreateRetrieveUpdateAPIView(mixins.CachedListCreateModelMixin,mi return self.list_partial_update(request, *args, **kwargs) -class CachedListCreateRetrieveUpdateDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView): +class CachedListCreateRetrieveUpdateDestroyAPIView( + mixins.CachedListCreateModelMixin, + mixins.CachedListRetrieveModelMixin, + mixins.CachedListUpdateModelMixin, + mixins.CachedListDestroyModelMixin, + GenericAPIView): """ - Concrete view for creating, retrieving, updating or deleting a model instance. + Concrete view for creating, retrieving, updating or deleting a model + instance. """ + def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) diff --git a/starfields_drf_generics/mixins.py b/starfields_drf_generics/mixins.py index 8b26d7f..feba16c 100644 --- a/starfields_drf_generics/mixins.py +++ b/starfields_drf_generics/mixins.py @@ -8,56 +8,70 @@ from rest_framework import status from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework import mixins -from starfields_drf_generics.cache_mixins import CacheGetMixin, CacheSetMixin, CacheDeleteMixin +from starfields_drf_generics.cache_mixins import ( + CacheGetMixin, CacheSetMixin, CacheDeleteMixin) # Mixin classes to be included in the generic classes class CachedCreateModelMixin(CacheDeleteMixin, mixins.CreateModelMixin): """ - A slightly modified version of rest_framework.mixins.CreateModelMixin that handles cache deletions. + A slightly modified version of rest_framework.mixins.CreateModelMixin + that handles cache deletions. """ + def create(self, request, *args, **kwargs): + " Creates the entry in the request " # Go on with the creation as normal serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) - + # Delete the cache self.delete_cache(request) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response(serializer.data, status=status.HTTP_201_CREATED, + headers=headers) class CachedRetrieveModelMixin(CacheGetMixin, CacheSetMixin): """ - A slightly modified version of rest_framework.mixins.RetrieveModelMixin that handles cache attempts. - mixins.RetrieveModelMixin only has the retrieve method so it doesn't stand to inherit anything from it. + A slightly modified version of rest_framework.mixins.RetrieveModelMixin + that handles cache attempts. + mixins.RetrieveModelMixin only has the retrieve method so it doesn't stand + to inherit anything from it. """ - def retrieve(self, request, *args, **kwargs): + + def retrieve(self, request): + " Retrieves the entry in the request " # Attempt to get the request from the cache cache_attempt = self.get_cache(request) - + if cache_attempt: return Response(cache_attempt) - else: - # The cache get attempt failed so we have to get the results from the database - instance = self.get_object() - - serializer = self.get_serializer(instance) - response = Response(serializer.data) - - self.set_cache(request, response) - return response + + # The cache get attempt failed so we have to get the results from + # the database + instance = self.get_object() + + serializer = self.get_serializer(instance) + response = Response(serializer.data) + + self.set_cache(request, response) + return response class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin): """ - A slightly modified version of rest_framework.mixins.UpdateModelMixin that handles cache deletes. + A slightly modified version of rest_framework.mixins.UpdateModelMixin that + handles cache deletes. """ + def update(self, request, *args, **kwargs): + " Updates the entry in the request " partial = kwargs.pop('partial', False) instance = self.get_object() - serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer = self.get_serializer(instance, data=request.data, + partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) @@ -65,7 +79,7 @@ class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin): # If 'prefetch_related' has been applied to a queryset, we need to # forcibly invalidate the prefetch cache on the instance. instance._prefetched_objects_cache = {} - + # Delete the related caches self.delete_cache(request) @@ -74,15 +88,18 @@ class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin): class CachedDestroyModelMixin(CacheDeleteMixin, mixins.DestroyModelMixin): """ - A slightly modified version of rest_framework.mixins.DestroyModelMixin that handles cache deletes. + A slightly modified version of rest_framework.mixins.DestroyModelMixin + that handles cache deletes. """ + def destroy(self, request, *args, **kwargs): + " Deletes the entry in the request " instance = self.get_object() self.perform_destroy(instance) - + # Delete the related caches self.delete_cache(request) - + return Response(status=status.HTTP_204_NO_CONTENT) @@ -91,21 +108,26 @@ class CachedListCreateModelMixin(CacheDeleteMixin): """ A fully custom mixin that handles mutiple instance cration. """ - def list_create(self, request, *args, **kwargs): + + def list_create(self, request): + " Creates the list of entries in the request " # Go on with the creation as normal serializer = self.get_serializer(data=request.data, many=True) serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) - + # Delete the cache self.delete_cache(request) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response(serializer.data, status=status.HTTP_201_CREATED, + headers=headers) def perform_create(self, serializer): + " Generic save hook " serializer.save() def get_success_headers(self, data): + " Returns extra success headers " try: return {'Location': str(data[api_settings.URL_FIELD_NAME])} except (TypeError, KeyError): @@ -114,31 +136,36 @@ class CachedListCreateModelMixin(CacheDeleteMixin): class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin): """ - A slightly modified version of rest_framework.mixins.ListModelMixin that handles cache saves. - mixins.ListModelMixin only has the list method so it doesn't stand to inherit anything from it. + A slightly modified version of rest_framework.mixins.ListModelMixin that + handles cache saves. + mixins.ListModelMixin only has the list method so it doesn't stand to + inherit anything from it. """ - def list(self, request, *args, **kwargs): + + def list(self, request): + " Retrieves the listing of entries " # Attempt to get the request from the cache cache_attempt = self.get_cache(request) - + if cache_attempt: return Response(cache_attempt) - else: - # The cache get attempt failed so we have to get the results from the database - queryset = self.filter_queryset(self.get_queryset()) - - if self.paged: - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - response = self.get_paginated_response(serializer.data) - else: - serializer = self.get_serializer(queryset, many=True) - response = Response(serializer.data) + + # The cache get attempt failed so we have to get the results from + # the database + queryset = self.filter_queryset(self.get_queryset()) + + if self.paged: + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + response = self.get_paginated_response(serializer.data) else: serializer = self.get_serializer(queryset, many=True) response = Response(serializer.data) - + else: + serializer = self.get_serializer(queryset, many=True) + response = Response(serializer.data) + self.set_cache(request, response) return response @@ -147,12 +174,15 @@ class CachedListUpdateModelMixin(CacheDeleteMixin): """ A fully custom mixin that handles mutiple instance updates. """ - def list_update(self, request, *args, **kwargs): + + def list_update(self, request, **kwargs): + " Updates the list of entries in the request " partial = kwargs.pop('partial', False) - + queryset = self.filter_queryset(self.get_queryset()) - - serializer = self.get_serializer(queryset, data=request.data, partial=partial, many=True) + + serializer = self.get_serializer(queryset, data=request.data, + partial=partial, many=True) serializer.is_valid(raise_exception=True) self.perform_update(serializer) @@ -162,9 +192,11 @@ class CachedListUpdateModelMixin(CacheDeleteMixin): return Response(serializer.data) def perform_update(self, serializer): + " Generic save hook " serializer.save() def list_partial_update(self, request, *args, **kwargs): + " Needs to be called on partial updates " kwargs['partial'] = True return self.list_update(request, *args, **kwargs) @@ -173,39 +205,22 @@ class CachedListDestroyModelMixin(CacheDeleteMixin): """ A fully custom mixin that handles mutiple instance deletions. """ - def list_destroy(self, request, *args, **kwargs): + + def list_destroy(self, request): + " Deletes the list of entries in the request " # Go on with the validation as normal serializer = self.get_serializer(data=request.data, many=True) serializer.is_valid(raise_exception=True) validated_data = serializer.validated_data - + # TODO does this new stuff work even? need to check on the frontend serializer.delete(validated_data) - + # for instance in self.get_objects(): # if instance is not None: # self.perform_destroy(instance) # Delete the related caches self.delete_cache(request) - + return Response(status=status.HTTP_204_NO_CONTENT) - - #def perform_destroy(self, instance): - # instance.delete() - - #def get_objects(self): - # """ - # The custom list version of get_object that retrieves one instance from the #database. It yields model instances with each call. - # """ - # queryset = self.filter_queryset(self.get_queryset()) - # - # if len(queryset): - # for obj in queryset.all(): - - # # May raise a permission denied - # self.check_object_permissions(self.request, obj) - - # yield obj - - #yield None diff --git a/starfields_drf_generics/utils.py b/starfields_drf_generics/utils.py index 63d261f..a158e28 100644 --- a/starfields_drf_generics/utils.py +++ b/starfields_drf_generics/utils.py @@ -1,22 +1,30 @@ +""" + Utility functions for the library +""" + def sorted_params_string(filters_dict): """ - This function takes a dict and returns it in a sorted form for the url filter, it's primarily used for cache purposes. + This function takes a dict and returns it in a sorted form for the url + filter, it's primarily used for cache purposes. """ filters_string = '' for key in sorted(filters_dict.keys()): if filters_string == '': - filters_string = f"{key}={','.join(str(val) for val in sorted(filters_dict[key]))}" + key_str = ','.join(str(val) for val in sorted(filters_dict[key])) + filters_string = f"{key}={key_str}" else: - filters_string = f"{filters_string}&{key}={','.join(str(val) for val in sorted(filters_dict[key]))}" + key_str = ','.join(str(val) for val in sorted(filters_dict[key])) + filters_string = f"{filters_string}&{key}={key_str}" filters_string = filters_string.strip() return filters_string def parse_tags_to_dict(tags): + " This function parses a tag string into a dictionary " tagdict = {} if ':' not in tags: tagdict = {} - else: + else: for subtag in tags.split('&'): tagkey, taglist = subtag.split(':') taglist = taglist.split(',')