This commit is contained in:
2023-08-10 00:26:27 +02:00
parent 6b726da24e
commit e2245aab9e

View File

@@ -1,5 +1,5 @@
use std::{net::SocketAddr, time::SystemTime}; use std::{net::SocketAddr, time::SystemTime};
use axum::{Router, routing::{post, get}, extract::Multipart, response::{IntoResponse, AppendHeaders}, http::{StatusCode, header}, body::StreamBody}; use axum::{Router, routing::{post, get}, extract::Multipart, response::{IntoResponse, AppendHeaders, Response}, http::{StatusCode, header, Request, self}, body::StreamBody, middleware::{Next, self}};
use axum_prometheus::PrometheusMetricLayer; use axum_prometheus::PrometheusMetricLayer;
use tokio::{fs::{remove_file, read_dir, remove_dir, File}, io::{AsyncWriteExt, copy}, process::Command}; use tokio::{fs::{remove_file, read_dir, remove_dir, File}, io::{AsyncWriteExt, copy}, process::Command};
use tower_http::trace::{TraceLayer, self}; use tower_http::trace::{TraceLayer, self};
@@ -161,13 +161,40 @@ async fn convert_file(
} }
async fn auth<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
let auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(StatusCode::UNAUTHORIZED);
};
if auth_header != std::env::var("API_KEY").unwrap_or_else(|_| panic!("Cannot get the API_KEY env variable")) {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(req).await)
}
fn get_router() -> Router { fn get_router() -> Router {
let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair(); let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();
Router::new() let app_router = Router::new()
.route("/", post(convert_file)) .route("/", post(convert_file))
.route("/metrics", get(|| async move { metric_handle.render() })) .layer(middleware::from_fn(auth))
.layer(prometheus_layer) .layer(prometheus_layer);
let metric_router =
Router::new().route("/metrics", get(|| async move { metric_handle.render() }));
Router::new()
.nest("/", app_router)
.nest("/", metric_router)
.layer( .layer(
TraceLayer::new_for_http() TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))