관리-도구
편집 파일: __init__.py
"""Utilities for managing local file storage synchronised with a remote server. Files are divided into types: signatures, modsecurity bundles, ip white lists, etc. Each type is represented by an Index instance. Index has a local subdirectory and a description that contains its files' metadata used to decide if the update is necessary. """ import asyncio import datetime as DT import hashlib import http.client import io import json import math import os import pathlib import random import shutil import socket import time import zipfile import urllib.error import urllib.request from collections import defaultdict, namedtuple from contextlib import ExitStack, suppress, contextmanager from email.utils import formatdate, parsedate_to_datetime from gzip import GzipFile from itertools import chain from logging import getLogger from typing import ( Any, BinaryIO, Dict, Iterable, List, Optional, Set, Tuple, Union, ) from urllib.parse import urlparse from defence360agent.contracts import config from defence360agent.contracts.license import LicenseCLN from defence360agent.utils import file_hash, retry_on, run_with_umask from defence360agent.utils.common import rate_limit, HOUR from defence360agent.utils.threads import to_thread from .hooks import default_hook logger = getLogger(__name__) # static file types EULA = "eula" SIGS = "sigs" # malware signatures REALTIME_AV_CONF = "realtime-av-conf" FILES_DIR = pathlib.Path("/var/imunify360/files") BASE_URL = "https://files.imunify360.com/static/" # chunk size for network and file operations, in bytes _BUFSIZE = 32 * 1024 _MAX_TRIES_FOR_DOWNLOAD = 10 _TIMEOUT_MULTIPLICATOR = 0.025 """ >>> _MAX_TRIES_FOR_DOWNLOAD = 10 >>> _TIMEOUT_MULTIPLICATOR = 0.025 >>> [(1 << i) * _TIMEOUT_MULTIPLICATOR for i in range(1, _MAX_TRIES_FOR_DOWNLOAD)] # noqa [0.05, 0.1, 0.2, 0.4, 0.8, 1.6, 3.2, 6.4, 12.8] """ #: sentinel: mtime for a missing/never modified file _NEVER = -math.inf # https://github.com/python/typing/issues/182 JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] class IntegrityError(RuntimeError): """Raised when on disk content does not match hashes in description.json""" class UpdateError(RuntimeError): """Raised on other errors during files update. Possible reasons are: * server returns non 200 status; * hash mismatched between downloaded content and description.json; * urllib errors; * JSON decoding errors; * errors while writing to disk. """ async def _log_failed_update(exc, i): logger.warning( "Files update failed with error: {err}, try: {try_}".format( err=exc, try_=i ) ) # exponential backoff await asyncio.sleep(random.randrange(1 << i) * _TIMEOUT_MULTIPLICATOR) def _open_with_mode(path: os.PathLike, mode: int) -> BinaryIO: """Open file at `path` using permission `mode` for writing in binary mode and return file object.""" with run_with_umask(0): fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode) return os.fdopen(fd, "wb") def _fetch_json_sync(url, timeout) -> JSONType: with _fetch_url(url, timeout=timeout) as response: return json.load( io.TextIOWrapper( response["file"], encoding=response["headers"].get_content_charset("utf-8"), ) ) @retry_on( UpdateError, on_error=_log_failed_update, max_tries=_MAX_TRIES_FOR_DOWNLOAD ) async def _fetch_json(url: str, timeout) -> JSONType: """Download and decode JSON from *url*. Return decoded JSON. Raise UpdateError: * HTTP response status code is not 200; * Unicode or JSON decoding fails; * on time outs during HTTP request; * on other HTTP errors. """ loop = asyncio.get_event_loop() try: return await loop.run_in_executor(None, _fetch_json_sync, url, timeout) except (UnicodeDecodeError, json.JSONDecodeError) as e: raise UpdateError("json decode error [{}] for url {}".format(e, url)) except socket.timeout: raise UpdateError("request to {} timed out".format(url)) except ConnectionResetError: raise UpdateError("request to {} reset".format(url)) except EOFError as e: raise UpdateError( f"eof error while updating files, url: {url}, err: {e}" ) except (http.client.HTTPException, urllib.error.URLError) as e: raise UpdateError( "urllib/http error while updating files, url: {}, err: {}".format( url, e ) ) except OSError as e: raise UpdateError(f"Can't fetch {url}, reason: {e}") def _perform_http_head_sync( # NOSONAR pylint:W0102 url: str, timeout: float, *, headers={} ): """Perform HEAD http request to *url* with *timeout* & *headers*.""" req = urllib.request.Request( url, headers={ "Imunify-Server-Id": LicenseCLN.get_server_id() or "", **headers, }, method="HEAD", ) with urllib.request.urlopen(req, timeout=timeout) as r: return r.code, r.headers @retry_on( UpdateError, on_error=_log_failed_update, max_tries=_MAX_TRIES_FOR_DOWNLOAD ) async def _need_to_download( url: str, current_mtime: float, timeout: float ) -> bool: """Check if we need to download description.json file: - perform HEAD request if local file exists and older return True otherwise return False """ if current_mtime is _NEVER: # file has never been updated return True # need to download it formatted_mtime = formatdate(current_mtime, usegmt=True) try: code, headers = await to_thread( _perform_http_head_sync, url, timeout, headers={"If-Modified-Since": formatted_mtime}, ) except socket.timeout: raise UpdateError("request to {} timed out".format(url)) except ConnectionResetError: raise UpdateError("request to {} reset".format(url)) except (http.client.HTTPException, urllib.error.URLError) as e: if hasattr(e, "code") and e.code == 304: # NOSONAR file not modified return False # no need to re-download raise UpdateError( "urllib/http error while updating files, url: {}, err: {}".format( url, e ) ) else: if code != 200: raise UpdateError( f"Unexpected http code {code!r} for {url}" ) # pragma: no cover with suppress(Exception): last_mtime = parsedate_to_datetime( headers["Last-Modified"] ).timestamp() if last_mtime <= current_mtime: # file on the server NOT newer logger.warning( "Got code %r, but last modification date %s is earlier" " than or equal to the date provided in the" " If-Modified-Since header, the origin server SHOULD" " generate a 304 (Not Modified) response [rfc7232]." " Here's curl cmd:\ncurl -s -I -w '%%{http_code}' -H" " 'If-Modified-Since: %s' '%s'", code, headers["Last-Modified"], formatted_mtime, url, ) return True # file has been modified since current mtime, re-download @contextmanager def _fetch_url(url: str, *, timeout: float, compress=True): """ Fetch *url* as binary file. If *compress* is true, ungzipping is done automatically if necessary. """ parameters = {} if timeout is not None: # use default timeout instead None parameters["timeout"] = timeout req_headers = {"Imunify-Server-Id": LicenseCLN.get_server_id() or ""} if compress: # express preference for gzip but don't forbid identity encoding req_headers.update({"Accept-Encoding": "gzip"}) req = urllib.request.Request(url, headers=req_headers) with urllib.request.urlopen( req, **parameters ) as response, ExitStack() as stack: # check whether response is gzipped regardless *compress* arg gzipped = response.headers.get("Content-Encoding") == "gzip" if ( compress and not gzipped and response.headers.get("Content-Type") != "application/zip" ): logger.info( "Requested gzip but got Content-Encoding=%r." " Read response as is [identity]. Headers: %s," " as curl cmd:\ncurl -Is -H 'Accept-Encoding: gzip' '%s'", response.headers.get("Content-Encoding"), response.headers.items(), url, ) yield { "file": ( stack.enter_context(GzipFile(fileobj=response)) if gzipped else response ), "headers": response.headers, } def _fetch_n_md5sum_url( url, dest_file: BinaryIO, timeout, *, compress, md5sum ): """ Fetch *url* to *dest_file* and return its md5sum. Raise *urllib.error.ContentTooShortError* if the downloaded file has unexpected length. """ md5 = hashlib.md5() initial_file_offset = dest_file.tell() with _fetch_url(url, timeout=timeout, compress=compress) as response: while chunk := response["file"].read(_BUFSIZE): # NOSONAR md5.update(chunk) dest_file.write(chunk) if not response["headers"].get("Content-Encoding") == "gzip": # Content-Length is compressed size # -> no point in comparing with the uncompressed result file_length = dest_file.tell() - initial_file_offset # make sure the file has been downloaded correctly # Content-Length may not be set if exist header # Transfer-Encoding: chunked content_length_header = response["headers"].get("Content-Length", None) if content_length_header is not None: expected_file_length = int(content_length_header) if expected_file_length != file_length: raise urllib.error.ContentTooShortError( message="{got} bytes read, {diff} more expected".format( got=file_length, diff=expected_file_length - file_length, ), content=None, ) got_md5sum = md5.hexdigest() if md5sum is not None and got_md5sum != md5sum: raise UpdateError( f"content fetched from {url} does not match hash:" f" expected={md5sum}, got={got_md5sum}" ) return got_md5sum @retry_on( UpdateError, on_error=_log_failed_update, max_tries=_MAX_TRIES_FOR_DOWNLOAD ) async def _fetch_and_save( url: str, dest_path: os.PathLike, timeout, *, dest_mode: int, compress=True, md5sum=None, ) -> str: """Fetch bytes from `url`, save them to `dest_path`, and return md5 checksum of downloaded content. Raise UpdateError: * HTTP response status code is not 200; * on time outs during HTTP request; * on other HTTP errors. """ try: with _open_with_mode(dest_path, dest_mode) as dest_file: return await to_thread( _fetch_n_md5sum_url, url, dest_file, timeout, compress=compress, md5sum=md5sum, ) except socket.timeout: raise UpdateError("request to {} timed out".format(url)) except ConnectionResetError: raise UpdateError("request to {} reset".format(url)) except EOFError as e: raise UpdateError( f"eof error while updating files, url: {url}, err: {e}" ) except (http.client.HTTPException, urllib.error.URLError) as e: raise UpdateError( "urllib/http error while updating files, url: {}, err: {}".format( url, e ) ) except OSError as e: raise UpdateError(f"Can't fetch {url} to {dest_path}, reason: {e}") _Item = namedtuple("_Item", ["url", "md5sum"]) def _items(data: Any) -> Set[_Item]: """Return a set of _Item for easy manipulation.""" return {_Item(item["url"], item["md5sum"]) for item in data["items"]} def check_mode_dirs(dirname, dir_perm, file_perm): """Check and change file/dir modes recursively. Starting at dirname, change all inner directory permissions to dir_perm, file permissions to file_perm """ def _os_chmod(file_dir_path, permission): try: current_mode = os.lstat(file_dir_path).st_mode & 0o777 if current_mode != permission and not os.path.islink( file_dir_path ): logger.warning( "Fixing wrong permission to file/dir" " %s [%s] expected [%s] (not symlink)", file_dir_path, oct(current_mode), oct(permission), ) os.chmod(file_dir_path, permission) except PermissionError: logger.error( "Failed to change permission to file %s", file_dir_path ) _os_chmod(dirname, dir_perm) for path, dirs, files in os.walk(dirname): for directory in dirs: _os_chmod(os.path.join(path, directory), dir_perm) for name in files: _os_chmod(os.path.join(path, name), file_perm) def _fix_directory_structure( description_path: pathlib.Path, files_path: pathlib.Path = FILES_DIR ) -> None: """ Try to fix the structure of /var/imunify360/files/ when NotADirectoryError happens. It indicates that some part in the path is a file: /var/imunify360/files/sigs <- is a file => open("/var/imunify360/files/sigs/v1/description.json") will fail. We try to rectify it by deleting the file but up to FILES_DIR. """ assert files_path in description_path.parents _dir = description_path.parent topmost_dir = _dir while _dir != files_path: if _dir.is_file(): _dir.unlink(missing_ok=True) topmost_dir.mkdir(parents=True, exist_ok=True) break _dir = _dir.parent class Index: # one lock is shared via Index and that allows # more than one instance of Index to co-exist _lock = defaultdict(asyncio.Lock) # type: Dict[Any, asyncio.Lock] _HOOKS = defaultdict(set) # type: Dict[str, Set[Any]] _PATHS = {} # type: Dict[str, str] _PERMS = {} # type: Dict[str, Dict[str, int]] _TYPES = set() # type: Set[str] _ESSENTIAL_TYPES = set() # type: Set[str] _ALL_ZIP_SUPPORT = {} # type: Dict[str, bool] _URL_PATH_PREFIX = "/static" _throttled_log_error = rate_limit(period=4 * HOUR)(logger.error) def __init__(self, type_, integrity_check=True): """ :param bool integrity_check: check if last update did not break anything (by interrupting it in the middle or another programmatic error) :raise IntegrityError: """ if type_ not in self._TYPES: raise ValueError( f"Trying to initiate unregistered file type {type_}. Allowed" f" types {self._TYPES}" ) self.type = type_ self._is_blank = False self._json = {"items": []} path = self._descriptionfile_path() try: with open(path) as f: self._json = json.load(f) except NotADirectoryError: Index._throttled_log_error("Path %s has a file in parents", path) _fix_directory_structure(pathlib.Path(path), FILES_DIR) except ( FileNotFoundError, UnicodeDecodeError, json.JSONDecodeError, ) as e: if integrity_check: raise IntegrityError( "cannot read description file {}".format(path) ) from e self._is_blank = True if integrity_check: bad_files = self._corrupted_files() if len(bad_files): raise IntegrityError( "some files are missing or corrupted: {}".format( ", ".join(bad_files) ) ) if not self._is_blank: self.check_mode_dirs() def __eq__(self, other): return ( self.__class__ == other.__class__ and self.type == other.type and self._is_blank == other._is_blank and self._json == other._json ) def __repr__(self): # pragma: no cover return ( f"<{self.__class__.__name__}(type_={self.type})" f" is_blank={self._is_blank}, " f"json={{<{len(self.items())}" " item(s)>}>" ) def validate(self, files_path: os.PathLike) -> None: """Whether *files_path* dir may be used for this type's file group. :raises: IntegrityError """ logger.info("Validating [%s]: %s", self.type, files_path) FileGroup = self._make_file_group( files_path ) # noqa NOSONAR disable python:S117 FileGroup(self.type, integrity_check=True) def _make_file_group(self, files_path: os.PathLike): """ Return FileGroup class: Index class with local path == *files_path*. """ class FileGroup(self.__class__): @classmethod def files_path(cls, type_: str) -> str: """Return local base path for given file type.""" assert type_ == self.type return os.fspath(files_path) return FileGroup def check_mode_dirs(self): perms = Index._PERMS[self.type] check_mode_dirs( os.path.normpath( os.path.join(FILES_DIR, Index._PATHS[self.type], os.pardir) ), perms["dir"], perms["file"], ) @classmethod def add_type( cls, type_: str, relative_path: str, dir_perm: int, file_perm: int, *, all_zip: bool = False, essential: bool = True, ) -> None: """Add a type to known file types. * relative_path is a relative path to all files for that type. * dir_perm is permission mask used to create directories. * file_perm is permission mask used to create files. * all_zip is a flag which shows whether that type of files can be downloaded in all.zip archive. all.zip is expected to be on the server. * essential is whether the agent can start if there are errors updating that type. """ cls._TYPES.add(type_) if essential: cls._ESSENTIAL_TYPES.add(type_) cls._PATHS[type_] = relative_path cls._PERMS[type_] = {"dir": dir_perm, "file": file_perm} cls._ALL_ZIP_SUPPORT[type_] = all_zip @classmethod async def essential_files_exist(cls) -> bool: """Whether essential files exist. Note: the files may be corrupted (integrity check is not performed). """ # use the existence of the description files as a proxy return all( not Index(type_, integrity_check=False)._is_blank for type_ in cls._ESSENTIAL_TYPES ) @classmethod def types(cls) -> Set[str]: """Return a set of all known files types.""" return cls._TYPES.copy() @classmethod def files_path(cls, type_: str) -> str: """Return local base path for given file type.""" return os.path.join(FILES_DIR, cls._PATHS[type_]) def _descriptionfile_path(self): """Return local path for description.json for current index.""" return os.path.join(self.files_path(self.type), "description.json") def _corrupted_files(self) -> Set[str]: """Return a set of file paths that are missing or corrupted.""" bad_files = set() for item in _items(self._json): path = self.localfilepath(item.url) try: actual = file_hash(path, hashlib.md5, _BUFSIZE) except FileNotFoundError: bad_files.add(path) continue if actual != item.md5sum: bad_files.add(path) return bad_files @classmethod def locked(cls, type_): """ usage example: >> async with Index.locked(WHITELISTS): ... """ return cls._lock[type_] def files(self) -> Iterable[str]: """Return iterable over all files in index.""" return (self.localfilepath(item.url) for item in _items(self._json)) def items(self): """Return 'items' field from JSON description.""" return self._json["items"] def _descriptionfile_mtime(self, default=_NEVER) -> float: """Return mtime of description file if it exists, otherwise -math.inf""" try: return os.stat(self._descriptionfile_path()).st_mtime except OSError: return default def _is_outdated(self) -> bool: """Return True if last update was too late in the past.""" _desc_mtime = self._descriptionfile_mtime() if not _desc_mtime: return True # pragma: no cover return _desc_mtime + config.FilesUpdate.PERIOD < time.time() async def is_update_needed(self, timeout: float) -> bool: """Return True if update from server is needed for current index.""" return ( self._is_blank or len(self._corrupted_files()) > 0 or ( self._is_outdated() and await _need_to_download( self._descriptionfile_url(self.type), self._descriptionfile_mtime(), timeout, ) ) ) def _makedirs(self, dirname, dir_mode, exist_ok=False): """Create local directory for current index.""" try: with run_with_umask(0): os.makedirs(dirname, mode=dir_mode, exist_ok=exist_ok) except OSError as e: raise UpdateError(str(e)) from e async def _update_files( self, files_path: pathlib.Path, to_update: Set[_Item], timeout ) -> None: """ Fetch files from *to_update* set, verify hashes, save to *files_path*. """ FileGroup = self._make_file_group( files_path ) # noqa NOSONAR disable python:S117 fg = FileGroup(self.type, integrity_check=False) dir_mode = fg._PERMS[fg.type]["dir"] # NOSONAR disable python:W0212 file_mode = fg._PERMS[fg.type]["file"] # NOSONAR disable python:W0212 for item in to_update: filename = fg.localfilepath(item.url) dirname = os.path.dirname(filename) if not os.path.isdir(dirname): self._makedirs(dirname, dir_mode, exist_ok=False) await _fetch_and_save( item.url, filename, timeout, dest_mode=file_mode, md5sum=item.md5sum, ) def _calculate_changes( self, remote_items: Set[_Item] ) -> Tuple[Set[_Item], Set[str]]: """Figure out what should be updated based on current items, file system state and remote items. Return tuple of files to fetch and files to delete. Files to fetch is a set of _Item. Files to delete is a set of file paths.""" local_items = _items(self._json) local_files = {self.localfilepath(item.url) for item in local_items} remote_files = {self.localfilepath(item.url) for item in remote_items} to_remove = local_files - remote_files bad_files = self._corrupted_files() local_set = { item for item in local_items if self.localfilepath(item.url) not in bad_files } to_update = remote_items - local_set return to_update, to_remove @classmethod def _descriptionfile_url(cls, type_: str) -> str: """Return remote path for description.json""" return "{}{}/description.json".format(BASE_URL, cls._PATHS[type_]) @classmethod def _all_zip_url(cls, type_: str) -> str: """Return remote path for all.zip""" return "{}{}/all.zip".format(BASE_URL, cls._PATHS[type_]) @staticmethod def _all_zip_cleanup(files_path, all_zip_localpath, remove_files=False): try: os.unlink(all_zip_localpath) except OSError as e: logger.warning( "failed to remove %s: %s", all_zip_localpath, str(e) ) if remove_files: logger.info("Removing old path on all.zip update: %s", files_path) shutil.rmtree(files_path, ignore_errors=True) @staticmethod def _generate_new_path(live_path: pathlib.Path) -> pathlib.Path: """Generate new base local path for *live_path* files. It should be on the same filesystem partition as *live_path* so that the rename would be atomic. """ new_suffix = DT.datetime.utcnow().strftime("_%Y-%m-%dT%H%M%S.%fZ") return live_path.with_name(live_path.name + new_suffix) async def _run_update_all_zip(self, timeout) -> bool: """ Update current type of files using all.zip archive. Directory with current type of files will be cleared and replaced with all.zip contents. all.zip is expected to be on the server Return whether updated. :param timeout: :raise UpdateError: if OSError or http error or integrity check error (got wrong data from the server) """ live_path = pathlib.Path(self.files_path(self.type)) new_path = Index._generate_new_path(live_path) archive_path = new_path.with_name(new_path.name + "all.zip") file_mode = self._PERMS[self.type]["file"] dir_mode = self._PERMS[self.type]["dir"] all_zip_url = self._all_zip_url(self.type) with ExitStack() as rollback_stack: # make new download dir self._makedirs(new_path, dir_mode, exist_ok=False) rollback_stack.callback( Index._all_zip_cleanup, new_path, archive_path, remove_files=True, ) # download the archive # TODO: DEF-16354 check md5sum for all.zip _ = await _fetch_and_save( all_zip_url, archive_path, timeout, dest_mode=file_mode, compress=False, ) # extract files to new dir with right permissions & verify try: with zipfile.ZipFile(archive_path, "r") as archive: # NOTE: this also verifies crc-32 checksum for files archive.extractall(new_path) # set mode for root, directories, filenames in os.walk(new_path): for directory in directories: os.chmod(os.path.join(root, directory), dir_mode) for filename in filenames: os.chmod(os.path.join(root, filename), file_mode) # verify against included description.json self.validate(new_path) # create symlink to new dir, replace *live* with the symlink old_path = Index._replace_live_with_new_dir( new_path, live_path ) except ( EOFError, IntegrityError, OSError, zipfile.BadZipfile, zipfile.LargeZipFile, ) as e: raise UpdateError(str(e)) from e # no exception, clear the rollback stack rollback_stack.pop_all() # cleanup: remove old dir & new all.zip Index._all_zip_cleanup( old_path, archive_path, remove_files=bool(old_path) ) return True # updated @staticmethod def _replace_live_with_new_dir( new_path: pathlib.Path, live_path: pathlib.Path ) -> pathlib.Path: """Replace *live_path* with *new_path*. Return *old_path* :raises: OSError """ new_live_path = new_path.with_name(new_path.name + "live") moved_path = None with ExitStack() as rollback_stack: new_live_path.symlink_to(new_path, target_is_directory=True) rollback_stack.callback(new_live_path.unlink) # save the path to old dir for the cleanup old_path = ( live_path.resolve(strict=False) if live_path.is_symlink() else None ) # switch to the new version # NOTE: nothing until this point touched old version; # the rename should be atomic # (paths are on the same partition) for last in range(2): # pragma: no branch try: new_live_path.rename(live_path) break except IsADirectoryError: if last: # give up (keep old) raise # pragma: no cover # live_path is a directory # (old agent version or tests) # move it so that the rename above could happen if not live_path.is_symlink(): # pragma: no branch # use unique to the current update name moved_path = new_live_path.with_name( new_live_path.name + ".live-moved" ) logger.info( "Moving %s [live] to %s," " to rename %s to it [live]", live_path, moved_path, new_live_path, ) live_path.replace(moved_path) # if enabling new_live fails the 2nd time, # try to move back, to restore old dir rollback_stack.callback(moved_path.replace, live_path) if moved_path is not None: shutil.rmtree(moved_path, ignore_errors=True) # no exception, clear the rollback stack rollback_stack.pop_all() return old_path async def _run_update(self, timeout) -> bool: """ Run update, return whether updated. :raise UpdateError: if OSError or http error or integrity check error (got wrong data from the server) """ url = self._descriptionfile_url(self.type) as_json = await _fetch_json(url, timeout=timeout) to_update, to_remove = self._calculate_changes(_items(as_json)) need_update = to_update or to_remove if not need_update: logger.info("updating %s: nothing to update.", self.type) self._touch() # postpone the next try for FilesUpdate.PERIOD return False # not updated # perform atomic update live_path = pathlib.Path(self.files_path(self.type)) # note: it is ok if the symlink changes before .resolve() is called old_path = ( live_path.resolve(strict=False) if live_path.is_symlink() else None ) new_path = Index._generate_new_path(live_path) # make new download dir with ExitStack() as rollback_stack: self._makedirs( new_path, self._PERMS[self.type]["dir"], exist_ok=False ) rollback_stack.callback( shutil.rmtree, new_path, ignore_errors=True ) # copy all files from *old* dir to *new* dir except those # that needs updating from_path = ( old_path if old_path and old_path.is_dir() else live_path ) if from_path.is_dir(): await Index._copytree( from_path, new_path, to_remove.union( self.localfilepath(item.url) for item in to_update ), ) # download *to_update* files to *new_path* await self._update_files(new_path, to_update, timeout=timeout) try: # write description.json with _open_with_mode( new_path / "description.json", self._PERMS[self.type]["file"], ) as file: file.write(json.dumps(as_json).encode()) # verify against included description.json self.validate(new_path) # create symlink to new dir, replace *live* with the symlink old_path = self._replace_live_with_new_dir(new_path, live_path) except (IntegrityError, OSError) as e: raise UpdateError(str(e)) from e # no exception, clear the rollback stack rollback_stack.pop_all() # cleanup: remove old path on success if old_path and old_path.is_dir(): logger.info( "Removing old path on file by file update: %s", old_path ) shutil.rmtree(old_path, ignore_errors=True) return True # updated @staticmethod async def _copytree( from_dir: os.PathLike, to_dir: os.PathLike, ignored_paths: Set[str] ) -> None: """Copy *from_dir* to *to_dir* except for *ignored_paths*.""" def ignore_names(path, names): """Return names that should not be copied.""" assert isinstance(os.fspath(path), str) # no bytes here return frozenset( name for name in names if os.path.join(path, name) in ignored_paths ) await to_thread( shutil.copytree, from_dir, to_dir, symlinks=True, ignore=ignore_names, dirs_exist_ok=True, ) def localfilepath(self, url: str) -> str: """Return a local file path corresponding to URL.""" url_relpath = os.path.relpath( urlparse(url).path, self._URL_PATH_PREFIX ) type_path = self._PATHS[self.type] assert ( pathlib.Path(type_path) in pathlib.Path(url_relpath).parents ), "url ({}) does not fit file path ({})".format(url, type_path) relative_path = os.path.relpath(url_relpath, type_path) return os.path.join(self.files_path(self.type), relative_path) def _touch(self) -> None: """Update mtime of description.json file so it is fresh.""" try: path = self._descriptionfile_path() if os.path.isfile(path): # pragma: no branch os.utime(path) except OSError as e: # pragma: no cover logger.warning(str(e)) async def _run_hooks(self, is_updated) -> None: for hook in chain(self._HOOKS[self.type], [default_hook]): try: await hook(self, is_updated) except Exception as e: logger.exception("hook %s error: %s", hook, e) logger.info( "%s files update finished%s", self.type, " (not updated)" * (not is_updated), ) async def update(self, force=False) -> None: """Run update for the current `type` of files. Normally update is performed when either is true: * index is never been fetched (description.json missing or broken); * last update was performed longer than configured period of time ago; * some local files are missing or have wrong content (md5 hash differs from description.json). If force is True then update is performed unconditionally. Raises asyncio.TimeoutError, UpdateError. """ timeout = config.FilesUpdate.TIMEOUT # total timeout if not force and not await self.is_update_needed(timeout): logger.info( "%s was updated less than %s minutes ago.", self.type, int(config.FilesUpdate.PERIOD // 60), ) await self._run_hooks(is_updated=False) return all_zip = self._is_blank and self._ALL_ZIP_SUPPORT[self.type] file_by_file = not all_zip if all_zip: log_str = "all.zip" logger.info("Updating %s files via %s", self.type, log_str) # Download updates using all.zip in case of empty or # corrupted description.json. # Initially we try to download updates using all.zip, if # error happened - download file by file. try: updated = await asyncio.wait_for( self._run_update_all_zip( config.FilesUpdate.SOCKET_TIMEOUT ), timeout, ) if updated: logger.info("Updated %s using %s", self.type, log_str) except (asyncio.TimeoutError, UpdateError) as e: logger.error( "%s update error via %s: %s", self.type, log_str, e ) file_by_file = True if file_by_file: log_str = "file by file download" logger.info("Updating %s files via %s", self.type, log_str) try: updated = await asyncio.wait_for( self._run_update(config.FilesUpdate.SOCKET_TIMEOUT), timeout, ) if updated: logger.info("Updated %s using %s", self.type, log_str) except (asyncio.TimeoutError, UpdateError) as e: logger.error( "%s update error via %s: %s", self.type, log_str, e ) await self._run_hooks(is_updated=False) # Ignore errors only for non-essential files if self.type in self._ESSENTIAL_TYPES: raise e else: return await self._run_hooks(is_updated=updated or force) @classmethod async def update_all( cls, only_type: Optional[str] = None, force=False, only_essential=False ) -> None: """Run update for all registered `types` of files. Raises asyncio.TimeoutError, UpdateError. """ if only_type: index = cls(only_type, integrity_check=False) async with cls.locked(only_type): await index.update(force) elif only_essential: logger.info("Updating essential files") for type_ in cls._ESSENTIAL_TYPES: index = cls(type_, integrity_check=False) async with cls.locked(type_): await index.update(force) else: logger.info("Updating all files") for type_ in cls._TYPES: index = cls(type_, integrity_check=False) async with cls.locked(type_): await index.update(force) @classmethod def add_hook(cls, type_: str, hook) -> None: """Add a hook for type_ to be called after successful update.""" cls._HOOKS[type_].add(hook) def configure() -> None: """Register required file types.""" Index.add_type(EULA, "eula/v1", 0o770, 0o660, all_zip=False) Index.add_type(SIGS, "sigs/v1", 0o775, 0o644, all_zip=True) Index.add_type( REALTIME_AV_CONF, "realtime-av-conf/v1", 0o770, 0o660, all_zip=False, ) update = Index.update_all essential_files_exist = Index.essential_files_exist async def update_and_log_error( only_type: Optional[str] = None, force=False ) -> None: """Run files.update and log Update/TimeoutErrors.""" try: return await Index.update_all(only_type, force) except (asyncio.TimeoutError, UpdateError) as err: logger.error( "Failed to update files [%s] with error: %s", only_type, err ) async def update_all_no_fail_if_files_exist(): """Update all files. Don't fail if essential files exist.""" try: return await Index.update_all(only_essential=True) except (asyncio.TimeoutError, UpdateError) as err: if await Index.essential_files_exist(): logger.error( "Failed to update files [essential files exist]: %s", err ) else: # re-raise if isinstance(err, asyncio.TimeoutError): raise UpdateError from err # wrap else: raise