Add allowed_langs filter

This commit is contained in:
2022-01-02 20:15:55 +03:00
parent cbba30f2af
commit 017cc05a19
9 changed files with 170 additions and 53 deletions

View File

@@ -1,11 +1,20 @@
from fastapi import Security, HTTPException, status from typing import Optional
from fastapi import Security, HTTPException, Query, status
from core.auth import default_security from core.auth import default_security
from core.config import env_config from core.config import env_config
async def check_token(api_key: str = Security(default_security)): def check_token(api_key: str = Security(default_security)):
if api_key != env_config.API_KEY: if api_key != env_config.API_KEY:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Wrong api key!" status_code=status.HTTP_403_FORBIDDEN, detail="Wrong api key!"
) )
def get_allowed_langs(allowed_langs: Optional[list[str]] = Query(None)) -> list[str]:
if allowed_langs is not None:
return allowed_langs
return ["ru", "be", "uk"]

View File

@@ -1,10 +1,19 @@
from typing import Optional from typing import Optional
from fastapi.params import Query
def get_book_filter(is_deleted: Optional[bool] = None) -> dict: from app.depends import get_allowed_langs
def get_book_filter(
is_deleted: Optional[bool] = None, allowed_langs: Optional[list[str]] = Query(None)
) -> dict:
result = {} result = {}
if is_deleted is not None: if is_deleted is not None:
result["is_deleted"] = is_deleted result["is_deleted"] = is_deleted
if not (allowed_langs and "__all__" in allowed_langs):
result["lang__in"] = get_allowed_langs(allowed_langs)
return result return result

View File

@@ -17,7 +17,10 @@ SELECT ARRAY(
) as sml, ) as sml,
( (
SELECT count(*) FROM book_authors SELECT count(*) FROM book_authors
LEFT JOIN books ON (books.id = book AND books.is_deleted = 'f') LEFT JOIN books
ON (books.id = book AND
books.is_deleted = 'f' AND
books.lang = ANY(:langs ::text[]))
WHERE author = authors.id WHERE author = authors.id
) as books_count ) as books_count
FROM authors FROM authors
@@ -28,7 +31,10 @@ SELECT ARRAY(
) AND ) AND
EXISTS ( EXISTS (
SELECT * FROM book_authors SELECT * FROM book_authors
LEFT JOIN books ON (books.id = book AND books.is_deleted = 'f') LEFT JOIN books
ON (books.id = book AND
books.is_deleted = 'f' AND
books.lang = ANY(:langs ::text[]))
WHERE author = authors.id WHERE author = authors.id
) )
) )
@@ -45,5 +51,23 @@ class AuthorTGRMSearchService(TRGMSearchService):
GET_OBJECT_IDS_QUERY = GET_OBJECT_IDS_QUERY GET_OBJECT_IDS_QUERY = GET_OBJECT_IDS_QUERY
GET_RANDOM_OBJECT_ID_QUERY = """
WITH filtered_authors AS (
SELECT id FROM authors
WHERE EXISTS (
SELECT * FROM book_authors
LEFT JOIN books
ON (books.id = book AND
books.is_deleted = 'f' AND
books.lang = ANY(:langs ::text[]))
WHERE author = authors.id
)
)
SELECT id FROM filtered_authors
ORDER BY RANDOM() LIMIT 1;
"""
class GetRandomAuthorService(GetRandomService): class GetRandomAuthorService(GetRandomService):
MODEL_CLASS = Author MODEL_CLASS = Author
GET_RANDOM_OBJECT_ID_QUERY = GET_RANDOM_OBJECT_ID_QUERY

View File

