14 Commits

Author SHA1 Message Date
45a1af3c52 Updated the README and the mixing to provide better asserts.
All checks were successful
StarFields Django Rest Framework Generics / build (push) Successful in 36s
2025-01-19 07:19:02 +02:00
edf3dc051d Attempting to upgrade pkginfo during publish action to fix a twine error.
All checks were successful
StarFields Django Rest Framework Generics / build (push) Successful in 11s
2025-01-11 00:25:49 +02:00
2cee88413b Bump version number to 0.4.0
Some checks failed
StarFields Django Rest Framework Generics / build (push) Failing after 10s
2025-01-11 00:20:59 +02:00
a380800285 Added a multipart parser that attempts to parse nested strings as json.
Some checks failed
StarFields Django Rest Framework Generics / build (push) Failing after 12s
2025-01-11 00:19:46 +02:00
13022609fc Fixed a missing import.
Some checks failed
StarFields Django Rest Framework Generics / build (push) Failing after 35s
2025-01-06 10:53:08 +02:00
1e0f3c694e Refactored the library a bit and added a .gitignore file. 2025-01-05 10:13:59 +02:00
ac7ea3c88f Bump new version.
All checks were successful
StarFields Django Rest Framework Generics / build (push) Successful in 10s
2024-11-26 08:00:12 +02:00
ab3083eab3 Fixed a bug that was raising FieldError in dev because of failing trigram similarity search.
Some checks failed
StarFields Django Rest Framework Generics / build (push) Failing after 32s
2024-11-26 07:58:35 +02:00
40ff763f96 Fixing the publish workflow.
All checks were successful
StarFields Django Rest Framework Generics / build (push) Successful in 9s
2024-10-08 11:38:57 +03:00
94c3304c1c Fixing the publish workflow.
Some checks failed
StarFields Django Rest Framework Generics / build (push) Failing after 13s
2024-10-08 11:34:10 +03:00
14fb9d804f Fixing the publish workflow.
Some checks failed
StarFields Django Rest Framework Generics / build (push) Failing after 8s
2024-10-08 11:32:03 +03:00
498a6da603 Fixing the publish workflow.
Some checks failed
StarFields Django Rest Framework Generics / build (push) Failing after 10s
2024-10-08 11:19:58 +03:00
353106ee17 Fixed a facet filter bug where redundant code would error out depending on the query.
All checks were successful
StarFields Django Rest Framework Generics / build (push) Successful in 12s
2024-10-08 11:03:20 +03:00
28ac95fd47 Testing the workflow.
All checks were successful
StarFields Django Rest Framework Generics / build (push) Successful in 11s
2024-10-01 03:20:22 +03:00
15 changed files with 1283 additions and 1011 deletions

View File

@@ -4,6 +4,8 @@ on: [push]
env:
GITHUB_WORKFLOW_REF:
TWINE_USERNAME: ${{ secrets.GIT_PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.GIT_PYPI_PASSWORD }}
jobs:
build:
@@ -34,11 +36,16 @@ jobs:
run: pip install build && python -m build
- name: Publish package to Gitea PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: ${{ secrets.GIT_PYPI_USERNAME }}
password: ${{ secrets.GIT_PYPI_PASSWORD }}
repository-url: https://git.starfieldsweb.com/api/packages/StarFields/pypi
print-hash: true
verbose: true
run: pip install twine && pip install --upgrade pkginfo && python -m twine upload --repository-url https://git.starfieldsweb.com/api/packages/StarFields/pypi ./dist/*
# - name: Publish package to Gitea PyPI
# continue-on-error: false
# uses: pypa/gh-action-pypi-publish@release/v1
# with:
# user: ${{ secrets.GIT_PYPI_USERNAME }}
# password: ${{ secrets.GIT_PYPI_PASSWORD }}
# repository-url: https://git.starfieldsweb.com/api/packages/StarFields/pypi
# print-hash: true
# verbose: true
# TODO make a release section that creates a gitea release

5
.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
# Ignore all the python cache files
**/__pycache__/**/*
# Ignore all the dist files
**/dist/**/*

