mirror of
https://github.com/flibusta-apps/book_library_server.git
synced 2025-12-06 15:15:36 +01:00
Add search result cache
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
from typing import Optional, Generic, TypeVar, Union
|
||||
from itertools import permutations
|
||||
from databases import Database
|
||||
import json
|
||||
|
||||
from fastapi_pagination.api import resolve_params
|
||||
from fastapi_pagination.bases import AbstractParams, RawParams
|
||||
from app.utils.pagination import Page, CustomPage
|
||||
import aioredis
|
||||
import orjson
|
||||
|
||||
from ormar import Model, QuerySet
|
||||
from sqlalchemy import text, func, select, or_, Table, Column, cast, Text
|
||||
from sqlalchemy.orm import Session
|
||||
from databases import Database
|
||||
|
||||
|
||||
def join_fields(fields):
|
||||
@@ -30,6 +32,7 @@ class TRGMSearchService(Generic[T]):
|
||||
SELECT_RELATED: Optional[Union[list[str], str]] = None
|
||||
PREFETCH_RELATED: Optional[Union[list[str], str]] = None
|
||||
FILTERS = []
|
||||
CACHE_TTL = 5 * 60
|
||||
|
||||
@classmethod
|
||||
def get_params(cls) -> AbstractParams:
|
||||
@@ -78,15 +81,13 @@ class TRGMSearchService(Generic[T]):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_objects(cls, query_data: str) -> tuple[int, list[T]]:
|
||||
async def _get_object_ids(cls, query_data: str) -> list[int]:
|
||||
similarity = cls.get_similarity_subquery(query_data)
|
||||
similarity_filter = cls.get_similarity_filter_subquery(query_data)
|
||||
|
||||
params = cls.get_raw_params()
|
||||
|
||||
session = Session(cls.database.connection())
|
||||
|
||||
q1 = session.query(
|
||||
filtered_objects_query = session.query(
|
||||
cls.table.c.id, similarity
|
||||
).order_by(
|
||||
text('sml DESC')
|
||||
@@ -95,23 +96,57 @@ class TRGMSearchService(Generic[T]):
|
||||
*cls.FILTERS
|
||||
).cte('objs')
|
||||
|
||||
sq = session.query(q1.c.id).limit(params.limit).offset(params.offset).subquery()
|
||||
|
||||
q2 = session.query(
|
||||
func.json_build_object(
|
||||
text("'total'"), func.count(q1.c.id),
|
||||
text("'items'"), select(func.array_to_json(func.array_agg(sq.c.id)))
|
||||
)
|
||||
object_ids_query = session.query(
|
||||
func.array_agg(filtered_objects_query.c.id)
|
||||
).cte()
|
||||
|
||||
print(str(q2))
|
||||
|
||||
row = await cls.database.fetch_one(q2)
|
||||
row = await cls.database.fetch_one(object_ids_query)
|
||||
|
||||
if row is None:
|
||||
raise ValueError('Something is wrong!')
|
||||
|
||||
result = json.loads(row['json_build_object_1'])
|
||||
return row['array_agg_1']
|
||||
|
||||
@classmethod
|
||||
def get_cache_key(cls, query_data: str) -> str:
|
||||
model_class_name = cls.model.__class__.__name__
|
||||
return f"{model_class_name}_{query_data}"
|
||||
|
||||
@classmethod
|
||||
async def get_cached_ids(cls, query_data: str, redis: aioredis.Redis) -> Optional[list[int]]:
|
||||
try:
|
||||
key = cls.get_cache_key(query_data)
|
||||
data = await redis.get(key)
|
||||
|
||||
if data is None:
|
||||
return data
|
||||
|
||||
return orjson.loads(data)
|
||||
except aioredis.RedisError as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def cache_object_ids(cls, query_data: str, object_ids: list[int], redis: aioredis.Redis):
|
||||
try:
|
||||
key = cls.get_cache_key(query_data)
|
||||
await redis.set(key, orjson.dumps(object_ids), ex=cls.CACHE_TTL)
|
||||
except aioredis.RedisError as e:
|
||||
print(e)
|
||||
|
||||
@classmethod
|
||||
async def get_objects(cls, query_data: str, redis: aioredis.Redis) -> tuple[int, list[T]]:
|
||||
params = cls.get_raw_params()
|
||||
|
||||
cached_object_ids = await cls.get_cached_ids(query_data, redis)
|
||||
|
||||
if cached_object_ids is None:
|
||||
object_ids = await cls._get_object_ids(query_data)
|
||||
await cls.cache_object_ids(query_data, object_ids, redis)
|
||||
else:
|
||||
object_ids = cached_object_ids
|
||||
|
||||
limited_object_ids = object_ids[params.offset:params.offset + params.limit]
|
||||
|
||||
queryset: QuerySet[T] = cls.model.objects
|
||||
|
||||
@@ -121,14 +156,13 @@ class TRGMSearchService(Generic[T]):
|
||||
if cls.SELECT_RELATED:
|
||||
queryset = queryset.select_related(cls.SELECT_RELATED)
|
||||
|
||||
return result['total'], await queryset.filter(id__in=result['items']).all()
|
||||
|
||||
return len(object_ids), await queryset.filter(id__in=limited_object_ids).all()
|
||||
|
||||
@classmethod
|
||||
async def get(cls, query: str) -> Page[T]:
|
||||
async def get(cls, query: str, redis: aioredis.Redis) -> Page[T]:
|
||||
params = cls.get_params()
|
||||
|
||||
total, objects = await cls.get_objects(query)
|
||||
total, objects = await cls.get_objects(query, redis)
|
||||
|
||||
return CustomPage.create(
|
||||
items=objects,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException, status
|
||||
|
||||
from fastapi_pagination import Params
|
||||
from fastapi_pagination.ext.ormar import paginate
|
||||
@@ -81,5 +81,5 @@ async def get_translated_books(id: int):
|
||||
|
||||
|
||||
@author_router.get("/search/{query}", response_model=CustomPage[Author], dependencies=[Depends(Params)])
|
||||
async def search_authors(query: str):
|
||||
return await AuthorTGRMSearchService.get(query)
|
||||
async def search_authors(query: str, request: Request):
|
||||
return await AuthorTGRMSearchService.get(query, request.app.state.redis)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException, status
|
||||
|
||||
from fastapi_pagination import Params
|
||||
from fastapi_pagination.ext.ormar import paginate
|
||||
@@ -92,5 +92,5 @@ async def get_book_annotation(id: int):
|
||||
|
||||
|
||||
@book_router.get("/search/{query}", response_model=CustomPage[Book], dependencies=[Depends(Params)])
|
||||
async def search_books(query: str):
|
||||
return await BookTGRMSearchService.get(query)
|
||||
async def search_books(query: str, request: Request):
|
||||
return await BookTGRMSearchService.get(query, request.app.state.redis)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from fastapi_pagination import Params
|
||||
from fastapi_pagination.ext.ormar import paginate
|
||||
@@ -37,5 +37,5 @@ async def create_sequence(data: CreateSequence):
|
||||
|
||||
|
||||
@sequence_router.get("/search/{query}", response_model=CustomPage[Sequence], dependencies=[Depends(Params)])
|
||||
async def search_sequences(query: str):
|
||||
return await SequenceTGRMSearchService.get(query)
|
||||
async def search_sequences(query: str, request: Request):
|
||||
return await SequenceTGRMSearchService.get(query, request.app.state.redis)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from operator import add
|
||||
from fastapi import FastAPI
|
||||
from fastapi_pagination import add_pagination
|
||||
import aioredis
|
||||
|
||||
from core.db import database
|
||||
from core.config import env_config
|
||||
|
||||
from app.views import routers
|
||||
|
||||
|
||||
@@ -11,6 +13,13 @@ def start_app() -> FastAPI:
|
||||
|
||||
app.state.database = database
|
||||
|
||||
app.state.redis = aioredis.Redis(
|
||||
host=env_config.REDIS_HOST,
|
||||
port=env_config.REDIS_PORT,
|
||||
db=env_config.REDIS_DB,
|
||||
password=env_config.REDIS_PASSWORD,
|
||||
)
|
||||
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
|
||||
@@ -10,6 +12,11 @@ class EnvConfig(BaseSettings):
|
||||
POSTGRES_PORT: int
|
||||
POSTGRES_DB: str
|
||||
|
||||
REDIS_HOST: str
|
||||
REDIS_PORT: int
|
||||
REDIS_DB: int
|
||||
REDIS_PASSWORD: Optional[str]
|
||||
|
||||
class Config:
|
||||
env_file = '.env'
|
||||
env_file_encoding = 'utf-8'
|
||||
|
||||
Reference in New Issue
Block a user