from django_filters import rest_framework as filters from django.contrib.postgres.search import TrigramSimilarity from django.db.models.functions import Concat from django.db.models import CharField from shop.models.product import Product, Facet from rest_framework.filters import BaseFilterBackend import operator from django.template import loader from django.utils.translation import gettext_lazy as _ from django.db.models import Max, Min, Count, Q from rest_framework.settings import api_settings from django.db.models.constants import LOOKUP_SEP from django.db import models from functools import reduce # TODO the dev pages are not done # Filters class LessThanOrEqualFilter(BaseFilterBackend): def get_less_than_field(self, view, request): return getattr(view, 'less_than_field', None) def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. """ less_than_field = self.get_less_than_field(view, request) assert less_than_field is not None, ( f"{view.__class__.__name__} should include a `less_than_field` attribute" ) filters_dict = {} if less_than_field+'_max' in request.query_params.keys(): field_value = request.query_params.get(less_than_field+'_max') filters_dict[less_than_field+'_max'] = [field_value] else: filters_dict[less_than_field+'_max'] = [] return filters_dict def filter_queryset(self, request, queryset, view): # Return the correctly filtered queryset but also assign the filter dict to create the unique url for the cache less_than_field = self.get_less_than_field(view, request) if less_than_field+'_max' in request.query_params.keys(): kwquery = {} field_value = request.query_params.get(less_than_field+'_max') kwquery[less_than_field+'__lte'] = field_value return queryset.filter(**kwquery) else: return queryset class MoreThanOrEqualFilter(BaseFilterBackend): def get_more_than_field(self, view, request): return getattr(view, 'more_than_field', None) def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. """ more_than_field = self.get_more_than_field(view, request) assert more_than_field is not None, ( f"{view.__class__.__name__} should include a `more_than_field` attribute" ) filters_dict = {} if more_than_field+'_min' in request.query_params.keys(): field_value = request.query_params.get(more_than_field+'_min') filters_dict[more_than_field+'_min'] = [field_value] else: filters_dict[more_than_field+'_min'] = [] return filters_dict def filter_queryset(self, request, queryset, view): """ Return the correctly filtered queryset """ more_than_field = self.get_more_than_field(view, request) if more_than_field+'_min' in request.query_params.keys(): kwquery = {} kwquery[more_than_field+'__gte'] = request.query_params.get(more_than_field+'_min') return queryset.filter(**kwquery) else: return queryset class CategoryFilter(BaseFilterBackend): """ This filter assigns the view.category object for later use, in particular for filters that depend on this one. """ template = './filters/categories.html' category_field = 'category' def get_category_class(self, view, request): return getattr(view, 'category_class', None) def assign_view_category(self, request, view): if not hasattr(view, self.category_field): if self.category_field in request.query_params.keys(): try: category_slug = request.query_params.get(self.category_field).strip() category = view.category_class.objects.get(slug=category_slug) # Append the category object in the view for later reference view.category = category except: view.category = None else: view.category = None def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. Queries the database for the current category and saves it in view.category for internal use and later filters. """ if hasattr(view, 'category_field'): self.category_field = self.get_category_field(view, request) category_class = self.get_category_class(view, request) assert category_class is not None, ( f"{view.__class__.__name__} should include a `category_class` attribute" ) # Create the filters dictionary and find the present category instance self.assign_view_category(request, view) filters_dict = {} if view.category: filters_dict[self.category_field] = [view.category.slug] else: filters_dict[self.category_field] = [] return filters_dict def filter_queryset(self, request, queryset, view): self.assign_view_category(request, view) # Create the queryset if view.category: kwquery_1 = {} kwquery_1[self.category_field+'__id'] = view.category.id if view.category.tn_descendants_pks: kwquery_2 = {} kwquery_2[self.category_field+'__id__in'] = view.category.tn_descendants_pks.split(',') queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2)) else: queryset = queryset.filter(**kwquery_1) return queryset # Developer Interface methods def get_valid_fields(self, queryset, view, context, request): # A query is executed here to get the possible categories category_class = self.get_category_class(view, request) if hasattr(view, 'category_field'): self.category_field = self.get_category_field(view, request) assert category_class is not None, ( f"{view.__class__.__name__} should include a `category_class` attribute" ) valid_fields = category_class.objects.all() if len(valid_fields): valid_fields = [ (item.slug, item.__str__()) for item in valid_fields ] return valid_fields else: return [] def get_current(self, request, queryset, view): params = request.query_params.get(self.category_field) if params: fields = [param.strip() for param in params.split(',')] return fields[0] else: return None def get_template_context(self, request, queryset, view): current = self.get_current(request, queryset, view) options = [] context = { 'request': request, 'current': current, 'param': self.category_field, } for key, label in self.get_valid_fields(queryset, view, context, request): options.append((key, '%s' % (label))) context['options'] = options return context def to_html(self, request, queryset, view): template = loader.get_template(self.template) context = self.get_template_context(request, queryset, view) return template.render(context) class FacetFilter(BaseFilterBackend): """ This filter requires CategoryFilter to be ran before it. It assigns the view.facets which includes all the facets applicable to the current category. """ template = './filters/facets.html' def get_facet_class(self, view, request): return getattr(view, 'facet_class', None) def get_facet_tag_class(self, view, request): return getattr(view, 'facet_tag_class', None) def get_facet_tag_field(self, view, request): return getattr(view, 'facet_tag_field', None) def assign_view_facets(self, request, view): if not hasattr(view, 'facets'): if view.category: if view.category.tn_ancestors_pks: view.facets = Facet.objects.filter(Q(category__id=view.category.id) | Q(category__id__in=view.category.tn_ancestors_pks.split(','))).prefetch_related('facet_tags') else: view.facets = Facet.objects.filter(category__id=view.category.id).prefetch_related('facet_tags') else: view.facets = Facet.objects.filter(category__tn_level=1).prefetch_related('facet_tags') def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. """ if hasattr(view, 'facet_class'): self.facet_class = self.get_facet_class(view, request) assert self.facet_class is not None, ( f"{view.__class__.__name__} should include a `facet_class` attribute" ) self.assign_view_facets(request, view) filters_dict = {} if view.facets: for facet in view.facets: if facet.slug in request.query_params.keys(): filters_dict[facet.slug] = set(request.query_params[facet.slug].split(',')) else: filters_dict[facet.slug] = set({}) # Append the facets object and the tags dict in the view for later reference view.tags = filters_dict return filters_dict def filter_queryset(self, request, queryset, view): if hasattr(view, 'facet_tag_class'): self.facet_tag_class = self.get_facet_tag_class(view, request) assert self.facet_tag_class is not None, ( f"{view.__class__.__name__} should include a `facet_tag_class` attribute" ) if hasattr(view, 'facet_tag_field'): self.facet_tag_field = self.get_facet_tag_field(view, request) assert self.facet_tag_field is not None, ( f"{view.__class__.__name__} should include a `facet_tag_field` attribute" ) self.assign_view_facets(request, view) if view.facets: for facet in view.facets: if facet.slug in request.query_params.keys() and request.query_params[facet.slug]: tag_filterlist = request.query_params.get(facet.slug) if tag_filterlist == '': # If the tag filterlist is empty then we're not filtering against it, it's like having all the tags of the facet selected pass else: kwquery = {} kwquery[self.facet_tag_field+'__slug__in'] = tag_filterlist.replace(' ','').split(',') queryset = queryset.filter(**kwquery) return queryset # Developer Interface methods def get_template_context(self, request, queryset, view): # Does aggressive database querying to get the necessary facets and facettags, but this is only for the developer interface so its fine if hasattr(view, 'facet_class'): self.facet_class = self.get_facet_class(view, request) assert self.facet_class is not None, ( f"{view.__class__.__name__} should include a `facet_class` attribute" ) if hasattr(view, 'facet_tag_class'): self.facet_tag_class = self.get_facet_tag_class(view, request) assert self.facet_tag_class is not None, ( f"{view.__class__.__name__} should include a `facet_tag_class` attribute" ) # Find the current choices current = [] facet_slugs = [] if view.facets: for facet in view.facets: facet_slugs.append(facet.slug) if facet.slug in request.query_params.keys(): current.append(request.query_params.get(facet.slug)) facet_slug_names = {} options = {} context = { 'request': request, 'current': current, 'facet_slugs': facet_slugs, } if view.facets: for facet in view.facets: facet_tag_instances = self.facet_tag_class.objects.filter(facet__slug=facet.slug) options[facet.slug] = [(facet_tag.slug, facet_tag.name) for facet_tag in facet_tag_instances] facet_slug_names[facet.slug] = facet.name context['facet_slug_names'] = facet_slug_names context['options'] = options else: context['facet_slug_names'] = {} context['options'] = {} return context def to_html(self, request, queryset, view): template = loader.get_template(self.template) context = self.get_template_context(request, queryset, view) return template.render(context) class TrigramSearchFilter(BaseFilterBackend): # The URL query parameter used for the search. search_param = 'search' template = 'rest_framework/filters/search.html' search_title = _('Search') search_description = _('A search string to perform trigram similarity based searching with.') def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. """ self.filters_dict = {} if 'search' in request.query_params.keys(): slug_term = request.query_params.get('search') self.filters_dict['search'] = [slug_term] else: self.filters_dict['search'] = [] return self.filters_dict def filter_queryset(self, request, queryset, view): search_fields = getattr(view, 'search_fields', None) assert search_fields is not None, ( f"{view.__class__.__name__} should include a `search_fields` attribute" ) query = request.query_params.get(self.search_param, '') if query: queryset = queryset.annotate( search_field=Concat( *search_fields, output_field=CharField() )).annotate( similarity=TrigramSimilarity('search_field', query) ).filter(similarity__gt=0.05).distinct() return queryset def to_html(self, request, queryset, view): if not getattr(view, 'search_fields', None): return '' term = request.query_params.get(self.search_param, '') term = term[0] if term else '' context = { 'param': self.search_param, 'term': term } template = loader.get_template(self.template) return template.render(context) def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' return [ coreapi.Field( name=self.search_param, required=False, location='query', schema=coreschema.String( title=force_str(self.search_title), description=force_str(self.search_description) ) ) ] def get_schema_operation_parameters(self, view): return [ { 'name': self.search_param, 'required': False, 'in': 'query', 'description': force_str(self.search_description), 'schema': { 'type': 'string', }, }, ] # TODO #class FieldFilter(BaseFilterBackend): # TODO misunderstood the urlconf stuff of the RUD methods, this is probably unnecessary class SlugSearchFilter(BaseFilterBackend): # The URL query parameter used for the search. template = './filters/slug.html' slug_title = _('Slug Search') slug_description = _("The instance's slug.") slug_field = 'slug' def get_slug_field(self, view, request): return getattr(view, 'slug_field', None) def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. """ if hasattr(view, 'slug_field'): self.slug_field = self.get_slug_field(view, request) assert self.slug_field is not None, ( f"{view.__class__.__name__} should include a `slug_field` attribute" ) self.filters_dict = {} if self.slug_field in request.query_params.keys(): slug_term = request.query_params.get(self.slug_field) self.filters_dict[self.slug_field] = [slug_term] else: self.filters_dict[self.slug_field] = [] return self.filters_dict def filter_queryset(self, request, queryset, view): # Ensure that the slug field was searched against try: if self.slug_field in request.query_params.keys(): slug_term = request.query_params.get(self.slug_field) query = {} query[self.slug_field] = slug_term queryset = queryset.get(**query) except Exception as e: print(e) return queryset def to_html(self, request, queryset, view): if not getattr(view, 'slug_field', None): return '' slug_term = self.get_slug_term(request) context = { 'param': self.slug_field, 'term': slug_term } template = loader.get_template(self.template) return template.render(context) class SearchFilter(BaseFilterBackend): # The URL query parameter used for the search. search_param = api_settings.SEARCH_PARAM template = 'rest_framework/filters/search.html' lookup_prefixes = { '^': 'istartswith', '=': 'iexact', '@': 'search', '$': 'iregex', } search_title = _('Search') search_description = _('A search term.') # TODO to be removed def get_search_terms(self, request): """ Search terms are set by a ?search=... query parameter, and may be comma and/or whitespace delimited. """ params = request.query_params.get(self.search_param, '') params = params.replace('\x00', '') # strip null characters params = params.replace(',', ' ') return params.split() # TODO to be removed def construct_search(self, field_name): lookup = self.lookup_prefixes.get(field_name[0]) if lookup: field_name = field_name[1:] else: lookup = 'icontains' return LOOKUP_SEP.join([field_name, lookup]) # TODO to be removed def must_call_distinct(self, queryset, search_fields): """ Return True if 'distinct()' should be used to query the given lookups. """ for search_field in search_fields: opts = queryset.model._meta if search_field[0] in self.lookup_prefixes: search_field = search_field[1:] # Annotated fields do not need to be distinct if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: continue parts = search_field.split(LOOKUP_SEP) for part in parts: field = opts.get_field(part) if hasattr(field, 'get_path_info'): # This field is a relation, update opts to follow the relation path_info = field.get_path_info() opts = path_info[-1].to_opts if any(path.m2m for path in path_info): # This field is a m2m relation so we know we need to call distinct return True else: # This field has a custom __ query transform but is not a relational field. break return False def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. """ self.filters_dict = {} if 'search' in request.query_params.keys(): slug_term = request.query_params.get('search') self.filters_dict['search'] = [slug_term] else: self.filters_dict['search'] = [] return self.filters_dict def filter_queryset(self, request, queryset, view): search_fields = getattr(view, 'search_fields', None) search_terms = self.get_search_terms(request) if not search_fields or not search_terms: return queryset orm_lookups = [ self.construct_search(str(search_field)) for search_field in search_fields ] base = queryset conditions = [] for search_term in search_terms: queries = [ models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups ] conditions.append(reduce(operator.or_, queries)) queryset = queryset.filter(reduce(operator.and_, conditions)) if self.must_call_distinct(queryset, search_fields): # Filtering against a many-to-many field requires us to # call queryset.distinct() in order to avoid duplicate items # in the resulting queryset. # We try to avoid this if possible, for performance reasons. queryset = distinct(queryset, base) return queryset def to_html(self, request, queryset, view): if not getattr(view, 'search_fields', None): return '' term = self.get_search_terms(request) term = term[0] if term else '' context = { 'param': self.search_param, 'term': term } template = loader.get_template(self.template) return template.render(context) def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' return [ coreapi.Field( name=self.search_param, required=False, location='query', schema=coreschema.String( title=force_str(self.search_title), description=force_str(self.search_description) ) ) ] def get_schema_operation_parameters(self, view): return [ { 'name': self.search_param, 'required': False, 'in': 'query', 'description': force_str(self.search_description), 'schema': { 'type': 'string', }, }, ] class OrderingFilter(BaseFilterBackend): # The URL query parameter used for the ordering. ordering_param = api_settings.ORDERING_PARAM ordering_fields = None ordering_title = _('Ordering') ordering_description = _('Which field to use when ordering the results.') template = 'rest_framework/filters/ordering.html' def get_ordering(self, request, queryset, view): """ Ordering is set by a comma delimited ?ordering=... query parameter. The `ordering` query parameter can be overridden by setting the `ordering_param` value on the OrderingFilter or by specifying an `ORDERING_PARAM` value in the API settings. """ params = request.query_params.get(self.ordering_param) if params: fields = [param.strip() for param in params.split(',')] ordering = self.remove_invalid_fields(queryset, fields, view, request) if ordering: return ordering # No ordering was included, or all the ordering fields were invalid return self.get_default_ordering(view) def get_default_ordering(self, view): ordering = getattr(view, 'ordering', None) if isinstance(ordering, str): return (ordering,) return ordering def get_default_valid_fields(self, queryset, view, context={}): # If `ordering_fields` is not specified, then we determine a default # based on the serializer class, if one exists on the view. if hasattr(view, 'get_serializer_class'): try: serializer_class = view.get_serializer_class() except AssertionError: # Raised by the default implementation if # no serializer_class was found serializer_class = None else: serializer_class = getattr(view, 'serializer_class', None) if serializer_class is None: msg = ( "Cannot use %s on a view which does not have either a " "'serializer_class', an overriding 'get_serializer_class' " "or 'ordering_fields' attribute." ) raise ImproperlyConfigured(msg % self.__class__.__name__) model_class = queryset.model model_property_names = [ # 'pk' is a property added in Django's Model class, however it is valid for ordering. attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk' ] return [ (field.source.replace('.', '__') or field_name, field.label) for field_name, field in serializer_class(context=context).fields.items() if ( not getattr(field, 'write_only', False) and not field.source == '*' and field.source not in model_property_names ) ] def get_valid_fields(self, queryset, view, context={}): valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) if valid_fields is None: # Default to allowing filtering on serializer fields return self.get_default_valid_fields(queryset, view, context) elif valid_fields == '__all__': # View explicitly allows filtering on any model field valid_fields = [ (field.name, field.verbose_name) for field in queryset.model._meta.fields ] valid_fields += [ (key, key.title().split('__')) for key in queryset.query.annotations ] else: valid_fields = [ (item, item) if isinstance(item, str) else item for item in valid_fields ] return valid_fields def remove_invalid_fields(self, queryset, fields, view, request): valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})] def term_valid(term): if term.startswith("-"): term = term[1:] return term in valid_fields return [term for term in fields if term_valid(term)] def get_filters_dict(self, request, view): """ Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. """ self.filters_dict = {} if 'ordering' in request.query_params.keys(): slug_term = request.query_params.get('ordering') self.filters_dict['ordering'] = [slug_term] else: self.filters_dict['ordering'] = [view.ordering_fields[0]] return self.filters_dict def filter_queryset(self, request, queryset, view): ordering = self.get_ordering(request, queryset, view) if ordering: return queryset.order_by(*ordering) return queryset def get_template_context(self, request, queryset, view): current = self.get_ordering(request, queryset, view) current = None if not current else current[0] options = [] context = { 'request': request, 'current': current, 'param': self.ordering_param, } for key, label in self.get_valid_fields(queryset, view, context): options.append((key, '%s - %s' % (label, _('ascending')))) options.append(('-' + key, '%s - %s' % (label, _('descending')))) context['options'] = options return context def to_html(self, request, queryset, view): template = loader.get_template(self.template) context = self.get_template_context(request, queryset, view) return template.render(context) def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' return [ coreapi.Field( name=self.ordering_param, required=False, location='query', schema=coreschema.String( title=force_str(self.ordering_title), description=force_str(self.ordering_description) ) ) ] def get_schema_operation_parameters(self, view): return [ { 'name': self.ordering_param, 'required': False, 'in': 'query', 'description': force_str(self.ordering_description), 'schema': { 'type': 'string', }, }, ]