View File

@@ -2,7 +2,17 @@ This repository holds the django library that StarFields uses for the django-res
# Differences with the DRF generic views
It changes the generic lifecycles of all the CRUD operations to fit within them automated caching functionality. Caching and deleting cache keys is handled by the library in a way that the cache keys have no duplicates. The generic views offered include single item CRUD and list-based CRUD.
The generic views of DRF use the serializers in such a way that the model serializers directly integrate the functionality of a specific model (db table) with CRUD. As a result simple operations such as deleting a list of n table rows ends up using n queries. This serves automated CRUD well but control and performance suffer, in particular changing the view and serializer methods to use a single query in order to perform a write operation becomes a very non-uniform experience between request methods. This library changes the DRF generic views to be more uniform in the way they use the serializers, in particular a single query is made for GET requests with elaborate filtering capabilities and calls to serializer .create(), .update() and .destroy() methods for write methods.
In particular:
- Single Create operations create the model instance as expected
- Single Retrive, Update and Destroy work with the .get_object() callable to find and manipulate the instance
- List Retrieve operations work through elaborate filters to get results starting with the .get_queryset() callable.
- List Create, Update and Destroy need to implement ListSerializer .create(), .update() and .destroy() methods for bulk operations. None of those methods use the .get_queryset() callable.
It is easy to notice that the single item apis are useful but limited. List apis allow you to perform much more flexible operations and organize frontends better as well, allowing for example out of step syncing (such as shopping carts). As a result it is recommended to use strictly list based generic views for all apis that are supposed to be flexibly used and restrict yourself to using single item generic views to apis whose access you want to partially restrict to the outside world from scanning (such as users) or are inherently simple.
The generics are also enhanced to fit within them automated caching functionality. Caching and deleting cache keys is handled by the library in a way that the cache keys have no duplicates. The generic views offered include single item CRUD and list-based CRUD.
To manage automated caching this the library replaces (and appends to) the DRF filters. These filters need a get_unique_dict method in order to avoid the duplicate cache keys problem.

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "starfields-drf-generics"
version = "0.3.0"
version = "0.4.1"
authors = [
{ name="Anastasios Svolis", email="support@starfields.gr" },
]
@@ -18,7 +18,7 @@ classifiers = [
]
[project.urls]
"Homepage" = "https://git.vickys-corner.xyz/ace/starfields-drf-generics"
"Homepage" = "https://git.starfieldsweb.com/StarFields/starfields-drf-generics"
[tool.setuptools.packages.find]
where = ["starfields_drf_generics"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,126 @@
from rest_framework.filters import BaseFilterBackend
from django.template import loader
from django.db.models import Q
class CategoryFilter(BaseFilterBackend):
"""
This filter assigns the view.category object for later use, in particular
for filters that depend on this one.
"""
template = 'filters/categories.html'
category_field = 'category'
def get_category_class(self, view, request):
return getattr(view, 'category_class', None)
def assign_view_category(self, request, view):
if not hasattr(view, self.category_field):
if self.category_field in request.query_params.keys():
try:
category_slug = request.query_params.get(
self.category_field).strip()
category = view.category_class.objects.get(
slug=category_slug)
# Append the category object in the view for later use
view.category = category
except Exception:
view.category = None
else:
view.category = None
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes. Queries the database for the current
category and saves it in view.category for internal use and later
filters.
"""
if hasattr(view, 'category_field'):
self.category_field = self.get_category_field(view, request)
category_class = self.get_category_class(view, request)
assert category_class is not None, (
f"{view.__class__.__name__} should include a `category_class`"
"attribute"
)
# Create the filters dictionary and find the present category instance
self.assign_view_category(request, view)
filters_dict = {}
if view.category:
filters_dict[self.category_field] = [view.category.slug]
else:
filters_dict[self.category_field] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
self.assign_view_category(request, view)
# Create the queryset
if view.category:
kwquery_1 = {}
kwquery_1[self.category_field+'__id'] = view.category.id
if view.category.tn_descendants_pks:
kwquery_2 = {}
key = self.category_field+'__id__in'
kwquery_2[key] = view.category.tn_descendants_pks.split(',')
queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2))
else:
queryset = queryset.filter(**kwquery_1)
return queryset
# Developer Interface methods
def get_valid_fields(self, queryset, view, context, request):
# A query is executed here to get the possible categories
category_class = self.get_category_class(view, request)
if hasattr(view, 'category_field'):
self.category_field = self.get_category_field(view, request)
assert category_class is not None, (
f"{view.__class__.__name__} should include a `category_class`"
"attribute"
)
valid_fields = category_class.objects.all()
if len(valid_fields):
valid_fields = [
(item.slug, item.__str__()) for item in valid_fields
]
return valid_fields
else:
return []
def get_current(self, request, queryset, view):
params = request.query_params.get(self.category_field)
if params:
fields = [param.strip() for param in params.split(',')]
return fields[0]
else:
return None
def get_template_context(self, request, queryset, view):
current = self.get_current(request, queryset, view)
options = []
context = {
'request': request,
'current': current,
'param': self.category_field,
}
valid_fields = self.get_valid_fields(queryset, view, context, request)
for key, label in valid_fields:
options.append((key, '%s' % (label)))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)

View File

@@ -0,0 +1,170 @@
from rest_framework.filters import BaseFilterBackend
from django.template import loader
from django.db.models import Q
class FacetFilter(BaseFilterBackend):
"""
This filter requires CategoryFilter to be ran before it. It assigns the
view.facets which includes all the facets applicable to the current
category.
"""
template = 'filters/facets.html'
def get_facet_class(self, view, request):
return getattr(view, 'facet_class', None)
def get_facet_tag_class(self, view, request):
return getattr(view, 'facet_tag_class', None)
def get_facet_tag_field(self, view, request):
return getattr(view, 'facet_tag_field', None)
def assign_view_facets(self, request, view):
if not hasattr(view, 'facets'):
if hasattr(view, 'facet_class'):
self.facet_class = self.get_facet_class(view, request)
assert self.facet_class is not None, (
f"{view.__class__.__name__} should include a `facet_class`"
"attribute"
)
if view.category:
if view.category.tn_ancestors_pks:
ancestor_ids = view.category.tn_ancestors_pks.split(',')
view.facets = self.facet_class.objects.filter(
Q(category__id=view.category.id) | Q(
category__id__in=ancestor_ids)
).prefetch_related('facet_tags')
else:
view.facets = self.facet_class.objects.filter(
category__id=view.category.id).prefetch_related(
'facet_tags')
else:
view.facets = self.facet_class.objects.filter(
category__tn_level=1).prefetch_related(
'facet_tags')
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
if hasattr(view, 'facet_class'):
self.facet_class = self.get_facet_class(view, request)
assert self.facet_class is not None, (
f"{view.__class__.__name__} should include a `facet_class`"
"attribute"
)
self.assign_view_facets(request, view)
filters_dict = {}
if view.facets:
for facet in view.facets:
if facet.slug in request.query_params.keys():
filters_dict[facet.slug] = set(
request.query_params[facet.slug].split(','))
else:
filters_dict[facet.slug] = set({})
# Append the facets object and the tags dict in the view for later
# reference
view.tags = filters_dict
return filters_dict
def filter_queryset(self, request, queryset, view):
if hasattr(view, 'facet_tag_class'):
self.facet_tag_class = self.get_facet_tag_class(view, request)
assert self.facet_tag_class is not None, (
f"{view.__class__.__name__} should include a `facet_tag_class`"
"attribute"
)
if hasattr(view, 'facet_tag_field'):
self.facet_tag_field = self.get_facet_tag_field(view, request)
assert self.facet_tag_field is not None, (
f"{view.__class__.__name__} should include a `facet_tag_field`"
"attribute"
)
self.assign_view_facets(request, view)
if view.facets:
for facet in view.facets:
if facet.slug in request.query_params.keys():
tag_filterlist = request.query_params.get(facet.slug)
if tag_filterlist == '':
# If the tag filterlist is empty then we're not
# filtering against it, it's like having all the tags
# of the facet selected
pass
else:
kwquery = {}
key = self.facet_tag_field+'__slug__in'
kwquery[key] = tag_filterlist.replace(' ', '').split(
',')
queryset = queryset.filter(**kwquery)
return queryset
# Developer Interface methods
def get_template_context(self, request, queryset, view):
# Does aggressive database querying to get the necessary facets and
# facettags, but this is only for the developer interface so its fine
if hasattr(view, 'facet_class'):
self.facet_class = self.get_facet_class(view, request)
assert self.facet_class is not None, (
f"{view.__class__.__name__} should include a `facet_class`"
"attribute"
)
if hasattr(view, 'facet_tag_class'):
self.facet_tag_class = self.get_facet_tag_class(view, request)
assert self.facet_tag_class is not None, (
f"{view.__class__.__name__} should include a `facet_tag_class`"
"attribute"
)
# Find the current choices
current = []
facet_slugs = []
if view.facets:
for facet in view.facets:
facet_slugs.append(facet.slug)
if facet.slug in request.query_params.keys():
current.append(request.query_params.get(facet.slug))
facet_slug_names = {}
options = {}
context = {
'request': request,
'current': current,
'facet_slugs': facet_slugs,
}
if view.facets:
for facet in view.facets:
facet_tag_instances = self.facet_tag_class.objects.filter(
facet__slug=facet.slug)
options[facet.slug] = [(facet_tag.slug, facet_tag.name) for
facet_tag in facet_tag_instances]
facet_slug_names[facet.slug] = facet.name
context['facet_slug_names'] = facet_slug_names
context['options'] = options
else:
context['facet_slug_names'] = {}
context['options'] = {}
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)

View File

@@ -0,0 +1,38 @@
from rest_framework.filters import BaseFilterBackend
class LessThanOrEqualFilter(BaseFilterBackend):
def get_less_than_field(self, view, request):
return getattr(view, 'less_than_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
less_than_field = self.get_less_than_field(view, request)
assert less_than_field is not None, (
f"{view.__class__.__name__} should include a `less_than_field`"
"attribute"
)
filters_dict = {}
if less_than_field+'_max' in request.query_params.keys():
field_value = request.query_params.get(less_than_field+'_max')
filters_dict[less_than_field+'_max'] = [field_value]
else:
filters_dict[less_than_field+'_max'] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
# Return the correctly filtered queryset but also assign the filter
# dict to create the unique url for the cache
less_than_field = self.get_less_than_field(view, request)
if less_than_field+'_max' in request.query_params.keys():
kwquery = {}
field_value = request.query_params.get(less_than_field+'_max')
kwquery[less_than_field+'__lte'] = field_value
return queryset.filter(**kwquery)
else:
return queryset

View File

@@ -0,0 +1,39 @@
from rest_framework.filters import BaseFilterBackend
class MoreThanOrEqualFilter(BaseFilterBackend):
def get_more_than_field(self, view, request):
return getattr(view, 'more_than_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
more_than_field = self.get_more_than_field(view, request)
assert more_than_field is not None, (
f"{view.__class__.__name__} should include a `more_than_field`"
"attribute"
)
filters_dict = {}
if more_than_field+'_min' in request.query_params.keys():
field_value = request.query_params.get(more_than_field+'_min')
filters_dict[more_than_field+'_min'] = [field_value]
else:
filters_dict[more_than_field+'_min'] = []
return filters_dict
def filter_queryset(self, request, queryset, view):
"""
Return the correctly filtered queryset
"""
more_than_field = self.get_more_than_field(view, request)
if more_than_field+'_min' in request.query_params.keys():
kwquery = {}
kwquery[more_than_field+'__gte'] = request.query_params.get(
more_than_field+'_min')
return queryset.filter(**kwquery)
else:
return queryset

View File

@@ -0,0 +1,68 @@
from rest_framework.settings import api_settings
from rest_framework.filters import BaseFilterBackend
from django.db.models import Q
class TreeNodeBranchFilter(BaseFilterBackend):
def get_descendants_of_field(self, view, request):
return getattr(view, 'descendants_of_field', None)
def get_depth_field(self, view, request):
return getattr(view, 'depth_field', None)
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
descendants_of_field = self.get_descendants_of_field(view, request)
assert descendants_of_field is not None, (
"{view.__class__.__name__} should include a "
f"`descendants_of_field` attribute"
)
depth_field = self.get_depth_field(view, request)
assert depth_field is not None, (
"{view.__class__.__name__} should include a "
f"`depth_field` attribute"
)
filters_dict = {}
if descendants_of_field in request.query_params.keys():
field_value = request.query_params.get(descendants_of_field)
filters_dict[descendants_of_field] = [field_value]
else:
filters_dict[descendants_of_field] = []
if depth_field in request.query_params.keys():
field_value = request.query_params.get(depth_field)
filters_dict[depth_field] = [field_value]
else:
filters_dict[depth_field] = [4]
return filters_dict
def filter_queryset(self, request, queryset, view):
# Return the correctly filtered queryset but also assign the filter
# dict to create the unique url for the cache
descendants_of_field = self.get_descendants_of_field(view, request)
depth_field = self.get_depth_field(view, request)
if descendants_of_field in request.query_params.keys():
field_value = request.query_params.get(descendants_of_field)
# Instead of doing two queries to get the descendants through the
# object a single more complex queryset
queryset = queryset.filter(
Q(tn_ancestors_pks=field_value) |
Q(tn_ancestors_pks__contains=","+field_value) |
Q(tn_ancestors_pks__contains=field_value+","))
if depth_field in request.query_params.keys():
field_value = request.query_params.get(depth_field)
queryset = queryset.filter(tn_level__lte=field_value)
else:
queryset = queryset.filter(tn_level__lte=4)
return queryset

View File

@@ -0,0 +1,180 @@
from rest_framework.settings import api_settings
from rest_framework.filters import BaseFilterBackend
from django.template import loader
from django.utils.translation import gettext_lazy as _
class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
ordering_title = _('Ordering')
ordering_description = _('Which field to use when ordering the results.')
template = 'rest_framework/filters/ordering.html'
def get_ordering(self, request, queryset, view):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
The `ordering` query parameter can be overridden by setting
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset,
fields,
view,
request)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
if isinstance(ordering, str):
return (ordering,)
return ordering
def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'):
try:
serializer_class = view.get_serializer_class()
except AssertionError:
# Raised by the default implementation if
# no serializer_class was found
serializer_class = None
else:
serializer_class = getattr(view, 'serializer_class', None)
if serializer_class is None:
msg = (
"Cannot use %s on a view which does not have either a "
"'serializer_class', an overriding 'get_serializer_class' "
"or 'ordering_fields' attribute."
)
raise ImproperlyConfigured(msg % self.__class__.__name__)
model_class = queryset.model
model_property_names = [
# 'pk' is a property added in Django's Model class, however it is valid for ordering.
attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk'
]
return [
(field.source.replace('.', '__') or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items()
if (
not getattr(field, 'write_only', False) and
not field.source == '*' and
field.source not in model_property_names
)
]
def get_valid_fields(self, queryset, view, context={}):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view, context)
elif valid_fields == '__all__':
# View explicitly allows filtering on any model field
valid_fields = [
(field.name, field.verbose_name) for field in queryset.model._meta.fields
]
valid_fields += [
(key, key.title().split('__'))
for key in queryset.query.annotations
]
else:
valid_fields = [
(item, item) if isinstance(item, str) else item
for item in valid_fields
]
return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
def term_valid(term):
if term.startswith("-"):
term = term[1:]
return term in valid_fields
return [term for term in fields if term_valid(term)]
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
"""
self.filters_dict = {}
if 'ordering' in request.query_params.keys():
slug_term = request.query_params.get('ordering')
self.filters_dict['ordering'] = [slug_term]
else:
self.filters_dict['ordering'] = [view.ordering_fields[0]]
return self.filters_dict
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
return queryset
def get_template_context(self, request, queryset, view):
current = self.get_ordering(request, queryset, view)
current = None if not current else current[0]
options = []
context = {
'request': request,
'current': current,
'param': self.ordering_param,
}
for key, label in self.get_valid_fields(queryset, view, context):
options.append((key, '%s - %s' % (label, _('ascending'))))
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.ordering_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.ordering_title),
description=force_str(self.ordering_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.ordering_param,
'required': False,
'in': 'query',
'description': force_str(self.ordering_description),
'schema': {
'type': 'string',
},
},
]

