Source code for requests_cache.backends.mongodb

"""MongoDB cache backend. For usage details, see :ref:`Backends: MongoDB <mongodb>`.

.. automodsumm:: requests_cache.backends.mongodb
   :classes-only:
   :nosignatures:
"""
from datetime import timedelta
from logging import getLogger
from typing import Iterable, Mapping, Optional, Union

from pymongo import MongoClient
from pymongo.errors import OperationFailure

from .._utils import get_valid_kwargs
from ..policy.expiration import NEVER_EXPIRE, get_expiration_seconds
from ..serializers import SerializerType, bson_document_serializer
from . import BaseCache, BaseStorage

logger = getLogger(__name__)


[docs]class MongoCache(BaseCache): """MongoDB cache backend. By default, responses are only partially serialized into a MongoDB-compatible document format. Args: db_name: Database name connection: :py:class:`pymongo.MongoClient` object to reuse instead of creating a new one kwargs: Additional keyword arguments for :py:class:`pymongo.mongo_client.MongoClient` """ def __init__( self, db_name: str = 'http_cache', connection: MongoClient = None, decode_content: bool = True, serializer: Optional[SerializerType] = None, **kwargs, ): super().__init__(cache_name=db_name, **kwargs) skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses: MongoDict = MongoDict( db_name, collection_name='responses', connection=connection, decode_content=decode_content, **skwargs, ) self.redirects: MongoDict = MongoDict( db_name, collection_name='redirects', connection=self.responses.connection, serializer=None, **kwargs, )
[docs] def get_ttl(self) -> Optional[int]: """Get the currently defined TTL value in seconds, if any""" return self.responses.get_ttl()
[docs] def set_ttl(self, ttl: Union[int, timedelta], overwrite: bool = False): """Create or update a TTL index. Notes: * This will have no effect if TTL is already set * To overwrite an existing TTL index, use ``overwrite=True`` * This may take some time to complete, depending on the size of your cache * Use ``ttl=None, overwrite=True`` to remove the TTL index """ self.responses.set_ttl(ttl, overwrite=overwrite)
[docs]class MongoDict(BaseStorage): """A dictionary-like interface for a MongoDB collection Args: db_name: Database name collection_name: Collection name connection: :py:class:`pymongo.MongoClient` object to reuse instead of creating a new one kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ def __init__( self, db_name: str, collection_name: str = 'http_cache', connection: Optional[MongoClient] = None, serializer: Optional[SerializerType] = bson_document_serializer, **kwargs, ): super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs) self.connection = connection or MongoClient(**connection_kwargs) self.collection = self.connection[db_name][collection_name]
[docs] def get_ttl(self) -> Optional[int]: """Get the currently defined TTL value in seconds, if any""" idx_info = self.collection.index_information().get('ttl_idx', {}) return idx_info.get('expireAfterSeconds')
[docs] def set_ttl(self, ttl: Union[int, timedelta], overwrite: bool = False): """Create or update a TTL index, and ignore and log any errors due to dropping a nonexistent index or attempting to overwrite without ```overwrite=True``. """ try: self._set_ttl(get_expiration_seconds(ttl), overwrite=overwrite) except OperationFailure: logger.warning('Failed to update TTL index', exc_info=True)
def _set_ttl(self, ttl: int, overwrite: bool = False): if overwrite: self.collection.drop_index('ttl_idx') logger.info('Dropped TTL index') if ttl and ttl != NEVER_EXPIRE: logger.info(f'Creating TTL index for {ttl} seconds') self.collection.create_index('created_at', name='ttl_idx', expireAfterSeconds=ttl) def __getitem__(self, key): result = self.collection.find_one({'_id': key}) if result is None: raise KeyError value = result['data'] if 'data' in result else result return self.deserialize(key, value) def __setitem__(self, key, value): """If ``value`` is already a dict, its values will be stored under top-level keys. Otherwise, it will be stored under a 'data' key. """ value = self.serialize(value) if not isinstance(value, Mapping): value = {'data': value} self.collection.replace_one({'_id': key}, value, upsert=True) def __delitem__(self, key): result = self.collection.find_one_and_delete({'_id': key}, {'_id': True}) if result is None: raise KeyError def __len__(self): return self.collection.estimated_document_count() def __iter__(self): for d in self.collection.find({}, {'_id': True}): yield d['_id']
[docs] def bulk_delete(self, keys: Iterable[str]): """Delete multiple keys from the cache. Does not raise errors for missing keys.""" self.collection.delete_many({'_id': {'$in': list(keys)}})
[docs] def clear(self): self.collection.drop()
[docs] def close(self): self.connection.close()