From 2b8a4863c020a58c3bcfb7018b3d951824b73f2a Mon Sep 17 00:00:00 2001 From: Pelagic Date: Sat, 17 Aug 2024 15:53:35 +0300 Subject: [PATCH] Made the search filter use TrigramStrictWordSimilarity instead of the plain version and made a backup plain sqlite compatible search for when postgres is not avaiable --- starfields_drf_generics/filters.py | 144 +++++++++++++++++++++++++++-- 1 file changed, 137 insertions(+), 7 deletions(-) diff --git a/starfields_drf_generics/filters.py b/starfields_drf_generics/filters.py index e8ddf9a..0864133 100644 --- a/starfields_drf_generics/filters.py +++ b/starfields_drf_generics/filters.py @@ -6,14 +6,32 @@ from rest_framework.filters import BaseFilterBackend from django.template import loader from django.utils.translation import gettext_lazy as _ from django.db import models -from django.db.models import Q, CharField +from django.db.models import Q +from rest_framework.fields import CharField from django.db.models.constants import LOOKUP_SEP from django.db.models.functions import Concat -from django.contrib.postgres.search import TrigramSimilarity # TODO the dev pages are not done +def search_smart_split(search_terms): + """generator that first splits string by spaces, leaving quoted phrases together, + then it splits non-quoted phrases by commas. + """ + split_terms = [] + for term in smart_split(search_terms): + # trim commas to avoid bad matching for quoted phrases + term = term.strip(',') + if term.startswith(('"', "'")) and term[0] == term[-1]: + # quoted phrases are kept together without any other split + split_terms.append(unescape_string_literal(term)) + else: + # non-quoted tokens are split by comma, keeping only non-empty ones + for sub_term in term.split(','): + if sub_term: + split_terms.append(sub_term.strip()) + return split_terms + class LessThanOrEqualFilter(BaseFilterBackend): def get_less_than_field(self, view, request): @@ -403,24 +421,136 @@ class TrigramSearchFilter(BaseFilterBackend): return self.filters_dict + def get_search_fields(self, view, request): + """ + Search fields are obtained from the view, but the request is always + passed to this method. Sub-classes can override this method to + dynamically change the search fields based on request content. + """ + return getattr(view, 'search_fields', None) + + def get_search_query(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be whitespace delimited. + """ + value = request.query_params.get(self.search_param, '') + field = CharField(trim_whitespace=False, allow_blank=True) + cleaned_value = field.run_validation(value) + return cleaned_value + + def construct_search(self, field_name, queryset): + """ + For the sqlite search + """ + lookup = self.lookup_prefixes.get(field_name[0]) + if lookup: + field_name = field_name[1:] + else: + # Use field_name if it includes a lookup. + opts = queryset.model._meta + lookup_fields = field_name.split(LOOKUP_SEP) + # Go through the fields, following all relations. + prev_field = None + for path_part in lookup_fields: + if path_part == "pk": + path_part = opts.pk.name + try: + field = opts.get_field(path_part) + except FieldDoesNotExist: + # Use valid query lookups. + if prev_field and prev_field.get_lookup(path_part): + return field_name + else: + prev_field = field + if hasattr(field, "path_infos"): + # Update opts to follow the relation. + opts = field.path_infos[-1].to_opts + # django < 4.1 + elif hasattr(field, 'get_path_info'): + # Update opts to follow the relation. + opts = field.get_path_info()[-1].to_opts + # Otherwise, use the field with icontains. + lookup = 'icontains' + return LOOKUP_SEP.join([field_name, lookup]) + + def must_call_distinct(self, queryset, search_fields): + """ + Return True if 'distinct()' should be used to query the given lookups. + """ + for search_field in search_fields: + opts = queryset.model._meta + if search_field[0] in self.lookup_prefixes: + search_field = search_field[1:] + # Annotated fields do not need to be distinct + if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: + continue + parts = search_field.split(LOOKUP_SEP) + for part in parts: + field = opts.get_field(part) + if hasattr(field, 'get_path_info'): + # This field is a relation, update opts to follow the relation + path_info = field.get_path_info() + opts = path_info[-1].to_opts + if any(path.m2m for path in path_info): + # This field is a m2m relation so we know we need to call distinct + return True + else: + # This field has a custom __ query transform but is not a relational field. + break + return False + def filter_queryset(self, request, queryset, view): - search_fields = getattr(view, 'search_fields', None) + search_fields = self.get_search_fields(view, request) assert search_fields is not None, ( f"{view.__class__.__name__} should include a `search_fields`" "attribute" ) - query = request.query_params.get(self.search_param, '') + query = self.get_search_query(request) - if query: + if not query: + return queryset + + try: + # Attempt postgresql's full text search + from django.contrib.postgres.search import TrigramStrictWordSimilarity queryset = queryset.annotate( search_field=Concat( *search_fields, output_field=CharField() )).annotate( - similarity=TrigramSimilarity('search_field', query) - ).filter(similarity__gt=0.05).distinct() + similarity=TrigramStrictWordSimilarity( + 'search_field', query) + ).filter(similarity__gt=0.05) + + except ImportError: + # Perform very simple sqlite compatible search + search_terms = search_smart_split(query) + + orm_lookups = [ + self.construct_search(str(search_field), queryset) + for search_field in search_fields + ] + + base = queryset + # generator which for each term builds the corresponding search + conditions = ( + reduce( + operator.or_, + (models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups) + ) for term in search_terms + ) + queryset = queryset.filter(reduce(operator.and_, conditions)) + + # Remove duplicates from results, if necessary + if self.must_call_distinct(queryset, search_fields): + # inspired by django.contrib.admin + # this is more accurate than .distinct form M2M relationship + # also is cross-database + queryset = queryset.filter(pk=models.OuterRef('pk')) + queryset = base.filter(models.Exists(queryset)) return queryset