Files
starfields-drf-generics/django-drf-generics/filters.py

794 lines
30 KiB
Python

from django_filters import rest_framework as filters
from django.contrib.postgres.search import TrigramSimilarity
from django.db.models.functions import Concat
from django.db.models import CharField
from shop.models.product import Product, Facet
from rest_framework.filters import BaseFilterBackend
import operator
from django.template import loader
from django.utils.translation import gettext_lazy as _
from django.db.models import Max, Min, Count, Q
from rest_framework.settings import api_settings
from django.db.models.constants import LOOKUP_SEP
from django.db import models
from functools import reduce
# TODO the dev pages are not done
# Filters
class LessThanOrEqualFilter(BaseFilterBackend):
def get_less_than_field(self, view, request):
return getattr(view, 'less_than_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
less_than_field = self.get_less_than_field(view, request)
assert less_than_field is not None, (
f"{view.__class__.__name__} should include a `less_than_field` attribute"
)
filters_dict = {}
if less_than_field+'_max' in request.query_params.keys():
field_value = request.query_params.get(less_than_field+'_max')
filters_dict[less_than_field+'_max'] = [field_value]
else:
filters_dict[less_than_field+'_max'] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
# Return the correctly filtered queryset but also assign the filter dict to create the unique url for the cache
less_than_field = self.get_less_than_field(view, request)
if less_than_field+'_max' in request.query_params.keys():
kwquery = {}
field_value = request.query_params.get(less_than_field+'_max')
kwquery[less_than_field+'__lte'] = field_value
return queryset.filter(**kwquery)
else:
return queryset
class MoreThanOrEqualFilter(BaseFilterBackend):
def get_more_than_field(self, view, request):
return getattr(view, 'more_than_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
more_than_field = self.get_more_than_field(view, request)
assert more_than_field is not None, (
f"{view.__class__.__name__} should include a `more_than_field` attribute"
)
filters_dict = {}
if more_than_field+'_min' in request.query_params.keys():
field_value = request.query_params.get(more_than_field+'_min')
filters_dict[more_than_field+'_min'] = [field_value]
else:
filters_dict[more_than_field+'_min'] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
"""
Return the correctly filtered queryset
"""
more_than_field = self.get_more_than_field(view, request)
if more_than_field+'_min' in request.query_params.keys():
kwquery = {}
kwquery[more_than_field+'__gte'] = request.query_params.get(more_than_field+'_min')
return queryset.filter(**kwquery)
else:
return queryset
class CategoryFilter(BaseFilterBackend):
"""
This filter assigns the view.category object for later use, in particular for filters that depend on this one.
"""
template = './filters/categories.html'
category_field = 'category'
def get_category_class(self, view, request):
return getattr(view, 'category_class', None)
def assign_view_category(self, request, view):
if not hasattr(view, self.category_field):
if self.category_field in request.query_params.keys():
try:
category_slug = request.query_params.get(self.category_field).strip()
category = view.category_class.objects.get(slug=category_slug)
# Append the category object in the view for later reference
view.category = category
except:
view.category = None
else:
view.category = None
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. Queries the database for the current category and saves it in view.category for internal use and later filters.
"""
if hasattr(view, 'category_field'):
self.category_field = self.get_category_field(view, request)
category_class = self.get_category_class(view, request)
assert category_class is not None, (
f"{view.__class__.__name__} should include a `category_class` attribute"
)
# Create the filters dictionary and find the present category instance
self.assign_view_category(request, view)
filters_dict = {}
if view.category:
filters_dict[self.category_field] = [view.category.slug]
else:
filters_dict[self.category_field] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
self.assign_view_category(request, view)
# Create the queryset
if view.category:
kwquery_1 = {}
kwquery_1[self.category_field+'__id'] = view.category.id
if view.category.tn_descendants_pks:
kwquery_2 = {}
kwquery_2[self.category_field+'__id__in'] = view.category.tn_descendants_pks.split(',')
queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2))
else:
queryset = queryset.filter(**kwquery_1)
return queryset
# Developer Interface methods
def get_valid_fields(self, queryset, view, context, request):
# A query is executed here to get the possible categories
category_class = self.get_category_class(view, request)
if hasattr(view, 'category_field'):
self.category_field = self.get_category_field(view, request)
assert category_class is not None, (
f"{view.__class__.__name__} should include a `category_class` attribute"
)
valid_fields = category_class.objects.all()
if len(valid_fields):
valid_fields = [
(item.slug, item.__str__()) for item in valid_fields
]
return valid_fields
else:
return []
def get_current(self, request, queryset, view):
params = request.query_params.get(self.category_field)
if params:
fields = [param.strip() for param in params.split(',')]
return fields[0]
else:
return None
def get_template_context(self, request, queryset, view):
current = self.get_current(request, queryset, view)
options = []
context = {
'request': request,
'current': current,
'param': self.category_field,
}
for key, label in self.get_valid_fields(queryset, view, context, request):
options.append((key, '%s' % (label)))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)
class FacetFilter(BaseFilterBackend):
"""
This filter requires CategoryFilter to be ran before it. It assigns the view.facets which includes all the facets applicable to the current category.
"""
template = './filters/facets.html'
def get_facet_class(self, view, request):
return getattr(view, 'facet_class', None)
def get_facet_tag_class(self, view, request):
return getattr(view, 'facet_tag_class', None)
def get_facet_tag_field(self, view, request):
return getattr(view, 'facet_tag_field', None)
def assign_view_facets(self, request, view):
if not hasattr(view, 'facets'):
if view.category:
if view.category.tn_ancestors_pks:
view.facets = Facet.objects.filter(Q(category__id=view.category.id) | Q(category__id__in=view.category.tn_ancestors_pks.split(','))).prefetch_related('facet_tags')
else:
view.facets = Facet.objects.filter(category__id=view.category.id).prefetch_related('facet_tags')
else:
view.facets = Facet.objects.filter(category__tn_level=1).prefetch_related('facet_tags')
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
if hasattr(view, 'facet_class'):
self.facet_class = self.get_facet_class(view, request)
assert self.facet_class is not None, (
f"{view.__class__.__name__} should include a `facet_class` attribute"
)
self.assign_view_facets(request, view)
filters_dict = {}
if view.facets:
for facet in view.facets:
if facet.slug in request.query_params.keys():
filters_dict[facet.slug] = set(request.query_params[facet.slug].split(','))
else:
filters_dict[facet.slug] = set({})
# Append the facets object and the tags dict in the view for later reference
view.tags = filters_dict
return filters_dict
def filter_queryset(self, request, queryset, view):
if hasattr(view, 'facet_tag_class'):
self.facet_tag_class = self.get_facet_tag_class(view, request)
assert self.facet_tag_class is not None, (
f"{view.__class__.__name__} should include a `facet_tag_class` attribute"
)
if hasattr(view, 'facet_tag_field'):
self.facet_tag_field = self.get_facet_tag_field(view, request)
assert self.facet_tag_field is not None, (
f"{view.__class__.__name__} should include a `facet_tag_field` attribute"
)
self.assign_view_facets(request, view)
if view.facets:
for facet in view.facets:
if facet.slug in request.query_params.keys() and request.query_params[facet.slug]:
tag_filterlist = request.query_params.get(facet.slug)
if tag_filterlist == '':
# If the tag filterlist is empty then we're not filtering against it, it's like having all the tags of the facet selected
pass
else:
kwquery = {}
kwquery[self.facet_tag_field+'__slug__in'] = tag_filterlist.replace(' ','').split(',')
queryset = queryset.filter(**kwquery)
return queryset
# Developer Interface methods
def get_template_context(self, request, queryset, view):
# Does aggressive database querying to get the necessary facets and facettags, but this is only for the developer interface so its fine
if hasattr(view, 'facet_class'):
self.facet_class = self.get_facet_class(view, request)
assert self.facet_class is not None, (
f"{view.__class__.__name__} should include a `facet_class` attribute"
)
if hasattr(view, 'facet_tag_class'):
self.facet_tag_class = self.get_facet_tag_class(view, request)
assert self.facet_tag_class is not None, (
f"{view.__class__.__name__} should include a `facet_tag_class` attribute"
)
# Find the current choices
current = []
facet_slugs = []
if view.facets:
for facet in view.facets:
facet_slugs.append(facet.slug)
if facet.slug in request.query_params.keys():
current.append(request.query_params.get(facet.slug))
facet_slug_names = {}
options = {}
context = {
'request': request,
'current': current,
'facet_slugs': facet_slugs,
}
if view.facets:
for facet in view.facets:
facet_tag_instances = self.facet_tag_class.objects.filter(facet__slug=facet.slug)
options[facet.slug] = [(facet_tag.slug, facet_tag.name) for facet_tag in facet_tag_instances]
facet_slug_names[facet.slug] = facet.name
context['facet_slug_names'] = facet_slug_names
context['options'] = options
else:
context['facet_slug_names'] = {}
context['options'] = {}
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)
class TrigramSearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = 'search'
template = 'rest_framework/filters/search.html'
search_title = _('Search')
search_description = _('A search string to perform trigram similarity based searching with.')
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
self.filters_dict = {}
if 'search' in request.query_params.keys():
slug_term = request.query_params.get('search')
self.filters_dict['search'] = [slug_term]
else:
self.filters_dict['search'] = []
return self.filters_dict
def filter_queryset(self, request, queryset, view):
search_fields = getattr(view, 'search_fields', None)
assert search_fields is not None, (
f"{view.__class__.__name__} should include a `search_fields` attribute"
)
query = request.query_params.get(self.search_param, '')
if query:
queryset = queryset.annotate(
search_field=Concat(
*search_fields,
output_field=CharField()
)).annotate(
similarity=TrigramSimilarity('search_field', query)
).filter(similarity__gt=0.05).distinct()
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):
return ''
term = request.query_params.get(self.search_param, '')
term = term[0] if term else ''
context = {
'param': self.search_param,
'term': term
}
template = loader.get_template(self.template)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.search_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.search_title),
description=force_str(self.search_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.search_param,
'required': False,
'in': 'query',
'description': force_str(self.search_description),
'schema': {
'type': 'string',
},
},
]
# TODO
#class FieldFilter(BaseFilterBackend):
# TODO misunderstood the urlconf stuff of the RUD methods, this is probably unnecessary
class SlugSearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
template = './filters/slug.html'
slug_title = _('Slug Search')
slug_description = _("The instance's slug.")
slug_field = 'slug'
def get_slug_field(self, view, request):
return getattr(view, 'slug_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
if hasattr(view, 'slug_field'):
self.slug_field = self.get_slug_field(view, request)
assert self.slug_field is not None, (
f"{view.__class__.__name__} should include a `slug_field` attribute"
)
self.filters_dict = {}
if self.slug_field in request.query_params.keys():
slug_term = request.query_params.get(self.slug_field)
self.filters_dict[self.slug_field] = [slug_term]
else:
self.filters_dict[self.slug_field] = []
return self.filters_dict
def filter_queryset(self, request, queryset, view):
# Ensure that the slug field was searched against
try:
if self.slug_field in request.query_params.keys():
slug_term = request.query_params.get(self.slug_field)
query = {}
query[self.slug_field] = slug_term
queryset = queryset.get(**query)
except Exception as e:
print(e)
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'slug_field', None):
return ''
slug_term = self.get_slug_term(request)
context = {
'param': self.slug_field,
'term': slug_term
}
template = loader.get_template(self.template)
return template.render(context)
class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM
template = 'rest_framework/filters/search.html'
lookup_prefixes = {
'^': 'istartswith',
'=': 'iexact',
'@': 'search',
'$': 'iregex',
}
search_title = _('Search')
search_description = _('A search term.')
# TODO to be removed
def get_search_terms(self, request):
"""
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
params = request.query_params.get(self.search_param, '')
params = params.replace('\x00', '') # strip null characters
params = params.replace(',', ' ')
return params.split()
# TODO to be removed
def construct_search(self, field_name):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
lookup = 'icontains'
return LOOKUP_SEP.join([field_name, lookup])
# TODO to be removed
def must_call_distinct(self, queryset, search_fields):
"""
Return True if 'distinct()' should be used to query the given lookups.
"""
for search_field in search_fields:
opts = queryset.model._meta
if search_field[0] in self.lookup_prefixes:
search_field = search_field[1:]
# Annotated fields do not need to be distinct
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
continue
parts = search_field.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
if hasattr(field, 'get_path_info'):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
if any(path.m2m for path in path_info):
# This field is a m2m relation so we know we need to call distinct
return True
else:
# This field has a custom __ query transform but is not a relational field.
break
return False
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
self.filters_dict = {}
if 'search' in request.query_params.keys():
slug_term = request.query_params.get('search')
self.filters_dict['search'] = [slug_term]
else:
self.filters_dict['search'] = []
return self.filters_dict
def filter_queryset(self, request, queryset, view):
search_fields = getattr(view, 'search_fields', None)
search_terms = self.get_search_terms(request)
if not search_fields or not search_terms:
return queryset
orm_lookups = [
self.construct_search(str(search_field))
for search_field in search_fields
]
base = queryset
conditions = []
for search_term in search_terms:
queries = [
models.Q(**{orm_lookup: search_term})
for orm_lookup in orm_lookups
]
conditions.append(reduce(operator.or_, queries))
queryset = queryset.filter(reduce(operator.and_, conditions))
if self.must_call_distinct(queryset, search_fields):
# Filtering against a many-to-many field requires us to
# call queryset.distinct() in order to avoid duplicate items
# in the resulting queryset.
# We try to avoid this if possible, for performance reasons.
queryset = distinct(queryset, base)
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):
return ''
term = self.get_search_terms(request)
term = term[0] if term else ''
context = {
'param': self.search_param,
'term': term
}
template = loader.get_template(self.template)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.search_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.search_title),
description=force_str(self.search_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.search_param,
'required': False,
'in': 'query',
'description': force_str(self.search_description),
'schema': {
'type': 'string',
},
},
]
class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
ordering_title = _('Ordering')
ordering_description = _('Which field to use when ordering the results.')
template = 'rest_framework/filters/ordering.html'
def get_ordering(self, request, queryset, view):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
The `ordering` query parameter can be overridden by setting
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
if isinstance(ordering, str):
return (ordering,)
return ordering
def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'):
try:
serializer_class = view.get_serializer_class()
except AssertionError:
# Raised by the default implementation if
# no serializer_class was found
serializer_class = None
else:
serializer_class = getattr(view, 'serializer_class', None)
if serializer_class is None:
msg = (
"Cannot use %s on a view which does not have either a "
"'serializer_class', an overriding 'get_serializer_class' "
"or 'ordering_fields' attribute."
)
raise ImproperlyConfigured(msg % self.__class__.__name__)
model_class = queryset.model
model_property_names = [
# 'pk' is a property added in Django's Model class, however it is valid for ordering.
attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk'
]
return [
(field.source.replace('.', '__') or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items()
if (
not getattr(field, 'write_only', False) and
not field.source == '*' and
field.source not in model_property_names
)
]
def get_valid_fields(self, queryset, view, context={}):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view, context)
elif valid_fields == '__all__':
# View explicitly allows filtering on any model field
valid_fields = [
(field.name, field.verbose_name) for field in queryset.model._meta.fields
]
valid_fields += [
(key, key.title().split('__'))
for key in queryset.query.annotations
]
else:
valid_fields = [
(item, item) if isinstance(item, str) else item
for item in valid_fields
]
return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
def term_valid(term):
if term.startswith("-"):
term = term[1:]
return term in valid_fields
return [term for term in fields if term_valid(term)]
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
self.filters_dict = {}
if 'ordering' in request.query_params.keys():
slug_term = request.query_params.get('ordering')
self.filters_dict['ordering'] = [slug_term]
else:
self.filters_dict['ordering'] = [view.ordering_fields[0]]
return self.filters_dict
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
return queryset
def get_template_context(self, request, queryset, view):
current = self.get_ordering(request, queryset, view)
current = None if not current else current[0]
options = []
context = {
'request': request,
'current': current,
'param': self.ordering_param,
}
for key, label in self.get_valid_fields(queryset, view, context):
options.append((key, '%s - %s' % (label, _('ascending'))))
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.ordering_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.ordering_title),
description=force_str(self.ordering_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.ordering_param,
'required': False,
'in': 'query',
'description': force_str(self.ordering_description),
'schema': {
'type': 'string',
},
},
]