Refactored the library a bit and added a .gitignore file.
This commit is contained in:
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Ignore all the python cache files
|
||||||
|
**/__pycache__/**/*
|
||||||
|
|
||||||
|
# Ignore all the dist files
|
||||||
|
**/dist/**/*
|
||||||
File diff suppressed because it is too large
Load Diff
0
starfields_drf_generics/libfilters/__init__.py
Normal file
0
starfields_drf_generics/libfilters/__init__.py
Normal file
126
starfields_drf_generics/libfilters/category.py
Normal file
126
starfields_drf_generics/libfilters/category.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from rest_framework.filters import BaseFilterBackend
|
||||||
|
from django.template import loader
|
||||||
|
from django.db.models import Q
|
||||||
|
|
||||||
|
|
||||||
|
class CategoryFilter(BaseFilterBackend):
|
||||||
|
"""
|
||||||
|
This filter assigns the view.category object for later use, in particular
|
||||||
|
for filters that depend on this one.
|
||||||
|
"""
|
||||||
|
template = 'filters/categories.html'
|
||||||
|
category_field = 'category'
|
||||||
|
|
||||||
|
def get_category_class(self, view, request):
|
||||||
|
return getattr(view, 'category_class', None)
|
||||||
|
|
||||||
|
def assign_view_category(self, request, view):
|
||||||
|
if not hasattr(view, self.category_field):
|
||||||
|
if self.category_field in request.query_params.keys():
|
||||||
|
try:
|
||||||
|
category_slug = request.query_params.get(
|
||||||
|
self.category_field).strip()
|
||||||
|
category = view.category_class.objects.get(
|
||||||
|
slug=category_slug)
|
||||||
|
# Append the category object in the view for later use
|
||||||
|
view.category = category
|
||||||
|
except Exception:
|
||||||
|
view.category = None
|
||||||
|
else:
|
||||||
|
view.category = None
|
||||||
|
|
||||||
|
def get_filters_dict(self, request, view):
|
||||||
|
"""
|
||||||
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes. Queries the database for the current
|
||||||
|
category and saves it in view.category for internal use and later
|
||||||
|
filters.
|
||||||
|
"""
|
||||||
|
if hasattr(view, 'category_field'):
|
||||||
|
self.category_field = self.get_category_field(view, request)
|
||||||
|
|
||||||
|
category_class = self.get_category_class(view, request)
|
||||||
|
|
||||||
|
assert category_class is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `category_class`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the filters dictionary and find the present category instance
|
||||||
|
self.assign_view_category(request, view)
|
||||||
|
|
||||||
|
filters_dict = {}
|
||||||
|
if view.category:
|
||||||
|
filters_dict[self.category_field] = [view.category.slug]
|
||||||
|
else:
|
||||||
|
filters_dict[self.category_field] = []
|
||||||
|
|
||||||
|
return filters_dict
|
||||||
|
|
||||||
|
def filter_queryset(self, request, queryset, view):
|
||||||
|
self.assign_view_category(request, view)
|
||||||
|
|
||||||
|
# Create the queryset
|
||||||
|
if view.category:
|
||||||
|
kwquery_1 = {}
|
||||||
|
kwquery_1[self.category_field+'__id'] = view.category.id
|
||||||
|
if view.category.tn_descendants_pks:
|
||||||
|
kwquery_2 = {}
|
||||||
|
key = self.category_field+'__id__in'
|
||||||
|
kwquery_2[key] = view.category.tn_descendants_pks.split(',')
|
||||||
|
|
||||||
|
queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2))
|
||||||
|
else:
|
||||||
|
queryset = queryset.filter(**kwquery_1)
|
||||||
|
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
# Developer Interface methods
|
||||||
|
def get_valid_fields(self, queryset, view, context, request):
|
||||||
|
# A query is executed here to get the possible categories
|
||||||
|
category_class = self.get_category_class(view, request)
|
||||||
|
if hasattr(view, 'category_field'):
|
||||||
|
self.category_field = self.get_category_field(view, request)
|
||||||
|
|
||||||
|
assert category_class is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `category_class`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_fields = category_class.objects.all()
|
||||||
|
|
||||||
|
if len(valid_fields):
|
||||||
|
valid_fields = [
|
||||||
|
(item.slug, item.__str__()) for item in valid_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
return valid_fields
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_current(self, request, queryset, view):
|
||||||
|
params = request.query_params.get(self.category_field)
|
||||||
|
if params:
|
||||||
|
fields = [param.strip() for param in params.split(',')]
|
||||||
|
return fields[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_template_context(self, request, queryset, view):
|
||||||
|
current = self.get_current(request, queryset, view)
|
||||||
|
options = []
|
||||||
|
context = {
|
||||||
|
'request': request,
|
||||||
|
'current': current,
|
||||||
|
'param': self.category_field,
|
||||||
|
}
|
||||||
|
valid_fields = self.get_valid_fields(queryset, view, context, request)
|
||||||
|
for key, label in valid_fields:
|
||||||
|
options.append((key, '%s' % (label)))
|
||||||
|
context['options'] = options
|
||||||
|
return context
|
||||||
|
|
||||||
|
def to_html(self, request, queryset, view):
|
||||||
|
template = loader.get_template(self.template)
|
||||||
|
context = self.get_template_context(request, queryset, view)
|
||||||
|
return template.render(context)
|
||||||
170
starfields_drf_generics/libfilters/facet.py
Normal file
170
starfields_drf_generics/libfilters/facet.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
from rest_framework.filters import BaseFilterBackend
|
||||||
|
from django.template import loader
|
||||||
|
from django.db.models import Q
|
||||||
|
|
||||||
|
|
||||||
|
class FacetFilter(BaseFilterBackend):
|
||||||
|
"""
|
||||||
|
This filter requires CategoryFilter to be ran before it. It assigns the
|
||||||
|
view.facets which includes all the facets applicable to the current
|
||||||
|
category.
|
||||||
|
"""
|
||||||
|
template = 'filters/facets.html'
|
||||||
|
|
||||||
|
def get_facet_class(self, view, request):
|
||||||
|
return getattr(view, 'facet_class', None)
|
||||||
|
|
||||||
|
def get_facet_tag_class(self, view, request):
|
||||||
|
return getattr(view, 'facet_tag_class', None)
|
||||||
|
|
||||||
|
def get_facet_tag_field(self, view, request):
|
||||||
|
return getattr(view, 'facet_tag_field', None)
|
||||||
|
|
||||||
|
def assign_view_facets(self, request, view):
|
||||||
|
if not hasattr(view, 'facets'):
|
||||||
|
if hasattr(view, 'facet_class'):
|
||||||
|
self.facet_class = self.get_facet_class(view, request)
|
||||||
|
|
||||||
|
assert self.facet_class is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `facet_class`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
if view.category:
|
||||||
|
if view.category.tn_ancestors_pks:
|
||||||
|
ancestor_ids = view.category.tn_ancestors_pks.split(',')
|
||||||
|
view.facets = self.facet_class.objects.filter(
|
||||||
|
Q(category__id=view.category.id) | Q(
|
||||||
|
category__id__in=ancestor_ids)
|
||||||
|
).prefetch_related('facet_tags')
|
||||||
|
else:
|
||||||
|
view.facets = self.facet_class.objects.filter(
|
||||||
|
category__id=view.category.id).prefetch_related(
|
||||||
|
'facet_tags')
|
||||||
|
else:
|
||||||
|
view.facets = self.facet_class.objects.filter(
|
||||||
|
category__tn_level=1).prefetch_related(
|
||||||
|
'facet_tags')
|
||||||
|
|
||||||
|
def get_filters_dict(self, request, view):
|
||||||
|
"""
|
||||||
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
|
"""
|
||||||
|
if hasattr(view, 'facet_class'):
|
||||||
|
self.facet_class = self.get_facet_class(view, request)
|
||||||
|
|
||||||
|
assert self.facet_class is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `facet_class`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assign_view_facets(request, view)
|
||||||
|
|
||||||
|
filters_dict = {}
|
||||||
|
if view.facets:
|
||||||
|
for facet in view.facets:
|
||||||
|
if facet.slug in request.query_params.keys():
|
||||||
|
filters_dict[facet.slug] = set(
|
||||||
|
request.query_params[facet.slug].split(','))
|
||||||
|
else:
|
||||||
|
filters_dict[facet.slug] = set({})
|
||||||
|
|
||||||
|
# Append the facets object and the tags dict in the view for later
|
||||||
|
# reference
|
||||||
|
view.tags = filters_dict
|
||||||
|
|
||||||
|
return filters_dict
|
||||||
|
|
||||||
|
def filter_queryset(self, request, queryset, view):
|
||||||
|
if hasattr(view, 'facet_tag_class'):
|
||||||
|
self.facet_tag_class = self.get_facet_tag_class(view, request)
|
||||||
|
|
||||||
|
assert self.facet_tag_class is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `facet_tag_class`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(view, 'facet_tag_field'):
|
||||||
|
self.facet_tag_field = self.get_facet_tag_field(view, request)
|
||||||
|
|
||||||
|
assert self.facet_tag_field is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `facet_tag_field`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assign_view_facets(request, view)
|
||||||
|
|
||||||
|
if view.facets:
|
||||||
|
for facet in view.facets:
|
||||||
|
if facet.slug in request.query_params.keys():
|
||||||
|
tag_filterlist = request.query_params.get(facet.slug)
|
||||||
|
if tag_filterlist == '':
|
||||||
|
# If the tag filterlist is empty then we're not
|
||||||
|
# filtering against it, it's like having all the tags
|
||||||
|
# of the facet selected
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
kwquery = {}
|
||||||
|
key = self.facet_tag_field+'__slug__in'
|
||||||
|
kwquery[key] = tag_filterlist.replace(' ', '').split(
|
||||||
|
',')
|
||||||
|
queryset = queryset.filter(**kwquery)
|
||||||
|
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
# Developer Interface methods
|
||||||
|
def get_template_context(self, request, queryset, view):
|
||||||
|
# Does aggressive database querying to get the necessary facets and
|
||||||
|
# facettags, but this is only for the developer interface so its fine
|
||||||
|
if hasattr(view, 'facet_class'):
|
||||||
|
self.facet_class = self.get_facet_class(view, request)
|
||||||
|
|
||||||
|
assert self.facet_class is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `facet_class`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(view, 'facet_tag_class'):
|
||||||
|
self.facet_tag_class = self.get_facet_tag_class(view, request)
|
||||||
|
|
||||||
|
assert self.facet_tag_class is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `facet_tag_class`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the current choices
|
||||||
|
current = []
|
||||||
|
facet_slugs = []
|
||||||
|
if view.facets:
|
||||||
|
for facet in view.facets:
|
||||||
|
facet_slugs.append(facet.slug)
|
||||||
|
if facet.slug in request.query_params.keys():
|
||||||
|
current.append(request.query_params.get(facet.slug))
|
||||||
|
|
||||||
|
facet_slug_names = {}
|
||||||
|
options = {}
|
||||||
|
context = {
|
||||||
|
'request': request,
|
||||||
|
'current': current,
|
||||||
|
'facet_slugs': facet_slugs,
|
||||||
|
}
|
||||||
|
if view.facets:
|
||||||
|
for facet in view.facets:
|
||||||
|
facet_tag_instances = self.facet_tag_class.objects.filter(
|
||||||
|
facet__slug=facet.slug)
|
||||||
|
options[facet.slug] = [(facet_tag.slug, facet_tag.name) for
|
||||||
|
facet_tag in facet_tag_instances]
|
||||||
|
facet_slug_names[facet.slug] = facet.name
|
||||||
|
context['facet_slug_names'] = facet_slug_names
|
||||||
|
context['options'] = options
|
||||||
|
else:
|
||||||
|
context['facet_slug_names'] = {}
|
||||||
|
context['options'] = {}
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
def to_html(self, request, queryset, view):
|
||||||
|
template = loader.get_template(self.template)
|
||||||
|
context = self.get_template_context(request, queryset, view)
|
||||||
|
return template.render(context)
|
||||||
38
starfields_drf_generics/libfilters/lessthanorequal.py
Normal file
38
starfields_drf_generics/libfilters/lessthanorequal.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from rest_framework.filters import BaseFilterBackend
|
||||||
|
|
||||||
|
|
||||||
|
class LessThanOrEqualFilter(BaseFilterBackend):
|
||||||
|
def get_less_than_field(self, view, request):
|
||||||
|
return getattr(view, 'less_than_field', None)
|
||||||
|
|
||||||
|
def get_filters_dict(self, request, view):
|
||||||
|
"""
|
||||||
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
|
"""
|
||||||
|
less_than_field = self.get_less_than_field(view, request)
|
||||||
|
|
||||||
|
assert less_than_field is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `less_than_field`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
filters_dict = {}
|
||||||
|
if less_than_field+'_max' in request.query_params.keys():
|
||||||
|
field_value = request.query_params.get(less_than_field+'_max')
|
||||||
|
filters_dict[less_than_field+'_max'] = [field_value]
|
||||||
|
else:
|
||||||
|
filters_dict[less_than_field+'_max'] = []
|
||||||
|
return filters_dict
|
||||||
|
|
||||||
|
def filter_queryset(self, request, queryset, view):
|
||||||
|
# Return the correctly filtered queryset but also assign the filter
|
||||||
|
# dict to create the unique url for the cache
|
||||||
|
less_than_field = self.get_less_than_field(view, request)
|
||||||
|
if less_than_field+'_max' in request.query_params.keys():
|
||||||
|
kwquery = {}
|
||||||
|
field_value = request.query_params.get(less_than_field+'_max')
|
||||||
|
kwquery[less_than_field+'__lte'] = field_value
|
||||||
|
return queryset.filter(**kwquery)
|
||||||
|
else:
|
||||||
|
return queryset
|
||||||
39
starfields_drf_generics/libfilters/morethanorequal.py
Normal file
39
starfields_drf_generics/libfilters/morethanorequal.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from rest_framework.filters import BaseFilterBackend
|
||||||
|
|
||||||
|
|
||||||
|
class MoreThanOrEqualFilter(BaseFilterBackend):
|
||||||
|
def get_more_than_field(self, view, request):
|
||||||
|
return getattr(view, 'more_than_field', None)
|
||||||
|
|
||||||
|
def get_filters_dict(self, request, view):
|
||||||
|
"""
|
||||||
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
|
"""
|
||||||
|
more_than_field = self.get_more_than_field(view, request)
|
||||||
|
|
||||||
|
assert more_than_field is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `more_than_field`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
filters_dict = {}
|
||||||
|
if more_than_field+'_min' in request.query_params.keys():
|
||||||
|
field_value = request.query_params.get(more_than_field+'_min')
|
||||||
|
filters_dict[more_than_field+'_min'] = [field_value]
|
||||||
|
else:
|
||||||
|
filters_dict[more_than_field+'_min'] = []
|
||||||
|
return filters_dict
|
||||||
|
|
||||||
|
def filter_queryset(self, request, queryset, view):
|
||||||
|
"""
|
||||||
|
Return the correctly filtered queryset
|
||||||
|
"""
|
||||||
|
more_than_field = self.get_more_than_field(view, request)
|
||||||
|
if more_than_field+'_min' in request.query_params.keys():
|
||||||
|
kwquery = {}
|
||||||
|
kwquery[more_than_field+'__gte'] = request.query_params.get(
|
||||||
|
more_than_field+'_min')
|
||||||
|
return queryset.filter(**kwquery)
|
||||||
|
else:
|
||||||
|
return queryset
|
||||||
68
starfields_drf_generics/libfilters/nodetreebranch.py
Normal file
68
starfields_drf_generics/libfilters/nodetreebranch.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from rest_framework.settings import api_settings
|
||||||
|
from rest_framework.filters import BaseFilterBackend
|
||||||
|
from django.db.models import Q
|
||||||
|
|
||||||
|
|
||||||
|
class TreeNodeBranchFilter(BaseFilterBackend):
|
||||||
|
def get_descendants_of_field(self, view, request):
|
||||||
|
return getattr(view, 'descendants_of_field', None)
|
||||||
|
|
||||||
|
def get_depth_field(self, view, request):
|
||||||
|
return getattr(view, 'depth_field', None)
|
||||||
|
|
||||||
|
def get_filters_dict(self, request, view):
|
||||||
|
"""
|
||||||
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
|
"""
|
||||||
|
descendants_of_field = self.get_descendants_of_field(view, request)
|
||||||
|
|
||||||
|
assert descendants_of_field is not None, (
|
||||||
|
"{view.__class__.__name__} should include a "
|
||||||
|
f"`descendants_of_field` attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
depth_field = self.get_depth_field(view, request)
|
||||||
|
|
||||||
|
assert depth_field is not None, (
|
||||||
|
"{view.__class__.__name__} should include a "
|
||||||
|
f"`depth_field` attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
filters_dict = {}
|
||||||
|
if descendants_of_field in request.query_params.keys():
|
||||||
|
field_value = request.query_params.get(descendants_of_field)
|
||||||
|
filters_dict[descendants_of_field] = [field_value]
|
||||||
|
else:
|
||||||
|
filters_dict[descendants_of_field] = []
|
||||||
|
|
||||||
|
if depth_field in request.query_params.keys():
|
||||||
|
field_value = request.query_params.get(depth_field)
|
||||||
|
filters_dict[depth_field] = [field_value]
|
||||||
|
else:
|
||||||
|
filters_dict[depth_field] = [4]
|
||||||
|
|
||||||
|
return filters_dict
|
||||||
|
|
||||||
|
def filter_queryset(self, request, queryset, view):
|
||||||
|
# Return the correctly filtered queryset but also assign the filter
|
||||||
|
# dict to create the unique url for the cache
|
||||||
|
descendants_of_field = self.get_descendants_of_field(view, request)
|
||||||
|
depth_field = self.get_depth_field(view, request)
|
||||||
|
|
||||||
|
if descendants_of_field in request.query_params.keys():
|
||||||
|
field_value = request.query_params.get(descendants_of_field)
|
||||||
|
# Instead of doing two queries to get the descendants through the
|
||||||
|
# object a single more complex queryset
|
||||||
|
queryset = queryset.filter(
|
||||||
|
Q(tn_ancestors_pks=field_value) |
|
||||||
|
Q(tn_ancestors_pks__contains=","+field_value) |
|
||||||
|
Q(tn_ancestors_pks__contains=field_value+","))
|
||||||
|
|
||||||
|
if depth_field in request.query_params.keys():
|
||||||
|
field_value = request.query_params.get(depth_field)
|
||||||
|
queryset = queryset.filter(tn_level__lte=field_value)
|
||||||
|
else:
|
||||||
|
queryset = queryset.filter(tn_level__lte=4)
|
||||||
|
|
||||||
|
return queryset
|
||||||
180
starfields_drf_generics/libfilters/ordering.py
Normal file
180
starfields_drf_generics/libfilters/ordering.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
from rest_framework.settings import api_settings
|
||||||
|
from rest_framework.filters import BaseFilterBackend
|
||||||
|
from django.template import loader
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
|
class OrderingFilter(BaseFilterBackend):
|
||||||
|
# The URL query parameter used for the ordering.
|
||||||
|
ordering_param = api_settings.ORDERING_PARAM
|
||||||
|
ordering_fields = None
|
||||||
|
ordering_title = _('Ordering')
|
||||||
|
ordering_description = _('Which field to use when ordering the results.')
|
||||||
|
template = 'rest_framework/filters/ordering.html'
|
||||||
|
|
||||||
|
def get_ordering(self, request, queryset, view):
|
||||||
|
"""
|
||||||
|
Ordering is set by a comma delimited ?ordering=... query parameter.
|
||||||
|
|
||||||
|
The `ordering` query parameter can be overridden by setting
|
||||||
|
the `ordering_param` value on the OrderingFilter or by
|
||||||
|
specifying an `ORDERING_PARAM` value in the API settings.
|
||||||
|
"""
|
||||||
|
params = request.query_params.get(self.ordering_param)
|
||||||
|
if params:
|
||||||
|
fields = [param.strip() for param in params.split(',')]
|
||||||
|
ordering = self.remove_invalid_fields(queryset,
|
||||||
|
fields,
|
||||||
|
view,
|
||||||
|
request)
|
||||||
|
if ordering:
|
||||||
|
return ordering
|
||||||
|
|
||||||
|
# No ordering was included, or all the ordering fields were invalid
|
||||||
|
return self.get_default_ordering(view)
|
||||||
|
|
||||||
|
def get_default_ordering(self, view):
|
||||||
|
ordering = getattr(view, 'ordering', None)
|
||||||
|
if isinstance(ordering, str):
|
||||||
|
return (ordering,)
|
||||||
|
return ordering
|
||||||
|
|
||||||
|
def get_default_valid_fields(self, queryset, view, context={}):
|
||||||
|
# If `ordering_fields` is not specified, then we determine a default
|
||||||
|
# based on the serializer class, if one exists on the view.
|
||||||
|
if hasattr(view, 'get_serializer_class'):
|
||||||
|
try:
|
||||||
|
serializer_class = view.get_serializer_class()
|
||||||
|
except AssertionError:
|
||||||
|
# Raised by the default implementation if
|
||||||
|
# no serializer_class was found
|
||||||
|
serializer_class = None
|
||||||
|
else:
|
||||||
|
serializer_class = getattr(view, 'serializer_class', None)
|
||||||
|
|
||||||
|
if serializer_class is None:
|
||||||
|
msg = (
|
||||||
|
"Cannot use %s on a view which does not have either a "
|
||||||
|
"'serializer_class', an overriding 'get_serializer_class' "
|
||||||
|
"or 'ordering_fields' attribute."
|
||||||
|
)
|
||||||
|
raise ImproperlyConfigured(msg % self.__class__.__name__)
|
||||||
|
|
||||||
|
model_class = queryset.model
|
||||||
|
model_property_names = [
|
||||||
|
# 'pk' is a property added in Django's Model class, however it is valid for ordering.
|
||||||
|
attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk'
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
(field.source.replace('.', '__') or field_name, field.label)
|
||||||
|
for field_name, field in serializer_class(context=context).fields.items()
|
||||||
|
if (
|
||||||
|
not getattr(field, 'write_only', False) and
|
||||||
|
not field.source == '*' and
|
||||||
|
field.source not in model_property_names
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_valid_fields(self, queryset, view, context={}):
|
||||||
|
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
||||||
|
|
||||||
|
if valid_fields is None:
|
||||||
|
# Default to allowing filtering on serializer fields
|
||||||
|
return self.get_default_valid_fields(queryset, view, context)
|
||||||
|
|
||||||
|
elif valid_fields == '__all__':
|
||||||
|
# View explicitly allows filtering on any model field
|
||||||
|
valid_fields = [
|
||||||
|
(field.name, field.verbose_name) for field in queryset.model._meta.fields
|
||||||
|
]
|
||||||
|
valid_fields += [
|
||||||
|
(key, key.title().split('__'))
|
||||||
|
for key in queryset.query.annotations
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
valid_fields = [
|
||||||
|
(item, item) if isinstance(item, str) else item
|
||||||
|
for item in valid_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
return valid_fields
|
||||||
|
|
||||||
|
def remove_invalid_fields(self, queryset, fields, view, request):
|
||||||
|
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
|
||||||
|
|
||||||
|
def term_valid(term):
|
||||||
|
if term.startswith("-"):
|
||||||
|
term = term[1:]
|
||||||
|
return term in valid_fields
|
||||||
|
|
||||||
|
return [term for term in fields if term_valid(term)]
|
||||||
|
|
||||||
|
def get_filters_dict(self, request, view):
|
||||||
|
"""
|
||||||
|
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||||
|
"""
|
||||||
|
self.filters_dict = {}
|
||||||
|
if 'ordering' in request.query_params.keys():
|
||||||
|
slug_term = request.query_params.get('ordering')
|
||||||
|
self.filters_dict['ordering'] = [slug_term]
|
||||||
|
else:
|
||||||
|
self.filters_dict['ordering'] = [view.ordering_fields[0]]
|
||||||
|
|
||||||
|
return self.filters_dict
|
||||||
|
|
||||||
|
def filter_queryset(self, request, queryset, view):
|
||||||
|
ordering = self.get_ordering(request, queryset, view)
|
||||||
|
|
||||||
|
if ordering:
|
||||||
|
return queryset.order_by(*ordering)
|
||||||
|
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
def get_template_context(self, request, queryset, view):
|
||||||
|
current = self.get_ordering(request, queryset, view)
|
||||||
|
current = None if not current else current[0]
|
||||||
|
options = []
|
||||||
|
context = {
|
||||||
|
'request': request,
|
||||||
|
'current': current,
|
||||||
|
'param': self.ordering_param,
|
||||||
|
}
|
||||||
|
for key, label in self.get_valid_fields(queryset, view, context):
|
||||||
|
options.append((key, '%s - %s' % (label, _('ascending'))))
|
||||||
|
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
|
||||||
|
context['options'] = options
|
||||||
|
return context
|
||||||
|
|
||||||
|
def to_html(self, request, queryset, view):
|
||||||
|
template = loader.get_template(self.template)
|
||||||
|
context = self.get_template_context(request, queryset, view)
|
||||||
|
return template.render(context)
|
||||||
|
|
||||||
|
def get_schema_fields(self, view):
|
||||||
|
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||||
|
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||||
|
return [
|
||||||
|
coreapi.Field(
|
||||||
|
name=self.ordering_param,
|
||||||
|
required=False,
|
||||||
|
location='query',
|
||||||
|
schema=coreschema.String(
|
||||||
|
title=force_str(self.ordering_title),
|
||||||
|
description=force_str(self.ordering_description)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_schema_operation_parameters(self, view):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'name': self.ordering_param,
|
||||||
|
'required': False,
|
||||||
|
'in': 'query',
|
||||||
|
'description': force_str(self.ordering_description),
|
||||||
|
'schema': {
|
||||||
|
'type': 'string',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
247
starfields_drf_generics/libfilters/trigramsearch.py
Normal file
247
starfields_drf_generics/libfilters/trigramsearch.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
from django.utils.text import smart_split
|
||||||
|
from django.core.exceptions import FieldError
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
|
from rest_framework.filters import BaseFilterBackend
|
||||||
|
from django.template import loader
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
from django.db import models
|
||||||
|
from django.db.models import Q
|
||||||
|
from rest_framework.fields import CharField
|
||||||
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
|
from django.db.models.functions import Concat
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_threshold(query, min_threshold, max_threshold):
|
||||||
|
query_threshold = len(query)/300
|
||||||
|
if query_threshold < min_threshold:
|
||||||
|
return min_threshold
|
||||||
|
if max_threshold < query_threshold:
|
||||||
|
return max_threshold
|
||||||
|
return query_threshold
|
||||||
|
|
||||||
|
|
||||||
|
def search_smart_split(search_terms):
|
||||||
|
"""
|
||||||
|
Generator that first splits string by spaces, leaving quoted phrases
|
||||||
|
together, then it splits non-quoted phrases by commas.
|
||||||
|
"""
|
||||||
|
split_terms = []
|
||||||
|
for term in smart_split(search_terms):
|
||||||
|
# trim commas to avoid bad matching for quoted phrases
|
||||||
|
term = term.strip(',')
|
||||||
|
if term.startswith(('"', "'")) and term[0] == term[-1]:
|
||||||
|
# quoted phrases are kept together without any other split
|
||||||
|
split_terms.append(unescape_string_literal(term))
|
||||||
|
else:
|
||||||
|
# non-quoted tokens are split by comma, keeping only non-empty ones
|
||||||
|
for sub_term in term.split(','):
|
||||||
|
if sub_term:
|
||||||
|
split_terms.append(sub_term.strip())
|
||||||
|
return split_terms
|
||||||
|
|
||||||
|
|
||||||
|
class TrigramSearchFilter(BaseFilterBackend):
|
||||||
|
# The URL query parameter used for the search.
|
||||||
|
search_param = 'search'
|
||||||
|
template = 'rest_framework/filters/search.html'
|
||||||
|
search_title = _('Search')
|
||||||
|
search_description = _('A search string to perform trigram similarity'
|
||||||
|
'based searching with.')
|
||||||
|
lookup_prefixes = {
|
||||||
|
'^': 'istartswith',
|
||||||
|
'=': 'iexact',
|
||||||
|
'@': 'search',
|
||||||
|
'$': 'iregex',
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_filters_dict(self, request, view):
|
||||||
|
"""
|
||||||
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
|
"""
|
||||||
|
self.filters_dict = {}
|
||||||
|
if 'search' in request.query_params.keys():
|
||||||
|
slug_term = request.query_params.get('search')
|
||||||
|
self.filters_dict['search'] = [slug_term]
|
||||||
|
else:
|
||||||
|
self.filters_dict['search'] = []
|
||||||
|
|
||||||
|
return self.filters_dict
|
||||||
|
|
||||||
|
def get_search_fields(self, view, request):
|
||||||
|
"""
|
||||||
|
Search fields are obtained from the view, but the request is always
|
||||||
|
passed to this method. Sub-classes can override this method to
|
||||||
|
dynamically change the search fields based on request content.
|
||||||
|
"""
|
||||||
|
return getattr(view, 'search_fields', None)
|
||||||
|
|
||||||
|
def get_search_query(self, request):
|
||||||
|
"""
|
||||||
|
Search terms are set by a ?search=... query parameter,
|
||||||
|
and may be whitespace delimited.
|
||||||
|
"""
|
||||||
|
value = request.query_params.get(self.search_param, '')
|
||||||
|
field = CharField(trim_whitespace=False, allow_blank=True)
|
||||||
|
cleaned_value = field.run_validation(value)
|
||||||
|
return cleaned_value
|
||||||
|
|
||||||
|
def construct_search(self, field_name, queryset):
|
||||||
|
"""
|
||||||
|
For the sqlite search
|
||||||
|
"""
|
||||||
|
lookup = self.lookup_prefixes.get(field_name[0])
|
||||||
|
if lookup:
|
||||||
|
field_name = field_name[1:]
|
||||||
|
else:
|
||||||
|
# Use field_name if it includes a lookup.
|
||||||
|
opts = queryset.model._meta
|
||||||
|
lookup_fields = field_name.split(LOOKUP_SEP)
|
||||||
|
# Go through the fields, following all relations.
|
||||||
|
prev_field = None
|
||||||
|
for path_part in lookup_fields:
|
||||||
|
if path_part == "pk":
|
||||||
|
path_part = opts.pk.name
|
||||||
|
try:
|
||||||
|
field = opts.get_field(path_part)
|
||||||
|
except FieldDoesNotExist:
|
||||||
|
# Use valid query lookups.
|
||||||
|
if prev_field and prev_field.get_lookup(path_part):
|
||||||
|
return field_name
|
||||||
|
else:
|
||||||
|
prev_field = field
|
||||||
|
if hasattr(field, "path_infos"):
|
||||||
|
# Update opts to follow the relation.
|
||||||
|
opts = field.path_infos[-1].to_opts
|
||||||
|
# django < 4.1
|
||||||
|
elif hasattr(field, 'get_path_info'):
|
||||||
|
# Update opts to follow the relation.
|
||||||
|
opts = field.get_path_info()[-1].to_opts
|
||||||
|
# Otherwise, use the field with icontains.
|
||||||
|
lookup = 'icontains'
|
||||||
|
return LOOKUP_SEP.join([field_name, lookup])
|
||||||
|
|
||||||
|
def must_call_distinct(self, queryset, search_fields):
|
||||||
|
"""
|
||||||
|
Return True if 'distinct()' should be used to query the given lookups.
|
||||||
|
"""
|
||||||
|
for search_field in search_fields:
|
||||||
|
opts = queryset.model._meta
|
||||||
|
if search_field[0] in self.lookup_prefixes:
|
||||||
|
search_field = search_field[1:]
|
||||||
|
# Annotated fields do not need to be distinct
|
||||||
|
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
|
||||||
|
continue
|
||||||
|
parts = search_field.split(LOOKUP_SEP)
|
||||||
|
for part in parts:
|
||||||
|
field = opts.get_field(part)
|
||||||
|
if hasattr(field, 'get_path_info'):
|
||||||
|
# This field is a relation, update opts to follow the relation
|
||||||
|
path_info = field.get_path_info()
|
||||||
|
opts = path_info[-1].to_opts
|
||||||
|
if any(path.m2m for path in path_info):
|
||||||
|
# This field is a m2m relation so we know we need to call distinct
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# This field has a custom __ query transform but is not a relational field.
|
||||||
|
break
|
||||||
|
return False
|
||||||
|
|
||||||
|
def filter_queryset(self, request, queryset, view):
|
||||||
|
search_fields = self.get_search_fields(view, request)
|
||||||
|
|
||||||
|
assert search_fields is not None, (
|
||||||
|
f"{view.__class__.__name__} should include a `search_fields`"
|
||||||
|
"attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.get_search_query(request)
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Attempt postgresql's full text search
|
||||||
|
from django.contrib.postgres.search import TrigramStrictWordSimilarity
|
||||||
|
threshold = calculate_threshold(query, 0.02, 0.12)
|
||||||
|
queryset = queryset.annotate(
|
||||||
|
search_field=Concat(
|
||||||
|
*search_fields,
|
||||||
|
output_field=CharField()
|
||||||
|
)).annotate(
|
||||||
|
similarity=TrigramStrictWordSimilarity(
|
||||||
|
'search_field', query)
|
||||||
|
).filter(similarity__gt=threshold)
|
||||||
|
|
||||||
|
# NOTE a weird FieldError is raised on sqlite
|
||||||
|
except (ImportError, FieldError):
|
||||||
|
# Perform very simple sqlite compatible search
|
||||||
|
search_terms = search_smart_split(query)
|
||||||
|
|
||||||
|
orm_lookups = [
|
||||||
|
self.construct_search(str(search_field), queryset)
|
||||||
|
for search_field in search_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
base = queryset
|
||||||
|
# generator which for each term builds the corresponding search
|
||||||
|
conditions = (
|
||||||
|
reduce(
|
||||||
|
operator.or_,
|
||||||
|
(models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups)
|
||||||
|
) for term in search_terms
|
||||||
|
)
|
||||||
|
queryset = queryset.filter(reduce(operator.and_, conditions))
|
||||||
|
|
||||||
|
# Remove duplicates from results, if necessary
|
||||||
|
if self.must_call_distinct(queryset, search_fields):
|
||||||
|
# inspired by django.contrib.admin
|
||||||
|
# this is more accurate than .distinct form M2M relationship
|
||||||
|
# also is cross-database
|
||||||
|
queryset = queryset.filter(pk=models.OuterRef('pk'))
|
||||||
|
queryset = base.filter(models.Exists(queryset))
|
||||||
|
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
def to_html(self, request, queryset, view):
|
||||||
|
if not getattr(view, 'search_fields', None):
|
||||||
|
return ''
|
||||||
|
|
||||||
|
term = request.query_params.get(self.search_param, '')
|
||||||
|
term = term[0] if term else ''
|
||||||
|
context = {
|
||||||
|
'param': self.search_param,
|
||||||
|
'term': term
|
||||||
|
}
|
||||||
|
template = loader.get_template(self.template)
|
||||||
|
return template.render(context)
|
||||||
|
|
||||||
|
def get_schema_fields(self, view):
|
||||||
|
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||||
|
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
|
||||||
|
return [
|
||||||
|
coreapi.Field(
|
||||||
|
name=self.search_param,
|
||||||
|
required=False,
|
||||||
|
location='query',
|
||||||
|
schema=coreschema.String(
|
||||||
|
title=force_str(self.search_title),
|
||||||
|
description=force_str(self.search_description)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_schema_operation_parameters(self, view):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'name': self.search_param,
|
||||||
|
'required': False,
|
||||||
|
'in': 'query',
|
||||||
|
'description': force_str(self.search_description),
|
||||||
|
'schema': {
|
||||||
|
'type': 'string',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ class CachedListCreateModelMixin(CacheDeleteMixin):
|
|||||||
A fully custom mixin that handles mutiple instance cration.
|
A fully custom mixin that handles mutiple instance cration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def list_create(self, request):
|
def list_create(self, request, **kwargs):
|
||||||
" Creates the list of entries in the request "
|
" Creates the list of entries in the request "
|
||||||
# Go on with the creation as normal
|
# Go on with the creation as normal
|
||||||
serializer = self.get_serializer(data=request.data, many=True)
|
serializer = self.get_serializer(data=request.data, many=True)
|
||||||
@@ -142,7 +142,7 @@ class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
|
|||||||
inherit anything from it.
|
inherit anything from it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def list(self, request):
|
def list(self, request, **kwargs):
|
||||||
" Retrieves the listing of entries "
|
" Retrieves the listing of entries "
|
||||||
# Attempt to get the request from the cache
|
# Attempt to get the request from the cache
|
||||||
cache_attempt = self.get_cache(request)
|
cache_attempt = self.get_cache(request)
|
||||||
@@ -206,7 +206,7 @@ class CachedListDestroyModelMixin(CacheDeleteMixin):
|
|||||||
A fully custom mixin that handles mutiple instance deletions.
|
A fully custom mixin that handles mutiple instance deletions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def list_destroy(self, request):
|
def list_destroy(self, request, **kwargs):
|
||||||
" Deletes the list of entries in the request "
|
" Deletes the list of entries in the request "
|
||||||
# Go on with the validation as normal
|
# Go on with the validation as normal
|
||||||
serializer = self.get_serializer(data=request.data, many=True)
|
serializer = self.get_serializer(data=request.data, many=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user