Rewrite to rust

This commit is contained in:
2023-08-06 21:27:52 +02:00
parent fd4de89515
commit a7a7dd50a2
32 changed files with 3475 additions and 1920 deletions

View File

@@ -1,22 +0,0 @@
from fastapi import HTTPException, Request, Security, status
from redis.asyncio import Redis
from taskiq import Context, TaskiqDepends
from core.auth import default_security
from core.config import env_config
async def check_token(api_key: str = Security(default_security)):
if api_key != env_config.API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Wrong api key!"
)
def get_redis(request: Request) -> Redis:
return request.app.state.redis
def get_redis_taskiq(context: Context = TaskiqDepends()) -> Redis:
return context.state.redis

View File

@@ -1,10 +0,0 @@
from pydantic import BaseModel
from app.services.task_creator import ObjectType
class CreateTaskData(BaseModel):
object_id: int
object_type: ObjectType
file_format: str
allowed_langs: list[str]

View File

@@ -1,245 +0,0 @@
import asyncio
from base64 import b64decode
from io import BytesIO
import tempfile
from typing import cast
import uuid
import zipfile
import httpx
from minio import Minio
from redis.asyncio import Redis
from taskiq import TaskiqDepends
from taskiq.task import AsyncTaskiqTask
from transliterate import translit
from app.depends import get_redis_taskiq
from app.services.library_client import LibraryClient
from app.services.task_manager import ObjectType, TaskManager, TaskStatusEnum
from core.config import env_config
from core.taskiq_broker import broker, result_backend
def get_minio_client():
return Minio(
env_config.MINIO_HOST,
access_key=env_config.MINIO_ACCESS_KEY,
secret_key=env_config.MINIO_SECRET_KEY,
secure=False,
)
async def _download_to_tmpfile(
book_id: int, file_type: str, output: tempfile.SpooledTemporaryFile
) -> tuple[str, int] | None:
async with httpx.AsyncClient() as client:
request = client.build_request(
"get",
f"{env_config.CACHE_URL}/api/v1/download/{book_id}/{file_type}",
headers={"Authorization": env_config.CACHE_API_KEY},
)
response = await client.send(request, stream=True)
if response.status_code != 200:
await response.aclose()
return None
filename = b64decode(response.headers["X-Filename-B64"]).decode()
loop = asyncio.get_running_loop()
async for chunk in response.aiter_bytes(2048):
await loop.run_in_executor(None, output.write, chunk)
await loop.run_in_executor(None, output.flush)
await loop.run_in_executor(None, output.seek, 0, 2)
size = await loop.run_in_executor(None, output.tell)
await loop.run_in_executor(None, output.seek, 0)
return filename, size
async def download_file_to_file(link: str, output: BytesIO) -> bool:
async with httpx.AsyncClient() as client:
request = client.build_request(
"get", link, headers={"Authorization": env_config.CACHE_API_KEY}
)
response = await client.send(request, stream=True)
if response.status_code != 200:
await response.aclose()
return False
loop = asyncio.get_running_loop()
async for chunk in response.aiter_bytes(2048):
await loop.run_in_executor(None, output.write, chunk)
return True
@broker.task()
async def download(
task_id: str, book_id: int, file_type: str, prev_task_id: str | None = None
) -> str | None:
if prev_task_id:
prev_task = AsyncTaskiqTask(prev_task_id, result_backend)
while not (await prev_task.is_ready()):
await asyncio.sleep(0.1)
try:
with tempfile.SpooledTemporaryFile() as temp_file:
data = await _download_to_tmpfile(book_id, file_type, temp_file)
if data is None:
return None
filename, size = data
minio_client = get_minio_client()
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
minio_client.put_object,
env_config.MINIO_BUCKET,
filename,
temp_file,
size,
)
return filename
finally:
await check_subtasks.kiq(task_id)
async def _check_subtasks(subtasks: list[str]) -> bool:
"""
Return `true` if all substask `.is_ready()`
"""
internal_subtasks = [
AsyncTaskiqTask(subtask, result_backend) for subtask in subtasks
]
for task in internal_subtasks:
task_is_ready = await task.is_ready()
if not task_is_ready:
return False
return True
@broker.task()
async def check_subtasks(task_id: str, redis: Redis = TaskiqDepends(get_redis_taskiq)):
task = await TaskManager.get_task(redis, uuid.UUID(task_id))
if task is None:
return False
await asyncio.sleep(1)
is_subtasks_ready = await _check_subtasks(task.subtasks)
if is_subtasks_ready:
await create_archive.kiq(task_id)
@broker.task()
async def create_archive(task_id: str, redis: Redis = TaskiqDepends(get_redis_taskiq)):
task = await TaskManager.get_task(redis, uuid.UUID(task_id))
assert task
match task.object_type:
case ObjectType.SEQUENCE:
item = await LibraryClient.get_sequence(task.object_id)
assert item
name = item.name
case ObjectType.AUTHOR | ObjectType.TRANSLATOR:
item = await LibraryClient.get_author(task.object_id)
assert item
names = [item.first_name, item.last_name, item.middle_name]
name = "_".join([i for i in names if i])
# TODO: test with `uk` and `be`
tr_name = translit(name, "ru", reversed=True, strict=True)
archive_filename = f"{item.id}_{tr_name}.zip"
assert item
task.status = TaskStatusEnum.ARCHIVING
await TaskManager.save_task(redis, task)
minio_client = get_minio_client()
loop = asyncio.get_running_loop()
with tempfile.SpooledTemporaryFile() as temp_zipfile:
zip_file = zipfile.ZipFile(
temp_zipfile,
mode="w",
compression=zipfile.ZIP_DEFLATED,
allowZip64=False,
compresslevel=9,
)
for subtask_id in task.subtasks:
subtask = AsyncTaskiqTask(subtask_id, result_backend)
result = await subtask.get_result()
if result.is_err:
continue
filename: str | None = result.return_value
if filename is None:
continue
book_file_link = await loop.run_in_executor(
None,
minio_client.get_presigned_url,
"GET",
env_config.MINIO_BUCKET,
filename,
)
with zip_file.open(filename, "w") as internal_zip_file:
await download_file_to_file(
book_file_link, cast(BytesIO, internal_zip_file)
)
await loop.run_in_executor(
None, minio_client.remove_object, env_config.MINIO_BUCKET, filename
)
zip_file.close()
await loop.run_in_executor(None, temp_zipfile.flush)
await loop.run_in_executor(None, temp_zipfile.seek, 0, 2)
size = await loop.run_in_executor(None, temp_zipfile.tell)
await loop.run_in_executor(None, temp_zipfile.seek, 0)
await loop.run_in_executor(
None,
minio_client.put_object,
env_config.MINIO_BUCKET,
archive_filename,
temp_zipfile,
size,
)
task.status = TaskStatusEnum.COMPLETE
task.result_filename = archive_filename
task.result_link = await loop.run_in_executor(
None,
minio_client.get_presigned_url,
"GET",
env_config.MINIO_BUCKET,
archive_filename,
)
await TaskManager.save_task(redis, task)

