improve REST API filter parameters and pagination

This commit is contained in:
Nick Sweeting 2024-08-20 01:56:37 -07:00
parent 850448b42c
commit 54acfd9f86
No known key found for this signature in database

View file

@ -1,14 +1,17 @@
__package__ = 'archivebox.api' __package__ = 'archivebox.api'
import math
from uuid import UUID from uuid import UUID
from typing import List, Optional from typing import List, Optional, Union, Any
from datetime import datetime from datetime import datetime
from django.db.models import Q from django.db.models import Q
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.core.exceptions import ValidationError
from django.contrib.auth import get_user_model
from ninja import Router, Schema, FilterSchema, Field, Query from ninja import Router, Schema, FilterSchema, Field, Query
from ninja.pagination import paginate from ninja.pagination import paginate, PaginationBase
from core.models import Snapshot, ArchiveResult, Tag from core.models import Snapshot, ArchiveResult, Tag
from abid_utils.abid import ABID from abid_utils.abid import ABID
@ -17,10 +20,45 @@ router = Router(tags=['Core Models'])
class CustomPagination(PaginationBase):
class Input(Schema):
limit: int = 200
offset: int = 0
page: int = 0
class Output(Schema):
total_items: int
total_pages: int
page: int
limit: int
offset: int
num_items: int
items: List[Any]
def paginate_queryset(self, queryset, pagination: Input, **params):
limit = min(pagination.limit, 500)
offset = pagination.offset or (pagination.page * limit)
total = queryset.count()
total_pages = math.ceil(total / limit)
current_page = math.ceil(offset / (limit + 1))
items = queryset[offset : offset + limit]
return {
'total_items': total,
'total_pages': total_pages,
'page': current_page,
'limit': limit,
'offset': offset,
'num_items': len(items),
'items': items,
}
### ArchiveResult ######################################################################### ### ArchiveResult #########################################################################
class ArchiveResultSchema(Schema): class ArchiveResultSchema(Schema):
TYPE: str = 'core.models.ArchiveResult'
id: UUID id: UUID
old_id: int old_id: int
abid: str abid: str
@ -28,8 +66,10 @@ class ArchiveResultSchema(Schema):
modified: datetime modified: datetime
created: datetime created: datetime
created_by_id: str created_by_id: str
created_by_username: str
snapshot_abid: str snapshot_abid: str
snapshot_timestamp: str
snapshot_url: str snapshot_url: str
snapshot_tags: str snapshot_tags: str
@ -43,6 +83,11 @@ class ArchiveResultSchema(Schema):
@staticmethod @staticmethod
def resolve_created_by_id(obj): def resolve_created_by_id(obj):
return str(obj.created_by_id) return str(obj.created_by_id)
@staticmethod
def resolve_created_by_username(obj):
User = get_user_model()
return User.objects.get(id=obj.created_by_id).username
@staticmethod @staticmethod
def resolve_pk(obj): def resolve_pk(obj):
@ -60,6 +105,10 @@ class ArchiveResultSchema(Schema):
def resolve_created(obj): def resolve_created(obj):
return obj.start_ts return obj.start_ts
@staticmethod
def resolve_snapshot_timestamp(obj):
return obj.snapshot.timestamp
@staticmethod @staticmethod
def resolve_snapshot_url(obj): def resolve_snapshot_url(obj):
return obj.snapshot.url return obj.snapshot.url
@ -74,10 +123,10 @@ class ArchiveResultSchema(Schema):
class ArchiveResultFilterSchema(FilterSchema): class ArchiveResultFilterSchema(FilterSchema):
id: Optional[UUID] = Field(None, q='id') id: Optional[str] = Field(None, q=['id__startswith', 'abid__icontains', 'old_id__startswith', 'snapshot__id__startswith', 'snapshot__abid__icontains', 'snapshot__timestamp__startswith'])
search: Optional[str] = Field(None, q=['snapshot__url__icontains', 'snapshot__title__icontains', 'snapshot__tags__name__icontains', 'extractor', 'output__icontains']) search: Optional[str] = Field(None, q=['snapshot__url__icontains', 'snapshot__title__icontains', 'snapshot__tags__name__icontains', 'extractor', 'output__icontains', 'id__startswith', 'abid__icontains', 'old_id__startswith', 'snapshot__id__startswith', 'snapshot__abid__icontains', 'snapshot__timestamp__startswith'])
snapshot_id: Optional[UUID] = Field(None, q='snapshot_id__icontains') snapshot_id: Optional[str] = Field(None, q=['snapshot__id__startswith', 'snapshot__abid__icontains', 'snapshot__timestamp__startswith'])
snapshot_url: Optional[str] = Field(None, q='snapshot__url__icontains') snapshot_url: Optional[str] = Field(None, q='snapshot__url__icontains')
snapshot_tag: Optional[str] = Field(None, q='snapshot__tags__name__icontains') snapshot_tag: Optional[str] = Field(None, q='snapshot__tags__name__icontains')
@ -94,11 +143,11 @@ class ArchiveResultFilterSchema(FilterSchema):
@router.get("/archiveresults", response=List[ArchiveResultSchema], url_name="get_archiveresult") @router.get("/archiveresults", response=List[ArchiveResultSchema], url_name="get_archiveresult")
@paginate @paginate(CustomPagination)
def get_archiveresults(request, filters: ArchiveResultFilterSchema = Query(...)): def get_archiveresults(request, filters: ArchiveResultFilterSchema = Query(...)):
"""List all ArchiveResult entries matching these filters.""" """List all ArchiveResult entries matching these filters."""
qs = ArchiveResult.objects.all() qs = ArchiveResult.objects.all()
results = filters.filter(qs) results = filters.filter(qs).distinct()
return results return results
@ -137,6 +186,8 @@ def get_archiveresult(request, archiveresult_id: str):
class SnapshotSchema(Schema): class SnapshotSchema(Schema):
TYPE: str = 'core.models.Snapshot'
id: UUID id: UUID
old_id: UUID old_id: UUID
abid: str abid: str
@ -144,6 +195,7 @@ class SnapshotSchema(Schema):
modified: datetime modified: datetime
created: datetime created: datetime
created_by_id: str created_by_id: str
created_by_username: str
url: str url: str
tags: str tags: str
@ -161,6 +213,11 @@ class SnapshotSchema(Schema):
@staticmethod @staticmethod
def resolve_created_by_id(obj): def resolve_created_by_id(obj):
return str(obj.created_by_id) return str(obj.created_by_id)
@staticmethod
def resolve_created_by_username(obj):
User = get_user_model()
return User.objects.get(id=obj.created_by_id).username
@staticmethod @staticmethod
def resolve_pk(obj): def resolve_pk(obj):
@ -190,11 +247,13 @@ class SnapshotSchema(Schema):
class SnapshotFilterSchema(FilterSchema): class SnapshotFilterSchema(FilterSchema):
id: Optional[str] = Field(None, q='id__icontains') id: Optional[str] = Field(None, q=['id__icontains', 'abid__icontains', 'old_id__icontains', 'timestamp__startswith'])
old_id: Optional[str] = Field(None, q='old_id__icontains') old_id: Optional[str] = Field(None, q='old_id__icontains')
abid: Optional[str] = Field(None, q='abid__icontains') abid: Optional[str] = Field(None, q='abid__icontains')
created_by_id: str = Field(None, q='created_by_id__icontains') created_by_id: str = Field(None, q='created_by_id')
created_by_username: str = Field(None, q='created_by__username__icontains')
created__gte: datetime = Field(None, q='created__gte') created__gte: datetime = Field(None, q='created__gte')
created__lt: datetime = Field(None, q='created__lt') created__lt: datetime = Field(None, q='created__lt')
@ -203,7 +262,7 @@ class SnapshotFilterSchema(FilterSchema):
modified__gte: datetime = Field(None, q='modified__gte') modified__gte: datetime = Field(None, q='modified__gte')
modified__lt: datetime = Field(None, q='modified__lt') modified__lt: datetime = Field(None, q='modified__lt')
search: Optional[str] = Field(None, q=['url__icontains', 'title__icontains', 'tags__name__icontains', 'id__icontains', 'abid__icontains', 'old_id__icontains']) search: Optional[str] = Field(None, q=['url__icontains', 'title__icontains', 'tags__name__icontains', 'id__icontains', 'abid__icontains', 'old_id__icontains', 'timestamp__startswith'])
url: Optional[str] = Field(None, q='url') url: Optional[str] = Field(None, q='url')
tag: Optional[str] = Field(None, q='tags__name') tag: Optional[str] = Field(None, q='tags__name')
title: Optional[str] = Field(None, q='title__icontains') title: Optional[str] = Field(None, q='title__icontains')
@ -215,13 +274,13 @@ class SnapshotFilterSchema(FilterSchema):
@router.get("/snapshots", response=List[SnapshotSchema], url_name="get_snapshots") @router.get("/snapshots", response=List[SnapshotSchema], url_name="get_snapshots")
@paginate @paginate(CustomPagination)
def get_snapshots(request, filters: SnapshotFilterSchema = Query(...), with_archiveresults: bool=True): def get_snapshots(request, filters: SnapshotFilterSchema = Query(...), with_archiveresults: bool=False):
"""List all Snapshot entries matching these filters.""" """List all Snapshot entries matching these filters."""
request.with_archiveresults = with_archiveresults request.with_archiveresults = with_archiveresults
qs = Snapshot.objects.all() qs = Snapshot.objects.all()
results = filters.filter(qs) results = filters.filter(qs).distinct()
return results return results
@router.get("/snapshot/{snapshot_id}", response=SnapshotSchema, url_name="get_snapshot") @router.get("/snapshot/{snapshot_id}", response=SnapshotSchema, url_name="get_snapshot")
@ -230,12 +289,7 @@ def get_snapshot(request, snapshot_id: str, with_archiveresults: bool=True):
request.with_archiveresults = with_archiveresults request.with_archiveresults = with_archiveresults
snapshot = None snapshot = None
try: try:
snapshot = Snapshot.objects.get(Q(abid__startswith=snapshot_id) | Q(id__startswith=snapshot_id) | Q(old_id__startswith=snapshot_id)) snapshot = Snapshot.objects.get(Q(abid__startswith=snapshot_id) | Q(id__startswith=snapshot_id) | Q(old_id__startswith=snapshot_id) | Q(timestamp__startswith=snapshot_id))
except Snapshot.DoesNotExist:
pass
try:
snapshot = snapshot or Snapshot.objects.get()
except Snapshot.DoesNotExist: except Snapshot.DoesNotExist:
pass pass
@ -244,6 +298,9 @@ def get_snapshot(request, snapshot_id: str, with_archiveresults: bool=True):
except Snapshot.DoesNotExist: except Snapshot.DoesNotExist:
pass pass
if not snapshot:
raise Snapshot.DoesNotExist
return snapshot return snapshot
@ -274,25 +331,94 @@ def get_snapshot(request, snapshot_id: str, with_archiveresults: bool=True):
class TagSchema(Schema): class TagSchema(Schema):
abid: Optional[UUID] = Field(None, q='abid') TYPE: str = 'core.models.Tag'
uuid: Optional[UUID] = Field(None, q='uuid')
pk: Optional[UUID] = Field(None, q='pk') id: UUID
old_id: str
abid: str
modified: datetime modified: datetime
created: datetime created: datetime
created_by_id: str created_by_id: str
created_by_username: str
name: str name: str
slug: str slug: str
num_snapshots: int
snapshots: List[SnapshotSchema]
@staticmethod
def resolve_old_id(obj):
return str(obj.old_id)
@staticmethod @staticmethod
def resolve_created_by_id(obj): def resolve_created_by_id(obj):
return str(obj.created_by_id) return str(obj.created_by_id)
@staticmethod
def resolve_created_by_username(obj):
User = get_user_model()
return User.objects.get(id=obj.created_by_id).username
@staticmethod
def resolve_num_snapshots(obj, context):
return obj.snapshot_set.all().distinct().count()
@staticmethod
def resolve_snapshots(obj, context):
if context['request'].with_snapshots:
return obj.snapshot_set.all().distinct()
return Snapshot.objects.none()
@router.get("/tags", response=List[TagSchema], url_name="get_tags") @router.get("/tags", response=List[TagSchema], url_name="get_tags")
@paginate(CustomPagination)
def get_tags(request): def get_tags(request):
return Tag.objects.all() request.with_snapshots = False
request.with_archiveresults = False
return Tag.objects.all().distinct()
@router.get("/tag/{tag_id}", response=TagSchema, url_name="get_tag") @router.get("/tag/{tag_id}", response=TagSchema, url_name="get_tag")
def get_tag(request, tag_id: str): def get_tag(request, tag_id: str, with_snapshots: bool=True):
return Tag.objects.get(id=tag_id) request.with_snapshots = with_snapshots
request.with_archiveresults = False
tag = None
try:
tag = tag or Tag.objects.get(old_id__icontains=tag_id)
except (Tag.DoesNotExist, ValidationError, ValueError):
pass
try:
tag = Tag.objects.get(abid__icontains=tag_id)
except (Tag.DoesNotExist, ValidationError):
pass
try:
tag = tag or Tag.objects.get(id__icontains=tag_id)
except (Tag.DoesNotExist, ValidationError):
pass
return tag
@router.get("/any/{abid}", response=Union[SnapshotSchema, ArchiveResultSchema, TagSchema], url_name="get_any")
def get_any(request, abid: str):
request.with_snapshots = False
request.with_archiveresults = False
response = None
try:
response = response or get_snapshot(request, abid)
except Exception:
pass
try:
response = response or get_archiveresult(request, abid)
except Exception:
pass
try:
response = response or get_tag(request, abid)
except Exception:
pass
return response