use std::{collections::HashMap, io::ErrorKind, path::PathBuf, str::FromStr}; use log::{debug, error, info, warn}; use rand::{ distributions::{Alphanumeric, DistString}, thread_rng, Rng, }; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use tokio::{ fs::File, io::{AsyncReadExt, AsyncWriteExt}, }; const STATE_FILE_NAME: &str = "files.json"; const DEFAULT_STORAGE_DIR: &str = "storage"; const DEFAULT_MAX_LIFETIME: u32 = 30; const GIGA: u64 = 1024 * 1024 * 1024; const DEFAULT_MAX_UPLOAD_SIZE: u64 = 16 * GIGA; const DEFAULT_MAX_STORAGE_SIZE: u64 = 64 * GIGA; pub fn gen_storage_code() -> String { if std::env::var("TRANSBEAM_MNEMONIC_CODES").as_deref() == Ok("false") { Alphanumeric.sample_string(&mut thread_rng(), 8) } else { mnemonic::to_string(thread_rng().gen::<[u8; 4]>()) } } pub fn is_valid_storage_code(s: &str) -> bool { s.as_bytes() .iter() .all(|c| c.is_ascii_alphanumeric() || c == &b'-') } pub(crate) fn storage_dir() -> PathBuf { PathBuf::from( std::env::var("TRANSBEAM_STORAGE_DIR") .unwrap_or_else(|_| String::from(DEFAULT_STORAGE_DIR)), ) } fn parse_env_var(var: &str, default: T) -> T { std::env::var(var) .ok() .and_then(|val| val.parse::().ok()) .unwrap_or(default) } pub(crate) fn max_lifetime() -> u32 { parse_env_var("TRANSBEAM_MAX_LIFETIME", DEFAULT_MAX_LIFETIME) } pub(crate) fn max_single_size() -> u64 { parse_env_var("TRANSBEAM_MAX_UPLOAD_SIZE", DEFAULT_MAX_UPLOAD_SIZE) } pub(crate) fn max_total_size() -> u64 { parse_env_var("TRANSBEAM_MAX_STORAGE_SIZE", DEFAULT_MAX_STORAGE_SIZE) } #[derive(Clone, Deserialize, Serialize)] pub struct StoredFile { pub name: String, pub size: u64, #[serde(with = "timestamp")] pub modtime: OffsetDateTime, #[serde(with = "timestamp")] pub expiry: OffsetDateTime, } pub(crate) mod timestamp { use core::fmt; use serde::{de::Visitor, Deserializer, Serializer}; use time::OffsetDateTime; pub(crate) fn serialize( time: &OffsetDateTime, ser: S, ) -> Result { ser.serialize_i64(time.unix_timestamp()) } struct I64Visitor; impl<'de> Visitor<'de> for I64Visitor { type Value = i64; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "an integer") } fn visit_i64(self, v: i64) -> Result { Ok(v) } fn visit_u64(self, v: u64) -> Result { Ok(v as i64) } } pub(crate) fn deserialize<'de, D: Deserializer<'de>>( de: D, ) -> Result { Ok( OffsetDateTime::from_unix_timestamp(de.deserialize_i64(I64Visitor)?) .unwrap_or_else(|_| OffsetDateTime::now_utc()), ) } } async fn is_valid_entry(key: &str, info: &StoredFile) -> bool { if info.expiry < OffsetDateTime::now_utc() { info!("File {} has expired", key); return false; } let file = if let Ok(f) = File::open(storage_dir().join(&key)).await { f } else { error!( "Unable to open file {} referenced in persistent storage", key ); return false; }; let metadata = if let Ok(md) = file.metadata().await { md } else { error!( "Unable to get metadata for file {} referenced in persistent storage", key ); return false; }; if metadata.len() != info.size { error!("Mismatched file size for file {} referenced in persistent storage: expected {}, found {}", key, info.size, metadata.len()); return false; } true } pub(crate) struct FileStore(HashMap); impl FileStore { pub(crate) async fn load() -> std::io::Result { let open_result = File::open(storage_dir().join(STATE_FILE_NAME)).await; match open_result { Ok(mut f) => { let mut buf = String::new(); f.read_to_string(&mut buf).await?; let map: HashMap = serde_json::from_str(&buf)?; info!("Loaded {} file entries from persistent storage", map.len()); let mut filtered: HashMap = HashMap::new(); for (key, info) in map.into_iter() { // Handle this case separately, because we don't // want to try to delete it if it's not the sort // of path we're expecting if !is_valid_storage_code(&key) { error!("Invalid key in persistent storage: {}", key); continue; } if is_valid_entry(&key, &info).await { filtered.insert(key, info); } else { info!("Deleting file {}", key); if let Err(e) = tokio::fs::remove_file(storage_dir().join(&key)).await { warn!("Failed to delete file {}: {}", key, e); } } } let mut loaded = Self(filtered); loaded.save().await?; Ok(loaded) } Err(e) => { if let ErrorKind::NotFound = e.kind() { Ok(Self(HashMap::new())) } else { Err(e) } } } } fn total_size(&self) -> u64 { self.0.iter().fold(0, |acc, (_, f)| acc + f.size) } async fn save(&mut self) -> std::io::Result<()> { info!("saving updated state: {} entries", self.0.len()); File::create(storage_dir().join(STATE_FILE_NAME)) .await? .write_all(&serde_json::to_vec_pretty(&self.0)?) .await } /// Attempts to add a file to the store. Returns an I/O error if /// something's broken, or a u64 of the maximum allowed file size /// if the file was too big, or a unit if everything worked. pub(crate) async fn add_file( &mut self, key: String, file: StoredFile, ) -> std::io::Result> { let remaining_size = max_total_size().saturating_sub(self.total_size()); let allowed_size = std::cmp::min(remaining_size, max_single_size()); if file.size > allowed_size { return Ok(Err(allowed_size)); } self.0.insert(key, file); self.save().await.map(Ok) } pub(crate) fn lookup_file(&self, key: &str) -> Option { self.0.get(key).cloned() } pub(crate) async fn remove_file(&mut self, key: &str) -> std::io::Result<()> { debug!("removing entry {} from state", key); self.0.remove(key); self.save().await } pub(crate) async fn remove_expired_files(&mut self) -> std::io::Result<()> { info!("Checking for expired files"); let now = OffsetDateTime::now_utc(); for (key, file) in std::mem::take(&mut self.0).into_iter() { if file.expiry > now { self.0.insert(key, file); } else { info!("Deleting expired file {}", key); if let Err(e) = tokio::fs::remove_file(storage_dir().join(&key)).await { warn!("Failed to delete expired file {}: {}", key, e); } } } self.save().await } }