@@ -13,6 +13,7 @@ SELECT ARRAY(
WITH filtered_books AS ( WITH filtered_books AS (
SELECT id, similarity(title, :query) as sml FROM books SELECT id, similarity(title, :query) as sml FROM books
WHERE books.title % :query AND books.is_deleted = 'f' WHERE books.title % :query AND books.is_deleted = 'f'
AND books.lang = ANY(:langs ::text[])
) )
SELECT fbooks.id FROM filtered_books as fbooks SELECT fbooks.id FROM filtered_books as fbooks
ORDER BY fbooks.sml DESC, fbooks.id ORDER BY fbooks.sml DESC, fbooks.id
@@ -76,5 +77,16 @@ class BookCreator:
return await cls._create_remote_book(data) return await cls._create_remote_book(data)
GET_RANDOM_OBJECT_ID_QUERY = """
WITH filtered_books AS (
SELECT id FROM books
WHERE books.is_deleted = 'f' AND books.lang = ANY(:langs ::text[])
)
SELECT id FROM filtered_books
ORDER BY RANDOM() LIMIT 1;
"""
class GetRandomBookService(GetRandomService): class GetRandomBookService(GetRandomService):
MODEL_CLASS = BookDB MODEL_CLASS = BookDB
GET_RANDOM_OBJECT_ID_QUERY = GET_RANDOM_OBJECT_ID_QUERY

View File

@@ -54,8 +54,12 @@ class TRGMSearchService(Generic[T]):
return cls.GET_OBJECT_IDS_QUERY return cls.GET_OBJECT_IDS_QUERY
@classmethod @classmethod
async def _get_object_ids(cls, query_data: str) -> list[int]: async def _get_object_ids(
row = await cls.database.fetch_one(cls.object_ids_query, {"query": query_data}) cls, query_data: str, allowed_langs: list[str]
) -> list[int]:
row = await cls.database.fetch_one(
cls.object_ids_query, {"query": query_data, "langs": allowed_langs}
)
if row is None: if row is None:
raise ValueError("Something is wrong!") raise ValueError("Something is wrong!")
@@ -63,16 +67,20 @@ class TRGMSearchService(Generic[T]):
return row["array"] return row["array"]
@classmethod @classmethod
def get_cache_key(cls, query_data: str) -> str: def get_cache_key(cls, query_data: str, allowed_langs: list[str]) -> str:
model_class_name = cls.model.__class__.__name__ model_class_name = cls.model.__class__.__name__
return f"{model_class_name}_{query_data}" allowed_langs_part = ",".join(allowed_langs)
return f"{model_class_name}_{query_data}_{allowed_langs_part}"
@classmethod @classmethod
async def get_cached_ids( async def get_cached_ids(
cls, query_data: str, redis: aioredis.Redis cls,
query_data: str,
allowed_langs: list[str],
redis: aioredis.Redis,
) -> Optional[list[int]]: ) -> Optional[list[int]]:
try: try:
key = cls.get_cache_key(query_data) key = cls.get_cache_key(query_data, allowed_langs)
data = await redis.get(key) data = await redis.get(key)
if data is None: if data is None:
@@ -85,25 +93,32 @@ class TRGMSearchService(Generic[T]):
@classmethod @classmethod
async def cache_object_ids( async def cache_object_ids(
cls, query_data: str, object_ids: list[int], redis: aioredis.Redis cls,
query_data: str,
allowed_langs: list[str],
object_ids: list[int],
redis: aioredis.Redis,
): ):
try: try:
key = cls.get_cache_key(query_data) key = cls.get_cache_key(query_data, allowed_langs)
await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL) await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL)
except aioredis.RedisError as e: except aioredis.RedisError as e:
print(e) print(e)
@classmethod @classmethod
async def get_objects( async def get_objects(
cls, query_data: str, redis: aioredis.Redis cls,
query_data: str,
redis: aioredis.Redis,
allowed_langs: list[str],
) -> tuple[int, list[T]]: ) -> tuple[int, list[T]]:
params = cls.get_raw_params() params = cls.get_raw_params()
cached_object_ids = await cls.get_cached_ids(query_data, redis) cached_object_ids = await cls.get_cached_ids(query_data, allowed_langs, redis)
if cached_object_ids is None: if cached_object_ids is None:
object_ids = await cls._get_object_ids(query_data) object_ids = await cls._get_object_ids(query_data, allowed_langs)
await cls.cache_object_ids(query_data, object_ids, redis) await cls.cache_object_ids(query_data, allowed_langs, object_ids, redis)
else: else:
object_ids = cached_object_ids object_ids = cached_object_ids
@@ -120,23 +135,19 @@ class TRGMSearchService(Generic[T]):
return len(object_ids), await queryset.filter(id__in=limited_object_ids).all() return len(object_ids), await queryset.filter(id__in=limited_object_ids).all()
@classmethod @classmethod
async def get(cls, query: str, redis: aioredis.Redis) -> Page[T]: async def get(
cls, query: str, redis: aioredis.Redis, allowed_langs: list[str]
) -> Page[T]:
params = cls.get_params() params = cls.get_params()
total, objects = await cls.get_objects(query, redis) total, objects = await cls.get_objects(query, redis, allowed_langs)
return CustomPage.create(items=objects, total=total, params=params) return CustomPage.create(items=objects, total=total, params=params)
GET_RANDOM_OBJECT_ID_QUERY = """
SELECT id FROM {table}
WHERE id >= RANDOM() * (SELECT MAX(id) FROM {table})
ORDER BY id LIMIT 1;
"""
class GetRandomService(Generic[T]): class GetRandomService(Generic[T]):
MODEL_CLASS: Optional[T] = None MODEL_CLASS: Optional[T] = None
GET_RANDOM_OBJECT_ID_QUERY: Optional[str] = None
@classmethod @classmethod
@property @property
@@ -150,7 +161,15 @@ class GetRandomService(Generic[T]):
return cls.model.Meta.database return cls.model.Meta.database
@classmethod @classmethod
async def get_random_id(cls) -> int: @property
table_name = cls.model.Meta.tablename def random_object_id_query(cls) -> str:
query = GET_RANDOM_OBJECT_ID_QUERY.format(table=table_name) assert (
return await cls.database.fetch_val(query) cls.GET_RANDOM_OBJECT_ID_QUERY is not None
), f"GET_OBJECT_IDS_QUERY in {cls.__name__} don't set!"
return cls.GET_RANDOM_OBJECT_ID_QUERY
@classmethod
async def get_random_id(cls, allowed_langs: list[str]) -> int:
return await cls.database.fetch_val(
cls.random_object_id_query, {"langs": allowed_langs}
)

