From 680275bd001196f033f8640a01fc7834f94de463 Mon Sep 17 00:00:00 2001 From: Bulat Kurbanov Date: Tue, 25 Jul 2023 20:37:28 +0200 Subject: [PATCH] Refactor --- src/main.rs | 51 +------------------------ src/views.rs | 104 ++++++++++++++++++++++++++++++++++----------------- 2 files changed, 71 insertions(+), 84 deletions(-) diff --git a/src/main.rs b/src/main.rs index 91f0cca..f64e163 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,58 +3,9 @@ pub mod views; pub mod services; use std::net::SocketAddr; -use axum::{Router, routing::get, middleware::{self, Next}, http::{self, Request}, response::Response}; -use axum_prometheus::PrometheusMetricLayer; -use config::CONFIG; -use reqwest::StatusCode; -use views::{download, get_filename}; use tracing::info; -use tower_http::trace::{TraceLayer, self}; -use tracing::Level; - -async fn auth(req: Request, next: Next) -> Result { - let auth_header = req.headers() - .get(http::header::AUTHORIZATION) - .and_then(|header| header.to_str().ok()); - - let auth_header = if let Some(auth_header) = auth_header { - auth_header - } else { - return Err(StatusCode::UNAUTHORIZED); - }; - - if auth_header != CONFIG.api_key { - return Err(StatusCode::UNAUTHORIZED); - } - - Ok(next.run(req).await) -} - - -async fn get_router() -> Router { - let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair(); - - let app_router = Router::new() - .route("/download/:source_id/:remote_id/:file_type", get(download)) - .route("/filename/:book_id/:file_type", get(get_filename)) - .layer(middleware::from_fn(auth)) - .layer(prometheus_layer); - - let metric_router = Router::new() - .route("/metrics", get(|| async move { metric_handle.render() })); - - Router::new() - .nest("/", app_router) - .nest("/", metric_router) - .layer( - TraceLayer::new_for_http() - .make_span_with(trace::DefaultMakeSpan::new() - .level(Level::INFO)) - .on_response(trace::DefaultOnResponse::new() - .level(Level::INFO)), - ) -} +use crate::views::get_router; #[tokio::main] diff --git a/src/views.rs b/src/views.rs index 58df68e..9a1b308 100644 --- a/src/views.rs +++ b/src/views.rs @@ -1,40 +1,36 @@ use axum::{ body::StreamBody, extract::Path, - http::{header, HeaderMap, StatusCode, header::AUTHORIZATION}, - response::{IntoResponse, AppendHeaders}, + http::{header, StatusCode}, + response::{AppendHeaders, IntoResponse}, }; +use axum::{ + http::{self, Request}, + middleware::{self, Next}, + response::Response, + routing::get, + Router, +}; +use axum_prometheus::PrometheusMetricLayer; use base64::{engine::general_purpose, Engine}; use tokio_util::io::ReaderStream; +use tower_http::trace::{self, TraceLayer}; +use tracing::Level; -use crate::{config, services::{book_library::get_book, filename_getter::get_filename_by_book, downloader::book_download}}; +use crate::{ + config::CONFIG, + services::{ + book_library::get_book, downloader::book_download, filename_getter::get_filename_by_book, + }, +}; -fn check_authorization(headers: HeaderMap) -> Result<(), (StatusCode, String)> { - let config_api_key = config::CONFIG.api_key.clone(); - - let api_key = match headers.get(AUTHORIZATION) { - Some(v) => v, - None => return Err((StatusCode::FORBIDDEN, "No api-key!".to_string())), - }; - - if config_api_key != api_key.to_str().unwrap() { - return Err((StatusCode::FORBIDDEN, "Wrong api-key!".to_string())) - } - - Ok(()) -} pub async fn download( Path((source_id, remote_id, file_type)): Path<(u32, u32, String)>, - headers: HeaderMap ) -> impl IntoResponse { - check_authorization(headers)?; - let download_result = match book_download(source_id, remote_id, file_type.as_str()).await { Ok(v) => v, - Err(_) => { - return Err((StatusCode::NO_CONTENT, "Can't download!".to_string())) - }, + Err(_) => return Err((StatusCode::NO_CONTENT, "Can't download!".to_string())), }; let data = match download_result { @@ -52,25 +48,65 @@ pub async fn download( let encoder = general_purpose::STANDARD; let headers = AppendHeaders([ - (header::CONTENT_DISPOSITION, format!("attachment; filename={filename_ascii}")), - (header::HeaderName::from_static("x-filename-b64"), encoder.encode(filename)) + ( + header::CONTENT_DISPOSITION, + format!("attachment; filename={filename_ascii}"), + ), + ( + header::HeaderName::from_static("x-filename-b64"), + encoder.encode(filename), + ), ]); Ok((headers, body)) } -pub async fn get_filename( - Path((book_id, file_type)): Path<(u32, String)>, - headers: HeaderMap -) -> (StatusCode, String){ - if let Err(v) = check_authorization(headers) { - return v; - } - +pub async fn get_filename(Path((book_id, file_type)): Path<(u32, String)>) -> (StatusCode, String) { let filename = match get_book(book_id).await { Ok(book) => get_filename_by_book(&book, file_type.as_str(), false, false), Err(_) => return (StatusCode::BAD_REQUEST, "Book not found!".to_string()), }; (StatusCode::OK, filename) -} \ No newline at end of file +} + +async fn auth(req: Request, next: Next) -> Result { + let auth_header = req + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()); + + let auth_header = if let Some(auth_header) = auth_header { + auth_header + } else { + return Err(StatusCode::UNAUTHORIZED); + }; + + if auth_header != CONFIG.api_key { + return Err(StatusCode::UNAUTHORIZED); + } + + Ok(next.run(req).await) +} + +pub async fn get_router() -> Router { + let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair(); + + let app_router = Router::new() + .route("/download/:source_id/:remote_id/:file_type", get(download)) + .route("/filename/:book_id/:file_type", get(get_filename)) + .layer(middleware::from_fn(auth)) + .layer(prometheus_layer); + + let metric_router = + Router::new().route("/metrics", get(|| async move { metric_handle.render() })); + + Router::new() + .nest("/", app_router) + .nest("/", metric_router) + .layer( + TraceLayer::new_for_http() + .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) + .on_response(trace::DefaultOnResponse::new().level(Level::INFO)), + ) +}