From 127d7e9c6724fb89beed98f0a88713391b474f27 Mon Sep 17 00:00:00 2001 From: xenofem Date: Thu, 28 Apr 2022 05:13:14 -0400 Subject: [PATCH 1/2] add persistent state --- src/main.rs | 17 +++++--- src/state.rs | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/upload.rs | 23 ++++++++--- src/util.rs | 3 ++ 4 files changed, 139 insertions(+), 12 deletions(-) create mode 100644 src/state.rs create mode 100644 src/util.rs diff --git a/src/main.rs b/src/main.rs index 144d612..c917f2a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,11 @@ mod download; mod file; +mod state; mod upload; +mod util; mod zip; use std::{ - collections::HashMap, path::PathBuf, fs::File, }; @@ -14,6 +15,8 @@ use actix_web::{ get, middleware::Logger, web, App, HttpRequest, HttpServer, Responder, HttpResponse, }; use actix_web_actors::ws; +use serde::{Deserialize, Serialize}; +use state::PersistentState; use time::OffsetDateTime; use tokio::sync::RwLock; @@ -35,15 +38,17 @@ impl UploadedFile { } } -#[derive(Clone)] +#[derive(Clone, Deserialize, Serialize)] pub struct DownloadableFile { name: String, size: u64, + #[serde(with = "state::timestamp")] modtime: OffsetDateTime, + #[serde(skip)] uploader: Option>, } -type AppData = web::Data>>; +type AppData = web::Data>; fn storage_dir() -> PathBuf { PathBuf::from(std::env::var("STORAGE_DIR").unwrap_or_else(|_| String::from("storage"))) @@ -52,11 +57,11 @@ fn storage_dir() -> PathBuf { #[get("/download/{file_code}")] async fn handle_download(req: HttpRequest, path: web::Path, data: AppData) -> actix_web::Result { let file_code = path.into_inner(); - if !file_code.as_bytes().iter().all(|c| c.is_ascii_alphanumeric()) { + if !util::is_ascii_alphanumeric(&file_code) { return Ok(HttpResponse::NotFound().finish()); } let data = data.read().await; - let info = data.get(&file_code); + let info = data.lookup_file(&file_code); if let Some(info) = info { Ok(download::DownloadingFile { file: File::open(storage_dir().join(file_code))?, @@ -76,7 +81,7 @@ async fn handle_upload(req: HttpRequest, stream: web::Payload, data: AppData) -> async fn main() -> std::io::Result<()> { env_logger::init(); - let data: AppData = web::Data::new(RwLock::new(HashMap::new())); + let data: AppData = web::Data::new(RwLock::new(PersistentState::load().await?)); let static_dir = PathBuf::from(std::env::var("STATIC_DIR").unwrap_or_else(|_| String::from("static"))); diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..9f5f8a8 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,108 @@ +use std::{collections::HashMap, io::ErrorKind}; + +use log::{info, error}; +use tokio::{fs::File, io::{AsyncReadExt, AsyncWriteExt}}; + +use crate::{DownloadableFile, storage_dir}; + +const STATE_FILE_NAME: &str = "files.json"; + +pub(crate) mod timestamp { + use core::fmt; + + use serde::{Serializer, Deserializer, de::Visitor}; + use time::OffsetDateTime; + + pub(crate) fn serialize(time: &OffsetDateTime, ser: S) -> Result { + ser.serialize_i64(time.unix_timestamp()) + } + + struct I64Visitor; + + impl<'de> Visitor<'de> for I64Visitor { + type Value = i64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a signed integer") + } + + fn visit_i64(self, v: i64) -> Result { + Ok(v) + } + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result { + Ok(OffsetDateTime::from_unix_timestamp(de.deserialize_i64(I64Visitor)?).unwrap_or_else(|_| OffsetDateTime::now_utc())) + } +} + +pub(crate) struct PersistentState(HashMap); +impl PersistentState { + pub(crate) async fn load() -> std::io::Result { + let open_result = File::open(storage_dir().join(STATE_FILE_NAME)).await; + match open_result { + Ok(mut f) => { + let mut buf = String::new(); + f.read_to_string(&mut buf).await?; + let map: HashMap = serde_json::from_str(&buf)?; + info!("Loaded {} file entries from persistent storage", map.len()); + let mut filtered: HashMap = HashMap::new(); + for (key, info) in map.into_iter() { + if !crate::util::is_ascii_alphanumeric(&key) { + error!("Invalid key in persistent storage: {}", key); + continue; + } + let file = if let Ok(f) = File::open(storage_dir().join(&key)).await { + f + } else { + error!("Unable to open file {} referenced in persistent storage", key); + continue; + }; + let metadata = if let Ok(md) = file.metadata().await { + md + } else { + error!("Unable to get metadata for file {} referenced in persistent storage", key); + continue; + }; + if metadata.len() != info.size { + error!("Mismatched file size for file {} referenced in persistent storage: expected {}, found {}", key, info.size, metadata.len()); + continue + } + filtered.insert(key, info); + } + Ok(Self(filtered)) + } + Err(e) => { + if let ErrorKind::NotFound = e.kind() { + Ok(Self(HashMap::new())) + } else { + Err(e) + } + } + } + } + + async fn save(&mut self) -> std::io::Result<()> { + File::create(storage_dir().join(STATE_FILE_NAME)).await?.write_all(&serde_json::to_vec_pretty(&self.0)?).await + } + + pub(crate) async fn add_file(&mut self, key: String, file: DownloadableFile) -> std::io::Result<()> { + self.0.insert(key, file); + self.save().await + } + + pub(crate) fn lookup_file(&self, key: &str) -> Option { + self.0.get(key).map(|f| f.clone()) + } + + pub(crate) async fn remove_file(&mut self, key: &str) -> std::io::Result<()> { + self.0.remove(key); + self.save().await + } + + pub(crate) fn remove_uploader(&mut self, key: &str) { + if let Some(f) = self.0.get_mut(key) { + f.uploader.take(); + } + } +} diff --git a/src/upload.rs b/src/upload.rs index 90c95ec..53191ff 100644 --- a/src/upload.rs +++ b/src/upload.rs @@ -61,7 +61,7 @@ pub struct Uploader { } impl Uploader { - pub fn new(app_data: super::AppData) -> Self { + pub(crate) fn new(app_data: super::AppData) -> Self { Self { writer: None, storage_filename: String::new(), @@ -110,6 +110,7 @@ impl StreamHandler> for Uploader { Ok(m) => m, Err(e) => { error!("Websocket error: {}", e); + self.cleanup_after_error(ctx); ctx.stop(); return; } @@ -122,6 +123,7 @@ impl StreamHandler> for Uploader { code: e.close_code(), description: Some(e.to_string()), })); + self.cleanup_after_error(ctx); ctx.stop(); } Ok(true) => { @@ -134,9 +136,7 @@ impl StreamHandler> for Uploader { let data = self.app_data.clone(); let filename = self.storage_filename.clone(); ctx.wait(actix::fut::wrap_future(async move { - if let Some(f) = data.write().await.get_mut(&filename) { - f.uploader.take(); - } + data.write().await.remove_uploader(&filename); })); ctx.stop(); } @@ -219,10 +219,10 @@ impl Uploader { data .write() .await - .insert( + .add_file( storage_filename_copy, downloadable_file, - ); + ).await.unwrap(); })); ctx.text(self.storage_filename.as_str()); } @@ -267,4 +267,15 @@ impl Uploader { Err(Error::UnexpectedMessageType) } } + + fn cleanup_after_error(&mut self, ctx: &mut ::Context) { + let data = self.app_data.clone(); + let filename = self.storage_filename.clone(); + ctx.spawn(actix::fut::wrap_future(async move { + data + .write() + .await + .remove_file(&filename).await.unwrap(); + })); + } } diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..bcc5caf --- /dev/null +++ b/src/util.rs @@ -0,0 +1,3 @@ +pub(crate) fn is_ascii_alphanumeric(s: &str) -> bool { + s.as_bytes().iter().all(|c| c.is_ascii_alphanumeric()) +} From bda6da33e8494f8d5c024c46b6f1c8b8e0e47d1f Mon Sep 17 00:00:00 2001 From: xenofem Date: Thu, 28 Apr 2022 05:18:35 -0400 Subject: [PATCH 2/2] cargo clippy and fmt --- src/download.rs | 17 ++++++++++----- src/file.rs | 47 ++++++++++++++++++++++++++++++---------- src/main.rs | 20 +++++++++-------- src/state.rs | 50 ++++++++++++++++++++++++++++++++----------- src/upload.rs | 57 ++++++++++++++++++++++++++----------------------- src/zip.rs | 6 +++--- 6 files changed, 129 insertions(+), 68 deletions(-) diff --git a/src/download.rs b/src/download.rs index 830e9fb..9183b91 100644 --- a/src/download.rs +++ b/src/download.rs @@ -4,8 +4,8 @@ use actix_web::{ body::{self, BoxBody, SizedStream}, http::{ header::{ - self, ContentDisposition, DispositionParam, - DispositionType, HeaderValue, HttpDate, EntityTag, IfUnmodifiedSince, IfMatch, IfNoneMatch, IfModifiedSince, + self, ContentDisposition, DispositionParam, DispositionType, EntityTag, HeaderValue, + HttpDate, IfMatch, IfModifiedSince, IfNoneMatch, IfUnmodifiedSince, }, StatusCode, }, @@ -52,7 +52,7 @@ impl DownloadingFile { ContentDisposition { disposition: DispositionType::Attachment, parameters: vec![DispositionParam::Filename(self.info.name)], - } + }, )); res.insert_header((header::LAST_MODIFIED, last_modified)); res.insert_header((header::ETAG, etag)); @@ -76,7 +76,12 @@ impl DownloadingFile { res.insert_header(( header::CONTENT_RANGE, - format!("bytes {}-{}/{}", offset, offset + length - 1, self.info.size), + format!( + "bytes {}-{}/{}", + offset, + offset + length - 1, + self.info.size + ), )); } else { res.insert_header((header::CONTENT_RANGE, format!("bytes */{}", length))); @@ -122,7 +127,9 @@ fn precondition_failed(req: &HttpRequest, etag: &EntityTag, last_modified: &Http fn not_modified(req: &HttpRequest, etag: &EntityTag, last_modified: &HttpDate) -> bool { match req.get_header::() { - Some(IfNoneMatch::Any) => { return true; } + Some(IfNoneMatch::Any) => { + return true; + } Some(IfNoneMatch::Items(ref items)) => { return items.iter().any(|item| item.weak_eq(etag)); } diff --git a/src/file.rs b/src/file.rs index 50e7330..0ba6bf8 100644 --- a/src/file.rs +++ b/src/file.rs @@ -1,4 +1,12 @@ -use std::{cmp, fs::File, future::Future, io::{self, Write}, path::PathBuf, pin::Pin, task::{Context, Poll, Waker}}; +use std::{ + cmp, + fs::File, + future::Future, + io::{self, Write}, + path::PathBuf, + pin::Pin, + task::{Context, Poll, Waker}, +}; use actix::Addr; use actix_web::error::{Error, ErrorInternalServerError}; @@ -53,7 +61,6 @@ impl Write for LiveFileWriter { } } - // This implementation of a file responder is copied pretty directly // from actix-files with some tweaks @@ -104,13 +111,19 @@ async fn live_file_reader_callback( use io::{Read as _, Seek as _}; let res = actix_web::web::block(move || { - trace!("reading up to {} bytes of file starting at {}", max_bytes, offset); + trace!( + "reading up to {} bytes of file starting at {}", + max_bytes, + offset + ); let mut buf = Vec::with_capacity(max_bytes); file.seek(io::SeekFrom::Start(offset))?; - let n_bytes = std::io::Read::by_ref(&mut file).take(max_bytes as u64).read_to_end(&mut buf)?; + let n_bytes = std::io::Read::by_ref(&mut file) + .take(max_bytes as u64) + .read_to_end(&mut buf)?; trace!("got {} bytes from file", n_bytes); if n_bytes == 0 { Err(io::Error::from(io::ErrorKind::UnexpectedEof)) @@ -141,12 +154,14 @@ where if size == counter { Poll::Ready(None) } else { - let inner_file = file - .take() - .expect("LiveFileReader polled after completion"); + let inner_file = file.take().expect("LiveFileReader polled after completion"); if offset >= *this.available_file_size { - trace!("offset {} has reached available file size {}, updating metadata", offset, this.available_file_size); + trace!( + "offset {} has reached available file size {}, updating metadata", + offset, + this.available_file_size + ); // If we've hit the end of what was available // last time we checked, check again *this.available_file_size = match inner_file.metadata() { @@ -165,15 +180,25 @@ where if let Ok(()) = addr.try_send(WakerMessage(cx.waker().clone())) { return Poll::Pending; } else { - return Poll::Ready(Some(Err(ErrorInternalServerError("Failed to contact file upload actor")))); + return Poll::Ready(Some(Err(ErrorInternalServerError( + "Failed to contact file upload actor", + )))); } } else { - return Poll::Ready(Some(Err(ErrorInternalServerError("File upload was not completed")))); + return Poll::Ready(Some(Err(ErrorInternalServerError( + "File upload was not completed", + )))); } } } - let max_bytes = cmp::min(65_536, cmp::min(size.saturating_sub(counter), this.available_file_size.saturating_sub(offset))) as usize; + let max_bytes = cmp::min( + 65_536, + cmp::min( + size.saturating_sub(counter), + this.available_file_size.saturating_sub(offset), + ), + ) as usize; let fut = (this.callback)(inner_file, offset, max_bytes); diff --git a/src/main.rs b/src/main.rs index c917f2a..87dcdca 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,14 +5,11 @@ mod upload; mod util; mod zip; -use std::{ - path::PathBuf, - fs::File, -}; +use std::{fs::File, path::PathBuf}; use actix::Addr; use actix_web::{ - get, middleware::Logger, web, App, HttpRequest, HttpServer, Responder, HttpResponse, + get, middleware::Logger, web, App, HttpRequest, HttpResponse, HttpServer, Responder, }; use actix_web_actors::ws; use serde::{Deserialize, Serialize}; @@ -20,7 +17,7 @@ use state::PersistentState; use time::OffsetDateTime; use tokio::sync::RwLock; -const APP_NAME: &'static str = "transbeam"; +const APP_NAME: &str = "transbeam"; pub struct UploadedFile { name: String, @@ -55,7 +52,11 @@ fn storage_dir() -> PathBuf { } #[get("/download/{file_code}")] -async fn handle_download(req: HttpRequest, path: web::Path, data: AppData) -> actix_web::Result { +async fn handle_download( + req: HttpRequest, + path: web::Path, + data: AppData, +) -> actix_web::Result { let file_code = path.into_inner(); if !util::is_ascii_alphanumeric(&file_code) { return Ok(HttpResponse::NotFound().finish()); @@ -65,8 +66,9 @@ async fn handle_download(req: HttpRequest, path: web::Path, data: AppDat if let Some(info) = info { Ok(download::DownloadingFile { file: File::open(storage_dir().join(file_code))?, - info: info.clone(), - }.into_response(&req)) + info, + } + .into_response(&req)) } else { Ok(HttpResponse::NotFound().finish()) } diff --git a/src/state.rs b/src/state.rs index 9f5f8a8..f4b0588 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,19 +1,25 @@ use std::{collections::HashMap, io::ErrorKind}; -use log::{info, error}; -use tokio::{fs::File, io::{AsyncReadExt, AsyncWriteExt}}; +use log::{error, info}; +use tokio::{ + fs::File, + io::{AsyncReadExt, AsyncWriteExt}, +}; -use crate::{DownloadableFile, storage_dir}; +use crate::{storage_dir, DownloadableFile}; const STATE_FILE_NAME: &str = "files.json"; pub(crate) mod timestamp { use core::fmt; - use serde::{Serializer, Deserializer, de::Visitor}; + use serde::{de::Visitor, Deserializer, Serializer}; use time::OffsetDateTime; - pub(crate) fn serialize(time: &OffsetDateTime, ser: S) -> Result { + pub(crate) fn serialize( + time: &OffsetDateTime, + ser: S, + ) -> Result { ser.serialize_i64(time.unix_timestamp()) } @@ -31,8 +37,13 @@ pub(crate) mod timestamp { } } - pub(crate) fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result { - Ok(OffsetDateTime::from_unix_timestamp(de.deserialize_i64(I64Visitor)?).unwrap_or_else(|_| OffsetDateTime::now_utc())) + pub(crate) fn deserialize<'de, D: Deserializer<'de>>( + de: D, + ) -> Result { + Ok( + OffsetDateTime::from_unix_timestamp(de.deserialize_i64(I64Visitor)?) + .unwrap_or_else(|_| OffsetDateTime::now_utc()), + ) } } @@ -55,18 +66,24 @@ impl PersistentState { let file = if let Ok(f) = File::open(storage_dir().join(&key)).await { f } else { - error!("Unable to open file {} referenced in persistent storage", key); + error!( + "Unable to open file {} referenced in persistent storage", + key + ); continue; }; let metadata = if let Ok(md) = file.metadata().await { md } else { - error!("Unable to get metadata for file {} referenced in persistent storage", key); + error!( + "Unable to get metadata for file {} referenced in persistent storage", + key + ); continue; }; if metadata.len() != info.size { error!("Mismatched file size for file {} referenced in persistent storage: expected {}, found {}", key, info.size, metadata.len()); - continue + continue; } filtered.insert(key, info); } @@ -83,16 +100,23 @@ impl PersistentState { } async fn save(&mut self) -> std::io::Result<()> { - File::create(storage_dir().join(STATE_FILE_NAME)).await?.write_all(&serde_json::to_vec_pretty(&self.0)?).await + File::create(storage_dir().join(STATE_FILE_NAME)) + .await? + .write_all(&serde_json::to_vec_pretty(&self.0)?) + .await } - pub(crate) async fn add_file(&mut self, key: String, file: DownloadableFile) -> std::io::Result<()> { + pub(crate) async fn add_file( + &mut self, + key: String, + file: DownloadableFile, + ) -> std::io::Result<()> { self.0.insert(key, file); self.save().await } pub(crate) fn lookup_file(&self, key: &str) -> Option { - self.0.get(key).map(|f| f.clone()) + self.0.get(key).cloned() } pub(crate) async fn remove_file(&mut self, key: &str) -> std::io::Result<()> { diff --git a/src/upload.rs b/src/upload.rs index 53191ff..a37efde 100644 --- a/src/upload.rs +++ b/src/upload.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, io::Write, task::Waker}; -use actix::{Actor, ActorContext, AsyncContext, StreamHandler, Message, Handler}; +use actix::{Actor, ActorContext, AsyncContext, Handler, Message, StreamHandler}; use actix_http::ws::{CloseReason, Item}; use actix_web_actors::ws::{self, CloseCode}; use bytes::Bytes; @@ -82,7 +82,11 @@ pub(crate) struct WakerMessage(pub Waker); impl Handler for Uploader { type Result = (); fn handle(&mut self, msg: WakerMessage, _: &mut Self::Context) { - self.writer.as_mut().map(|w| w.add_waker(msg.0)); + if let Some(w) = self.writer.as_mut() { + w.add_waker(msg.0); + } else { + error!("Got a wakeup request before creating a file"); + } } } @@ -198,36 +202,38 @@ impl Uploader { let size = zip_writer.total_size(); let download_filename = super::APP_NAME.to_owned() + &now.format(FILENAME_DATE_FORMAT)? + ".zip"; - (Box::new(zip_writer), DownloadableFile { - name: download_filename, - size, - modtime: now, - uploader: addr, - }) + ( + Box::new(zip_writer), + DownloadableFile { + name: download_filename, + size, + modtime: now, + uploader: addr, + }, + ) } else { - (Box::new(writer), DownloadableFile { - name: files[0].name.clone(), - size: files[0].size, - modtime: files[0].modtime, - uploader: addr, - }) + ( + Box::new(writer), + DownloadableFile { + name: files[0].name.clone(), + size: files[0].size, + modtime: files[0].modtime, + uploader: addr, + }, + ) }; self.writer = Some(writer); let data = self.app_data.clone(); - let storage_filename_copy = storage_filename.clone(); ctx.spawn(actix::fut::wrap_future(async move { - data - .write() + data.write() .await - .add_file( - storage_filename_copy, - downloadable_file, - ).await.unwrap(); + .add_file(storage_filename, downloadable_file) + .await + .unwrap(); })); ctx.text(self.storage_filename.as_str()); } - ws::Message::Binary(data) - | ws::Message::Continuation(Item::Last(data)) => { + ws::Message::Binary(data) | ws::Message::Continuation(Item::Last(data)) => { let result = self.handle_data(data)?; ack(ctx); return Ok(result); @@ -272,10 +278,7 @@ impl Uploader { let data = self.app_data.clone(); let filename = self.storage_filename.clone(); ctx.spawn(actix::fut::wrap_future(async move { - data - .write() - .await - .remove_file(&filename).await.unwrap(); + data.write().await.remove_file(&filename).await.unwrap(); })); } } diff --git a/src/zip.rs b/src/zip.rs index e20a45f..2ff9ecf 100644 --- a/src/zip.rs +++ b/src/zip.rs @@ -87,10 +87,10 @@ impl UploadedFile { /// through "Extra field length". fn shared_header_fields(&self, hash: Option) -> Vec { let mut fields = vec![ - 45, 0, // Minimum version required to extract: 4.5 for ZIP64 + 45, 0, // Minimum version required to extract: 4.5 for ZIP64 0b00001000, // General purpose bit flag: bit 3 - size and CRC-32 in data descriptor 0b00001000, // General purpose bit flag: bit 11 - UTF-8 filenames - 0, 0, // Compression method: none + 0, 0, // Compression method: none ]; append_value(&mut fields, fat_timestamp(self.modtime) as u64, 4); // Use 0s as a placeholder if the CRC-32 hash isn't known yet @@ -138,7 +138,7 @@ impl UploadedFile { fn central_directory_header(&self, local_header_offset: u64, hash: u32) -> Vec { let mut header = vec![ 0x50, 0x4b, 0x01, 0x02, // Central directory file header signature - 45, 3, // Made by a Unix system supporting version 4.5 + 45, 3, // Made by a Unix system supporting version 4.5 ]; header.append(&mut self.shared_header_fields(Some(hash))); append_0(&mut header, 8); // Comment length, disk number, internal attributes, DOS external attributes