View File

@@ -10,14 +10,20 @@ SELECT ARRAY (
similarity(name, :query) as sml, similarity(name, :query) as sml,
( (
SELECT count(*) FROM book_sequences SELECT count(*) FROM book_sequences
LEFT JOIN books ON (books.id = book AND books.is_deleted = 'f') LEFT JOIN books
ON (books.id = book AND
books.is_deleted = 'f' AND
books.lang = ANY(:langs ::text[]))
WHERE sequence = sequences.id WHERE sequence = sequences.id
) as books_count ) as books_count
FROM sequences FROM sequences
WHERE name % :query AND WHERE name % :query AND
EXISTS ( EXISTS (
SELECT * FROM book_sequences SELECT * FROM book_sequences
LEFT JOIN books ON (books.id = book AND books.is_deleted = 'f') LEFT JOIN books
ON (books.id = book AND
books.is_deleted = 'f' AND
books.lang = ANY(:langs ::text[]))
WHERE sequence = sequences.id WHERE sequence = sequences.id
) )
) )
@@ -34,5 +40,23 @@ class SequenceTGRMSearchService(TRGMSearchService):
GET_OBJECT_IDS_QUERY = GET_OBJECT_IDS_QUERY GET_OBJECT_IDS_QUERY = GET_OBJECT_IDS_QUERY
GET_RANDOM_OBJECT_ID_QUERY = """
WITH filtered_sequences AS (
SELECT id FROM sequences
WHERE EXISTS (
SELECT * FROM book_sequences
LEFT JOIN books
ON (books.id = book AND
books.is_deleted = 'f' AND
books.lang = ANY(:langs ::text[]))
WHERE sequence = sequences.id
)
)
SELECT id FROM filtered_sequences
ORDER BY RANDOM() LIMIT 1;
"""
class GetRandomSequenceService(GetRandomService): class GetRandomSequenceService(GetRandomService):
MODEL_CLASS = Sequence MODEL_CLASS = Sequence
GET_RANDOM_OBJECT_ID_QUERY = GET_RANDOM_OBJECT_ID_QUERY

View File

@@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, Request, HTTPException, status
from fastapi_pagination import Params from fastapi_pagination import Params
from fastapi_pagination.ext.ormar import paginate from fastapi_pagination.ext.ormar import paginate
from app.depends import check_token from app.depends import check_token, get_allowed_langs
from app.models import Author as AuthorDB from app.models import Author as AuthorDB
from app.models import AuthorAnnotation as AuthorAnnotationDB from app.models import AuthorAnnotation as AuthorAnnotationDB
from app.models import Book as BookDB from app.models import Book as BookDB
@@ -44,8 +44,8 @@ async def create_author(data: CreateAuthor):
@author_router.get("/random", response_model=Author) @author_router.get("/random", response_model=Author)
async def get_random_author(): async def get_random_author(allowed_langs: list[str] = Depends(get_allowed_langs)):
author_id = await GetRandomAuthorService.get_random_id() author_id = await GetRandomAuthorService.get_random_id(allowed_langs)
return await AuthorDB.objects.prefetch_related(PREFETCH_RELATED).get(id=author_id) return await AuthorDB.objects.prefetch_related(PREFETCH_RELATED).get(id=author_id)
@@ -87,19 +87,25 @@ async def get_author_annotation(id: int):
@author_router.get( @author_router.get(
"/{id}/books", response_model=CustomPage[AuthorBook], dependencies=[Depends(Params)] "/{id}/books", response_model=CustomPage[AuthorBook], dependencies=[Depends(Params)]
) )
async def get_author_books(id: int): async def get_author_books(
id: int, allowed_langs: list[str] = Depends(get_allowed_langs)
):
return await paginate( return await paginate(
BookDB.objects.select_related(["source", "annotations", "translators"]) BookDB.objects.select_related(["source", "annotations", "translators"])
.filter(authors__id=id) .filter(authors__id=id, lang__in=allowed_langs, is_deleted=False)
.order_by("title") .order_by("title")
) )
@author_router.get("/{id}/translated_books", response_model=CustomPage[TranslatedBook]) @author_router.get("/{id}/translated_books", response_model=CustomPage[TranslatedBook])
async def get_translated_books(id: int): async def get_translated_books(
id: int, allowed_langs: list[str] = Depends(get_allowed_langs)
):
return await paginate( return await paginate(
BookDB.objects.select_related(["source", "annotations", "translators"]).filter( BookDB.objects.select_related(["source", "annotations", "translators"]).filter(
translations__translator__id=id translations__translator__id=id,
lang__in=allowed_langs,
is_deleted=False,
) )
) )
@@ -107,5 +113,9 @@ async def get_translated_books(id: int):
@author_router.get( @author_router.get(
"/search/{query}", response_model=CustomPage[Author], dependencies=[Depends(Params)] "/search/{query}", response_model=CustomPage[Author], dependencies=[Depends(Params)]
) )
async def search_authors(query: str, request: Request): async def search_authors(
return await AuthorTGRMSearchService.get(query, request.app.state.redis) query: str, request: Request, allowed_langs: list[str] = Depends(get_allowed_langs)
):
return await AuthorTGRMSearchService.get(
query, request.app.state.redis, allowed_langs
)

View File

@@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, Request, HTTPException, status
from fastapi_pagination import Params from fastapi_pagination import Params
from fastapi_pagination.ext.ormar import paginate from fastapi_pagination.ext.ormar import paginate
from app.depends import check_token from app.depends import check_token, get_allowed_langs
from app.filters.book import get_book_filter from app.filters.book import get_book_filter
from app.models import Author as AuthorDB from app.models import Author as AuthorDB
from app.models import Book as BookDB from app.models import Book as BookDB
@@ -50,8 +50,8 @@ async def create_book(data: Union[CreateBook, CreateRemoteBook]):
@book_router.get("/random", response_model=BookDetail) @book_router.get("/random", response_model=BookDetail)
async def get_random_book(): async def get_random_book(allowed_langs: list[str] = Depends(get_allowed_langs)):
book_id = await GetRandomBookService.get_random_id() book_id = await GetRandomBookService.get_random_id(allowed_langs)
return await BookDB.objects.select_related(SELECT_RELATED_FIELDS).get(id=book_id) return await BookDB.objects.select_related(SELECT_RELATED_FIELDS).get(id=book_id)
@@ -114,5 +114,9 @@ async def get_book_annotation(id: int):
@book_router.get( @book_router.get(
"/search/{query}", response_model=CustomPage[Book], dependencies=[Depends(Params)] "/search/{query}", response_model=CustomPage[Book], dependencies=[Depends(Params)]
) )
async def search_books(query: str, request: Request): async def search_books(
return await BookTGRMSearchService.get(query, request.app.state.redis) query: str, request: Request, allowed_langs: list[str] = Depends(get_allowed_langs)
):
return await BookTGRMSearchService.get(
query, request.app.state.redis, allowed_langs
)

View File

@@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, Request
from fastapi_pagination import Params from fastapi_pagination import Params
from fastapi_pagination.ext.ormar import paginate from fastapi_pagination.ext.ormar import paginate
from app.depends import check_token from app.depends import check_token, get_allowed_langs
from app.models import Book as BookDB from app.models import Book as BookDB
from app.models import Sequence as SequenceDB from app.models import Sequence as SequenceDB
from app.serializers.sequence import Book as SequenceBook from app.serializers.sequence import Book as SequenceBook
@@ -27,8 +27,8 @@ async def get_sequences():
@sequence_router.get("/random", response_model=Sequence) @sequence_router.get("/random", response_model=Sequence)
async def get_random_sequence(): async def get_random_sequence(allowed_langs: list[str] = Depends(get_allowed_langs)):
sequence_id = await GetRandomSequenceService.get_random_id() sequence_id = await GetRandomSequenceService.get_random_id(allowed_langs)
return await SequenceDB.objects.get(id=sequence_id) return await SequenceDB.objects.get(id=sequence_id)
@@ -43,12 +43,14 @@ async def get_sequence(id: int):
response_model=CustomPage[SequenceBook], response_model=CustomPage[SequenceBook],
dependencies=[Depends(Params)], dependencies=[Depends(Params)],
) )
async def get_sequence_books(id: int): async def get_sequence_books(
id: int, allowed_langs: list[str] = Depends(get_allowed_langs)
):
return await paginate( return await paginate(
BookDB.objects.select_related( BookDB.objects.select_related(
["source", "annotations", "authors", "translators"] ["source", "annotations", "authors", "translators"]
) )
.filter(sequences__id=id) .filter(sequences__id=id, lang__in=allowed_langs, is_deleted=False)
.order_by("sequences__booksequences__position") .order_by("sequences__booksequences__position")
) )
@@ -63,5 +65,9 @@ async def create_sequence(data: CreateSequence):
response_model=CustomPage[Sequence], response_model=CustomPage[Sequence],
dependencies=[Depends(Params)], dependencies=[Depends(Params)],
) )
async def search_sequences(query: str, request: Request): async def search_sequences(
return await SequenceTGRMSearchService.get(query, request.app.state.redis) query: str, request: Request, allowed_langs: list[str] = Depends(get_allowed_langs)
):
return await SequenceTGRMSearchService.get(
query, request.app.state.redis, allowed_langs
)