View File

@@ -1,132 +0,0 @@
from typing import Generic, TypeVar
import httpx
from pydantic import BaseModel
from core.config import env_config
class SequenceBook(BaseModel):
id: int
available_types: list[str]
class AuthorBook(BaseModel):
id: int
available_types: list[str]
class TranslatorBook(BaseModel):
id: int
available_types: list[str]
Item = TypeVar("Item", bound=BaseModel)
class Page(BaseModel, Generic[Item]):
items: list[Item]
total: int
page: int
size: int
pages: int
class Sequence(BaseModel):
id: int
name: str
class Author(BaseModel):
id: int
first_name: str
last_name: str
middle_name: str | None = None
class LibraryClient:
@staticmethod
async def get_sequence_books(
sequence_id: int, allowed_langs: list[str], page: int = 1
) -> Page[SequenceBook] | None:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{env_config.LIBRARY_URL}/api/v1/sequences/{sequence_id}/books",
params={
"page": page,
"allowed_langs": allowed_langs,
"is_deleted": "false",
},
headers={"Authorization": env_config.LIBRARY_API_KEY},
)
if response.status_code != 200:
return None
return Page[SequenceBook].model_validate_json(response.text)
@staticmethod
async def get_author_books(
author_id: int, allowed_langs: list[str], page: int = 1
) -> Page[AuthorBook] | None:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{env_config.LIBRARY_URL}/api/v1/authors/{author_id}/books",
params={
"page": page,
"allowed_langs": allowed_langs,
"is_deleted": "false",
},
headers={"Authorization": env_config.LIBRARY_API_KEY},
)
if response.status_code != 200:
return None
return Page[AuthorBook].model_validate_json(response.text)
@staticmethod
async def get_translator_books(
translator_id: int, allowed_langs: list[str], page: int = 1
) -> Page[TranslatorBook] | None:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{env_config.LIBRARY_URL}/api/v1/translators/{translator_id}/books",
params={
"page": page,
"allowed_langs": allowed_langs,
"is_deleted": "false",
},
headers={"Authorization": env_config.LIBRARY_API_KEY},
)
if response.status_code != 200:
return None
return Page[TranslatorBook].model_validate_json(response.text)
@staticmethod
async def get_sequence(sequence_id: int) -> Sequence | None:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{env_config.LIBRARY_URL}/api/v1/sequences/{sequence_id}",
headers={"Authorization": env_config.LIBRARY_API_KEY},
)
if response.status_code != 200:
return None
return Sequence.model_validate_json(response.text)
@staticmethod
async def get_author(author_id: int) -> Author | None:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{env_config.LIBRARY_URL}/api/v1/authors/{author_id}",
headers={"Authorization": env_config.LIBRARY_API_KEY},
)
if response.status_code != 200:
return None
return Author.model_validate_json(response.text)

