From e2245aab9e6382691076763800d95c86155d5c1f Mon Sep 17 00:00:00 2001 From: Bulat Kurbanov Date: Thu, 10 Aug 2023 00:26:27 +0200 Subject: [PATCH] Fix --- src/main.rs | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2a49e43..d47d2c7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use std::{net::SocketAddr, time::SystemTime}; -use axum::{Router, routing::{post, get}, extract::Multipart, response::{IntoResponse, AppendHeaders}, http::{StatusCode, header}, body::StreamBody}; +use axum::{Router, routing::{post, get}, extract::Multipart, response::{IntoResponse, AppendHeaders, Response}, http::{StatusCode, header, Request, self}, body::StreamBody, middleware::{Next, self}}; use axum_prometheus::PrometheusMetricLayer; use tokio::{fs::{remove_file, read_dir, remove_dir, File}, io::{AsyncWriteExt, copy}, process::Command}; use tower_http::trace::{TraceLayer, self}; @@ -161,13 +161,40 @@ async fn convert_file( } +async fn auth(req: Request, next: Next) -> Result { + let auth_header = req + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()); + + let auth_header = if let Some(auth_header) = auth_header { + auth_header + } else { + return Err(StatusCode::UNAUTHORIZED); + }; + + if auth_header != std::env::var("API_KEY").unwrap_or_else(|_| panic!("Cannot get the API_KEY env variable")) { + return Err(StatusCode::UNAUTHORIZED); + } + + Ok(next.run(req).await) +} + + fn get_router() -> Router { let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair(); - Router::new() + let app_router = Router::new() .route("/", post(convert_file)) - .route("/metrics", get(|| async move { metric_handle.render() })) - .layer(prometheus_layer) + .layer(middleware::from_fn(auth)) + .layer(prometheus_layer); + + let metric_router = + Router::new().route("/metrics", get(|| async move { metric_handle.render() })); + + Router::new() + .nest("/", app_router) + .nest("/", metric_router) .layer( TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))