Made the search filter use TrigramStrictWordSimilarity instead of the plain version and made a backup plain sqlite compatible search for when postgres is not avaiable
All checks were successful
StarFields Django Rest Framework Generics / build (push) Successful in 16s
All checks were successful
StarFields Django Rest Framework Generics / build (push) Successful in 16s
This commit is contained in:
@@ -6,14 +6,32 @@ from rest_framework.filters import BaseFilterBackend
|
|||||||
from django.template import loader
|
from django.template import loader
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models import Q, CharField
|
from django.db.models import Q
|
||||||
|
from rest_framework.fields import CharField
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.db.models.functions import Concat
|
from django.db.models.functions import Concat
|
||||||
from django.contrib.postgres.search import TrigramSimilarity
|
|
||||||
|
|
||||||
|
|
||||||
# TODO the dev pages are not done
|
# TODO the dev pages are not done
|
||||||
|
|
||||||
|
def search_smart_split(search_terms):
|
||||||
|
"""generator that first splits string by spaces, leaving quoted phrases together,
|
||||||
|
then it splits non-quoted phrases by commas.
|
||||||
|
"""
|
||||||
|
split_terms = []
|
||||||
|
for term in smart_split(search_terms):
|
||||||
|
# trim commas to avoid bad matching for quoted phrases
|
||||||
|
term = term.strip(',')
|
||||||
|
if term.startswith(('"', "'")) and term[0] == term[-1]:
|
||||||
|
# quoted phrases are kept together without any other split
|
||||||
|
split_terms.append(unescape_string_literal(term))
|
||||||
|
else:
|
||||||
|
# non-quoted tokens are split by comma, keeping only non-empty ones
|
||||||
|
for sub_term in term.split(','):
|
||||||
|
if sub_term:
|
||||||
|
split_terms.append(sub_term.strip())
|
||||||
|
return split_terms
|
||||||
|
|
||||||
|
|
||||||
class LessThanOrEqualFilter(BaseFilterBackend):
|
class LessThanOrEqualFilter(BaseFilterBackend):
|
||||||
def get_less_than_field(self, view, request):
|
def get_less_than_field(self, view, request):
|
||||||
@@ -403,24 +421,136 @@ class TrigramSearchFilter(BaseFilterBackend):
|
|||||||
|
|
||||||
return self.filters_dict
|
return self.filters_dict
|
||||||
|
|
||||||
|
def get_search_fields(self, view, request):
|
||||||
|
"""
|
||||||
|
Search fields are obtained from the view, but the request is always
|
||||||
|
passed to this method. Sub-classes can override this method to
|
||||||
|
dynamically change the search fields based on request content.
|
||||||
|
"""
|
||||||
|
return getattr(view, 'search_fields', None)
|
||||||
|
|
||||||
|
def get_search_query(self, request):
|
||||||
|
"""
|
||||||
|
Search terms are set by a ?search=... query parameter,
|
||||||
|
and may be whitespace delimited.
|
||||||
|
"""
|
||||||
|
value = request.query_params.get(self.search_param, '')
|
||||||
|
field = CharField(trim_whitespace=False, allow_blank=True)
|
||||||
|
cleaned_value = field.run_validation(value)
|
||||||
|
return cleaned_value
|
||||||
|
|
||||||
|
def construct_search(self, field_name, queryset):
|
||||||
|
"""
|
||||||
|
For the sqlite search
|
||||||
|
"""
|
||||||
|
lookup = self.lookup_prefixes.get(field_name[0])
|
||||||
|
if lookup:
|
||||||
|
field_name = field_name[1:]
|
||||||
|
else:
|
||||||
|
# Use field_name if it includes a lookup.
|
||||||
|
opts = queryset.model._meta
|
||||||
|
lookup_fields = field_name.split(LOOKUP_SEP)
|
||||||
|
# Go through the fields, following all relations.
|
||||||
|
prev_field = None
|
||||||
|
for path_part in lookup_fields:
|
||||||
|
if path_part == "pk":
|
||||||
|
path_part = opts.pk.name
|
||||||
|
try:
|
||||||
|
field = opts.get_field(path_part)
|
||||||
|
except FieldDoesNotExist:
|
||||||
|
# Use valid query lookups.
|
||||||
|
if prev_field and prev_field.get_lookup(path_part):
|
||||||
|
return field_name
|
||||||
|
else:
|
||||||
|
prev_field = field
|
||||||
|
if hasattr(field, "path_infos"):
|
||||||
|
# Update opts to follow the relation.
|
||||||
|
opts = field.path_infos[-1].to_opts
|
||||||
|
# django < 4.1
|
||||||
|
elif hasattr(field, 'get_path_info'):
|
||||||
|
# Update opts to follow the relation.
|
||||||
|
opts = field.get_path_info()[-1].to_opts
|
||||||
|
# Otherwise, use the field with icontains.
|
||||||
|
lookup = 'icontains'
|
||||||
|
return LOOKUP_SEP.join([field_name, lookup])
|
||||||
|
|
||||||
|
def must_call_distinct(self, queryset, search_fields):
|
||||||
|
"""
|
||||||
|
Return True if 'distinct()' should be used to query the given lookups.
|
||||||
|
"""
|
||||||
|
for search_field in search_fields:
|
||||||
|
opts = queryset.model._meta
|
||||||
|
if search_field[0] in self.lookup_prefixes:
|
||||||
|
search_field = search_field[1:]
|
||||||
|
# Annotated fields do not need to be distinct
|
||||||
|
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
|
||||||
|
continue
|
||||||
|
parts = search_field.split(LOOKUP_SEP)
|
||||||
|
for part in parts:
|
||||||
|
field = opts.get_field(part)
|
||||||
|
if hasattr(field, 'get_path_info'):
|
||||||
|
# This field is a relation, update opts to follow the relation
|
||||||
|
path_info = field.get_path_info()
|
||||||
|
opts = path_info[-1].to_opts
|
||||||
|
if any(path.m2m for path in path_info):
|
||||||
|
# This field is a m2m relation so we know we need to call distinct
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# This field has a custom __ query transform but is not a relational field.
|
||||||
|
break
|
||||||
|
return False
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
search_fields = getattr(view, 'search_fields', None)
|
search_fields = self.get_search_fields(view, request)
|
||||||
|
|
||||||
assert search_fields is not None, (
|
assert search_fields is not None, (
|
||||||
f"{view.__class__.__name__} should include a `search_fields`"
|
f"{view.__class__.__name__} should include a `search_fields`"
|
||||||
"attribute"
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
query = request.query_params.get(self.search_param, '')
|
query = self.get_search_query(request)
|
||||||
|
|
||||||
if query:
|
if not query:
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Attempt postgresql's full text search
|
||||||
|
from django.contrib.postgres.search import TrigramStrictWordSimilarity
|
||||||
queryset = queryset.annotate(
|
queryset = queryset.annotate(
|
||||||
search_field=Concat(
|
search_field=Concat(
|
||||||
*search_fields,
|
*search_fields,
|
||||||
output_field=CharField()
|
output_field=CharField()
|
||||||
)).annotate(
|
)).annotate(
|
||||||
similarity=TrigramSimilarity('search_field', query)
|
similarity=TrigramStrictWordSimilarity(
|
||||||
).filter(similarity__gt=0.05).distinct()
|
'search_field', query)
|
||||||
|
).filter(similarity__gt=0.05)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Perform very simple sqlite compatible search
|
||||||
|
search_terms = search_smart_split(query)
|
||||||
|
|
||||||
|
orm_lookups = [
|
||||||
|
self.construct_search(str(search_field), queryset)
|
||||||
|
for search_field in search_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
base = queryset
|
||||||
|
# generator which for each term builds the corresponding search
|
||||||
|
conditions = (
|
||||||
|
reduce(
|
||||||
|
operator.or_,
|
||||||
|
(models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups)
|
||||||
|
) for term in search_terms
|
||||||
|
)
|
||||||
|
queryset = queryset.filter(reduce(operator.and_, conditions))
|
||||||
|
|
||||||
|
# Remove duplicates from results, if necessary
|
||||||
|
if self.must_call_distinct(queryset, search_fields):
|
||||||
|
# inspired by django.contrib.admin
|
||||||
|
# this is more accurate than .distinct form M2M relationship
|
||||||
|
# also is cross-database
|
||||||
|
queryset = queryset.filter(pk=models.OuterRef('pk'))
|
||||||
|
queryset = base.filter(models.Exists(queryset))
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user