diff --git a/zulip_bots/zulip_bots/lib.py b/zulip_bots/zulip_bots/lib.py index 9c4928a..8ceb5e7 100644 --- a/zulip_bots/zulip_bots/lib.py +++ b/zulip_bots/zulip_bots/lib.py @@ -65,11 +65,7 @@ class StateHandler(object): self._client = client self.marshal = lambda obj: json.dumps(obj) self.demarshal = lambda obj: json.loads(obj) - response = self._client.get_storage() - if response['result'] == 'success': - self.state_ = response['storage'] - else: - raise StateHandlerError("Error initializing state: {}".format(str(response))) + self.state_ = dict() # type: Dict[Text, Any] def put(self, key, value): # type: (Text, Any) -> None @@ -80,7 +76,16 @@ class StateHandler(object): def get(self, key): # type: (Text) -> Any - return self.demarshal(self.state_[key]) + if key in self.state_: + return self.demarshal(self.state_[key]) + + response = self._client.get_storage(keys=(key,)) + if response['result'] != 'success': + raise StateHandlerError("Error fetching state: {}".format(str(response))) + + marshalled_value = response['storage'][key] + self.state_[key] = marshalled_value + return self.demarshal(marshalled_value) def contains(self, key): # type: (Text) -> bool diff --git a/zulip_bots/zulip_bots/lib_tests.py b/zulip_bots/zulip_bots/lib_tests.py index 5116c5a..a974845 100644 --- a/zulip_bots/zulip_bots/lib_tests.py +++ b/zulip_bots/zulip_bots/lib_tests.py @@ -71,6 +71,34 @@ class LibTest(TestCase): val = state_handler.get('key') self.assertEqual(val, [1, 2, 3]) + def test_state_handler(self): + client = MagicMock() + + state_handler = StateHandler(client) + client.get_storage.assert_not_called() + + client.update_storage = MagicMock(return_value=dict(result='success')) + state_handler.put('key', [1, 2, 3]) + client.update_storage.assert_called_with(dict(storage=dict(key='[1, 2, 3]'))) + + val = state_handler.get('key') + client.get_storage.assert_not_called() + self.assertEqual(val, [1, 2, 3]) + + # force us to get non-cached values + client.get_storage = MagicMock(return_value=dict( + result='success', + storage=dict(non_cached_key='[5]'))) + val = state_handler.get('non_cached_key') + client.get_storage.assert_called_with(keys=('non_cached_key',)) + self.assertEqual(val, [5]) + + # value must already be cached + client.get_storage = MagicMock() + val = state_handler.get('non_cached_key') + client.get_storage.assert_not_called() + self.assertEqual(val, [5]) + def test_send_reply(self): client = FakeClient() profile = client.get_profile()