View File

@@ -1,105 +0,0 @@
import uuid
from pydantic import BaseModel
from redis.asyncio import Redis
from app.services.downloader import download
from app.services.library_client import LibraryClient, SequenceBook
from app.services.task_manager import ObjectType, Task, TaskManager
class CreateTaskError(BaseModel):
message: str
class TaskCreator:
@classmethod
async def _get_books(
cls, object_id: int, object_type: ObjectType, allowed_langs: list[str]
) -> list[SequenceBook] | CreateTaskError:
books = []
current_page = 1
pages_count = 1
match object_type:
case ObjectType.SEQUENCE:
books_getter = LibraryClient.get_sequence_books
case ObjectType.AUTHOR:
books_getter = LibraryClient.get_author_books
case ObjectType.TRANSLATOR:
books_getter = LibraryClient.get_translator_books
while current_page <= pages_count:
book_page = await books_getter(object_id, allowed_langs, page=current_page)
if book_page is None:
return CreateTaskError(message="Can't get books!")
books.extend(book_page.items)
current_page += 1
pages_count = book_page.pages
if len(books) == 0:
return CreateTaskError(message="No books!")
return books
@classmethod
async def _create_subtasks(
cls,
task_id: uuid.UUID,
object_id: int,
object_type: ObjectType,
file_format: str,
allowed_langs: list[str],
) -> list[str] | CreateTaskError:
books = await cls._get_books(object_id, object_type, allowed_langs)
if isinstance(books, CreateTaskError):
return books
task_ids: list[str] = []
prev_task_id = None
for book in books:
if file_format not in book.available_types:
continue
task = await download.kiq(
str(task_id), book.id, file_format, prev_task_id=prev_task_id
)
prev_task_id = task.task_id
task_ids.append(task.task_id)
if len(task_ids) == 0:
return CreateTaskError(message="No books to archive!")
return task_ids
@classmethod
async def create_task(
cls,
redis: Redis,
object_id: int,
object_type: ObjectType,
file_format: str,
allowed_langs: list[str],
) -> Task | CreateTaskError:
task_id = uuid.uuid4()
subtasks = await cls._create_subtasks(
task_id, object_id, object_type, file_format, allowed_langs
)
if isinstance(subtasks, CreateTaskError):
return subtasks
task = Task(
id=task_id, object_id=object_id, object_type=object_type, subtasks=subtasks
)
is_saved = await TaskManager.save_task(redis, task)
if not is_saved:
return CreateTaskError(message="Save task error")
return task

View File

