diff --git a/src/main.rs b/src/main.rs index 29672de..8a0b4f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use axum::{ body::Body, - extract::{Path, State}, + extract::{ConnectInfo, Path, State}, http::{header, HeaderMap, StatusCode}, middleware::{self, Next}, response::{Html, IntoResponse, Response}, @@ -293,6 +293,7 @@ async fn list_files(State(state): State) -> Html { async fn download_file( State(state): State, + ConnectInfo(addr): ConnectInfo, Path(file_id): Path, headers: HeaderMap, ) -> Result { @@ -329,12 +330,21 @@ async fn download_file( let file_size = metadata.len(); + // Get client IP (trust X-Forwarded-For if behind proxy, else socket addr) + let client_ip = headers + .get("X-Forwarded-For") + .or_else(|| headers.get("X-Real-IP")) + .and_then(|v| v.to_str().ok()) + .map(|s| s.split(',').next().unwrap_or(s).trim()) + .map(|s| s.to_string()) + .unwrap_or_else(|| addr.ip().to_string()); + // Check for Range header to support resume let range_header = headers.get(header::RANGE); if let Some(range) = range_header { // Handle range requests for resume support - return handle_range_request(&canonical_path, file_size, range, file_info).await; + return handle_range_request(&canonical_path, file_size, range, file_info, &client_ip).await; } // Open file for streaming @@ -363,10 +373,11 @@ async fn download_file( .unwrap_or("application/octet-stream"); tracing::info!( - "File download started: {} ({}) - {} bytes", + "File download started: {} ({}) - {} bytes - IP: {}", file_id, safe_filename, - file_size + file_size, + client_ip ); Ok(Response::builder() @@ -389,6 +400,7 @@ async fn handle_range_request( file_size: u64, range_header: &header::HeaderValue, file_info: &MediaFile, + client_ip: &str, ) -> Result { let range_str = range_header.to_str().map_err(|_| StatusCode::BAD_REQUEST)?; @@ -449,12 +461,13 @@ async fn handle_range_request( .unwrap_or("application/octet-stream"); tracing::info!( - "Range request: {} bytes {}-{}/{} ({})", + "Range request: {} bytes {}-{}/{} ({}) - IP: {}", safe_filename, start, end, file_size, - content_length + content_length, + client_ip ); Ok(Response::builder()