diff --git a/zulip_bots/zulip_bots/lib.py b/zulip_bots/zulip_bots/lib.py index 6996122..d5d802a 100644 --- a/zulip_bots/zulip_bots/lib.py +++ b/zulip_bots/zulip_bots/lib.py @@ -14,7 +14,7 @@ from contextlib import contextmanager if False: from mypy_extensions import NoReturn -from typing import Any, Optional, List, Dict, IO, Text +from typing import Any, Optional, List, Dict, IO, Text, Set from types import ModuleType from zulip import Client @@ -52,16 +52,26 @@ class RateLimit(object): logging.error(self.error_message) sys.exit(1) +class StateHandlerError(Exception): + pass + class StateHandler(object): - def __init__(self): - # type: () -> None - self.state_ = {} # type: Dict[Text, Text] + def __init__(self, client): + # type: (Client) -> None + self._client = client self.marshal = lambda obj: json.dumps(obj) self.demarshal = lambda obj: json.loads(obj) + response = self._client.get_state() + if response['result'] == 'success': + self.state_ = response['state'] + self._modified_entries = set() # type: Set[Text] + else: + raise StateHandlerError("Error initializing state: {}".format(str(response))) def put(self, key, value): # type: (Text, Text) -> None self.state_[key] = self.marshal(value) + self._modified_entries.add(key) def get(self, key): # type: (Text) -> Text @@ -71,6 +81,16 @@ class StateHandler(object): # type: (Text) -> bool return key in self.state_ + def _save(self): + # type: () -> None + state_update = {'state': {key: self.state_[key] for key in self._modified_entries}} + if state_update: + response = self._client.update_state(state_update) + if response['result'] == 'success': + self._modified_entries.clear() + else: + raise StateHandlerError("Error updating state: {}".format(str(response))) + class ExternalBotHandler(object): def __init__(self, client, root_dir): # type: (Client, str) -> None @@ -79,7 +99,7 @@ class ExternalBotHandler(object): self._rate_limit = RateLimit(20, 5) self._client = client self._root_dir = root_dir - self.storage = StateHandler() + self.storage = StateHandler(client) try: self.user_id = user_profile['user_id'] self.full_name = user_profile['full_name'] @@ -218,6 +238,7 @@ def run_message_handler_for_bot(lib_module, quiet, config_file, bot_name): message=message, bot_handler=restricted_client ) + restricted_client.storage._save() signal.signal(signal.SIGINT, exit_gracefully) diff --git a/zulip_bots/zulip_bots/test_lib.py b/zulip_bots/zulip_bots/test_lib.py index f0d6a49..3d76e09 100755 --- a/zulip_bots/zulip_bots/test_lib.py +++ b/zulip_bots/zulip_bots/test_lib.py @@ -45,7 +45,9 @@ class BotTestCaseBase(TestCase): self.patcher = patch('zulip_bots.lib.ExternalBotHandler', autospec=True) self.MockClass = self.patcher.start() self.mock_bot_handler = self.MockClass(None, None) - self.mock_bot_handler.storage = StateHandler() + self.mock_client = MagicMock() + self.mock_client.get_state.return_value = {'result': 'success', 'state': {}} + self.mock_bot_handler.storage = StateHandler(self.mock_client) self.mock_bot_handler.send_message.return_value = {'id': 42} self.mock_bot_handler.send_reply.return_value = {'id': 42} self.message_handler = self.get_bot_message_handler() diff --git a/zulip_botserver/tests/test_server.py b/zulip_botserver/tests/test_server.py index 048303d..b870679 100644 --- a/zulip_botserver/tests/test_server.py +++ b/zulip_botserver/tests/test_server.py @@ -39,8 +39,9 @@ class BotServerTests(BotServerTestCase): check_success=False) @mock.patch('logging.error') - def test_wrong_bot_credentials(self, mock_LoggingError): - # type: (mock.Mock) -> None + @mock.patch('zulip_bots.lib.StateHandler') + def test_wrong_bot_credentials(self, mock_StateHandler, mock_LoggingError): + # type: (mock.Mock, mock.Mock) -> None available_bots = ['nonexistent-bot'] bots_config = { 'nonexistent-bot': {