@@ -1,58 +0,0 @@
import enum
import uuid
from pydantic import BaseModel
from redis.asyncio import Redis, RedisError
class TaskStatusEnum(enum.StrEnum):
IN_PROGRESS = "in_progress"
ARCHIVING = "archiving"
COMPLETE = "complete"
class ObjectType(enum.StrEnum):
SEQUENCE = "sequence"
AUTHOR = "author"
TRANSLATOR = "translator"
class Task(BaseModel):
id: uuid.UUID
object_id: int
object_type: ObjectType
subtasks: list[str]
status: TaskStatusEnum = TaskStatusEnum.IN_PROGRESS
result_filename: str | None = None
result_link: str | None = None
class TaskManager:
@classmethod
def _get_key(cls, task_id: uuid.UUID) -> str:
return f"at_{task_id}"
@classmethod
async def save_task(cls, redis: Redis, task: Task) -> bool:
key = cls._get_key(task.id)
try:
data = task.model_dump_json()
await redis.set(key, data, ex=60 * 60)
return True
except RedisError:
return False
@classmethod
async def get_task(cls, redis: Redis, task_id: uuid.UUID) -> Task | None:
key = cls._get_key(task_id)
try:
data = await redis.get(key)
if data is None:
return None
return Task.model_validate_json(data)
except RedisError:
return None

View File

@@ -1,44 +0,0 @@
from typing import Annotated
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from redis.asyncio import Redis
from app.depends import check_token, get_redis
from app.serializers import CreateTaskData
from app.services.task_creator import CreateTaskError, TaskCreator
from app.services.task_manager import TaskManager
router = APIRouter(prefix="/api", dependencies=[Depends(check_token)])
@router.post("/")
async def create_archive_task(
redis: Annotated[Redis, Depends(get_redis)], data: CreateTaskData
):
task = await TaskCreator.create_task(
redis=redis,
object_id=data.object_id,
object_type=data.object_type,
file_format=data.file_format,
allowed_langs=data.allowed_langs,
)
if isinstance(task, CreateTaskError):
raise HTTPException(status.HTTP_400_BAD_REQUEST, task)
return task
@router.get("/check_archive/{task_id}")
async def check_archive_task_status(
redis: Annotated[Redis, Depends(get_redis)], task_id: uuid.UUID
):
task = await TaskManager.get_task(redis, task_id)
if task is None:
raise HTTPException(status.HTTP_404_NOT_FOUND)
return task

48
src/config.rs Normal file
View File

@@ -0,0 +1,48 @@
use once_cell::sync::Lazy;
fn get_env(env: &'static str) -> String {
std::env::var(env).unwrap_or_else(|_| panic!("Cannot get the {} env variable", env))
}
pub struct Config {
pub api_key: String,
pub minio_host: String,
pub minio_bucket: String,
pub minio_access_key: String,
pub minio_secret_key: String,
pub library_api_key: String,
pub library_url: String,
pub cache_api_key: String,
pub cache_url: String,
// pub sentry_dsn: String
}
impl Config {
pub fn load() -> Config {
Config {
api_key: get_env("API_KEY"),
minio_host: get_env("MINIO_HOST"),
minio_bucket: get_env("MINIO_BUCKET"),
minio_access_key: get_env("MINIO_ACCESS_KEY"),
minio_secret_key: get_env("MINIO_SECRET_KEY"),
library_api_key: get_env("LIBRARY_API_KEY"),
library_url: get_env("LIBRARY_URL"),
cache_api_key: get_env("CACHE_API_KEY"),
cache_url: get_env("CACHE_URL"),
// sentry_dsn: get_env("SENTRY_DSN")
}
}
}
pub static CONFIG: Lazy<Config> = Lazy::new(|| {
Config::load()
});

View File

@@ -1,31 +0,0 @@
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse
from redis.asyncio import Redis
from app.views import router
from core.config import REDIS_URL
from core.taskiq_broker import broker
def start_app() -> FastAPI:
app = FastAPI(default_response_class=ORJSONResponse)
redis = Redis.from_url(REDIS_URL)
app.state.redis = redis
app.include_router(router)
@app.on_event("startup")
async def app_startup():
if not broker.is_worker_process:
await broker.startup()
@app.on_event("shutdown")
async def app_shutdown():
if not broker.is_worker_process:
await broker.shutdown()
await redis.close()
return app

View File

@@ -1,4 +0,0 @@
from fastapi.security import APIKeyHeader
default_security = APIKeyHeader(name="Authorization")

View File

@@ -1,30 +0,0 @@
from pydantic_settings import BaseSettings
class Config(BaseSettings):
API_KEY: str
REDIS_HOST: str
REDIS_PORT: int
REDIS_DB: int
REDIS_PASSWORD: str | None = None
MINIO_HOST: str
MINIO_BUCKET: str
MINIO_ACCESS_KEY: str
MINIO_SECRET_KEY: str
LIBRARY_API_KEY: str
LIBRARY_URL: str
CACHE_API_KEY: str
CACHE_URL: str
SENTRY_DSN: str | None = None
env_config = Config() # type: ignore
REDIS_URL = (
f"redis://{env_config.REDIS_HOST}:{env_config.REDIS_PORT}/{env_config.REDIS_DB}"
)

