Source code for requests_cache.models.response

from datetime import datetime, timedelta, timezone
from logging import getLogger
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import attr
from attr import define, field
from requests import PreparedRequest, Response
from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict
from urllib3._collections import HTTPHeaderDict

from ..policy import ExpirationTime, get_expiration_datetime
from . import CachedHTTPResponse, CachedRequest

DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S %Z'  # Format used for __str__ only
HeaderList = List[Tuple[str, str]]
logger = getLogger(__name__)

[docs]@define(auto_attribs=False, slots=False) class CachedResponse(Response): """A class that emulates :py:class:`requests.Response`, with some additional optimizations for serialization. """ _content: bytes = field(default=None) _next: Optional[CachedRequest] = field(default=None) cache_key: Optional[str] = None # Not serialized; set by BaseCache.get_response() cookies: RequestsCookieJar = field(factory=RequestsCookieJar) created_at: datetime = field(factory=datetime.utcnow) elapsed: timedelta = field(factory=timedelta) encoding: str = field(default=None) expires: Optional[datetime] = field(default=None) headers: CaseInsensitiveDict = field(factory=CaseInsensitiveDict) history: List['CachedResponse'] = field(factory=list) # type: ignore raw: CachedHTTPResponse = field(factory=CachedHTTPResponse, repr=False) reason: str = field(default=None) request: CachedRequest = field(factory=CachedRequest) # type: ignore status_code: int = field(default=0) url: str = field(default=None) def __attrs_post_init__(self): """Re-initialize raw response body after deserialization""" if self.raw._body is None and self._content is not None: self.raw.reset(self._content) if not self.raw.headers: self.raw.headers = HTTPHeaderDict(self.headers)
[docs] @classmethod def from_response( cls, original_response: Union[Response, 'CachedResponse'], expires: datetime = None, **kwargs, ): """Create a CachedResponse based on an original Response or another CachedResponse object""" if isinstance(original_response, CachedResponse): return attr.evolve(original_response, expires=expires) obj = cls(expires=expires, **kwargs) # Copy basic attributes for k in Response.__attrs__: setattr(obj, k, getattr(original_response, k, None)) # Store request, raw response, and next response (if it's a redirect response) obj.request = CachedRequest.from_request(original_response.request) obj.raw = CachedHTTPResponse.from_response(original_response) obj._next = ( CachedRequest.from_request( if else None ) # Store response body, which will have been read & decoded by requests.Response by now obj._content = original_response.content # Copy redirect history, if any; avoid recursion by not copying redirects of redirects obj.history = [] if not obj.is_redirect: for redirect in original_response.history: obj.history.append(cls.from_response(redirect)) return obj
@property def _content_consumed(self) -> bool: """For compatibility with requests.Response; will always be True for a cached response""" return True @_content_consumed.setter def _content_consumed(self, value: bool): pass @property def from_cache(self) -> bool: return True @property def is_expired(self) -> bool: """Determine if this cached response is expired""" return self.expires is not None and datetime.utcnow() >= self.expires @property def ttl(self) -> Optional[int]: """Get time to expiration in seconds""" if self.expires is None or self.is_expired: return None delta = self.expires - datetime.utcnow() return int(delta.total_seconds()) @property def next(self) -> Optional[PreparedRequest]: """Returns a PreparedRequest for the next request in a redirect chain, if there is one.""" return self._next.prepare() if self._next else None
[docs] def revalidate(self, expire_after: ExpirationTime) -> bool: """Set a new expiration for this response, and determine if it is now expired""" self.expires = get_expiration_datetime(expire_after) return self.is_expired
@property def size(self) -> int: """Get the size of the response body in bytes""" return len(self.content) if self.content else 0 def __getstate__(self): """Override pickling behavior from ``requests.Response.__getstate__``""" return self.__dict__ def __setstate__(self, state): """Override pickling behavior from ``requests.Response.__setstate__``""" for name, value in state.items(): setattr(self, name, value) def __str__(self): return ( f'request: {self.request}, response: {self.status_code} ' f'({format_file_size(self.size)}), created: {format_datetime(self.created_at)}, ' f'expires: {format_datetime(self.expires)} ({"stale" if self.is_expired else "fresh"})' )
[docs]def format_datetime(value: Optional[datetime]) -> str: """Get a formatted datetime string in the local time zone""" if not value: return "N/A" if value.tzinfo is None: value = value.replace(tzinfo=timezone.utc) return value.astimezone().strftime(DATETIME_FORMAT)
[docs]def format_file_size(n_bytes: int) -> str: """Convert a file size in bytes into a human-readable format""" filesize = float(n_bytes or 0) def _format(unit): return f'{int(filesize)} {unit}' if unit == 'bytes' else f'{filesize:.2f} {unit}' for unit in ['bytes', 'KiB', 'MiB', 'GiB']: if filesize < 1024 or unit == 'GiB': return _format(unit) filesize /= 1024 if TYPE_CHECKING: return _format(unit)
[docs]def set_response_defaults( response: Union[Response, CachedResponse], cache_key: str = None ) -> Union[Response, CachedResponse]: """Set some default CachedResponse values on a requests.Response object, so they can be expected to always be present """ if not isinstance(response, CachedResponse): response.cache_key = cache_key # type: ignore response.created_at = None # type: ignore response.expires = None # type: ignore response.from_cache = False # type: ignore response.is_expired = False # type: ignore return response