Initial Code Commit

This commit is contained in:
2025-12-12 19:12:36 +00:00
commit e021191eb3
5 changed files with 1036 additions and 0 deletions
+497
View File
@@ -0,0 +1,497 @@
use axum::{
body::Body,
extract::{Path, State},
http::{header, HeaderMap, StatusCode},
middleware::{self, Next},
response::{Html, IntoResponse, Response},
routing::get,
Router,
};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
net::SocketAddr,
path::PathBuf,
sync::Arc,
time::Duration,
};
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use tokio_util::io::ReaderStream;
use tower_http::{
compression::CompressionLayer,
timeout::TimeoutLayer,
trace::TraceLayer,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MediaFile {
name: String,
path: String,
description: Option<String>,
#[serde(default)]
content_type: Option<String>,
#[serde(default)]
size_bytes: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct MediaConfig {
files: HashMap<String, MediaFile>,
}
#[derive(Debug, Clone)]
struct Config {
port: u16,
host: String,
request_timeout: u64,
allowed_ips: Option<Vec<String>>,
chunk_size: usize,
domain: Option<String>,
}
#[derive(Clone)]
struct AppState {
files: Arc<HashMap<String, MediaFile>>,
config: Arc<Config>,
}
#[tokio::main]
async fn main() {
// Initialize tracing for logging
tracing_subscriber::fmt::init();
// Load environment configuration
dotenv::dotenv().ok();
let config = load_config();
// Load media files configuration
let media_config = load_media_config("media_config.json")
.expect("Failed to load media configuration");
let state = AppState {
files: Arc::new(media_config.files),
config: Arc::new(config.clone()),
};
let app = Router::new()
.route("/", get(list_files))
.route("/download/:file_id", get(download_file))
.route("/health", get(health_check))
.layer(middleware::from_fn_with_state(
state.clone(),
security_headers_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
ip_whitelist_middleware,
))
.layer(CompressionLayer::new())
.layer(TimeoutLayer::new(Duration::from_secs(config.request_timeout)))
.layer(TraceLayer::new_for_http())
.with_state(state);
let addr = SocketAddr::from((
config.host.parse::<std::net::IpAddr>().unwrap(),
config.port,
));
println!("📥 Public media download server running on http://{}", addr);
println!("✅ No authentication required - public access enabled");
println!("💾 Optimized for large file streaming (100GB+)");
println!("📋 Configuration loaded from .env and media_config.json");
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
fn load_config() -> Config {
Config {
port: std::env::var("PORT")
.unwrap_or_else(|_| "3000".to_string())
.parse()
.expect("Invalid PORT"),
host: std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
request_timeout: std::env::var("REQUEST_TIMEOUT_SECS")
.unwrap_or_else(|_| "3600".to_string()) // 1 hour default for large files
.parse()
.expect("Invalid REQUEST_TIMEOUT_SECS"),
allowed_ips: std::env::var("ALLOWED_IPS")
.ok()
.map(|s| s.split(',').map(|ip| ip.trim().to_string()).collect()),
chunk_size: std::env::var("CHUNK_SIZE_KB")
.unwrap_or_else(|_| "8192".to_string()) // 8MB chunks for large files
.parse::<usize>()
.expect("Invalid CHUNK_SIZE_KB")
* 1024,
domain: std::env::var("DOMAIN").ok().map(|d| d.trim_end_matches('/').to_string()),
}
}
fn load_media_config(path: &str) -> Result<MediaConfig, Box<dyn std::error::Error>> {
let content = std::fs::read_to_string(path)?;
let mut config: MediaConfig = serde_json::from_str(&content)?;
// Validate and get file sizes
for (id, file) in &mut config.files {
let path = PathBuf::from(&file.path);
if !path.exists() {
tracing::warn!("File '{}' at path '{}' does not exist", id, file.path);
} else if !path.is_file() {
tracing::warn!("Path '{}' for file '{}' is not a file", file.path, id);
} else {
// Get file size
if let Ok(metadata) = std::fs::metadata(&path) {
file.size_bytes = Some(metadata.len());
tracing::info!(
"Loaded file '{}': {} ({} bytes)",
id,
file.name,
metadata.len()
);
}
}
}
Ok(config)
}
// IP whitelist middleware (optional)
async fn ip_whitelist_middleware(
State(state): State<AppState>,
headers: HeaderMap,
request: axum::http::Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
if let Some(allowed_ips) = &state.config.allowed_ips {
// Try to get real IP from X-Forwarded-For or X-Real-IP headers
let client_ip = headers
.get("X-Forwarded-For")
.or_else(|| headers.get("X-Real-IP"))
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim());
if let Some(ip) = client_ip {
if !allowed_ips.iter().any(|allowed| allowed == ip || allowed == "*") {
tracing::warn!("Access denied for IP: {}", ip);
return Err(StatusCode::FORBIDDEN);
}
}
}
Ok(next.run(request).await)
}
// Security headers middleware
async fn security_headers_middleware(
State(_state): State<AppState>,
request: axum::http::Request<Body>,
next: Next,
) -> Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
// Basic security headers
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
headers.insert("X-Frame-Options", "SAMEORIGIN".parse().unwrap());
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
headers.insert("Referrer-Policy", "no-referrer".parse().unwrap());
response
}
async fn health_check() -> impl IntoResponse {
(StatusCode::OK, "OK")
}
async fn list_files(State(state): State<AppState>) -> Html<String> {
// Read template from file
let template_path = "templates/index.html";
let template = match tokio::fs::read_to_string(template_path).await {
Ok(t) => t,
Err(e) => {
tracing::error!("Failed to read template {}: {}", template_path, e);
return Html(format!(
"<h1>500 Internal Server Error</h1><p>Failed to load template: {}</p>",
e
));
}
};
let file_count = state.files.len();
let total_size: u64 = state.files.values().filter_map(|f| f.size_bytes).sum();
// Determine base URL to display in commands: prefer DOMAIN env var, else fall back to localhost:PORT
let base_url = state
.config
.domain
.clone()
.unwrap_or_else(|| format!("http://localhost:{}", state.config.port));
let mut file_list_items = String::new();
// Sort files by ID for consistent display order
let mut sorted_files: Vec<_> = state.files.iter().collect();
sorted_files.sort_by_key(|(id, _)| *id);
for (id, file) in sorted_files {
let size_str = file.size_bytes
.map(format_size)
.unwrap_or_else(|| "Unknown size".to_string());
// Build safe element id for per-file UI
let safe_elem_id = safe_html_id(id);
// Build example commands using base_url
let curl_cmd = format!("curl -O {}/download/{}", base_url, id);
file_list_items.push_str(&format!(
r#"
<li class="file-item">
<div class="file-name">{name}</div>
<div class="file-meta">
<span class="file-id">ID: {id}</span>
<span class="file-size">📦 {size}</span>
</div>
<div class="file-desc">{desc}</div>
<a href="/download/{id}" class="download-btn">⬇️ Download</a>
<div style="margin-top:12px; display:flex; gap:8px; align-items:center;">
<button class="show-cmd" data-target="cmd-{elem}">Show Command</button>
<button class="copy-cmd" data-target="cmd-{elem}">Copy</button>
</div>
<pre id="cmd-{elem}" class="usage-pre" style="display:none; margin-top:12px; background:#111827; color:#f9fafb; padding:12px; border-radius:6px; overflow:auto;">{curl}</pre>
</li>
"#,
name = html_escape(&file.name),
id = html_escape(id),
size = size_str,
desc = html_escape(file.description.as_deref().unwrap_or("No description")),
elem = html_escape(&safe_elem_id),
curl = html_escape(&curl_cmd),
));
}
let html = template
.replace("{FILE_COUNT}", &file_count.to_string())
.replace("{TOTAL_SIZE}", &format_size(total_size))
.replace("{FILE_LIST_ITEMS}", &file_list_items)
.replace("{BASE_URL}", &html_escape(&base_url));
Html(html)
}
async fn download_file(
State(state): State<AppState>,
Path(file_id): Path<String>,
headers: HeaderMap,
) -> Result<Response, StatusCode> {
// Sanitize file_id to prevent path traversal
if file_id.contains("..") || file_id.contains('/') || file_id.contains('\\') {
tracing::warn!("Path traversal attempt detected: {}", file_id);
return Err(StatusCode::BAD_REQUEST);
}
let file_info = state
.files
.get(&file_id)
.ok_or(StatusCode::NOT_FOUND)?;
// Canonicalize path to prevent symlink attacks
let file_path = PathBuf::from(&file_info.path);
let canonical_path = file_path
.canonicalize()
.map_err(|_| StatusCode::NOT_FOUND)?;
// Verify file still exists and is a file
if !canonical_path.is_file() {
tracing::error!("File not found or is not a file: {:?}", canonical_path);
return Err(StatusCode::NOT_FOUND);
}
// Get file metadata
let metadata = tokio::fs::metadata(&canonical_path)
.await
.map_err(|e| {
tracing::error!("Failed to get file metadata: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let file_size = metadata.len();
// Check for Range header to support resume
let range_header = headers.get(header::RANGE);
if let Some(range) = range_header {
// Handle range requests for resume support
return handle_range_request(&canonical_path, file_size, range, file_info).await;
}
// Open file for streaming
let file = File::open(&canonical_path)
.await
.map_err(|e| {
tracing::error!("Failed to open file: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Create stream with optimized buffer size for large files
let stream = ReaderStream::with_capacity(file, state.config.chunk_size);
let body = Body::from_stream(stream);
let filename = canonical_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("download");
// Sanitize filename
let safe_filename = filename.replace(&['/', '\\', '\0'][..], "_");
let content_type = file_info
.content_type
.as_deref()
.unwrap_or("application/octet-stream");
tracing::info!(
"File download started: {} ({}) - {} bytes",
file_id,
safe_filename,
file_size
);
Ok(Response::builder()
.status(StatusCode::OK)
.header(
header::CONTENT_DISPOSITION,
format!("attachment; filename=\"{}\"", safe_filename),
)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_LENGTH, file_size)
.header(header::ACCEPT_RANGES, "bytes")
.header("X-Content-Type-Options", "nosniff")
.header("Cache-Control", "public, max-age=3600")
.body(body)
.unwrap())
}
async fn handle_range_request(
file_path: &PathBuf,
file_size: u64,
range_header: &header::HeaderValue,
file_info: &MediaFile,
) -> Result<Response, StatusCode> {
let range_str = range_header.to_str().map_err(|_| StatusCode::BAD_REQUEST)?;
// Parse range header (format: "bytes=start-end")
if !range_str.starts_with("bytes=") {
return Err(StatusCode::BAD_REQUEST);
}
let range_spec = &range_str[6..];
let parts: Vec<&str> = range_spec.split('-').collect();
if parts.len() != 2 {
return Err(StatusCode::BAD_REQUEST);
}
let start: u64 = if parts[0].is_empty() {
0
} else {
parts[0].parse().map_err(|_| StatusCode::BAD_REQUEST)?
};
let end: u64 = if parts[1].is_empty() {
file_size - 1
} else {
parts[1].parse::<u64>().map_err(|_| StatusCode::BAD_REQUEST)?
.min(file_size - 1)
};
if start > end || start >= file_size {
return Err(StatusCode::RANGE_NOT_SATISFIABLE);
}
let content_length = end - start + 1;
// Open file and seek to start position
let mut file = File::open(file_path)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
file.seek(std::io::SeekFrom::Start(start))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Create limited stream
let limited_file = file.take(content_length);
let stream = ReaderStream::new(limited_file);
let body = Body::from_stream(stream);
let filename = file_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("download");
let safe_filename = filename.replace(&['/', '\\', '\0'][..], "_");
let content_type = file_info
.content_type
.as_deref()
.unwrap_or("application/octet-stream");
tracing::info!(
"Range request: {} bytes {}-{}/{} ({})",
safe_filename,
start,
end,
file_size,
content_length
);
Ok(Response::builder()
.status(StatusCode::PARTIAL_CONTENT)
.header(
header::CONTENT_DISPOSITION,
format!("attachment; filename=\"{}\"", safe_filename),
)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_LENGTH, content_length)
.header(
header::CONTENT_RANGE,
format!("bytes {}-{}/{}", start, end, file_size),
)
.header(header::ACCEPT_RANGES, "bytes")
.header("X-Content-Type-Options", "nosniff")
.body(body)
.unwrap())
}
// Format bytes to human-readable size
fn format_size(bytes: u64) -> String {
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB", "PB"];
if bytes == 0 {
return "0 B".to_string();
}
let bytes_f = bytes as f64;
let i = (bytes_f.log10() / 1024_f64.log10()).floor() as usize;
let i = i.min(UNITS.len() - 1);
let size = bytes_f / 1024_f64.powi(i as i32);
format!("{:.2} {}", size, UNITS[i])
}
// HTML escape to prevent XSS
fn html_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#x27;")
}
// Create a safe HTML element id from an arbitrary string
fn safe_html_id(s: &str) -> String {
s.chars()
.map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
.collect()
}