View File

@@ -0,0 +1,247 @@
from django.utils.text import smart_split
from django.core.exceptions import FieldError, FieldDoesNotExist
import operator
from functools import reduce
from rest_framework.filters import BaseFilterBackend
from django.template import loader
from django.utils.translation import gettext_lazy as _
from django.db import models
from django.db.models import Q
from rest_framework.fields import CharField
from django.db.models.constants import LOOKUP_SEP
from django.db.models.functions import Concat
def calculate_threshold(query, min_threshold, max_threshold):
query_threshold = len(query)/300
if query_threshold < min_threshold:
return min_threshold
if max_threshold < query_threshold:
return max_threshold
return query_threshold
def search_smart_split(search_terms):
"""
Generator that first splits string by spaces, leaving quoted phrases
together, then it splits non-quoted phrases by commas.
"""
split_terms = []
for term in smart_split(search_terms):
# trim commas to avoid bad matching for quoted phrases
term = term.strip(',')
if term.startswith(('"', "'")) and term[0] == term[-1]:
# quoted phrases are kept together without any other split
split_terms.append(unescape_string_literal(term))
else:
# non-quoted tokens are split by comma, keeping only non-empty ones
for sub_term in term.split(','):
if sub_term:
split_terms.append(sub_term.strip())
return split_terms
class TrigramSearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = 'search'
template = 'rest_framework/filters/search.html'
search_title = _('Search')
search_description = _('A search string to perform trigram similarity'
'based searching with.')
lookup_prefixes = {
'^': 'istartswith',
'=': 'iexact',
'@': 'search',
'$': 'iregex',
}
def get_filters_dict(self, request, view):
"""
Custom method that returns the filters exclusive to this filter in a
dict. For caching purposes.
"""
self.filters_dict = {}
if 'search' in request.query_params.keys():
slug_term = request.query_params.get('search')
self.filters_dict['search'] = [slug_term]
else:
self.filters_dict['search'] = []
return self.filters_dict
def get_search_fields(self, view, request):
"""
Search fields are obtained from the view, but the request is always
passed to this method. Sub-classes can override this method to
dynamically change the search fields based on request content.
"""
return getattr(view, 'search_fields', None)
def get_search_query(self, request):
"""
Search terms are set by a ?search=... query parameter,
and may be whitespace delimited.
"""
value = request.query_params.get(self.search_param, '')
field = CharField(trim_whitespace=False, allow_blank=True)
cleaned_value = field.run_validation(value)
return cleaned_value
def construct_search(self, field_name, queryset):
"""
For the sqlite search
"""
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
# Use field_name if it includes a lookup.
opts = queryset.model._meta
lookup_fields = field_name.split(LOOKUP_SEP)
# Go through the fields, following all relations.
prev_field = None
for path_part in lookup_fields:
if path_part == "pk":
path_part = opts.pk.name
try:
field = opts.get_field(path_part)
except FieldDoesNotExist:
# Use valid query lookups.
if prev_field and prev_field.get_lookup(path_part):
return field_name
else:
prev_field = field
if hasattr(field, "path_infos"):
# Update opts to follow the relation.
opts = field.path_infos[-1].to_opts
# django < 4.1
elif hasattr(field, 'get_path_info'):
# Update opts to follow the relation.
opts = field.get_path_info()[-1].to_opts
# Otherwise, use the field with icontains.
lookup = 'icontains'
return LOOKUP_SEP.join([field_name, lookup])
def must_call_distinct(self, queryset, search_fields):
"""
Return True if 'distinct()' should be used to query the given lookups.
"""
for search_field in search_fields:
opts = queryset.model._meta
if search_field[0] in self.lookup_prefixes:
search_field = search_field[1:]
# Annotated fields do not need to be distinct
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
continue
parts = search_field.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
if hasattr(field, 'get_path_info'):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
if any(path.m2m for path in path_info):
# This field is a m2m relation so we know we need to call distinct
return True
else:
# This field has a custom __ query transform but is not a relational field.
break
return False
def filter_queryset(self, request, queryset, view):
search_fields = self.get_search_fields(view, request)
assert search_fields is not None, (
f"{view.__class__.__name__} should include a `search_fields`"
"attribute"
)
query = self.get_search_query(request)
if not query:
return queryset
try:
# Attempt postgresql's full text search
from django.contrib.postgres.search import TrigramStrictWordSimilarity
threshold = calculate_threshold(query, 0.02, 0.12)
queryset = queryset.annotate(
search_field=Concat(
*search_fields,
output_field=CharField()
)).annotate(
similarity=TrigramStrictWordSimilarity(
'search_field', query)
).filter(similarity__gt=threshold)
# NOTE a weird FieldError is raised on sqlite
except (ImportError, FieldError):
# Perform very simple sqlite compatible search
search_terms = search_smart_split(query)
orm_lookups = [
self.construct_search(str(search_field), queryset)
for search_field in search_fields
]
base = queryset
# generator which for each term builds the corresponding search
conditions = (
reduce(
operator.or_,
(models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups)
) for term in search_terms
)
queryset = queryset.filter(reduce(operator.and_, conditions))
# Remove duplicates from results, if necessary
if self.must_call_distinct(queryset, search_fields):
# inspired by django.contrib.admin
# this is more accurate than .distinct form M2M relationship
# also is cross-database
queryset = queryset.filter(pk=models.OuterRef('pk'))
queryset = base.filter(models.Exists(queryset))
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):
return ''
term = request.query_params.get(self.search_param, '')
term = term[0] if term else ''
context = {
'param': self.search_param,
'term': term
}
template = loader.get_template(self.template)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.search_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.search_title),
description=force_str(self.search_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.search_param,
'required': False,
'in': 'query',
'description': force_str(self.search_description),
'schema': {
'type': 'string',
},
},
]

