Compare commits
33 Commits
starfields
...
starfields
| Author | SHA1 | Date | |
|---|---|---|---|
| 14dc2cd4af | |||
| 0cde861177 | |||
| 00bbb65d21 | |||
| 75e4a70fce | |||
| 7be7104346 | |||
| cc2528344d | |||
| 621fcace85 | |||
| 7e4596e5b1 | |||
| dad7aa8348 | |||
| dcf3b3990b | |||
| 2b8a4863c0 | |||
| 0706fc5dc8 | |||
| cd79201d4a | |||
| 8a6a73d321 | |||
| 76181a46ac | |||
| bc5324bef6 | |||
| 3c2eddfea0 | |||
| 942b576521 | |||
| 2ec35f434d | |||
| 2f61d197ae | |||
| afa6a6f9c6 | |||
| 9c873d5d8e | |||
| e370098312 | |||
| 6466efa9fc | |||
| 813135ed43 | |||
| 4a45f05f2d | |||
| 0fccdf60bd | |||
| a86f25b230 | |||
| 4e2b0ec0c6 | |||
| 18812423c5 | |||
| d8c56fc9d1 | |||
| 06d612cd70 | |||
| 053c50dda1 |
44
.gitea/workflows/publish.yaml
Normal file
44
.gitea/workflows/publish.yaml
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
name: StarFields Django Rest Framework Generics
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
env:
|
||||||
|
GITHUB_WORKFLOW_REF:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
# Bizzarre bug that needed the below according to https://forum.gitea.com/t/gitea-actions-with-python/7605/3
|
||||||
|
container: catthehacker/ubuntu:act-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
run: |
|
||||||
|
git clone https://${{secrets.GIT_USERNAME}}:${{secrets.GIT_PASSWORD_URLENCODED}}@git.starfieldsweb.com/StarFields/starfields-drf-generics ./
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
# This is the version of the action for setting up Python, not the Python version.
|
||||||
|
uses: https://github.com/actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
# Semantic version range syntax or exact version of a Python version
|
||||||
|
python-version: '3.12.3'
|
||||||
|
|
||||||
|
- name: Display Python version
|
||||||
|
run: python -c "import sys; print(sys.version)"
|
||||||
|
|
||||||
|
# TODO testing
|
||||||
|
# check https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#packaging-workflow-data-as-artifacts
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: pip install build && python -m build
|
||||||
|
|
||||||
|
- name: Publish package to Gitea PyPI
|
||||||
|
continue-on-error: true
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
user: ${{ secrets.GIT_PYPI_USERNAME }}
|
||||||
|
password: ${{ secrets.GIT_PYPI_PASSWORD }}
|
||||||
|
repository-url: https://git.starfieldsweb.com/api/packages/StarFields/pypi
|
||||||
|
print-hash: true
|
||||||
|
verbose: true
|
||||||
@@ -7,14 +7,6 @@ It changes the generic lifecycles of all the CRUD operations to fit within them
|
|||||||
To manage automated caching this the library replaces (and appends to) the DRF filters. These filters need a get_unique_dict method in order to avoid the duplicate cache keys problem.
|
To manage automated caching this the library replaces (and appends to) the DRF filters. These filters need a get_unique_dict method in order to avoid the duplicate cache keys problem.
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
### Ensure that the module is in the INSTALLED_APPS in settings.py:
|
|
||||||
```python
|
|
||||||
INSTALLED_APPS = [
|
|
||||||
...
|
|
||||||
'starfields_drf_generics',
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
### Making views in views.py:
|
### Making views in views.py:
|
||||||
```python
|
```python
|
||||||
from starfields_drf_generics import generics
|
from starfields_drf_generics import generics
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "starfields-drf-generics"
|
name = "starfields-drf-generics"
|
||||||
version = "0.1.0"
|
version = "0.3.0"
|
||||||
authors = [
|
authors = [
|
||||||
{ name="Anastasios Svolis", email="support@starfields.gr" },
|
{ name="Anastasios Svolis", email="support@starfields.gr" },
|
||||||
]
|
]
|
||||||
@@ -27,5 +27,5 @@ where = ["starfields_drf_generics"]
|
|||||||
"templates.filters" = ["*.html"]
|
"templates.filters" = ["*.html"]
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
include = ["*.py"]
|
include = ["*.py", "*.html"]
|
||||||
exclude = ["test*"]
|
exclude = ["test*"]
|
||||||
|
|||||||
@@ -1,36 +1,45 @@
|
|||||||
from starfields_drf_generics.utils import sorted_params_string
|
from starfields_drf_generics.utils import sorted_params_string
|
||||||
|
|
||||||
# TODO classes below that involve create, update, destroy don't delete the caches properly, they need a regex cache delete
|
# TODO classes below that involve create, update, destroy don't delete the
|
||||||
|
# caches properly, they need a regex cache delete
|
||||||
# TODO put more reasonable asserts and feedback
|
# TODO put more reasonable asserts and feedback
|
||||||
|
|
||||||
|
|
||||||
# Mixin classes that provide cache functionalities
|
# Mixin classes that provide cache functionalities
|
||||||
class CacheUniqueUrl:
|
class CacheUniqueUrl:
|
||||||
def get_cache_unique_url(self, request):
|
def get_cache_unique_url(self, request):
|
||||||
""" Create the query to be cached in a unique way to avoid duplicates. """
|
"""
|
||||||
|
Create the query to be cached in a unique way to avoid duplicates.
|
||||||
|
"""
|
||||||
if not hasattr(self, 'filters_string'):
|
if not hasattr(self, 'filters_string'):
|
||||||
# Only assign the attribute if it's not already assigned
|
# Only assign the attribute if it's not already assigned
|
||||||
filters = {}
|
filters = {}
|
||||||
if self.extra_filters_dict:
|
if self.extra_filters_dict:
|
||||||
filters.update(self.extra_filters_dict)
|
filters.update(self.extra_filters_dict)
|
||||||
# Check if the url parameters have any of the keys of the extra filters and if so assign them
|
# Check if the url parameters have any of the keys of the extra
|
||||||
|
# filters and if so assign them
|
||||||
for key in self.extra_filters_dict:
|
for key in self.extra_filters_dict:
|
||||||
if key in self.request.query_params.keys():
|
if key in self.request.query_params.keys():
|
||||||
filters[key] = self.request.query_params[key].replace(' ','').split(',')
|
filters[key] = self.request.query_params[key].replace(
|
||||||
|
' ', '').split(',')
|
||||||
# Check if they're resolved in the urlconf as well
|
# Check if they're resolved in the urlconf as well
|
||||||
if key in self.kwargs.keys():
|
if key in self.kwargs.keys():
|
||||||
filters[key] = [self.kwargs[key]]
|
filters[key] = [self.kwargs[key]]
|
||||||
|
|
||||||
if hasattr(self, 'paged'):
|
if hasattr(self, 'paged'):
|
||||||
if self.paged:
|
if self.paged:
|
||||||
filters.update({'limit': [self.default_page_size], 'offset': [0]})
|
filters.update({'limit': [self.default_page_size],
|
||||||
|
'offset': [0]})
|
||||||
if 'limit' in self.request.query_params.keys():
|
if 'limit' in self.request.query_params.keys():
|
||||||
filters.update({'limit': [self.request.query_params['limit']]})
|
filters.update({
|
||||||
|
'limit': [self.request.query_params['limit']]})
|
||||||
if 'offset' in self.request.query_params.keys():
|
if 'offset' in self.request.query_params.keys():
|
||||||
filters.update({'offset': [self.request.query_params['offset']]})
|
filters.update({
|
||||||
|
'offset': [self.request.query_params['offset']]})
|
||||||
for backend in list(self.filter_backends):
|
for backend in list(self.filter_backends):
|
||||||
filters.update(backend().get_filters_dict(request, self))
|
filters.update(backend().get_filters_dict(request, self))
|
||||||
self.filters_string = sorted_params_string(filters)
|
self.filters_string = sorted_params_string(filters)
|
||||||
|
|
||||||
|
|
||||||
class CacheGetMixin(CacheUniqueUrl):
|
class CacheGetMixin(CacheUniqueUrl):
|
||||||
cache_prefix = None
|
cache_prefix = None
|
||||||
@@ -38,30 +47,34 @@ class CacheGetMixin(CacheUniqueUrl):
|
|||||||
cache_timeout_mins = None
|
cache_timeout_mins = None
|
||||||
default_page_size = 20
|
default_page_size = 20
|
||||||
extra_filters_dict = None
|
extra_filters_dict = None
|
||||||
|
|
||||||
def get_cache(self, request):
|
def get_cache(self, request):
|
||||||
assert self.cache_prefix is not None, (
|
assert self.cache_prefix is not None, (
|
||||||
"'%s' should include a `cache_prefix` attribute"
|
"'%s' should include a `cache_prefix` attribute"
|
||||||
% self.__class__.__name__
|
% self.__class__.__name__
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_cache_unique_url(request)
|
self.get_cache_unique_url(request)
|
||||||
|
|
||||||
# Attempt to get the response from the cache for the whole request
|
# Attempt to get the response from the cache for the whole request
|
||||||
try:
|
try:
|
||||||
if self.cache_vary_on_user:
|
if self.cache_vary_on_user:
|
||||||
cache_attempt = self.cache.get(f"{self.cache_prefix}.{request.user}.{self.filters_string}")
|
cache_attempt = self.cache.get(
|
||||||
|
f"{self.cache_prefix}.{request.user}.{self.filters_string}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cache_attempt = self.cache.get(f"{self.cache_prefix}.{self.filters_string}")
|
cache_attempt = self.cache.get(
|
||||||
except:
|
f"{self.cache_prefix}.{self.filters_string}")
|
||||||
self.logger.info(f"Cache get attempt for {self.__class__.__name__} failed.")
|
except Exception:
|
||||||
|
self.logger.info(f"Cache get attempt for {self.__class__.__name__}"
|
||||||
|
" failed.")
|
||||||
cache_attempt = None
|
cache_attempt = None
|
||||||
|
|
||||||
if cache_attempt:
|
if cache_attempt:
|
||||||
return cache_attempt
|
return cache_attempt
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class CacheSetMixin(CacheUniqueUrl):
|
class CacheSetMixin(CacheUniqueUrl):
|
||||||
cache_prefix = None
|
cache_prefix = None
|
||||||
@@ -69,31 +82,34 @@ class CacheSetMixin(CacheUniqueUrl):
|
|||||||
cache_timeout_mins = None
|
cache_timeout_mins = None
|
||||||
default_page_size = 20
|
default_page_size = 20
|
||||||
extra_filters_dict = None
|
extra_filters_dict = None
|
||||||
|
|
||||||
def set_cache(self, request, response):
|
def set_cache(self, request, response):
|
||||||
self.get_cache_unique_url(request)
|
self.get_cache_unique_url(request)
|
||||||
|
|
||||||
# Create a function that programmatically defines the caching function
|
# Create a function that programmatically defines the caching function
|
||||||
def make_caching_function(cls, request, cache):
|
def make_caching_function(cls, request, cache):
|
||||||
def caching_function(response):
|
def caching_function(response):
|
||||||
# Writes the response to the cache
|
# Writes the response to the cache
|
||||||
try:
|
try:
|
||||||
if self.cache_vary_on_user:
|
if self.cache_vary_on_user:
|
||||||
self.cache.set(key=f"{self.cache_prefix}.{request.user}.{self.filters_string}",
|
self.cache.set(key=f"{self.cache_prefix}."
|
||||||
value = response.data,
|
f"{request.user}.{self.filters_string}",
|
||||||
timeout=60*self.cache_timeout_mins)
|
value=response.data,
|
||||||
|
timeout=60*self.cache_timeout_mins)
|
||||||
else:
|
else:
|
||||||
self.cache.set(key=f"{self.cache_prefix}.{self.filters_string}",
|
self.cache.set(key=f"{self.cache_prefix}"
|
||||||
value = response.data,
|
f".{self.filters_string}",
|
||||||
timeout=60*self.cache_timeout_mins)
|
value=response.data,
|
||||||
except:
|
timeout=60*self.cache_timeout_mins)
|
||||||
self.logger.exception(f"Cache set attempt for {self.__class__.__name__} failed.")
|
except Exception:
|
||||||
|
self.logger.exception("Cache set attempt for "
|
||||||
|
f"{self.__class__.__name__} failed.")
|
||||||
return caching_function
|
return caching_function
|
||||||
|
|
||||||
# Register the post rendering hook to the response
|
# Register the post rendering hook to the response
|
||||||
caching_function = make_caching_function(self, request, self.cache)
|
caching_function = make_caching_function(self, request, self.cache)
|
||||||
response.add_post_render_callback(caching_function)
|
response.add_post_render_callback(caching_function)
|
||||||
|
|
||||||
|
|
||||||
class CacheDeleteMixin(CacheUniqueUrl):
|
class CacheDeleteMixin(CacheUniqueUrl):
|
||||||
cache_delete = True
|
cache_delete = True
|
||||||
@@ -101,22 +117,26 @@ class CacheDeleteMixin(CacheUniqueUrl):
|
|||||||
cache_vary_on_user = False
|
cache_vary_on_user = False
|
||||||
cache_timeout_mins = None
|
cache_timeout_mins = None
|
||||||
extra_filters_dict = None
|
extra_filters_dict = None
|
||||||
|
|
||||||
def delete_cache(self, request):
|
def delete_cache(self, request):
|
||||||
# Handle the caching
|
# Handle the caching
|
||||||
if self.cache_delete:
|
if self.cache_delete:
|
||||||
# Create the query to be cached in a unique way to avoid duplicates
|
# Create the query to be cached in a unique way to avoid duplicates
|
||||||
self.get_cache_unique_url(request)
|
self.get_cache_unique_url(request)
|
||||||
|
|
||||||
assert self.cache_prefix is not None, (
|
assert self.cache_prefix is not None, (
|
||||||
f"{self.__class__.__name__} should include a `cache_prefix` attribute"
|
f"{self.__class__.__name__} should include a `cache_prefix`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete the cache since a new entry has been created
|
# Delete the cache since a new entry has been created
|
||||||
try:
|
try:
|
||||||
if self.cache_vary_on_user:
|
if self.cache_vary_on_user:
|
||||||
self.cache.delete(f"{self.cache_prefix}.{request.user}.{self.filters_string}")
|
self.cache.delete(f"{self.cache_prefix}.{request.user}"
|
||||||
|
f".{self.filters_string}")
|
||||||
else:
|
else:
|
||||||
self.cache.delete(f"{self.cache_prefix}.{self.filters_string}")
|
self.cache.delete(f"{self.cache_prefix}"
|
||||||
except:
|
f".{self.filters_string}")
|
||||||
self.logger.exception(f"Cache delete attempt for {self.__class__.__name__} failed.")
|
except Exception:
|
||||||
|
self.logger.exception("Cache delete attempt for "
|
||||||
|
f"{self.__class__.__name__} failed.")
|
||||||
|
|||||||
@@ -1,34 +1,63 @@
|
|||||||
from django_filters import rest_framework as filters
|
|
||||||
from django.contrib.postgres.search import TrigramSimilarity
|
|
||||||
from django.db.models.functions import Concat
|
|
||||||
from django.db.models import CharField
|
|
||||||
from rest_framework.filters import BaseFilterBackend
|
|
||||||
import operator
|
import operator
|
||||||
|
from functools import reduce
|
||||||
|
from rest_framework.settings import api_settings
|
||||||
|
from rest_framework.filters import BaseFilterBackend
|
||||||
from django.template import loader
|
from django.template import loader
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django.db.models import Max, Min, Count, Q
|
|
||||||
from rest_framework.settings import api_settings
|
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from functools import reduce
|
from django.db.models import Q
|
||||||
|
from rest_framework.fields import CharField
|
||||||
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
|
from django.db.models.functions import Concat
|
||||||
|
|
||||||
|
|
||||||
# TODO the dev pages are not done
|
# TODO the dev pages are not done
|
||||||
|
|
||||||
# Filters
|
def calculate_threshold(query, min_threshold, max_threshold):
|
||||||
|
query_threshold = len(query)/300
|
||||||
|
if query_threshold < min_threshold:
|
||||||
|
return min_threshold
|
||||||
|
if max_threshold < query_threshold:
|
||||||
|
return max_threshold
|
||||||
|
return query_threshold
|
||||||
|
|
||||||
|
|
||||||
|
def search_smart_split(search_terms):
|
||||||
|
"""generator that first splits string by spaces, leaving quoted phrases together,
|
||||||
|
then it splits non-quoted phrases by commas.
|
||||||
|
"""
|
||||||
|
split_terms = []
|
||||||
|
for term in smart_split(search_terms):
|
||||||
|
# trim commas to avoid bad matching for quoted phrases
|
||||||
|
term = term.strip(',')
|
||||||
|
if term.startswith(('"', "'")) and term[0] == term[-1]:
|
||||||
|
# quoted phrases are kept together without any other split
|
||||||
|
split_terms.append(unescape_string_literal(term))
|
||||||
|
else:
|
||||||
|
# non-quoted tokens are split by comma, keeping only non-empty ones
|
||||||
|
for sub_term in term.split(','):
|
||||||
|
if sub_term:
|
||||||
|
split_terms.append(sub_term.strip())
|
||||||
|
return split_terms
|
||||||
|
|
||||||
|
|
||||||
class LessThanOrEqualFilter(BaseFilterBackend):
|
class LessThanOrEqualFilter(BaseFilterBackend):
|
||||||
def get_less_than_field(self, view, request):
|
def get_less_than_field(self, view, request):
|
||||||
return getattr(view, 'less_than_field', None)
|
return getattr(view, 'less_than_field', None)
|
||||||
|
|
||||||
def get_filters_dict(self, request, view):
|
def get_filters_dict(self, request, view):
|
||||||
"""
|
"""
|
||||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
"""
|
"""
|
||||||
less_than_field = self.get_less_than_field(view, request)
|
less_than_field = self.get_less_than_field(view, request)
|
||||||
|
|
||||||
assert less_than_field is not None, (
|
assert less_than_field is not None, (
|
||||||
f"{view.__class__.__name__} should include a `less_than_field` attribute"
|
f"{view.__class__.__name__} should include a `less_than_field`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
filters_dict = {}
|
filters_dict = {}
|
||||||
if less_than_field+'_max' in request.query_params.keys():
|
if less_than_field+'_max' in request.query_params.keys():
|
||||||
field_value = request.query_params.get(less_than_field+'_max')
|
field_value = request.query_params.get(less_than_field+'_max')
|
||||||
@@ -36,9 +65,10 @@ class LessThanOrEqualFilter(BaseFilterBackend):
|
|||||||
else:
|
else:
|
||||||
filters_dict[less_than_field+'_max'] = []
|
filters_dict[less_than_field+'_max'] = []
|
||||||
return filters_dict
|
return filters_dict
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
# Return the correctly filtered queryset but also assign the filter dict to create the unique url for the cache
|
# Return the correctly filtered queryset but also assign the filter
|
||||||
|
# dict to create the unique url for the cache
|
||||||
less_than_field = self.get_less_than_field(view, request)
|
less_than_field = self.get_less_than_field(view, request)
|
||||||
if less_than_field+'_max' in request.query_params.keys():
|
if less_than_field+'_max' in request.query_params.keys():
|
||||||
kwquery = {}
|
kwquery = {}
|
||||||
@@ -52,17 +82,19 @@ class LessThanOrEqualFilter(BaseFilterBackend):
|
|||||||
class MoreThanOrEqualFilter(BaseFilterBackend):
|
class MoreThanOrEqualFilter(BaseFilterBackend):
|
||||||
def get_more_than_field(self, view, request):
|
def get_more_than_field(self, view, request):
|
||||||
return getattr(view, 'more_than_field', None)
|
return getattr(view, 'more_than_field', None)
|
||||||
|
|
||||||
def get_filters_dict(self, request, view):
|
def get_filters_dict(self, request, view):
|
||||||
"""
|
"""
|
||||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
"""
|
"""
|
||||||
more_than_field = self.get_more_than_field(view, request)
|
more_than_field = self.get_more_than_field(view, request)
|
||||||
|
|
||||||
assert more_than_field is not None, (
|
assert more_than_field is not None, (
|
||||||
f"{view.__class__.__name__} should include a `more_than_field` attribute"
|
f"{view.__class__.__name__} should include a `more_than_field`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
filters_dict = {}
|
filters_dict = {}
|
||||||
if more_than_field+'_min' in request.query_params.keys():
|
if more_than_field+'_min' in request.query_params.keys():
|
||||||
field_value = request.query_params.get(more_than_field+'_min')
|
field_value = request.query_params.get(more_than_field+'_min')
|
||||||
@@ -70,7 +102,7 @@ class MoreThanOrEqualFilter(BaseFilterBackend):
|
|||||||
else:
|
else:
|
||||||
filters_dict[more_than_field+'_min'] = []
|
filters_dict[more_than_field+'_min'] = []
|
||||||
return filters_dict
|
return filters_dict
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
"""
|
"""
|
||||||
Return the correctly filtered queryset
|
Return the correctly filtered queryset
|
||||||
@@ -78,7 +110,8 @@ class MoreThanOrEqualFilter(BaseFilterBackend):
|
|||||||
more_than_field = self.get_more_than_field(view, request)
|
more_than_field = self.get_more_than_field(view, request)
|
||||||
if more_than_field+'_min' in request.query_params.keys():
|
if more_than_field+'_min' in request.query_params.keys():
|
||||||
kwquery = {}
|
kwquery = {}
|
||||||
kwquery[more_than_field+'__gte'] = request.query_params.get(more_than_field+'_min')
|
kwquery[more_than_field+'__gte'] = request.query_params.get(
|
||||||
|
more_than_field+'_min')
|
||||||
return queryset.filter(**kwquery)
|
return queryset.filter(**kwquery)
|
||||||
else:
|
else:
|
||||||
return queryset
|
return queryset
|
||||||
@@ -86,82 +119,90 @@ class MoreThanOrEqualFilter(BaseFilterBackend):
|
|||||||
|
|
||||||
class CategoryFilter(BaseFilterBackend):
|
class CategoryFilter(BaseFilterBackend):
|
||||||
"""
|
"""
|
||||||
This filter assigns the view.category object for later use, in particular for filters that depend on this one.
|
This filter assigns the view.category object for later use, in particular
|
||||||
|
for filters that depend on this one.
|
||||||
"""
|
"""
|
||||||
template = 'starfields_drf_generics/templates/filters/categories.html'
|
template = 'filters/categories.html'
|
||||||
category_field = 'category'
|
category_field = 'category'
|
||||||
|
|
||||||
def get_category_class(self, view, request):
|
def get_category_class(self, view, request):
|
||||||
return getattr(view, 'category_class', None)
|
return getattr(view, 'category_class', None)
|
||||||
|
|
||||||
def assign_view_category(self, request, view):
|
def assign_view_category(self, request, view):
|
||||||
if not hasattr(view, self.category_field):
|
if not hasattr(view, self.category_field):
|
||||||
if self.category_field in request.query_params.keys():
|
if self.category_field in request.query_params.keys():
|
||||||
try:
|
try:
|
||||||
category_slug = request.query_params.get(self.category_field).strip()
|
category_slug = request.query_params.get(
|
||||||
category = view.category_class.objects.get(slug=category_slug)
|
self.category_field).strip()
|
||||||
# Append the category object in the view for later reference
|
category = view.category_class.objects.get(
|
||||||
|
slug=category_slug)
|
||||||
|
# Append the category object in the view for later use
|
||||||
view.category = category
|
view.category = category
|
||||||
except:
|
except Exception:
|
||||||
view.category = None
|
view.category = None
|
||||||
else:
|
else:
|
||||||
view.category = None
|
view.category = None
|
||||||
|
|
||||||
|
|
||||||
def get_filters_dict(self, request, view):
|
def get_filters_dict(self, request, view):
|
||||||
"""
|
"""
|
||||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes. Queries the database for the current category and saves it in view.category for internal use and later filters.
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes. Queries the database for the current
|
||||||
|
category and saves it in view.category for internal use and later
|
||||||
|
filters.
|
||||||
"""
|
"""
|
||||||
if hasattr(view, 'category_field'):
|
if hasattr(view, 'category_field'):
|
||||||
self.category_field = self.get_category_field(view, request)
|
self.category_field = self.get_category_field(view, request)
|
||||||
|
|
||||||
category_class = self.get_category_class(view, request)
|
category_class = self.get_category_class(view, request)
|
||||||
|
|
||||||
assert category_class is not None, (
|
assert category_class is not None, (
|
||||||
f"{view.__class__.__name__} should include a `category_class` attribute"
|
f"{view.__class__.__name__} should include a `category_class`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the filters dictionary and find the present category instance
|
# Create the filters dictionary and find the present category instance
|
||||||
self.assign_view_category(request, view)
|
self.assign_view_category(request, view)
|
||||||
|
|
||||||
filters_dict = {}
|
filters_dict = {}
|
||||||
if view.category:
|
if view.category:
|
||||||
filters_dict[self.category_field] = [view.category.slug]
|
filters_dict[self.category_field] = [view.category.slug]
|
||||||
else:
|
else:
|
||||||
filters_dict[self.category_field] = []
|
filters_dict[self.category_field] = []
|
||||||
|
|
||||||
return filters_dict
|
return filters_dict
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
self.assign_view_category(request, view)
|
self.assign_view_category(request, view)
|
||||||
|
|
||||||
# Create the queryset
|
# Create the queryset
|
||||||
if view.category:
|
if view.category:
|
||||||
kwquery_1 = {}
|
kwquery_1 = {}
|
||||||
kwquery_1[self.category_field+'__id'] = view.category.id
|
kwquery_1[self.category_field+'__id'] = view.category.id
|
||||||
if view.category.tn_descendants_pks:
|
if view.category.tn_descendants_pks:
|
||||||
kwquery_2 = {}
|
kwquery_2 = {}
|
||||||
kwquery_2[self.category_field+'__id__in'] = view.category.tn_descendants_pks.split(',')
|
key = self.category_field+'__id__in'
|
||||||
|
kwquery_2[key] = view.category.tn_descendants_pks.split(',')
|
||||||
|
|
||||||
queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2))
|
queryset = queryset.filter(Q(**kwquery_1) | Q(**kwquery_2))
|
||||||
else:
|
else:
|
||||||
queryset = queryset.filter(**kwquery_1)
|
queryset = queryset.filter(**kwquery_1)
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
# Developer Interface methods
|
# Developer Interface methods
|
||||||
def get_valid_fields(self, queryset, view, context, request):
|
def get_valid_fields(self, queryset, view, context, request):
|
||||||
# A query is executed here to get the possible categories
|
# A query is executed here to get the possible categories
|
||||||
category_class = self.get_category_class(view, request)
|
category_class = self.get_category_class(view, request)
|
||||||
if hasattr(view, 'category_field'):
|
if hasattr(view, 'category_field'):
|
||||||
self.category_field = self.get_category_field(view, request)
|
self.category_field = self.get_category_field(view, request)
|
||||||
|
|
||||||
assert category_class is not None, (
|
assert category_class is not None, (
|
||||||
f"{view.__class__.__name__} should include a `category_class` attribute"
|
f"{view.__class__.__name__} should include a `category_class`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_fields = category_class.objects.all()
|
valid_fields = category_class.objects.all()
|
||||||
|
|
||||||
if len(valid_fields):
|
if len(valid_fields):
|
||||||
valid_fields = [
|
valid_fields = [
|
||||||
(item.slug, item.__str__()) for item in valid_fields
|
(item.slug, item.__str__()) for item in valid_fields
|
||||||
@@ -170,7 +211,7 @@ class CategoryFilter(BaseFilterBackend):
|
|||||||
return valid_fields
|
return valid_fields
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_current(self, request, queryset, view):
|
def get_current(self, request, queryset, view):
|
||||||
params = request.query_params.get(self.category_field)
|
params = request.query_params.get(self.category_field)
|
||||||
if params:
|
if params:
|
||||||
@@ -178,7 +219,7 @@ class CategoryFilter(BaseFilterBackend):
|
|||||||
return fields[0]
|
return fields[0]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_template_context(self, request, queryset, view):
|
def get_template_context(self, request, queryset, view):
|
||||||
current = self.get_current(request, queryset, view)
|
current = self.get_current(request, queryset, view)
|
||||||
options = []
|
options = []
|
||||||
@@ -187,11 +228,12 @@ class CategoryFilter(BaseFilterBackend):
|
|||||||
'current': current,
|
'current': current,
|
||||||
'param': self.category_field,
|
'param': self.category_field,
|
||||||
}
|
}
|
||||||
for key, label in self.get_valid_fields(queryset, view, context, request):
|
valid_fields = self.get_valid_fields(queryset, view, context, request)
|
||||||
|
for key, label in valid_fields:
|
||||||
options.append((key, '%s' % (label)))
|
options.append((key, '%s' % (label)))
|
||||||
context['options'] = options
|
context['options'] = options
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def to_html(self, request, queryset, view):
|
def to_html(self, request, queryset, view):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_template_context(request, queryset, view)
|
context = self.get_template_context(request, queryset, view)
|
||||||
@@ -200,112 +242,135 @@ class CategoryFilter(BaseFilterBackend):
|
|||||||
|
|
||||||
class FacetFilter(BaseFilterBackend):
|
class FacetFilter(BaseFilterBackend):
|
||||||
"""
|
"""
|
||||||
This filter requires CategoryFilter to be ran before it. It assigns the view.facets which includes all the facets applicable to the current category.
|
This filter requires CategoryFilter to be ran before it. It assigns the
|
||||||
|
view.facets which includes all the facets applicable to the current
|
||||||
|
category.
|
||||||
"""
|
"""
|
||||||
template = 'starfields_drf_generics/templates/filters/facets.html'
|
template = 'filters/facets.html'
|
||||||
|
|
||||||
def get_facet_class(self, view, request):
|
def get_facet_class(self, view, request):
|
||||||
return getattr(view, 'facet_class', None)
|
return getattr(view, 'facet_class', None)
|
||||||
|
|
||||||
def get_facet_tag_class(self, view, request):
|
def get_facet_tag_class(self, view, request):
|
||||||
return getattr(view, 'facet_tag_class', None)
|
return getattr(view, 'facet_tag_class', None)
|
||||||
|
|
||||||
def get_facet_tag_field(self, view, request):
|
def get_facet_tag_field(self, view, request):
|
||||||
return getattr(view, 'facet_tag_field', None)
|
return getattr(view, 'facet_tag_field', None)
|
||||||
|
|
||||||
def assign_view_facets(self, request, view):
|
def assign_view_facets(self, request, view):
|
||||||
if not hasattr(view, 'facets'):
|
if not hasattr(view, 'facets'):
|
||||||
if hasattr(view, 'facet_class'):
|
if hasattr(view, 'facet_class'):
|
||||||
self.facet_class = self.get_facet_class(view, request)
|
self.facet_class = self.get_facet_class(view, request)
|
||||||
|
|
||||||
assert self.facet_class is not None, (
|
assert self.facet_class is not None, (
|
||||||
f"{view.__class__.__name__} should include a `facet_class` attribute"
|
f"{view.__class__.__name__} should include a `facet_class`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
if view.category:
|
if view.category:
|
||||||
if view.category.tn_ancestors_pks:
|
if view.category.tn_ancestors_pks:
|
||||||
view.facets = self.facet_class.objects.filter(Q(category__id=view.category.id) | Q(category__id__in=view.category.tn_ancestors_pks.split(','))).prefetch_related('facet_tags')
|
ancestor_ids = view.category.tn_ancestors_pks.split(',')
|
||||||
|
view.facets = self.facet_class.objects.filter(
|
||||||
|
Q(category__id=view.category.id) | Q(
|
||||||
|
category__id__in=ancestor_ids)
|
||||||
|
).prefetch_related('facet_tags')
|
||||||
else:
|
else:
|
||||||
view.facets = self.facet_class.objects.filter(category__id=view.category.id).prefetch_related('facet_tags')
|
view.facets = self.facet_class.objects.filter(
|
||||||
|
category__id=view.category.id).prefetch_related(
|
||||||
|
'facet_tags')
|
||||||
else:
|
else:
|
||||||
view.facets = self.facet_class.objects.filter(category__tn_level=1).prefetch_related('facet_tags')
|
view.facets = self.facet_class.objects.filter(
|
||||||
|
category__tn_level=1).prefetch_related(
|
||||||
|
'facet_tags')
|
||||||
|
|
||||||
def get_filters_dict(self, request, view):
|
def get_filters_dict(self, request, view):
|
||||||
"""
|
"""
|
||||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
"""
|
"""
|
||||||
if hasattr(view, 'facet_class'):
|
if hasattr(view, 'facet_class'):
|
||||||
self.facet_class = self.get_facet_class(view, request)
|
self.facet_class = self.get_facet_class(view, request)
|
||||||
|
|
||||||
assert self.facet_class is not None, (
|
assert self.facet_class is not None, (
|
||||||
f"{view.__class__.__name__} should include a `facet_class` attribute"
|
f"{view.__class__.__name__} should include a `facet_class`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assign_view_facets(request, view)
|
self.assign_view_facets(request, view)
|
||||||
|
|
||||||
filters_dict = {}
|
filters_dict = {}
|
||||||
if view.facets:
|
if view.facets:
|
||||||
for facet in view.facets:
|
for facet in view.facets:
|
||||||
if facet.slug in request.query_params.keys():
|
if facet.slug in request.query_params.keys():
|
||||||
filters_dict[facet.slug] = set(request.query_params[facet.slug].split(','))
|
filters_dict[facet.slug] = set(
|
||||||
|
request.query_params[facet.slug].split(','))
|
||||||
else:
|
else:
|
||||||
filters_dict[facet.slug] = set({})
|
filters_dict[facet.slug] = set({})
|
||||||
|
|
||||||
# Append the facets object and the tags dict in the view for later reference
|
# Append the facets object and the tags dict in the view for later
|
||||||
|
# reference
|
||||||
view.tags = filters_dict
|
view.tags = filters_dict
|
||||||
|
|
||||||
return filters_dict
|
return filters_dict
|
||||||
|
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
if hasattr(view, 'facet_tag_class'):
|
if hasattr(view, 'facet_tag_class'):
|
||||||
self.facet_tag_class = self.get_facet_tag_class(view, request)
|
self.facet_tag_class = self.get_facet_tag_class(view, request)
|
||||||
|
|
||||||
assert self.facet_tag_class is not None, (
|
assert self.facet_tag_class is not None, (
|
||||||
f"{view.__class__.__name__} should include a `facet_tag_class` attribute"
|
f"{view.__class__.__name__} should include a `facet_tag_class`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(view, 'facet_tag_field'):
|
if hasattr(view, 'facet_tag_field'):
|
||||||
self.facet_tag_field = self.get_facet_tag_field(view, request)
|
self.facet_tag_field = self.get_facet_tag_field(view, request)
|
||||||
|
|
||||||
assert self.facet_tag_field is not None, (
|
assert self.facet_tag_field is not None, (
|
||||||
f"{view.__class__.__name__} should include a `facet_tag_field` attribute"
|
f"{view.__class__.__name__} should include a `facet_tag_field`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assign_view_facets(request, view)
|
self.assign_view_facets(request, view)
|
||||||
|
|
||||||
if view.facets:
|
if view.facets:
|
||||||
for facet in view.facets:
|
for facet in view.facets:
|
||||||
if facet.slug in request.query_params.keys() and request.query_params[facet.slug]:
|
request_slug = request.query_params[facet.slug]
|
||||||
|
if facet.slug in request.query_params.keys() and request_slug:
|
||||||
tag_filterlist = request.query_params.get(facet.slug)
|
tag_filterlist = request.query_params.get(facet.slug)
|
||||||
if tag_filterlist == '':
|
if tag_filterlist == '':
|
||||||
# If the tag filterlist is empty then we're not filtering against it, it's like having all the tags of the facet selected
|
# If the tag filterlist is empty then we're not
|
||||||
|
# filtering against it, it's like having all the tags
|
||||||
|
# of the facet selected
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
kwquery = {}
|
kwquery = {}
|
||||||
kwquery[self.facet_tag_field+'__slug__in'] = tag_filterlist.replace(' ','').split(',')
|
key = self.facet_tag_field+'__slug__in'
|
||||||
|
kwquery[key] = tag_filterlist.replace(' ', '').split(
|
||||||
|
',')
|
||||||
queryset = queryset.filter(**kwquery)
|
queryset = queryset.filter(**kwquery)
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
|
||||||
# Developer Interface methods
|
# Developer Interface methods
|
||||||
def get_template_context(self, request, queryset, view):
|
def get_template_context(self, request, queryset, view):
|
||||||
# Does aggressive database querying to get the necessary facets and facettags, but this is only for the developer interface so its fine
|
# Does aggressive database querying to get the necessary facets and
|
||||||
|
# facettags, but this is only for the developer interface so its fine
|
||||||
if hasattr(view, 'facet_class'):
|
if hasattr(view, 'facet_class'):
|
||||||
self.facet_class = self.get_facet_class(view, request)
|
self.facet_class = self.get_facet_class(view, request)
|
||||||
|
|
||||||
assert self.facet_class is not None, (
|
assert self.facet_class is not None, (
|
||||||
f"{view.__class__.__name__} should include a `facet_class` attribute"
|
f"{view.__class__.__name__} should include a `facet_class`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(view, 'facet_tag_class'):
|
if hasattr(view, 'facet_tag_class'):
|
||||||
self.facet_tag_class = self.get_facet_tag_class(view, request)
|
self.facet_tag_class = self.get_facet_tag_class(view, request)
|
||||||
|
|
||||||
assert self.facet_tag_class is not None, (
|
assert self.facet_tag_class is not None, (
|
||||||
f"{view.__class__.__name__} should include a `facet_tag_class` attribute"
|
f"{view.__class__.__name__} should include a `facet_tag_class`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find the current choices
|
# Find the current choices
|
||||||
current = []
|
current = []
|
||||||
facet_slugs = []
|
facet_slugs = []
|
||||||
@@ -314,7 +379,7 @@ class FacetFilter(BaseFilterBackend):
|
|||||||
facet_slugs.append(facet.slug)
|
facet_slugs.append(facet.slug)
|
||||||
if facet.slug in request.query_params.keys():
|
if facet.slug in request.query_params.keys():
|
||||||
current.append(request.query_params.get(facet.slug))
|
current.append(request.query_params.get(facet.slug))
|
||||||
|
|
||||||
facet_slug_names = {}
|
facet_slug_names = {}
|
||||||
options = {}
|
options = {}
|
||||||
context = {
|
context = {
|
||||||
@@ -324,17 +389,19 @@ class FacetFilter(BaseFilterBackend):
|
|||||||
}
|
}
|
||||||
if view.facets:
|
if view.facets:
|
||||||
for facet in view.facets:
|
for facet in view.facets:
|
||||||
facet_tag_instances = self.facet_tag_class.objects.filter(facet__slug=facet.slug)
|
facet_tag_instances = self.facet_tag_class.objects.filter(
|
||||||
options[facet.slug] = [(facet_tag.slug, facet_tag.name) for facet_tag in facet_tag_instances]
|
facet__slug=facet.slug)
|
||||||
|
options[facet.slug] = [(facet_tag.slug, facet_tag.name) for
|
||||||
|
facet_tag in facet_tag_instances]
|
||||||
facet_slug_names[facet.slug] = facet.name
|
facet_slug_names[facet.slug] = facet.name
|
||||||
context['facet_slug_names'] = facet_slug_names
|
context['facet_slug_names'] = facet_slug_names
|
||||||
context['options'] = options
|
context['options'] = options
|
||||||
else:
|
else:
|
||||||
context['facet_slug_names'] = {}
|
context['facet_slug_names'] = {}
|
||||||
context['options'] = {}
|
context['options'] = {}
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def to_html(self, request, queryset, view):
|
def to_html(self, request, queryset, view):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_template_context(request, queryset, view)
|
context = self.get_template_context(request, queryset, view)
|
||||||
@@ -346,10 +413,13 @@ class TrigramSearchFilter(BaseFilterBackend):
|
|||||||
search_param = 'search'
|
search_param = 'search'
|
||||||
template = 'rest_framework/filters/search.html'
|
template = 'rest_framework/filters/search.html'
|
||||||
search_title = _('Search')
|
search_title = _('Search')
|
||||||
search_description = _('A search string to perform trigram similarity based searching with.')
|
search_description = _('A search string to perform trigram similarity'
|
||||||
|
'based searching with.')
|
||||||
|
|
||||||
def get_filters_dict(self, request, view):
|
def get_filters_dict(self, request, view):
|
||||||
"""
|
"""
|
||||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
Custom method that returns the filters exclusive to this filter in a
|
||||||
|
dict. For caching purposes.
|
||||||
"""
|
"""
|
||||||
self.filters_dict = {}
|
self.filters_dict = {}
|
||||||
if 'search' in request.query_params.keys():
|
if 'search' in request.query_params.keys():
|
||||||
@@ -357,27 +427,141 @@ class TrigramSearchFilter(BaseFilterBackend):
|
|||||||
self.filters_dict['search'] = [slug_term]
|
self.filters_dict['search'] = [slug_term]
|
||||||
else:
|
else:
|
||||||
self.filters_dict['search'] = []
|
self.filters_dict['search'] = []
|
||||||
|
|
||||||
return self.filters_dict
|
return self.filters_dict
|
||||||
|
|
||||||
|
def get_search_fields(self, view, request):
|
||||||
|
"""
|
||||||
|
Search fields are obtained from the view, but the request is always
|
||||||
|
passed to this method. Sub-classes can override this method to
|
||||||
|
dynamically change the search fields based on request content.
|
||||||
|
"""
|
||||||
|
return getattr(view, 'search_fields', None)
|
||||||
|
|
||||||
|
def get_search_query(self, request):
|
||||||
|
"""
|
||||||
|
Search terms are set by a ?search=... query parameter,
|
||||||
|
and may be whitespace delimited.
|
||||||
|
"""
|
||||||
|
value = request.query_params.get(self.search_param, '')
|
||||||
|
field = CharField(trim_whitespace=False, allow_blank=True)
|
||||||
|
cleaned_value = field.run_validation(value)
|
||||||
|
return cleaned_value
|
||||||
|
|
||||||
|
def construct_search(self, field_name, queryset):
|
||||||
|
"""
|
||||||
|
For the sqlite search
|
||||||
|
"""
|
||||||
|
lookup = self.lookup_prefixes.get(field_name[0])
|
||||||
|
if lookup:
|
||||||
|
field_name = field_name[1:]
|
||||||
|
else:
|
||||||
|
# Use field_name if it includes a lookup.
|
||||||
|
opts = queryset.model._meta
|
||||||
|
lookup_fields = field_name.split(LOOKUP_SEP)
|
||||||
|
# Go through the fields, following all relations.
|
||||||
|
prev_field = None
|
||||||
|
for path_part in lookup_fields:
|
||||||
|
if path_part == "pk":
|
||||||
|
path_part = opts.pk.name
|
||||||
|
try:
|
||||||
|
field = opts.get_field(path_part)
|
||||||
|
except FieldDoesNotExist:
|
||||||
|
# Use valid query lookups.
|
||||||
|
if prev_field and prev_field.get_lookup(path_part):
|
||||||
|
return field_name
|
||||||
|
else:
|
||||||
|
prev_field = field
|
||||||
|
if hasattr(field, "path_infos"):
|
||||||
|
# Update opts to follow the relation.
|
||||||
|
opts = field.path_infos[-1].to_opts
|
||||||
|
# django < 4.1
|
||||||
|
elif hasattr(field, 'get_path_info'):
|
||||||
|
# Update opts to follow the relation.
|
||||||
|
opts = field.get_path_info()[-1].to_opts
|
||||||
|
# Otherwise, use the field with icontains.
|
||||||
|
lookup = 'icontains'
|
||||||
|
return LOOKUP_SEP.join([field_name, lookup])
|
||||||
|
|
||||||
|
def must_call_distinct(self, queryset, search_fields):
|
||||||
|
"""
|
||||||
|
Return True if 'distinct()' should be used to query the given lookups.
|
||||||
|
"""
|
||||||
|
for search_field in search_fields:
|
||||||
|
opts = queryset.model._meta
|
||||||
|
if search_field[0] in self.lookup_prefixes:
|
||||||
|
search_field = search_field[1:]
|
||||||
|
# Annotated fields do not need to be distinct
|
||||||
|
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
|
||||||
|
continue
|
||||||
|
parts = search_field.split(LOOKUP_SEP)
|
||||||
|
for part in parts:
|
||||||
|
field = opts.get_field(part)
|
||||||
|
if hasattr(field, 'get_path_info'):
|
||||||
|
# This field is a relation, update opts to follow the relation
|
||||||
|
path_info = field.get_path_info()
|
||||||
|
opts = path_info[-1].to_opts
|
||||||
|
if any(path.m2m for path in path_info):
|
||||||
|
# This field is a m2m relation so we know we need to call distinct
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# This field has a custom __ query transform but is not a relational field.
|
||||||
|
break
|
||||||
|
return False
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
search_fields = getattr(view, 'search_fields', None)
|
search_fields = self.get_search_fields(view, request)
|
||||||
|
|
||||||
assert search_fields is not None, (
|
assert search_fields is not None, (
|
||||||
f"{view.__class__.__name__} should include a `search_fields` attribute"
|
f"{view.__class__.__name__} should include a `search_fields`"
|
||||||
|
"attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
query = request.query_params.get(self.search_param, '')
|
query = self.get_search_query(request)
|
||||||
|
|
||||||
if query:
|
if not query:
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Attempt postgresql's full text search
|
||||||
|
from django.contrib.postgres.search import TrigramStrictWordSimilarity
|
||||||
|
threshold = calculate_threshold(query, 0.02, 0.12)
|
||||||
queryset = queryset.annotate(
|
queryset = queryset.annotate(
|
||||||
search_field=Concat(
|
search_field=Concat(
|
||||||
*search_fields,
|
*search_fields,
|
||||||
output_field=CharField()
|
output_field=CharField()
|
||||||
)).annotate(
|
)).annotate(
|
||||||
similarity=TrigramSimilarity('search_field', query)
|
similarity=TrigramStrictWordSimilarity(
|
||||||
).filter(similarity__gt=0.05).distinct()
|
'search_field', query)
|
||||||
|
).filter(similarity__gt=threshold)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Perform very simple sqlite compatible search
|
||||||
|
search_terms = search_smart_split(query)
|
||||||
|
|
||||||
|
orm_lookups = [
|
||||||
|
self.construct_search(str(search_field), queryset)
|
||||||
|
for search_field in search_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
base = queryset
|
||||||
|
# generator which for each term builds the corresponding search
|
||||||
|
conditions = (
|
||||||
|
reduce(
|
||||||
|
operator.or_,
|
||||||
|
(models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups)
|
||||||
|
) for term in search_terms
|
||||||
|
)
|
||||||
|
queryset = queryset.filter(reduce(operator.and_, conditions))
|
||||||
|
|
||||||
|
# Remove duplicates from results, if necessary
|
||||||
|
if self.must_call_distinct(queryset, search_fields):
|
||||||
|
# inspired by django.contrib.admin
|
||||||
|
# this is more accurate than .distinct form M2M relationship
|
||||||
|
# also is cross-database
|
||||||
|
queryset = queryset.filter(pk=models.OuterRef('pk'))
|
||||||
|
queryset = base.filter(models.Exists(queryset))
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
def to_html(self, request, queryset, view):
|
def to_html(self, request, queryset, view):
|
||||||
@@ -429,21 +613,21 @@ class TrigramSearchFilter(BaseFilterBackend):
|
|||||||
# TODO misunderstood the urlconf stuff of the RUD methods, this is probably unnecessary
|
# TODO misunderstood the urlconf stuff of the RUD methods, this is probably unnecessary
|
||||||
class SlugSearchFilter(BaseFilterBackend):
|
class SlugSearchFilter(BaseFilterBackend):
|
||||||
# The URL query parameter used for the search.
|
# The URL query parameter used for the search.
|
||||||
template = 'starfields_drf_generics/templates/filters/slug.html'
|
template = 'filters/slug.html'
|
||||||
slug_title = _('Slug Search')
|
slug_title = _('Slug Search')
|
||||||
slug_description = _("The instance's slug.")
|
slug_description = _("The instance's slug.")
|
||||||
slug_field = 'slug'
|
slug_field = 'slug'
|
||||||
|
|
||||||
def get_slug_field(self, view, request):
|
def get_slug_field(self, view, request):
|
||||||
return getattr(view, 'slug_field', None)
|
return getattr(view, 'slug_field', None)
|
||||||
|
|
||||||
def get_filters_dict(self, request, view):
|
def get_filters_dict(self, request, view):
|
||||||
"""
|
"""
|
||||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||||
"""
|
"""
|
||||||
if hasattr(view, 'slug_field'):
|
if hasattr(view, 'slug_field'):
|
||||||
self.slug_field = self.get_slug_field(view, request)
|
self.slug_field = self.get_slug_field(view, request)
|
||||||
|
|
||||||
assert self.slug_field is not None, (
|
assert self.slug_field is not None, (
|
||||||
f"{view.__class__.__name__} should include a `slug_field` attribute"
|
f"{view.__class__.__name__} should include a `slug_field` attribute"
|
||||||
)
|
)
|
||||||
@@ -453,9 +637,9 @@ class SlugSearchFilter(BaseFilterBackend):
|
|||||||
self.filters_dict[self.slug_field] = [slug_term]
|
self.filters_dict[self.slug_field] = [slug_term]
|
||||||
else:
|
else:
|
||||||
self.filters_dict[self.slug_field] = []
|
self.filters_dict[self.slug_field] = []
|
||||||
|
|
||||||
return self.filters_dict
|
return self.filters_dict
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
# Ensure that the slug field was searched against
|
# Ensure that the slug field was searched against
|
||||||
try:
|
try:
|
||||||
@@ -463,18 +647,18 @@ class SlugSearchFilter(BaseFilterBackend):
|
|||||||
slug_term = request.query_params.get(self.slug_field)
|
slug_term = request.query_params.get(self.slug_field)
|
||||||
query = {}
|
query = {}
|
||||||
query[self.slug_field] = slug_term
|
query[self.slug_field] = slug_term
|
||||||
|
|
||||||
queryset = queryset.get(**query)
|
queryset = queryset.get(**query)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
|
||||||
def to_html(self, request, queryset, view):
|
def to_html(self, request, queryset, view):
|
||||||
if not getattr(view, 'slug_field', None):
|
if not getattr(view, 'slug_field', None):
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
slug_term = self.get_slug_term(request)
|
slug_term = self.get_slug_term(request)
|
||||||
context = {
|
context = {
|
||||||
'param': self.slug_field,
|
'param': self.slug_field,
|
||||||
@@ -540,7 +724,7 @@ class SearchFilter(BaseFilterBackend):
|
|||||||
# This field has a custom __ query transform but is not a relational field.
|
# This field has a custom __ query transform but is not a relational field.
|
||||||
break
|
break
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_filters_dict(self, request, view):
|
def get_filters_dict(self, request, view):
|
||||||
"""
|
"""
|
||||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||||
@@ -551,7 +735,7 @@ class SearchFilter(BaseFilterBackend):
|
|||||||
self.filters_dict['search'] = [slug_term]
|
self.filters_dict['search'] = [slug_term]
|
||||||
else:
|
else:
|
||||||
self.filters_dict['search'] = []
|
self.filters_dict['search'] = []
|
||||||
|
|
||||||
return self.filters_dict
|
return self.filters_dict
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
@@ -728,7 +912,7 @@ class OrderingFilter(BaseFilterBackend):
|
|||||||
return term in valid_fields
|
return term in valid_fields
|
||||||
|
|
||||||
return [term for term in fields if term_valid(term)]
|
return [term for term in fields if term_valid(term)]
|
||||||
|
|
||||||
def get_filters_dict(self, request, view):
|
def get_filters_dict(self, request, view):
|
||||||
"""
|
"""
|
||||||
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
Custom method that returns the filters exclusive to this filter in a dict. For caching purposes.
|
||||||
@@ -739,9 +923,9 @@ class OrderingFilter(BaseFilterBackend):
|
|||||||
self.filters_dict['ordering'] = [slug_term]
|
self.filters_dict['ordering'] = [slug_term]
|
||||||
else:
|
else:
|
||||||
self.filters_dict['ordering'] = [view.ordering_fields[0]]
|
self.filters_dict['ordering'] = [view.ordering_fields[0]]
|
||||||
|
|
||||||
return self.filters_dict
|
return self.filters_dict
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
ordering = self.get_ordering(request, queryset, view)
|
ordering = self.get_ordering(request, queryset, view)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Generic views that provide commonly needed behaviour.
|
Generic views that provide commonly needed behaviour.
|
||||||
"""
|
"""
|
||||||
from django.core.exceptions import ValidationError
|
|
||||||
from django.db.models.query import QuerySet
|
|
||||||
from django.http import Http404
|
|
||||||
from django.shortcuts import get_object_or_404 as _get_object_or_404
|
|
||||||
|
|
||||||
from rest_framework import views
|
|
||||||
from rest_framework.generics import GenericAPIView
|
from rest_framework.generics import GenericAPIView
|
||||||
from rest_framework.settings import api_settings
|
|
||||||
|
|
||||||
from starfields_drf_generics import mixins
|
from starfields_drf_generics import mixins
|
||||||
|
|
||||||
|
|
||||||
@@ -18,26 +10,32 @@ from starfields_drf_generics import mixins
|
|||||||
|
|
||||||
# Single item CRUD
|
# Single item CRUD
|
||||||
|
|
||||||
class CachedCreateAPIView(mixins.CachedCreateModelMixin,GenericAPIView):
|
class CachedCreateAPIView(mixins.CachedCreateModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for creating a model instance.
|
Concrete view for creating a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
return self.create(request, *args, **kwargs)
|
return self.create(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedRetrieveAPIView(mixins.CachedRetrieveModelMixin,GenericAPIView):
|
class CachedRetrieveAPIView(mixins.CachedRetrieveModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for retrieving a model instance.
|
Concrete view for retrieving a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.retrieve(request, *args, **kwargs)
|
return self.retrieve(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,GenericAPIView):
|
class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for updating a model instance.
|
Concrete view for updating a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def put(self, request, *args, **kwargs):
|
def put(self, request, *args, **kwargs):
|
||||||
return self.update(request, *args, **kwargs)
|
return self.update(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -45,18 +43,23 @@ class CachedUpdateAPIView(mixins.CachedUpdateModelMixin,GenericAPIView):
|
|||||||
return self.partial_update(request, *args, **kwargs)
|
return self.partial_update(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedDestroyAPIView(mixins.CachedDestroyModelMixin,GenericAPIView):
|
class CachedDestroyAPIView(mixins.CachedDestroyModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for deleting a model instance.
|
Concrete view for deleting a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete(self, request, *args, **kwargs):
|
def delete(self, request, *args, **kwargs):
|
||||||
return self.destroy(request, *args, **kwargs)
|
return self.destroy(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,GenericAPIView):
|
class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,
|
||||||
|
mixins.CachedUpdateModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for retrieving, updating a model instance.
|
Concrete view for retrieving, updating a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.retrieve(request, *args, **kwargs)
|
return self.retrieve(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -67,10 +70,13 @@ class CachedRetrieveUpdateAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedU
|
|||||||
return self.partial_update(request, *args, **kwargs)
|
return self.partial_update(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
|
class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,
|
||||||
|
mixins.CachedDestroyModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for retrieving or deleting a model instance.
|
Concrete view for retrieving or deleting a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.retrieve(request, *args, **kwargs)
|
return self.retrieve(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -78,10 +84,14 @@ class CachedRetrieveDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.Cached
|
|||||||
return self.destroy(request, *args, **kwargs)
|
return self.destroy(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
|
class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,
|
||||||
|
mixins.CachedUpdateModelMixin,
|
||||||
|
mixins.CachedDestroyModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for retrieving, updating or deleting a model instance.
|
Concrete view for retrieving, updating or deleting a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.retrieve(request, *args, **kwargs)
|
return self.retrieve(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -95,10 +105,16 @@ class CachedRetrieveUpdateDestroyAPIView(mixins.CachedRetrieveModelMixin,mixins.
|
|||||||
return self.destroy(request, *args, **kwargs)
|
return self.destroy(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,mixins.CachedRetrieveModelMixin,mixins.CachedUpdateModelMixin,mixins.CachedDestroyModelMixin,GenericAPIView):
|
class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,
|
||||||
|
mixins.CachedRetrieveModelMixin,
|
||||||
|
mixins.CachedUpdateModelMixin,
|
||||||
|
mixins.CachedDestroyModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for creating, retrieving, updating or deleting a model instance.
|
Concrete view for creating, retrieving, updating or deleting a model
|
||||||
|
instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.retrieve(request, *args, **kwargs)
|
return self.retrieve(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -117,26 +133,32 @@ class CachedCreateRetrieveUpdateDestroyAPIView(mixins.CachedCreateModelMixin,mix
|
|||||||
|
|
||||||
# List based CRUD
|
# List based CRUD
|
||||||
|
|
||||||
class CachedListRetrieveAPIView(mixins.CachedListRetrieveModelMixin,GenericAPIView):
|
class CachedListRetrieveAPIView(mixins.CachedListRetrieveModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for listing a queryset.
|
Concrete view for listing a queryset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.list(request, *args, **kwargs)
|
return self.list(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedListCreateAPIView(mixins.CachedListCreateModelMixin,GenericAPIView):
|
class CachedListCreateAPIView(mixins.CachedListCreateModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for creating multiple instances.
|
Concrete view for creating multiple instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
return self.list_create(request, *args, **kwargs)
|
return self.list_create(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,GenericAPIView):
|
class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for updating multiple instances.
|
Concrete view for updating multiple instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def put(self, request, *args, **kwargs):
|
def put(self, request, *args, **kwargs):
|
||||||
return self.list_update(request, *args, **kwargs)
|
return self.list_update(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -144,18 +166,23 @@ class CachedListUpdateAPIView(mixins.CachedListUpdateModelMixin,GenericAPIView):
|
|||||||
return self.list_partial_update(request, *args, **kwargs)
|
return self.list_partial_update(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedListDestroyAPIView(mixins.CachedListDestroyModelMixin,GenericAPIView):
|
class CachedListDestroyAPIView(mixins.CachedListDestroyModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for deleting multiple instances.
|
Concrete view for deleting multiple instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete(self, request, *args, **kwargs):
|
def delete(self, request, *args, **kwargs):
|
||||||
return self.list_destroy(request, *args, **kwargs)
|
return self.list_destroy(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,mixins.CachedListCreateModelMixin,GenericAPIView):
|
class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,
|
||||||
|
mixins.CachedListCreateModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for listing a queryset or creating a model instance.
|
Concrete view for listing a queryset or creating a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.list(request, *args, **kwargs)
|
return self.list(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -163,10 +190,15 @@ class CachedListRetrieveCreateAPIView(mixins.CachedListRetrieveModelMixin,mixins
|
|||||||
return self.create(request, *args, **kwargs)
|
return self.create(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedListCreateRetrieveDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView):
|
class CachedListCreateRetrieveDestroyAPIView(
|
||||||
|
mixins.CachedListCreateModelMixin,
|
||||||
|
mixins.CachedListRetrieveModelMixin,
|
||||||
|
mixins.CachedListDestroyModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for creating, retrieving or deleting a model instance.
|
Concrete view for creating, retrieving or deleting a model instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.list(request, *args, **kwargs)
|
return self.list(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -177,10 +209,16 @@ class CachedListCreateRetrieveDestroyAPIView(mixins.CachedListCreateModelMixin,m
|
|||||||
return self.list_destroy(request, *args, **kwargs)
|
return self.list_destroy(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedListCreateRetrieveUpdateAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,GenericAPIView):
|
class CachedListCreateRetrieveUpdateAPIView(
|
||||||
|
mixins.CachedListCreateModelMixin,
|
||||||
|
mixins.CachedListRetrieveModelMixin,
|
||||||
|
mixins.CachedListUpdateModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for creating, retrieving, updating or deleting a model instance.
|
Concrete view for creating, retrieving, updating or deleting a model
|
||||||
|
instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.list(request, *args, **kwargs)
|
return self.list(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -194,10 +232,17 @@ class CachedListCreateRetrieveUpdateAPIView(mixins.CachedListCreateModelMixin,mi
|
|||||||
return self.list_partial_update(request, *args, **kwargs)
|
return self.list_partial_update(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CachedListCreateRetrieveUpdateDestroyAPIView(mixins.CachedListCreateModelMixin,mixins.CachedListRetrieveModelMixin,mixins.CachedListUpdateModelMixin,mixins.CachedListDestroyModelMixin,GenericAPIView):
|
class CachedListCreateRetrieveUpdateDestroyAPIView(
|
||||||
|
mixins.CachedListCreateModelMixin,
|
||||||
|
mixins.CachedListRetrieveModelMixin,
|
||||||
|
mixins.CachedListUpdateModelMixin,
|
||||||
|
mixins.CachedListDestroyModelMixin,
|
||||||
|
GenericAPIView):
|
||||||
"""
|
"""
|
||||||
Concrete view for creating, retrieving, updating or deleting a model instance.
|
Concrete view for creating, retrieving, updating or deleting a model
|
||||||
|
instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
return self.list(request, *args, **kwargs)
|
return self.list(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -8,56 +8,70 @@ from rest_framework import status
|
|||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
from rest_framework import mixins
|
from rest_framework import mixins
|
||||||
from starfields_drf_generics.cache_mixins import CacheGetMixin, CacheSetMixin, CacheDeleteMixin
|
from starfields_drf_generics.cache_mixins import (
|
||||||
|
CacheGetMixin, CacheSetMixin, CacheDeleteMixin)
|
||||||
|
|
||||||
|
|
||||||
# Mixin classes to be included in the generic classes
|
# Mixin classes to be included in the generic classes
|
||||||
class CachedCreateModelMixin(CacheDeleteMixin, mixins.CreateModelMixin):
|
class CachedCreateModelMixin(CacheDeleteMixin, mixins.CreateModelMixin):
|
||||||
"""
|
"""
|
||||||
A slightly modified version of rest_framework.mixins.CreateModelMixin that handles cache deletions.
|
A slightly modified version of rest_framework.mixins.CreateModelMixin
|
||||||
|
that handles cache deletions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create(self, request, *args, **kwargs):
|
def create(self, request, *args, **kwargs):
|
||||||
|
" Creates the entry in the request "
|
||||||
# Go on with the creation as normal
|
# Go on with the creation as normal
|
||||||
serializer = self.get_serializer(data=request.data)
|
serializer = self.get_serializer(data=request.data)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
self.perform_create(serializer)
|
self.perform_create(serializer)
|
||||||
headers = self.get_success_headers(serializer.data)
|
headers = self.get_success_headers(serializer.data)
|
||||||
|
|
||||||
# Delete the cache
|
# Delete the cache
|
||||||
self.delete_cache(request)
|
self.delete_cache(request)
|
||||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
return Response(serializer.data, status=status.HTTP_201_CREATED,
|
||||||
|
headers=headers)
|
||||||
|
|
||||||
|
|
||||||
class CachedRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
|
class CachedRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
|
||||||
"""
|
"""
|
||||||
A slightly modified version of rest_framework.mixins.RetrieveModelMixin that handles cache attempts.
|
A slightly modified version of rest_framework.mixins.RetrieveModelMixin
|
||||||
mixins.RetrieveModelMixin only has the retrieve method so it doesn't stand to inherit anything from it.
|
that handles cache attempts.
|
||||||
|
mixins.RetrieveModelMixin only has the retrieve method so it doesn't stand
|
||||||
|
to inherit anything from it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def retrieve(self, request, *args, **kwargs):
|
def retrieve(self, request, *args, **kwargs):
|
||||||
|
" Retrieves the entry in the request "
|
||||||
# Attempt to get the request from the cache
|
# Attempt to get the request from the cache
|
||||||
cache_attempt = self.get_cache(request)
|
cache_attempt = self.get_cache(request)
|
||||||
|
|
||||||
if cache_attempt:
|
if cache_attempt:
|
||||||
return Response(cache_attempt)
|
return Response(cache_attempt)
|
||||||
else:
|
|
||||||
# The cache get attempt failed so we have to get the results from the database
|
# The cache get attempt failed so we have to get the results from
|
||||||
instance = self.get_object()
|
# the database
|
||||||
|
instance = self.get_object()
|
||||||
serializer = self.get_serializer(instance)
|
|
||||||
response = Response(serializer.data)
|
serializer = self.get_serializer(instance)
|
||||||
|
response = Response(serializer.data)
|
||||||
self.set_cache(request, response)
|
|
||||||
return response
|
self.set_cache(request, response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin):
|
class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin):
|
||||||
"""
|
"""
|
||||||
A slightly modified version of rest_framework.mixins.UpdateModelMixin that handles cache deletes.
|
A slightly modified version of rest_framework.mixins.UpdateModelMixin that
|
||||||
|
handles cache deletes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def update(self, request, *args, **kwargs):
|
def update(self, request, *args, **kwargs):
|
||||||
|
" Updates the entry in the request "
|
||||||
partial = kwargs.pop('partial', False)
|
partial = kwargs.pop('partial', False)
|
||||||
instance = self.get_object()
|
instance = self.get_object()
|
||||||
serializer = self.get_serializer(instance, data=request.data, partial=partial)
|
serializer = self.get_serializer(instance, data=request.data,
|
||||||
|
partial=partial)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
self.perform_update(serializer)
|
self.perform_update(serializer)
|
||||||
|
|
||||||
@@ -65,7 +79,7 @@ class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin):
|
|||||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||||
# forcibly invalidate the prefetch cache on the instance.
|
# forcibly invalidate the prefetch cache on the instance.
|
||||||
instance._prefetched_objects_cache = {}
|
instance._prefetched_objects_cache = {}
|
||||||
|
|
||||||
# Delete the related caches
|
# Delete the related caches
|
||||||
self.delete_cache(request)
|
self.delete_cache(request)
|
||||||
|
|
||||||
@@ -74,15 +88,18 @@ class CachedUpdateModelMixin(CacheDeleteMixin, mixins.UpdateModelMixin):
|
|||||||
|
|
||||||
class CachedDestroyModelMixin(CacheDeleteMixin, mixins.DestroyModelMixin):
|
class CachedDestroyModelMixin(CacheDeleteMixin, mixins.DestroyModelMixin):
|
||||||
"""
|
"""
|
||||||
A slightly modified version of rest_framework.mixins.DestroyModelMixin that handles cache deletes.
|
A slightly modified version of rest_framework.mixins.DestroyModelMixin
|
||||||
|
that handles cache deletes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def destroy(self, request, *args, **kwargs):
|
def destroy(self, request, *args, **kwargs):
|
||||||
|
" Deletes the entry in the request "
|
||||||
instance = self.get_object()
|
instance = self.get_object()
|
||||||
self.perform_destroy(instance)
|
self.perform_destroy(instance)
|
||||||
|
|
||||||
# Delete the related caches
|
# Delete the related caches
|
||||||
self.delete_cache(request)
|
self.delete_cache(request)
|
||||||
|
|
||||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
|
||||||
@@ -91,21 +108,26 @@ class CachedListCreateModelMixin(CacheDeleteMixin):
|
|||||||
"""
|
"""
|
||||||
A fully custom mixin that handles mutiple instance cration.
|
A fully custom mixin that handles mutiple instance cration.
|
||||||
"""
|
"""
|
||||||
def list_create(self, request, *args, **kwargs):
|
|
||||||
|
def list_create(self, request):
|
||||||
|
" Creates the list of entries in the request "
|
||||||
# Go on with the creation as normal
|
# Go on with the creation as normal
|
||||||
serializer = self.get_serializer(data=request.data, many=True)
|
serializer = self.get_serializer(data=request.data, many=True)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
self.perform_create(serializer)
|
self.perform_create(serializer)
|
||||||
headers = self.get_success_headers(serializer.data)
|
headers = self.get_success_headers(serializer.data)
|
||||||
|
|
||||||
# Delete the cache
|
# Delete the cache
|
||||||
self.delete_cache(request)
|
self.delete_cache(request)
|
||||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
return Response(serializer.data, status=status.HTTP_201_CREATED,
|
||||||
|
headers=headers)
|
||||||
|
|
||||||
def perform_create(self, serializer):
|
def perform_create(self, serializer):
|
||||||
|
" Generic save hook "
|
||||||
serializer.save()
|
serializer.save()
|
||||||
|
|
||||||
def get_success_headers(self, data):
|
def get_success_headers(self, data):
|
||||||
|
" Returns extra success headers "
|
||||||
try:
|
try:
|
||||||
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
|
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
|
||||||
except (TypeError, KeyError):
|
except (TypeError, KeyError):
|
||||||
@@ -114,31 +136,36 @@ class CachedListCreateModelMixin(CacheDeleteMixin):
|
|||||||
|
|
||||||
class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
|
class CachedListRetrieveModelMixin(CacheGetMixin, CacheSetMixin):
|
||||||
"""
|
"""
|
||||||
A slightly modified version of rest_framework.mixins.ListModelMixin that handles cache saves.
|
A slightly modified version of rest_framework.mixins.ListModelMixin that
|
||||||
mixins.ListModelMixin only has the list method so it doesn't stand to inherit anything from it.
|
handles cache saves.
|
||||||
|
mixins.ListModelMixin only has the list method so it doesn't stand to
|
||||||
|
inherit anything from it.
|
||||||
"""
|
"""
|
||||||
def list(self, request, *args, **kwargs):
|
|
||||||
|
def list(self, request):
|
||||||
|
" Retrieves the listing of entries "
|
||||||
# Attempt to get the request from the cache
|
# Attempt to get the request from the cache
|
||||||
cache_attempt = self.get_cache(request)
|
cache_attempt = self.get_cache(request)
|
||||||
|
|
||||||
if cache_attempt:
|
if cache_attempt:
|
||||||
return Response(cache_attempt)
|
return Response(cache_attempt)
|
||||||
else:
|
|
||||||
# The cache get attempt failed so we have to get the results from the database
|
# The cache get attempt failed so we have to get the results from
|
||||||
queryset = self.filter_queryset(self.get_queryset())
|
# the database
|
||||||
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
if self.paged:
|
|
||||||
page = self.paginate_queryset(queryset)
|
if self.paged:
|
||||||
if page is not None:
|
page = self.paginate_queryset(queryset)
|
||||||
serializer = self.get_serializer(page, many=True)
|
if page is not None:
|
||||||
response = self.get_paginated_response(serializer.data)
|
serializer = self.get_serializer(page, many=True)
|
||||||
else:
|
response = self.get_paginated_response(serializer.data)
|
||||||
serializer = self.get_serializer(queryset, many=True)
|
|
||||||
response = Response(serializer.data)
|
|
||||||
else:
|
else:
|
||||||
serializer = self.get_serializer(queryset, many=True)
|
serializer = self.get_serializer(queryset, many=True)
|
||||||
response = Response(serializer.data)
|
response = Response(serializer.data)
|
||||||
|
else:
|
||||||
|
serializer = self.get_serializer(queryset, many=True)
|
||||||
|
response = Response(serializer.data)
|
||||||
|
|
||||||
self.set_cache(request, response)
|
self.set_cache(request, response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -147,12 +174,15 @@ class CachedListUpdateModelMixin(CacheDeleteMixin):
|
|||||||
"""
|
"""
|
||||||
A fully custom mixin that handles mutiple instance updates.
|
A fully custom mixin that handles mutiple instance updates.
|
||||||
"""
|
"""
|
||||||
def list_update(self, request, *args, **kwargs):
|
|
||||||
|
def list_update(self, request, **kwargs):
|
||||||
|
" Updates the list of entries in the request "
|
||||||
partial = kwargs.pop('partial', False)
|
partial = kwargs.pop('partial', False)
|
||||||
|
|
||||||
queryset = self.filter_queryset(self.get_queryset())
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
|
|
||||||
serializer = self.get_serializer(queryset, data=request.data, partial=partial, many=True)
|
serializer = self.get_serializer(queryset, data=request.data,
|
||||||
|
partial=partial, many=True)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
self.perform_update(serializer)
|
self.perform_update(serializer)
|
||||||
|
|
||||||
@@ -162,9 +192,11 @@ class CachedListUpdateModelMixin(CacheDeleteMixin):
|
|||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
def perform_update(self, serializer):
|
def perform_update(self, serializer):
|
||||||
|
" Generic save hook "
|
||||||
serializer.save()
|
serializer.save()
|
||||||
|
|
||||||
def list_partial_update(self, request, *args, **kwargs):
|
def list_partial_update(self, request, *args, **kwargs):
|
||||||
|
" Needs to be called on partial updates "
|
||||||
kwargs['partial'] = True
|
kwargs['partial'] = True
|
||||||
return self.list_update(request, *args, **kwargs)
|
return self.list_update(request, *args, **kwargs)
|
||||||
|
|
||||||
@@ -173,39 +205,22 @@ class CachedListDestroyModelMixin(CacheDeleteMixin):
|
|||||||
"""
|
"""
|
||||||
A fully custom mixin that handles mutiple instance deletions.
|
A fully custom mixin that handles mutiple instance deletions.
|
||||||
"""
|
"""
|
||||||
def list_destroy(self, request, *args, **kwargs):
|
|
||||||
|
def list_destroy(self, request):
|
||||||
|
" Deletes the list of entries in the request "
|
||||||
# Go on with the validation as normal
|
# Go on with the validation as normal
|
||||||
serializer = self.get_serializer(data=request.data, many=True)
|
serializer = self.get_serializer(data=request.data, many=True)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
validated_data = serializer.validated_data
|
validated_data = serializer.validated_data
|
||||||
|
|
||||||
# TODO does this new stuff work even? need to check on the frontend
|
# TODO does this new stuff work even? need to check on the frontend
|
||||||
serializer.delete(validated_data)
|
serializer.delete(validated_data)
|
||||||
|
|
||||||
# for instance in self.get_objects():
|
# for instance in self.get_objects():
|
||||||
# if instance is not None:
|
# if instance is not None:
|
||||||
# self.perform_destroy(instance)
|
# self.perform_destroy(instance)
|
||||||
|
|
||||||
# Delete the related caches
|
# Delete the related caches
|
||||||
self.delete_cache(request)
|
self.delete_cache(request)
|
||||||
|
|
||||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
#def perform_destroy(self, instance):
|
|
||||||
# instance.delete()
|
|
||||||
|
|
||||||
#def get_objects(self):
|
|
||||||
# """
|
|
||||||
# The custom list version of get_object that retrieves one instance from the #database. It yields model instances with each call.
|
|
||||||
# """
|
|
||||||
# queryset = self.filter_queryset(self.get_queryset())
|
|
||||||
#
|
|
||||||
# if len(queryset):
|
|
||||||
# for obj in queryset.all():
|
|
||||||
|
|
||||||
# # May raise a permission denied
|
|
||||||
# self.check_object_permissions(self.request, obj)
|
|
||||||
|
|
||||||
# yield obj
|
|
||||||
|
|
||||||
#yield None
|
|
||||||
|
|||||||
@@ -1,22 +1,30 @@
|
|||||||
|
"""
|
||||||
|
Utility functions for the library
|
||||||
|
"""
|
||||||
|
|
||||||
def sorted_params_string(filters_dict):
|
def sorted_params_string(filters_dict):
|
||||||
"""
|
"""
|
||||||
This function takes a dict and returns it in a sorted form for the url filter, it's primarily used for cache purposes.
|
This function takes a dict and returns it in a sorted form for the url
|
||||||
|
filter, it's primarily used for cache purposes.
|
||||||
"""
|
"""
|
||||||
filters_string = ''
|
filters_string = ''
|
||||||
for key in sorted(filters_dict.keys()):
|
for key in sorted(filters_dict.keys()):
|
||||||
if filters_string == '':
|
if filters_string == '':
|
||||||
filters_string = f"{key}={','.join(str(val) for val in sorted(filters_dict[key]))}"
|
key_str = ','.join(str(val) for val in sorted(filters_dict[key]))
|
||||||
|
filters_string = f"{key}={key_str}"
|
||||||
else:
|
else:
|
||||||
filters_string = f"{filters_string}&{key}={','.join(str(val) for val in sorted(filters_dict[key]))}"
|
key_str = ','.join(str(val) for val in sorted(filters_dict[key]))
|
||||||
|
filters_string = f"{filters_string}&{key}={key_str}"
|
||||||
filters_string = filters_string.strip()
|
filters_string = filters_string.strip()
|
||||||
return filters_string
|
return filters_string
|
||||||
|
|
||||||
|
|
||||||
def parse_tags_to_dict(tags):
|
def parse_tags_to_dict(tags):
|
||||||
|
" This function parses a tag string into a dictionary "
|
||||||
tagdict = {}
|
tagdict = {}
|
||||||
if ':' not in tags:
|
if ':' not in tags:
|
||||||
tagdict = {}
|
tagdict = {}
|
||||||
else:
|
else:
|
||||||
for subtag in tags.split('&'):
|
for subtag in tags.split('&'):
|
||||||
tagkey, taglist = subtag.split(':')
|
tagkey, taglist = subtag.split(':')
|
||||||
taglist = taglist.split(',')
|
taglist = taglist.split(',')
|
||||||
|
|||||||
Reference in New Issue
Block a user