diff --git a/zulip_bots/zulip_bots/lib.py b/zulip_bots/zulip_bots/lib.py index 05025ff..d7f9023 100644 --- a/zulip_bots/zulip_bots/lib.py +++ b/zulip_bots/zulip_bots/lib.py @@ -8,7 +8,8 @@ import time import re -from typing import Any, Optional, List, Dict, IO, Text +from contextlib import contextmanager +from typing import Any, Iterator, Optional, List, Dict, IO, Set, Text from typing_extensions import Protocol from zulip import Client, ZulipError @@ -83,6 +84,50 @@ class BotStorage(Protocol): def contains(self, key: Text) -> bool: ... +class CachedStorage: + def __init__(self, parent_storage: BotStorage, init_data: Dict[str, Any]) -> None: + # CachedStorage is implemented solely for the context manager of any BotHandler. + # It has a parent_storage that is responsible of communicating with the database + # 1. when certain data is not cached; + # 2. when the data need to be flushed to the database. + # It can be initialized with the given data. + self._parent_storage = parent_storage + self._cache = init_data + self._dirty_keys: Set[str] = set() + + def put(self, key: Text, value: Any) -> None: + # In the cached storage, values being put to the storage is not flushed to the parent storage. + # It will be marked dirty until it get flushed. + self._cache[key] = value + self._dirty_keys.add(key) + + def get(self, key: Text) -> Any: + # Unless the key is not found in the cache, the cached storage will not lookup the parent storage. + if key in self._cache: + return self._cache[key] + else: + value = self._parent_storage.get(key) + self._cache[key] = value + return value + + def flush(self) -> None: + # Flush the data to the parent storage. + # Data that are not marked dirty will be omitted. + # This should be manually called when CachedStorage is not used with a context manager. + while len(self._dirty_keys) > 0: + key = self._dirty_keys.pop() + self._parent_storage.put(key, self._cache[key]) + + def flush_one(self, key: Text) -> None: + self._dirty_keys.remove(key) + self._parent_storage.put(key, self._cache[key]) + + def contains(self, key: Text) -> bool: + if key in self._cache: + return True + else: + return self._parent_storage.contains(key) + class StateHandler: def __init__(self, client: Client) -> None: self._client = client @@ -111,6 +156,16 @@ class StateHandler: def contains(self, key: Text) -> bool: return key in self.state_ +@contextmanager +def use_storage(storage: BotStorage, keys: List[Text]) -> Iterator[BotStorage]: + # The context manager for StateHandler that minimizes the number of round-trips to the server. + # It will fetch all the data using the specified keys and store them to + # a CachedStorage that will not communicate with the server until manually + # calling flush or getting some values that are not previously fetched. + data = {key: storage.get(key) for key in keys} + cache = CachedStorage(storage, data) + yield storage + cache.flush() class BotHandler(Protocol):