mdpreview/src/main.rs

441 lines
12 KiB
Rust

use askama::Template;
use axum::{
Router,
extract::{Path as AxumPath, State},
http::{HeaderMap, StatusCode, header},
response::{IntoResponse, Redirect, Sse},
routing::get,
};
use futures::StreamExt;
use mime_guess::from_path;
use notify::{Config, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
use rand::seq::SliceRandom;
use std::convert::Infallible;
use std::ffi::OsStr;
use std::net::{SocketAddr, ToSocketAddrs};
use std::path::{Path as StdPath, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::fs;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use syntect::highlighting::ThemeSet;
use syntect::parsing::SyntaxSet;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use webbrowser;
use clap::Parser;
use std::io;
mod markdown;
use markdown::markdown_to_html;
mod other;
use other::code_to_html;
#[derive(Clone)]
pub struct FileEntry {
pub name: String,
pub link: String,
pub is_dir: bool,
}
#[derive(Template)]
#[template(path = "dir.html")]
pub struct DirectoryTemplate {
pub title_path: String,
pub files: Vec<FileEntry>,
}
#[derive(Template)]
#[template(path = "file.html")]
pub struct NoteTemplate {
pub filename: String,
pub back_link: String,
pub content: String,
pub sse_url: String,
pub copy_path: String,
}
#[derive(clap::Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Host to listen
#[arg(long, default_value_t = String::from("localhost"))]
host: String,
/// Port to listen
#[arg(short, long, default_value_t = 8080)]
port: u16,
/// Open browser
#[arg(long, default_value_t = false)]
browser: bool,
/// Open browser on random note. Requires flag --broswer
#[arg(long, default_value_t = false, requires = "browser")]
random: bool,
/// Notes root
#[arg()]
root: PathBuf,
}
#[derive(Clone)]
struct AppState {
syntax_set: Arc<SyntaxSet>,
theme_set: Arc<ThemeSet>,
tx: Arc<broadcast::Sender<String>>,
root: PathBuf,
}
#[tokio::main]
async fn main() {
let args = Args::parse();
if !args.root.is_dir() {
eprintln!("Root {} is not a directory", args.root.display());
std::process::exit(1);
}
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "info,tower_http=warn".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
let ss = SyntaxSet::load_defaults_newlines();
let ts = ThemeSet::load_defaults();
let (tx, _rx) = broadcast::channel::<String>(100);
let state = AppState {
syntax_set: Arc::new(ss),
theme_set: Arc::new(ts),
tx: Arc::new(tx),
root: args.root.clone(),
};
let watcher_state = state.clone();
tokio::spawn(async move {
run_file_watcher(watcher_state).await;
});
let app = Router::new()
.route("/", get(root_handler))
.route("/random", get(random_file))
.route("/*path", get(serve_file))
.route("/events/*path", get(sse_handler))
.with_state(state)
.layer(TraceLayer::new_for_http());
let addr = resolve_addr(&args.host, args.port).expect("Failed to resolve address");
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
let actual_addr = listener.local_addr().expect("Failed to get local address");
println!("Server started on http://{actual_addr}");
if args.browser {
let mut url = format!("http://{actual_addr}");
if args.random {
url = format!("{url}/random")
}
let _ = webbrowser::open(url.as_str());
}
axum::serve(listener, app).await.unwrap();
}
fn resolve_addr(host: &str, port: u16) -> io::Result<SocketAddr> {
let addr_str = format!("{host}:{port}");
addr_str
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Cannot resolve addr"))
}
async fn root_handler(State(state): State<AppState>) -> impl IntoResponse {
match render_directory_index(&state.root, "").await {
Ok(template) => template.into_response(),
Err(e) => e.into_response(),
}
}
async fn serve_file(
State(state): State<AppState>,
AxumPath(full_path): AxumPath<String>,
) -> impl IntoResponse {
if full_path.is_empty() {
return Err(StatusCode::NOT_FOUND);
}
let mut requested_path = state.root.clone();
requested_path.push(&full_path);
let Ok(safe_path) = fs::canonicalize(&requested_path).await else {
return Err(StatusCode::NOT_FOUND);
};
let Ok(base_dir) = fs::canonicalize(&state.root).await else {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
if !safe_path.starts_with(&base_dir) {
eprintln!("Path traversal attempt: {}", safe_path.display());
return Err(StatusCode::FORBIDDEN);
}
let metadata = match fs::metadata(&safe_path).await {
Ok(m) => m,
Err(_) => return Err(StatusCode::NOT_FOUND),
};
if metadata.is_dir() {
return match render_directory_index(&safe_path, &full_path).await {
Ok(t) => Ok(t.into_response()),
Err(e) => Err(e),
};
}
let extension = safe_path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("")
.to_lowercase();
let is_image = matches!(
extension.as_str(),
"png" | "jpg" | "jpeg" | "gif" | "svg" | "webp" | "bmp" | "ico"
);
if is_image {
let file_content = match fs::read(&safe_path).await {
Ok(content) => content,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let mime_type = from_path(&safe_path).first_or_octet_stream();
return Ok(([(header::CONTENT_TYPE, mime_type.as_ref())], file_content).into_response());
}
let content = match fs::read_to_string(&safe_path).await {
Ok(c) => c,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let html_content = if extension == "md" {
markdown_to_html(&content, &state.syntax_set, &state.theme_set, &full_path)
} else {
code_to_html(
&content,
extension.as_str(),
&state.syntax_set,
&state.theme_set,
)
};
let filename = safe_path
.file_name()
.and_then(OsStr::to_str)
.unwrap_or("Unknown")
.to_string();
let back_link = if let Some(pos) = full_path.rfind('/') {
let parent = &full_path[..pos];
if parent.is_empty() {
"/".to_string()
} else {
format!("/{parent}")
}
} else {
"/".to_string()
};
let sse_url = format!("/events/{full_path}");
let copy_path = format!("note edit {full_path}");
let template = NoteTemplate {
filename,
back_link,
content: html_content,
sse_url,
copy_path,
};
Ok(template.into_response())
}
async fn render_directory_index(
dir_path: &StdPath,
request_path: &str,
) -> Result<DirectoryTemplate, StatusCode> {
let mut entries = match fs::read_dir(dir_path).await {
Ok(list) => list,
Err(_) => return Err(StatusCode::FORBIDDEN),
};
let mut files: Vec<FileEntry> = Vec::new();
while let Some(entry) = entries
.next_entry()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
let file_name = entry.file_name().to_string_lossy().to_string();
if file_name.starts_with('.') {
continue;
}
let is_dir = entry.metadata().await.map(|m| m.is_dir()).unwrap_or(false);
let link = if request_path.is_empty() {
file_name.clone()
} else {
format!("{}/{}", request_path, file_name)
};
files.push(FileEntry {
name: file_name,
link,
is_dir,
});
}
files.sort_by(|a, b| match (a.is_dir, b.is_dir) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => a.name.cmp(&b.name),
});
let title_path = request_path.trim_start_matches('/').to_string();
Ok(DirectoryTemplate { title_path, files })
}
async fn sse_handler(
State(state): State<AppState>,
AxumPath(full_path): AxumPath<String>,
) -> impl IntoResponse {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/event-stream"),
);
headers.insert(
header::CACHE_CONTROL,
header::HeaderValue::from_static("no-cache"),
);
headers.insert(
header::CONNECTION,
header::HeaderValue::from_static("keep-alive"),
);
let rx = state.tx.subscribe();
let requested_path = full_path.clone();
let stream = BroadcastStream::new(rx).filter_map(move |res| {
let req_path = requested_path.clone();
async move {
match res {
Ok(changed_path) => {
if changed_path.contains(&req_path) || changed_path.ends_with(&req_path) {
Some(Ok::<axum::response::sse::Event, Infallible>(
axum::response::sse::Event::default()
.event("reload")
.data(""),
))
} else {
None
}
}
Err(_) => None,
}
}
});
let sse = Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(15))
.text("ping"),
);
(headers, sse)
}
async fn random_file(State(state): State<AppState>) -> impl IntoResponse {
let mut files: Vec<PathBuf> = Vec::new();
let mut stack = vec![state.root.clone()];
while let Some(current_dir) = stack.pop() {
let Ok(mut entries) = fs::read_dir(&current_dir).await else {
continue;
};
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
if entry.file_name().to_string_lossy().starts_with('.') {
continue;
}
if entry.metadata().await.map(|m| m.is_dir()).unwrap_or(false) {
stack.push(path);
} else {
files.push(path);
}
}
}
if files.is_empty() {
return Err(StatusCode::NOT_FOUND);
}
let mut rng = rand::thread_rng();
let random_path = files.choose(&mut rng).unwrap();
let relative_path = match random_path.strip_prefix(&state.root) {
Ok(p) => p,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let url_path = relative_path.to_string_lossy().replace('\\', "/");
Ok(Redirect::temporary(&format!("/{url_path}")))
}
async fn run_file_watcher(state: AppState) {
let (tx_fs, mut rx_fs) = tokio::sync::mpsc::channel::<PathBuf>(100);
let tx_fs_clone = tx_fs.clone();
let mut watcher = RecommendedWatcher::new(
move |res: Result<notify::Event, notify::Error>| {
if let Ok(event) = res
&& matches!(
event.kind,
EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
)
{
for path in event.paths {
let _ = tx_fs_clone.blocking_send(path);
}
}
},
Config::default(),
)
.expect("Failed to create watcher");
if let Err(e) = watcher.watch(&state.root, RecursiveMode::Recursive) {
eprintln!("Failed set watcher: {e}");
return;
}
while let Some(path) = rx_fs.recv().await {
if let Some(path_str) = path.to_str() {
let _ = state.tx.send(path_str.to_string());
}
}
}