diff --git a/Cargo.toml b/Cargo.toml index 848bc3b..1e794bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jsondb" -version = "0.1.1" +version = "0.2.0" edition = "2021" authors = ["xenofem "] diff --git a/src/lib.rs b/src/lib.rs index af5e749..d226105 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,25 +5,37 @@ //! //! * The saved data includes a schema version number, and will be //! automatically migrated to newer schema versions. +//! * The live data is guarded by a built-in read-write lock which can +//! be used synchronously or from a [tokio] async environment. //! * Data is saved to the backing JSON file, in a hopefully-atomic -//! fashion, every time it's modified. -//! * All I/O operations are async using [tokio]. +//! fashion, every time a write lock is released. //! -//! Data can be represented in pretty much any format you can convince -//! [serde] to go along with, except for two restrictions: +//! Data can be represented in pretty much any form you can convince +//! [serde] to go along with, except for the following restrictions: //! +//! * Your data type must be [`Debug`] + [`Send`] + [`Sync`] + `'static`. //! * Your serialization format shouldn't include a top-level //! `version` field of its own, as this is reserved for our schema //! version tracking. //! * You can't use `#[serde(deny_unknown_fields)]`, as this conflicts //! with our use of `#[serde(flatten)]`. -use std::{cmp::Ordering, ffi::OsString, future::Future, io::ErrorKind, path::PathBuf}; +use std::{ + cmp::Ordering, + ffi::OsString, + fmt::Debug, + future::Future, + io::ErrorKind, + ops::Deref, + path::{Path, PathBuf}, + sync::Arc, +}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{ fs::{rename, File}, io::{AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock}, }; /// A JSON-backed “““database”””. @@ -33,8 +45,14 @@ use tokio::{ /// written to disk when it's updated (we attempt to make saves atomic /// using the `rename(2)` function). pub struct JsonDb { - path: PathBuf, - data: T, + channel: mpsc::UnboundedSender>, +} + +#[derive(Debug)] +enum Request { + Read(oneshot::Sender>), + Write(oneshot::Sender>), + Flush(oneshot::Sender<()>), } /// Schema for a JSON-backed database. @@ -46,7 +64,7 @@ pub struct JsonDb { /// from a JSON object containing a `version` field along with the /// other fields of the corresponding schema version; earlier versions /// will be migrated to the current version automatically. -pub trait Schema: DeserializeOwned + Serialize { +pub trait Schema: Send + Sync + Debug + DeserializeOwned + Serialize + 'static { /// Previous schema that can be migrated into the new schema type Prev: Schema + Into; @@ -67,7 +85,7 @@ pub trait Schema: DeserializeOwned + Serialize { /// /// Implementing this will automatically implement [`Schema`], with /// version number `0` and `Self` as the previous version. -pub trait SchemaV0: DeserializeOwned + Serialize { +pub trait SchemaV0: Send + Sync + Debug + DeserializeOwned + Serialize + 'static { /// Set this to false if your version 0 is a pre-`JsonDb` schema /// that does not include a version number. const EXPECT_VERSION_NUMBER: bool = true; @@ -115,23 +133,44 @@ impl JsonDb { /// Load a [`JsonDb`] from a given file, creating it and /// initializing it with the schema's default value if it does not /// exist. - pub async fn load(path: PathBuf) -> Result { + pub async fn load(path: PathBuf) -> Result, Error> { Self::load_or_else(path, || std::future::ready(Ok(T::default()))).await } } +async fn save(data: &T, path: &Path) -> Result<(), Error> { + let mut temp_file_name = OsString::from("."); + temp_file_name.push(path.file_name().unwrap()); + temp_file_name.push(".tmp"); + let temp_file_path = path.parent().unwrap().join(temp_file_name); + { + let mut temp_file = File::create(&temp_file_path).await?; + temp_file + .write_all(&serde_json::to_vec_pretty(&Repr { + version: T::VERSION, + data, + })?) + .await?; + temp_file.sync_all().await?; + } + // Atomically update the actual file + rename(&temp_file_path, &path).await?; + + Ok(()) +} + impl JsonDb { /// Load a [`JsonDb`] from a given file, creating it and /// initializing it with the provided default value if it does not /// exist. - pub async fn load_or(path: PathBuf, default: T) -> Result { + pub async fn load_or(path: PathBuf, default: T) -> Result, Error> { Self::load_or_else(path, || std::future::ready(Ok(default))).await } /// Load a [`JsonDb`] from a given file, creating it and /// initializing it with the provided function if it does not /// exist. - pub async fn load_or_else(path: PathBuf, default: F) -> Result + pub async fn load_or_else(path: PathBuf, default: F) -> Result, Error> where F: FnOnce() -> Fut, Fut: Future>, @@ -151,59 +190,110 @@ impl JsonDb { } } }; - let mut db = JsonDb { path, data }; - // Always save in case we've run migrations - db.save().await?; - Ok(db) + let (request_send, mut request_recv) = mpsc::unbounded_channel::>(); + tokio::spawn(async move { + save(&data, &path).await.expect("Failed to save data"); + let lock = Arc::new(RwLock::new(data)); + while let Some(request) = request_recv.recv().await { + match request { + Request::Read(response) => { + response + .send(lock.clone().read_owned().await) + .expect("Failed to send read guard"); + } + Request::Write(response) => { + response + .send(lock.clone().write_owned().await) + .expect("Failed to send write guard"); + save(lock.read().await.deref(), &path) + .await + .expect("Failed to save data"); + } + Request::Flush(response) => { + // Once we get around to handling this + // request, we've already flushed data from + // any previously-issued write requests + response + .send(()) + .expect("Failed to send flush confirmation"); + } + } + } + }); + Ok(JsonDb { + channel: request_send, + }) } - async fn save(&mut self) -> Result<(), Error> { - let mut temp_file_name = OsString::from("."); - temp_file_name.push(self.path.file_name().unwrap()); - temp_file_name.push(".tmp"); - let temp_file_path = self.path.parent().unwrap().join(temp_file_name); - { - let mut temp_file = File::create(&temp_file_path).await?; - temp_file - .write_all(&serde_json::to_vec_pretty(&Repr { - version: T::VERSION, - data: &self.data, - })?) - .await?; - temp_file.sync_all().await?; - } - // Atomically update the actual file - rename(&temp_file_path, &self.path).await?; - - Ok(()) + fn request_read(&self) -> oneshot::Receiver> { + let (send, recv) = oneshot::channel(); + self.channel + .send(Request::Read(send)) + .expect("Failed to send read lock request"); + recv } - /// Borrow an immutable reference to the wrapped data - pub fn read(&self) -> &T { - &self.data + /// Take a read lock on the wrapped data. + pub async fn read(&self) -> OwnedRwLockReadGuard { + self.request_read() + .await + .expect("Failed to receive read lock") } - /// Modify the wrapped data in-place, atomically writing it back - /// to disk afterwards. - pub async fn write(&mut self, updater: U) -> Result - where - U: FnOnce(&mut T) -> V, - { - let result = updater(&mut self.data); - self.save().await?; - Ok(result) + /// Synchronous version of [`read`][Self::read]. + pub fn blocking_read(&self) -> OwnedRwLockReadGuard { + self.request_read() + .blocking_recv() + .expect("Failed to receive read lock") } - /// Modify the wrapped data in-place using asynchronous code, - /// atomically writing it back to disk afterwards. - pub async fn write_async(&mut self, updater: U) -> Result - where - U: FnOnce(&mut T) -> Fut, - Fut: Future, - { - let result = updater(&mut self.data).await; - self.save().await?; - Ok(result) + fn request_write(&self) -> oneshot::Receiver> { + let (send, recv) = oneshot::channel(); + self.channel + .send(Request::Write(send)) + .expect("Failed to send write lock request"); + recv + } + + /// Take a write lock on the wrapped data. When the write guard is + /// dropped, it triggers an atomic write of the updated data back + /// to disk. + pub async fn write(&self) -> OwnedRwLockWriteGuard { + self.request_write() + .await + .expect("Failed to receive write lock") + } + + /// Synchronous version of [`write`][Self::write]. + pub fn blocking_write(&self) -> OwnedRwLockWriteGuard { + self.request_write() + .blocking_recv() + .expect("Failed to receive write lock") + } + + fn request_flush(&self) -> oneshot::Receiver<()> { + let (send, recv) = oneshot::channel(); + self.channel + .send(Request::Flush(send)) + .expect("Failed to send flush request"); + recv + } + + /// Wait for data to finish flushing to disk. Every call to + /// [`read`][Self::read] or [`write`][Self::write], or their + /// blocking equivalents, also waits for data to be flushed before + /// returning a guard. + pub async fn flush(&self) { + self.request_flush() + .await + .expect("Failed to receive flush confirmation"); + } + + /// Synchronous version of [`flush`][Self::flush]. + pub fn blocking_flush(&self) { + self.request_flush() + .blocking_recv() + .expect("Failed to receive flush confirmation"); } } @@ -352,34 +442,82 @@ mod tests { } #[tokio::test] - async fn load_write_migrate() { + async fn async_load_write_migrate() { let dir = tempdir().unwrap(); let db_file = dir.path().join("test.json"); { - let mut db0: JsonDb = JsonDb::load(db_file.clone()).await.unwrap(); + let db0: JsonDb = JsonDb::load(db_file.clone()).await.unwrap(); + db0.flush().await; let value: Value = serde_json::from_reader(File::open(&db_file).unwrap()).unwrap(); assert_eq!(value["version"], 0); assert_eq!(&value["name"], ""); - db0.write(|ref mut val| { - val.name = String::from("mefonex"); - }) - .await - .unwrap(); + { + let mut writer = db0.write().await; + writer.name = String::from("mefonex"); + } + { + let reader = db0.read().await; + assert_eq!(reader.name, "mefonex"); + } + // Reading also awaits a flush let value: Value = serde_json::from_reader(File::open(&db_file).unwrap()).unwrap(); assert_eq!(&value["name"], "mefonex"); } { - let mut db2: JsonDb = JsonDb::load(db_file.clone()).await.unwrap(); + let db2: JsonDb = JsonDb::load(db_file.clone()).await.unwrap(); + db2.flush().await; let value: Value = serde_json::from_reader(File::open(&db_file).unwrap()).unwrap(); assert_eq!(value["version"], 2); assert_eq!(&value["name"], "mefonex"); assert_eq!(value["gender"], Value::Null); assert_eq!(&value["last_updated"], "1970-01-01T00:00:00Z"); - db2.write(|ref mut val| { - val.last_updated = OffsetDateTime::from_unix_timestamp(1660585638).unwrap(); - }) - .await - .unwrap(); + { + let mut writer = db2.write().await; + writer.last_updated = OffsetDateTime::from_unix_timestamp(1660585638).unwrap(); + } + db2.flush().await; + let value: Value = serde_json::from_reader(File::open(&db_file).unwrap()).unwrap(); + assert_eq!(&value["last_updated"], "2022-08-15T17:47:18Z"); + } + } + + #[test] + fn blocking_load_write_migrate() { + let rt = tokio::runtime::Runtime::new().unwrap(); + + let dir = tempdir().unwrap(); + let db_file = dir.path().join("test.json"); + { + let db0: JsonDb = rt.block_on(JsonDb::load(db_file.clone())).unwrap(); + db0.blocking_flush(); + let value: Value = serde_json::from_reader(File::open(&db_file).unwrap()).unwrap(); + assert_eq!(value["version"], 0); + assert_eq!(&value["name"], ""); + { + let mut writer = db0.blocking_write(); + writer.name = String::from("mefonex"); + } + { + let reader = db0.blocking_read(); + assert_eq!(reader.name, "mefonex"); + } + // Reading also waits for a flush + let value: Value = serde_json::from_reader(File::open(&db_file).unwrap()).unwrap(); + assert_eq!(&value["name"], "mefonex"); + } + { + let db2: JsonDb = rt.block_on(JsonDb::load(db_file.clone())).unwrap(); + db2.blocking_flush(); + let value: Value = serde_json::from_reader(File::open(&db_file).unwrap()).unwrap(); + assert_eq!(value["version"], 2); + assert_eq!(&value["name"], "mefonex"); + assert_eq!(value["gender"], Value::Null); + assert_eq!(&value["last_updated"], "1970-01-01T00:00:00Z"); + { + let mut writer = db2.blocking_write(); + writer.last_updated = OffsetDateTime::from_unix_timestamp(1660585638).unwrap(); + } + db2.blocking_flush(); let value: Value = serde_json::from_reader(File::open(&db_file).unwrap()).unwrap(); assert_eq!(&value["last_updated"], "2022-08-15T17:47:18Z"); }