Initial Code Commit
This commit is contained in:
+497
@@ -0,0 +1,497 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Path, State},
|
||||
http::{header, HeaderMap, StatusCode},
|
||||
middleware::{self, Next},
|
||||
response::{Html, IntoResponse, Response},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::SocketAddr,
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::fs::File;
|
||||
use tokio::io::{AsyncReadExt, AsyncSeekExt};
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tower_http::{
|
||||
compression::CompressionLayer,
|
||||
timeout::TimeoutLayer,
|
||||
trace::TraceLayer,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct MediaFile {
|
||||
name: String,
|
||||
path: String,
|
||||
description: Option<String>,
|
||||
#[serde(default)]
|
||||
content_type: Option<String>,
|
||||
#[serde(default)]
|
||||
size_bytes: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MediaConfig {
|
||||
files: HashMap<String, MediaFile>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Config {
|
||||
port: u16,
|
||||
host: String,
|
||||
request_timeout: u64,
|
||||
allowed_ips: Option<Vec<String>>,
|
||||
chunk_size: usize,
|
||||
domain: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
files: Arc<HashMap<String, MediaFile>>,
|
||||
config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Initialize tracing for logging
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// Load environment configuration
|
||||
dotenv::dotenv().ok();
|
||||
let config = load_config();
|
||||
|
||||
// Load media files configuration
|
||||
let media_config = load_media_config("media_config.json")
|
||||
.expect("Failed to load media configuration");
|
||||
|
||||
let state = AppState {
|
||||
files: Arc::new(media_config.files),
|
||||
config: Arc::new(config.clone()),
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(list_files))
|
||||
.route("/download/:file_id", get(download_file))
|
||||
.route("/health", get(health_check))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
security_headers_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
ip_whitelist_middleware,
|
||||
))
|
||||
.layer(CompressionLayer::new())
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(config.request_timeout)))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(state);
|
||||
|
||||
let addr = SocketAddr::from((
|
||||
config.host.parse::<std::net::IpAddr>().unwrap(),
|
||||
config.port,
|
||||
));
|
||||
|
||||
println!("📥 Public media download server running on http://{}", addr);
|
||||
println!("✅ No authentication required - public access enabled");
|
||||
println!("💾 Optimized for large file streaming (100GB+)");
|
||||
println!("📋 Configuration loaded from .env and media_config.json");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
fn load_config() -> Config {
|
||||
Config {
|
||||
port: std::env::var("PORT")
|
||||
.unwrap_or_else(|_| "3000".to_string())
|
||||
.parse()
|
||||
.expect("Invalid PORT"),
|
||||
host: std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
|
||||
request_timeout: std::env::var("REQUEST_TIMEOUT_SECS")
|
||||
.unwrap_or_else(|_| "3600".to_string()) // 1 hour default for large files
|
||||
.parse()
|
||||
.expect("Invalid REQUEST_TIMEOUT_SECS"),
|
||||
allowed_ips: std::env::var("ALLOWED_IPS")
|
||||
.ok()
|
||||
.map(|s| s.split(',').map(|ip| ip.trim().to_string()).collect()),
|
||||
chunk_size: std::env::var("CHUNK_SIZE_KB")
|
||||
.unwrap_or_else(|_| "8192".to_string()) // 8MB chunks for large files
|
||||
.parse::<usize>()
|
||||
.expect("Invalid CHUNK_SIZE_KB")
|
||||
* 1024,
|
||||
domain: std::env::var("DOMAIN").ok().map(|d| d.trim_end_matches('/').to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn load_media_config(path: &str) -> Result<MediaConfig, Box<dyn std::error::Error>> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let mut config: MediaConfig = serde_json::from_str(&content)?;
|
||||
|
||||
// Validate and get file sizes
|
||||
for (id, file) in &mut config.files {
|
||||
let path = PathBuf::from(&file.path);
|
||||
if !path.exists() {
|
||||
tracing::warn!("File '{}' at path '{}' does not exist", id, file.path);
|
||||
} else if !path.is_file() {
|
||||
tracing::warn!("Path '{}' for file '{}' is not a file", file.path, id);
|
||||
} else {
|
||||
// Get file size
|
||||
if let Ok(metadata) = std::fs::metadata(&path) {
|
||||
file.size_bytes = Some(metadata.len());
|
||||
tracing::info!(
|
||||
"Loaded file '{}': {} ({} bytes)",
|
||||
id,
|
||||
file.name,
|
||||
metadata.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
// IP whitelist middleware (optional)
|
||||
async fn ip_whitelist_middleware(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
request: axum::http::Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if let Some(allowed_ips) = &state.config.allowed_ips {
|
||||
// Try to get real IP from X-Forwarded-For or X-Real-IP headers
|
||||
let client_ip = headers
|
||||
.get("X-Forwarded-For")
|
||||
.or_else(|| headers.get("X-Real-IP"))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.split(',').next().unwrap_or(s).trim());
|
||||
|
||||
if let Some(ip) = client_ip {
|
||||
if !allowed_ips.iter().any(|allowed| allowed == ip || allowed == "*") {
|
||||
tracing::warn!("Access denied for IP: {}", ip);
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
// Security headers middleware
|
||||
async fn security_headers_middleware(
|
||||
State(_state): State<AppState>,
|
||||
request: axum::http::Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let mut response = next.run(request).await;
|
||||
let headers = response.headers_mut();
|
||||
|
||||
// Basic security headers
|
||||
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
|
||||
headers.insert("X-Frame-Options", "SAMEORIGIN".parse().unwrap());
|
||||
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
|
||||
headers.insert("Referrer-Policy", "no-referrer".parse().unwrap());
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
async fn health_check() -> impl IntoResponse {
|
||||
(StatusCode::OK, "OK")
|
||||
}
|
||||
|
||||
async fn list_files(State(state): State<AppState>) -> Html<String> {
|
||||
// Read template from file
|
||||
let template_path = "templates/index.html";
|
||||
let template = match tokio::fs::read_to_string(template_path).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read template {}: {}", template_path, e);
|
||||
return Html(format!(
|
||||
"<h1>500 Internal Server Error</h1><p>Failed to load template: {}</p>",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let file_count = state.files.len();
|
||||
let total_size: u64 = state.files.values().filter_map(|f| f.size_bytes).sum();
|
||||
|
||||
// Determine base URL to display in commands: prefer DOMAIN env var, else fall back to localhost:PORT
|
||||
let base_url = state
|
||||
.config
|
||||
.domain
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("http://localhost:{}", state.config.port));
|
||||
|
||||
let mut file_list_items = String::new();
|
||||
|
||||
// Sort files by ID for consistent display order
|
||||
let mut sorted_files: Vec<_> = state.files.iter().collect();
|
||||
sorted_files.sort_by_key(|(id, _)| *id);
|
||||
|
||||
for (id, file) in sorted_files {
|
||||
let size_str = file.size_bytes
|
||||
.map(format_size)
|
||||
.unwrap_or_else(|| "Unknown size".to_string());
|
||||
// Build safe element id for per-file UI
|
||||
let safe_elem_id = safe_html_id(id);
|
||||
|
||||
// Build example commands using base_url
|
||||
let curl_cmd = format!("curl -O {}/download/{}", base_url, id);
|
||||
|
||||
file_list_items.push_str(&format!(
|
||||
r#"
|
||||
<li class="file-item">
|
||||
<div class="file-name">{name}</div>
|
||||
<div class="file-meta">
|
||||
<span class="file-id">ID: {id}</span>
|
||||
<span class="file-size">📦 {size}</span>
|
||||
</div>
|
||||
<div class="file-desc">{desc}</div>
|
||||
<a href="/download/{id}" class="download-btn">⬇️ Download</a>
|
||||
<div style="margin-top:12px; display:flex; gap:8px; align-items:center;">
|
||||
<button class="show-cmd" data-target="cmd-{elem}">Show Command</button>
|
||||
<button class="copy-cmd" data-target="cmd-{elem}">Copy</button>
|
||||
</div>
|
||||
<pre id="cmd-{elem}" class="usage-pre" style="display:none; margin-top:12px; background:#111827; color:#f9fafb; padding:12px; border-radius:6px; overflow:auto;">{curl}</pre>
|
||||
</li>
|
||||
"#,
|
||||
name = html_escape(&file.name),
|
||||
id = html_escape(id),
|
||||
size = size_str,
|
||||
desc = html_escape(file.description.as_deref().unwrap_or("No description")),
|
||||
elem = html_escape(&safe_elem_id),
|
||||
curl = html_escape(&curl_cmd),
|
||||
));
|
||||
}
|
||||
|
||||
let html = template
|
||||
.replace("{FILE_COUNT}", &file_count.to_string())
|
||||
.replace("{TOTAL_SIZE}", &format_size(total_size))
|
||||
.replace("{FILE_LIST_ITEMS}", &file_list_items)
|
||||
.replace("{BASE_URL}", &html_escape(&base_url));
|
||||
|
||||
Html(html)
|
||||
}
|
||||
|
||||
async fn download_file(
|
||||
State(state): State<AppState>,
|
||||
Path(file_id): Path<String>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Response, StatusCode> {
|
||||
// Sanitize file_id to prevent path traversal
|
||||
if file_id.contains("..") || file_id.contains('/') || file_id.contains('\\') {
|
||||
tracing::warn!("Path traversal attempt detected: {}", file_id);
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let file_info = state
|
||||
.files
|
||||
.get(&file_id)
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Canonicalize path to prevent symlink attacks
|
||||
let file_path = PathBuf::from(&file_info.path);
|
||||
let canonical_path = file_path
|
||||
.canonicalize()
|
||||
.map_err(|_| StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Verify file still exists and is a file
|
||||
if !canonical_path.is_file() {
|
||||
tracing::error!("File not found or is not a file: {:?}", canonical_path);
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
// Get file metadata
|
||||
let metadata = tokio::fs::metadata(&canonical_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to get file metadata: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let file_size = metadata.len();
|
||||
|
||||
// Check for Range header to support resume
|
||||
let range_header = headers.get(header::RANGE);
|
||||
|
||||
if let Some(range) = range_header {
|
||||
// Handle range requests for resume support
|
||||
return handle_range_request(&canonical_path, file_size, range, file_info).await;
|
||||
}
|
||||
|
||||
// Open file for streaming
|
||||
let file = File::open(&canonical_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to open file: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// Create stream with optimized buffer size for large files
|
||||
let stream = ReaderStream::with_capacity(file, state.config.chunk_size);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let filename = canonical_path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("download");
|
||||
|
||||
// Sanitize filename
|
||||
let safe_filename = filename.replace(&['/', '\\', '\0'][..], "_");
|
||||
|
||||
let content_type = file_info
|
||||
.content_type
|
||||
.as_deref()
|
||||
.unwrap_or("application/octet-stream");
|
||||
|
||||
tracing::info!(
|
||||
"File download started: {} ({}) - {} bytes",
|
||||
file_id,
|
||||
safe_filename,
|
||||
file_size
|
||||
);
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", safe_filename),
|
||||
)
|
||||
.header(header::CONTENT_TYPE, content_type)
|
||||
.header(header::CONTENT_LENGTH, file_size)
|
||||
.header(header::ACCEPT_RANGES, "bytes")
|
||||
.header("X-Content-Type-Options", "nosniff")
|
||||
.header("Cache-Control", "public, max-age=3600")
|
||||
.body(body)
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
async fn handle_range_request(
|
||||
file_path: &PathBuf,
|
||||
file_size: u64,
|
||||
range_header: &header::HeaderValue,
|
||||
file_info: &MediaFile,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let range_str = range_header.to_str().map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
// Parse range header (format: "bytes=start-end")
|
||||
if !range_str.starts_with("bytes=") {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let range_spec = &range_str[6..];
|
||||
let parts: Vec<&str> = range_spec.split('-').collect();
|
||||
|
||||
if parts.len() != 2 {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let start: u64 = if parts[0].is_empty() {
|
||||
0
|
||||
} else {
|
||||
parts[0].parse().map_err(|_| StatusCode::BAD_REQUEST)?
|
||||
};
|
||||
|
||||
let end: u64 = if parts[1].is_empty() {
|
||||
file_size - 1
|
||||
} else {
|
||||
parts[1].parse::<u64>().map_err(|_| StatusCode::BAD_REQUEST)?
|
||||
.min(file_size - 1)
|
||||
};
|
||||
|
||||
if start > end || start >= file_size {
|
||||
return Err(StatusCode::RANGE_NOT_SATISFIABLE);
|
||||
}
|
||||
|
||||
let content_length = end - start + 1;
|
||||
|
||||
// Open file and seek to start position
|
||||
let mut file = File::open(file_path)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
file.seek(std::io::SeekFrom::Start(start))
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Create limited stream
|
||||
let limited_file = file.take(content_length);
|
||||
let stream = ReaderStream::new(limited_file);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let filename = file_path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("download");
|
||||
let safe_filename = filename.replace(&['/', '\\', '\0'][..], "_");
|
||||
|
||||
let content_type = file_info
|
||||
.content_type
|
||||
.as_deref()
|
||||
.unwrap_or("application/octet-stream");
|
||||
|
||||
tracing::info!(
|
||||
"Range request: {} bytes {}-{}/{} ({})",
|
||||
safe_filename,
|
||||
start,
|
||||
end,
|
||||
file_size,
|
||||
content_length
|
||||
);
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::PARTIAL_CONTENT)
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", safe_filename),
|
||||
)
|
||||
.header(header::CONTENT_TYPE, content_type)
|
||||
.header(header::CONTENT_LENGTH, content_length)
|
||||
.header(
|
||||
header::CONTENT_RANGE,
|
||||
format!("bytes {}-{}/{}", start, end, file_size),
|
||||
)
|
||||
.header(header::ACCEPT_RANGES, "bytes")
|
||||
.header("X-Content-Type-Options", "nosniff")
|
||||
.body(body)
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
// Format bytes to human-readable size
|
||||
fn format_size(bytes: u64) -> String {
|
||||
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB", "PB"];
|
||||
|
||||
if bytes == 0 {
|
||||
return "0 B".to_string();
|
||||
}
|
||||
|
||||
let bytes_f = bytes as f64;
|
||||
let i = (bytes_f.log10() / 1024_f64.log10()).floor() as usize;
|
||||
let i = i.min(UNITS.len() - 1);
|
||||
|
||||
let size = bytes_f / 1024_f64.powi(i as i32);
|
||||
|
||||
format!("{:.2} {}", size, UNITS[i])
|
||||
}
|
||||
|
||||
// HTML escape to prevent XSS
|
||||
fn html_escape(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
|
||||
// Create a safe HTML element id from an arbitrary string
|
||||
fn safe_html_id(s: &str) -> String {
|
||||
s.chars()
|
||||
.map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
|
||||
.collect()
|
||||
}
|
||||
Reference in New Issue
Block a user