View File

@@ -1,20 +0,0 @@
from redis.asyncio import Redis
from taskiq import TaskiqEvents, TaskiqState
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
from core.config import REDIS_URL
result_backend = RedisAsyncResultBackend(redis_url=REDIS_URL, result_ex_time=5 * 60)
broker = ListQueueBroker(url=REDIS_URL).with_result_backend(result_backend)
@broker.on_event(TaskiqEvents.WORKER_STARTUP)
async def startup(state: TaskiqState) -> None:
state.redis = Redis.from_url(REDIS_URL)
@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
async def shutdown(state: TaskiqState) -> None:
await state.redis.close()

View File

@@ -1,11 +0,0 @@
import sentry_sdk
from core.app import start_app
from core.config import env_config
if env_config.SENTRY_DSN:
sentry_sdk.init(dsn=env_config.SENTRY_DSN)
app = start_app()

29
src/main.rs Normal file
View File

@@ -0,0 +1,29 @@
pub mod views;
pub mod config;
pub mod services;
pub mod structures;
use std::net::SocketAddr;
use tracing::info;
use crate::views::get_router;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_target(false)
.compact()
.init();
let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
let app = get_router().await;
info!("Start webserver...");
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
info!("Webserver shutdown...")
}

View File

View File

@@ -0,0 +1,125 @@
use serde::{de::DeserializeOwned, Deserialize};
use smallvec::SmallVec;
use smartstring::alias::String as SmartString;
use tracing::log;
use crate::config;
const PAGE_SIZE: &str = "50";
fn get_allowed_langs_params(allowed_langs: SmallVec<[SmartString; 3]>) -> Vec<(&'static str, SmartString)> {
allowed_langs
.into_iter()
.map(|lang| ("allowed_langs", lang))
.collect()
}
async fn _make_request<T>(
url: &str,
params: Vec<(&str, SmartString)>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
where
T: DeserializeOwned,
{
let response = reqwest::Client::new()
.get(format!("{}{}", &config::CONFIG.library_url, url))
.query(&params)
.header("Authorization", &config::CONFIG.library_api_key)
.send()
.await?
.error_for_status()?;
match response.json::<T>().await {
Ok(v) => Ok(v),
Err(err) => {
log::error!("Failed serialization: url={:?} err={:?}", url, err);
Err(Box::new(err))
},
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct Book {
pub id: u64,
pub available_types: SmallVec<[String; 4]>
}
#[derive(Deserialize, Debug, Clone)]
pub struct Page<T> {
pub items: Vec<T>,
pub total: u32,
pub page: u32,
pub size: u32,
pub pages: u32,
}
#[derive(Deserialize, Debug, Clone)]
pub struct Sequence {
pub id: u32,
pub name: String
}
#[derive(Deserialize, Debug, Clone)]
pub struct Author {
pub id: u32,
pub first_name: String,
pub last_name: String,
pub middle_name: Option<String>
}
pub async fn get_author_books(
id: u32,
page: u32,
allowed_langs: SmallVec<[SmartString; 3]>,
) -> Result<Page<Book>, Box<dyn std::error::Error + Send + Sync>> {
let mut params = get_allowed_langs_params(allowed_langs);
params.push(("page", page.to_string().into()));
params.push(("size", PAGE_SIZE.to_string().into()));
_make_request(format!("/api/v1/authors/{id}/books").as_str(), params).await
}
pub async fn get_translator_books(
id: u32,
page: u32,
allowed_langs: SmallVec<[SmartString; 3]>,
) -> Result<Page<Book>, Box<dyn std::error::Error + Send + Sync>> {
let mut params = get_allowed_langs_params(allowed_langs);
params.push(("page", page.to_string().into()));
params.push(("size", PAGE_SIZE.to_string().into()));
_make_request(format!("/api/v1/translators/{id}/books").as_str(), params).await
}
pub async fn get_sequence_books(
id: u32,
page: u32,
allowed_langs: SmallVec<[SmartString; 3]>,
) -> Result<Page<Book>, Box<dyn std::error::Error + Send + Sync>> {
let mut params = get_allowed_langs_params(allowed_langs);
params.push(("page", page.to_string().into()));
params.push(("size", PAGE_SIZE.to_string().into()));
_make_request(format!("/api/v1/sequences/{id}/books").as_str(), params).await
}
pub async fn get_author(id: u32) -> Result<Author, Box<dyn std::error::Error + Send + Sync>> {
_make_request(&format!("/api/v1/authors/{id}"), vec![]).await
}
pub async fn get_sequence(id: u32) -> Result<Sequence, Box<dyn std::error::Error + Send + Sync>> {
_make_request(&format!("/api/v1/sequences/{id}"), vec![]).await
}

4
src/services/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod task_creator;
pub mod library_client;
pub mod utils;
pub mod downloader;

View File

@@ -0,0 +1,337 @@
use std::{fmt, io::{Seek, Read}};
use base64::{engine::general_purpose, Engine};
use bytes::Bytes;
use minio_rsc::{provider::StaticProvider, Minio, types::args::{ObjectArgs, PresignedArgs}, errors::MinioError};
use reqwest::StatusCode;
use smallvec::SmallVec;
use smartstring::alias::String as SmartString;
use tempfile::SpooledTempFile;
use translit::{Transliterator, gost779b_ru, CharsMapping};
use zip::write::FileOptions;
use async_stream::stream;
use crate::{structures::{CreateTask, Task, ObjectType}, config, views::TASK_RESULTS};
use super::{library_client::{Book, get_sequence_books, get_author_books, get_translator_books, Page, get_sequence, get_author}, utils::response_to_tempfile};
pub fn get_key(
input_data: CreateTask
) -> String {
let mut data = input_data.clone();
data.allowed_langs.sort();
let data_string = serde_json::to_string(&data).unwrap();
format!("{:x}", md5::compute(data_string))
}
pub async fn get_books<Fut>(
object_id: u32,
allowed_langs: SmallVec<[SmartString; 3]>,
books_getter: fn(id: u32, page: u32, allowed_langs: SmallVec<[SmartString; 3]>) -> Fut
) -> Result<Vec<Book>, Box<dyn std::error::Error + Send + Sync>>
where
Fut: std::future::Future<Output = Result<Page<Book>, Box<dyn std::error::Error + Send + Sync>>>,
{
let mut result: Vec<Book> = vec![];
let first_page = match books_getter(object_id, 1, allowed_langs.clone()).await {
Ok(v) => v,
Err(err) => return Err(err),
};
result.extend(first_page.items);
let mut current_page = 2;
let page_count = first_page.pages;
while current_page <= page_count {
let page = match books_getter(object_id, current_page, allowed_langs.clone()).await {
Ok(v) => v,
Err(err) => return Err(err),
};
result.extend(page.items);
current_page += 1;
};
Ok(result)
}
#[derive(Debug, Clone)]
struct DownloadError {
status_code: StatusCode,
}
impl fmt::Display for DownloadError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Status code is {0}", self.status_code)
}
}
impl std::error::Error for DownloadError {}
pub async fn download(
book_id: u64,
file_type: String,
) -> Result<Option<(SpooledTempFile, String)>, Box<dyn std::error::Error + Send + Sync>> {
let mut response = reqwest::Client::new()
.get(format!(
"{}/api/v1/download/{book_id}/{file_type}",
&config::CONFIG.cache_url
))
.header("Authorization", &config::CONFIG.cache_api_key)
.send()
.await?
.error_for_status()?;
if response.status() != StatusCode::OK {
return Err(Box::new(DownloadError {
status_code: response.status(),
}));
};
let headers = response.headers();
let base64_encoder = general_purpose::STANDARD;
let filename = std::str::from_utf8(
&base64_encoder
.decode(headers.get("x-filename-b64").unwrap())
.unwrap(),
)
.unwrap()
.to_string();
let output_file = match response_to_tempfile(&mut response).await {
Some(v) => v.0,
None => return Ok(None),
};
Ok(Some((output_file, filename)))
}
fn get_stream(mut temp_file: Box<dyn Read + Send>) -> impl futures_core::Stream<Item = Result<Bytes, MinioError>> {
stream! {
let mut buf = [0; 2048];
loop {
match temp_file.read(&mut buf) {
Ok(count) => {
if count == 0 {
break;
}
yield Ok(Bytes::copy_from_slice(&buf[0..count]))
},
Err(_) => break
}
}
}
}
pub async fn create_archive_task(key: String, data: CreateTask) {
let books = match data.object_type {
ObjectType::Sequence => get_books(data.object_id, data.allowed_langs, get_sequence_books).await,
ObjectType::Author => get_books(data.object_id, data.allowed_langs, get_author_books).await,
ObjectType::Translator => get_books(data.object_id, data.allowed_langs, get_translator_books).await,
};
let books = match books {
Ok(v) => v,
Err(err) => {
return; // log error and task error
},
};
let books: Vec<_> = books
.iter()
.filter(|book| book.available_types.contains(&data.file_format))
.collect();
if books.is_empty() {
return; // log error and task error
}
let output_file = tempfile::spooled_tempfile(5 * 1024 * 1024);
let mut archive = zip::ZipWriter::new(output_file);
let options = FileOptions::default()
.compression_level(Some(9))
.compression_method(zip::CompressionMethod::Deflated)
.unix_permissions(0o755);
for book in books {
let (mut tmp_file, filename) = match download(book.id, data.file_format.clone()).await {
Ok(v) => {
match v {
Some(v) => v,
None => {
return; // log error and task error
},
}
},
Err(err) => {
return; // log error and task error
},
};
match archive.start_file(filename, options) {
Ok(_) => (),
Err(_) => return, // log error and task error
};
match std::io::copy(&mut tmp_file, &mut archive) {
Ok(_) => (),
Err(_) => return, // log error and task error
};
}
let mut archive_result = match archive.finish() {
Ok(v) => v,
Err(err) => return, // log error and task error
};
archive_result.rewind().unwrap();
let result_filename = match data.object_type {
ObjectType::Sequence => {
match get_sequence(data.object_id).await {
Ok(v) => v.name,
Err(err) => {
println!("{}", err);
return; // log error and task error
},
}
},
ObjectType::Author | ObjectType::Translator => {
match get_author(data.object_id).await {
Ok(v) => {
vec![v.first_name, v.last_name, v.middle_name.unwrap_or("".to_string())]
.into_iter()
.filter(|v| !v.is_empty())
.collect::<Vec<String>>()
.join("_")
},
Err(err) => {
println!("{}", err);
return; // log error and task error
},
}
},
};
let final_filename = {
let transliterator = Transliterator::new(gost779b_ru());
let mut filename_without_type = transliterator.convert(&result_filename, false);
"(),….!\"?»«':".get(..).into_iter().for_each(|char| {
filename_without_type = filename_without_type.replace(char, "");
});
let replace_char_map: CharsMapping = [
("", "-"),
("/", "_"),
("", "N"),
(" ", "_"),
("", "-"),
("á", "a"),
(" ", "_"),
("'", ""),
("`", ""),
("[", ""),
("]", ""),
("\"", ""),
].to_vec();
let replace_transliterator = Transliterator::new(replace_char_map);
let normal_filename = replace_transliterator.convert(&filename_without_type, false);
let normal_filename = normal_filename.replace(|c: char| !c.is_ascii(), "");
let right_part = format!(".zip");
let normal_filename_slice = std::cmp::min(64 - right_part.len() - 1, normal_filename.len() - 1);
let left_part = if normal_filename_slice == normal_filename.len() - 1 {
&normal_filename
} else {
normal_filename.get(..normal_filename_slice).unwrap_or_else(|| panic!("Can't slice left part: {:?} {:?}", normal_filename, normal_filename_slice))
};
format!("{left_part}{right_part}")
};
let provider = StaticProvider::new(
&config::CONFIG.minio_access_key,
&config::CONFIG.minio_secret_key,
None
);
let minio = Minio::builder()
.host(&config::CONFIG.minio_host)
.provider(provider)
.secure(false)
.build()
.unwrap();
let is_bucket_exist = match minio.bucket_exists(&config::CONFIG.minio_bucket).await {
Ok(v) => v,
Err(err) => {
println!("{}", err);
return; // log error and task error
}, // log error and task error
};
if !is_bucket_exist {
minio.make_bucket(&config::CONFIG.minio_bucket, false).await;
}
let data_stream = get_stream(Box::new(archive_result));
if let Err(err) = minio.put_object_stream(
ObjectArgs::new(&config::CONFIG.minio_bucket, final_filename.clone()),
Box::pin(data_stream)
).await {
println!("{}", err);
return; // log error and task error
}
let link = match minio.presigned_get_object(
PresignedArgs::new(&config::CONFIG.minio_bucket, final_filename)
).await {
Ok(v) => v,
Err(err) => {
println!("{}", err);
return; // log error and task error
}, // log error and task error
};
println!("{}", link);
}
pub async fn create_task(
data: CreateTask
) -> Task {
let key = get_key(data.clone());
let task = Task {
id: key.clone(),
status: crate::structures::TaskStatus::InProgress,
result_filename: None,
result_link: None
};
TASK_RESULTS.insert(key.clone(), task.clone()).await;
tokio::spawn(async {
create_archive_task(key, data).await;
});
task
}

40
src/services/utils.rs Normal file
View File

@@ -0,0 +1,40 @@
use reqwest::Response;
use tempfile::SpooledTempFile;
use bytes::Buf;
use std::io::{Seek, SeekFrom, Write};
pub async fn response_to_tempfile(res: &mut Response) -> Option<(SpooledTempFile, usize)> {
let mut tmp_file = tempfile::spooled_tempfile(5 * 1024 * 1024);
let mut data_size: usize = 0;
{
loop {
let chunk = res.chunk().await;
let result = match chunk {
Ok(v) => v,
Err(_) => return None,
};
let data = match result {
Some(v) => v,
None => break,
};
data_size += data.len();
match tmp_file.write(data.chunk()) {
Ok(_) => (),
Err(_) => return None,
}
}
tmp_file.seek(SeekFrom::Start(0)).unwrap();
}
Some((tmp_file, data_size))
}

36
src/structures.rs Normal file
View File

@@ -0,0 +1,36 @@
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use smartstring::alias::String as SmartString;
#[derive(Serialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum TaskStatus {
InProgress,
Archiving,
Complete
}
#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum ObjectType {
Sequence,
Author,
Translator
}
#[derive(Serialize, Deserialize, Clone)]
pub struct CreateTask{
pub object_id: u32,
pub object_type: ObjectType,
pub file_format: String,
pub allowed_langs: SmallVec<[SmartString; 3]>
}
#[derive(Serialize, Clone)]
pub struct Task {
pub id: String,
pub status: TaskStatus,
pub result_filename: Option<String>,
pub result_link: Option<String>
}