View File

@@ -109,7 +109,7 @@ class CachedListCreateModelMixin(CacheDeleteMixin):
A fully custom mixin that handles mutiple instance cration.
"""
def list_create(self, request):
def list_create(self, request, **kwargs):
" Creates the list of entries in the request "
# Go on with the creation as normal
serializer = self.get_serializer(data=request.data, many=True)
@@ -123,8 +123,24 @@ class CachedListCreateModelMixin(CacheDeleteMixin):
headers=headers)
def perform_create(self, serializer):
" Generic save hook "
serializer.save()
"""
Uses serializer.create instead of serializer.save to avoid making a
query. We save the returned instance list to the serializer in order to
be used as serializer.data during rendering
"""
assert hasattr(serializer, 'create'), (
f'Cannot call .create() on serializer {serializer.__class__} as'
' no such attribute exists.'
)
validated_data = serializer.validated_data
instance_list = serializer.create(validated_data)
# Check whatever you can
assert hasattr(instance_list, '__iter__'), (
'Method .create() on serializer on serializer '
f'{serializer.__class__} should return a list of serializable'
' model instances.'
)
serializer.instance = instance_list
def get_success_headers(self, data):
" Returns extra success headers "
@@ -142,7 +158,7 @@ class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
inherit anything from it.
"""
def list(self, request):
def list(self, request, **kwargs):
" Retrieves the listing of entries "
# Attempt to get the request from the cache
cache_attempt = self.get_cache(request)
@@ -179,9 +195,7 @@ class CachedListUpdateModelMixin(CacheDeleteMixin):
" Updates the list of entries in the request "
partial = kwargs.pop('partial', False)
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, data=request.data,
serializer = self.get_serializer(data=request.data,
partial=partial, many=True)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
@@ -192,8 +206,24 @@ class CachedListUpdateModelMixin(CacheDeleteMixin):
return Response(serializer.data)
def perform_update(self, serializer):
" Generic save hook "
serializer.save()
"""
Uses serializer.update instead of serializer.save to avoid making a
query. We save the returned instance list to the serializer in order to
be used as serializer.data during rendering
"""
assert hasattr(serializer, 'update'), (
f'Cannot call .update() on serializer {serializer.__class__} as'
' no such attribute exists.'
)
validated_data = serializer.validated_data
instance_list = serializer.update(None, validated_data)
# Check whatever you can
assert hasattr(instance_list, '__iter__'), (
'Method .update() on serializer on serializer '
f'{serializer.__class__} should return a list of serializable '
' model instances.'
)
serializer.instance = instance_list
def list_partial_update(self, request, *args, **kwargs):
" Needs to be called on partial updates "
@@ -206,21 +236,24 @@ class CachedListDestroyModelMixin(CacheDeleteMixin):
A fully custom mixin that handles mutiple instance deletions.
"""
def list_destroy(self, request):
def list_destroy(self, request, **kwargs):
" Deletes the list of entries in the request "
# Go on with the validation as normal
serializer = self.get_serializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
validated_data = serializer.validated_data
# TODO does this new stuff work even? need to check on the frontend
serializer.delete(validated_data)
# for instance in self.get_objects():
# if instance is not None:
# self.perform_destroy(instance)
self.perform_destroy(serializer)
# Delete the related caches
self.delete_cache(request)
return Response(status=status.HTTP_204_NO_CONTENT)
def perform_destroy(self, serializer):
" Custom generic destroy hook "
assert hasattr(serializer, 'destroy'), (
f'Cannot call .destroy() on serializer {serializer.__class__} as'
' no such attribute exists.'
)
validated_data = serializer.validated_data
serializer.destroy(validated_data)

View File

@@ -0,0 +1,73 @@
from rest_framework.parsers import BaseParser, DataAndFiles
from django.http.multipartparser import MultiPartParserError
from rest_framework.exceptions import ParseError
from django.http.multipartparser import \
MultiPartParser as DjangoMultiPartParser
from django.conf import settings
import json
class NestedJsonMultiPartParser(BaseParser):
"""
Parser for multipart form data, which may include file data.
"""
media_type = 'multipart/form-data'
def parse(self, stream, media_type=None, parser_context=None):
"""
Parses the incoming bytestream as a multipart encoded form,
and returns a DataAndFiles object.
The main difference with the parser from rest_framework.parsers
is that it attempts to parse nested strings as json to fit with
better with json payloads. I also ensure that a single file is
passed per field instead of a list which was erroring out
FieldFile.to_internal_value.
`.data` will be a dict containing all the form parameters.
`.files` will be a dict containing all the form files.
"""
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
meta = request.META.copy()
meta['CONTENT_TYPE'] = media_type
upload_handlers = request.upload_handlers
try:
parser = DjangoMultiPartParser(meta,
stream,
upload_handlers,
encoding)
data, files = parser.parse()
# Attempt to parse the multipart fields as json, this is not
# demanding since multiparts exist exclusively for file uploads
# which is much more demanding
data_dict = {}
for key in data.keys():
values = data[key]
try:
data_dict[key] = json.loads(values)
except Exception as e:
data_dict[key] = values
# Make sure the filenames become file names
for filename in files.keys():
uploaded_file = files[filename]
hasfilename = hasattr(uploaded_file, '_name')
hasname = hasattr(uploaded_file, 'name')
if hasfilename and not hasname:
uploaded_file.name = uploaded_file._name
# Turn the monstrous MultiValueDict into a normal dict so that
# there is only a single uploaded file per field passed
files_dict = {}
for key in files.keys():
uploaded_file = files[key]
if isinstance(uploaded_file, list):
uploaded_file = uploaded_file[0]
files_dict[key] = uploaded_file
return DataAndFiles(data_dict, files_dict)
except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % str(exc))