use askama::Template; use axum::{ Router, extract::{Path as AxumPath, State}, http::{HeaderMap, StatusCode, header}, response::{IntoResponse, Redirect, Sse}, routing::get, }; use futures::StreamExt; use mime_guess::from_path; use notify::{Config, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; use rand::seq::SliceRandom; use std::convert::Infallible; use std::ffi::OsStr; use std::net::{SocketAddr, ToSocketAddrs}; use std::path::{Path as StdPath, PathBuf}; use std::sync::Arc; use std::time::Duration; use tokio::fs; use tokio::sync::broadcast; use tokio_stream::wrappers::BroadcastStream; use syntect::highlighting::ThemeSet; use syntect::parsing::SyntaxSet; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use webbrowser; use clap::Parser; use std::io; mod markdown; use markdown::markdown_to_html; mod other; use other::code_to_html; #[derive(Clone)] pub struct FileEntry { pub name: String, pub link: String, pub is_dir: bool, } #[derive(Template)] #[template(path = "dir.html")] pub struct DirectoryTemplate { pub title_path: String, pub files: Vec, } #[derive(Template)] #[template(path = "file.html")] pub struct NoteTemplate { pub filename: String, pub back_link: String, pub content: String, pub sse_url: String, pub copy_path: String, } #[derive(clap::Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Host to listen #[arg(long, default_value_t = String::from("localhost"))] host: String, /// Port to listen #[arg(short, long, default_value_t = 8080)] port: u16, /// Open browser #[arg(long, default_value_t = false)] browser: bool, /// Open browser on random note. Requires flag --broswer #[arg(long, default_value_t = false, requires = "browser")] random: bool, /// Notes root #[arg()] root: PathBuf, } #[derive(Clone)] struct AppState { syntax_set: Arc, theme_set: Arc, tx: Arc>, root: PathBuf, } #[tokio::main] async fn main() { let args = Args::parse(); if !args.root.is_dir() { eprintln!("Root {} is not a directory", args.root.display()); std::process::exit(1); } tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( std::env::var("RUST_LOG").unwrap_or_else(|_| "info,tower_http=warn".into()), )) .with(tracing_subscriber::fmt::layer()) .init(); let ss = SyntaxSet::load_defaults_newlines(); let ts = ThemeSet::load_defaults(); let (tx, _rx) = broadcast::channel::(100); let state = AppState { syntax_set: Arc::new(ss), theme_set: Arc::new(ts), tx: Arc::new(tx), root: args.root.clone(), }; let watcher_state = state.clone(); tokio::spawn(async move { run_file_watcher(watcher_state).await; }); let app = Router::new() .route("/", get(root_handler)) .route("/random", get(random_file)) .route("/*path", get(serve_file)) .route("/events/*path", get(sse_handler)) .with_state(state) .layer(TraceLayer::new_for_http()); let addr = resolve_addr(&args.host, args.port).expect("Failed to resolve address"); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); let actual_addr = listener.local_addr().expect("Failed to get local address"); println!("Server started on http://{actual_addr}"); if args.browser { let mut url = format!("http://{actual_addr}"); if args.random { url = format!("{url}/random") } let _ = webbrowser::open(url.as_str()); } axum::serve(listener, app).await.unwrap(); } fn resolve_addr(host: &str, port: u16) -> io::Result { let addr_str = format!("{host}:{port}"); addr_str .to_socket_addrs()? .next() .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Cannot resolve addr")) } async fn root_handler(State(state): State) -> impl IntoResponse { match render_directory_index(&state.root, "").await { Ok(template) => template.into_response(), Err(e) => e.into_response(), } } async fn serve_file( State(state): State, AxumPath(full_path): AxumPath, ) -> impl IntoResponse { if full_path.is_empty() { return Err(StatusCode::NOT_FOUND); } let mut requested_path = state.root.clone(); requested_path.push(&full_path); let Ok(safe_path) = fs::canonicalize(&requested_path).await else { return Err(StatusCode::NOT_FOUND); }; let Ok(base_dir) = fs::canonicalize(&state.root).await else { return Err(StatusCode::INTERNAL_SERVER_ERROR); }; if !safe_path.starts_with(&base_dir) { eprintln!("Path traversal attempt: {}", safe_path.display()); return Err(StatusCode::FORBIDDEN); } let metadata = match fs::metadata(&safe_path).await { Ok(m) => m, Err(_) => return Err(StatusCode::NOT_FOUND), }; if metadata.is_dir() { return match render_directory_index(&safe_path, &full_path).await { Ok(t) => Ok(t.into_response()), Err(e) => Err(e), }; } let extension = safe_path .extension() .and_then(|ext| ext.to_str()) .unwrap_or("") .to_lowercase(); let is_image = matches!( extension.as_str(), "png" | "jpg" | "jpeg" | "gif" | "svg" | "webp" | "bmp" | "ico" ); if is_image { let file_content = match fs::read(&safe_path).await { Ok(content) => content, Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR), }; let mime_type = from_path(&safe_path).first_or_octet_stream(); return Ok(([(header::CONTENT_TYPE, mime_type.as_ref())], file_content).into_response()); } let content = match fs::read_to_string(&safe_path).await { Ok(c) => c, Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR), }; let html_content = if extension == "md" { markdown_to_html(&content, &state.syntax_set, &state.theme_set, &full_path) } else { code_to_html( &content, extension.as_str(), &state.syntax_set, &state.theme_set, ) }; let filename = safe_path .file_name() .and_then(OsStr::to_str) .unwrap_or("Unknown") .to_string(); let back_link = if let Some(pos) = full_path.rfind('/') { let parent = &full_path[..pos]; if parent.is_empty() { "/".to_string() } else { format!("/{parent}") } } else { "/".to_string() }; let sse_url = format!("/events/{full_path}"); let copy_path = format!("note edit {full_path}"); let template = NoteTemplate { filename, back_link, content: html_content, sse_url, copy_path, }; Ok(template.into_response()) } async fn render_directory_index( dir_path: &StdPath, request_path: &str, ) -> Result { let mut entries = match fs::read_dir(dir_path).await { Ok(list) => list, Err(_) => return Err(StatusCode::FORBIDDEN), }; let mut files: Vec = Vec::new(); while let Some(entry) = entries .next_entry() .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? { let file_name = entry.file_name().to_string_lossy().to_string(); if file_name.starts_with('.') { continue; } let is_dir = entry.metadata().await.map(|m| m.is_dir()).unwrap_or(false); let link = if request_path.is_empty() { file_name.clone() } else { format!("{}/{}", request_path, file_name) }; files.push(FileEntry { name: file_name, link, is_dir, }); } files.sort_by(|a, b| match (a.is_dir, b.is_dir) { (true, false) => std::cmp::Ordering::Less, (false, true) => std::cmp::Ordering::Greater, _ => a.name.cmp(&b.name), }); let title_path = request_path.trim_start_matches('/').to_string(); Ok(DirectoryTemplate { title_path, files }) } async fn sse_handler( State(state): State, AxumPath(full_path): AxumPath, ) -> impl IntoResponse { let mut headers = HeaderMap::new(); headers.insert( header::CONTENT_TYPE, header::HeaderValue::from_static("text/event-stream"), ); headers.insert( header::CACHE_CONTROL, header::HeaderValue::from_static("no-cache"), ); headers.insert( header::CONNECTION, header::HeaderValue::from_static("keep-alive"), ); let rx = state.tx.subscribe(); let requested_path = full_path.clone(); let stream = BroadcastStream::new(rx).filter_map(move |res| { let req_path = requested_path.clone(); async move { match res { Ok(changed_path) => { if changed_path.contains(&req_path) || changed_path.ends_with(&req_path) { Some(Ok::( axum::response::sse::Event::default() .event("reload") .data(""), )) } else { None } } Err(_) => None, } } }); let sse = Sse::new(stream).keep_alive( axum::response::sse::KeepAlive::new() .interval(Duration::from_secs(15)) .text("ping"), ); (headers, sse) } async fn random_file(State(state): State) -> impl IntoResponse { let mut files: Vec = Vec::new(); let mut stack = vec![state.root.clone()]; while let Some(current_dir) = stack.pop() { let Ok(mut entries) = fs::read_dir(¤t_dir).await else { continue; }; while let Ok(Some(entry)) = entries.next_entry().await { let path = entry.path(); if entry.file_name().to_string_lossy().starts_with('.') { continue; } if entry.metadata().await.map(|m| m.is_dir()).unwrap_or(false) { stack.push(path); } else { files.push(path); } } } if files.is_empty() { return Err(StatusCode::NOT_FOUND); } let mut rng = rand::thread_rng(); let random_path = files.choose(&mut rng).unwrap(); let relative_path = match random_path.strip_prefix(&state.root) { Ok(p) => p, Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR), }; let url_path = relative_path.to_string_lossy().replace('\\', "/"); Ok(Redirect::temporary(&format!("/{url_path}"))) } async fn run_file_watcher(state: AppState) { let (tx_fs, mut rx_fs) = tokio::sync::mpsc::channel::(100); let tx_fs_clone = tx_fs.clone(); let mut watcher = RecommendedWatcher::new( move |res: Result| { if let Ok(event) = res && matches!( event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) ) { for path in event.paths { let _ = tx_fs_clone.blocking_send(path); } } }, Config::default(), ) .expect("Failed to create watcher"); if let Err(e) = watcher.watch(&state.root, RecursiveMode::Recursive) { eprintln!("Failed set watcher: {e}"); return; } while let Some(path) = rx_fs.recv().await { if let Some(path_str) = path.to_str() { let _ = state.tx.send(path_str.to_string()); } } }