From 127d7e9c6724fb89beed98f0a88713391b474f27 Mon Sep 17 00:00:00 2001 From: xenofem Date: Thu, 28 Apr 2022 05:13:14 -0400 Subject: [PATCH] add persistent state --- src/main.rs | 17 +++++--- src/state.rs | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/upload.rs | 23 ++++++++--- src/util.rs | 3 ++ 4 files changed, 139 insertions(+), 12 deletions(-) create mode 100644 src/state.rs create mode 100644 src/util.rs diff --git a/src/main.rs b/src/main.rs index 144d612..c917f2a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,11 @@ mod download; mod file; +mod state; mod upload; +mod util; mod zip; use std::{ - collections::HashMap, path::PathBuf, fs::File, }; @@ -14,6 +15,8 @@ use actix_web::{ get, middleware::Logger, web, App, HttpRequest, HttpServer, Responder, HttpResponse, }; use actix_web_actors::ws; +use serde::{Deserialize, Serialize}; +use state::PersistentState; use time::OffsetDateTime; use tokio::sync::RwLock; @@ -35,15 +38,17 @@ impl UploadedFile { } } -#[derive(Clone)] +#[derive(Clone, Deserialize, Serialize)] pub struct DownloadableFile { name: String, size: u64, + #[serde(with = "state::timestamp")] modtime: OffsetDateTime, + #[serde(skip)] uploader: Option>, } -type AppData = web::Data>>; +type AppData = web::Data>; fn storage_dir() -> PathBuf { PathBuf::from(std::env::var("STORAGE_DIR").unwrap_or_else(|_| String::from("storage"))) @@ -52,11 +57,11 @@ fn storage_dir() -> PathBuf { #[get("/download/{file_code}")] async fn handle_download(req: HttpRequest, path: web::Path, data: AppData) -> actix_web::Result { let file_code = path.into_inner(); - if !file_code.as_bytes().iter().all(|c| c.is_ascii_alphanumeric()) { + if !util::is_ascii_alphanumeric(&file_code) { return Ok(HttpResponse::NotFound().finish()); } let data = data.read().await; - let info = data.get(&file_code); + let info = data.lookup_file(&file_code); if let Some(info) = info { Ok(download::DownloadingFile { file: File::open(storage_dir().join(file_code))?, @@ -76,7 +81,7 @@ async fn handle_upload(req: HttpRequest, stream: web::Payload, data: AppData) -> async fn main() -> std::io::Result<()> { env_logger::init(); - let data: AppData = web::Data::new(RwLock::new(HashMap::new())); + let data: AppData = web::Data::new(RwLock::new(PersistentState::load().await?)); let static_dir = PathBuf::from(std::env::var("STATIC_DIR").unwrap_or_else(|_| String::from("static"))); diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..9f5f8a8 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,108 @@ +use std::{collections::HashMap, io::ErrorKind}; + +use log::{info, error}; +use tokio::{fs::File, io::{AsyncReadExt, AsyncWriteExt}}; + +use crate::{DownloadableFile, storage_dir}; + +const STATE_FILE_NAME: &str = "files.json"; + +pub(crate) mod timestamp { + use core::fmt; + + use serde::{Serializer, Deserializer, de::Visitor}; + use time::OffsetDateTime; + + pub(crate) fn serialize(time: &OffsetDateTime, ser: S) -> Result { + ser.serialize_i64(time.unix_timestamp()) + } + + struct I64Visitor; + + impl<'de> Visitor<'de> for I64Visitor { + type Value = i64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a signed integer") + } + + fn visit_i64(self, v: i64) -> Result { + Ok(v) + } + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result { + Ok(OffsetDateTime::from_unix_timestamp(de.deserialize_i64(I64Visitor)?).unwrap_or_else(|_| OffsetDateTime::now_utc())) + } +} + +pub(crate) struct PersistentState(HashMap); +impl PersistentState { + pub(crate) async fn load() -> std::io::Result { + let open_result = File::open(storage_dir().join(STATE_FILE_NAME)).await; + match open_result { + Ok(mut f) => { + let mut buf = String::new(); + f.read_to_string(&mut buf).await?; + let map: HashMap = serde_json::from_str(&buf)?; + info!("Loaded {} file entries from persistent storage", map.len()); + let mut filtered: HashMap = HashMap::new(); + for (key, info) in map.into_iter() { + if !crate::util::is_ascii_alphanumeric(&key) { + error!("Invalid key in persistent storage: {}", key); + continue; + } + let file = if let Ok(f) = File::open(storage_dir().join(&key)).await { + f + } else { + error!("Unable to open file {} referenced in persistent storage", key); + continue; + }; + let metadata = if let Ok(md) = file.metadata().await { + md + } else { + error!("Unable to get metadata for file {} referenced in persistent storage", key); + continue; + }; + if metadata.len() != info.size { + error!("Mismatched file size for file {} referenced in persistent storage: expected {}, found {}", key, info.size, metadata.len()); + continue + } + filtered.insert(key, info); + } + Ok(Self(filtered)) + } + Err(e) => { + if let ErrorKind::NotFound = e.kind() { + Ok(Self(HashMap::new())) + } else { + Err(e) + } + } + } + } + + async fn save(&mut self) -> std::io::Result<()> { + File::create(storage_dir().join(STATE_FILE_NAME)).await?.write_all(&serde_json::to_vec_pretty(&self.0)?).await + } + + pub(crate) async fn add_file(&mut self, key: String, file: DownloadableFile) -> std::io::Result<()> { + self.0.insert(key, file); + self.save().await + } + + pub(crate) fn lookup_file(&self, key: &str) -> Option { + self.0.get(key).map(|f| f.clone()) + } + + pub(crate) async fn remove_file(&mut self, key: &str) -> std::io::Result<()> { + self.0.remove(key); + self.save().await + } + + pub(crate) fn remove_uploader(&mut self, key: &str) { + if let Some(f) = self.0.get_mut(key) { + f.uploader.take(); + } + } +} diff --git a/src/upload.rs b/src/upload.rs index 90c95ec..53191ff 100644 --- a/src/upload.rs +++ b/src/upload.rs @@ -61,7 +61,7 @@ pub struct Uploader { } impl Uploader { - pub fn new(app_data: super::AppData) -> Self { + pub(crate) fn new(app_data: super::AppData) -> Self { Self { writer: None, storage_filename: String::new(), @@ -110,6 +110,7 @@ impl StreamHandler> for Uploader { Ok(m) => m, Err(e) => { error!("Websocket error: {}", e); + self.cleanup_after_error(ctx); ctx.stop(); return; } @@ -122,6 +123,7 @@ impl StreamHandler> for Uploader { code: e.close_code(), description: Some(e.to_string()), })); + self.cleanup_after_error(ctx); ctx.stop(); } Ok(true) => { @@ -134,9 +136,7 @@ impl StreamHandler> for Uploader { let data = self.app_data.clone(); let filename = self.storage_filename.clone(); ctx.wait(actix::fut::wrap_future(async move { - if let Some(f) = data.write().await.get_mut(&filename) { - f.uploader.take(); - } + data.write().await.remove_uploader(&filename); })); ctx.stop(); } @@ -219,10 +219,10 @@ impl Uploader { data .write() .await - .insert( + .add_file( storage_filename_copy, downloadable_file, - ); + ).await.unwrap(); })); ctx.text(self.storage_filename.as_str()); } @@ -267,4 +267,15 @@ impl Uploader { Err(Error::UnexpectedMessageType) } } + + fn cleanup_after_error(&mut self, ctx: &mut ::Context) { + let data = self.app_data.clone(); + let filename = self.storage_filename.clone(); + ctx.spawn(actix::fut::wrap_future(async move { + data + .write() + .await + .remove_file(&filename).await.unwrap(); + })); + } } diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..bcc5caf --- /dev/null +++ b/src/util.rs @@ -0,0 +1,3 @@ +pub(crate) fn is_ascii_alphanumeric(s: &str) -> bool { + s.as_bytes().iter().all(|c| c.is_ascii_alphanumeric()) +}