diff --git a/Cargo.lock b/Cargo.lock index 9501035..9a85b17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,6 +56,7 @@ dependencies = [ "bitflags", "bytes", "futures-util", + "headers", "http", "http-body", "hyper", @@ -96,6 +97,12 @@ dependencies = [ "tower-service", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.0" @@ -371,6 +378,31 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "headers" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584" +dependencies = [ + "base64 0.13.1", + "bitflags", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.4.1" @@ -565,6 +597,16 @@ version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -839,7 +881,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64", + "base64 0.21.0", ] [[package]] @@ -931,12 +973,25 @@ dependencies = [ "axum", "dotenvy", "hyper", + "mime_guess", "nanoid", "serde", "sqlx", "tokio", "tokio-stream", - "tokio-util", + "tower", + "tower-http", +] + +[[package]] +name = "sha1" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", ] [[package]] @@ -1253,10 +1308,17 @@ dependencies = [ "http", "http-body", "http-range-header", + "httpdate", + "mime", + "mime_guess", + "percent-encoding", "pin-project-lite", + "tokio", + "tokio-util", "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -1304,6 +1366,15 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.10" diff --git a/server/Cargo.toml b/server/Cargo.toml index dd3a388..ce74d89 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -8,12 +8,14 @@ edition = "2021" [dependencies] anyhow = "1.0.69" async-trait = "0.1.65" -axum = { version = "0.6.10", features = ["multipart"] } +axum = { version = "0.6.10", features = ["multipart", "headers"] } dotenvy = "0.15.6" hyper = "0.14.24" +mime_guess = "2.0.4" nanoid = "0.4.0" serde = { version = "1.0.152", features = ["derive"] } sqlx = { version = "0.6.2", features = ["sqlite", "runtime-tokio-rustls"] } tokio = { version = "1.26.0", features = ["full"] } tokio-stream = { version = "0.1.12", features = ["net"] } -tokio-util = { version = "0.7.7", features = ["io"] } +tower = "0.4.13" +tower-http = { version = "0.4.0", features = ["fs"] } diff --git a/server/src/routes/i.rs b/server/src/routes/i.rs index 3de3e1a..d2903d9 100644 --- a/server/src/routes/i.rs +++ b/server/src/routes/i.rs @@ -1,28 +1,38 @@ use crate::state::ReqState; use axum::{ - body::StreamBody, extract::{Path, State}, response::IntoResponse, }; -use hyper::StatusCode; -use tokio_util::io::ReaderStream; +use hyper::{Body, Request, StatusCode, Uri}; +use std::{path::PathBuf, str::FromStr}; +use tower::util::ServiceExt; +use tower_http::services::ServeFile; pub async fn get(State(state): ReqState, Path(id): Path) -> impl IntoResponse { if id.len() < 8 { return Err(StatusCode::NOT_FOUND); } + let extension = PathBuf::from_str(&id).map_err(|_| StatusCode::NOT_FOUND)?; + let extension = extension.extension().and_then(|ext| ext.to_str()); + let id = id[..8].to_string(); let path = state.files_dir.join(id.clone()); if path.exists() { - let file = tokio::fs::File::open(path) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let stream = ReaderStream::new(file); - let body = StreamBody::new(stream); + let req = Request::builder() + .uri(Uri::from_static("/")) + .body(Body::empty()) + .unwrap(); - Ok(body) + if let Some(ext) = extension { + let mime = mime_guess::from_ext(ext).first_or_octet_stream(); + let serve_file = ServeFile::new_with_mime(path, &mime).oneshot(req); + Ok(serve_file.await.unwrap().map(axum::body::boxed)) + } else { + let serve_file = ServeFile::new(path).oneshot(req); + Ok(serve_file.await.unwrap().map(axum::body::boxed)) + } } else { Err(StatusCode::NOT_FOUND) }