diff --git a/cache_mixins.py b/cache_mixins.py new file mode 100644 index 0000000..a9f1395 --- /dev/null +++ b/cache_mixins.py @@ -0,0 +1,122 @@ +from libraries.utils import sorted_params_string + +# TODO classes below that involve create, update, destroy don't delete the caches properly, they need a regex cache delete +# TODO put more reasonable asserts and feedback + +# Mixin classes that provide cache functionalities +class CacheUniqueUrl: + def get_cache_unique_url(self, request): + """ Create the query to be cached in a unique way to avoid duplicates. """ + if not hasattr(self, 'filters_string'): + # Only assign the attribute if it's not already assigned + filters = {} + if self.extra_filters_dict: + filters.update(self.extra_filters_dict) + # Check if the url parameters have any of the keys of the extra filters and if so assign them + for key in self.extra_filters_dict: + if key in self.request.query_params.keys(): + filters[key] = self.request.query_params[key].replace(' ','').split(',') + # Check if they're resolved in the urlconf as well + if key in self.kwargs.keys(): + filters[key] = [self.kwargs[key]] + + if hasattr(self, 'paged'): + if self.paged: + filters.update({'limit': [self.default_page_size], 'offset': [0]}) + if 'limit' in self.request.query_params.keys(): + filters.update({'limit': [self.request.query_params['limit']]}) + if 'offset' in self.request.query_params.keys(): + filters.update({'offset': [self.request.query_params['offset']]}) + for backend in list(self.filter_backends): + filters.update(backend().get_filters_dict(request, self)) + self.filters_string = sorted_params_string(filters) + + +class CacheGetMixin(CacheUniqueUrl): + cache_prefix = None + cache_vary_on_user = False + cache_timeout_mins = None + default_page_size = 20 + extra_filters_dict = None + + def get_cache(self, request): + assert self.cache_prefix is not None, ( + "'%s' should include a `cache_prefix` attribute" + % self.__class__.__name__ + ) + + self.get_cache_unique_url(request) + + # Attempt to get the response from the cache for the whole request + try: + if self.cache_vary_on_user: + cache_attempt = self.cache.get(f"{self.cache_prefix}.{request.user}.{self.filters_string}") + else: + cache_attempt = self.cache.get(f"{self.cache_prefix}.{self.filters_string}") + except: + self.logger.info(f"Cache get attempt for {self.__class__.__name__} failed.") + cache_attempt = None + + if cache_attempt: + return cache_attempt + else: + return None + + +class CacheSetMixin(CacheUniqueUrl): + cache_prefix = None + cache_vary_on_user = False + cache_timeout_mins = None + default_page_size = 20 + extra_filters_dict = None + + def set_cache(self, request, response): + self.get_cache_unique_url(request) + + # Create a function that programmatically defines the caching function + def make_caching_function(cls, request, cache): + def caching_function(response): + # Writes the response to the cache + try: + if cls.cache_vary_on_user: + cls.cache.set(key=f"{cls.cache_prefix}.{request.user}.{self.filters_string}", + value = response.data, + timeout=60*cls.cache_timeout_mins) + else: + cls.cache.set(key=f"{cls.cache_prefix}.{self.filters_string}", + value = response.data, + timeout=60*cls.cache_timeout_mins) + except: + cls.logger.exception(f"Cache set attempt for {cls.__class__.__name__} failed.") + return caching_function + + # Register the post rendering hook to the response + caching_function = make_caching_function(self, request, self.cache) + response.add_post_render_callback(caching_function) + + +class CacheDeleteMixin(CacheUniqueUrl): + cache_delete = True + cache_prefix = None + cache_vary_on_user = False + cache_timeout_mins = None + extra_filters_dict = None + + def delete_cache(self, request): + # Handle the caching + if self.cache_delete: + # Create the query to be cached in a unique way to avoid duplicates + self.get_cache_unique_url(request) + + assert self.cache_prefix is not None, ( + f"{self.__class__.__name__} should include a `cache_prefix` attribute" + ) + + # Delete the cache since a new entry has been created + try: + if self.cache_vary_on_user: + self.cache.delete(f"{self.cache_prefix}.{request.user}.{self.filters_string}") + else: + self.cache.delete(f"{self.cache_prefix}.{self.filters_string}") + except: + self.logger.exception(f"Cache delete attempt for {cls.__class__.__name__} failed.") diff --git a/filters.py b/filters.py new file mode 100644 index 0000000..3f6c308 --- /dev/null +++ b/filters.py @@ -0,0 +1,793 @@ +from django_filters import rest_framework as filters +from django.contrib.postgres.search import TrigramSimilarity +from django.db.models.functions import Concat +from django.db.models import CharField +from shop.models.product import Product, Facet +from rest_framework.filters import BaseFilterBackend +import operator +from django.template import loader +from django.utils.translation import gettext_lazy as _ +from django.db.models import Max, Min, Count, Q +from rest_framework.settings import api_settings +from django.db.models.constants import LOOKUP_SEP +from django.db import models +from functools import reduce + +# TODO the dev pages are not done + +# Filters +class LessThanOrEqualFilter(BaseFilterBackend): + def get_less_than_field(self, view, request): + return getattr(view, 'less_than_field', None) + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + """ + less_than_field = self.get_less_than_field(view, request) + + assert less_than_field is not None, ( + f"{view.__class__.__name__} should include a `less_than_field` attribute" + ) + + filters_dict = {} + if less_than_field+'_max' in request.query_params.keys(): + field_value = request.query_params.get(less_than_field+'_max') + filters_dict[less_than_field+'_max'] = [field_value] + else: + filters_dict[less_than_field+'_max'] = [] + return filters_dict + + def filter_queryset(self, request, queryset, view): + # Return the correctly filtered queryset but also assign the filter dict to create the unique url for the cache + less_than_field = self.get_less_than_field(view, request) + if less_than_field+'_max' in request.query_params.keys(): + kwquery = {} + field_value = request.query_params.get(less_than_field+'_max') + kwquery[less_than_field+'__lte'] = field_value + return queryset.filter(**kwquery) + else: + return queryset + + +class MoreThanOrEqualFilter(BaseFilterBackend): + def get_more_than_field(self, view, request): + return getattr(view, 'more_than_field', None) + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + """ + more_than_field = self.get_more_than_field(view, request) + + assert more_than_field is not None, ( + f"{view.__class__.__name__} should include a `more_than_field` attribute" + ) + + filters_dict = {} + if more_than_field+'_min' in request.query_params.keys(): + field_value = request.query_params.get(more_than_field+'_min') + filters_dict[more_than_field+'_min'] = [field_value] + else: + filters_dict[more_than_field+'_min'] = [] + return filters_dict + + def filter_queryset(self, request, queryset, view): + """ + Return the correctly filtered queryset + """ + more_than_field = self.get_more_than_field(view, request) + if more_than_field+'_min' in request.query_params.keys(): + kwquery = {} + kwquery[more_than_field+'__gte'] = request.query_params.get(more_than_field+'_min') + return queryset.filter(**kwquery) + else: + return queryset + + +class CategoryFilter(BaseFilterBackend): + """ + This filter assigns the view.category object for later use, in particular for filters that depend on this one. + """ + template = './filters/categories.html' + category_field = 'category' + + def get_category_class(self, view, request): + return getattr(view, 'category_class', None) + + def assign_view_category(self, request, view): + if not hasattr(view, self.category_field): + if self.category_field in request.query_params.keys(): + try: + category_slug = request.query_params.get(self.category_field).strip() + category = view.category_class.objects.get(slug=category_slug) + # Append the category object in the view for later reference + view.category = category + except: + view.category = None + else: + view.category = None + + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. Queries the database for the current category and saves it in view.category for internal use and later filters. + """ + if hasattr(view, 'category_field'): + self.category_field = self.get_category_field(view, request) + + category_class = self.get_category_class(view, request) + + assert category_class is not None, ( + f"{view.__class__.__name__} should include a `category_class` attribute" + ) + + # Create the filters dictionary and find the present category instance + self.assign_view_category(request, view) + + filters_dict = {} + if view.category: + filters_dict[self.category_field] = [view.category.slug] + else: + filters_dict[self.category_field] = [] + + return filters_dict + + def filter_queryset(self, request, queryset, view): + self.assign_view_category(request, view) + + # Create the queryset + if view.category: + kwquery_1 = {} + kwquery_1[self.category_field+'__id'] = view.category.id + if view.category.tn_descendants_pks: + kwquery_2 = {} + kwquery_2[self.category_field+'__id__in'] = view.category.tn_descendants_pks.split(',') + + queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2)) + else: + queryset = queryset.filter(**kwquery_1) + + return queryset + + # Developer Interface methods + def get_valid_fields(self, queryset, view, context, request): + # A query is executed here to get the possible categories + category_class = self.get_category_class(view, request) + if hasattr(view, 'category_field'): + self.category_field = self.get_category_field(view, request) + + assert category_class is not None, ( + f"{view.__class__.__name__} should include a `category_class` attribute" + ) + + valid_fields = category_class.objects.all() + + if len(valid_fields): + valid_fields = [ + (item.slug, item.__str__()) for item in valid_fields + ] + + return valid_fields + else: + return [] + + def get_current(self, request, queryset, view): + params = request.query_params.get(self.category_field) + if params: + fields = [param.strip() for param in params.split(',')] + return fields[0] + else: + return None + + def get_template_context(self, request, queryset, view): + current = self.get_current(request, queryset, view) + options = [] + context = { + 'request': request, + 'current': current, + 'param': self.category_field, + } + for key, label in self.get_valid_fields(queryset, view, context, request): + options.append((key, '%s' % (label))) + context['options'] = options + return context + + def to_html(self, request, queryset, view): + template = loader.get_template(self.template) + context = self.get_template_context(request, queryset, view) + return template.render(context) + + +class FacetFilter(BaseFilterBackend): + """ + This filter requires CategoryFilter to be ran before it. It assigns the view.facets which includes all the facets applicable to the current category. + """ + template = './filters/facets.html' + + def get_facet_class(self, view, request): + return getattr(view, 'facet_class', None) + + def get_facet_tag_class(self, view, request): + return getattr(view, 'facet_tag_class', None) + + def get_facet_tag_field(self, view, request): + return getattr(view, 'facet_tag_field', None) + + def assign_view_facets(self, request, view): + if not hasattr(view, 'facets'): + if view.category: + kwquery_1 = {} + kwquery_1['category__id'] = view.category.id + if view.category.tn_descendants_pks: + kwquery_2 = {} + kwquery_2['category__id__in'] = view.category.tn_descendants_pks.split(',') + view.facets = Facet.objects.filter(Q(**kwquery_1) | Q(**kwquery_2)) + else: + view.facets = Facet.objects.filter(**kwquery_1) + else: + view.facets = {} + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + """ + if hasattr(view, 'facet_class'): + self.facet_class = self.get_facet_class(view, request) + + assert self.facet_class is not None, ( + f"{view.__class__.__name__} should include a `facet_class` attribute" + ) + + self.assign_view_facets(request, view) + + filters_dict = {} + if view.facets: + for facet in view.facets: + if facet.slug in request.query_params.keys(): + filters_dict[facet.slug] = set(request.query_params[facet.slug].split(',')) + else: + filters_dict[facet.slug] = [] + + # Append the facets object and the tags dict in the view for later reference + view.tags = filters_dict + + return filters_dict + + + def filter_queryset(self, request, queryset, view): + if hasattr(view, 'facet_tag_class'): + self.facet_tag_class = self.get_facet_tag_class(view, request) + + assert self.facet_tag_class is not None, ( + f"{view.__class__.__name__} should include a `facet_tag_class` attribute" + ) + + if hasattr(view, 'facet_tag_field'): + self.facet_tag_field = self.get_facet_tag_field(view, request) + + assert self.facet_tag_field is not None, ( + f"{view.__class__.__name__} should include a `facet_tag_field` attribute" + ) + + self.assign_view_facets(request, view) + + if view.facets: + for facet in view.facets: + if facet.slug in request.query_params.keys() and request.query_params[facet.slug]: + tag_filterlist = request.query_params.get(facet.slug) + kwquery = {} + kwquery[self.facet_tag_field+'__slug__in'] = tag_filterlist.replace(' ','').split(',') + queryset = queryset.filter(**kwquery) + + return queryset + + + # Developer Interface methods + def get_template_context(self, request, queryset, view): + # Does aggressive database querying to get the necessary facets and facettags, but this is only for the developer interface so its fine + if hasattr(view, 'facet_class'): + self.facet_class = self.get_facet_class(view, request) + + assert self.facet_class is not None, ( + f"{view.__class__.__name__} should include a `facet_class` attribute" + ) + + if hasattr(view, 'facet_tag_class'): + self.facet_tag_class = self.get_facet_tag_class(view, request) + + assert self.facet_tag_class is not None, ( + f"{view.__class__.__name__} should include a `facet_tag_class` attribute" + ) + + # Find the current choices + current = [] + facet_slugs = [] + if view.facets: + for facet in view.facets: + facet_slugs.append(facet.slug) + if facet.slug in request.query_params.keys(): + current.append(request.query_params.get(facet.slug)) + + facet_slug_names = {} + options = {} + context = { + 'request': request, + 'current': current, + 'facet_slugs': facet_slugs, + } + if view.facets: + for facet in view.facets: + facet_tag_instances = self.facet_tag_class.objects.filter(facet__slug=facet.slug) + options[facet.slug] = [(facet_tag.slug, facet_tag.name) for facet_tag in facet_tag_instances] + facet_slug_names[facet.slug] = facet.name + context['facet_slug_names'] = facet_slug_names + context['options'] = options + else: + context['facet_slug_names'] = {} + context['options'] = {} + + return context + + def to_html(self, request, queryset, view): + template = loader.get_template(self.template) + context = self.get_template_context(request, queryset, view) + return template.render(context) + + +class TrigramSearchFilter(BaseFilterBackend): + # The URL query parameter used for the search. + search_param = 'search' + template = 'rest_framework/filters/search.html' + search_title = _('Search') + search_description = _('A search string to perform trigram similarity based searching with.') + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + """ + self.filters_dict = {} + if 'search' in request.query_params.keys(): + slug_term = request.query_params.get('search') + self.filters_dict['search'] = [slug_term] + else: + self.filters_dict['search'] = [] + + return self.filters_dict + + def filter_queryset(self, request, queryset, view): + search_fields = getattr(view, 'search_fields', None) + + assert search_fields is not None, ( + f"{view.__class__.__name__} should include a `search_fields` attribute" + ) + + query = request.query_params.get(self.search_param, '') + + if query: + queryset = queryset.annotate( + search_field=Concat( + *search_fields, + output_field=CharField() + )).annotate( + similarity=TrigramSimilarity('search_field', query) + ).filter(similarity__gt=0.05).distinct() + + return queryset + + def to_html(self, request, queryset, view): + if not getattr(view, 'search_fields', None): + return '' + + term = request.query_params.get(self.search_param, '') + term = term[0] if term else '' + context = { + 'param': self.search_param, + 'term': term + } + template = loader.get_template(self.template) + return template.render(context) + + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + return [ + coreapi.Field( + name=self.search_param, + required=False, + location='query', + schema=coreschema.String( + title=force_str(self.search_title), + description=force_str(self.search_description) + ) + ) + ] + + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.search_param, + 'required': False, + 'in': 'query', + 'description': force_str(self.search_description), + 'schema': { + 'type': 'string', + }, + }, + ] + + +# TODO +#class FieldFilter(BaseFilterBackend): + + +# TODO misunderstood the urlconf stuff of the RUD methods, this is probably unnecessary +class SlugSearchFilter(BaseFilterBackend): + # The URL query parameter used for the search. + template = './filters/slug.html' + slug_title = _('Slug Search') + slug_description = _("The instance's slug.") + slug_field = 'slug' + + def get_slug_field(self, view, request): + return getattr(view, 'slug_field', None) + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + """ + if hasattr(view, 'slug_field'): + self.slug_field = self.get_slug_field(view, request) + + assert self.slug_field is not None, ( + f"{view.__class__.__name__} should include a `slug_field` attribute" + ) + self.filters_dict = {} + if self.slug_field in request.query_params.keys(): + slug_term = request.query_params.get(self.slug_field) + self.filters_dict[self.slug_field] = [slug_term] + else: + self.filters_dict[self.slug_field] = [] + + return self.filters_dict + + def filter_queryset(self, request, queryset, view): + # Ensure that the slug field was searched against + try: + if self.slug_field in request.query_params.keys(): + slug_term = request.query_params.get(self.slug_field) + query = {} + query[self.slug_field] = slug_term + + queryset = queryset.get(**query) + except Exception as e: + print(e) + + return queryset + + + def to_html(self, request, queryset, view): + if not getattr(view, 'slug_field', None): + return '' + + slug_term = self.get_slug_term(request) + context = { + 'param': self.slug_field, + 'term': slug_term + } + template = loader.get_template(self.template) + return template.render(context) + + +class SearchFilter(BaseFilterBackend): + # The URL query parameter used for the search. + search_param = api_settings.SEARCH_PARAM + template = 'rest_framework/filters/search.html' + lookup_prefixes = { + '^': 'istartswith', + '=': 'iexact', + '@': 'search', + '$': 'iregex', + } + search_title = _('Search') + search_description = _('A search term.') + # TODO to be removed + def get_search_terms(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.query_params.get(self.search_param, '') + params = params.replace('\x00', '') # strip null characters + params = params.replace(',', ' ') + return params.split() + # TODO to be removed + def construct_search(self, field_name): + lookup = self.lookup_prefixes.get(field_name[0]) + if lookup: + field_name = field_name[1:] + else: + lookup = 'icontains' + return LOOKUP_SEP.join([field_name, lookup]) + # TODO to be removed + def must_call_distinct(self, queryset, search_fields): + """ + Return True if 'distinct()' should be used to query the given lookups. + """ + for search_field in search_fields: + opts = queryset.model._meta + if search_field[0] in self.lookup_prefixes: + search_field = search_field[1:] + # Annotated fields do not need to be distinct + if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: + continue + parts = search_field.split(LOOKUP_SEP) + for part in parts: + field = opts.get_field(part) + if hasattr(field, 'get_path_info'): + # This field is a relation, update opts to follow the relation + path_info = field.get_path_info() + opts = path_info[-1].to_opts + if any(path.m2m for path in path_info): + # This field is a m2m relation so we know we need to call distinct + return True + else: + # This field has a custom __ query transform but is not a relational field. + break + return False + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + """ + self.filters_dict = {} + if 'search' in request.query_params.keys(): + slug_term = request.query_params.get('search') + self.filters_dict['search'] = [slug_term] + else: + self.filters_dict['search'] = [] + + return self.filters_dict + + def filter_queryset(self, request, queryset, view): + search_fields = getattr(view, 'search_fields', None) + search_terms = self.get_search_terms(request) + + if not search_fields or not search_terms: + return queryset + + orm_lookups = [ + self.construct_search(str(search_field)) + for search_field in search_fields + ] + + base = queryset + conditions = [] + for search_term in search_terms: + queries = [ + models.Q(**{orm_lookup: search_term}) + for orm_lookup in orm_lookups + ] + conditions.append(reduce(operator.or_, queries)) + queryset = queryset.filter(reduce(operator.and_, conditions)) + + if self.must_call_distinct(queryset, search_fields): + # Filtering against a many-to-many field requires us to + # call queryset.distinct() in order to avoid duplicate items + # in the resulting queryset. + # We try to avoid this if possible, for performance reasons. + queryset = distinct(queryset, base) + return queryset + + def to_html(self, request, queryset, view): + if not getattr(view, 'search_fields', None): + return '' + + term = self.get_search_terms(request) + term = term[0] if term else '' + context = { + 'param': self.search_param, + 'term': term + } + template = loader.get_template(self.template) + return template.render(context) + + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + return [ + coreapi.Field( + name=self.search_param, + required=False, + location='query', + schema=coreschema.String( + title=force_str(self.search_title), + description=force_str(self.search_description) + ) + ) + ] + + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.search_param, + 'required': False, + 'in': 'query', + 'description': force_str(self.search_description), + 'schema': { + 'type': 'string', + }, + }, + ] + + +class OrderingFilter(BaseFilterBackend): + # The URL query parameter used for the ordering. + ordering_param = api_settings.ORDERING_PARAM + ordering_fields = None + ordering_title = _('Ordering') + ordering_description = _('Which field to use when ordering the results.') + template = 'rest_framework/filters/ordering.html' + + def get_ordering(self, request, queryset, view): + """ + Ordering is set by a comma delimited ?ordering=... query parameter. + + The `ordering` query parameter can be overridden by setting + the `ordering_param` value on the OrderingFilter or by + specifying an `ORDERING_PARAM` value in the API settings. + """ + params = request.query_params.get(self.ordering_param) + if params: + fields = [param.strip() for param in params.split(',')] + ordering = self.remove_invalid_fields(queryset, fields, view, request) + if ordering: + return ordering + + # No ordering was included, or all the ordering fields were invalid + return self.get_default_ordering(view) + + def get_default_ordering(self, view): + ordering = getattr(view, 'ordering', None) + if isinstance(ordering, str): + return (ordering,) + return ordering + + def get_default_valid_fields(self, queryset, view, context={}): + # If `ordering_fields` is not specified, then we determine a default + # based on the serializer class, if one exists on the view. + if hasattr(view, 'get_serializer_class'): + try: + serializer_class = view.get_serializer_class() + except AssertionError: + # Raised by the default implementation if + # no serializer_class was found + serializer_class = None + else: + serializer_class = getattr(view, 'serializer_class', None) + + if serializer_class is None: + msg = ( + "Cannot use %s on a view which does not have either a " + "'serializer_class', an overriding 'get_serializer_class' " + "or 'ordering_fields' attribute." + ) + raise ImproperlyConfigured(msg % self.__class__.__name__) + + model_class = queryset.model + model_property_names = [ + # 'pk' is a property added in Django's Model class, however it is valid for ordering. + attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk' + ] + + return [ + (field.source.replace('.', '__') or field_name, field.label) + for field_name, field in serializer_class(context=context).fields.items() + if ( + not getattr(field, 'write_only', False) and + not field.source == '*' and + field.source not in model_property_names + ) + ] + + def get_valid_fields(self, queryset, view, context={}): + valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) + + if valid_fields is None: + # Default to allowing filtering on serializer fields + return self.get_default_valid_fields(queryset, view, context) + + elif valid_fields == '__all__': + # View explicitly allows filtering on any model field + valid_fields = [ + (field.name, field.verbose_name) for field in queryset.model._meta.fields + ] + valid_fields += [ + (key, key.title().split('__')) + for key in queryset.query.annotations + ] + else: + valid_fields = [ + (item, item) if isinstance(item, str) else item + for item in valid_fields + ] + + return valid_fields + + def remove_invalid_fields(self, queryset, fields, view, request): + valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})] + + def term_valid(term): + if term.startswith("-"): + term = term[1:] + return term in valid_fields + + return [term for term in fields if term_valid(term)] + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + """ + self.filters_dict = {} + if 'ordering' in request.query_params.keys(): + slug_term = request.query_params.get('ordering') + self.filters_dict['ordering'] = [slug_term] + else: + self.filters_dict['ordering'] = [view.ordering_fields[0]] + + return self.filters_dict + + def filter_queryset(self, request, queryset, view): + ordering = self.get_ordering(request, queryset, view) + + if ordering: + return queryset.order_by(*ordering) + + return queryset + + def get_template_context(self, request, queryset, view): + current = self.get_ordering(request, queryset, view) + current = None if not current else current[0] + options = [] + context = { + 'request': request, + 'current': current, + 'param': self.ordering_param, + } + for key, label in self.get_valid_fields(queryset, view, context): + options.append((key, '%s - %s' % (label, _('ascending')))) + options.append(('-' + key, '%s - %s' % (label, _('descending')))) + context['options'] = options + return context + + def to_html(self, request, queryset, view): + template = loader.get_template(self.template) + context = self.get_template_context(request, queryset, view) + return template.render(context) + + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + return [ + coreapi.Field( + name=self.ordering_param, + required=False, + location='query', + schema=coreschema.String( + title=force_str(self.ordering_title), + description=force_str(self.ordering_description) + ) + ) + ] + + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.ordering_param, + 'required': False, + 'in': 'query', + 'description': force_str(self.ordering_description), + 'schema': { + 'type': 'string', + }, + }, + ] diff --git a/generics.py b/generics.py new file mode 100644 index 0000000..7cf8df7 --- /dev/null +++ b/generics.py @@ -0,0 +1,214 @@ +""" +Generic views that provide commonly needed behaviour. +""" +from django.core.exceptions import ValidationError +from django.db.models.query import QuerySet +from django.http import Http404 +from django.shortcuts import get_object_or_404 as _get_object_or_404 + +from rest_framework import views +from rest_framework.generics import GenericAPIView +from rest_framework.settings import api_settings + +from libraries import mixins + + +# Concrete view classes that provide method handlers +# by composing the mixin classes with the base view. + +# Single item CRUD + +class CachedCreateAPIView(mixins.CachedCreateModelMixin,GenericAPIView): + """ + Concrete view for creating a model instance. + """ + def post(self, request, *args, **kwargs): + return self.create(request, *args, **kwargs) + + +class CachedRetrieveAPIView(mixins.CachedRetrieveModelMixin,GenericAPIView): + """ + Concrete view for retrieving a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,GenericAPIView): + """ + Concrete view for updating a model instance. + """ + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + +class CachedDestroyAPIView(mixins.CachedDestroyModelMixin,GenericAPIView): + """ + Concrete view for deleting a model instance. + """ + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + +class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,GenericAPIView): + """ + Concrete view for retrieving, updating a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + +class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView): + """ + Concrete view for retrieving or deleting a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + +class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView): + """ + Concrete view for retrieving, updating or deleting a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + +class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView): + """ + Concrete view for creating, retrieving, updating or deleting a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def post(self, request, *args, **kwargs): + return self.create(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + +# List based CRUD + +class CachedListRetrieveAPIView(mixins.CachedListRetrieveModelMixin,GenericAPIView): + """ + Concrete view for listing a queryset. + """ + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + +class CachedListCreateAPIView(mixins.CachedListCreateModelMixin,GenericAPIView): + """ + Concrete view for creating multiple instances. + """ + def post(self, request, *args, **kwargs): + return self.list_create(request, *args, **kwargs) + + +class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,GenericAPIView): + """ + Concrete view for updating multiple instances. + """ + def put(self, request, *args, **kwargs): + return self.list_update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.list_partial_update(request, *args, **kwargs) + + +class CachedListDestroyAPIView(mixins.CachedListDestroyModelMixin,GenericAPIView): + """ + Concrete view for deleting multiple instances. + """ + def delete(self, request, *args, **kwargs): + return self.list_destroy(request, *args, **kwargs) + + +class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,mixins.CachedListCreateModelMixin,GenericAPIView): + """ + Concrete view for listing a queryset or creating a model instance. + """ + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def post(self, request, *args, **kwargs): + return self.create(request, *args, **kwargs) + + +class CachedListCreateRetrieveDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView): + """ + Concrete view for creating, retrieving or deleting a model instance. + """ + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def post(self, request, *args, **kwargs): + return self.list_create(request, *args, **kwargs) + + def delete(self, request, *args, **kwargs): + return self.list_destroy(request, *args, **kwargs) + + +class CachedListCreateRetrieveUpdateAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,GenericAPIView): + """ + Concrete view for creating, retrieving, updating or deleting a model instance. + """ + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def post(self, request, *args, **kwargs): + return self.list_create(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.list_update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.list_partial_update(request, *args, **kwargs) + + +class CachedListCreateRetrieveUpdateDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView): + """ + Concrete view for creating, retrieving, updating or deleting a model instance. + """ + def get(self, request, *args, **kwargs): + return self.list(request, *args, **kwargs) + + def post(self, request, *args, **kwargs): + return self.list_create(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.list_update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.list_partial_update(request, *args, **kwargs) + + def delete(self, request, *args, **kwargs): + return self.list_destroy(request, *args, **kwargs) diff --git a/mixins.py b/mixins.py new file mode 100644 index 0000000..5526536 --- /dev/null +++ b/mixins.py @@ -0,0 +1,211 @@ +""" +Basic building blocks for generic class based views. + +We don't bind behaviour to http method handlers yet, +which allows mixin classes to be composed in interesting ways. +""" +from rest_framework import status +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework import mixins +from libraries.cache_mixins import CacheGetMixin, CacheSetMixin, CacheDeleteMixin + + +# Mixin classes to be included in the generic classes +class CachedCreateModelMixin(CacheDeleteMixin, mixins.CreateModelMixin): + """ + A slightly modified version of rest_framework.mixins.CreateModelMixin that handles cache deletions. + """ + def create(self, request, *args, **kwargs): + # Go on with the creation as normal + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + headers = self.get_success_headers(serializer.data) + + # Delete the cache + self.delete_cache(request) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + + +class CachedRetrieveModelMixin(CacheGetMixin, CacheSetMixin): + """ + A slightly modified version of rest_framework.mixins.RetrieveModelMixin that handles cache attempts. + mixins.RetrieveModelMixin only has the retrieve method so it doesn't stand to inherit anything from it. + """ + def retrieve(self, request, *args, **kwargs): + # Attempt to get the request from the cache + cache_attempt = self.get_cache(request) + + if cache_attempt: + return Response(cache_attempt) + else: + # The cache get attempt failed so we have to get the results from the database + instance = self.get_object() + + serializer = self.get_serializer(instance) + response = Response(serializer.data) + + self.set_cache(request, response) + return response + + +class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin): + """ + A slightly modified version of rest_framework.mixins.UpdateModelMixin that handles cache deletes. + """ + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + if getattr(instance, '_prefetched_objects_cache', None): + # If 'prefetch_related' has been applied to a queryset, we need to + # forcibly invalidate the prefetch cache on the instance. + instance._prefetched_objects_cache = {} + + # Delete the related caches + self.delete_cache(request) + + return Response(serializer.data) + + +class CachedDestroyModelMixin(CacheDeleteMixin, mixins.DestroyModelMixin): + """ + A slightly modified version of rest_framework.mixins.DestroyModelMixin that handles cache deletes. + """ + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + + # Delete the related caches + self.delete_cache(request) + + return Response(status=status.HTTP_204_NO_CONTENT) + + +# List mixin classes to be included with list generic classes +class CachedListCreateModelMixin(CacheDeleteMixin): + """ + A fully custom mixin that handles mutiple instance cration. + """ + def list_create(self, request, *args, **kwargs): + # Go on with the creation as normal + serializer = self.get_serializer(data=request.data, many=True) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + headers = self.get_success_headers(serializer.data) + + # Delete the cache + self.delete_cache(request) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + + def perform_create(self, serializer): + serializer.save() + + def get_success_headers(self, data): + try: + return {'Location': str(data[api_settings.URL_FIELD_NAME])} + except (TypeError, KeyError): + return {} + + +class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin): + """ + A slightly modified version of rest_framework.mixins.ListModelMixin that handles cache saves. + mixins.ListModelMixin only has the list method so it doesn't stand to inherit anything from it. + """ + def list(self, request, *args, **kwargs): + # Attempt to get the request from the cache + cache_attempt = self.get_cache(request) + + if cache_attempt: + return Response(cache_attempt) + else: + # The cache get attempt failed so we have to get the results from the database + queryset = self.filter_queryset(self.get_queryset()) + + if self.paged: + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + response = self.get_paginated_response(serializer.data) + else: + serializer = self.get_serializer(queryset, many=True) + response = Response(serializer.data) + else: + serializer = self.get_serializer(queryset, many=True) + response = Response(serializer.data) + + self.set_cache(request, response) + return response + + +class CachedListUpdateModelMixin(CacheDeleteMixin): + """ + A fully custom mixin that handles mutiple instance updates. + """ + def list_update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + + queryset = self.filter_queryset(self.get_queryset()) + + serializer = self.get_serializer(queryset, data=request.data, partial=partial, many=True) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + # Delete the related caches + self.delete_cache(request) + + return Response(serializer.data) + + def perform_update(self, serializer): + serializer.save() + + def list_partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.list_update(request, *args, **kwargs) + + +class CachedListDestroyModelMixin(CacheDeleteMixin): + """ + A fully custom mixin that handles mutiple instance deletions. + """ + def list_destroy(self, request, *args, **kwargs): + # Go on with the validation as normal + serializer = self.get_serializer(data=request.data, many=True) + serializer.is_valid(raise_exception=True) + validated_data = serializer.validated_data + + # TODO does this new stuff work even? need to check on the frontend + serializer.delete(validated_data) + +# for instance in self.get_objects(): +# if instance is not None: +# self.perform_destroy(instance) + + # Delete the related caches + self.delete_cache(request) + + return Response(status=status.HTTP_204_NO_CONTENT) + + #def perform_destroy(self, instance): + # instance.delete() + + #def get_objects(self): + # """ + # The custom list version of get_object that retrieves one instance from the #database. It yields model instances with each call. + # """ + # queryset = self.filter_queryset(self.get_queryset()) + # + # if len(queryset): + # for obj in queryset.all(): + + # # May raise a permission denied + # self.check_object_permissions(self.request, obj) + + # yield obj + + #yield None diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..63d261f --- /dev/null +++ b/utils.py @@ -0,0 +1,24 @@ +def sorted_params_string(filters_dict): + """ + This function takes a dict and returns it in a sorted form for the url filter, it's primarily used for cache purposes. + """ + filters_string = '' + for key in sorted(filters_dict.keys()): + if filters_string == '': + filters_string = f"{key}={','.join(str(val) for val in sorted(filters_dict[key]))}" + else: + filters_string = f"{filters_string}&{key}={','.join(str(val) for val in sorted(filters_dict[key]))}" + filters_string = filters_string.strip() + return filters_string + + +def parse_tags_to_dict(tags): + tagdict = {} + if ':' not in tags: + tagdict = {} + else: + for subtag in tags.split('&'): + tagkey, taglist = subtag.split(':') + taglist = taglist.split(',') + tagdict[tagkey] = set(taglist) + return tagdict