From 1e0f3c694e0293be9affb225e8c19cd6da72c149 Mon Sep 17 00:00:00 2001 From: Anastasios Svolis Date: Sun, 5 Jan 2025 10:13:59 +0200 Subject: [PATCH] Refactored the library a bit and added a .gitignore file. --- .gitignore | 5 + starfields_drf_generics/filters.py | 1249 ++++------------- .../libfilters/__init__.py | 0 .../libfilters/category.py | 126 ++ starfields_drf_generics/libfilters/facet.py | 170 +++ .../libfilters/lessthanorequal.py | 38 + .../libfilters/morethanorequal.py | 39 + .../libfilters/nodetreebranch.py | 68 + .../libfilters/ordering.py | 180 +++ .../libfilters/trigramsearch.py | 247 ++++ starfields_drf_generics/mixins.py | 6 +- 11 files changed, 1135 insertions(+), 993 deletions(-) create mode 100644 .gitignore create mode 100644 starfields_drf_generics/libfilters/__init__.py create mode 100644 starfields_drf_generics/libfilters/category.py create mode 100644 starfields_drf_generics/libfilters/facet.py create mode 100644 starfields_drf_generics/libfilters/lessthanorequal.py create mode 100644 starfields_drf_generics/libfilters/morethanorequal.py create mode 100644 starfields_drf_generics/libfilters/nodetreebranch.py create mode 100644 starfields_drf_generics/libfilters/ordering.py create mode 100644 starfields_drf_generics/libfilters/trigramsearch.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..23c22c4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +# Ignore all the python cache files +**/__pycache__/**/* + +# Ignore all the dist files +**/dist/**/* diff --git a/starfields_drf_generics/filters.py b/starfields_drf_generics/filters.py index 6405db2..2dd3875 100644 --- a/starfields_drf_generics/filters.py +++ b/starfields_drf_generics/filters.py @@ -1,990 +1,259 @@ -from django.utils.text import smart_split -from django.core.exceptions import FieldError -import operator -from functools import reduce -from rest_framework.settings import api_settings -from rest_framework.filters import BaseFilterBackend -from django.template import loader -from django.utils.translation import gettext_lazy as _ -from django.db import models -from django.db.models import Q -from rest_framework.fields import CharField -from django.db.models.constants import LOOKUP_SEP -from django.db.models.functions import Concat - - -# TODO the dev pages are not done - -def calculate_threshold(query, min_threshold, max_threshold): - query_threshold = len(query)/300 - if query_threshold < min_threshold: - return min_threshold - if max_threshold < query_threshold: - return max_threshold - return query_threshold - - -def search_smart_split(search_terms): - """generator that first splits string by spaces, leaving quoted phrases together, - then it splits non-quoted phrases by commas. - """ - split_terms = [] - for term in smart_split(search_terms): - # trim commas to avoid bad matching for quoted phrases - term = term.strip(',') - if term.startswith(('"', "'")) and term[0] == term[-1]: - # quoted phrases are kept together without any other split - split_terms.append(unescape_string_literal(term)) - else: - # non-quoted tokens are split by comma, keeping only non-empty ones - for sub_term in term.split(','): - if sub_term: - split_terms.append(sub_term.strip()) - return split_terms - - -class LessThanOrEqualFilter(BaseFilterBackend): - def get_less_than_field(self, view, request): - return getattr(view, 'less_than_field', None) - - def get_filters_dict(self, request, view): - """ - Custom method that returns the filters exclusive to this filter in a - dict. For caching purposes. - """ - less_than_field = self.get_less_than_field(view, request) - - assert less_than_field is not None, ( - f"{view.__class__.__name__} should include a `less_than_field`" - "attribute" - ) - - filters_dict = {} - if less_than_field+'_max' in request.query_params.keys(): - field_value = request.query_params.get(less_than_field+'_max') - filters_dict[less_than_field+'_max'] = [field_value] - else: - filters_dict[less_than_field+'_max'] = [] - return filters_dict - - def filter_queryset(self, request, queryset, view): - # Return the correctly filtered queryset but also assign the filter - # dict to create the unique url for the cache - less_than_field = self.get_less_than_field(view, request) - if less_than_field+'_max' in request.query_params.keys(): - kwquery = {} - field_value = request.query_params.get(less_than_field+'_max') - kwquery[less_than_field+'__lte'] = field_value - return queryset.filter(**kwquery) - else: - return queryset - - -class MoreThanOrEqualFilter(BaseFilterBackend): - def get_more_than_field(self, view, request): - return getattr(view, 'more_than_field', None) - - def get_filters_dict(self, request, view): - """ - Custom method that returns the filters exclusive to this filter in a - dict. For caching purposes. - """ - more_than_field = self.get_more_than_field(view, request) - - assert more_than_field is not None, ( - f"{view.__class__.__name__} should include a `more_than_field`" - "attribute" - ) - - filters_dict = {} - if more_than_field+'_min' in request.query_params.keys(): - field_value = request.query_params.get(more_than_field+'_min') - filters_dict[more_than_field+'_min'] = [field_value] - else: - filters_dict[more_than_field+'_min'] = [] - return filters_dict - - def filter_queryset(self, request, queryset, view): - """ - Return the correctly filtered queryset - """ - more_than_field = self.get_more_than_field(view, request) - if more_than_field+'_min' in request.query_params.keys(): - kwquery = {} - kwquery[more_than_field+'__gte'] = request.query_params.get( - more_than_field+'_min') - return queryset.filter(**kwquery) - else: - return queryset - - -class CategoryFilter(BaseFilterBackend): - """ - This filter assigns the view.category object for later use, in particular - for filters that depend on this one. - """ - template = 'filters/categories.html' - category_field = 'category' - - def get_category_class(self, view, request): - return getattr(view, 'category_class', None) - - def assign_view_category(self, request, view): - if not hasattr(view, self.category_field): - if self.category_field in request.query_params.keys(): - try: - category_slug = request.query_params.get( - self.category_field).strip() - category = view.category_class.objects.get( - slug=category_slug) - # Append the category object in the view for later use - view.category = category - except Exception: - view.category = None - else: - view.category = None - - def get_filters_dict(self, request, view): - """ - Custom method that returns the filters exclusive to this filter in a - dict. For caching purposes. Queries the database for the current - category and saves it in view.category for internal use and later - filters. - """ - if hasattr(view, 'category_field'): - self.category_field = self.get_category_field(view, request) - - category_class = self.get_category_class(view, request) - - assert category_class is not None, ( - f"{view.__class__.__name__} should include a `category_class`" - "attribute" - ) - - # Create the filters dictionary and find the present category instance - self.assign_view_category(request, view) - - filters_dict = {} - if view.category: - filters_dict[self.category_field] = [view.category.slug] - else: - filters_dict[self.category_field] = [] - - return filters_dict - - def filter_queryset(self, request, queryset, view): - self.assign_view_category(request, view) - - # Create the queryset - if view.category: - kwquery_1 = {} - kwquery_1[self.category_field+'__id'] = view.category.id - if view.category.tn_descendants_pks: - kwquery_2 = {} - key = self.category_field+'__id__in' - kwquery_2[key] = view.category.tn_descendants_pks.split(',') - - queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2)) - else: - queryset = queryset.filter(**kwquery_1) - - return queryset - - # Developer Interface methods - def get_valid_fields(self, queryset, view, context, request): - # A query is executed here to get the possible categories - category_class = self.get_category_class(view, request) - if hasattr(view, 'category_field'): - self.category_field = self.get_category_field(view, request) - - assert category_class is not None, ( - f"{view.__class__.__name__} should include a `category_class`" - "attribute" - ) - - valid_fields = category_class.objects.all() - - if len(valid_fields): - valid_fields = [ - (item.slug, item.__str__()) for item in valid_fields - ] - - return valid_fields - else: - return [] - - def get_current(self, request, queryset, view): - params = request.query_params.get(self.category_field) - if params: - fields = [param.strip() for param in params.split(',')] - return fields[0] - else: - return None - - def get_template_context(self, request, queryset, view): - current = self.get_current(request, queryset, view) - options = [] - context = { - 'request': request, - 'current': current, - 'param': self.category_field, - } - valid_fields = self.get_valid_fields(queryset, view, context, request) - for key, label in valid_fields: - options.append((key, '%s' % (label))) - context['options'] = options - return context - - def to_html(self, request, queryset, view): - template = loader.get_template(self.template) - context = self.get_template_context(request, queryset, view) - return template.render(context) - - -class FacetFilter(BaseFilterBackend): - """ - This filter requires CategoryFilter to be ran before it. It assigns the - view.facets which includes all the facets applicable to the current - category. - """ - template = 'filters/facets.html' - - def get_facet_class(self, view, request): - return getattr(view, 'facet_class', None) - - def get_facet_tag_class(self, view, request): - return getattr(view, 'facet_tag_class', None) - - def get_facet_tag_field(self, view, request): - return getattr(view, 'facet_tag_field', None) - - def assign_view_facets(self, request, view): - if not hasattr(view, 'facets'): - if hasattr(view, 'facet_class'): - self.facet_class = self.get_facet_class(view, request) - - assert self.facet_class is not None, ( - f"{view.__class__.__name__} should include a `facet_class`" - "attribute" - ) - - if view.category: - if view.category.tn_ancestors_pks: - ancestor_ids = view.category.tn_ancestors_pks.split(',') - view.facets = self.facet_class.objects.filter( - Q(category__id=view.category.id) | Q( - category__id__in=ancestor_ids) - ).prefetch_related('facet_tags') - else: - view.facets = self.facet_class.objects.filter( - category__id=view.category.id).prefetch_related( - 'facet_tags') - else: - view.facets = self.facet_class.objects.filter( - category__tn_level=1).prefetch_related( - 'facet_tags') - - def get_filters_dict(self, request, view): - """ - Custom method that returns the filters exclusive to this filter in a - dict. For caching purposes. - """ - if hasattr(view, 'facet_class'): - self.facet_class = self.get_facet_class(view, request) - - assert self.facet_class is not None, ( - f"{view.__class__.__name__} should include a `facet_class`" - "attribute" - ) - - self.assign_view_facets(request, view) - - filters_dict = {} - if view.facets: - for facet in view.facets: - if facet.slug in request.query_params.keys(): - filters_dict[facet.slug] = set( - request.query_params[facet.slug].split(',')) - else: - filters_dict[facet.slug] = set({}) - - # Append the facets object and the tags dict in the view for later - # reference - view.tags = filters_dict - - return filters_dict - - def filter_queryset(self, request, queryset, view): - if hasattr(view, 'facet_tag_class'): - self.facet_tag_class = self.get_facet_tag_class(view, request) - - assert self.facet_tag_class is not None, ( - f"{view.__class__.__name__} should include a `facet_tag_class`" - "attribute" - ) - - if hasattr(view, 'facet_tag_field'): - self.facet_tag_field = self.get_facet_tag_field(view, request) - - assert self.facet_tag_field is not None, ( - f"{view.__class__.__name__} should include a `facet_tag_field`" - "attribute" - ) - - self.assign_view_facets(request, view) - - if view.facets: - for facet in view.facets: - if facet.slug in request.query_params.keys(): - tag_filterlist = request.query_params.get(facet.slug) - if tag_filterlist == '': - # If the tag filterlist is empty then we're not - # filtering against it, it's like having all the tags - # of the facet selected - pass - else: - kwquery = {} - key = self.facet_tag_field+'__slug__in' - kwquery[key] = tag_filterlist.replace(' ', '').split( - ',') - queryset = queryset.filter(**kwquery) - - return queryset - - # Developer Interface methods - def get_template_context(self, request, queryset, view): - # Does aggressive database querying to get the necessary facets and - # facettags, but this is only for the developer interface so its fine - if hasattr(view, 'facet_class'): - self.facet_class = self.get_facet_class(view, request) - - assert self.facet_class is not None, ( - f"{view.__class__.__name__} should include a `facet_class`" - "attribute" - ) - - if hasattr(view, 'facet_tag_class'): - self.facet_tag_class = self.get_facet_tag_class(view, request) - - assert self.facet_tag_class is not None, ( - f"{view.__class__.__name__} should include a `facet_tag_class`" - "attribute" - ) - - # Find the current choices - current = [] - facet_slugs = [] - if view.facets: - for facet in view.facets: - facet_slugs.append(facet.slug) - if facet.slug in request.query_params.keys(): - current.append(request.query_params.get(facet.slug)) - - facet_slug_names = {} - options = {} - context = { - 'request': request, - 'current': current, - 'facet_slugs': facet_slugs, - } - if view.facets: - for facet in view.facets: - facet_tag_instances = self.facet_tag_class.objects.filter( - facet__slug=facet.slug) - options[facet.slug] = [(facet_tag.slug, facet_tag.name) for - facet_tag in facet_tag_instances] - facet_slug_names[facet.slug] = facet.name - context['facet_slug_names'] = facet_slug_names - context['options'] = options - else: - context['facet_slug_names'] = {} - context['options'] = {} - - return context - - def to_html(self, request, queryset, view): - template = loader.get_template(self.template) - context = self.get_template_context(request, queryset, view) - return template.render(context) - - -class TrigramSearchFilter(BaseFilterBackend): - # The URL query parameter used for the search. - search_param = 'search' - template = 'rest_framework/filters/search.html' - search_title = _('Search') - search_description = _('A search string to perform trigram similarity' - 'based searching with.') - lookup_prefixes = { - '^': 'istartswith', - '=': 'iexact', - '@': 'search', - '$': 'iregex', - } - - def get_filters_dict(self, request, view): - """ - Custom method that returns the filters exclusive to this filter in a - dict. For caching purposes. - """ - self.filters_dict = {} - if 'search' in request.query_params.keys(): - slug_term = request.query_params.get('search') - self.filters_dict['search'] = [slug_term] - else: - self.filters_dict['search'] = [] - - return self.filters_dict - - def get_search_fields(self, view, request): - """ - Search fields are obtained from the view, but the request is always - passed to this method. Sub-classes can override this method to - dynamically change the search fields based on request content. - """ - return getattr(view, 'search_fields', None) - - def get_search_query(self, request): - """ - Search terms are set by a ?search=... query parameter, - and may be whitespace delimited. - """ - value = request.query_params.get(self.search_param, '') - field = CharField(trim_whitespace=False, allow_blank=True) - cleaned_value = field.run_validation(value) - return cleaned_value - - def construct_search(self, field_name, queryset): - """ - For the sqlite search - """ - lookup = self.lookup_prefixes.get(field_name[0]) - if lookup: - field_name = field_name[1:] - else: - # Use field_name if it includes a lookup. - opts = queryset.model._meta - lookup_fields = field_name.split(LOOKUP_SEP) - # Go through the fields, following all relations. - prev_field = None - for path_part in lookup_fields: - if path_part == "pk": - path_part = opts.pk.name - try: - field = opts.get_field(path_part) - except FieldDoesNotExist: - # Use valid query lookups. - if prev_field and prev_field.get_lookup(path_part): - return field_name - else: - prev_field = field - if hasattr(field, "path_infos"): - # Update opts to follow the relation. - opts = field.path_infos[-1].to_opts - # django < 4.1 - elif hasattr(field, 'get_path_info'): - # Update opts to follow the relation. - opts = field.get_path_info()[-1].to_opts - # Otherwise, use the field with icontains. - lookup = 'icontains' - return LOOKUP_SEP.join([field_name, lookup]) - - def must_call_distinct(self, queryset, search_fields): - """ - Return True if 'distinct()' should be used to query the given lookups. - """ - for search_field in search_fields: - opts = queryset.model._meta - if search_field[0] in self.lookup_prefixes: - search_field = search_field[1:] - # Annotated fields do not need to be distinct - if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: - continue - parts = search_field.split(LOOKUP_SEP) - for part in parts: - field = opts.get_field(part) - if hasattr(field, 'get_path_info'): - # This field is a relation, update opts to follow the relation - path_info = field.get_path_info() - opts = path_info[-1].to_opts - if any(path.m2m for path in path_info): - # This field is a m2m relation so we know we need to call distinct - return True - else: - # This field has a custom __ query transform but is not a relational field. - break - return False - - def filter_queryset(self, request, queryset, view): - search_fields = self.get_search_fields(view, request) - - assert search_fields is not None, ( - f"{view.__class__.__name__} should include a `search_fields`" - "attribute" - ) - - query = self.get_search_query(request) - - if not query: - return queryset - - try: - # Attempt postgresql's full text search - from django.contrib.postgres.search import TrigramStrictWordSimilarity - threshold = calculate_threshold(query, 0.02, 0.12) - queryset = queryset.annotate( - search_field=Concat( - *search_fields, - output_field=CharField() - )).annotate( - similarity=TrigramStrictWordSimilarity( - 'search_field', query) - ).filter(similarity__gt=threshold) - - # NOTE a weird FieldError is raised on sqlite - except (ImportError, FieldError): - # Perform very simple sqlite compatible search - search_terms = search_smart_split(query) - - orm_lookups = [ - self.construct_search(str(search_field), queryset) - for search_field in search_fields - ] - - base = queryset - # generator which for each term builds the corresponding search - conditions = ( - reduce( - operator.or_, - (models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups) - ) for term in search_terms - ) - queryset = queryset.filter(reduce(operator.and_, conditions)) - - # Remove duplicates from results, if necessary - if self.must_call_distinct(queryset, search_fields): - # inspired by django.contrib.admin - # this is more accurate than .distinct form M2M relationship - # also is cross-database - queryset = queryset.filter(pk=models.OuterRef('pk')) - queryset = base.filter(models.Exists(queryset)) - - return queryset - - def to_html(self, request, queryset, view): - if not getattr(view, 'search_fields', None): - return '' - - term = request.query_params.get(self.search_param, '') - term = term[0] if term else '' - context = { - 'param': self.search_param, - 'term': term - } - template = loader.get_template(self.template) - return template.render(context) - - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - return [ - coreapi.Field( - name=self.search_param, - required=False, - location='query', - schema=coreschema.String( - title=force_str(self.search_title), - description=force_str(self.search_description) - ) - ) - ] - - def get_schema_operation_parameters(self, view): - return [ - { - 'name': self.search_param, - 'required': False, - 'in': 'query', - 'description': force_str(self.search_description), - 'schema': { - 'type': 'string', - }, - }, - ] - - -# TODO -#class FieldFilter(BaseFilterBackend): - - -# TODO misunderstood the urlconf stuff of the RUD methods, this is probably unnecessary -class SlugSearchFilter(BaseFilterBackend): - # The URL query parameter used for the search. - template = 'filters/slug.html' - slug_title = _('Slug Search') - slug_description = _("The instance's slug.") - slug_field = 'slug' - - def get_slug_field(self, view, request): - return getattr(view, 'slug_field', None) - - def get_filters_dict(self, request, view): - """ - Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. - """ - if hasattr(view, 'slug_field'): - self.slug_field = self.get_slug_field(view, request) - - assert self.slug_field is not None, ( - f"{view.__class__.__name__} should include a `slug_field` attribute" - ) - self.filters_dict = {} - if self.slug_field in request.query_params.keys(): - slug_term = request.query_params.get(self.slug_field) - self.filters_dict[self.slug_field] = [slug_term] - else: - self.filters_dict[self.slug_field] = [] - - return self.filters_dict - - def filter_queryset(self, request, queryset, view): - # Ensure that the slug field was searched against - try: - if self.slug_field in request.query_params.keys(): - slug_term = request.query_params.get(self.slug_field) - query = {} - query[self.slug_field] = slug_term - - queryset = queryset.get(**query) - except Exception as e: - print(e) - - return queryset - - - def to_html(self, request, queryset, view): - if not getattr(view, 'slug_field', None): - return '' - - slug_term = self.get_slug_term(request) - context = { - 'param': self.slug_field, - 'term': slug_term - } - template = loader.get_template(self.template) - return template.render(context) - - -class SearchFilter(BaseFilterBackend): - # The URL query parameter used for the search. - search_param = api_settings.SEARCH_PARAM - template = 'rest_framework/filters/search.html' - lookup_prefixes = { - '^': 'istartswith', - '=': 'iexact', - '@': 'search', - '$': 'iregex', - } - search_title = _('Search') - search_description = _('A search term.') - # TODO to be removed - def get_search_terms(self, request): - """ - Search terms are set by a ?search=... query parameter, - and may be comma and/or whitespace delimited. - """ - params = request.query_params.get(self.search_param, '') - params = params.replace('\x00', '') # strip null characters - params = params.replace(',', ' ') - return params.split() - # TODO to be removed - def construct_search(self, field_name): - lookup = self.lookup_prefixes.get(field_name[0]) - if lookup: - field_name = field_name[1:] - else: - lookup = 'icontains' - return LOOKUP_SEP.join([field_name, lookup]) - # TODO to be removed - def must_call_distinct(self, queryset, search_fields): - """ - Return True if 'distinct()' should be used to query the given lookups. - """ - for search_field in search_fields: - opts = queryset.model._meta - if search_field[0] in self.lookup_prefixes: - search_field = search_field[1:] - # Annotated fields do not need to be distinct - if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: - continue - parts = search_field.split(LOOKUP_SEP) - for part in parts: - field = opts.get_field(part) - if hasattr(field, 'get_path_info'): - # This field is a relation, update opts to follow the relation - path_info = field.get_path_info() - opts = path_info[-1].to_opts - if any(path.m2m for path in path_info): - # This field is a m2m relation so we know we need to call distinct - return True - else: - # This field has a custom __ query transform but is not a relational field. - break - return False - - def get_filters_dict(self, request, view): - """ - Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. - """ - self.filters_dict = {} - if 'search' in request.query_params.keys(): - slug_term = request.query_params.get('search') - self.filters_dict['search'] = [slug_term] - else: - self.filters_dict['search'] = [] - - return self.filters_dict - - def filter_queryset(self, request, queryset, view): - search_fields = getattr(view, 'search_fields', None) - search_terms = self.get_search_terms(request) - - if not search_fields or not search_terms: - return queryset - - orm_lookups = [ - self.construct_search(str(search_field)) - for search_field in search_fields - ] - - base = queryset - conditions = [] - for search_term in search_terms: - queries = [ - models.Q(**{orm_lookup: search_term}) - for orm_lookup in orm_lookups - ] - conditions.append(reduce(operator.or_, queries)) - queryset = queryset.filter(reduce(operator.and_, conditions)) - - if self.must_call_distinct(queryset, search_fields): - # Filtering against a many-to-many field requires us to - # call queryset.distinct() in order to avoid duplicate items - # in the resulting queryset. - # We try to avoid this if possible, for performance reasons. - queryset = distinct(queryset, base) - return queryset - - def to_html(self, request, queryset, view): - if not getattr(view, 'search_fields', None): - return '' - - term = self.get_search_terms(request) - term = term[0] if term else '' - context = { - 'param': self.search_param, - 'term': term - } - template = loader.get_template(self.template) - return template.render(context) - - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - return [ - coreapi.Field( - name=self.search_param, - required=False, - location='query', - schema=coreschema.String( - title=force_str(self.search_title), - description=force_str(self.search_description) - ) - ) - ] - - def get_schema_operation_parameters(self, view): - return [ - { - 'name': self.search_param, - 'required': False, - 'in': 'query', - 'description': force_str(self.search_description), - 'schema': { - 'type': 'string', - }, - }, - ] - - -class OrderingFilter(BaseFilterBackend): - # The URL query parameter used for the ordering. - ordering_param = api_settings.ORDERING_PARAM - ordering_fields = None - ordering_title = _('Ordering') - ordering_description = _('Which field to use when ordering the results.') - template = 'rest_framework/filters/ordering.html' - - def get_ordering(self, request, queryset, view): - """ - Ordering is set by a comma delimited ?ordering=... query parameter. - - The `ordering` query parameter can be overridden by setting - the `ordering_param` value on the OrderingFilter or by - specifying an `ORDERING_PARAM` value in the API settings. - """ - params = request.query_params.get(self.ordering_param) - if params: - fields = [param.strip() for param in params.split(',')] - ordering = self.remove_invalid_fields(queryset, fields, view, request) - if ordering: - return ordering - - # No ordering was included, or all the ordering fields were invalid - return self.get_default_ordering(view) - - def get_default_ordering(self, view): - ordering = getattr(view, 'ordering', None) - if isinstance(ordering, str): - return (ordering,) - return ordering - - def get_default_valid_fields(self, queryset, view, context={}): - # If `ordering_fields` is not specified, then we determine a default - # based on the serializer class, if one exists on the view. - if hasattr(view, 'get_serializer_class'): - try: - serializer_class = view.get_serializer_class() - except AssertionError: - # Raised by the default implementation if - # no serializer_class was found - serializer_class = None - else: - serializer_class = getattr(view, 'serializer_class', None) - - if serializer_class is None: - msg = ( - "Cannot use %s on a view which does not have either a " - "'serializer_class', an overriding 'get_serializer_class' " - "or 'ordering_fields' attribute." - ) - raise ImproperlyConfigured(msg % self.__class__.__name__) - - model_class = queryset.model - model_property_names = [ - # 'pk' is a property added in Django's Model class, however it is valid for ordering. - attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk' - ] - - return [ - (field.source.replace('.', '__') or field_name, field.label) - for field_name, field in serializer_class(context=context).fields.items() - if ( - not getattr(field, 'write_only', False) and - not field.source == '*' and - field.source not in model_property_names - ) - ] - - def get_valid_fields(self, queryset, view, context={}): - valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) - - if valid_fields is None: - # Default to allowing filtering on serializer fields - return self.get_default_valid_fields(queryset, view, context) - - elif valid_fields == '__all__': - # View explicitly allows filtering on any model field - valid_fields = [ - (field.name, field.verbose_name) for field in queryset.model._meta.fields - ] - valid_fields += [ - (key, key.title().split('__')) - for key in queryset.query.annotations - ] - else: - valid_fields = [ - (item, item) if isinstance(item, str) else item - for item in valid_fields - ] - - return valid_fields - - def remove_invalid_fields(self, queryset, fields, view, request): - valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})] - - def term_valid(term): - if term.startswith("-"): - term = term[1:] - return term in valid_fields - - return [term for term in fields if term_valid(term)] - - def get_filters_dict(self, request, view): - """ - Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. - """ - self.filters_dict = {} - if 'ordering' in request.query_params.keys(): - slug_term = request.query_params.get('ordering') - self.filters_dict['ordering'] = [slug_term] - else: - self.filters_dict['ordering'] = [view.ordering_fields[0]] - - return self.filters_dict - - def filter_queryset(self, request, queryset, view): - ordering = self.get_ordering(request, queryset, view) - - if ordering: - return queryset.order_by(*ordering) - - return queryset - - def get_template_context(self, request, queryset, view): - current = self.get_ordering(request, queryset, view) - current = None if not current else current[0] - options = [] - context = { - 'request': request, - 'current': current, - 'param': self.ordering_param, - } - for key, label in self.get_valid_fields(queryset, view, context): - options.append((key, '%s - %s' % (label, _('ascending')))) - options.append(('-' + key, '%s - %s' % (label, _('descending')))) - context['options'] = options - return context - - def to_html(self, request, queryset, view): - template = loader.get_template(self.template) - context = self.get_template_context(request, queryset, view) - return template.render(context) - - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - return [ - coreapi.Field( - name=self.ordering_param, - required=False, - location='query', - schema=coreschema.String( - title=force_str(self.ordering_title), - description=force_str(self.ordering_description) - ) - ) - ] - - def get_schema_operation_parameters(self, view): - return [ - { - 'name': self.ordering_param, - 'required': False, - 'in': 'query', - 'description': force_str(self.ordering_description), - 'schema': { - 'type': 'string', - }, - }, - ] +from .libfilters.category import * +from .libfilters.facet import * +from .libfilters.lessthanorequal import * +from .libfilters.morethanorequal import * +from .libfilters.nodetreebranch import * +from .libfilters.ordering import * +from .libfilters.trigramsearch import * + + +# TODO all the below are here for potential future reference, they should be +# TODO deleted at some point + +# from django.utils.text import smart_split +# from django.core.exceptions import FieldError +# import operator +# from functools import reduce +# from rest_framework.settings import api_settings +# from rest_framework.filters import BaseFilterBackend +# from django.template import loader +# from django.utils.translation import gettext_lazy as _ +# from django.db import models +# from django.db.models import Q +# from rest_framework.fields import CharField +# from django.db.models.constants import LOOKUP_SEP +# from django.db.models.functions import Concat + +# def calculate_threshold(query, min_threshold, max_threshold): +# query_threshold = len(query)/300 +# if query_threshold < min_threshold: +# return min_threshold +# if max_threshold < query_threshold: +# return max_threshold +# return query_threshold +# +# +# def search_smart_split(search_terms): +# """generator that first splits string by spaces, leaving quoted phrases together, +# then it splits non-quoted phrases by commas. +# """ +# split_terms = [] +# for term in smart_split(search_terms): +# # trim commas to avoid bad matching for quoted phrases +# term = term.strip(',') +# if term.startswith(('"', "'")) and term[0] == term[-1]: +# # quoted phrases are kept together without any other split +# split_terms.append(unescape_string_literal(term)) +# else: +# # non-quoted tokens are split by comma, keeping only non-empty ones +# for sub_term in term.split(','): +# if sub_term: +# split_terms.append(sub_term.strip()) +# return split_terms +# +# +# +# +# +# +# +# +# # TODO misunderstood the urlconf stuff of the RUD methods, this is probably unnecessary +# class SlugSearchFilter(BaseFilterBackend): +# # The URL query parameter used for the search. +# template = 'filters/slug.html' +# slug_title = _('Slug Search') +# slug_description = _("The instance's slug.") +# slug_field = 'slug' +# +# def get_slug_field(self, view, request): +# return getattr(view, 'slug_field', None) +# +# def get_filters_dict(self, request, view): +# """ +# Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. +# """ +# if hasattr(view, 'slug_field'): +# self.slug_field = self.get_slug_field(view, request) +# +# assert self.slug_field is not None, ( +# f"{view.__class__.__name__} should include a `slug_field` attribute" +# ) +# self.filters_dict = {} +# if self.slug_field in request.query_params.keys(): +# slug_term = request.query_params.get(self.slug_field) +# self.filters_dict[self.slug_field] = [slug_term] +# else: +# self.filters_dict[self.slug_field] = [] +# +# return self.filters_dict +# +# def filter_queryset(self, request, queryset, view): +# # Ensure that the slug field was searched against +# try: +# if self.slug_field in request.query_params.keys(): +# slug_term = request.query_params.get(self.slug_field) +# query = {} +# query[self.slug_field] = slug_term +# +# queryset = queryset.get(**query) +# except Exception as e: +# print(e) +# +# return queryset +# +# +# def to_html(self, request, queryset, view): +# if not getattr(view, 'slug_field', None): +# return '' +# +# slug_term = self.get_slug_term(request) +# context = { +# 'param': self.slug_field, +# 'term': slug_term +# } +# template = loader.get_template(self.template) +# return template.render(context) +# +# +# # TODO this is here only for reference for the dev pages +# class SearchFilter(BaseFilterBackend): +# # The URL query parameter used for the search. +# search_param = api_settings.SEARCH_PARAM +# template = 'rest_framework/filters/search.html' +# lookup_prefixes = { +# '^': 'istartswith', +# '=': 'iexact', +# '@': 'search', +# '$': 'iregex', +# } +# search_title = _('Search') +# search_description = _('A search term.') +# # TODO to be removed +# def get_search_terms(self, request): +# """ +# Search terms are set by a ?search=... query parameter, +# and may be comma and/or whitespace delimited. +# """ +# params = request.query_params.get(self.search_param, '') +# params = params.replace('\x00', '') # strip null characters +# params = params.replace(',', ' ') +# return params.split() +# # TODO to be removed +# def construct_search(self, field_name): +# lookup = self.lookup_prefixes.get(field_name[0]) +# if lookup: +# field_name = field_name[1:] +# else: +# lookup = 'icontains' +# return LOOKUP_SEP.join([field_name, lookup]) +# # TODO to be removed +# def must_call_distinct(self, queryset, search_fields): +# """ +# Return True if 'distinct()' should be used to query the given lookups. +# """ +# for search_field in search_fields: +# opts = queryset.model._meta +# if search_field[0] in self.lookup_prefixes: +# search_field = search_field[1:] +# # Annotated fields do not need to be distinct +# if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: +# continue +# parts = search_field.split(LOOKUP_SEP) +# for part in parts: +# field = opts.get_field(part) +# if hasattr(field, 'get_path_info'): +# # This field is a relation, update opts to follow the relation +# path_info = field.get_path_info() +# opts = path_info[-1].to_opts +# if any(path.m2m for path in path_info): +# # This field is a m2m relation so we know we need to call distinct +# return True +# else: +# # This field has a custom __ query transform but is not a relational field. +# break +# return False +# +# def get_filters_dict(self, request, view): +# """ +# Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. +# """ +# self.filters_dict = {} +# if 'search' in request.query_params.keys(): +# slug_term = request.query_params.get('search') +# self.filters_dict['search'] = [slug_term] +# else: +# self.filters_dict['search'] = [] +# +# return self.filters_dict +# +# def filter_queryset(self, request, queryset, view): +# search_fields = getattr(view, 'search_fields', None) +# search_terms = self.get_search_terms(request) +# +# if not search_fields or not search_terms: +# return queryset +# +# orm_lookups = [ +# self.construct_search(str(search_field)) +# for search_field in search_fields +# ] +# +# base = queryset +# conditions = [] +# for search_term in search_terms: +# queries = [ +# models.Q(**{orm_lookup: search_term}) +# for orm_lookup in orm_lookups +# ] +# conditions.append(reduce(operator.or_, queries)) +# queryset = queryset.filter(reduce(operator.and_, conditions)) +# +# if self.must_call_distinct(queryset, search_fields): +# # Filtering against a many-to-many field requires us to +# # call queryset.distinct() in order to avoid duplicate items +# # in the resulting queryset. +# # We try to avoid this if possible, for performance reasons. +# queryset = distinct(queryset, base) +# return queryset +# +# def to_html(self, request, queryset, view): +# if not getattr(view, 'search_fields', None): +# return '' +# +# term = self.get_search_terms(request) +# term = term[0] if term else '' +# context = { +# 'param': self.search_param, +# 'term': term +# } +# template = loader.get_template(self.template) +# return template.render(context) +# +# def get_schema_fields(self, view): +# assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' +# assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' +# return [ +# coreapi.Field( +# name=self.search_param, +# required=False, +# location='query', +# schema=coreschema.String( +# title=force_str(self.search_title), +# description=force_str(self.search_description) +# ) +# ) +# ] +# +# def get_schema_operation_parameters(self, view): +# return [ +# { +# 'name': self.search_param, +# 'required': False, +# 'in': 'query', +# 'description': force_str(self.search_description), +# 'schema': { +# 'type': 'string', +# }, +# }, +# ] diff --git a/starfields_drf_generics/libfilters/__init__.py b/starfields_drf_generics/libfilters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/starfields_drf_generics/libfilters/category.py b/starfields_drf_generics/libfilters/category.py new file mode 100644 index 0000000..f2afd2d --- /dev/null +++ b/starfields_drf_generics/libfilters/category.py @@ -0,0 +1,126 @@ +from rest_framework.filters import BaseFilterBackend +from django.template import loader +from django.db.models import Q + + +class CategoryFilter(BaseFilterBackend): + """ + This filter assigns the view.category object for later use, in particular + for filters that depend on this one. + """ + template = 'filters/categories.html' + category_field = 'category' + + def get_category_class(self, view, request): + return getattr(view, 'category_class', None) + + def assign_view_category(self, request, view): + if not hasattr(view, self.category_field): + if self.category_field in request.query_params.keys(): + try: + category_slug = request.query_params.get( + self.category_field).strip() + category = view.category_class.objects.get( + slug=category_slug) + # Append the category object in the view for later use + view.category = category + except Exception: + view.category = None + else: + view.category = None + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. Queries the database for the current + category and saves it in view.category for internal use and later + filters. + """ + if hasattr(view, 'category_field'): + self.category_field = self.get_category_field(view, request) + + category_class = self.get_category_class(view, request) + + assert category_class is not None, ( + f"{view.__class__.__name__} should include a `category_class`" + "attribute" + ) + + # Create the filters dictionary and find the present category instance + self.assign_view_category(request, view) + + filters_dict = {} + if view.category: + filters_dict[self.category_field] = [view.category.slug] + else: + filters_dict[self.category_field] = [] + + return filters_dict + + def filter_queryset(self, request, queryset, view): + self.assign_view_category(request, view) + + # Create the queryset + if view.category: + kwquery_1 = {} + kwquery_1[self.category_field+'__id'] = view.category.id + if view.category.tn_descendants_pks: + kwquery_2 = {} + key = self.category_field+'__id__in' + kwquery_2[key] = view.category.tn_descendants_pks.split(',') + + queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2)) + else: + queryset = queryset.filter(**kwquery_1) + + return queryset + + # Developer Interface methods + def get_valid_fields(self, queryset, view, context, request): + # A query is executed here to get the possible categories + category_class = self.get_category_class(view, request) + if hasattr(view, 'category_field'): + self.category_field = self.get_category_field(view, request) + + assert category_class is not None, ( + f"{view.__class__.__name__} should include a `category_class`" + "attribute" + ) + + valid_fields = category_class.objects.all() + + if len(valid_fields): + valid_fields = [ + (item.slug, item.__str__()) for item in valid_fields + ] + + return valid_fields + else: + return [] + + def get_current(self, request, queryset, view): + params = request.query_params.get(self.category_field) + if params: + fields = [param.strip() for param in params.split(',')] + return fields[0] + else: + return None + + def get_template_context(self, request, queryset, view): + current = self.get_current(request, queryset, view) + options = [] + context = { + 'request': request, + 'current': current, + 'param': self.category_field, + } + valid_fields = self.get_valid_fields(queryset, view, context, request) + for key, label in valid_fields: + options.append((key, '%s' % (label))) + context['options'] = options + return context + + def to_html(self, request, queryset, view): + template = loader.get_template(self.template) + context = self.get_template_context(request, queryset, view) + return template.render(context) diff --git a/starfields_drf_generics/libfilters/facet.py b/starfields_drf_generics/libfilters/facet.py new file mode 100644 index 0000000..9e4f744 --- /dev/null +++ b/starfields_drf_generics/libfilters/facet.py @@ -0,0 +1,170 @@ +from rest_framework.filters import BaseFilterBackend +from django.template import loader +from django.db.models import Q + + +class FacetFilter(BaseFilterBackend): + """ + This filter requires CategoryFilter to be ran before it. It assigns the + view.facets which includes all the facets applicable to the current + category. + """ + template = 'filters/facets.html' + + def get_facet_class(self, view, request): + return getattr(view, 'facet_class', None) + + def get_facet_tag_class(self, view, request): + return getattr(view, 'facet_tag_class', None) + + def get_facet_tag_field(self, view, request): + return getattr(view, 'facet_tag_field', None) + + def assign_view_facets(self, request, view): + if not hasattr(view, 'facets'): + if hasattr(view, 'facet_class'): + self.facet_class = self.get_facet_class(view, request) + + assert self.facet_class is not None, ( + f"{view.__class__.__name__} should include a `facet_class`" + "attribute" + ) + + if view.category: + if view.category.tn_ancestors_pks: + ancestor_ids = view.category.tn_ancestors_pks.split(',') + view.facets = self.facet_class.objects.filter( + Q(category__id=view.category.id) | Q( + category__id__in=ancestor_ids) + ).prefetch_related('facet_tags') + else: + view.facets = self.facet_class.objects.filter( + category__id=view.category.id).prefetch_related( + 'facet_tags') + else: + view.facets = self.facet_class.objects.filter( + category__tn_level=1).prefetch_related( + 'facet_tags') + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. + """ + if hasattr(view, 'facet_class'): + self.facet_class = self.get_facet_class(view, request) + + assert self.facet_class is not None, ( + f"{view.__class__.__name__} should include a `facet_class`" + "attribute" + ) + + self.assign_view_facets(request, view) + + filters_dict = {} + if view.facets: + for facet in view.facets: + if facet.slug in request.query_params.keys(): + filters_dict[facet.slug] = set( + request.query_params[facet.slug].split(',')) + else: + filters_dict[facet.slug] = set({}) + + # Append the facets object and the tags dict in the view for later + # reference + view.tags = filters_dict + + return filters_dict + + def filter_queryset(self, request, queryset, view): + if hasattr(view, 'facet_tag_class'): + self.facet_tag_class = self.get_facet_tag_class(view, request) + + assert self.facet_tag_class is not None, ( + f"{view.__class__.__name__} should include a `facet_tag_class`" + "attribute" + ) + + if hasattr(view, 'facet_tag_field'): + self.facet_tag_field = self.get_facet_tag_field(view, request) + + assert self.facet_tag_field is not None, ( + f"{view.__class__.__name__} should include a `facet_tag_field`" + "attribute" + ) + + self.assign_view_facets(request, view) + + if view.facets: + for facet in view.facets: + if facet.slug in request.query_params.keys(): + tag_filterlist = request.query_params.get(facet.slug) + if tag_filterlist == '': + # If the tag filterlist is empty then we're not + # filtering against it, it's like having all the tags + # of the facet selected + pass + else: + kwquery = {} + key = self.facet_tag_field+'__slug__in' + kwquery[key] = tag_filterlist.replace(' ', '').split( + ',') + queryset = queryset.filter(**kwquery) + + return queryset + + # Developer Interface methods + def get_template_context(self, request, queryset, view): + # Does aggressive database querying to get the necessary facets and + # facettags, but this is only for the developer interface so its fine + if hasattr(view, 'facet_class'): + self.facet_class = self.get_facet_class(view, request) + + assert self.facet_class is not None, ( + f"{view.__class__.__name__} should include a `facet_class`" + "attribute" + ) + + if hasattr(view, 'facet_tag_class'): + self.facet_tag_class = self.get_facet_tag_class(view, request) + + assert self.facet_tag_class is not None, ( + f"{view.__class__.__name__} should include a `facet_tag_class`" + "attribute" + ) + + # Find the current choices + current = [] + facet_slugs = [] + if view.facets: + for facet in view.facets: + facet_slugs.append(facet.slug) + if facet.slug in request.query_params.keys(): + current.append(request.query_params.get(facet.slug)) + + facet_slug_names = {} + options = {} + context = { + 'request': request, + 'current': current, + 'facet_slugs': facet_slugs, + } + if view.facets: + for facet in view.facets: + facet_tag_instances = self.facet_tag_class.objects.filter( + facet__slug=facet.slug) + options[facet.slug] = [(facet_tag.slug, facet_tag.name) for + facet_tag in facet_tag_instances] + facet_slug_names[facet.slug] = facet.name + context['facet_slug_names'] = facet_slug_names + context['options'] = options + else: + context['facet_slug_names'] = {} + context['options'] = {} + + return context + + def to_html(self, request, queryset, view): + template = loader.get_template(self.template) + context = self.get_template_context(request, queryset, view) + return template.render(context) diff --git a/starfields_drf_generics/libfilters/lessthanorequal.py b/starfields_drf_generics/libfilters/lessthanorequal.py new file mode 100644 index 0000000..65520b6 --- /dev/null +++ b/starfields_drf_generics/libfilters/lessthanorequal.py @@ -0,0 +1,38 @@ +from rest_framework.filters import BaseFilterBackend + + +class LessThanOrEqualFilter(BaseFilterBackend): + def get_less_than_field(self, view, request): + return getattr(view, 'less_than_field', None) + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. + """ + less_than_field = self.get_less_than_field(view, request) + + assert less_than_field is not None, ( + f"{view.__class__.__name__} should include a `less_than_field`" + "attribute" + ) + + filters_dict = {} + if less_than_field+'_max' in request.query_params.keys(): + field_value = request.query_params.get(less_than_field+'_max') + filters_dict[less_than_field+'_max'] = [field_value] + else: + filters_dict[less_than_field+'_max'] = [] + return filters_dict + + def filter_queryset(self, request, queryset, view): + # Return the correctly filtered queryset but also assign the filter + # dict to create the unique url for the cache + less_than_field = self.get_less_than_field(view, request) + if less_than_field+'_max' in request.query_params.keys(): + kwquery = {} + field_value = request.query_params.get(less_than_field+'_max') + kwquery[less_than_field+'__lte'] = field_value + return queryset.filter(**kwquery) + else: + return queryset diff --git a/starfields_drf_generics/libfilters/morethanorequal.py b/starfields_drf_generics/libfilters/morethanorequal.py new file mode 100644 index 0000000..08e1dbb --- /dev/null +++ b/starfields_drf_generics/libfilters/morethanorequal.py @@ -0,0 +1,39 @@ +from rest_framework.filters import BaseFilterBackend + + +class MoreThanOrEqualFilter(BaseFilterBackend): + def get_more_than_field(self, view, request): + return getattr(view, 'more_than_field', None) + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. + """ + more_than_field = self.get_more_than_field(view, request) + + assert more_than_field is not None, ( + f"{view.__class__.__name__} should include a `more_than_field`" + "attribute" + ) + + filters_dict = {} + if more_than_field+'_min' in request.query_params.keys(): + field_value = request.query_params.get(more_than_field+'_min') + filters_dict[more_than_field+'_min'] = [field_value] + else: + filters_dict[more_than_field+'_min'] = [] + return filters_dict + + def filter_queryset(self, request, queryset, view): + """ + Return the correctly filtered queryset + """ + more_than_field = self.get_more_than_field(view, request) + if more_than_field+'_min' in request.query_params.keys(): + kwquery = {} + kwquery[more_than_field+'__gte'] = request.query_params.get( + more_than_field+'_min') + return queryset.filter(**kwquery) + else: + return queryset diff --git a/starfields_drf_generics/libfilters/nodetreebranch.py b/starfields_drf_generics/libfilters/nodetreebranch.py new file mode 100644 index 0000000..39b9790 --- /dev/null +++ b/starfields_drf_generics/libfilters/nodetreebranch.py @@ -0,0 +1,68 @@ +from rest_framework.settings import api_settings +from rest_framework.filters import BaseFilterBackend +from django.db.models import Q + + +class TreeNodeBranchFilter(BaseFilterBackend): + def get_descendants_of_field(self, view, request): + return getattr(view, 'descendants_of_field', None) + + def get_depth_field(self, view, request): + return getattr(view, 'depth_field', None) + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. + """ + descendants_of_field = self.get_descendants_of_field(view, request) + + assert descendants_of_field is not None, ( + "{view.__class__.__name__} should include a " + f"`descendants_of_field` attribute" + ) + + depth_field = self.get_depth_field(view, request) + + assert depth_field is not None, ( + "{view.__class__.__name__} should include a " + f"`depth_field` attribute" + ) + + filters_dict = {} + if descendants_of_field in request.query_params.keys(): + field_value = request.query_params.get(descendants_of_field) + filters_dict[descendants_of_field] = [field_value] + else: + filters_dict[descendants_of_field] = [] + + if depth_field in request.query_params.keys(): + field_value = request.query_params.get(depth_field) + filters_dict[depth_field] = [field_value] + else: + filters_dict[depth_field] = [4] + + return filters_dict + + def filter_queryset(self, request, queryset, view): + # Return the correctly filtered queryset but also assign the filter + # dict to create the unique url for the cache + descendants_of_field = self.get_descendants_of_field(view, request) + depth_field = self.get_depth_field(view, request) + + if descendants_of_field in request.query_params.keys(): + field_value = request.query_params.get(descendants_of_field) + # Instead of doing two queries to get the descendants through the + # object a single more complex queryset + queryset = queryset.filter( + Q(tn_ancestors_pks=field_value) | + Q(tn_ancestors_pks__contains=","+field_value) | + Q(tn_ancestors_pks__contains=field_value+",")) + + if depth_field in request.query_params.keys(): + field_value = request.query_params.get(depth_field) + queryset = queryset.filter(tn_level__lte=field_value) + else: + queryset = queryset.filter(tn_level__lte=4) + + return queryset diff --git a/starfields_drf_generics/libfilters/ordering.py b/starfields_drf_generics/libfilters/ordering.py new file mode 100644 index 0000000..cf3ecc8 --- /dev/null +++ b/starfields_drf_generics/libfilters/ordering.py @@ -0,0 +1,180 @@ +from rest_framework.settings import api_settings +from rest_framework.filters import BaseFilterBackend +from django.template import loader +from django.utils.translation import gettext_lazy as _ + + +class OrderingFilter(BaseFilterBackend): + # The URL query parameter used for the ordering. + ordering_param = api_settings.ORDERING_PARAM + ordering_fields = None + ordering_title = _('Ordering') + ordering_description = _('Which field to use when ordering the results.') + template = 'rest_framework/filters/ordering.html' + + def get_ordering(self, request, queryset, view): + """ + Ordering is set by a comma delimited ?ordering=... query parameter. + + The `ordering` query parameter can be overridden by setting + the `ordering_param` value on the OrderingFilter or by + specifying an `ORDERING_PARAM` value in the API settings. + """ + params = request.query_params.get(self.ordering_param) + if params: + fields = [param.strip() for param in params.split(',')] + ordering = self.remove_invalid_fields(queryset, + fields, + view, + request) + if ordering: + return ordering + + # No ordering was included, or all the ordering fields were invalid + return self.get_default_ordering(view) + + def get_default_ordering(self, view): + ordering = getattr(view, 'ordering', None) + if isinstance(ordering, str): + return (ordering,) + return ordering + + def get_default_valid_fields(self, queryset, view, context={}): + # If `ordering_fields` is not specified, then we determine a default + # based on the serializer class, if one exists on the view. + if hasattr(view, 'get_serializer_class'): + try: + serializer_class = view.get_serializer_class() + except AssertionError: + # Raised by the default implementation if + # no serializer_class was found + serializer_class = None + else: + serializer_class = getattr(view, 'serializer_class', None) + + if serializer_class is None: + msg = ( + "Cannot use %s on a view which does not have either a " + "'serializer_class', an overriding 'get_serializer_class' " + "or 'ordering_fields' attribute." + ) + raise ImproperlyConfigured(msg % self.__class__.__name__) + + model_class = queryset.model + model_property_names = [ + # 'pk' is a property added in Django's Model class, however it is valid for ordering. + attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk' + ] + + return [ + (field.source.replace('.', '__') or field_name, field.label) + for field_name, field in serializer_class(context=context).fields.items() + if ( + not getattr(field, 'write_only', False) and + not field.source == '*' and + field.source not in model_property_names + ) + ] + + def get_valid_fields(self, queryset, view, context={}): + valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) + + if valid_fields is None: + # Default to allowing filtering on serializer fields + return self.get_default_valid_fields(queryset, view, context) + + elif valid_fields == '__all__': + # View explicitly allows filtering on any model field + valid_fields = [ + (field.name, field.verbose_name) for field in queryset.model._meta.fields + ] + valid_fields += [ + (key, key.title().split('__')) + for key in queryset.query.annotations + ] + else: + valid_fields = [ + (item, item) if isinstance(item, str) else item + for item in valid_fields + ] + + return valid_fields + + def remove_invalid_fields(self, queryset, fields, view, request): + valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})] + + def term_valid(term): + if term.startswith("-"): + term = term[1:] + return term in valid_fields + + return [term for term in fields if term_valid(term)] + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. + """ + self.filters_dict = {} + if 'ordering' in request.query_params.keys(): + slug_term = request.query_params.get('ordering') + self.filters_dict['ordering'] = [slug_term] + else: + self.filters_dict['ordering'] = [view.ordering_fields[0]] + + return self.filters_dict + + def filter_queryset(self, request, queryset, view): + ordering = self.get_ordering(request, queryset, view) + + if ordering: + return queryset.order_by(*ordering) + + return queryset + + def get_template_context(self, request, queryset, view): + current = self.get_ordering(request, queryset, view) + current = None if not current else current[0] + options = [] + context = { + 'request': request, + 'current': current, + 'param': self.ordering_param, + } + for key, label in self.get_valid_fields(queryset, view, context): + options.append((key, '%s - %s' % (label, _('ascending')))) + options.append(('-' + key, '%s - %s' % (label, _('descending')))) + context['options'] = options + return context + + def to_html(self, request, queryset, view): + template = loader.get_template(self.template) + context = self.get_template_context(request, queryset, view) + return template.render(context) + + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + return [ + coreapi.Field( + name=self.ordering_param, + required=False, + location='query', + schema=coreschema.String( + title=force_str(self.ordering_title), + description=force_str(self.ordering_description) + ) + ) + ] + + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.ordering_param, + 'required': False, + 'in': 'query', + 'description': force_str(self.ordering_description), + 'schema': { + 'type': 'string', + }, + }, + ] diff --git a/starfields_drf_generics/libfilters/trigramsearch.py b/starfields_drf_generics/libfilters/trigramsearch.py new file mode 100644 index 0000000..85efda0 --- /dev/null +++ b/starfields_drf_generics/libfilters/trigramsearch.py @@ -0,0 +1,247 @@ +from django.utils.text import smart_split +from django.core.exceptions import FieldError +import operator +from functools import reduce +from rest_framework.filters import BaseFilterBackend +from django.template import loader +from django.utils.translation import gettext_lazy as _ +from django.db import models +from django.db.models import Q +from rest_framework.fields import CharField +from django.db.models.constants import LOOKUP_SEP +from django.db.models.functions import Concat + + +def calculate_threshold(query, min_threshold, max_threshold): + query_threshold = len(query)/300 + if query_threshold < min_threshold: + return min_threshold + if max_threshold < query_threshold: + return max_threshold + return query_threshold + + +def search_smart_split(search_terms): + """ + Generator that first splits string by spaces, leaving quoted phrases + together, then it splits non-quoted phrases by commas. + """ + split_terms = [] + for term in smart_split(search_terms): + # trim commas to avoid bad matching for quoted phrases + term = term.strip(',') + if term.startswith(('"', "'")) and term[0] == term[-1]: + # quoted phrases are kept together without any other split + split_terms.append(unescape_string_literal(term)) + else: + # non-quoted tokens are split by comma, keeping only non-empty ones + for sub_term in term.split(','): + if sub_term: + split_terms.append(sub_term.strip()) + return split_terms + + +class TrigramSearchFilter(BaseFilterBackend): + # The URL query parameter used for the search. + search_param = 'search' + template = 'rest_framework/filters/search.html' + search_title = _('Search') + search_description = _('A search string to perform trigram similarity' + 'based searching with.') + lookup_prefixes = { + '^': 'istartswith', + '=': 'iexact', + '@': 'search', + '$': 'iregex', + } + + def get_filters_dict(self, request, view): + """ + Custom method that returns the filters exclusive to this filter in a + dict. For caching purposes. + """ + self.filters_dict = {} + if 'search' in request.query_params.keys(): + slug_term = request.query_params.get('search') + self.filters_dict['search'] = [slug_term] + else: + self.filters_dict['search'] = [] + + return self.filters_dict + + def get_search_fields(self, view, request): + """ + Search fields are obtained from the view, but the request is always + passed to this method. Sub-classes can override this method to + dynamically change the search fields based on request content. + """ + return getattr(view, 'search_fields', None) + + def get_search_query(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be whitespace delimited. + """ + value = request.query_params.get(self.search_param, '') + field = CharField(trim_whitespace=False, allow_blank=True) + cleaned_value = field.run_validation(value) + return cleaned_value + + def construct_search(self, field_name, queryset): + """ + For the sqlite search + """ + lookup = self.lookup_prefixes.get(field_name[0]) + if lookup: + field_name = field_name[1:] + else: + # Use field_name if it includes a lookup. + opts = queryset.model._meta + lookup_fields = field_name.split(LOOKUP_SEP) + # Go through the fields, following all relations. + prev_field = None + for path_part in lookup_fields: + if path_part == "pk": + path_part = opts.pk.name + try: + field = opts.get_field(path_part) + except FieldDoesNotExist: + # Use valid query lookups. + if prev_field and prev_field.get_lookup(path_part): + return field_name + else: + prev_field = field + if hasattr(field, "path_infos"): + # Update opts to follow the relation. + opts = field.path_infos[-1].to_opts + # django < 4.1 + elif hasattr(field, 'get_path_info'): + # Update opts to follow the relation. + opts = field.get_path_info()[-1].to_opts + # Otherwise, use the field with icontains. + lookup = 'icontains' + return LOOKUP_SEP.join([field_name, lookup]) + + def must_call_distinct(self, queryset, search_fields): + """ + Return True if 'distinct()' should be used to query the given lookups. + """ + for search_field in search_fields: + opts = queryset.model._meta + if search_field[0] in self.lookup_prefixes: + search_field = search_field[1:] + # Annotated fields do not need to be distinct + if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: + continue + parts = search_field.split(LOOKUP_SEP) + for part in parts: + field = opts.get_field(part) + if hasattr(field, 'get_path_info'): + # This field is a relation, update opts to follow the relation + path_info = field.get_path_info() + opts = path_info[-1].to_opts + if any(path.m2m for path in path_info): + # This field is a m2m relation so we know we need to call distinct + return True + else: + # This field has a custom __ query transform but is not a relational field. + break + return False + + def filter_queryset(self, request, queryset, view): + search_fields = self.get_search_fields(view, request) + + assert search_fields is not None, ( + f"{view.__class__.__name__} should include a `search_fields`" + "attribute" + ) + + query = self.get_search_query(request) + + if not query: + return queryset + + try: + # Attempt postgresql's full text search + from django.contrib.postgres.search import TrigramStrictWordSimilarity + threshold = calculate_threshold(query, 0.02, 0.12) + queryset = queryset.annotate( + search_field=Concat( + *search_fields, + output_field=CharField() + )).annotate( + similarity=TrigramStrictWordSimilarity( + 'search_field', query) + ).filter(similarity__gt=threshold) + + # NOTE a weird FieldError is raised on sqlite + except (ImportError, FieldError): + # Perform very simple sqlite compatible search + search_terms = search_smart_split(query) + + orm_lookups = [ + self.construct_search(str(search_field), queryset) + for search_field in search_fields + ] + + base = queryset + # generator which for each term builds the corresponding search + conditions = ( + reduce( + operator.or_, + (models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups) + ) for term in search_terms + ) + queryset = queryset.filter(reduce(operator.and_, conditions)) + + # Remove duplicates from results, if necessary + if self.must_call_distinct(queryset, search_fields): + # inspired by django.contrib.admin + # this is more accurate than .distinct form M2M relationship + # also is cross-database + queryset = queryset.filter(pk=models.OuterRef('pk')) + queryset = base.filter(models.Exists(queryset)) + + return queryset + + def to_html(self, request, queryset, view): + if not getattr(view, 'search_fields', None): + return '' + + term = request.query_params.get(self.search_param, '') + term = term[0] if term else '' + context = { + 'param': self.search_param, + 'term': term + } + template = loader.get_template(self.template) + return template.render(context) + + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + return [ + coreapi.Field( + name=self.search_param, + required=False, + location='query', + schema=coreschema.String( + title=force_str(self.search_title), + description=force_str(self.search_description) + ) + ) + ] + + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.search_param, + 'required': False, + 'in': 'query', + 'description': force_str(self.search_description), + 'schema': { + 'type': 'string', + }, + }, + ] + diff --git a/starfields_drf_generics/mixins.py b/starfields_drf_generics/mixins.py index 1ca222f..3c2732d 100644 --- a/starfields_drf_generics/mixins.py +++ b/starfields_drf_generics/mixins.py @@ -109,7 +109,7 @@ class CachedListCreateModelMixin(CacheDeleteMixin): A fully custom mixin that handles mutiple instance cration. """ - def list_create(self, request): + def list_create(self, request, **kwargs): " Creates the list of entries in the request " # Go on with the creation as normal serializer = self.get_serializer(data=request.data, many=True) @@ -142,7 +142,7 @@ class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin): inherit anything from it. """ - def list(self, request): + def list(self, request, **kwargs): " Retrieves the listing of entries " # Attempt to get the request from the cache cache_attempt = self.get_cache(request) @@ -206,7 +206,7 @@ class CachedListDestroyModelMixin(CacheDeleteMixin): A fully custom mixin that handles mutiple instance deletions. """ - def list_destroy(self, request): + def list_destroy(self, request, **kwargs): " Deletes the list of entries in the request " # Go on with the validation as normal serializer = self.get_serializer(data=request.data, many=True)