86
src/views.rs Normal file
View File

@@ -0,0 +1,86 @@
use std::time::Duration;
use axum::{Router, routing::{get, post}, middleware::{self, Next}, http::{Request, StatusCode, self}, response::{Response, IntoResponse}, extract::{Path, self}, Json};
use axum_prometheus::PrometheusMetricLayer;
use moka::future::Cache;
use once_cell::sync::Lazy;
use tower_http::trace::{TraceLayer, self};
use tracing::Level;
use crate::{config::CONFIG, structures::{Task, CreateTask}, services::task_creator::{get_key, create_task}};
pub static TASK_RESULTS: Lazy<Cache<String, Task>> = Lazy::new(|| {
Cache::builder()
.time_to_idle(Duration::from_secs(24 * 60 * 60))
.max_capacity(2048)
.build()
});
async fn create_archive_task(
extract::Json(data): extract::Json<CreateTask>
) -> impl IntoResponse {
let key = get_key(data.clone());
let result = match TASK_RESULTS.get(&key) {
Some(result) => result,
None => create_task(data).await,
};
Json::<Task>(result.into()).into_response()
}
async fn check_archive_task_status(
Path(task_id): Path<String>
) -> impl IntoResponse {
match TASK_RESULTS.get(&task_id) {
Some(result) => Json::<Task>(result.into()).into_response(),
None => StatusCode::NOT_FOUND.into_response(),
}
}
async fn auth<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
let auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(StatusCode::UNAUTHORIZED);
};
if auth_header != CONFIG.api_key {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(req).await)
}
pub async fn get_router() -> Router {
let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();
let app_router = Router::new()
.route("/", post(create_archive_task))
.route("/check_archive/:task_id", get(check_archive_task_status))
.layer(middleware::from_fn(auth))
.layer(prometheus_layer);
let metric_router =
Router::new().route("/metrics", get(|| async move { metric_handle.render() }));
Router::new()
.nest("/", app_router)
.nest("/", metric_router)
.layer(
TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
)
}