From 5faed9d5f5284b9182d5e12af578ba91c38a395a Mon Sep 17 00:00:00 2001 From: tomaszduda23 Date: Tue, 9 Jun 2026 13:04:51 +0200 Subject: [PATCH] [nrf52] native build - download toolchain and sdk in venv (#16388) Co-authored-by: Jonathan Swoboda <154711427+swoboda1337@users.noreply.github.com> Co-authored-by: Jonathan Swoboda --- esphome/components/nrf52/__init__.py | 13 + esphome/components/nrf52/framework.py | 171 ++++ esphome/const.py | 1 + esphome/core/__init__.py | 4 + esphome/espidf/framework.py | 597 +---------- esphome/framework_helpers.py | 677 +++++++++++++ requirements.txt | 1 + tests/unit_tests/test_core.py | 7 + tests/unit_tests/test_espidf_framework.py | 530 +++++++++- tests/unit_tests/test_framework_helpers.py | 954 ++++++++++++++++++ tests/unit_tests/test_nrf52_framework.py | 219 ++++ tests/unit_tests/test_platformio_toolchain.py | 15 + 12 files changed, 2624 insertions(+), 565 deletions(-) create mode 100644 esphome/components/nrf52/framework.py create mode 100644 esphome/framework_helpers.py create mode 100644 tests/unit_tests/test_framework_helpers.py create mode 100644 tests/unit_tests/test_nrf52_framework.py diff --git a/esphome/components/nrf52/__init__.py b/esphome/components/nrf52/__init__.py index 48b67e1ef9..56367d0b26 100644 --- a/esphome/components/nrf52/__init__.py +++ b/esphome/components/nrf52/__init__.py @@ -63,6 +63,7 @@ from .const import ( BOOTLOADER_ADAFRUIT_NRF52_SD140_V6, BOOTLOADER_ADAFRUIT_NRF52_SD140_V7, ) +from .framework import check_and_install # force import gpio to register pin schema from .gpio import nrf52_pin_to_code # noqa: F401 @@ -562,3 +563,15 @@ def process_stacktrace(config: ConfigType, line: str, backtrace_state: bool) -> _LOGGER.error("LR: %s", _addr2line(addr2line, elf, lr)) return False + + +def run_compile(args, config: ConfigType) -> bool: + if CORE.using_toolchain_platformio: + return False + if not CORE.using_toolchain_sdk_nrf: + raise EsphomeError( + "Unsupported toolchain for nRF52. " + "Supported toolchains are 'platformio' and 'sdk-nrf'." + ) + check_and_install() + raise EsphomeError("Native build for nRF52 is not implemented yet") diff --git a/esphome/components/nrf52/framework.py b/esphome/components/nrf52/framework.py new file mode 100644 index 0000000000..607ad0c7ed --- /dev/null +++ b/esphome/components/nrf52/framework.py @@ -0,0 +1,171 @@ +import logging +import os +from pathlib import Path +import platform +import tempfile + +from esphome.const import KEY_CORE, KEY_FRAMEWORK_VERSION +from esphome.core import CORE, EsphomeError +from esphome.framework_helpers import ( + archive_extract_all, + create_venv, + download_from_mirrors, + get_python_env_executable_path, + rmdir, + run_command_ok, + str_to_lst_of_str, +) + +_LOGGER = logging.getLogger(__name__) + +_WEST_VERSION = "1.5.0" +_TOOLCHAIN_VERSION = "0.17.4" + +SDK_NG_TOOLCHAIN_MIRRORS = str_to_lst_of_str( + os.environ.get( + "ESPHOME_SDK_NG_TOOLCHAIN_MIRRORS", + "https://github.com/zephyrproject-rtos/sdk-ng/releases/download/v{VERSION}/toolchain_{sysname}-{machine}_arm-zephyr-eabi.{extension}", + ) +) + + +def _get_tools_path() -> Path: + return CORE.data_dir / "sdk-nrf" + + +def _get_python_env_path(version: str) -> Path: + return _get_tools_path() / "penvs" / version + + +def _get_framework_path(version: str) -> Path: + return _get_tools_path() / "frameworks" / f"{version}" + + +def _get_toolchain_path(version: str) -> Path: + return _get_tools_path() / "toolchains" / f"{version}" + + +# onexc/dir_fd were added to shutil.rmtree in 3.12; the 3.11 branch uses onerror. +_SITECUSTOMIZE = """\ +import os, stat, shutil, sys +_orig = shutil.rmtree +def _handler(func, path, exc): + os.chmod(path, stat.S_IWRITE); func(path) +if sys.version_info >= (3, 12): + def _rmtree(path, ignore_errors=False, onerror=None, *, onexc=None, dir_fd=None): + if onerror is None and onexc is None: + onexc = _handler + return _orig(path, ignore_errors=ignore_errors, onerror=onerror, onexc=onexc, dir_fd=dir_fd) +else: + def _rmtree(path, ignore_errors=False, onerror=None): + if onerror is None: + onerror = _handler + return _orig(path, ignore_errors=ignore_errors, onerror=onerror) +shutil.rmtree = _rmtree +""" + + +def _install_sitecustomize(python_env_path: Path) -> None: + """Patch shutil.rmtree inside the penv to handle read-only files. + + west init's shutil.move falls back to copytree+rmtree on Windows, and + rmtree dies on the read-only .idx/.pack files git just wrote into + manifest-tmp. Dropping a sitecustomize.py into the venv applies the + same fix esphome.helpers.rmtree uses, but inside the subprocess. + """ + if os.name != "nt": + return + site_packages = python_env_path / "Lib" / "site-packages" + site_packages.mkdir(parents=True, exist_ok=True) + (site_packages / "sitecustomize.py").write_text(_SITECUSTOMIZE, encoding="utf-8") + + +def _get_toolchain_platform_info() -> tuple[str, str, str]: + """Return (sysname, machine, extension) for the current host.""" + extension = "tar.xz" + sysname = platform.system().lower() + machine = platform.machine() + if machine == "arm64": + machine = "aarch64" + if sysname == "darwin": + sysname = "macos" + elif sysname == "windows": + machine = "x86_64" + extension = "7z" + return sysname, machine, extension + + +def check_and_install() -> None: + framework_ver = CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] + version = f"v{framework_ver.major}.{framework_ver.minor}.{framework_ver.patch}" + python_env_path = _get_python_env_path(version) + env_python_path = get_python_env_executable_path(python_env_path, "python") + sentinel = python_env_path / ".ready" + install_venv = not sentinel.exists() + if install_venv: + rmdir(python_env_path, msg=f"Clean up {version} Python environment") + + create_venv(python_env_path, msg=f"{version}") + + _install_sitecustomize(python_env_path) + + _LOGGER.info("Installing west %s ...", _WEST_VERSION) + cmd = [str(env_python_path), "-m", "pip", "install", f"west=={_WEST_VERSION}"] + if not run_command_ok(cmd): + raise EsphomeError(f"Install west for {version} Python environment failure") + sentinel.touch() + + framework_path = _get_framework_path(version) + sentinel = framework_path / ".ready" + if install_venv or not sentinel.exists(): + rmdir(framework_path, msg=f"Clean up {version} framework environment") + _LOGGER.info("Initializing nRF Connect SDK %s ...", version) + cmd = [ + str(env_python_path), + "-m", + "west", + "init", + "-m", + "https://github.com/nrfconnect/sdk-nrf", + "--mr", + f"{version}", + str(framework_path), + ] + if not run_command_ok(cmd): + raise EsphomeError(f"Can't initialize nRF Connect SDK {version}") + _LOGGER.info("Updating nRF Connect SDK %s (this may take a while) ...", version) + cmd = [ + str(env_python_path), + "-m", + "west", + "update", + "--narrow", + "--fetch-opt=--depth=1", + ] + if not run_command_ok(cmd, cwd=framework_path): + raise EsphomeError(f"Can't update nRF Connect SDK {version}") + sentinel.touch() + + toolchains_dir = _get_toolchain_path(_TOOLCHAIN_VERSION) + sentinel = toolchains_dir / ".ready" + if not sentinel.exists(): + rmdir( + toolchains_dir, msg=f"Clean up {_TOOLCHAIN_VERSION} toolchain environment" + ) + with tempfile.NamedTemporaryFile() as tmp: + _LOGGER.info("Downloading %s toolchain ...", _TOOLCHAIN_VERSION) + + sysname, machine, extension = _get_toolchain_platform_info() + + download_from_mirrors( + SDK_NG_TOOLCHAIN_MIRRORS, + { + "VERSION": _TOOLCHAIN_VERSION, + "sysname": sysname, + "machine": machine, + "extension": extension, + }, + tmp.file, + ) + archive_extract_all(tmp.file, toolchains_dir, progress_header="Extracting") + sentinel.touch() diff --git a/esphome/const.py b/esphome/const.py index 07f6bad771..22351244bd 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -20,6 +20,7 @@ class Toolchain(StrEnum): PLATFORMIO = "platformio" ESP_IDF = "esp-idf" + SDK_NRF = "sdk-nrf" class Platform(StrEnum): diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index df8fd0a756..90c162fedd 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -867,6 +867,10 @@ class EsphomeCore: def using_toolchain_platformio(self): return self.toolchain == Toolchain.PLATFORMIO + @property + def using_toolchain_sdk_nrf(self): + return self.toolchain == Toolchain.SDK_NRF + @property def using_zephyr(self): return self.target_framework == "zephyr" diff --git a/esphome/espidf/framework.py b/esphome/espidf/framework.py index 2c520d0d2c..1bc79cc412 100644 --- a/esphome/espidf/framework.py +++ b/esphome/espidf/framework.py @@ -1,8 +1,5 @@ """ESP-IDF framework tools for ESPHome.""" -from collections.abc import Iterable -from contextlib import ExitStack -import io import json import logging import os @@ -10,39 +7,29 @@ from pathlib import Path import platform import re import shutil -import subprocess -import sys import tempfile -from typing import IO - -import requests from esphome.config_validation import Version from esphome.core import CORE -from esphome.helpers import ProgressBar, get_str_env, rmtree, write_file_if_changed - -PathType = str | os.PathLike +from esphome.framework_helpers import ( + PathType, + archive_extract_all, + create_venv, + download_from_mirrors, + get_python_env_executable_path, + get_system_python_path, + rmdir, + run_command, + run_command_ok, + str_to_lst_of_str, +) +from esphome.helpers import get_str_env, write_file_if_changed _LOGGER = logging.getLogger(__name__) _SCRIPTS_DIR = Path(__file__).parent -def _str_to_lst_of_str(a: str | list[str]) -> list[str]: - """ - Convert a string to a list of string - - Args: - a: A string containing semicolon-separated values, or an already-split list - - Returns: - list of strings - """ - if isinstance(a, list): - return a - return [f.strip() for f in a.split(";") if f.strip()] - - ESPHOME_STAMP_FILE = ".esphome.stamp.json" # Cache-buster baked into the stamp file. Bump this whenever a change would @@ -54,23 +41,23 @@ ESPHOME_STAMP_FILE = ".esphome.stamp.json" # Bumping triggers a full reinstall on every user's next run. STAMP_SCHEMA_VERSION = "0" -ESPHOME_IDF_DEFAULT_TARGETS = _str_to_lst_of_str( +ESPHOME_IDF_DEFAULT_TARGETS = str_to_lst_of_str( os.environ.get("ESPHOME_IDF_DEFAULT_TARGETS", "all") ) -ESPHOME_IDF_DEFAULT_TOOLS = _str_to_lst_of_str( +ESPHOME_IDF_DEFAULT_TOOLS = str_to_lst_of_str( os.environ.get("ESPHOME_IDF_DEFAULT_TOOLS", "cmake;ninja") ) -ESPHOME_IDF_DEFAULT_TOOLS_FORCE = _str_to_lst_of_str( +ESPHOME_IDF_DEFAULT_TOOLS_FORCE = str_to_lst_of_str( os.environ.get("ESPHOME_IDF_DEFAULT_TOOLS_FORCE", "required") ) -ESPHOME_IDF_DEFAULT_FEATURES = _str_to_lst_of_str( +ESPHOME_IDF_DEFAULT_FEATURES = str_to_lst_of_str( os.environ.get("ESPHOME_IDF_DEFAULT_FEATURES", "core") ) -ESPHOME_IDF_FRAMEWORK_MIRRORS = _str_to_lst_of_str( +ESPHOME_IDF_FRAMEWORK_MIRRORS = str_to_lst_of_str( os.environ.get("ESPHOME_IDF_FRAMEWORK_MIRRORS") or [ "https://github.com/esphome-libs/esp-idf/releases/download/v{VERSION}/esp-idf-v{VERSION}.tar.xz", @@ -78,7 +65,7 @@ ESPHOME_IDF_FRAMEWORK_MIRRORS = _str_to_lst_of_str( ] ) -ESP_IDF_CONSTRAINTS_MIRRORS = _str_to_lst_of_str( +ESP_IDF_CONSTRAINTS_MIRRORS = str_to_lst_of_str( os.environ.get( "ESP_IDF_CONSTRAINTS_MIRRORS", "https://dl.espressif.com/dl/esp-idf/espidf.constraints.v{VERSION}.txt", @@ -124,59 +111,6 @@ def _get_python_env_path(version: str) -> Path: return _get_idf_tools_path() / "penvs" / f"{version}" -def rmdir(directory: PathType, msg: str | None = None): - """ - Remove a directory and its contents recursively if it exists. - - Args: - directory: Path to the directory to be removed - msg: Optional debug message to log before removal or it an error occurs - - Returns: - None - - Raises: - RuntimeError: If directory removal fails - """ - if Path(directory).is_dir(): - try: - if msg: - _LOGGER.debug(msg) - rmtree(directory) - except OSError as e: - raise RuntimeError( - f"Error during {msg}: can't remove `{directory}`. Please remove it manually!" - ) from e - - -def _get_pythonexe_path() -> str: - """ - Get the path to the Python executable. - - Returns: - Path to Python executable as string - """ - # Try to get PYTHONEXEPATH environment variable - # Fallback to sys.executable if not set - return os.environ.get("PYTHONEXEPATH", os.path.normpath(sys.executable)) - - -def _get_python_env_executable_path(root: PathType, binary: str) -> Path: - """ - Get the path to a Python environment executable file. - - Args: - root: Root directory of the Python environment - binary: Name of the executable binary - - Returns: - Path object pointing to the executable file - """ - if os.name == "nt": - return Path(root) / "Scripts" / f"{binary}.exe" - return Path(root) / "bin" / binary - - def _check_stamp(file: PathType, data: dict[str, str]) -> bool: """ Check if a stamp file contains the expected data. @@ -210,84 +144,6 @@ def _write_stamp(file: PathType, data: dict[str, str]): json.dump(data, fp) -def _exec( - cmd: list[str], - msg: str | None = None, - env: dict[str, str] | None = None, - stream_output: bool = False, -) -> tuple[bool, str | None, str | None]: - """ - Execute a command and return results. - - Args: - cmd: list of command arguments - msg: Optional custom message for logging - env: Optional dictionary of environment variables to set - stream_output: If True, inherit parent stdio so the subprocess prints - directly to the terminal (useful for commands that produce their - own progress output). stdout/stderr are not captured in this mode. - - Returns: - tuple of (success: bool, stdout: str or None, stderr: str or None). - When stream_output is True, stdout and stderr are always None. - """ - cmd_str = msg or " ".join(cmd) - try: - _LOGGER.debug("%s - running ...", cmd_str) - - run_env = os.environ.copy() - if env: - run_env.update(env) - - if stream_output: - result = subprocess.run(cmd, check=False, env=run_env) - stdout = stderr = None - else: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=False, - env=run_env, - ) - stdout = result.stdout - stderr = result.stderr - - if result.returncode != 0: - if stream_output: - _LOGGER.error("%s - failed (returncode=%s)", cmd_str, result.returncode) - else: - tail = (stderr or stdout or "").strip()[-1000:] - _LOGGER.error( - "%s - failed (returncode=%s). Tail:\n%s", - cmd_str, - result.returncode, - tail, - ) - return False, stdout, stderr - - _LOGGER.debug("%s - executed successfully", cmd_str) - return True, stdout, stderr - - except (subprocess.SubprocessError, OSError) as e: - _LOGGER.error("%s - error: %s", cmd_str, str(e)) - return False, None, None - - -def _exec_ok(*args, **kwargs) -> bool: - """ - Execute a command and return only the success status. - - Args: - *args: Positional arguments to pass to _exec function - **kwargs: Keyword arguments to pass to _exec function - - Returns: - True if command executed successfully, False otherwise - """ - return _exec(*args, **kwargs)[0] - - def _get_idf_version( idf_framework_root: PathType, env: dict[str, str] | None = None ) -> str: @@ -306,12 +162,12 @@ def _get_idf_version( """ cmd = [ - _get_pythonexe_path(), + get_system_python_path(), str(_SCRIPTS_DIR / "get_idf_version.py"), str(idf_framework_root), ] - success, stdout, stderr = _exec( + success, stdout, stderr = run_command( cmd, msg="ESP-IDF version", env=(env or os.environ) @@ -346,12 +202,12 @@ def _get_idf_tool_paths( """ cmd = [ - _get_pythonexe_path(), + get_system_python_path(), str(_SCRIPTS_DIR / "get_idf_tool_paths.py"), str(idf_framework_root), ] - success, stdout, stderr = _exec( + success, stdout, stderr = run_command( cmd, msg="ESP-IDF tool paths", env=(env or os.environ) @@ -397,7 +253,7 @@ print(".".join([str(x) for x in sys.version_info])) """ cmd = [python_executable, "-c", script] - success, stdout, _ = _exec(cmd, msg="Python version", env=env) + success, stdout, _ = run_command(cmd, msg="Python version", env=env) if stdout: stdout = stdout.strip() @@ -406,393 +262,6 @@ print(".".join([str(x) for x in sys.version_info])) return stdout -def _create_venv(root: PathType, msg: str | None = None): - """ - Create a Python virtual environment. - - Args: - root: Path to the virtual environment directory - msg: Optional message for logging - - Returns: - None - - Raises: - Exception: If virtual environment creation fails - """ - cmd = [_get_pythonexe_path(), "-m", "venv", "--clear", root] - if not _exec_ok(cmd, msg=f"Create Python virtual environment for {msg}"): - raise RuntimeError(f"Can't create Python virtual environment for {msg}") - - -def _detect_archive_root(names: Iterable[str]) -> str | None: - """Detect a single top-level directory shared by all archive entries. - - Returns the directory name if every non-empty entry sits under the same - top-level directory, else ``None``. Extraction helpers use this to strip - the wrapper directory commonly found in source archives during extraction - rather than renaming it afterwards — post-extraction renames are - unreliable on Windows because antivirus and the search indexer briefly - hold handles on freshly written files. - """ - root: str | None = None - has_descendant = False - for raw in names: - name = raw.replace("\\", "/").strip("/") - if not name: - continue - first, sep, _ = name.partition("/") - if root is None: - root = first - elif root != first: - return None - if sep: - has_descendant = True - return root if has_descendant else None - - -def _tar_extract_all( - data: io.BufferedIOBase, - extract_dir: PathType = ".", - progress_header: str | None = None, -): - """ - Extract a TAR archive to the specified directory. - - Implementation is inspired by Python 3.12's tarfile data filtering logic. - This can be replaced with the standard library implementation once - support for Python 3.11 is no longer required. - - Args: - data: File-like object containing the TAR archive - extract_dir: Directory to extract contents to - progress_header: If set, show a progress bar with this header - """ - import stat - import tarfile - - # Tar extraction safety: os.path.realpath / commonpath / normpath have no - # pathlib equivalents and Path.resolve() would follow symlinks unsafely. - # Use os.path for the security-sensitive parts; the simple checks move to - # Path. - extract_dir = os.fspath(extract_dir) - abs_dest = os.path.abspath(extract_dir) # noqa: PTH100 - - with tarfile.open(fileobj=data, mode="r") as tar_ref: - all_members = tar_ref.getmembers() - - # Detect a single common top-level directory and strip it during - # extraction so we don't have to flatten it via a rename afterwards. - strip_root = _detect_archive_root(m.name for m in all_members) - strip_prefix = f"{strip_root}/" if strip_root is not None else None - - safe_members = [] - - for member in all_members: - name = member.name - - # 1. Strip leading slashes - name = name.lstrip("/" + os.sep) - - # 2. Reject absolute paths (incl. Windows drive) - if Path(name).is_absolute() or ( - os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206 - ): - continue - - # 3. Strip wrapper directory if one was detected - if strip_prefix is not None: - norm = name.replace("\\", "/") - if norm in (strip_root, strip_prefix): - continue - if not norm.startswith(strip_prefix): - continue - name = norm[len(strip_prefix) :] - - # 4. Compute final path - target_path = os.path.realpath(os.path.join(abs_dest, name)) # noqa: PTH118 - if os.path.commonpath([abs_dest, target_path]) != abs_dest: - continue - - # 5. Validate links properly - if member.issym() or member.islnk(): - linkname = member.linkname - - # Reject absolute link targets - if Path(linkname).is_absolute(): - continue - - # Strip leading slashes - linkname = os.path.normpath(linkname) - - if member.issym(): - link_target = os.path.join( # noqa: PTH118 - abs_dest, - os.path.dirname(name), # noqa: PTH120 - linkname, - ) - else: - link_target = os.path.join(abs_dest, linkname) # noqa: PTH118 - link_target = os.path.realpath(link_target) - - if os.path.commonpath([abs_dest, link_target]) != abs_dest: - continue - - # write back normalized linkname - member.linkname = linkname - - # 6. Sanitize permissions - mode = member.mode - if mode is not None: - # Strip high bits & group/other write bits - mode &= ( - stat.S_IRWXU - | stat.S_IRGRP - | stat.S_IXGRP - | stat.S_IROTH - | stat.S_IXOTH - ) - if member.isfile() or member.islnk(): - # remove exec bits unless explicitly user-executable - if not (mode & stat.S_IXUSR): - mode &= ~(stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) - mode |= stat.S_IRUSR | stat.S_IWUSR - elif not (member.isdir() or member.issym()): - # Block special files. Directories and symlinks keep - # their masked-original mode — passing None here would - # crash tarfile.extract on Python <3.12 (its chmod - # path calls os.chmod unconditionally). - continue - - member.mode = mode - - # 7. Strip ownership - member.uid = None - member.gid = None - member.uname = None - member.gname = None - - # 8. Assign sanitized name back - member.name = name - - safe_members.append(member) - - total = len(safe_members) - progress = ( - ProgressBar(progress_header) if progress_header and total > 0 else None - ) - for i, member in enumerate(safe_members, 1): - tar_ref.extract(member, abs_dest) - if progress is not None: - progress.update(i / total) - if progress is not None: - progress.update(1) - - -def _zip_extract_all( - data: io.BufferedIOBase, - extract_dir: PathType = ".", - progress_header: str | None = None, -): - """ - Extract a ZIP archive to the specified directory. - - Args: - data: File-like object containing the ZIP archive - extract_dir: Directory to extract contents to - progress_header: If set, show a progress bar with this header - """ - import zipfile - - # See note in archive_extract_all_tar: os.path is used intentionally for - # the security-sensitive abspath/commonpath checks below. - extract_dir = os.path.abspath(extract_dir) # noqa: PTH100 - - with zipfile.ZipFile(data, "r") as zip_ref: - all_members = zip_ref.infolist() - - # Detect a single common top-level directory and strip it during - # extraction so we don't have to flatten it via a rename afterwards. - strip_root = _detect_archive_root(m.filename for m in all_members) - strip_prefix = f"{strip_root}/" if strip_root is not None else None - - total = len(all_members) - progress = ( - ProgressBar(progress_header) if progress_header and total > 0 else None - ) - - for i, member in enumerate(all_members, 1): - # 1. Normalize name - name = member.filename.lstrip("/\\") - - # 2. Reject absolute paths / Windows drives - if Path(name).is_absolute() or ( - os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206 - ): - continue - - # 3. Strip wrapper directory if one was detected - if strip_prefix is not None: - norm = name.replace("\\", "/") - if norm in (strip_root, strip_prefix): - continue - if not norm.startswith(strip_prefix): - continue - name = norm[len(strip_prefix) :] - - # 4. Compute safe target path - target_path = os.path.abspath(os.path.join(extract_dir, name)) # noqa: PTH100, PTH118 - - if os.path.commonpath([extract_dir, target_path]) != extract_dir: - raise ValueError(f"Unsafe path detected: {member.filename}") - - # 5. Assign sanitized name back - member.filename = name - - # 6. Extract - zip_ref.extract(member, extract_dir) - - if progress is not None: - progress.update(i / total) - if progress is not None: - progress.update(1) - - -_ARCHIVE_MAGIC_MAP = { - b"\x1f\x8b\x08": _tar_extract_all, - b"\x42\x5a\x68": _tar_extract_all, - b"\xfd\x37\x7a\x58\x5a\x00": _tar_extract_all, - b"\x50\x4b\x03\x04": _zip_extract_all, -} - - -def archive_extract_all( - archive: PathType | io.RawIOBase | IO[bytes], - extract_dir: PathType = ".", - progress_header: str | None = None, -): - """ - Extract an archive file to the specified directory. - - Args: - archive: Path to archive file or file-like object - extract_dir: Directory to extract contents to - progress_header: If set, show a progress bar with this header - - Raises: - TypeError: If archive is not a valid type - ValueError: If archive format is unsupported - """ - - # 1. Handle different archive input types - with ExitStack() as stack: - archive_ref: io.BufferedIOBase - if isinstance(archive, (str, os.PathLike)): - archive_ref = stack.enter_context(Path(archive).open("rb")) - elif isinstance(archive, (io.BufferedReader, io.BufferedRandom)): - archive_ref = archive - elif isinstance(archive, io.RawIOBase): - archive_ref = io.BufferedReader(archive) - else: - raise TypeError( - f"archive must be str, Path, or file-like object: {type(archive)}" - ) - - # 2. Detect archive format and select appropriate extraction function - matched_fct = None - magic_len = max(len(k) for k in _ARCHIVE_MAGIC_MAP) - header = archive_ref.peek(magic_len) - for magic, fct in _ARCHIVE_MAGIC_MAP.items(): - if header.startswith(magic): - matched_fct = fct - break - if matched_fct is None: - raise ValueError("Unsupported archive format") - matched_fct(archive_ref, extract_dir, progress_header=progress_header) - - -def download_from_mirrors( - mirrors: list[str], - substitutions: dict[str, str], - target: io.RawIOBase | IO[bytes] | PathType, - timeout: int = 30, -) -> str | None: - """ - Download file from multiple mirrors with substitution support. - - Args: - mirrors: list of mirror URLs - substitutions: Dictionary of substitutions to apply to URLs - target: Target file path or file-like object - timeout: Download timeout in seconds - - Returns: - The source URL. - - Raises: - Exception: If all download attempts fail - """ - # 1. Open target file for writing if path given - with ExitStack() as stack: - if isinstance(target, (str, os.PathLike)): - f = stack.enter_context(Path(target).open("wb")) - elif isinstance(target, (io.RawIOBase, io.IOBase)): - f = target - else: - raise TypeError( - f"target must be str, Path, or file-like object: {type(target)}" - ) - - # 2. Try each mirror in order - last_exception = None - - for mirror in mirrors: - # 3. Apply substitutions to URL - url = mirror.format(**substitutions) - - _LOGGER.debug("Trying downloading from %s", url) - - try: - # 4. Reset file pointer and download - f.seek(0) - f.truncate(0) - - with requests.get(url, stream=True, timeout=timeout) as r: - r.raise_for_status() - - total_size = int(r.headers.get("content-length", 0)) - downloaded = 0 - - progress = ProgressBar("Downloading") if total_size > 0 else None - - for chunk in r.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - - downloaded += len(chunk) - - if progress is not None: - progress.update(downloaded / total_size) - - if progress is not None: - progress.update(1) - - _LOGGER.debug("Downloaded successfully from: %s", url) - - # 6. Reset file pointer and return - f.seek(0) - return url - - except Exception as e: # noqa: BLE001 # pylint: disable=broad-exception-caught - _LOGGER.debug("Failed to download %s: %s", url, str(e)) - last_exception = e - - # 7. Raise last exception if all mirrors failed - if last_exception: - raise last_exception - return None - - _GITHUB_SHORTHAND_RE = re.compile( r"^github://([a-zA-Z0-9\-]+)/([a-zA-Z0-9\-\._]+?)(?:@([a-zA-Z0-9\-_.\./]+))?$" ) @@ -1067,12 +536,12 @@ def _check_esphome_idf_framework_install( if _check_stamp(env_stamp_file, stamp_info): _LOGGER.info("Checking ESP-IDF %s framework installation ...", version) cmd = [ - _get_pythonexe_path(), + get_system_python_path(), str(idf_tools_path), "--non-interactive", "check", ] - if _exec_ok(cmd, msg=f"ESP-IDF {version} check", env=env): + if run_command_ok(cmd, msg=f"ESP-IDF {version} check", env=env): install = False # 4. Install framework tools if not installed or needs update @@ -1080,13 +549,13 @@ def _check_esphome_idf_framework_install( _LOGGER.info("Installing ESP-IDF %s framework ...", version) targets_str = ",".join(targets) cmd = [ - _get_pythonexe_path(), + get_system_python_path(), str(idf_tools_path), "--non-interactive", "install", f"--targets={targets_str}", ] + tools - if not _exec_ok( + if not run_command_ok( cmd, msg=f"ESP-IDF {version} framework installation", env=env, @@ -1128,7 +597,7 @@ def _check_esp_idf_python_env_install( framework_path = _get_framework_path(version) python_env_path = _get_python_env_path(version) env_stamp_file = python_env_path / ESPHOME_STAMP_FILE - env_python_path = _get_python_env_executable_path(python_env_path, "python") + env_python_path = get_python_env_executable_path(python_env_path, "python") _LOGGER.info("Checking ESP-IDF %s Python environment ...", version) install = force or not python_env_path.is_dir() or not env_python_path.is_file() @@ -1144,7 +613,7 @@ def _check_esp_idf_python_env_install( if install: rmdir(python_env_path, msg=f"Clean up ESP-IDF {version} Python environment") - _create_venv(python_env_path, msg=f"ESP-IDF {version}") + create_venv(python_env_path, msg=f"ESP-IDF {version}") esp_idf_version = _get_idf_version(framework_path, env=env) constraint_file_path = ( @@ -1174,7 +643,7 @@ def _check_esp_idf_python_env_install( "pip", "setuptools", ] - if not _exec_ok( + if not run_command_ok( cmd, msg=f"Upgrade ESP-IDF {version} Python environment packages", env=env, @@ -1194,7 +663,7 @@ def _check_esp_idf_python_env_install( "-r", str(requirements_file), ] - if not _exec_ok( + if not run_command_ok( cmd, msg=f"Install ESP-IDF {version} Python dependencies for {feature}", env=env, @@ -1296,7 +765,7 @@ def get_framework_env( # 3. If Python environment path is provided, add it to PATH and set IDF_PYTHON_ENV_PATH if python_env_path: - python_path = _get_python_env_executable_path(python_env_path, "python") + python_path = get_python_env_executable_path(python_env_path, "python") path_list.insert(0, str(python_path.parent)) env["IDF_PYTHON_ENV_PATH"] = str(python_env_path) diff --git a/esphome/framework_helpers.py b/esphome/framework_helpers.py new file mode 100644 index 0000000000..276dfbbf1c --- /dev/null +++ b/esphome/framework_helpers.py @@ -0,0 +1,677 @@ +"""Generic toolchain installation helpers shared across framework implementations.""" + +from collections.abc import Iterable +from contextlib import ExitStack +import io +import logging +import os +from pathlib import Path +import subprocess +import sys +import time +from typing import IO + +import requests + +from esphome.helpers import ProgressBar, rmtree + +PathType = str | os.PathLike + +_LOGGER = logging.getLogger(__name__) + + +def str_to_lst_of_str(a: str | list[str]) -> list[str]: + """ + Convert a string to a list of string + + Args: + a: A string containing semicolon-separated values, or an already-split list + + Returns: + list of strings + """ + if isinstance(a, list): + return a + return [f.strip() for f in a.split(";") if f.strip()] + + +def rmdir(directory: PathType, msg: str | None = None): + """ + Remove a directory and its contents recursively if it exists. + + Args: + directory: Path to the directory to be removed + msg: Optional debug message to log before removal or it an error occurs + + Returns: + None + + Raises: + RuntimeError: If directory removal fails + """ + if Path(directory).is_dir(): + try: + if msg: + _LOGGER.debug(msg) + rmtree(directory) + except OSError as e: + raise RuntimeError( + f"Error during {msg}: can't remove `{directory}`. Please remove it manually!" + ) from e + + +def get_system_python_path() -> str: + """ + Get the path to the Python executable. + + Returns: + Path to Python executable as string + """ + # Try to get PYTHONEXEPATH environment variable + # Fallback to sys.executable if not set + return os.environ.get("PYTHONEXEPATH", os.path.normpath(sys.executable)) + + +def get_python_env_executable_path(root: PathType, binary: str) -> Path: + """ + Get the path to a Python environment executable file. + + Args: + root: Root directory of the Python environment + binary: Name of the executable binary + + Returns: + Path object pointing to the executable file + """ + if os.name == "nt": + return Path(root) / "Scripts" / f"{binary}.exe" + return Path(root) / "bin" / binary + + +def run_command( + cmd: list[str], + msg: str | None = None, + env: dict[str, str] | None = None, + stream_output: bool = False, + cwd: PathType | None = None, +) -> tuple[bool, str | None, str | None]: + """ + Execute a command and return results. + + Args: + cmd: list of command arguments + msg: Optional custom message for logging + env: Optional dictionary of environment variables to set + stream_output: If True, inherit parent stdio so the subprocess prints + directly to the terminal (useful for commands that produce their + own progress output). stdout/stderr are not captured in this mode. + cwd: Optional working directory for the subprocess. + + Returns: + tuple of (success: bool, stdout: str or None, stderr: str or None). + When stream_output is True, stdout and stderr are always None. + """ + cmd_str = msg or " ".join(cmd) + try: + _LOGGER.debug("%s - running ...", cmd_str) + + run_env = os.environ.copy() + if env: + run_env.update(env) + + if stream_output: + result = subprocess.run(cmd, check=False, env=run_env, cwd=cwd) + stdout = stderr = None + else: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, + env=run_env, + cwd=cwd, + ) + stdout = result.stdout + stderr = result.stderr + + if result.returncode != 0: + if stream_output: + _LOGGER.error("%s - failed (returncode=%s)", cmd_str, result.returncode) + else: + tail = (stderr or stdout or "").strip()[-1000:] + _LOGGER.error( + "%s - failed (returncode=%s). Tail:\n%s", + cmd_str, + result.returncode, + tail, + ) + return False, stdout, stderr + + _LOGGER.debug("%s - executed successfully", cmd_str) + return True, stdout, stderr + + except (subprocess.SubprocessError, OSError) as e: + _LOGGER.error("%s - error: %s", cmd_str, str(e)) + return False, None, None + + +def run_command_ok(*args, **kwargs) -> bool: + """ + Execute a command and return only the success status. + + Args: + *args: Positional arguments to pass to run_command + **kwargs: Keyword arguments to pass to run_command + + Returns: + True if command executed successfully, False otherwise + """ + return run_command(*args, **kwargs)[0] + + +def create_venv(root: PathType, msg: str | None = None): + """ + Create a Python virtual environment. + + Args: + root: Path to the virtual environment directory + msg: Optional message for logging + + Returns: + None + + Raises: + RuntimeError: If virtual environment creation fails + """ + cmd = [get_system_python_path(), "-m", "venv", "--clear", root] + if not run_command_ok(cmd, msg=f"Create Python virtual environment for {msg}"): + raise RuntimeError(f"Can't create Python virtual environment for {msg}") + + +def _detect_archive_root(names: Iterable[str]) -> str | None: + """Detect a single top-level directory shared by all archive entries. + + Returns the directory name if every non-empty entry sits under the same + top-level directory, else ``None``. Extraction helpers use this to strip + the wrapper directory commonly found in source archives during extraction + rather than renaming it afterwards — post-extraction renames are + unreliable on Windows because antivirus and the search indexer briefly + hold handles on freshly written files. + """ + root: str | None = None + has_descendant = False + for raw in names: + name = raw.replace("\\", "/").strip("/") + if not name: + continue + first, sep, _ = name.partition("/") + if root is None: + root = first + elif root != first: + return None + if sep: + has_descendant = True + return root if has_descendant else None + + +def _tar_extract_all( + data: io.BufferedIOBase, + extract_dir: PathType = ".", + progress_header: str | None = None, +): + """ + Extract a TAR archive to the specified directory. + + Implementation is inspired by Python 3.12's tarfile data filtering logic. + This can be replaced with the standard library implementation once + support for Python 3.11 is no longer required. + + Args: + data: File-like object containing the TAR archive + extract_dir: Directory to extract contents to + progress_header: If set, show a progress bar with this header + """ + import stat + import tarfile + + # Tar extraction safety: os.path.realpath / commonpath / normpath have no + # pathlib equivalents and Path.resolve() would follow symlinks unsafely. + # Use os.path for the security-sensitive parts; the simple checks move to + # Path. + extract_dir = os.fspath(extract_dir) + abs_dest = os.path.abspath(extract_dir) # noqa: PTH100 + + with tarfile.open(fileobj=data, mode="r") as tar_ref: + all_members = tar_ref.getmembers() + + # Detect a single common top-level directory and strip it during + # extraction so we don't have to flatten it via a rename afterwards. + strip_root = _detect_archive_root(m.name for m in all_members) + strip_prefix = f"{strip_root}/" if strip_root is not None else None + + safe_members = [] + + for member in all_members: + name = member.name + + # 1. Strip leading slashes + name = name.lstrip("/" + os.sep) + + # 2. Reject absolute paths (incl. Windows drive) + if Path(name).is_absolute() or ( + os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206 + ): + continue + + # 3. Strip wrapper directory if one was detected + if strip_prefix is not None: + norm = name.replace("\\", "/") + if norm in (strip_root, strip_prefix): + continue + if not norm.startswith(strip_prefix): + continue + name = norm[len(strip_prefix) :] + + # 4. Compute final path + target_path = os.path.realpath(os.path.join(abs_dest, name)) # noqa: PTH118 + if os.path.commonpath([abs_dest, target_path]) != abs_dest: + continue + + # 5. Validate links properly + if member.issym() or member.islnk(): + linkname = member.linkname + + # Reject absolute link targets + if Path(linkname).is_absolute(): + continue + + if member.islnk() and strip_prefix is not None: + # Hard-link linknames reference another archive member + # by its archive name. We've stripped the wrapper prefix + # from member.name above (step 3); strip it here too so + # tarfile._find_link_target can resolve the target during + # extraction. Symlink linknames are filesystem-relative + # paths, not archive-member references, so they don't + # need this treatment. + norm_link = linkname.replace("\\", "/") + if norm_link in (strip_root, strip_prefix): + continue + if not norm_link.startswith(strip_prefix): + continue + linkname = norm_link[len(strip_prefix) :] + + # Strip leading slashes + linkname = os.path.normpath(linkname) + + if member.issym(): + link_target = os.path.join( # noqa: PTH118 + abs_dest, + os.path.dirname(name), # noqa: PTH120 + linkname, + ) + else: + link_target = os.path.join(abs_dest, linkname) # noqa: PTH118 + link_target = os.path.realpath(link_target) + + if os.path.commonpath([abs_dest, link_target]) != abs_dest: + continue + + # write back normalized linkname + member.linkname = linkname + + # 6. Sanitize permissions + mode = member.mode + if mode is not None: + # Strip high bits & group/other write bits + mode &= ( + stat.S_IRWXU + | stat.S_IRGRP + | stat.S_IXGRP + | stat.S_IROTH + | stat.S_IXOTH + ) + if member.isfile() or member.islnk(): + # remove exec bits unless explicitly user-executable + if not (mode & stat.S_IXUSR): + mode &= ~(stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + mode |= stat.S_IRUSR | stat.S_IWUSR + elif not (member.isdir() or member.issym()): + # Block special files. Directories and symlinks keep + # their masked-original mode — passing None here would + # crash tarfile.extract on Python <3.12 (its chmod + # path calls os.chmod unconditionally). + continue + + member.mode = mode + + # 7. Strip ownership + member.uid = None + member.gid = None + member.uname = None + member.gname = None + + # 8. Assign sanitized name back + member.name = name + + safe_members.append(member) + + total = len(safe_members) + progress = ( + ProgressBar(progress_header) if progress_header and total > 0 else None + ) + for i, member in enumerate(safe_members, 1): + tar_ref.extract(member, abs_dest) + if progress is not None: + progress.update(i / total) + if progress is not None: + progress.update(1) + + +def _zip_extract_all( + data: io.BufferedIOBase, + extract_dir: PathType = ".", + progress_header: str | None = None, +): + """ + Extract a ZIP archive to the specified directory. + + Args: + data: File-like object containing the ZIP archive + extract_dir: Directory to extract contents to + progress_header: If set, show a progress bar with this header + """ + import zipfile + + # See note in _tar_extract_all: os.path is used intentionally for + # the security-sensitive abspath/commonpath checks below. + extract_dir = os.path.abspath(extract_dir) # noqa: PTH100 + + with zipfile.ZipFile(data, "r") as zip_ref: + all_members = zip_ref.infolist() + + # Detect a single common top-level directory and strip it during + # extraction so we don't have to flatten it via a rename afterwards. + strip_root = _detect_archive_root(m.filename for m in all_members) + strip_prefix = f"{strip_root}/" if strip_root is not None else None + + total = len(all_members) + progress = ( + ProgressBar(progress_header) if progress_header and total > 0 else None + ) + + for i, member in enumerate(all_members, 1): + # 1. Normalize name + name = member.filename.lstrip("/\\") + + # 2. Reject absolute paths / Windows drives + if Path(name).is_absolute() or ( + os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206 + ): + continue + + # 3. Strip wrapper directory if one was detected + if strip_prefix is not None: + norm = name.replace("\\", "/") + if norm in (strip_root, strip_prefix): + continue + if not norm.startswith(strip_prefix): + continue + name = norm[len(strip_prefix) :] + + # 4. Compute safe target path + target_path = os.path.abspath(os.path.join(extract_dir, name)) # noqa: PTH100, PTH118 + + if os.path.commonpath([extract_dir, target_path]) != extract_dir: + raise ValueError(f"Unsafe path detected: {member.filename}") + + # 5. Assign sanitized name back + member.filename = name + + # 6. Extract + zip_ref.extract(member, extract_dir) + + if progress is not None: + progress.update(i / total) + if progress is not None: + progress.update(1) + + +def _rename_with_retry(src: Path, dst: Path, attempts: int = 5) -> None: + """Rename ``src`` to ``dst`` with backoff retries on Windows sharing violations. + + Antivirus/indexer handles on freshly-written files can briefly block + ``os.rename`` with ERROR_SHARING_VIOLATION / ERROR_ACCESS_DENIED. The + handle is released within tens of ms in practice, so exponential backoff + works. + """ + for i in range(attempts): + try: + src.rename(dst) + return + except PermissionError: + if i == attempts - 1: + raise + time.sleep(0.1 * (2**i)) + + +def _7z_extract_all( + data: io.BufferedIOBase, + extract_dir: PathType = ".", + progress_header: str | None = None, +): + """ + Extract a 7z archive to the specified directory. + + py7zr only supports bulk extraction (no per-member rename hook like + tarfile/zipfile), so we extract into a unique staging subdir of + ``extract_dir`` and then move children up. This keeps everything on + the same volume and sidesteps wrapper-vs-child name collisions + (e.g. ``arm-zephyr-eabi/`` containing another ``arm-zephyr-eabi/``). + + Args: + data: File-like object containing the 7z archive (must be seekable) + extract_dir: Directory to extract contents to + progress_header: If set, show a progress bar with this header + """ + import py7zr + + extract_dir = os.path.abspath(extract_dir) # noqa: PTH100 + Path(extract_dir).mkdir(parents=True, exist_ok=True) + + suffix = 0 + while True: + staging = Path(extract_dir) / f".extract_tmp_{suffix}" + if not staging.exists(): + break + suffix += 1 + staging.mkdir() + + try: + with py7zr.SevenZipFile(data, "r") as z: + all_names = z.getnames() + + # Detect a single common top-level directory to flatten. + strip_root = _detect_archive_root(all_names) + + # Validate names: reject absolute paths, Windows drives, and + # path traversal. Filter via targets= since py7zr can't rename + # per-member. + safe_targets: list[str] = [] + for raw in all_names: + name = raw.lstrip("/\\") + if not name: + continue + if Path(name).is_absolute() or ( + os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206 + ): + continue + target_path = os.path.abspath(os.path.join(staging, name)) # noqa: PTH100, PTH118 + if os.path.commonpath([str(staging), target_path]) != str(staging): + continue + safe_targets.append(raw) + + progress = ( + ProgressBar(progress_header) + if progress_header and safe_targets + else None + ) + + if len(safe_targets) == len(all_names): + z.extractall(path=staging) + else: + z.extract(path=staging, targets=safe_targets) + + if progress is not None: + progress.update(1) + + src_root = staging / strip_root if strip_root else staging + for item in src_root.iterdir(): + dest = Path(extract_dir) / item.name + if dest.exists(): + if dest.is_dir(): + rmtree(dest) + else: + dest.unlink() + _rename_with_retry(item, dest) + finally: + # staging is created before the try, so it always exists here; the + # guard is defensive cleanup and its False branch is unreachable. + if staging.exists(): # pragma: no cover + rmtree(staging) + + +_ARCHIVE_MAGIC_MAP = { + b"\x1f\x8b\x08": _tar_extract_all, + b"\x42\x5a\x68": _tar_extract_all, + b"\xfd\x37\x7a\x58\x5a\x00": _tar_extract_all, + b"\x50\x4b\x03\x04": _zip_extract_all, + b"\x37\x7a\xbc\xaf\x27\x1c": _7z_extract_all, +} + + +def archive_extract_all( + archive: PathType | io.RawIOBase | IO[bytes], + extract_dir: PathType = ".", + progress_header: str | None = None, +): + """ + Extract an archive file to the specified directory. + + Args: + archive: Path to archive file or file-like object + extract_dir: Directory to extract contents to + progress_header: If set, show a progress bar with this header + + Raises: + TypeError: If archive is not a valid type + ValueError: If archive format is unsupported + """ + + # 1. Handle different archive input types + with ExitStack() as stack: + archive_ref: io.BufferedIOBase + if isinstance(archive, (str, os.PathLike)): + archive_ref = stack.enter_context(Path(archive).open("rb")) + elif isinstance(archive, (io.BufferedReader, io.BufferedRandom)): + archive_ref = archive + elif isinstance(archive, io.RawIOBase): + archive_ref = io.BufferedReader(archive) + else: + raise TypeError( + f"archive must be str, Path, or file-like object: {type(archive)}" + ) + + # 2. Detect archive format and select appropriate extraction function + matched_fct = None + magic_len = max(len(k) for k in _ARCHIVE_MAGIC_MAP) + header = archive_ref.peek(magic_len) + for magic, fct in _ARCHIVE_MAGIC_MAP.items(): + if header.startswith(magic): + matched_fct = fct + break + if matched_fct is None: + raise ValueError("Unsupported archive format") + matched_fct(archive_ref, extract_dir, progress_header=progress_header) + + +def download_from_mirrors( + mirrors: list[str], + substitutions: dict[str, str], + target: io.RawIOBase | IO[bytes] | PathType, + timeout: int = 30, +) -> str: + """ + Download file from multiple mirrors with substitution support. + + Args: + mirrors: list of mirror URLs + substitutions: Dictionary of substitutions to apply to URLs + target: Target file path or file-like object + timeout: Download timeout in seconds + + Returns: + The source URL. + + Raises: + ValueError: If mirrors list is empty. + Exception: If all download attempts fail. + """ + # 1. Open target file for writing if path given + with ExitStack() as stack: + if isinstance(target, (str, os.PathLike)): + f = stack.enter_context(Path(target).open("wb")) + elif isinstance(target, (io.RawIOBase, io.IOBase)): + f = target + else: + raise TypeError( + f"target must be str, Path, or file-like object: {type(target)}" + ) + + # 2. Try each mirror in order + last_exception = None + + for mirror in mirrors: + # 3. Apply substitutions to URL + url = mirror.format(**substitutions) + + _LOGGER.debug("Trying downloading from %s", url) + + try: + # 4. Reset file pointer and download + f.seek(0) + f.truncate(0) + + with requests.get(url, stream=True, timeout=timeout) as r: + r.raise_for_status() + + total_size = int(r.headers.get("content-length", 0)) + downloaded = 0 + + progress = ProgressBar("Downloading") if total_size > 0 else None + + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + downloaded += len(chunk) + + if progress is not None: + progress.update(downloaded / total_size) + + if progress is not None: + progress.update(1) + + _LOGGER.debug("Downloaded successfully from: %s", url) + + # 6. Reset file pointer and return + f.seek(0) + return url + + except Exception as e: # noqa: BLE001 # pylint: disable=broad-exception-caught + _LOGGER.debug("Failed to download %s: %s", url, str(e)) + last_exception = e + + # 7. Raise last exception if all mirrors failed + if last_exception: + raise last_exception + raise ValueError("download_from_mirrors called with an empty mirrors list") diff --git a/requirements.txt b/requirements.txt index 8202a2bb44..ed7f2c2941 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ jinja2==3.1.6 bleak==2.1.1 smpclient==6.0.0 requests==2.34.2 +py7zr==0.22.0 # esp-idf >= 5.0 requires this pyparsing >= 3.3.2 diff --git a/tests/unit_tests/test_core.py b/tests/unit_tests/test_core.py index 2322fdd014..cc371ee1f9 100644 --- a/tests/unit_tests/test_core.py +++ b/tests/unit_tests/test_core.py @@ -894,6 +894,13 @@ class TestEsphomeCore: "foo/build/.pioenvs/test-device/bootloader.bin" ) + def test_using_toolchain_sdk_nrf(self, target): + """using_toolchain_sdk_nrf is True only for the SDK_NRF toolchain.""" + target.toolchain = const.Toolchain.SDK_NRF + assert target.using_toolchain_sdk_nrf is True + target.toolchain = const.Toolchain.ESP_IDF + assert target.using_toolchain_sdk_nrf is False + def test_add_library__extracts_short_name_from_path(self, target): """Test add_library extracts short name from library paths like owner/lib.""" target.data[const.KEY_CORE] = { diff --git a/tests/unit_tests/test_espidf_framework.py b/tests/unit_tests/test_espidf_framework.py index 9f4e4fcca8..036c7c0454 100644 --- a/tests/unit_tests/test_espidf_framework.py +++ b/tests/unit_tests/test_espidf_framework.py @@ -2,12 +2,32 @@ # pylint: disable=protected-access +import io +import json from pathlib import Path +import tarfile +from types import SimpleNamespace from unittest.mock import patch import pytest -from esphome.espidf.framework import _clone_idf_with_submodules, _parse_git_source +from esphome.espidf.framework import ( + _check_stamp, + _clone_idf_with_submodules, + _get_framework_path, + _get_idf_tool_paths, + _get_idf_tools_path, + _get_idf_version, + _get_python_env_path, + _get_python_version, + _parse_git_source, + _patch_tools_json_for_linux_arm64, + _write_idf_version_txt, + _write_stamp, + check_esp_idf_install, + get_framework_env, +) +from esphome.framework_helpers import _tar_extract_all, get_python_env_executable_path @pytest.mark.parametrize( @@ -154,3 +174,511 @@ def test_clone_idf_with_submodules_raises_when_tree_missing( "https://github.com/espressif/esp-idf.git", None, ) + + +# --------------------------------------------------------------------------- +# Helpers for _tar_extract_all hard-link prefix-stripping tests +# --------------------------------------------------------------------------- + + +def _make_tar( + members: list[tarfile.TarInfo], file_contents: dict[str, bytes] +) -> io.BytesIO: + """Build an in-memory tar archive from a list of TarInfo objects.""" + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tf: + for info in members: + if info.isreg() and info.name in file_contents: + data = file_contents[info.name] + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + else: + tf.addfile(info) + buf.seek(0) + return buf + + +def _regular(name: str) -> tarfile.TarInfo: + info = tarfile.TarInfo(name=name) + info.type = tarfile.REGTYPE + info.size = 0 + info.mode = 0o644 + return info + + +def _hardlink(name: str, linkname: str) -> tarfile.TarInfo: + info = tarfile.TarInfo(name=name) + info.type = tarfile.LNKTYPE + info.linkname = linkname + info.size = 0 + info.mode = 0o644 + return info + + +class TestTarExtractHardLinkPrefixStripping: + """ + Covers the hard-link prefix-stripping block in _tar_extract_all (L528-541). + + Archive layout used by every test: + + wrapper/ ← single top-level wrapper dir (stripped) + wrapper/target.txt ← regular file; becomes target.txt in dest + wrapper/link_good ← hard link to wrapper/target.txt (kept, linkname stripped) + wrapper/link_exact_root ← hard link to "wrapper" (skipped – equals strip_root) + wrapper/link_exact_prefix ← hard link to "wrapper/" (skipped – equals strip_prefix) + wrapper/link_outside ← hard link to "other/target.txt" (skipped – not under prefix) + """ + + WRAPPER = "wrapper" + + def _build_archive(self) -> io.BytesIO: + members = [ + _regular(f"{self.WRAPPER}/"), + _regular(f"{self.WRAPPER}/target.txt"), + _hardlink(f"{self.WRAPPER}/link_good", f"{self.WRAPPER}/target.txt"), + _hardlink(f"{self.WRAPPER}/link_exact_root", self.WRAPPER), + _hardlink(f"{self.WRAPPER}/link_exact_prefix", f"{self.WRAPPER}/"), + _hardlink(f"{self.WRAPPER}/link_outside", "other/target.txt"), + ] + return _make_tar(members, {f"{self.WRAPPER}/target.txt": b"hello"}) + + def test_good_hardlink_is_extracted_with_stripped_linkname( + self, tmp_path: Path + ) -> None: + """Hard link whose linkname starts with wrapper/ is extracted and its + linkname has the prefix removed so tarfile can resolve the target.""" + _tar_extract_all(self._build_archive(), tmp_path) + link = tmp_path / "link_good" + assert link.exists(), "link_good should have been extracted" + assert link.read_bytes() == b"hello" + + def test_hardlink_equal_to_strip_root_is_skipped(self, tmp_path: Path) -> None: + """Hard link whose linkname equals strip_root exactly must be dropped.""" + _tar_extract_all(self._build_archive(), tmp_path) + assert not (tmp_path / "link_exact_root").exists() + + def test_hardlink_equal_to_strip_prefix_is_skipped(self, tmp_path: Path) -> None: + """Hard link whose linkname equals strip_prefix (strip_root + '/') must be dropped.""" + _tar_extract_all(self._build_archive(), tmp_path) + assert not (tmp_path / "link_exact_prefix").exists() + + def test_hardlink_outside_prefix_is_skipped(self, tmp_path: Path) -> None: + """Hard link whose linkname does not start with wrapper/ must be dropped.""" + _tar_extract_all(self._build_archive(), tmp_path) + assert not (tmp_path / "link_outside").exists() + + def test_regular_file_and_no_spurious_files(self, tmp_path: Path) -> None: + """Sanity check: target.txt is extracted and no unexpected files appear.""" + _tar_extract_all(self._build_archive(), tmp_path) + assert (tmp_path / "target.txt").read_bytes() == b"hello" + extracted = {p.name for p in tmp_path.iterdir()} + assert extracted == {"target.txt", "link_good"} + + +_IDF_VERSION = "5.1.2" + + +@pytest.fixture +def espidf_mocks(setup_core: Path): + """Patch the heavy I/O of check_esp_idf_install and pre-create the framework dir.""" + # archive_extract_all is mocked, so pre-create the framework dir that the + # extracted-marker touch writes into. + _get_framework_path(_IDF_VERSION).mkdir(parents=True, exist_ok=True) + with ( + patch("esphome.espidf.framework.rmdir"), + patch( + "esphome.espidf.framework.download_from_mirrors", + return_value="https://example.com/idf.tar.xz", + ) as download, + patch("esphome.espidf.framework.archive_extract_all") as extract, + patch("esphome.espidf.framework.create_venv") as venv, + patch("esphome.espidf.framework.run_command_ok", return_value=True) as run_ok, + patch("esphome.espidf.framework._clone_idf_with_submodules") as clone, + patch("esphome.espidf.framework._write_idf_version_txt"), + patch("esphome.espidf.framework._patch_tools_json_for_linux_arm64"), + patch("esphome.espidf.framework._write_stamp"), + patch("esphome.espidf.framework._check_stamp", return_value=True), + patch("esphome.espidf.framework._get_idf_version", return_value=_IDF_VERSION), + patch("esphome.espidf.framework._get_python_version", return_value="3.11.0"), + patch("esphome.espidf.framework.get_system_python_path", return_value="python"), + ): + yield SimpleNamespace( + download=download, extract=extract, venv=venv, run_ok=run_ok, clone=clone + ) + + +def test_check_esp_idf_install_fresh(espidf_mocks: SimpleNamespace) -> None: + """A forced install drives download/extract, venv creation, and pip installs.""" + framework_path, python_env_path = check_esp_idf_install(_IDF_VERSION, force=True) + + assert framework_path == _get_framework_path(_IDF_VERSION) + assert python_env_path == _get_python_env_path(_IDF_VERSION) + # framework tarball + python-env constraints file are both downloaded + assert espidf_mocks.download.call_count == 2 + espidf_mocks.extract.assert_called_once() + espidf_mocks.venv.assert_called_once() + espidf_mocks.clone.assert_not_called() + + +def test_check_esp_idf_install_git_source(espidf_mocks: SimpleNamespace) -> None: + """A git source_url clones instead of downloading; explicit tools skip discovery.""" + check_esp_idf_install( + _IDF_VERSION, + force=True, + source_url="https://github.com/espressif/esp-idf.git", + tools=["xtensa-esp-elf"], + ) + + espidf_mocks.clone.assert_called_once() + # framework is cloned, so only the python-env constraints file is downloaded + assert espidf_mocks.download.call_count == 1 + + +def test_check_esp_idf_install_already_installed(espidf_mocks: SimpleNamespace) -> None: + """Marker + matching stamps + existing python env → nothing is re-installed.""" + framework_path = _get_framework_path(_IDF_VERSION) + (framework_path / ".esphome_extracted").touch() + python_env_path = _get_python_env_path(_IDF_VERSION) + env_python = get_python_env_executable_path(python_env_path, "python") + env_python.parent.mkdir(parents=True, exist_ok=True) + env_python.touch() + + check_esp_idf_install(_IDF_VERSION) + + espidf_mocks.extract.assert_not_called() + espidf_mocks.venv.assert_not_called() + + +def test_check_esp_idf_install_framework_failure(espidf_mocks: SimpleNamespace) -> None: + """A failing idf_tools install raises.""" + espidf_mocks.run_ok.side_effect = [False] + with pytest.raises(RuntimeError, match="framework installation failure"): + check_esp_idf_install(_IDF_VERSION, force=True) + + +def test_check_esp_idf_install_pip_upgrade_failure( + espidf_mocks: SimpleNamespace, +) -> None: + """A failing pip upgrade in the python env raises (framework install ok).""" + espidf_mocks.run_ok.side_effect = [True, False] + with pytest.raises(RuntimeError, match="Python environment packages failure"): + check_esp_idf_install(_IDF_VERSION, force=True) + + +def test_check_esp_idf_install_feature_failure(espidf_mocks: SimpleNamespace) -> None: + """A failing feature requirements install raises.""" + espidf_mocks.run_ok.side_effect = [True, True, False] + with pytest.raises(RuntimeError, match="Python dependencies for"): + check_esp_idf_install(_IDF_VERSION, force=True, features=["fb"]) + + +def _mark_installed() -> None: + """Create the extracted marker and python-env interpreter so the install + check takes the already-installed path rather than force-installing.""" + (_get_framework_path(_IDF_VERSION) / ".esphome_extracted").touch() + env_python = get_python_env_executable_path( + _get_python_env_path(_IDF_VERSION), "python" + ) + env_python.parent.mkdir(parents=True, exist_ok=True) + env_python.touch() + + +def test_check_esp_idf_install_stamp_mismatch_reinstalls( + espidf_mocks: SimpleNamespace, +) -> None: + """A stamp mismatch reinstalls tools (marker present, so no re-extract).""" + _mark_installed() + with patch("esphome.espidf.framework._check_stamp", return_value=False): + check_esp_idf_install(_IDF_VERSION) + + espidf_mocks.extract.assert_not_called() # marker present -> no re-extract + espidf_mocks.venv.assert_called_once() # tools reinstall -> venv rebuilt + + +def test_check_esp_idf_install_check_command_failure_reinstalls( + espidf_mocks: SimpleNamespace, +) -> None: + """A failing idf_tools check reinstalls tools (marker present, no re-extract).""" + _mark_installed() + # idf_tools check fails -> install stays True; the later installs succeed. + espidf_mocks.run_ok.side_effect = [False, True, True, True] + check_esp_idf_install(_IDF_VERSION, features=["fb"]) + + espidf_mocks.extract.assert_not_called() + espidf_mocks.venv.assert_called_once() + + +def test_check_esp_idf_install_unknown_python_version_reinstalls( + espidf_mocks: SimpleNamespace, +) -> None: + """An undeterminable python version rebuilds the venv (framework stamp still ok).""" + _mark_installed() + with patch("esphome.espidf.framework._get_python_version", return_value=None): + check_esp_idf_install(_IDF_VERSION) + + espidf_mocks.extract.assert_not_called() # framework stamp matched + espidf_mocks.venv.assert_called_once() # python env rebuilt + + +def test_check_esp_idf_install_python_stamp_mismatch_rebuilds_venv( + espidf_mocks: SimpleNamespace, +) -> None: + """Framework stamp matches but the python-env stamp does not -> venv rebuilt.""" + + # _check_stamp passes for the framework (no python_version key) and fails + # for the python env (carries python_version), so only the venv rebuilds. + def stamp_ok(_stamp_file, info: dict) -> bool: + return "python_version" not in info + + _mark_installed() + with patch("esphome.espidf.framework._check_stamp", side_effect=stamp_ok): + check_esp_idf_install(_IDF_VERSION) + + espidf_mocks.extract.assert_not_called() + espidf_mocks.venv.assert_called_once() + + +def test_check_esp_idf_install_unparseable_version( + espidf_mocks: SimpleNamespace, +) -> None: + """A non-semver version skips the MAJOR/MINOR substitutions without erroring.""" + bad_version = "main" + _get_framework_path(bad_version).mkdir(parents=True, exist_ok=True) + check_esp_idf_install(bad_version, force=True) + + espidf_mocks.extract.assert_called_once() + + +# --------------------------------------------------------------------------- +# _patch_tools_json_for_linux_arm64 (arm64-only ninja backport) +# --------------------------------------------------------------------------- + + +def _write_tools_json(framework_path: Path, data: dict) -> Path: + tools_dir = framework_path / "tools" + tools_dir.mkdir(parents=True, exist_ok=True) + tools_json = tools_dir / "tools.json" + tools_json.write_text(json.dumps(data), encoding="utf-8") + return tools_json + + +def test_patch_tools_json_non_aarch64_is_noop(tmp_path: Path) -> None: + tools_json = _write_tools_json( + tmp_path, {"tools": [{"name": "ninja", "versions": [{"name": "1.12.1"}]}]} + ) + before = tools_json.read_text(encoding="utf-8") + with patch("esphome.espidf.framework.platform.machine", return_value="x86_64"): + _patch_tools_json_for_linux_arm64(tmp_path) + assert tools_json.read_text(encoding="utf-8") == before + + +def test_patch_tools_json_missing_file_is_noop(tmp_path: Path) -> None: + with patch("esphome.espidf.framework.platform.machine", return_value="aarch64"): + _patch_tools_json_for_linux_arm64(tmp_path) # no tools/tools.json present + + +def test_patch_tools_json_corrupt_file_warns_and_skips(tmp_path: Path) -> None: + (tmp_path / "tools").mkdir() + (tmp_path / "tools" / "tools.json").write_text("{ not json", encoding="utf-8") + with patch("esphome.espidf.framework.platform.machine", return_value="aarch64"): + _patch_tools_json_for_linux_arm64(tmp_path) # JSONDecodeError -> skip + + +def test_patch_tools_json_injects_ninja_arm64(tmp_path: Path) -> None: + tools_json = _write_tools_json( + tmp_path, + { + "tools": [ + {"name": "ninja", "versions": [{"name": "1.12.1"}]}, + {"name": "cmake", "versions": [{"name": "3.24.0"}]}, + ] + }, + ) + with patch("esphome.espidf.framework.platform.machine", return_value="aarch64"): + _patch_tools_json_for_linux_arm64(tmp_path) + + data = json.loads(tools_json.read_text(encoding="utf-8")) + ninja = next(t for t in data["tools"] if t["name"] == "ninja") + assert "linux-arm64" in ninja["versions"][0] + assert ninja["versions"][0]["linux-arm64"]["size"] == 121787 + + +def test_patch_tools_json_already_patched_is_noop(tmp_path: Path) -> None: + tools_json = _write_tools_json( + tmp_path, + { + "tools": [ + { + "name": "ninja", + "versions": [{"name": "1.12.1", "linux-arm64": {"url": "x"}}], + } + ] + }, + ) + before = tools_json.read_text(encoding="utf-8") + with patch("esphome.espidf.framework.platform.machine", return_value="aarch64"): + _patch_tools_json_for_linux_arm64(tmp_path) + assert tools_json.read_text(encoding="utf-8") == before + + +# --------------------------------------------------------------------------- +# Subprocess-backed helpers (_exec -> run_command rename) and get_framework_env +# --------------------------------------------------------------------------- + + +def test_get_idf_version_parses_stdout(tmp_path: Path) -> None: + with patch( + "esphome.espidf.framework.run_command", return_value=(True, "5.1.2\n", "") + ): + assert _get_idf_version(tmp_path) == "5.1.2" + + +def test_get_idf_version_raises_on_failure(tmp_path: Path) -> None: + with ( + patch("esphome.espidf.framework.run_command", return_value=(False, "", "boom")), + pytest.raises(RuntimeError, match="Can't get ESP-IDF version"), + ): + _get_idf_version(tmp_path) + + +def test_get_idf_tool_paths_parses_json(tmp_path: Path) -> None: + payload = json.dumps({"paths_to_export": ["/a", "/b"], "export_vars": {"X": "1"}}) + with patch( + "esphome.espidf.framework.run_command", return_value=(True, payload, "") + ): + paths, export_vars = _get_idf_tool_paths(tmp_path) + assert paths == ["/a", "/b"] + assert export_vars == {"X": "1"} + + +def test_get_idf_tool_paths_raises_on_bad_json(tmp_path: Path) -> None: + with ( + patch( + "esphome.espidf.framework.run_command", return_value=(True, "not json", "") + ), + pytest.raises(RuntimeError, match="Can't extract ESP-IDF tool paths"), + ): + _get_idf_tool_paths(tmp_path) + + +def test_get_idf_tool_paths_raises_on_failure(tmp_path: Path) -> None: + with ( + patch("esphome.espidf.framework.run_command", return_value=(False, "", "err")), + pytest.raises(RuntimeError, match="Can't get ESP-IDF tool paths"), + ): + _get_idf_tool_paths(tmp_path) + + +def test_get_python_version_parses_stdout(tmp_path: Path) -> None: + with patch( + "esphome.espidf.framework.run_command", return_value=(True, "3.11.0\n", "") + ): + assert _get_python_version(tmp_path / "python") == "3.11.0" + + +def test_get_python_version_returns_falsy_on_failure(tmp_path: Path) -> None: + with patch("esphome.espidf.framework.run_command", return_value=(False, "", "")): + # non-throwing failure returns the (empty) stdout as-is + assert not _get_python_version(tmp_path / "python") + + +def test_get_python_version_raises_when_requested(tmp_path: Path) -> None: + with ( + patch("esphome.espidf.framework.run_command", return_value=(False, "", "")), + pytest.raises(RuntimeError, match="Can't get Python version"), + ): + _get_python_version(tmp_path / "python", throw_exception=True) + + +def test_write_stamp_writes_json(tmp_path: Path) -> None: + stamp = tmp_path / "stamp.json" + _write_stamp(stamp, {"a": "1", "b": "2"}) + assert json.loads(stamp.read_text(encoding="utf-8")) == {"a": "1", "b": "2"} + + +def test_get_framework_env_with_python_env(tmp_path: Path) -> None: + with ( + patch( + "esphome.espidf.framework._get_idf_tools_path", + return_value=tmp_path / "tools", + ), + patch("esphome.espidf.framework._get_idf_version", return_value="5.1.2"), + patch( + "esphome.espidf.framework._get_idf_tool_paths", + return_value=(["/tool/bin"], {"IDF_X": "1"}), + ), + ): + env = get_framework_env( + tmp_path / "fw", tmp_path / "penv", {"PATH": "/usr/bin"} + ) + + assert env["IDF_PATH"] == str(tmp_path / "fw") + assert env["ESP_IDF_VERSION"] == "5.1.2" + assert env["IDF_X"] == "1" + assert env["IDF_PYTHON_ENV_PATH"] == str(tmp_path / "penv") + assert "/tool/bin" in env["PATH"] + + +def test_get_framework_env_without_python_env_uses_os_path(tmp_path: Path) -> None: + with ( + patch( + "esphome.espidf.framework._get_idf_tools_path", + return_value=tmp_path / "tools", + ), + patch("esphome.espidf.framework._get_idf_version", return_value="5.1.2"), + patch("esphome.espidf.framework._get_idf_tool_paths", return_value=([], {})), + ): + env = get_framework_env(tmp_path / "fw") + + assert "IDF_PYTHON_ENV_PATH" not in env + assert env["PATH"] # taken from os.environ + + +# --------------------------------------------------------------------------- +# _check_stamp / _write_idf_version_txt / _get_idf_tools_path +# --------------------------------------------------------------------------- + + +def test_check_stamp_matches(tmp_path: Path) -> None: + f = tmp_path / "s.json" + f.write_text(json.dumps({"a": "1"}), encoding="utf-8") + assert _check_stamp(f, {"a": "1"}) is True + + +def test_check_stamp_mismatch(tmp_path: Path) -> None: + f = tmp_path / "s.json" + f.write_text(json.dumps({"a": "1"}), encoding="utf-8") + assert _check_stamp(f, {"a": "2"}) is False + + +def test_check_stamp_missing_file(tmp_path: Path) -> None: + assert _check_stamp(tmp_path / "nope.json", {"a": "1"}) is False + + +def test_check_stamp_corrupt_file(tmp_path: Path) -> None: + f = tmp_path / "s.json" + f.write_text("{ not json", encoding="utf-8") + assert _check_stamp(f, {"a": "1"}) is False + + +def test_write_idf_version_txt_writes_when_missing(tmp_path: Path) -> None: + _write_idf_version_txt(tmp_path, "5.1.2") + assert (tmp_path / "version.txt").read_text(encoding="utf-8") == "v5.1.2\n" + + +def test_write_idf_version_txt_skips_when_present(tmp_path: Path) -> None: + (tmp_path / "version.txt").write_text("existing\n", encoding="utf-8") + _write_idf_version_txt(tmp_path, "5.1.2") + assert (tmp_path / "version.txt").read_text(encoding="utf-8") == "existing\n" + + +def test_get_idf_tools_path_env_override(tmp_path: Path) -> None: + override = str(tmp_path / "custom-idf") + with patch.dict("os.environ", {"ESPHOME_ESP_IDF_PREFIX": override}): + assert _get_idf_tools_path() == Path(override) + + +def test_write_idf_version_txt_warns_on_write_error(tmp_path: Path) -> None: + with patch("pathlib.Path.write_text", side_effect=OSError("denied")): + # write failure is caught and warned, not raised + _write_idf_version_txt(tmp_path, "5.1.2") diff --git a/tests/unit_tests/test_framework_helpers.py b/tests/unit_tests/test_framework_helpers.py new file mode 100644 index 0000000000..a8533608c0 --- /dev/null +++ b/tests/unit_tests/test_framework_helpers.py @@ -0,0 +1,954 @@ +"""Tests for esphome.framework_helpers.""" + +# pylint: disable=protected-access + +import importlib.util +import io +import logging +import os +from pathlib import Path +import subprocess +import sys +import tarfile +from unittest.mock import MagicMock, Mock, patch +import zipfile + +import pytest +import requests as req + +from esphome.framework_helpers import ( + _7z_extract_all, + _detect_archive_root, + _rename_with_retry, + _tar_extract_all, + _zip_extract_all, + archive_extract_all, + create_venv, + download_from_mirrors, + get_python_env_executable_path, + get_system_python_path, + rmdir, + run_command, + run_command_ok, + str_to_lst_of_str, +) + +_HAS_PY7ZR = importlib.util.find_spec("py7zr") is not None + +# --------------------------------------------------------------------------- +# str_to_lst_of_str +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + ("a;b;c", ["a", "b", "c"]), + (" a ; b ", ["a", "b"]), + (";; a ;;", ["a"]), + ("single", ["single"]), + ("", []), + (["already", "a", "list"], ["already", "a", "list"]), + ], +) +def test_str_to_lst_of_str(value: str | list, expected: list) -> None: + assert str_to_lst_of_str(value) == expected + + +# --------------------------------------------------------------------------- +# rmdir +# --------------------------------------------------------------------------- + + +def test_rmdir_nonexistent_is_noop(tmp_path: Path) -> None: + rmdir(tmp_path / "missing") + + +def test_rmdir_removes_existing_directory(tmp_path: Path) -> None: + d = tmp_path / "to_remove" + d.mkdir() + (d / "file.txt").write_text("x") + rmdir(d) + assert not d.exists() + + +def test_rmdir_logs_debug_with_msg( + tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: + d = tmp_path / "logged" + d.mkdir() + with caplog.at_level(logging.DEBUG, logger="esphome.framework_helpers"): + rmdir(d, msg="cleanup message") + assert "cleanup message" in caplog.text + + +def test_rmdir_raises_runtime_error_on_os_error(tmp_path: Path) -> None: + d = tmp_path / "stubborn" + d.mkdir() + with ( + patch("esphome.framework_helpers.rmtree", side_effect=OSError("perm denied")), + pytest.raises(RuntimeError, match="can't remove"), + ): + rmdir(d, msg="cleanup step") + + +# --------------------------------------------------------------------------- +# get_system_python_path +# --------------------------------------------------------------------------- + + +def test_get_system_python_path_returns_env_var() -> None: + with patch.dict(os.environ, {"PYTHONEXEPATH": "/custom/python"}): + assert get_system_python_path() == "/custom/python" + + +def test_get_system_python_path_falls_back_to_sys_executable() -> None: + env = {k: v for k, v in os.environ.items() if k != "PYTHONEXEPATH"} + with patch.dict(os.environ, env, clear=True): + assert get_system_python_path() == os.path.normpath(sys.executable) + + +# --------------------------------------------------------------------------- +# get_python_env_executable_path +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(os.name != "posix", reason="PosixPath construction requires POSIX") +def test_get_python_env_executable_path_posix() -> None: + assert get_python_env_executable_path("/env", "python") == Path("/env/bin/python") + + +@pytest.mark.skipif(os.name != "nt", reason="WindowsPath construction requires Windows") +def test_get_python_env_executable_path_windows() -> None: + assert get_python_env_executable_path("/env", "python") == Path( + "/env/Scripts/python.exe" + ) + + +# --------------------------------------------------------------------------- +# run_command +# --------------------------------------------------------------------------- + + +def test_run_command_success_returns_stdout(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.return_value = Mock(returncode=0, stdout="out\n", stderr="") + ok, stdout, _stderr = run_command(["echo", "hello"]) + assert ok is True + assert stdout == "out\n" + + +def test_run_command_failure_returns_false(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.return_value = Mock(returncode=1, stdout="", stderr="boom") + ok, _stdout, stderr = run_command(["bad"]) + assert ok is False + assert stderr == "boom" + + +def test_run_command_stream_output_success(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.return_value = Mock(returncode=0) + ok, stdout, stderr = run_command(["cmd"], stream_output=True) + assert ok is True + assert stdout is None + assert stderr is None + + +def test_run_command_stream_output_failure(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.return_value = Mock(returncode=2) + ok, stdout, _stderr = run_command(["cmd"], stream_output=True) + assert ok is False + assert stdout is None + + +def test_run_command_subprocess_error_returns_false(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.side_effect = subprocess.SubprocessError("exploded") + ok, stdout, stderr = run_command(["cmd"]) + assert ok is False + assert stdout is None + assert stderr is None + + +def test_run_command_os_error_returns_false(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.side_effect = OSError("not found") + ok, _stdout, _stderr = run_command(["cmd"]) + assert ok is False + + +def test_run_command_passes_env(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.return_value = Mock(returncode=0, stdout="", stderr="") + run_command(["cmd"], env={"MY_VAR": "42"}) + assert mock_subprocess_run.call_args[1]["env"]["MY_VAR"] == "42" + + +def test_run_command_passes_cwd(mock_subprocess_run: Mock, tmp_path: Path) -> None: + mock_subprocess_run.return_value = Mock(returncode=0, stdout="", stderr="") + run_command(["cmd"], cwd=str(tmp_path)) + assert mock_subprocess_run.call_args[1]["cwd"] == str(tmp_path) + + +# --------------------------------------------------------------------------- +# run_command_ok +# --------------------------------------------------------------------------- + + +def test_run_command_ok_true(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.return_value = Mock(returncode=0, stdout="", stderr="") + assert run_command_ok(["cmd"]) is True + + +def test_run_command_ok_false(mock_subprocess_run: Mock) -> None: + mock_subprocess_run.return_value = Mock(returncode=1, stdout="", stderr="") + assert run_command_ok(["cmd"]) is False + + +# --------------------------------------------------------------------------- +# create_venv +# --------------------------------------------------------------------------- + + +def test_create_venv_calls_run_command_ok(tmp_path: Path) -> None: + with patch( + "esphome.framework_helpers.run_command_ok", return_value=True + ) as mock_cmd: + create_venv(tmp_path / "env", msg="test") + mock_cmd.assert_called_once() + + +def test_create_venv_raises_on_failure(tmp_path: Path) -> None: + with ( + patch("esphome.framework_helpers.run_command_ok", return_value=False), + pytest.raises(RuntimeError, match="Can't create Python virtual environment"), + ): + create_venv(tmp_path / "env", msg="test") + + +# --------------------------------------------------------------------------- +# _detect_archive_root +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("names", "expected"), + [ + (["wrapper/", "wrapper/a.txt", "wrapper/sub/b.txt"], "wrapper"), + (["root1/a.txt", "root2/b.txt"], None), + (["wrapper"], None), # no descendant → None + (["", "wrapper/file.txt"], "wrapper"), # empty names skipped + (["wrapper\\file.txt"], "wrapper"), # backslash normalised + (["w/a", "w/b", "w/c"], "w"), + ], +) +def test_detect_archive_root(names: list[str], expected: str | None) -> None: + assert _detect_archive_root(names) == expected + + +# --------------------------------------------------------------------------- +# Tar archive helpers +# --------------------------------------------------------------------------- + + +def _make_tar( + members: list[tarfile.TarInfo], + file_contents: dict[str, bytes] | None = None, +) -> io.BytesIO: + buf = io.BytesIO() + contents = file_contents or {} + with tarfile.open(fileobj=buf, mode="w") as tf: + for info in members: + if info.isreg() and info.name in contents: + data = contents[info.name] + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + else: + tf.addfile(info) + buf.seek(0) + return buf + + +def _reg(name: str) -> tarfile.TarInfo: + info = tarfile.TarInfo(name=name) + info.type = tarfile.REGTYPE + info.size = 0 + info.mode = 0o644 + return info + + +def _dir(name: str) -> tarfile.TarInfo: + info = tarfile.TarInfo(name=name) + info.type = tarfile.DIRTYPE + info.mode = 0o755 + return info + + +def _sym(name: str, target: str) -> tarfile.TarInfo: + info = tarfile.TarInfo(name=name) + info.type = tarfile.SYMTYPE + info.linkname = target + info.mode = 0o777 + return info + + +def _special(name: str) -> tarfile.TarInfo: + info = tarfile.TarInfo(name=name) + info.type = tarfile.CHRTYPE + info.mode = 0o600 + return info + + +def _hlnk(name: str, target: str) -> tarfile.TarInfo: + info = tarfile.TarInfo(name=name) + info.type = tarfile.LNKTYPE + info.linkname = target + info.mode = 0o644 + return info + + +# --------------------------------------------------------------------------- +# _tar_extract_all — branches not covered by the hard-link prefix-strip tests +# --------------------------------------------------------------------------- + + +class TestTarExtractAllSecurity: + def test_flat_archive_no_wrapper(self, tmp_path: Path) -> None: + """Without a single common root files land directly in extract_dir.""" + buf = _make_tar( + [_reg("a.txt"), _reg("b.txt")], + {"a.txt": b"aaa", "b.txt": b"bbb"}, + ) + _tar_extract_all(buf, tmp_path) + assert (tmp_path / "a.txt").read_bytes() == b"aaa" + assert (tmp_path / "b.txt").read_bytes() == b"bbb" + + def test_directory_member_extracted(self, tmp_path: Path) -> None: + buf = _make_tar([_dir("subdir/")]) + _tar_extract_all(buf, tmp_path) + assert (tmp_path / "subdir").is_dir() + + def test_symlink_within_dest_extracted(self, tmp_path: Path) -> None: + buf = _make_tar( + [_reg("target.txt"), _sym("link.txt", "target.txt")], + {"target.txt": b"data"}, + ) + _tar_extract_all(buf, tmp_path) + assert (tmp_path / "link.txt").exists() + + def test_path_traversal_skipped(self, tmp_path: Path) -> None: + """Member resolving outside extract_dir via .. is silently skipped.""" + info = tarfile.TarInfo(name="sub/../../escape.txt") + info.type = tarfile.REGTYPE + info.size = 5 + info.mode = 0o644 + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tf: + tf.addfile(info, io.BytesIO(b"OOPS!")) + buf.seek(0) + _tar_extract_all(buf, tmp_path) + assert not (tmp_path.parent / "escape.txt").exists() + assert not list(tmp_path.rglob("escape.txt")) + + def test_absolute_symlink_target_skipped(self, tmp_path: Path) -> None: + """Symlink pointing to an absolute path is silently skipped.""" + buf = _make_tar( + [_reg("real.txt"), _sym("danger.lnk", "/etc/passwd")], + {"real.txt": b"ok"}, + ) + _tar_extract_all(buf, tmp_path) + assert not (tmp_path / "danger.lnk").exists() + + def test_symlink_escaping_dest_skipped(self, tmp_path: Path) -> None: + """Symlink whose resolved path exits extract_dir is silently skipped.""" + buf = _make_tar([_sym("up.lnk", "../outside.txt")]) + _tar_extract_all(buf, tmp_path) + assert not (tmp_path / "up.lnk").exists() + + def test_special_file_skipped(self, tmp_path: Path) -> None: + """Character-device and other special-file members are silently skipped.""" + buf = _make_tar([_special("chardev")]) + _tar_extract_all(buf, tmp_path) + assert not (tmp_path / "chardev").exists() + + @pytest.mark.skipif( + os.name == "nt", reason="Windows has no POSIX executable permission bit" + ) + def test_executable_bit_preserved(self, tmp_path: Path) -> None: + """User-executable bit is kept for explicitly executable files.""" + info = _reg("script.sh") + info.mode = 0o755 + buf = _make_tar([info], {"script.sh": b"#!/bin/sh"}) + _tar_extract_all(buf, tmp_path) + assert (tmp_path / "script.sh").stat().st_mode & 0o100 # S_IXUSR + + def test_non_executable_exec_bits_stripped(self, tmp_path: Path) -> None: + """Exec bits are removed when S_IXUSR is not set.""" + info = _reg("data.bin") + info.mode = 0o654 # group/other exec present, user exec absent + buf = _make_tar([info], {"data.bin": b"\x00"}) + _tar_extract_all(buf, tmp_path) + mode = (tmp_path / "data.bin").stat().st_mode + assert not (mode & 0o111) # all exec bits cleared + + +# --------------------------------------------------------------------------- +# ZIP archive helper +# --------------------------------------------------------------------------- + + +def _make_zip(entries: list[tuple[str, str | bytes]]) -> io.BytesIO: + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + for name, content in entries: + zf.writestr(name, content) + buf.seek(0) + return buf + + +# --------------------------------------------------------------------------- +# _zip_extract_all +# --------------------------------------------------------------------------- + + +class TestZipExtractAll: + def test_basic_extraction_strips_wrapper(self, tmp_path: Path) -> None: + buf = _make_zip([("wrapper/file.txt", "hello")]) + _zip_extract_all(buf, tmp_path) + assert (tmp_path / "file.txt").read_text() == "hello" + + def test_flat_archive_no_wrapper(self, tmp_path: Path) -> None: + buf = _make_zip([("a.txt", "aaa"), ("b.txt", "bbb")]) + _zip_extract_all(buf, tmp_path) + assert (tmp_path / "a.txt").read_text() == "aaa" + assert (tmp_path / "b.txt").read_text() == "bbb" + + def test_wrapper_root_entry_skipped(self, tmp_path: Path) -> None: + """The wrapper directory entry itself (step 3a) does not appear in dest.""" + buf = _make_zip([("wrapper/", ""), ("wrapper/file.txt", "content")]) + _zip_extract_all(buf, tmp_path) + assert (tmp_path / "file.txt").read_text() == "content" + assert not (tmp_path / "wrapper").exists() + + def test_path_traversal_raises(self, tmp_path: Path) -> None: + # Two members with different roots so _detect_archive_root returns None + # and strip_prefix is not applied, leaving "../escape.txt" to hit the + # commonpath safety check directly. + buf = _make_zip([("safe.txt", "ok"), ("../escape.txt", "bad")]) + with pytest.raises(ValueError, match="Unsafe path"): + _zip_extract_all(buf, tmp_path) + + def test_multiple_files_extracted(self, tmp_path: Path) -> None: + entries = [(f"root/{c}.txt", c * 3) for c in "abc"] + buf = _make_zip(entries) + _zip_extract_all(buf, tmp_path) + for c in "abc": + assert (tmp_path / f"{c}.txt").read_text() == c * 3 + + +# --------------------------------------------------------------------------- +# archive_extract_all dispatch +# --------------------------------------------------------------------------- + + +def _gzip_tar_bytes(entries: dict[str, bytes]) -> bytes: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tf: + for name, content in entries.items(): + info = tarfile.TarInfo(name=name) + info.size = len(content) + info.mode = 0o644 + tf.addfile(info, io.BytesIO(content)) + return buf.getvalue() + + +class TestArchiveExtractAll: + def test_path_input_gzip_tar(self, tmp_path: Path) -> None: + archive = tmp_path / "test.tar.gz" + archive.write_bytes(_gzip_tar_bytes({"file.txt": b"hello"})) + dest = tmp_path / "out" + dest.mkdir() + archive_extract_all(archive, dest) + assert (dest / "file.txt").read_bytes() == b"hello" + + def test_buffered_reader_input(self, tmp_path: Path) -> None: + archive = tmp_path / "test.tar.gz" + archive.write_bytes(_gzip_tar_bytes({"file.txt": b"data"})) + dest = tmp_path / "out" + dest.mkdir() + with archive.open("rb") as f: # io.BufferedReader + archive_extract_all(f, dest) + assert (dest / "file.txt").read_bytes() == b"data" + + def test_rawio_input(self, tmp_path: Path) -> None: + archive = tmp_path / "test.tar.gz" + archive.write_bytes(_gzip_tar_bytes({"file.txt": b"raw"})) + dest = tmp_path / "out" + dest.mkdir() + archive_extract_all(io.FileIO(archive), dest) + assert (dest / "file.txt").read_bytes() == b"raw" + + def test_zip_dispatched(self, tmp_path: Path) -> None: + archive = tmp_path / "test.zip" + archive.write_bytes(_make_zip([("file.txt", "hi")]).getvalue()) + dest = tmp_path / "out" + dest.mkdir() + archive_extract_all(archive, dest) + assert (dest / "file.txt").read_text() == "hi" + + def test_invalid_type_raises_type_error(self) -> None: + with pytest.raises(TypeError, match="archive must be"): + archive_extract_all(42, ".") # type: ignore[arg-type] + + def test_unsupported_format_raises_value_error(self, tmp_path: Path) -> None: + bad = tmp_path / "bad.bin" + bad.write_bytes(b"\x00\x01\x02\x03\x04\x05\x06") + with pytest.raises(ValueError, match="Unsupported archive format"): + archive_extract_all(bad, tmp_path) + + +# --------------------------------------------------------------------------- +# download_from_mirrors +# --------------------------------------------------------------------------- + + +def _mock_response(content: bytes, ok: bool = True) -> MagicMock: + r = MagicMock() + r.__enter__.return_value = r + r.__exit__.return_value = False + if ok: + r.raise_for_status.return_value = None + else: + r.raise_for_status.side_effect = req.HTTPError("503") + r.headers = {"content-length": "0"} # suppress ProgressBar + r.iter_content.return_value = [content] if content else [] + return r + + +class TestDownloadFromMirrors: + def test_success_returns_url_and_writes_content(self, tmp_path: Path) -> None: + target = tmp_path / "out.bin" + with patch( + "esphome.framework_helpers.requests.get", + return_value=_mock_response(b"filedata"), + ): + url = download_from_mirrors(["https://example.com/f"], {}, target) + assert url == "https://example.com/f" + assert target.read_bytes() == b"filedata" + + def test_substitutions_applied_to_url(self, tmp_path: Path) -> None: + with patch( + "esphome.framework_helpers.requests.get", + return_value=_mock_response(b"x"), + ) as mock_get: + download_from_mirrors( + ["https://example.com/{VERSION}.bin"], + {"VERSION": "1.2.3"}, + tmp_path / "out.bin", + ) + assert mock_get.call_args[0][0] == "https://example.com/1.2.3.bin" + + def test_falls_back_to_second_mirror(self, tmp_path: Path) -> None: + with patch( + "esphome.framework_helpers.requests.get", + side_effect=[_mock_response(b"", ok=False), _mock_response(b"second")], + ): + url = download_from_mirrors( + ["https://mirror1.com/f", "https://mirror2.com/f"], + {}, + tmp_path / "out.bin", + ) + assert url == "https://mirror2.com/f" + assert (tmp_path / "out.bin").read_bytes() == b"second" + + def test_all_mirrors_fail_reraises_last_exception(self, tmp_path: Path) -> None: + with ( + patch( + "esphome.framework_helpers.requests.get", + return_value=_mock_response(b"", ok=False), + ), + pytest.raises(req.HTTPError), + ): + download_from_mirrors(["https://example.com/f"], {}, tmp_path / "out.bin") + + def test_empty_mirrors_raises_value_error(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="empty mirrors list"): + download_from_mirrors([], {}, tmp_path / "out.bin") + + def test_invalid_target_type_raises_type_error(self) -> None: + with pytest.raises(TypeError, match="target must be"): + download_from_mirrors(["https://example.com/f"], {}, 42) # type: ignore[arg-type] + + def test_file_like_target_written(self) -> None: + buf = io.BytesIO() + with patch( + "esphome.framework_helpers.requests.get", + return_value=_mock_response(b"bytes"), + ): + download_from_mirrors(["https://example.com/f"], {}, buf) + buf.seek(0) + assert buf.read() == b"bytes" + + def test_progress_bar_shown_when_content_length_known(self, tmp_path: Path) -> None: + r = _mock_response(b"1234567890") + r.headers = {"content-length": "10"} + with ( + patch("esphome.framework_helpers.requests.get", return_value=r), + patch("esphome.framework_helpers.ProgressBar") as mock_pb, + ): + download_from_mirrors(["https://example.com/f"], {}, tmp_path / "out.bin") + mock_pb.assert_called_once_with("Downloading") + mock_pb.return_value.update.assert_called() + + def test_empty_chunk_not_written(self, tmp_path: Path) -> None: + """Empty chunks yielded by iter_content are skipped without writing.""" + r = MagicMock() + r.__enter__.return_value = r + r.__exit__.return_value = False + r.raise_for_status.return_value = None + r.headers = {"content-length": "0"} + r.iter_content.return_value = [b""] # one empty chunk + target = tmp_path / "out.bin" + with patch("esphome.framework_helpers.requests.get", return_value=r): + download_from_mirrors(["https://example.com/f"], {}, target) + assert target.exists() + assert target.read_bytes() == b"" + + +# --------------------------------------------------------------------------- +# get_python_env_executable_path — Windows branch +# --------------------------------------------------------------------------- + + +def test_get_python_env_executable_path_nt() -> None: + """Windows path uses Scripts/ and .exe suffix.""" + from pathlib import PurePosixPath + + with ( + patch.object(os, "name", "nt"), + patch("esphome.framework_helpers.Path", PurePosixPath), + ): + result = get_python_env_executable_path("/env", "python") + assert str(result) == "/env/Scripts/python.exe" + + +# --------------------------------------------------------------------------- +# _tar_extract_all — additional branch coverage +# --------------------------------------------------------------------------- + + +class TestTarExtractAllBranches: + @pytest.mark.skipif( + sys.version_info < (3, 12), + reason="patching os.name makes pathlib build a WindowsPath, which only " + "instantiates on POSIX in 3.12+", + ) + def test_windows_drive_path_skipped(self, tmp_path: Path) -> None: + """Windows-style drive path (C:/...) is skipped when os.name == 'nt'.""" + info = tarfile.TarInfo(name="C:/secret.txt") + info.type = tarfile.REGTYPE + info.size = 0 + info.mode = 0o644 + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tf: + tf.addfile(info) + buf.seek(0) + with patch.object(os, "name", "nt"): + _tar_extract_all(buf, tmp_path) + assert not list(tmp_path.rglob("*")) + + def test_strip_root_exact_match_skipped(self, tmp_path: Path) -> None: + """Member whose name equals strip_root exactly (no trailing slash) is skipped.""" + # "wrapper" (file entry) + "wrapper/file.txt" causes _detect_archive_root + # to return "wrapper"; the bare "wrapper" entry matches strip_root exactly. + buf = _make_tar( + [_reg("wrapper"), _reg("wrapper/file.txt")], + {"wrapper/file.txt": b"content"}, + ) + _tar_extract_all(buf, tmp_path) + assert not (tmp_path / "wrapper").exists() + assert (tmp_path / "file.txt").read_bytes() == b"content" + + def test_member_not_under_strip_prefix_skipped(self, tmp_path: Path) -> None: + """Member whose name doesn't start with strip_prefix is silently skipped.""" + buf = _make_tar([_reg("other/file.txt")], {"other/file.txt": b"data"}) + with patch("esphome.framework_helpers._detect_archive_root", return_value="w"): + _tar_extract_all(buf, tmp_path) + assert not list(tmp_path.rglob("*")) + + def test_hardlink_prefix_stripped(self, tmp_path: Path) -> None: + """Hard-link linkname has wrapper prefix stripped along with its entry name.""" + buf = _make_tar( + [_reg("wrapper/file.txt"), _hlnk("wrapper/link.txt", "wrapper/file.txt")], + {"wrapper/file.txt": b"data"}, + ) + _tar_extract_all(buf, tmp_path) + assert (tmp_path / "file.txt").read_bytes() == b"data" + assert (tmp_path / "link.txt").exists() + + def test_hardlink_linkname_equals_strip_root_skipped(self, tmp_path: Path) -> None: + """Hard link whose linkname equals strip_root is silently skipped.""" + buf = _make_tar( + [_reg("wrapper/file.txt"), _hlnk("wrapper/link.txt", "wrapper")], + {"wrapper/file.txt": b"data"}, + ) + _tar_extract_all(buf, tmp_path) + assert not (tmp_path / "link.txt").exists() + + def test_hardlink_linkname_outside_prefix_skipped(self, tmp_path: Path) -> None: + """Hard link whose linkname doesn't start with strip_prefix is skipped.""" + buf = _make_tar( + [_reg("wrapper/file.txt"), _hlnk("wrapper/link.txt", "other/file.txt")], + {"wrapper/file.txt": b"data"}, + ) + _tar_extract_all(buf, tmp_path) + assert not (tmp_path / "link.txt").exists() + + def test_member_mode_none_skips_sanitization(self, tmp_path: Path) -> None: + """Member with mode=None bypasses the sanitization block without error.""" + info = _reg("file.txt") + buf = _make_tar([info], {"file.txt": b"data"}) + buf.seek(0) + with tarfile.open(fileobj=buf) as tf: + members = tf.getmembers() + for m in members: + m.mode = None + buf.seek(0) + with ( + patch("tarfile.TarFile.getmembers", return_value=members), + patch("tarfile.TarFile.extract"), + ): + _tar_extract_all(buf, tmp_path) + + def test_progress_bar_shown(self, tmp_path: Path) -> None: + """A non-empty progress_header causes ProgressBar to be created and updated.""" + buf = _make_tar([_reg("file.txt")], {"file.txt": b"x"}) + with patch("esphome.framework_helpers.ProgressBar") as mock_pb: + _tar_extract_all(buf, tmp_path, progress_header="Extracting") + mock_pb.assert_called_once_with("Extracting") + mock_pb.return_value.update.assert_called() + + +# --------------------------------------------------------------------------- +# _zip_extract_all — additional branch coverage +# --------------------------------------------------------------------------- + + +class TestZipExtractAllBranches: + @pytest.mark.skipif( + sys.version_info < (3, 12), + reason="patching os.name makes pathlib build a WindowsPath, which only " + "instantiates on POSIX in 3.12+", + ) + def test_windows_drive_path_skipped(self, tmp_path: Path) -> None: + """Windows-style drive path (C:/...) is skipped when os.name == 'nt'.""" + buf = _make_zip([("C:/secret.txt", "bad")]) + with patch.object(os, "name", "nt"): + _zip_extract_all(buf, tmp_path) + assert not list(tmp_path.rglob("*")) + + def test_member_not_under_strip_prefix_skipped(self, tmp_path: Path) -> None: + """Member whose name doesn't start with strip_prefix is silently skipped.""" + buf = _make_zip([("other/file.txt", "data")]) + with patch("esphome.framework_helpers._detect_archive_root", return_value="w"): + _zip_extract_all(buf, tmp_path) + assert not list(tmp_path.rglob("*")) + + def test_progress_bar_shown(self, tmp_path: Path) -> None: + """A non-empty progress_header causes ProgressBar to be created and updated.""" + buf = _make_zip([("file.txt", "hello")]) + with patch("esphome.framework_helpers.ProgressBar") as mock_pb: + _zip_extract_all(buf, tmp_path, progress_header="Unzipping") + mock_pb.assert_called_once_with("Unzipping") + mock_pb.return_value.update.assert_called() + + +# --------------------------------------------------------------------------- +# _rename_with_retry +# --------------------------------------------------------------------------- + + +class TestRenameWithRetry: + def test_success_on_first_attempt(self, tmp_path: Path) -> None: + src = tmp_path / "src.txt" + src.write_text("data") + dst = tmp_path / "dst.txt" + _rename_with_retry(src, dst) + assert dst.read_text() == "data" + assert not src.exists() + + def test_retries_on_permission_error_then_succeeds(self, tmp_path: Path) -> None: + src = tmp_path / "src.txt" + src.write_text("data") + dst = tmp_path / "dst.txt" + call_count = 0 + original_rename = Path.rename + + def flaky_rename(self, target): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise PermissionError("locked") + return original_rename(self, target) + + with ( + patch.object(Path, "rename", flaky_rename), + patch("esphome.framework_helpers.time.sleep"), + ): + _rename_with_retry(src, dst, attempts=3) + assert dst.read_text() == "data" + + def test_raises_after_all_attempts_fail(self, tmp_path: Path) -> None: + src = tmp_path / "src.txt" + src.write_text("data") + dst = tmp_path / "dst.txt" + with ( + patch.object(Path, "rename", side_effect=PermissionError("locked")), + patch("esphome.framework_helpers.time.sleep"), + pytest.raises(PermissionError), + ): + _rename_with_retry(src, dst, attempts=3) + + def test_attempts_zero_is_noop(self, tmp_path: Path) -> None: + """Zero attempts means the for-loop body never runs; src is untouched.""" + src = tmp_path / "src.txt" + src.write_text("data") + dst = tmp_path / "dst.txt" + _rename_with_retry(src, dst, attempts=0) + assert src.exists() + assert not dst.exists() + + +# --------------------------------------------------------------------------- +# _7z_extract_all +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _HAS_PY7ZR, reason="py7zr not installed") +class TestSevenZipExtractAll: + @staticmethod + def _make_7z(entries: dict[str, bytes]) -> io.BytesIO: + import py7zr + + buf = io.BytesIO() + with py7zr.SevenZipFile(buf, "w") as sz: + for name, content in entries.items(): + sz.writef(io.BytesIO(content), name) + buf.seek(0) + return buf + + def test_basic_extraction_no_wrapper(self, tmp_path: Path) -> None: + buf = self._make_7z({"a.txt": b"aaa", "b.txt": b"bbb"}) + out = tmp_path / "out" + out.mkdir() + _7z_extract_all(buf, out) + assert (out / "a.txt").exists() + assert (out / "b.txt").exists() + + def test_strips_wrapper_directory(self, tmp_path: Path) -> None: + buf = self._make_7z({"wrapper/file.txt": b"data"}) + out = tmp_path / "out" + out.mkdir() + _7z_extract_all(buf, out) + assert (out / "file.txt").exists() + assert not (out / "wrapper").exists() + + def test_staging_suffix_collision(self, tmp_path: Path) -> None: + """When .extract_tmp_0 already exists, suffix is incremented to find a free slot.""" + out = tmp_path / "out" + out.mkdir() + (out / ".extract_tmp_0").mkdir() + buf = self._make_7z({"file.txt": b"hi"}) + _7z_extract_all(buf, out) + assert (out / "file.txt").exists() + # .extract_tmp_1 should be cleaned up after extraction + assert not (out / ".extract_tmp_1").exists() + + def test_overwrites_existing_directory(self, tmp_path: Path) -> None: + """Pre-existing destination directory is replaced.""" + out = tmp_path / "out" + out.mkdir() + existing_dir = out / "file.txt" + existing_dir.mkdir() + buf = self._make_7z({"file.txt": b"new"}) + _7z_extract_all(buf, out) + assert (out / "file.txt").is_file() + + def test_overwrites_existing_file(self, tmp_path: Path) -> None: + """Pre-existing destination file is replaced.""" + out = tmp_path / "out" + out.mkdir() + (out / "file.txt").write_bytes(b"old") + buf = self._make_7z({"file.txt": b"new"}) + _7z_extract_all(buf, out) + assert (out / "file.txt").exists() + + def test_empty_name_skipped(self, tmp_path: Path) -> None: + """Archive entries with empty names are silently skipped.""" + import py7zr + + buf = self._make_7z({"file.txt": b"data"}) + out = tmp_path / "out" + out.mkdir() + with patch.object( + py7zr.SevenZipFile, "getnames", return_value=["", "file.txt"] + ): + _7z_extract_all(buf, out) + assert (out / "file.txt").exists() + + def test_path_traversal_skipped(self, tmp_path: Path) -> None: + """Entries whose resolved path exits extract_dir are skipped.""" + import py7zr + + buf = self._make_7z({"file.txt": b"safe"}) + out = tmp_path / "out" + out.mkdir() + with patch.object( + py7zr.SevenZipFile, "getnames", return_value=["../escape.txt", "file.txt"] + ): + _7z_extract_all(buf, out) + assert not (tmp_path / "escape.txt").exists() + assert (out / "file.txt").exists() + + def test_progress_bar_shown(self, tmp_path: Path) -> None: + buf = self._make_7z({"file.txt": b"x"}) + out = tmp_path / "out" + out.mkdir() + with patch("esphome.framework_helpers.ProgressBar") as mock_pb: + _7z_extract_all(buf, out, progress_header="Unpacking 7z") + mock_pb.assert_called_once_with("Unpacking 7z") + mock_pb.return_value.update.assert_called() + + def test_absolute_path_in_names_skipped(self, tmp_path: Path) -> None: + """Names that resolve as absolute are silently skipped.""" + import py7zr + + buf = self._make_7z({"file.txt": b"safe"}) + out = tmp_path / "out" + out.mkdir() + + original_is_absolute = Path.is_absolute + + def patched_is_absolute(self: Path) -> bool: + if str(self).startswith("C:"): + return True + return original_is_absolute(self) + + with ( + patch.object( + py7zr.SevenZipFile, "getnames", return_value=["C:/evil.txt", "file.txt"] + ), + patch.object(Path, "is_absolute", patched_is_absolute), + ): + _7z_extract_all(buf, out) + # Avoid `out / "C:"` here: pathlib treats "C:" as a drive (always + # "exists" on Windows). Assert on the actual extracted files instead. + extracted = sorted(p.name for p in out.rglob("*") if p.is_file()) + assert extracted == ["file.txt"] + + def test_dispatched_via_archive_extract_all(self, tmp_path: Path) -> None: + """archive_extract_all dispatches 7z archives to _7z_extract_all.""" + buf = self._make_7z({"hello.txt": b"world"}) + data = buf.read() + assert data[:6] == b"\x37\x7a\xbc\xaf\x27\x1c" + archive = tmp_path / "test.7z" + archive.write_bytes(data) + out = tmp_path / "out" + out.mkdir() + archive_extract_all(archive, out) + assert (out / "hello.txt").exists() diff --git a/tests/unit_tests/test_nrf52_framework.py b/tests/unit_tests/test_nrf52_framework.py new file mode 100644 index 0000000000..9652ad08eb --- /dev/null +++ b/tests/unit_tests/test_nrf52_framework.py @@ -0,0 +1,219 @@ +"""Tests for esphome.components.nrf52.framework helpers.""" + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from esphome.components.nrf52.framework import ( + _TOOLCHAIN_VERSION, + _get_toolchain_platform_info, + check_and_install, +) +from esphome.config_validation import Version +from esphome.const import KEY_CORE, KEY_FRAMEWORK_VERSION +from esphome.core import CORE, EsphomeError + + +@pytest.mark.parametrize( + ("system", "machine", "expected"), + [ + # default — no branch hit + ("Linux", "x86_64", ("linux", "x86_64", "tar.xz")), + # arm64 → aarch64 rename + ("Linux", "arm64", ("linux", "aarch64", "tar.xz")), + # darwin → macos rename only + ("Darwin", "x86_64", ("macos", "x86_64", "tar.xz")), + # both renames apply + ("Darwin", "arm64", ("macos", "aarch64", "tar.xz")), + # windows forces x86_64 + 7z; arm64 rename is overwritten + ("Windows", "arm64", ("windows", "x86_64", "7z")), + ], +) +def test_get_toolchain_platform_info( + system: str, machine: str, expected: tuple[str, str, str] +) -> None: + with ( + patch("platform.system", return_value=system), + patch("platform.machine", return_value=machine), + ): + assert _get_toolchain_platform_info() == expected + + +# --------------------------------------------------------------------------- +# Helpers and fixtures for check_and_install tests +# --------------------------------------------------------------------------- + +_TEST_SDK_VERSION = "2.9.0" + + +@pytest.fixture +def nrf52_dirs(setup_core: Path) -> SimpleNamespace: + """Populate CORE and pre-create SDK directories so sentinel.touch() succeeds.""" + CORE.data[KEY_CORE] = {KEY_FRAMEWORK_VERSION: Version.parse(_TEST_SDK_VERSION)} + tools = CORE.data_dir / "sdk-nrf" + python_env = tools / "penvs" / f"v{_TEST_SDK_VERSION}" + framework = tools / "frameworks" / f"v{_TEST_SDK_VERSION}" + toolchain_dir = tools / "toolchains" / _TOOLCHAIN_VERSION + for d in (python_env, framework, toolchain_dir): + d.mkdir(parents=True, exist_ok=True) + return SimpleNamespace( + python_env=python_env, + framework=framework, + toolchain=toolchain_dir, + ) + + +@pytest.fixture +def mock_nrf52_ops(): + """Patch all heavy I/O operations used by check_and_install.""" + with ( + patch("esphome.components.nrf52.framework.rmdir") as mock_rmdir, + patch("esphome.components.nrf52.framework.create_venv") as mock_create_venv, + patch( + "esphome.components.nrf52.framework.run_command_ok", return_value=True + ) as mock_run_cmd, + patch( + "esphome.components.nrf52.framework.download_from_mirrors", + return_value="https://example.com/tc.tar.xz", + ) as mock_download, + patch("esphome.components.nrf52.framework.archive_extract_all") as mock_extract, + ): + yield SimpleNamespace( + rmdir=mock_rmdir, + create_venv=mock_create_venv, + run_command_ok=mock_run_cmd, + download_from_mirrors=mock_download, + archive_extract_all=mock_extract, + ) + + +# --------------------------------------------------------------------------- +# check_and_install tests +# --------------------------------------------------------------------------- + + +class TestCheckAndInstall: + def test_all_installed_skips_all_steps( + self, + nrf52_dirs: SimpleNamespace, + mock_nrf52_ops: SimpleNamespace, + ) -> None: + """All three sentinels present → nothing downloaded or compiled.""" + (nrf52_dirs.python_env / ".ready").touch() + (nrf52_dirs.framework / ".ready").touch() + (nrf52_dirs.toolchain / ".ready").touch() + + check_and_install() + + mock_nrf52_ops.create_venv.assert_not_called() + mock_nrf52_ops.run_command_ok.assert_not_called() + mock_nrf52_ops.download_from_mirrors.assert_not_called() + mock_nrf52_ops.archive_extract_all.assert_not_called() + + def test_fresh_install_runs_all_steps( + self, + nrf52_dirs: SimpleNamespace, + mock_nrf52_ops: SimpleNamespace, + ) -> None: + """No sentinels → venv created, west installed, SDK init+update, toolchain downloaded.""" + check_and_install() + + mock_nrf52_ops.create_venv.assert_called_once() + # pip install west, west init, west update + assert mock_nrf52_ops.run_command_ok.call_count == 3 + mock_nrf52_ops.download_from_mirrors.assert_called_once() + mock_nrf52_ops.archive_extract_all.assert_called_once() + assert (nrf52_dirs.python_env / ".ready").exists() + assert (nrf52_dirs.framework / ".ready").exists() + assert (nrf52_dirs.toolchain / ".ready").exists() + + def test_venv_exists_installs_framework_and_toolchain( + self, + nrf52_dirs: SimpleNamespace, + mock_nrf52_ops: SimpleNamespace, + ) -> None: + """Venv ready but framework missing → skip venv creation, run SDK init+update.""" + (nrf52_dirs.python_env / ".ready").touch() + + check_and_install() + + mock_nrf52_ops.create_venv.assert_not_called() + # west init + west update only (no pip install) + assert mock_nrf52_ops.run_command_ok.call_count == 2 + mock_nrf52_ops.download_from_mirrors.assert_called_once() + + def test_toolchain_only_missing( + self, + nrf52_dirs: SimpleNamespace, + mock_nrf52_ops: SimpleNamespace, + ) -> None: + """Venv and framework ready → only toolchain downloaded and extracted.""" + (nrf52_dirs.python_env / ".ready").touch() + (nrf52_dirs.framework / ".ready").touch() + + check_and_install() + + mock_nrf52_ops.create_venv.assert_not_called() + mock_nrf52_ops.run_command_ok.assert_not_called() + mock_nrf52_ops.download_from_mirrors.assert_called_once() + mock_nrf52_ops.archive_extract_all.assert_called_once() + + def test_west_install_failure_raises( + self, + nrf52_dirs: SimpleNamespace, + mock_nrf52_ops: SimpleNamespace, + ) -> None: + """Failing pip install west raises EsphomeError.""" + mock_nrf52_ops.run_command_ok.return_value = False + + with pytest.raises(EsphomeError, match="Install west"): + check_and_install() + + def test_framework_init_failure_raises( + self, + nrf52_dirs: SimpleNamespace, + mock_nrf52_ops: SimpleNamespace, + ) -> None: + """Failing west init raises EsphomeError.""" + (nrf52_dirs.python_env / ".ready").touch() + mock_nrf52_ops.run_command_ok.return_value = False + + with pytest.raises(EsphomeError, match="Can't initialize"): + check_and_install() + + def test_framework_update_failure_raises( + self, + nrf52_dirs: SimpleNamespace, + mock_nrf52_ops: SimpleNamespace, + ) -> None: + """Failing west update raises EsphomeError.""" + (nrf52_dirs.python_env / ".ready").touch() + # init succeeds, update fails + mock_nrf52_ops.run_command_ok.side_effect = [True, False] + + with pytest.raises(EsphomeError, match="Can't update"): + check_and_install() + + def test_toolchain_download_passes_platform_substitutions( + self, + nrf52_dirs: SimpleNamespace, + mock_nrf52_ops: SimpleNamespace, + ) -> None: + """download_from_mirrors receives VERSION + platform triple from _get_toolchain_platform_info.""" + (nrf52_dirs.python_env / ".ready").touch() + (nrf52_dirs.framework / ".ready").touch() + + with patch( + "esphome.components.nrf52.framework._get_toolchain_platform_info", + return_value=("linux", "x86_64", "tar.xz"), + ): + check_and_install() + + args, _ = mock_nrf52_ops.download_from_mirrors.call_args + substitutions = args[1] + assert substitutions["VERSION"] == _TOOLCHAIN_VERSION + assert substitutions["sysname"] == "linux" + assert substitutions["machine"] == "x86_64" + assert substitutions["extension"] == "tar.xz" diff --git a/tests/unit_tests/test_platformio_toolchain.py b/tests/unit_tests/test_platformio_toolchain.py index c1d16530cb..a37b19f584 100644 --- a/tests/unit_tests/test_platformio_toolchain.py +++ b/tests/unit_tests/test_platformio_toolchain.py @@ -442,6 +442,21 @@ def test_run_compile(setup_core: Path, mock_run_platformio_cli_run: Mock) -> Non mock_run_platformio_cli_run.assert_called_once_with(config, True, "-j4") +def test_run_compile_without_process_limit( + setup_core: Path, mock_run_platformio_cli_run: Mock +) -> None: + """When no compile_process_limit is set, run_compile passes no -j flag.""" + from esphome.const import CONF_ESPHOME + + CORE.build_path = str(setup_core / "build" / "test") + config = {CONF_ESPHOME: {}} + mock_run_platformio_cli_run.return_value = 0 + + toolchain.run_compile(config, verbose=False) + + mock_run_platformio_cli_run.assert_called_once_with(config, False) + + def test_get_idedata_caches_result( setup_core: Path, mock_run_platformio_cli_run: Mock ) -> None: