Files
book_library_server/fastapi_book_server/app/services/common.py
2022-04-02 17:58:29 +03:00

347 lines
10 KiB
Python

import abc
import asyncio
from concurrent.futures import ThreadPoolExecutor
from random import choice
from typing import Optional, Generic, TypeVar, TypedDict, Union
from fastapi import BackgroundTasks
import aioredis
from databases import Database
from fastapi_pagination.api import resolve_params
from fastapi_pagination.bases import AbstractParams, RawParams
import meilisearch
import orjson
from ormar import Model, QuerySet
from sqlalchemy import Table
from app.utils.pagination import Page, CustomPage
from core.config import env_config
MODEL = TypeVar("MODEL", bound=Model)
QUERY = TypeVar("QUERY", bound=TypedDict)
class BaseSearchService(Generic[MODEL, QUERY], abc.ABC):
MODEL_CLASS: Optional[MODEL] = None
SELECT_RELATED: Optional[Union[list[str], str]] = None
PREFETCH_RELATED: Optional[Union[list[str], str]] = None
CUSTOM_CACHE_PREFIX: Optional[str] = None
CACHE_TTL = 6 * 60 * 60
@classmethod
def get_params(cls) -> AbstractParams:
return resolve_params()
@classmethod
def get_raw_params(cls) -> RawParams:
return resolve_params().to_raw_params()
@classmethod
@property
def model(cls) -> MODEL:
assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MODEL_CLASS
@classmethod
@property
def table(cls) -> Table:
return cls.model.Meta.table
@classmethod
@property
def database(cls) -> Database:
return cls.model.Meta.database
@classmethod
@property
def cache_prefix(cls) -> str:
return cls.CUSTOM_CACHE_PREFIX or cls.model.Meta.tablename
@staticmethod
def _get_query_hash(query: QUERY):
return hash(frozenset(query.items()))
@classmethod
async def _get_object_ids(cls, query: QUERY) -> list[int]:
...
@classmethod
def get_cache_key(cls, query: QUERY) -> str:
model_class_name = cls.cache_prefix
query_hash = cls._get_query_hash(query)
return f"{model_class_name}_{query_hash}"
@classmethod
async def get_cached_ids(
cls,
query: QUERY,
redis: aioredis.Redis,
) -> Optional[list[int]]:
try:
key = cls.get_cache_key(query)
data = await redis.get(key)
if data is None:
return None
return orjson.loads(data)
except aioredis.RedisError as e:
print(e)
return None
@classmethod
async def cache_object_ids(
cls,
query: QUERY,
object_ids: list[int],
redis: aioredis.Redis,
):
try:
key = cls.get_cache_key(query)
await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL)
except aioredis.RedisError as e:
print(e)
@classmethod
async def _get_objects(cls, query: QUERY, redis: aioredis.Redis) -> list[int]:
cached_object_ids = await cls.get_cached_ids(query, redis)
if cached_object_ids is None:
object_ids = await cls._get_object_ids(query)
await cls.cache_object_ids(query, object_ids, redis)
else:
object_ids = cached_object_ids
return object_ids
@classmethod
async def get_limited_objects(
cls, query: QUERY, redis: aioredis.Redis
) -> tuple[int, list[MODEL]]:
object_ids = await cls._get_objects(query, redis)
params = cls.get_raw_params()
limited_object_ids = object_ids[params.offset : params.offset + params.limit]
queryset: QuerySet[MODEL] = cls.model.objects
if cls.PREFETCH_RELATED is not None:
queryset = queryset.prefetch_related(cls.PREFETCH_RELATED)
if cls.SELECT_RELATED:
queryset = queryset.select_related(cls.SELECT_RELATED)
db_objects = await queryset.filter(id__in=limited_object_ids).all()
return len(object_ids), sorted(
db_objects, key=lambda o: limited_object_ids.index(o.id)
)
@classmethod
async def get(cls, query: QUERY, redis: aioredis.Redis) -> Page[MODEL]:
params = cls.get_params()
total, objects = await cls.get_limited_objects(query, redis)
return CustomPage.create(items=objects, total=total, params=params)
class SearchQuery(TypedDict):
query: str
allowed_langs: frozenset[str]
class TRGMSearchService(Generic[MODEL], BaseSearchService[MODEL, SearchQuery]):
GET_OBJECT_IDS_QUERY: Optional[str] = None
@classmethod
@property
def object_ids_query(cls) -> str:
assert (
cls.GET_OBJECT_IDS_QUERY is not None
), f"GET_OBJECT_IDS_QUERY in {cls.__name__} don't set!"
return cls.GET_OBJECT_IDS_QUERY
@classmethod
async def _get_object_ids(cls, query: SearchQuery) -> list[int]:
row = await cls.database.fetch_one(
cls.object_ids_query,
{"query": query["query"], "langs": query["allowed_langs"]},
)
if row is None:
raise ValueError("Something is wrong!")
return row["array"]
class MeiliSearchService(Generic[MODEL], BaseSearchService[MODEL, SearchQuery]):
MS_INDEX_NAME: Optional[str] = None
MS_INDEX_LANG_KEY: Optional[str] = None
_executor = ThreadPoolExecutor(2)
@classmethod
@property
def lang_key(cls) -> str:
assert cls.MS_INDEX_LANG_KEY is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MS_INDEX_LANG_KEY
@classmethod
@property
def index_name(cls) -> str:
assert cls.MS_INDEX_NAME is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MS_INDEX_NAME
@classmethod
def get_allowed_langs_filter(cls, allowed_langs: frozenset[str]) -> list[list[str]]:
return [[f"{cls.lang_key} = {lang}" for lang in allowed_langs]]
@classmethod
def make_request(
cls, query: str, allowed_langs_filter: list[list[str]], offset: int
):
client = meilisearch.Client(env_config.MEILI_HOST, env_config.MEILI_MASTER_KEY)
index = client.index(cls.index_name)
result = index.search(
query,
{
"filter": allowed_langs_filter,
"offset": offset,
"limit": 630,
"attributesToRetrieve": ["id"],
},
)
total: int = result["nbHits"]
ids: list[int] = [r["id"] for r in result["hits"][:total]]
return ids
@classmethod
async def _get_object_ids(cls, query: SearchQuery) -> list[int]:
params = cls.get_raw_params()
allowed_langs_filter = cls.get_allowed_langs_filter(query["allowed_langs"])
return await asyncio.get_event_loop().run_in_executor(
cls._executor,
cls.make_request,
query["query"],
allowed_langs_filter,
params.offset,
)
class GetRandomService(Generic[MODEL]):
MODEL_CLASS: Optional[MODEL] = None
GET_OBJECTS_ID_QUERY: Optional[str] = None
CUSTOM_CACHE_PREFIX: Optional[str] = None
CACHE_TTL = 6 * 60 * 60
@classmethod
@property
def model(cls) -> MODEL:
assert cls.MODEL_CLASS is not None, f"MODEL in {cls.__name__} don't set!"
return cls.MODEL_CLASS
@classmethod
@property
def database(cls) -> Database:
return cls.model.Meta.database
@classmethod
@property
def cache_prefix(cls) -> str:
return cls.CUSTOM_CACHE_PREFIX or cls.model.Meta.tablename
@staticmethod
def _get_query_hash(query: frozenset[str]):
return hash(query)
@classmethod
def get_cache_key(cls, query: frozenset[str]) -> str:
model_class_name = cls.cache_prefix
query_hash = cls._get_query_hash(query)
return f"random_{model_class_name}_{query_hash}"
@classmethod
@property
def objects_id_query(cls) -> str:
assert (
cls.GET_OBJECTS_ID_QUERY is not None
), f"GET_OBJECT_IDS_QUERY in {cls.__name__} don't set!"
return cls.GET_OBJECTS_ID_QUERY
@classmethod
async def _get_objects_from_db(cls, allowed_langs: frozenset[str]) -> list[int]:
objects = await cls.database.fetch_all(
cls.objects_id_query, {"langs": allowed_langs}
)
return [obj["id"] for obj in objects]
@classmethod
async def _get_random_object_from_cache(
cls, allowed_langs: frozenset[str], redis: aioredis.Redis
) -> Optional[int]:
try:
key = cls.get_cache_key(allowed_langs)
active_key = f"{key}_active"
if not await redis.exists(active_key):
return None
data: bytes = await redis.srandmember(key)
return int(data.decode())
except aioredis.RedisError as e:
print(e)
return None
@classmethod
async def _cache_object_ids(
cls, object_ids: list[int], allowed_langs: frozenset[str], redis: aioredis.Redis
) -> bool:
try:
key = cls.get_cache_key(allowed_langs)
active_key = f"{key}_active"
await redis.set(active_key, 1, ex=cls.CACHE_TTL)
await redis.delete(key)
await redis.sadd(key, *object_ids)
return True
except aioredis.RedisError as e:
print(e)
return False
@classmethod
async def get_random_id(
cls,
allowed_langs: frozenset[str],
redis: aioredis.Redis,
background_tasks: BackgroundTasks,
) -> int:
cached_object_id = await cls._get_random_object_from_cache(allowed_langs, redis)
if cached_object_id is not None:
return cached_object_id
object_ids = await cls._get_objects_from_db(allowed_langs)
background_tasks.add_task(cls._cache_object_ids, allowed_langs, redis)
return choice(object_ids)
class BaseFilterService(Generic[MODEL, QUERY], BaseSearchService[MODEL, QUERY]):
@classmethod
async def _get_object_ids(cls, query: QUERY) -> list[int]:
return (
await cls.model.objects.filter(**query)
.fields("id")
.values_list(flatten=True)
)