[nrf52] native build - download toolchain and sdk in venv (#16388)

Co-authored-by: Jonathan Swoboda <154711427+swoboda1337@users.noreply.github.com>
Co-authored-by: Jonathan Swoboda <swoboda1337@users.noreply.github.com>
This commit is contained in:
tomaszduda23
2026-06-09 13:04:51 +02:00
committed by GitHub
parent 25d656d468
commit 5faed9d5f5
12 changed files with 2624 additions and 565 deletions

View File

@@ -63,6 +63,7 @@ from .const import (
BOOTLOADER_ADAFRUIT_NRF52_SD140_V6, BOOTLOADER_ADAFRUIT_NRF52_SD140_V6,
BOOTLOADER_ADAFRUIT_NRF52_SD140_V7, BOOTLOADER_ADAFRUIT_NRF52_SD140_V7,
) )
from .framework import check_and_install
# force import gpio to register pin schema # force import gpio to register pin schema
from .gpio import nrf52_pin_to_code # noqa: F401 from .gpio import nrf52_pin_to_code # noqa: F401
@@ -562,3 +563,15 @@ def process_stacktrace(config: ConfigType, line: str, backtrace_state: bool) ->
_LOGGER.error("LR: %s", _addr2line(addr2line, elf, lr)) _LOGGER.error("LR: %s", _addr2line(addr2line, elf, lr))
return False return False
def run_compile(args, config: ConfigType) -> bool:
if CORE.using_toolchain_platformio:
return False
if not CORE.using_toolchain_sdk_nrf:
raise EsphomeError(
"Unsupported toolchain for nRF52. "
"Supported toolchains are 'platformio' and 'sdk-nrf'."
)
check_and_install()
raise EsphomeError("Native build for nRF52 is not implemented yet")

View File

@@ -0,0 +1,171 @@
import logging
import os
from pathlib import Path
import platform
import tempfile
from esphome.const import KEY_CORE, KEY_FRAMEWORK_VERSION
from esphome.core import CORE, EsphomeError
from esphome.framework_helpers import (
archive_extract_all,
create_venv,
download_from_mirrors,
get_python_env_executable_path,
rmdir,
run_command_ok,
str_to_lst_of_str,
)
_LOGGER = logging.getLogger(__name__)
_WEST_VERSION = "1.5.0"
_TOOLCHAIN_VERSION = "0.17.4"
SDK_NG_TOOLCHAIN_MIRRORS = str_to_lst_of_str(
os.environ.get(
"ESPHOME_SDK_NG_TOOLCHAIN_MIRRORS",
"https://github.com/zephyrproject-rtos/sdk-ng/releases/download/v{VERSION}/toolchain_{sysname}-{machine}_arm-zephyr-eabi.{extension}",
)
)
def _get_tools_path() -> Path:
return CORE.data_dir / "sdk-nrf"
def _get_python_env_path(version: str) -> Path:
return _get_tools_path() / "penvs" / version
def _get_framework_path(version: str) -> Path:
return _get_tools_path() / "frameworks" / f"{version}"
def _get_toolchain_path(version: str) -> Path:
return _get_tools_path() / "toolchains" / f"{version}"
# onexc/dir_fd were added to shutil.rmtree in 3.12; the 3.11 branch uses onerror.
_SITECUSTOMIZE = """\
import os, stat, shutil, sys
_orig = shutil.rmtree
def _handler(func, path, exc):
os.chmod(path, stat.S_IWRITE); func(path)
if sys.version_info >= (3, 12):
def _rmtree(path, ignore_errors=False, onerror=None, *, onexc=None, dir_fd=None):
if onerror is None and onexc is None:
onexc = _handler
return _orig(path, ignore_errors=ignore_errors, onerror=onerror, onexc=onexc, dir_fd=dir_fd)
else:
def _rmtree(path, ignore_errors=False, onerror=None):
if onerror is None:
onerror = _handler
return _orig(path, ignore_errors=ignore_errors, onerror=onerror)
shutil.rmtree = _rmtree
"""
def _install_sitecustomize(python_env_path: Path) -> None:
"""Patch shutil.rmtree inside the penv to handle read-only files.
west init's shutil.move falls back to copytree+rmtree on Windows, and
rmtree dies on the read-only .idx/.pack files git just wrote into
manifest-tmp. Dropping a sitecustomize.py into the venv applies the
same fix esphome.helpers.rmtree uses, but inside the subprocess.
"""
if os.name != "nt":
return
site_packages = python_env_path / "Lib" / "site-packages"
site_packages.mkdir(parents=True, exist_ok=True)
(site_packages / "sitecustomize.py").write_text(_SITECUSTOMIZE, encoding="utf-8")
def _get_toolchain_platform_info() -> tuple[str, str, str]:
"""Return (sysname, machine, extension) for the current host."""
extension = "tar.xz"
sysname = platform.system().lower()
machine = platform.machine()
if machine == "arm64":
machine = "aarch64"
if sysname == "darwin":
sysname = "macos"
elif sysname == "windows":
machine = "x86_64"
extension = "7z"
return sysname, machine, extension
def check_and_install() -> None:
framework_ver = CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION]
version = f"v{framework_ver.major}.{framework_ver.minor}.{framework_ver.patch}"
python_env_path = _get_python_env_path(version)
env_python_path = get_python_env_executable_path(python_env_path, "python")
sentinel = python_env_path / ".ready"
install_venv = not sentinel.exists()
if install_venv:
rmdir(python_env_path, msg=f"Clean up {version} Python environment")
create_venv(python_env_path, msg=f"{version}")
_install_sitecustomize(python_env_path)
_LOGGER.info("Installing west %s ...", _WEST_VERSION)
cmd = [str(env_python_path), "-m", "pip", "install", f"west=={_WEST_VERSION}"]
if not run_command_ok(cmd):
raise EsphomeError(f"Install west for {version} Python environment failure")
sentinel.touch()
framework_path = _get_framework_path(version)
sentinel = framework_path / ".ready"
if install_venv or not sentinel.exists():
rmdir(framework_path, msg=f"Clean up {version} framework environment")
_LOGGER.info("Initializing nRF Connect SDK %s ...", version)
cmd = [
str(env_python_path),
"-m",
"west",
"init",
"-m",
"https://github.com/nrfconnect/sdk-nrf",
"--mr",
f"{version}",
str(framework_path),
]
if not run_command_ok(cmd):
raise EsphomeError(f"Can't initialize nRF Connect SDK {version}")
_LOGGER.info("Updating nRF Connect SDK %s (this may take a while) ...", version)
cmd = [
str(env_python_path),
"-m",
"west",
"update",
"--narrow",
"--fetch-opt=--depth=1",
]
if not run_command_ok(cmd, cwd=framework_path):
raise EsphomeError(f"Can't update nRF Connect SDK {version}")
sentinel.touch()
toolchains_dir = _get_toolchain_path(_TOOLCHAIN_VERSION)
sentinel = toolchains_dir / ".ready"
if not sentinel.exists():
rmdir(
toolchains_dir, msg=f"Clean up {_TOOLCHAIN_VERSION} toolchain environment"
)
with tempfile.NamedTemporaryFile() as tmp:
_LOGGER.info("Downloading %s toolchain ...", _TOOLCHAIN_VERSION)
sysname, machine, extension = _get_toolchain_platform_info()
download_from_mirrors(
SDK_NG_TOOLCHAIN_MIRRORS,
{
"VERSION": _TOOLCHAIN_VERSION,
"sysname": sysname,
"machine": machine,
"extension": extension,
},
tmp.file,
)
archive_extract_all(tmp.file, toolchains_dir, progress_header="Extracting")
sentinel.touch()

View File

@@ -20,6 +20,7 @@ class Toolchain(StrEnum):
PLATFORMIO = "platformio" PLATFORMIO = "platformio"
ESP_IDF = "esp-idf" ESP_IDF = "esp-idf"
SDK_NRF = "sdk-nrf"
class Platform(StrEnum): class Platform(StrEnum):

View File

@@ -867,6 +867,10 @@ class EsphomeCore:
def using_toolchain_platformio(self): def using_toolchain_platformio(self):
return self.toolchain == Toolchain.PLATFORMIO return self.toolchain == Toolchain.PLATFORMIO
@property
def using_toolchain_sdk_nrf(self):
return self.toolchain == Toolchain.SDK_NRF
@property @property
def using_zephyr(self): def using_zephyr(self):
return self.target_framework == "zephyr" return self.target_framework == "zephyr"

View File

@@ -1,8 +1,5 @@
"""ESP-IDF framework tools for ESPHome.""" """ESP-IDF framework tools for ESPHome."""
from collections.abc import Iterable
from contextlib import ExitStack
import io
import json import json
import logging import logging
import os import os
@@ -10,39 +7,29 @@ from pathlib import Path
import platform import platform
import re import re
import shutil import shutil
import subprocess
import sys
import tempfile import tempfile
from typing import IO
import requests
from esphome.config_validation import Version from esphome.config_validation import Version
from esphome.core import CORE from esphome.core import CORE
from esphome.helpers import ProgressBar, get_str_env, rmtree, write_file_if_changed from esphome.framework_helpers import (
PathType,
PathType = str | os.PathLike archive_extract_all,
create_venv,
download_from_mirrors,
get_python_env_executable_path,
get_system_python_path,
rmdir,
run_command,
run_command_ok,
str_to_lst_of_str,
)
from esphome.helpers import get_str_env, write_file_if_changed
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_SCRIPTS_DIR = Path(__file__).parent _SCRIPTS_DIR = Path(__file__).parent
def _str_to_lst_of_str(a: str | list[str]) -> list[str]:
"""
Convert a string to a list of string
Args:
a: A string containing semicolon-separated values, or an already-split list
Returns:
list of strings
"""
if isinstance(a, list):
return a
return [f.strip() for f in a.split(";") if f.strip()]
ESPHOME_STAMP_FILE = ".esphome.stamp.json" ESPHOME_STAMP_FILE = ".esphome.stamp.json"
# Cache-buster baked into the stamp file. Bump this whenever a change would # Cache-buster baked into the stamp file. Bump this whenever a change would
@@ -54,23 +41,23 @@ ESPHOME_STAMP_FILE = ".esphome.stamp.json"
# Bumping triggers a full reinstall on every user's next run. # Bumping triggers a full reinstall on every user's next run.
STAMP_SCHEMA_VERSION = "0" STAMP_SCHEMA_VERSION = "0"
ESPHOME_IDF_DEFAULT_TARGETS = _str_to_lst_of_str( ESPHOME_IDF_DEFAULT_TARGETS = str_to_lst_of_str(
os.environ.get("ESPHOME_IDF_DEFAULT_TARGETS", "all") os.environ.get("ESPHOME_IDF_DEFAULT_TARGETS", "all")
) )
ESPHOME_IDF_DEFAULT_TOOLS = _str_to_lst_of_str( ESPHOME_IDF_DEFAULT_TOOLS = str_to_lst_of_str(
os.environ.get("ESPHOME_IDF_DEFAULT_TOOLS", "cmake;ninja") os.environ.get("ESPHOME_IDF_DEFAULT_TOOLS", "cmake;ninja")
) )
ESPHOME_IDF_DEFAULT_TOOLS_FORCE = _str_to_lst_of_str( ESPHOME_IDF_DEFAULT_TOOLS_FORCE = str_to_lst_of_str(
os.environ.get("ESPHOME_IDF_DEFAULT_TOOLS_FORCE", "required") os.environ.get("ESPHOME_IDF_DEFAULT_TOOLS_FORCE", "required")
) )
ESPHOME_IDF_DEFAULT_FEATURES = _str_to_lst_of_str( ESPHOME_IDF_DEFAULT_FEATURES = str_to_lst_of_str(
os.environ.get("ESPHOME_IDF_DEFAULT_FEATURES", "core") os.environ.get("ESPHOME_IDF_DEFAULT_FEATURES", "core")
) )
ESPHOME_IDF_FRAMEWORK_MIRRORS = _str_to_lst_of_str( ESPHOME_IDF_FRAMEWORK_MIRRORS = str_to_lst_of_str(
os.environ.get("ESPHOME_IDF_FRAMEWORK_MIRRORS") os.environ.get("ESPHOME_IDF_FRAMEWORK_MIRRORS")
or [ or [
"https://github.com/esphome-libs/esp-idf/releases/download/v{VERSION}/esp-idf-v{VERSION}.tar.xz", "https://github.com/esphome-libs/esp-idf/releases/download/v{VERSION}/esp-idf-v{VERSION}.tar.xz",
@@ -78,7 +65,7 @@ ESPHOME_IDF_FRAMEWORK_MIRRORS = _str_to_lst_of_str(
] ]
) )
ESP_IDF_CONSTRAINTS_MIRRORS = _str_to_lst_of_str( ESP_IDF_CONSTRAINTS_MIRRORS = str_to_lst_of_str(
os.environ.get( os.environ.get(
"ESP_IDF_CONSTRAINTS_MIRRORS", "ESP_IDF_CONSTRAINTS_MIRRORS",
"https://dl.espressif.com/dl/esp-idf/espidf.constraints.v{VERSION}.txt", "https://dl.espressif.com/dl/esp-idf/espidf.constraints.v{VERSION}.txt",
@@ -124,59 +111,6 @@ def _get_python_env_path(version: str) -> Path:
return _get_idf_tools_path() / "penvs" / f"{version}" return _get_idf_tools_path() / "penvs" / f"{version}"
def rmdir(directory: PathType, msg: str | None = None):
"""
Remove a directory and its contents recursively if it exists.
Args:
directory: Path to the directory to be removed
msg: Optional debug message to log before removal or it an error occurs
Returns:
None
Raises:
RuntimeError: If directory removal fails
"""
if Path(directory).is_dir():
try:
if msg:
_LOGGER.debug(msg)
rmtree(directory)
except OSError as e:
raise RuntimeError(
f"Error during {msg}: can't remove `{directory}`. Please remove it manually!"
) from e
def _get_pythonexe_path() -> str:
"""
Get the path to the Python executable.
Returns:
Path to Python executable as string
"""
# Try to get PYTHONEXEPATH environment variable
# Fallback to sys.executable if not set
return os.environ.get("PYTHONEXEPATH", os.path.normpath(sys.executable))
def _get_python_env_executable_path(root: PathType, binary: str) -> Path:
"""
Get the path to a Python environment executable file.
Args:
root: Root directory of the Python environment
binary: Name of the executable binary
Returns:
Path object pointing to the executable file
"""
if os.name == "nt":
return Path(root) / "Scripts" / f"{binary}.exe"
return Path(root) / "bin" / binary
def _check_stamp(file: PathType, data: dict[str, str]) -> bool: def _check_stamp(file: PathType, data: dict[str, str]) -> bool:
""" """
Check if a stamp file contains the expected data. Check if a stamp file contains the expected data.
@@ -210,84 +144,6 @@ def _write_stamp(file: PathType, data: dict[str, str]):
json.dump(data, fp) json.dump(data, fp)
def _exec(
cmd: list[str],
msg: str | None = None,
env: dict[str, str] | None = None,
stream_output: bool = False,
) -> tuple[bool, str | None, str | None]:
"""
Execute a command and return results.
Args:
cmd: list of command arguments
msg: Optional custom message for logging
env: Optional dictionary of environment variables to set
stream_output: If True, inherit parent stdio so the subprocess prints
directly to the terminal (useful for commands that produce their
own progress output). stdout/stderr are not captured in this mode.
Returns:
tuple of (success: bool, stdout: str or None, stderr: str or None).
When stream_output is True, stdout and stderr are always None.
"""
cmd_str = msg or " ".join(cmd)
try:
_LOGGER.debug("%s - running ...", cmd_str)
run_env = os.environ.copy()
if env:
run_env.update(env)
if stream_output:
result = subprocess.run(cmd, check=False, env=run_env)
stdout = stderr = None
else:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
env=run_env,
)
stdout = result.stdout
stderr = result.stderr
if result.returncode != 0:
if stream_output:
_LOGGER.error("%s - failed (returncode=%s)", cmd_str, result.returncode)
else:
tail = (stderr or stdout or "").strip()[-1000:]
_LOGGER.error(
"%s - failed (returncode=%s). Tail:\n%s",
cmd_str,
result.returncode,
tail,
)
return False, stdout, stderr
_LOGGER.debug("%s - executed successfully", cmd_str)
return True, stdout, stderr
except (subprocess.SubprocessError, OSError) as e:
_LOGGER.error("%s - error: %s", cmd_str, str(e))
return False, None, None
def _exec_ok(*args, **kwargs) -> bool:
"""
Execute a command and return only the success status.
Args:
*args: Positional arguments to pass to _exec function
**kwargs: Keyword arguments to pass to _exec function
Returns:
True if command executed successfully, False otherwise
"""
return _exec(*args, **kwargs)[0]
def _get_idf_version( def _get_idf_version(
idf_framework_root: PathType, env: dict[str, str] | None = None idf_framework_root: PathType, env: dict[str, str] | None = None
) -> str: ) -> str:
@@ -306,12 +162,12 @@ def _get_idf_version(
""" """
cmd = [ cmd = [
_get_pythonexe_path(), get_system_python_path(),
str(_SCRIPTS_DIR / "get_idf_version.py"), str(_SCRIPTS_DIR / "get_idf_version.py"),
str(idf_framework_root), str(idf_framework_root),
] ]
success, stdout, stderr = _exec( success, stdout, stderr = run_command(
cmd, cmd,
msg="ESP-IDF version", msg="ESP-IDF version",
env=(env or os.environ) env=(env or os.environ)
@@ -346,12 +202,12 @@ def _get_idf_tool_paths(
""" """
cmd = [ cmd = [
_get_pythonexe_path(), get_system_python_path(),
str(_SCRIPTS_DIR / "get_idf_tool_paths.py"), str(_SCRIPTS_DIR / "get_idf_tool_paths.py"),
str(idf_framework_root), str(idf_framework_root),
] ]
success, stdout, stderr = _exec( success, stdout, stderr = run_command(
cmd, cmd,
msg="ESP-IDF tool paths", msg="ESP-IDF tool paths",
env=(env or os.environ) env=(env or os.environ)
@@ -397,7 +253,7 @@ print(".".join([str(x) for x in sys.version_info]))
""" """
cmd = [python_executable, "-c", script] cmd = [python_executable, "-c", script]
success, stdout, _ = _exec(cmd, msg="Python version", env=env) success, stdout, _ = run_command(cmd, msg="Python version", env=env)
if stdout: if stdout:
stdout = stdout.strip() stdout = stdout.strip()
@@ -406,393 +262,6 @@ print(".".join([str(x) for x in sys.version_info]))
return stdout return stdout
def _create_venv(root: PathType, msg: str | None = None):
"""
Create a Python virtual environment.
Args:
root: Path to the virtual environment directory
msg: Optional message for logging
Returns:
None
Raises:
Exception: If virtual environment creation fails
"""
cmd = [_get_pythonexe_path(), "-m", "venv", "--clear", root]
if not _exec_ok(cmd, msg=f"Create Python virtual environment for {msg}"):
raise RuntimeError(f"Can't create Python virtual environment for {msg}")
def _detect_archive_root(names: Iterable[str]) -> str | None:
"""Detect a single top-level directory shared by all archive entries.
Returns the directory name if every non-empty entry sits under the same
top-level directory, else ``None``. Extraction helpers use this to strip
the wrapper directory commonly found in source archives during extraction
rather than renaming it afterwards — post-extraction renames are
unreliable on Windows because antivirus and the search indexer briefly
hold handles on freshly written files.
"""
root: str | None = None
has_descendant = False
for raw in names:
name = raw.replace("\\", "/").strip("/")
if not name:
continue
first, sep, _ = name.partition("/")
if root is None:
root = first
elif root != first:
return None
if sep:
has_descendant = True
return root if has_descendant else None
def _tar_extract_all(
data: io.BufferedIOBase,
extract_dir: PathType = ".",
progress_header: str | None = None,
):
"""
Extract a TAR archive to the specified directory.
Implementation is inspired by Python 3.12's tarfile data filtering logic.
This can be replaced with the standard library implementation once
support for Python 3.11 is no longer required.
Args:
data: File-like object containing the TAR archive
extract_dir: Directory to extract contents to
progress_header: If set, show a progress bar with this header
"""
import stat
import tarfile
# Tar extraction safety: os.path.realpath / commonpath / normpath have no
# pathlib equivalents and Path.resolve() would follow symlinks unsafely.
# Use os.path for the security-sensitive parts; the simple checks move to
# Path.
extract_dir = os.fspath(extract_dir)
abs_dest = os.path.abspath(extract_dir) # noqa: PTH100
with tarfile.open(fileobj=data, mode="r") as tar_ref:
all_members = tar_ref.getmembers()
# Detect a single common top-level directory and strip it during
# extraction so we don't have to flatten it via a rename afterwards.
strip_root = _detect_archive_root(m.name for m in all_members)
strip_prefix = f"{strip_root}/" if strip_root is not None else None
safe_members = []
for member in all_members:
name = member.name
# 1. Strip leading slashes
name = name.lstrip("/" + os.sep)
# 2. Reject absolute paths (incl. Windows drive)
if Path(name).is_absolute() or (
os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206
):
continue
# 3. Strip wrapper directory if one was detected
if strip_prefix is not None:
norm = name.replace("\\", "/")
if norm in (strip_root, strip_prefix):
continue
if not norm.startswith(strip_prefix):
continue
name = norm[len(strip_prefix) :]
# 4. Compute final path
target_path = os.path.realpath(os.path.join(abs_dest, name)) # noqa: PTH118
if os.path.commonpath([abs_dest, target_path]) != abs_dest:
continue
# 5. Validate links properly
if member.issym() or member.islnk():
linkname = member.linkname
# Reject absolute link targets
if Path(linkname).is_absolute():
continue
# Strip leading slashes
linkname = os.path.normpath(linkname)
if member.issym():
link_target = os.path.join( # noqa: PTH118
abs_dest,
os.path.dirname(name), # noqa: PTH120
linkname,
)
else:
link_target = os.path.join(abs_dest, linkname) # noqa: PTH118
link_target = os.path.realpath(link_target)
if os.path.commonpath([abs_dest, link_target]) != abs_dest:
continue
# write back normalized linkname
member.linkname = linkname
# 6. Sanitize permissions
mode = member.mode
if mode is not None:
# Strip high bits & group/other write bits
mode &= (
stat.S_IRWXU
| stat.S_IRGRP
| stat.S_IXGRP
| stat.S_IROTH
| stat.S_IXOTH
)
if member.isfile() or member.islnk():
# remove exec bits unless explicitly user-executable
if not (mode & stat.S_IXUSR):
mode &= ~(stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
mode |= stat.S_IRUSR | stat.S_IWUSR
elif not (member.isdir() or member.issym()):
# Block special files. Directories and symlinks keep
# their masked-original mode — passing None here would
# crash tarfile.extract on Python <3.12 (its chmod
# path calls os.chmod unconditionally).
continue
member.mode = mode
# 7. Strip ownership
member.uid = None
member.gid = None
member.uname = None
member.gname = None
# 8. Assign sanitized name back
member.name = name
safe_members.append(member)
total = len(safe_members)
progress = (
ProgressBar(progress_header) if progress_header and total > 0 else None
)
for i, member in enumerate(safe_members, 1):
tar_ref.extract(member, abs_dest)
if progress is not None:
progress.update(i / total)
if progress is not None:
progress.update(1)
def _zip_extract_all(
data: io.BufferedIOBase,
extract_dir: PathType = ".",
progress_header: str | None = None,
):
"""
Extract a ZIP archive to the specified directory.
Args:
data: File-like object containing the ZIP archive
extract_dir: Directory to extract contents to
progress_header: If set, show a progress bar with this header
"""
import zipfile
# See note in archive_extract_all_tar: os.path is used intentionally for
# the security-sensitive abspath/commonpath checks below.
extract_dir = os.path.abspath(extract_dir) # noqa: PTH100
with zipfile.ZipFile(data, "r") as zip_ref:
all_members = zip_ref.infolist()
# Detect a single common top-level directory and strip it during
# extraction so we don't have to flatten it via a rename afterwards.
strip_root = _detect_archive_root(m.filename for m in all_members)
strip_prefix = f"{strip_root}/" if strip_root is not None else None
total = len(all_members)
progress = (
ProgressBar(progress_header) if progress_header and total > 0 else None
)
for i, member in enumerate(all_members, 1):
# 1. Normalize name
name = member.filename.lstrip("/\\")
# 2. Reject absolute paths / Windows drives
if Path(name).is_absolute() or (
os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206
):
continue
# 3. Strip wrapper directory if one was detected
if strip_prefix is not None:
norm = name.replace("\\", "/")
if norm in (strip_root, strip_prefix):
continue
if not norm.startswith(strip_prefix):
continue
name = norm[len(strip_prefix) :]
# 4. Compute safe target path
target_path = os.path.abspath(os.path.join(extract_dir, name)) # noqa: PTH100, PTH118
if os.path.commonpath([extract_dir, target_path]) != extract_dir:
raise ValueError(f"Unsafe path detected: {member.filename}")
# 5. Assign sanitized name back
member.filename = name
# 6. Extract
zip_ref.extract(member, extract_dir)
if progress is not None:
progress.update(i / total)
if progress is not None:
progress.update(1)
_ARCHIVE_MAGIC_MAP = {
b"\x1f\x8b\x08": _tar_extract_all,
b"\x42\x5a\x68": _tar_extract_all,
b"\xfd\x37\x7a\x58\x5a\x00": _tar_extract_all,
b"\x50\x4b\x03\x04": _zip_extract_all,
}
def archive_extract_all(
archive: PathType | io.RawIOBase | IO[bytes],
extract_dir: PathType = ".",
progress_header: str | None = None,
):
"""
Extract an archive file to the specified directory.
Args:
archive: Path to archive file or file-like object
extract_dir: Directory to extract contents to
progress_header: If set, show a progress bar with this header
Raises:
TypeError: If archive is not a valid type
ValueError: If archive format is unsupported
"""
# 1. Handle different archive input types
with ExitStack() as stack:
archive_ref: io.BufferedIOBase
if isinstance(archive, (str, os.PathLike)):
archive_ref = stack.enter_context(Path(archive).open("rb"))
elif isinstance(archive, (io.BufferedReader, io.BufferedRandom)):
archive_ref = archive
elif isinstance(archive, io.RawIOBase):
archive_ref = io.BufferedReader(archive)
else:
raise TypeError(
f"archive must be str, Path, or file-like object: {type(archive)}"
)
# 2. Detect archive format and select appropriate extraction function
matched_fct = None
magic_len = max(len(k) for k in _ARCHIVE_MAGIC_MAP)
header = archive_ref.peek(magic_len)
for magic, fct in _ARCHIVE_MAGIC_MAP.items():
if header.startswith(magic):
matched_fct = fct
break
if matched_fct is None:
raise ValueError("Unsupported archive format")
matched_fct(archive_ref, extract_dir, progress_header=progress_header)
def download_from_mirrors(
mirrors: list[str],
substitutions: dict[str, str],
target: io.RawIOBase | IO[bytes] | PathType,
timeout: int = 30,
) -> str | None:
"""
Download file from multiple mirrors with substitution support.
Args:
mirrors: list of mirror URLs
substitutions: Dictionary of substitutions to apply to URLs
target: Target file path or file-like object
timeout: Download timeout in seconds
Returns:
The source URL.
Raises:
Exception: If all download attempts fail
"""
# 1. Open target file for writing if path given
with ExitStack() as stack:
if isinstance(target, (str, os.PathLike)):
f = stack.enter_context(Path(target).open("wb"))
elif isinstance(target, (io.RawIOBase, io.IOBase)):
f = target
else:
raise TypeError(
f"target must be str, Path, or file-like object: {type(target)}"
)
# 2. Try each mirror in order
last_exception = None
for mirror in mirrors:
# 3. Apply substitutions to URL
url = mirror.format(**substitutions)
_LOGGER.debug("Trying downloading from %s", url)
try:
# 4. Reset file pointer and download
f.seek(0)
f.truncate(0)
with requests.get(url, stream=True, timeout=timeout) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
downloaded = 0
progress = ProgressBar("Downloading") if total_size > 0 else None
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if progress is not None:
progress.update(downloaded / total_size)
if progress is not None:
progress.update(1)
_LOGGER.debug("Downloaded successfully from: %s", url)
# 6. Reset file pointer and return
f.seek(0)
return url
except Exception as e: # noqa: BLE001 # pylint: disable=broad-exception-caught
_LOGGER.debug("Failed to download %s: %s", url, str(e))
last_exception = e
# 7. Raise last exception if all mirrors failed
if last_exception:
raise last_exception
return None
_GITHUB_SHORTHAND_RE = re.compile( _GITHUB_SHORTHAND_RE = re.compile(
r"^github://([a-zA-Z0-9\-]+)/([a-zA-Z0-9\-\._]+?)(?:@([a-zA-Z0-9\-_.\./]+))?$" r"^github://([a-zA-Z0-9\-]+)/([a-zA-Z0-9\-\._]+?)(?:@([a-zA-Z0-9\-_.\./]+))?$"
) )
@@ -1067,12 +536,12 @@ def _check_esphome_idf_framework_install(
if _check_stamp(env_stamp_file, stamp_info): if _check_stamp(env_stamp_file, stamp_info):
_LOGGER.info("Checking ESP-IDF %s framework installation ...", version) _LOGGER.info("Checking ESP-IDF %s framework installation ...", version)
cmd = [ cmd = [
_get_pythonexe_path(), get_system_python_path(),
str(idf_tools_path), str(idf_tools_path),
"--non-interactive", "--non-interactive",
"check", "check",
] ]
if _exec_ok(cmd, msg=f"ESP-IDF {version} check", env=env): if run_command_ok(cmd, msg=f"ESP-IDF {version} check", env=env):
install = False install = False
# 4. Install framework tools if not installed or needs update # 4. Install framework tools if not installed or needs update
@@ -1080,13 +549,13 @@ def _check_esphome_idf_framework_install(
_LOGGER.info("Installing ESP-IDF %s framework ...", version) _LOGGER.info("Installing ESP-IDF %s framework ...", version)
targets_str = ",".join(targets) targets_str = ",".join(targets)
cmd = [ cmd = [
_get_pythonexe_path(), get_system_python_path(),
str(idf_tools_path), str(idf_tools_path),
"--non-interactive", "--non-interactive",
"install", "install",
f"--targets={targets_str}", f"--targets={targets_str}",
] + tools ] + tools
if not _exec_ok( if not run_command_ok(
cmd, cmd,
msg=f"ESP-IDF {version} framework installation", msg=f"ESP-IDF {version} framework installation",
env=env, env=env,
@@ -1128,7 +597,7 @@ def _check_esp_idf_python_env_install(
framework_path = _get_framework_path(version) framework_path = _get_framework_path(version)
python_env_path = _get_python_env_path(version) python_env_path = _get_python_env_path(version)
env_stamp_file = python_env_path / ESPHOME_STAMP_FILE env_stamp_file = python_env_path / ESPHOME_STAMP_FILE
env_python_path = _get_python_env_executable_path(python_env_path, "python") env_python_path = get_python_env_executable_path(python_env_path, "python")
_LOGGER.info("Checking ESP-IDF %s Python environment ...", version) _LOGGER.info("Checking ESP-IDF %s Python environment ...", version)
install = force or not python_env_path.is_dir() or not env_python_path.is_file() install = force or not python_env_path.is_dir() or not env_python_path.is_file()
@@ -1144,7 +613,7 @@ def _check_esp_idf_python_env_install(
if install: if install:
rmdir(python_env_path, msg=f"Clean up ESP-IDF {version} Python environment") rmdir(python_env_path, msg=f"Clean up ESP-IDF {version} Python environment")
_create_venv(python_env_path, msg=f"ESP-IDF {version}") create_venv(python_env_path, msg=f"ESP-IDF {version}")
esp_idf_version = _get_idf_version(framework_path, env=env) esp_idf_version = _get_idf_version(framework_path, env=env)
constraint_file_path = ( constraint_file_path = (
@@ -1174,7 +643,7 @@ def _check_esp_idf_python_env_install(
"pip", "pip",
"setuptools", "setuptools",
] ]
if not _exec_ok( if not run_command_ok(
cmd, cmd,
msg=f"Upgrade ESP-IDF {version} Python environment packages", msg=f"Upgrade ESP-IDF {version} Python environment packages",
env=env, env=env,
@@ -1194,7 +663,7 @@ def _check_esp_idf_python_env_install(
"-r", "-r",
str(requirements_file), str(requirements_file),
] ]
if not _exec_ok( if not run_command_ok(
cmd, cmd,
msg=f"Install ESP-IDF {version} Python dependencies for {feature}", msg=f"Install ESP-IDF {version} Python dependencies for {feature}",
env=env, env=env,
@@ -1296,7 +765,7 @@ def get_framework_env(
# 3. If Python environment path is provided, add it to PATH and set IDF_PYTHON_ENV_PATH # 3. If Python environment path is provided, add it to PATH and set IDF_PYTHON_ENV_PATH
if python_env_path: if python_env_path:
python_path = _get_python_env_executable_path(python_env_path, "python") python_path = get_python_env_executable_path(python_env_path, "python")
path_list.insert(0, str(python_path.parent)) path_list.insert(0, str(python_path.parent))
env["IDF_PYTHON_ENV_PATH"] = str(python_env_path) env["IDF_PYTHON_ENV_PATH"] = str(python_env_path)

View File

@@ -0,0 +1,677 @@
"""Generic toolchain installation helpers shared across framework implementations."""
from collections.abc import Iterable
from contextlib import ExitStack
import io
import logging
import os
from pathlib import Path
import subprocess
import sys
import time
from typing import IO
import requests
from esphome.helpers import ProgressBar, rmtree
PathType = str | os.PathLike
_LOGGER = logging.getLogger(__name__)
def str_to_lst_of_str(a: str | list[str]) -> list[str]:
"""
Convert a string to a list of string
Args:
a: A string containing semicolon-separated values, or an already-split list
Returns:
list of strings
"""
if isinstance(a, list):
return a
return [f.strip() for f in a.split(";") if f.strip()]
def rmdir(directory: PathType, msg: str | None = None):
"""
Remove a directory and its contents recursively if it exists.
Args:
directory: Path to the directory to be removed
msg: Optional debug message to log before removal or it an error occurs
Returns:
None
Raises:
RuntimeError: If directory removal fails
"""
if Path(directory).is_dir():
try:
if msg:
_LOGGER.debug(msg)
rmtree(directory)
except OSError as e:
raise RuntimeError(
f"Error during {msg}: can't remove `{directory}`. Please remove it manually!"
) from e
def get_system_python_path() -> str:
"""
Get the path to the Python executable.
Returns:
Path to Python executable as string
"""
# Try to get PYTHONEXEPATH environment variable
# Fallback to sys.executable if not set
return os.environ.get("PYTHONEXEPATH", os.path.normpath(sys.executable))
def get_python_env_executable_path(root: PathType, binary: str) -> Path:
"""
Get the path to a Python environment executable file.
Args:
root: Root directory of the Python environment
binary: Name of the executable binary
Returns:
Path object pointing to the executable file
"""
if os.name == "nt":
return Path(root) / "Scripts" / f"{binary}.exe"
return Path(root) / "bin" / binary
def run_command(
cmd: list[str],
msg: str | None = None,
env: dict[str, str] | None = None,
stream_output: bool = False,
cwd: PathType | None = None,
) -> tuple[bool, str | None, str | None]:
"""
Execute a command and return results.
Args:
cmd: list of command arguments
msg: Optional custom message for logging
env: Optional dictionary of environment variables to set
stream_output: If True, inherit parent stdio so the subprocess prints
directly to the terminal (useful for commands that produce their
own progress output). stdout/stderr are not captured in this mode.
cwd: Optional working directory for the subprocess.
Returns:
tuple of (success: bool, stdout: str or None, stderr: str or None).
When stream_output is True, stdout and stderr are always None.
"""
cmd_str = msg or " ".join(cmd)
try:
_LOGGER.debug("%s - running ...", cmd_str)
run_env = os.environ.copy()
if env:
run_env.update(env)
if stream_output:
result = subprocess.run(cmd, check=False, env=run_env, cwd=cwd)
stdout = stderr = None
else:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
env=run_env,
cwd=cwd,
)
stdout = result.stdout
stderr = result.stderr
if result.returncode != 0:
if stream_output:
_LOGGER.error("%s - failed (returncode=%s)", cmd_str, result.returncode)
else:
tail = (stderr or stdout or "").strip()[-1000:]
_LOGGER.error(
"%s - failed (returncode=%s). Tail:\n%s",
cmd_str,
result.returncode,
tail,
)
return False, stdout, stderr
_LOGGER.debug("%s - executed successfully", cmd_str)
return True, stdout, stderr
except (subprocess.SubprocessError, OSError) as e:
_LOGGER.error("%s - error: %s", cmd_str, str(e))
return False, None, None
def run_command_ok(*args, **kwargs) -> bool:
"""
Execute a command and return only the success status.
Args:
*args: Positional arguments to pass to run_command
**kwargs: Keyword arguments to pass to run_command
Returns:
True if command executed successfully, False otherwise
"""
return run_command(*args, **kwargs)[0]
def create_venv(root: PathType, msg: str | None = None):
"""
Create a Python virtual environment.
Args:
root: Path to the virtual environment directory
msg: Optional message for logging
Returns:
None
Raises:
RuntimeError: If virtual environment creation fails
"""
cmd = [get_system_python_path(), "-m", "venv", "--clear", root]
if not run_command_ok(cmd, msg=f"Create Python virtual environment for {msg}"):
raise RuntimeError(f"Can't create Python virtual environment for {msg}")
def _detect_archive_root(names: Iterable[str]) -> str | None:
"""Detect a single top-level directory shared by all archive entries.
Returns the directory name if every non-empty entry sits under the same
top-level directory, else ``None``. Extraction helpers use this to strip
the wrapper directory commonly found in source archives during extraction
rather than renaming it afterwards — post-extraction renames are
unreliable on Windows because antivirus and the search indexer briefly
hold handles on freshly written files.
"""
root: str | None = None
has_descendant = False
for raw in names:
name = raw.replace("\\", "/").strip("/")
if not name:
continue
first, sep, _ = name.partition("/")
if root is None:
root = first
elif root != first:
return None
if sep:
has_descendant = True
return root if has_descendant else None
def _tar_extract_all(
data: io.BufferedIOBase,
extract_dir: PathType = ".",
progress_header: str | None = None,
):
"""
Extract a TAR archive to the specified directory.
Implementation is inspired by Python 3.12's tarfile data filtering logic.
This can be replaced with the standard library implementation once
support for Python 3.11 is no longer required.
Args:
data: File-like object containing the TAR archive
extract_dir: Directory to extract contents to
progress_header: If set, show a progress bar with this header
"""
import stat
import tarfile
# Tar extraction safety: os.path.realpath / commonpath / normpath have no
# pathlib equivalents and Path.resolve() would follow symlinks unsafely.
# Use os.path for the security-sensitive parts; the simple checks move to
# Path.
extract_dir = os.fspath(extract_dir)
abs_dest = os.path.abspath(extract_dir) # noqa: PTH100
with tarfile.open(fileobj=data, mode="r") as tar_ref:
all_members = tar_ref.getmembers()
# Detect a single common top-level directory and strip it during
# extraction so we don't have to flatten it via a rename afterwards.
strip_root = _detect_archive_root(m.name for m in all_members)
strip_prefix = f"{strip_root}/" if strip_root is not None else None
safe_members = []
for member in all_members:
name = member.name
# 1. Strip leading slashes
name = name.lstrip("/" + os.sep)
# 2. Reject absolute paths (incl. Windows drive)
if Path(name).is_absolute() or (
os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206
):
continue
# 3. Strip wrapper directory if one was detected
if strip_prefix is not None:
norm = name.replace("\\", "/")
if norm in (strip_root, strip_prefix):
continue
if not norm.startswith(strip_prefix):
continue
name = norm[len(strip_prefix) :]
# 4. Compute final path
target_path = os.path.realpath(os.path.join(abs_dest, name)) # noqa: PTH118
if os.path.commonpath([abs_dest, target_path]) != abs_dest:
continue
# 5. Validate links properly
if member.issym() or member.islnk():
linkname = member.linkname
# Reject absolute link targets
if Path(linkname).is_absolute():
continue
if member.islnk() and strip_prefix is not None:
# Hard-link linknames reference another archive member
# by its archive name. We've stripped the wrapper prefix
# from member.name above (step 3); strip it here too so
# tarfile._find_link_target can resolve the target during
# extraction. Symlink linknames are filesystem-relative
# paths, not archive-member references, so they don't
# need this treatment.
norm_link = linkname.replace("\\", "/")
if norm_link in (strip_root, strip_prefix):
continue
if not norm_link.startswith(strip_prefix):
continue
linkname = norm_link[len(strip_prefix) :]
# Strip leading slashes
linkname = os.path.normpath(linkname)
if member.issym():
link_target = os.path.join( # noqa: PTH118
abs_dest,
os.path.dirname(name), # noqa: PTH120
linkname,
)
else:
link_target = os.path.join(abs_dest, linkname) # noqa: PTH118
link_target = os.path.realpath(link_target)
if os.path.commonpath([abs_dest, link_target]) != abs_dest:
continue
# write back normalized linkname
member.linkname = linkname
# 6. Sanitize permissions
mode = member.mode
if mode is not None:
# Strip high bits & group/other write bits
mode &= (
stat.S_IRWXU
| stat.S_IRGRP
| stat.S_IXGRP
| stat.S_IROTH
| stat.S_IXOTH
)
if member.isfile() or member.islnk():
# remove exec bits unless explicitly user-executable
if not (mode & stat.S_IXUSR):
mode &= ~(stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
mode |= stat.S_IRUSR | stat.S_IWUSR
elif not (member.isdir() or member.issym()):
# Block special files. Directories and symlinks keep
# their masked-original mode — passing None here would
# crash tarfile.extract on Python <3.12 (its chmod
# path calls os.chmod unconditionally).
continue
member.mode = mode
# 7. Strip ownership
member.uid = None
member.gid = None
member.uname = None
member.gname = None
# 8. Assign sanitized name back
member.name = name
safe_members.append(member)
total = len(safe_members)
progress = (
ProgressBar(progress_header) if progress_header and total > 0 else None
)
for i, member in enumerate(safe_members, 1):
tar_ref.extract(member, abs_dest)
if progress is not None:
progress.update(i / total)
if progress is not None:
progress.update(1)
def _zip_extract_all(
data: io.BufferedIOBase,
extract_dir: PathType = ".",
progress_header: str | None = None,
):
"""
Extract a ZIP archive to the specified directory.
Args:
data: File-like object containing the ZIP archive
extract_dir: Directory to extract contents to
progress_header: If set, show a progress bar with this header
"""
import zipfile
# See note in _tar_extract_all: os.path is used intentionally for
# the security-sensitive abspath/commonpath checks below.
extract_dir = os.path.abspath(extract_dir) # noqa: PTH100
with zipfile.ZipFile(data, "r") as zip_ref:
all_members = zip_ref.infolist()
# Detect a single common top-level directory and strip it during
# extraction so we don't have to flatten it via a rename afterwards.
strip_root = _detect_archive_root(m.filename for m in all_members)
strip_prefix = f"{strip_root}/" if strip_root is not None else None
total = len(all_members)
progress = (
ProgressBar(progress_header) if progress_header and total > 0 else None
)
for i, member in enumerate(all_members, 1):
# 1. Normalize name
name = member.filename.lstrip("/\\")
# 2. Reject absolute paths / Windows drives
if Path(name).is_absolute() or (
os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206
):
continue
# 3. Strip wrapper directory if one was detected
if strip_prefix is not None:
norm = name.replace("\\", "/")
if norm in (strip_root, strip_prefix):
continue
if not norm.startswith(strip_prefix):
continue
name = norm[len(strip_prefix) :]
# 4. Compute safe target path
target_path = os.path.abspath(os.path.join(extract_dir, name)) # noqa: PTH100, PTH118
if os.path.commonpath([extract_dir, target_path]) != extract_dir:
raise ValueError(f"Unsafe path detected: {member.filename}")
# 5. Assign sanitized name back
member.filename = name
# 6. Extract
zip_ref.extract(member, extract_dir)
if progress is not None:
progress.update(i / total)
if progress is not None:
progress.update(1)
def _rename_with_retry(src: Path, dst: Path, attempts: int = 5) -> None:
"""Rename ``src`` to ``dst`` with backoff retries on Windows sharing violations.
Antivirus/indexer handles on freshly-written files can briefly block
``os.rename`` with ERROR_SHARING_VIOLATION / ERROR_ACCESS_DENIED. The
handle is released within tens of ms in practice, so exponential backoff
works.
"""
for i in range(attempts):
try:
src.rename(dst)
return
except PermissionError:
if i == attempts - 1:
raise
time.sleep(0.1 * (2**i))
def _7z_extract_all(
data: io.BufferedIOBase,
extract_dir: PathType = ".",
progress_header: str | None = None,
):
"""
Extract a 7z archive to the specified directory.
py7zr only supports bulk extraction (no per-member rename hook like
tarfile/zipfile), so we extract into a unique staging subdir of
``extract_dir`` and then move children up. This keeps everything on
the same volume and sidesteps wrapper-vs-child name collisions
(e.g. ``arm-zephyr-eabi/`` containing another ``arm-zephyr-eabi/``).
Args:
data: File-like object containing the 7z archive (must be seekable)
extract_dir: Directory to extract contents to
progress_header: If set, show a progress bar with this header
"""
import py7zr
extract_dir = os.path.abspath(extract_dir) # noqa: PTH100
Path(extract_dir).mkdir(parents=True, exist_ok=True)
suffix = 0
while True:
staging = Path(extract_dir) / f".extract_tmp_{suffix}"
if not staging.exists():
break
suffix += 1
staging.mkdir()
try:
with py7zr.SevenZipFile(data, "r") as z:
all_names = z.getnames()
# Detect a single common top-level directory to flatten.
strip_root = _detect_archive_root(all_names)
# Validate names: reject absolute paths, Windows drives, and
# path traversal. Filter via targets= since py7zr can't rename
# per-member.
safe_targets: list[str] = []
for raw in all_names:
name = raw.lstrip("/\\")
if not name:
continue
if Path(name).is_absolute() or (
os.name == "nt" and ":" in name.split(os.sep)[0] # noqa: PTH206
):
continue
target_path = os.path.abspath(os.path.join(staging, name)) # noqa: PTH100, PTH118
if os.path.commonpath([str(staging), target_path]) != str(staging):
continue
safe_targets.append(raw)
progress = (
ProgressBar(progress_header)
if progress_header and safe_targets
else None
)
if len(safe_targets) == len(all_names):
z.extractall(path=staging)
else:
z.extract(path=staging, targets=safe_targets)
if progress is not None:
progress.update(1)
src_root = staging / strip_root if strip_root else staging
for item in src_root.iterdir():
dest = Path(extract_dir) / item.name
if dest.exists():
if dest.is_dir():
rmtree(dest)
else:
dest.unlink()
_rename_with_retry(item, dest)
finally:
# staging is created before the try, so it always exists here; the
# guard is defensive cleanup and its False branch is unreachable.
if staging.exists(): # pragma: no cover
rmtree(staging)
_ARCHIVE_MAGIC_MAP = {
b"\x1f\x8b\x08": _tar_extract_all,
b"\x42\x5a\x68": _tar_extract_all,
b"\xfd\x37\x7a\x58\x5a\x00": _tar_extract_all,
b"\x50\x4b\x03\x04": _zip_extract_all,
b"\x37\x7a\xbc\xaf\x27\x1c": _7z_extract_all,
}
def archive_extract_all(
archive: PathType | io.RawIOBase | IO[bytes],
extract_dir: PathType = ".",
progress_header: str | None = None,
):
"""
Extract an archive file to the specified directory.
Args:
archive: Path to archive file or file-like object
extract_dir: Directory to extract contents to
progress_header: If set, show a progress bar with this header
Raises:
TypeError: If archive is not a valid type
ValueError: If archive format is unsupported
"""
# 1. Handle different archive input types
with ExitStack() as stack:
archive_ref: io.BufferedIOBase
if isinstance(archive, (str, os.PathLike)):
archive_ref = stack.enter_context(Path(archive).open("rb"))
elif isinstance(archive, (io.BufferedReader, io.BufferedRandom)):
archive_ref = archive
elif isinstance(archive, io.RawIOBase):
archive_ref = io.BufferedReader(archive)
else:
raise TypeError(
f"archive must be str, Path, or file-like object: {type(archive)}"
)
# 2. Detect archive format and select appropriate extraction function
matched_fct = None
magic_len = max(len(k) for k in _ARCHIVE_MAGIC_MAP)
header = archive_ref.peek(magic_len)
for magic, fct in _ARCHIVE_MAGIC_MAP.items():
if header.startswith(magic):
matched_fct = fct
break
if matched_fct is None:
raise ValueError("Unsupported archive format")
matched_fct(archive_ref, extract_dir, progress_header=progress_header)
def download_from_mirrors(
mirrors: list[str],
substitutions: dict[str, str],
target: io.RawIOBase | IO[bytes] | PathType,
timeout: int = 30,
) -> str:
"""
Download file from multiple mirrors with substitution support.
Args:
mirrors: list of mirror URLs
substitutions: Dictionary of substitutions to apply to URLs
target: Target file path or file-like object
timeout: Download timeout in seconds
Returns:
The source URL.
Raises:
ValueError: If mirrors list is empty.
Exception: If all download attempts fail.
"""
# 1. Open target file for writing if path given
with ExitStack() as stack:
if isinstance(target, (str, os.PathLike)):
f = stack.enter_context(Path(target).open("wb"))
elif isinstance(target, (io.RawIOBase, io.IOBase)):
f = target
else:
raise TypeError(
f"target must be str, Path, or file-like object: {type(target)}"
)
# 2. Try each mirror in order
last_exception = None
for mirror in mirrors:
# 3. Apply substitutions to URL
url = mirror.format(**substitutions)
_LOGGER.debug("Trying downloading from %s", url)
try:
# 4. Reset file pointer and download
f.seek(0)
f.truncate(0)
with requests.get(url, stream=True, timeout=timeout) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
downloaded = 0
progress = ProgressBar("Downloading") if total_size > 0 else None
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if progress is not None:
progress.update(downloaded / total_size)
if progress is not None:
progress.update(1)
_LOGGER.debug("Downloaded successfully from: %s", url)
# 6. Reset file pointer and return
f.seek(0)
return url
except Exception as e: # noqa: BLE001 # pylint: disable=broad-exception-caught
_LOGGER.debug("Failed to download %s: %s", url, str(e))
last_exception = e
# 7. Raise last exception if all mirrors failed
if last_exception:
raise last_exception
raise ValueError("download_from_mirrors called with an empty mirrors list")

View File

@@ -25,6 +25,7 @@ jinja2==3.1.6
bleak==2.1.1 bleak==2.1.1
smpclient==6.0.0 smpclient==6.0.0
requests==2.34.2 requests==2.34.2
py7zr==0.22.0
# esp-idf >= 5.0 requires this # esp-idf >= 5.0 requires this
pyparsing >= 3.3.2 pyparsing >= 3.3.2

View File

@@ -894,6 +894,13 @@ class TestEsphomeCore:
"foo/build/.pioenvs/test-device/bootloader.bin" "foo/build/.pioenvs/test-device/bootloader.bin"
) )
def test_using_toolchain_sdk_nrf(self, target):
"""using_toolchain_sdk_nrf is True only for the SDK_NRF toolchain."""
target.toolchain = const.Toolchain.SDK_NRF
assert target.using_toolchain_sdk_nrf is True
target.toolchain = const.Toolchain.ESP_IDF
assert target.using_toolchain_sdk_nrf is False
def test_add_library__extracts_short_name_from_path(self, target): def test_add_library__extracts_short_name_from_path(self, target):
"""Test add_library extracts short name from library paths like owner/lib.""" """Test add_library extracts short name from library paths like owner/lib."""
target.data[const.KEY_CORE] = { target.data[const.KEY_CORE] = {

View File

@@ -2,12 +2,32 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import io
import json
from pathlib import Path from pathlib import Path
import tarfile
from types import SimpleNamespace
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from esphome.espidf.framework import _clone_idf_with_submodules, _parse_git_source from esphome.espidf.framework import (
_check_stamp,
_clone_idf_with_submodules,
_get_framework_path,
_get_idf_tool_paths,
_get_idf_tools_path,
_get_idf_version,
_get_python_env_path,
_get_python_version,
_parse_git_source,
_patch_tools_json_for_linux_arm64,
_write_idf_version_txt,
_write_stamp,
check_esp_idf_install,
get_framework_env,
)
from esphome.framework_helpers import _tar_extract_all, get_python_env_executable_path
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -154,3 +174,511 @@ def test_clone_idf_with_submodules_raises_when_tree_missing(
"https://github.com/espressif/esp-idf.git", "https://github.com/espressif/esp-idf.git",
None, None,
) )
# ---------------------------------------------------------------------------
# Helpers for _tar_extract_all hard-link prefix-stripping tests
# ---------------------------------------------------------------------------
def _make_tar(
members: list[tarfile.TarInfo], file_contents: dict[str, bytes]
) -> io.BytesIO:
"""Build an in-memory tar archive from a list of TarInfo objects."""
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
for info in members:
if info.isreg() and info.name in file_contents:
data = file_contents[info.name]
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
else:
tf.addfile(info)
buf.seek(0)
return buf
def _regular(name: str) -> tarfile.TarInfo:
info = tarfile.TarInfo(name=name)
info.type = tarfile.REGTYPE
info.size = 0
info.mode = 0o644
return info
def _hardlink(name: str, linkname: str) -> tarfile.TarInfo:
info = tarfile.TarInfo(name=name)
info.type = tarfile.LNKTYPE
info.linkname = linkname
info.size = 0
info.mode = 0o644
return info
class TestTarExtractHardLinkPrefixStripping:
"""
Covers the hard-link prefix-stripping block in _tar_extract_all (L528-541).
Archive layout used by every test:
wrapper/ ← single top-level wrapper dir (stripped)
wrapper/target.txt ← regular file; becomes target.txt in dest
wrapper/link_good ← hard link to wrapper/target.txt (kept, linkname stripped)
wrapper/link_exact_root ← hard link to "wrapper" (skipped equals strip_root)
wrapper/link_exact_prefix ← hard link to "wrapper/" (skipped equals strip_prefix)
wrapper/link_outside ← hard link to "other/target.txt" (skipped not under prefix)
"""
WRAPPER = "wrapper"
def _build_archive(self) -> io.BytesIO:
members = [
_regular(f"{self.WRAPPER}/"),
_regular(f"{self.WRAPPER}/target.txt"),
_hardlink(f"{self.WRAPPER}/link_good", f"{self.WRAPPER}/target.txt"),
_hardlink(f"{self.WRAPPER}/link_exact_root", self.WRAPPER),
_hardlink(f"{self.WRAPPER}/link_exact_prefix", f"{self.WRAPPER}/"),
_hardlink(f"{self.WRAPPER}/link_outside", "other/target.txt"),
]
return _make_tar(members, {f"{self.WRAPPER}/target.txt": b"hello"})
def test_good_hardlink_is_extracted_with_stripped_linkname(
self, tmp_path: Path
) -> None:
"""Hard link whose linkname starts with wrapper/ is extracted and its
linkname has the prefix removed so tarfile can resolve the target."""
_tar_extract_all(self._build_archive(), tmp_path)
link = tmp_path / "link_good"
assert link.exists(), "link_good should have been extracted"
assert link.read_bytes() == b"hello"
def test_hardlink_equal_to_strip_root_is_skipped(self, tmp_path: Path) -> None:
"""Hard link whose linkname equals strip_root exactly must be dropped."""
_tar_extract_all(self._build_archive(), tmp_path)
assert not (tmp_path / "link_exact_root").exists()
def test_hardlink_equal_to_strip_prefix_is_skipped(self, tmp_path: Path) -> None:
"""Hard link whose linkname equals strip_prefix (strip_root + '/') must be dropped."""
_tar_extract_all(self._build_archive(), tmp_path)
assert not (tmp_path / "link_exact_prefix").exists()
def test_hardlink_outside_prefix_is_skipped(self, tmp_path: Path) -> None:
"""Hard link whose linkname does not start with wrapper/ must be dropped."""
_tar_extract_all(self._build_archive(), tmp_path)
assert not (tmp_path / "link_outside").exists()
def test_regular_file_and_no_spurious_files(self, tmp_path: Path) -> None:
"""Sanity check: target.txt is extracted and no unexpected files appear."""
_tar_extract_all(self._build_archive(), tmp_path)
assert (tmp_path / "target.txt").read_bytes() == b"hello"
extracted = {p.name for p in tmp_path.iterdir()}
assert extracted == {"target.txt", "link_good"}
_IDF_VERSION = "5.1.2"
@pytest.fixture
def espidf_mocks(setup_core: Path):
"""Patch the heavy I/O of check_esp_idf_install and pre-create the framework dir."""
# archive_extract_all is mocked, so pre-create the framework dir that the
# extracted-marker touch writes into.
_get_framework_path(_IDF_VERSION).mkdir(parents=True, exist_ok=True)
with (
patch("esphome.espidf.framework.rmdir"),
patch(
"esphome.espidf.framework.download_from_mirrors",
return_value="https://example.com/idf.tar.xz",
) as download,
patch("esphome.espidf.framework.archive_extract_all") as extract,
patch("esphome.espidf.framework.create_venv") as venv,
patch("esphome.espidf.framework.run_command_ok", return_value=True) as run_ok,
patch("esphome.espidf.framework._clone_idf_with_submodules") as clone,
patch("esphome.espidf.framework._write_idf_version_txt"),
patch("esphome.espidf.framework._patch_tools_json_for_linux_arm64"),
patch("esphome.espidf.framework._write_stamp"),
patch("esphome.espidf.framework._check_stamp", return_value=True),
patch("esphome.espidf.framework._get_idf_version", return_value=_IDF_VERSION),
patch("esphome.espidf.framework._get_python_version", return_value="3.11.0"),
patch("esphome.espidf.framework.get_system_python_path", return_value="python"),
):
yield SimpleNamespace(
download=download, extract=extract, venv=venv, run_ok=run_ok, clone=clone
)
def test_check_esp_idf_install_fresh(espidf_mocks: SimpleNamespace) -> None:
"""A forced install drives download/extract, venv creation, and pip installs."""
framework_path, python_env_path = check_esp_idf_install(_IDF_VERSION, force=True)
assert framework_path == _get_framework_path(_IDF_VERSION)
assert python_env_path == _get_python_env_path(_IDF_VERSION)
# framework tarball + python-env constraints file are both downloaded
assert espidf_mocks.download.call_count == 2
espidf_mocks.extract.assert_called_once()
espidf_mocks.venv.assert_called_once()
espidf_mocks.clone.assert_not_called()
def test_check_esp_idf_install_git_source(espidf_mocks: SimpleNamespace) -> None:
"""A git source_url clones instead of downloading; explicit tools skip discovery."""
check_esp_idf_install(
_IDF_VERSION,
force=True,
source_url="https://github.com/espressif/esp-idf.git",
tools=["xtensa-esp-elf"],
)
espidf_mocks.clone.assert_called_once()
# framework is cloned, so only the python-env constraints file is downloaded
assert espidf_mocks.download.call_count == 1
def test_check_esp_idf_install_already_installed(espidf_mocks: SimpleNamespace) -> None:
"""Marker + matching stamps + existing python env → nothing is re-installed."""
framework_path = _get_framework_path(_IDF_VERSION)
(framework_path / ".esphome_extracted").touch()
python_env_path = _get_python_env_path(_IDF_VERSION)
env_python = get_python_env_executable_path(python_env_path, "python")
env_python.parent.mkdir(parents=True, exist_ok=True)
env_python.touch()
check_esp_idf_install(_IDF_VERSION)
espidf_mocks.extract.assert_not_called()
espidf_mocks.venv.assert_not_called()
def test_check_esp_idf_install_framework_failure(espidf_mocks: SimpleNamespace) -> None:
"""A failing idf_tools install raises."""
espidf_mocks.run_ok.side_effect = [False]
with pytest.raises(RuntimeError, match="framework installation failure"):
check_esp_idf_install(_IDF_VERSION, force=True)
def test_check_esp_idf_install_pip_upgrade_failure(
espidf_mocks: SimpleNamespace,
) -> None:
"""A failing pip upgrade in the python env raises (framework install ok)."""
espidf_mocks.run_ok.side_effect = [True, False]
with pytest.raises(RuntimeError, match="Python environment packages failure"):
check_esp_idf_install(_IDF_VERSION, force=True)
def test_check_esp_idf_install_feature_failure(espidf_mocks: SimpleNamespace) -> None:
"""A failing feature requirements install raises."""
espidf_mocks.run_ok.side_effect = [True, True, False]
with pytest.raises(RuntimeError, match="Python dependencies for"):
check_esp_idf_install(_IDF_VERSION, force=True, features=["fb"])
def _mark_installed() -> None:
"""Create the extracted marker and python-env interpreter so the install
check takes the already-installed path rather than force-installing."""
(_get_framework_path(_IDF_VERSION) / ".esphome_extracted").touch()
env_python = get_python_env_executable_path(
_get_python_env_path(_IDF_VERSION), "python"
)
env_python.parent.mkdir(parents=True, exist_ok=True)
env_python.touch()
def test_check_esp_idf_install_stamp_mismatch_reinstalls(
espidf_mocks: SimpleNamespace,
) -> None:
"""A stamp mismatch reinstalls tools (marker present, so no re-extract)."""
_mark_installed()
with patch("esphome.espidf.framework._check_stamp", return_value=False):
check_esp_idf_install(_IDF_VERSION)
espidf_mocks.extract.assert_not_called() # marker present -> no re-extract
espidf_mocks.venv.assert_called_once() # tools reinstall -> venv rebuilt
def test_check_esp_idf_install_check_command_failure_reinstalls(
espidf_mocks: SimpleNamespace,
) -> None:
"""A failing idf_tools check reinstalls tools (marker present, no re-extract)."""
_mark_installed()
# idf_tools check fails -> install stays True; the later installs succeed.
espidf_mocks.run_ok.side_effect = [False, True, True, True]
check_esp_idf_install(_IDF_VERSION, features=["fb"])
espidf_mocks.extract.assert_not_called()
espidf_mocks.venv.assert_called_once()
def test_check_esp_idf_install_unknown_python_version_reinstalls(
espidf_mocks: SimpleNamespace,
) -> None:
"""An undeterminable python version rebuilds the venv (framework stamp still ok)."""
_mark_installed()
with patch("esphome.espidf.framework._get_python_version", return_value=None):
check_esp_idf_install(_IDF_VERSION)
espidf_mocks.extract.assert_not_called() # framework stamp matched
espidf_mocks.venv.assert_called_once() # python env rebuilt
def test_check_esp_idf_install_python_stamp_mismatch_rebuilds_venv(
espidf_mocks: SimpleNamespace,
) -> None:
"""Framework stamp matches but the python-env stamp does not -> venv rebuilt."""
# _check_stamp passes for the framework (no python_version key) and fails
# for the python env (carries python_version), so only the venv rebuilds.
def stamp_ok(_stamp_file, info: dict) -> bool:
return "python_version" not in info
_mark_installed()
with patch("esphome.espidf.framework._check_stamp", side_effect=stamp_ok):
check_esp_idf_install(_IDF_VERSION)
espidf_mocks.extract.assert_not_called()
espidf_mocks.venv.assert_called_once()
def test_check_esp_idf_install_unparseable_version(
espidf_mocks: SimpleNamespace,
) -> None:
"""A non-semver version skips the MAJOR/MINOR substitutions without erroring."""
bad_version = "main"
_get_framework_path(bad_version).mkdir(parents=True, exist_ok=True)
check_esp_idf_install(bad_version, force=True)
espidf_mocks.extract.assert_called_once()
# ---------------------------------------------------------------------------
# _patch_tools_json_for_linux_arm64 (arm64-only ninja backport)
# ---------------------------------------------------------------------------
def _write_tools_json(framework_path: Path, data: dict) -> Path:
tools_dir = framework_path / "tools"
tools_dir.mkdir(parents=True, exist_ok=True)
tools_json = tools_dir / "tools.json"
tools_json.write_text(json.dumps(data), encoding="utf-8")
return tools_json
def test_patch_tools_json_non_aarch64_is_noop(tmp_path: Path) -> None:
tools_json = _write_tools_json(
tmp_path, {"tools": [{"name": "ninja", "versions": [{"name": "1.12.1"}]}]}
)
before = tools_json.read_text(encoding="utf-8")
with patch("esphome.espidf.framework.platform.machine", return_value="x86_64"):
_patch_tools_json_for_linux_arm64(tmp_path)
assert tools_json.read_text(encoding="utf-8") == before
def test_patch_tools_json_missing_file_is_noop(tmp_path: Path) -> None:
with patch("esphome.espidf.framework.platform.machine", return_value="aarch64"):
_patch_tools_json_for_linux_arm64(tmp_path) # no tools/tools.json present
def test_patch_tools_json_corrupt_file_warns_and_skips(tmp_path: Path) -> None:
(tmp_path / "tools").mkdir()
(tmp_path / "tools" / "tools.json").write_text("{ not json", encoding="utf-8")
with patch("esphome.espidf.framework.platform.machine", return_value="aarch64"):
_patch_tools_json_for_linux_arm64(tmp_path) # JSONDecodeError -> skip
def test_patch_tools_json_injects_ninja_arm64(tmp_path: Path) -> None:
tools_json = _write_tools_json(
tmp_path,
{
"tools": [
{"name": "ninja", "versions": [{"name": "1.12.1"}]},
{"name": "cmake", "versions": [{"name": "3.24.0"}]},
]
},
)
with patch("esphome.espidf.framework.platform.machine", return_value="aarch64"):
_patch_tools_json_for_linux_arm64(tmp_path)
data = json.loads(tools_json.read_text(encoding="utf-8"))
ninja = next(t for t in data["tools"] if t["name"] == "ninja")
assert "linux-arm64" in ninja["versions"][0]
assert ninja["versions"][0]["linux-arm64"]["size"] == 121787
def test_patch_tools_json_already_patched_is_noop(tmp_path: Path) -> None:
tools_json = _write_tools_json(
tmp_path,
{
"tools": [
{
"name": "ninja",
"versions": [{"name": "1.12.1", "linux-arm64": {"url": "x"}}],
}
]
},
)
before = tools_json.read_text(encoding="utf-8")
with patch("esphome.espidf.framework.platform.machine", return_value="aarch64"):
_patch_tools_json_for_linux_arm64(tmp_path)
assert tools_json.read_text(encoding="utf-8") == before
# ---------------------------------------------------------------------------
# Subprocess-backed helpers (_exec -> run_command rename) and get_framework_env
# ---------------------------------------------------------------------------
def test_get_idf_version_parses_stdout(tmp_path: Path) -> None:
with patch(
"esphome.espidf.framework.run_command", return_value=(True, "5.1.2\n", "")
):
assert _get_idf_version(tmp_path) == "5.1.2"
def test_get_idf_version_raises_on_failure(tmp_path: Path) -> None:
with (
patch("esphome.espidf.framework.run_command", return_value=(False, "", "boom")),
pytest.raises(RuntimeError, match="Can't get ESP-IDF version"),
):
_get_idf_version(tmp_path)
def test_get_idf_tool_paths_parses_json(tmp_path: Path) -> None:
payload = json.dumps({"paths_to_export": ["/a", "/b"], "export_vars": {"X": "1"}})
with patch(
"esphome.espidf.framework.run_command", return_value=(True, payload, "")
):
paths, export_vars = _get_idf_tool_paths(tmp_path)
assert paths == ["/a", "/b"]
assert export_vars == {"X": "1"}
def test_get_idf_tool_paths_raises_on_bad_json(tmp_path: Path) -> None:
with (
patch(
"esphome.espidf.framework.run_command", return_value=(True, "not json", "")
),
pytest.raises(RuntimeError, match="Can't extract ESP-IDF tool paths"),
):
_get_idf_tool_paths(tmp_path)
def test_get_idf_tool_paths_raises_on_failure(tmp_path: Path) -> None:
with (
patch("esphome.espidf.framework.run_command", return_value=(False, "", "err")),
pytest.raises(RuntimeError, match="Can't get ESP-IDF tool paths"),
):
_get_idf_tool_paths(tmp_path)
def test_get_python_version_parses_stdout(tmp_path: Path) -> None:
with patch(
"esphome.espidf.framework.run_command", return_value=(True, "3.11.0\n", "")
):
assert _get_python_version(tmp_path / "python") == "3.11.0"
def test_get_python_version_returns_falsy_on_failure(tmp_path: Path) -> None:
with patch("esphome.espidf.framework.run_command", return_value=(False, "", "")):
# non-throwing failure returns the (empty) stdout as-is
assert not _get_python_version(tmp_path / "python")
def test_get_python_version_raises_when_requested(tmp_path: Path) -> None:
with (
patch("esphome.espidf.framework.run_command", return_value=(False, "", "")),
pytest.raises(RuntimeError, match="Can't get Python version"),
):
_get_python_version(tmp_path / "python", throw_exception=True)
def test_write_stamp_writes_json(tmp_path: Path) -> None:
stamp = tmp_path / "stamp.json"
_write_stamp(stamp, {"a": "1", "b": "2"})
assert json.loads(stamp.read_text(encoding="utf-8")) == {"a": "1", "b": "2"}
def test_get_framework_env_with_python_env(tmp_path: Path) -> None:
with (
patch(
"esphome.espidf.framework._get_idf_tools_path",
return_value=tmp_path / "tools",
),
patch("esphome.espidf.framework._get_idf_version", return_value="5.1.2"),
patch(
"esphome.espidf.framework._get_idf_tool_paths",
return_value=(["/tool/bin"], {"IDF_X": "1"}),
),
):
env = get_framework_env(
tmp_path / "fw", tmp_path / "penv", {"PATH": "/usr/bin"}
)
assert env["IDF_PATH"] == str(tmp_path / "fw")
assert env["ESP_IDF_VERSION"] == "5.1.2"
assert env["IDF_X"] == "1"
assert env["IDF_PYTHON_ENV_PATH"] == str(tmp_path / "penv")
assert "/tool/bin" in env["PATH"]
def test_get_framework_env_without_python_env_uses_os_path(tmp_path: Path) -> None:
with (
patch(
"esphome.espidf.framework._get_idf_tools_path",
return_value=tmp_path / "tools",
),
patch("esphome.espidf.framework._get_idf_version", return_value="5.1.2"),
patch("esphome.espidf.framework._get_idf_tool_paths", return_value=([], {})),
):
env = get_framework_env(tmp_path / "fw")
assert "IDF_PYTHON_ENV_PATH" not in env
assert env["PATH"] # taken from os.environ
# ---------------------------------------------------------------------------
# _check_stamp / _write_idf_version_txt / _get_idf_tools_path
# ---------------------------------------------------------------------------
def test_check_stamp_matches(tmp_path: Path) -> None:
f = tmp_path / "s.json"
f.write_text(json.dumps({"a": "1"}), encoding="utf-8")
assert _check_stamp(f, {"a": "1"}) is True
def test_check_stamp_mismatch(tmp_path: Path) -> None:
f = tmp_path / "s.json"
f.write_text(json.dumps({"a": "1"}), encoding="utf-8")
assert _check_stamp(f, {"a": "2"}) is False
def test_check_stamp_missing_file(tmp_path: Path) -> None:
assert _check_stamp(tmp_path / "nope.json", {"a": "1"}) is False
def test_check_stamp_corrupt_file(tmp_path: Path) -> None:
f = tmp_path / "s.json"
f.write_text("{ not json", encoding="utf-8")
assert _check_stamp(f, {"a": "1"}) is False
def test_write_idf_version_txt_writes_when_missing(tmp_path: Path) -> None:
_write_idf_version_txt(tmp_path, "5.1.2")
assert (tmp_path / "version.txt").read_text(encoding="utf-8") == "v5.1.2\n"
def test_write_idf_version_txt_skips_when_present(tmp_path: Path) -> None:
(tmp_path / "version.txt").write_text("existing\n", encoding="utf-8")
_write_idf_version_txt(tmp_path, "5.1.2")
assert (tmp_path / "version.txt").read_text(encoding="utf-8") == "existing\n"
def test_get_idf_tools_path_env_override(tmp_path: Path) -> None:
override = str(tmp_path / "custom-idf")
with patch.dict("os.environ", {"ESPHOME_ESP_IDF_PREFIX": override}):
assert _get_idf_tools_path() == Path(override)
def test_write_idf_version_txt_warns_on_write_error(tmp_path: Path) -> None:
with patch("pathlib.Path.write_text", side_effect=OSError("denied")):
# write failure is caught and warned, not raised
_write_idf_version_txt(tmp_path, "5.1.2")

View File

@@ -0,0 +1,954 @@
"""Tests for esphome.framework_helpers."""
# pylint: disable=protected-access
import importlib.util
import io
import logging
import os
from pathlib import Path
import subprocess
import sys
import tarfile
from unittest.mock import MagicMock, Mock, patch
import zipfile
import pytest
import requests as req
from esphome.framework_helpers import (
_7z_extract_all,
_detect_archive_root,
_rename_with_retry,
_tar_extract_all,
_zip_extract_all,
archive_extract_all,
create_venv,
download_from_mirrors,
get_python_env_executable_path,
get_system_python_path,
rmdir,
run_command,
run_command_ok,
str_to_lst_of_str,
)
_HAS_PY7ZR = importlib.util.find_spec("py7zr") is not None
# ---------------------------------------------------------------------------
# str_to_lst_of_str
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
("value", "expected"),
[
("a;b;c", ["a", "b", "c"]),
(" a ; b ", ["a", "b"]),
(";; a ;;", ["a"]),
("single", ["single"]),
("", []),
(["already", "a", "list"], ["already", "a", "list"]),
],
)
def test_str_to_lst_of_str(value: str | list, expected: list) -> None:
assert str_to_lst_of_str(value) == expected
# ---------------------------------------------------------------------------
# rmdir
# ---------------------------------------------------------------------------
def test_rmdir_nonexistent_is_noop(tmp_path: Path) -> None:
rmdir(tmp_path / "missing")
def test_rmdir_removes_existing_directory(tmp_path: Path) -> None:
d = tmp_path / "to_remove"
d.mkdir()
(d / "file.txt").write_text("x")
rmdir(d)
assert not d.exists()
def test_rmdir_logs_debug_with_msg(
tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
d = tmp_path / "logged"
d.mkdir()
with caplog.at_level(logging.DEBUG, logger="esphome.framework_helpers"):
rmdir(d, msg="cleanup message")
assert "cleanup message" in caplog.text
def test_rmdir_raises_runtime_error_on_os_error(tmp_path: Path) -> None:
d = tmp_path / "stubborn"
d.mkdir()
with (
patch("esphome.framework_helpers.rmtree", side_effect=OSError("perm denied")),
pytest.raises(RuntimeError, match="can't remove"),
):
rmdir(d, msg="cleanup step")
# ---------------------------------------------------------------------------
# get_system_python_path
# ---------------------------------------------------------------------------
def test_get_system_python_path_returns_env_var() -> None:
with patch.dict(os.environ, {"PYTHONEXEPATH": "/custom/python"}):
assert get_system_python_path() == "/custom/python"
def test_get_system_python_path_falls_back_to_sys_executable() -> None:
env = {k: v for k, v in os.environ.items() if k != "PYTHONEXEPATH"}
with patch.dict(os.environ, env, clear=True):
assert get_system_python_path() == os.path.normpath(sys.executable)
# ---------------------------------------------------------------------------
# get_python_env_executable_path
# ---------------------------------------------------------------------------
@pytest.mark.skipif(os.name != "posix", reason="PosixPath construction requires POSIX")
def test_get_python_env_executable_path_posix() -> None:
assert get_python_env_executable_path("/env", "python") == Path("/env/bin/python")
@pytest.mark.skipif(os.name != "nt", reason="WindowsPath construction requires Windows")
def test_get_python_env_executable_path_windows() -> None:
assert get_python_env_executable_path("/env", "python") == Path(
"/env/Scripts/python.exe"
)
# ---------------------------------------------------------------------------
# run_command
# ---------------------------------------------------------------------------
def test_run_command_success_returns_stdout(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.return_value = Mock(returncode=0, stdout="out\n", stderr="")
ok, stdout, _stderr = run_command(["echo", "hello"])
assert ok is True
assert stdout == "out\n"
def test_run_command_failure_returns_false(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.return_value = Mock(returncode=1, stdout="", stderr="boom")
ok, _stdout, stderr = run_command(["bad"])
assert ok is False
assert stderr == "boom"
def test_run_command_stream_output_success(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.return_value = Mock(returncode=0)
ok, stdout, stderr = run_command(["cmd"], stream_output=True)
assert ok is True
assert stdout is None
assert stderr is None
def test_run_command_stream_output_failure(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.return_value = Mock(returncode=2)
ok, stdout, _stderr = run_command(["cmd"], stream_output=True)
assert ok is False
assert stdout is None
def test_run_command_subprocess_error_returns_false(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.side_effect = subprocess.SubprocessError("exploded")
ok, stdout, stderr = run_command(["cmd"])
assert ok is False
assert stdout is None
assert stderr is None
def test_run_command_os_error_returns_false(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.side_effect = OSError("not found")
ok, _stdout, _stderr = run_command(["cmd"])
assert ok is False
def test_run_command_passes_env(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.return_value = Mock(returncode=0, stdout="", stderr="")
run_command(["cmd"], env={"MY_VAR": "42"})
assert mock_subprocess_run.call_args[1]["env"]["MY_VAR"] == "42"
def test_run_command_passes_cwd(mock_subprocess_run: Mock, tmp_path: Path) -> None:
mock_subprocess_run.return_value = Mock(returncode=0, stdout="", stderr="")
run_command(["cmd"], cwd=str(tmp_path))
assert mock_subprocess_run.call_args[1]["cwd"] == str(tmp_path)
# ---------------------------------------------------------------------------
# run_command_ok
# ---------------------------------------------------------------------------
def test_run_command_ok_true(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.return_value = Mock(returncode=0, stdout="", stderr="")
assert run_command_ok(["cmd"]) is True
def test_run_command_ok_false(mock_subprocess_run: Mock) -> None:
mock_subprocess_run.return_value = Mock(returncode=1, stdout="", stderr="")
assert run_command_ok(["cmd"]) is False
# ---------------------------------------------------------------------------
# create_venv
# ---------------------------------------------------------------------------
def test_create_venv_calls_run_command_ok(tmp_path: Path) -> None:
with patch(
"esphome.framework_helpers.run_command_ok", return_value=True
) as mock_cmd:
create_venv(tmp_path / "env", msg="test")
mock_cmd.assert_called_once()
def test_create_venv_raises_on_failure(tmp_path: Path) -> None:
with (
patch("esphome.framework_helpers.run_command_ok", return_value=False),
pytest.raises(RuntimeError, match="Can't create Python virtual environment"),
):
create_venv(tmp_path / "env", msg="test")
# ---------------------------------------------------------------------------
# _detect_archive_root
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
("names", "expected"),
[
(["wrapper/", "wrapper/a.txt", "wrapper/sub/b.txt"], "wrapper"),
(["root1/a.txt", "root2/b.txt"], None),
(["wrapper"], None), # no descendant → None
(["", "wrapper/file.txt"], "wrapper"), # empty names skipped
(["wrapper\\file.txt"], "wrapper"), # backslash normalised
(["w/a", "w/b", "w/c"], "w"),
],
)
def test_detect_archive_root(names: list[str], expected: str | None) -> None:
assert _detect_archive_root(names) == expected
# ---------------------------------------------------------------------------
# Tar archive helpers
# ---------------------------------------------------------------------------
def _make_tar(
members: list[tarfile.TarInfo],
file_contents: dict[str, bytes] | None = None,
) -> io.BytesIO:
buf = io.BytesIO()
contents = file_contents or {}
with tarfile.open(fileobj=buf, mode="w") as tf:
for info in members:
if info.isreg() and info.name in contents:
data = contents[info.name]
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
else:
tf.addfile(info)
buf.seek(0)
return buf
def _reg(name: str) -> tarfile.TarInfo:
info = tarfile.TarInfo(name=name)
info.type = tarfile.REGTYPE
info.size = 0
info.mode = 0o644
return info
def _dir(name: str) -> tarfile.TarInfo:
info = tarfile.TarInfo(name=name)
info.type = tarfile.DIRTYPE
info.mode = 0o755
return info
def _sym(name: str, target: str) -> tarfile.TarInfo:
info = tarfile.TarInfo(name=name)
info.type = tarfile.SYMTYPE
info.linkname = target
info.mode = 0o777
return info
def _special(name: str) -> tarfile.TarInfo:
info = tarfile.TarInfo(name=name)
info.type = tarfile.CHRTYPE
info.mode = 0o600
return info
def _hlnk(name: str, target: str) -> tarfile.TarInfo:
info = tarfile.TarInfo(name=name)
info.type = tarfile.LNKTYPE
info.linkname = target
info.mode = 0o644
return info
# ---------------------------------------------------------------------------
# _tar_extract_all — branches not covered by the hard-link prefix-strip tests
# ---------------------------------------------------------------------------
class TestTarExtractAllSecurity:
def test_flat_archive_no_wrapper(self, tmp_path: Path) -> None:
"""Without a single common root files land directly in extract_dir."""
buf = _make_tar(
[_reg("a.txt"), _reg("b.txt")],
{"a.txt": b"aaa", "b.txt": b"bbb"},
)
_tar_extract_all(buf, tmp_path)
assert (tmp_path / "a.txt").read_bytes() == b"aaa"
assert (tmp_path / "b.txt").read_bytes() == b"bbb"
def test_directory_member_extracted(self, tmp_path: Path) -> None:
buf = _make_tar([_dir("subdir/")])
_tar_extract_all(buf, tmp_path)
assert (tmp_path / "subdir").is_dir()
def test_symlink_within_dest_extracted(self, tmp_path: Path) -> None:
buf = _make_tar(
[_reg("target.txt"), _sym("link.txt", "target.txt")],
{"target.txt": b"data"},
)
_tar_extract_all(buf, tmp_path)
assert (tmp_path / "link.txt").exists()
def test_path_traversal_skipped(self, tmp_path: Path) -> None:
"""Member resolving outside extract_dir via .. is silently skipped."""
info = tarfile.TarInfo(name="sub/../../escape.txt")
info.type = tarfile.REGTYPE
info.size = 5
info.mode = 0o644
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
tf.addfile(info, io.BytesIO(b"OOPS!"))
buf.seek(0)
_tar_extract_all(buf, tmp_path)
assert not (tmp_path.parent / "escape.txt").exists()
assert not list(tmp_path.rglob("escape.txt"))
def test_absolute_symlink_target_skipped(self, tmp_path: Path) -> None:
"""Symlink pointing to an absolute path is silently skipped."""
buf = _make_tar(
[_reg("real.txt"), _sym("danger.lnk", "/etc/passwd")],
{"real.txt": b"ok"},
)
_tar_extract_all(buf, tmp_path)
assert not (tmp_path / "danger.lnk").exists()
def test_symlink_escaping_dest_skipped(self, tmp_path: Path) -> None:
"""Symlink whose resolved path exits extract_dir is silently skipped."""
buf = _make_tar([_sym("up.lnk", "../outside.txt")])
_tar_extract_all(buf, tmp_path)
assert not (tmp_path / "up.lnk").exists()
def test_special_file_skipped(self, tmp_path: Path) -> None:
"""Character-device and other special-file members are silently skipped."""
buf = _make_tar([_special("chardev")])
_tar_extract_all(buf, tmp_path)
assert not (tmp_path / "chardev").exists()
@pytest.mark.skipif(
os.name == "nt", reason="Windows has no POSIX executable permission bit"
)
def test_executable_bit_preserved(self, tmp_path: Path) -> None:
"""User-executable bit is kept for explicitly executable files."""
info = _reg("script.sh")
info.mode = 0o755
buf = _make_tar([info], {"script.sh": b"#!/bin/sh"})
_tar_extract_all(buf, tmp_path)
assert (tmp_path / "script.sh").stat().st_mode & 0o100 # S_IXUSR
def test_non_executable_exec_bits_stripped(self, tmp_path: Path) -> None:
"""Exec bits are removed when S_IXUSR is not set."""
info = _reg("data.bin")
info.mode = 0o654 # group/other exec present, user exec absent
buf = _make_tar([info], {"data.bin": b"\x00"})
_tar_extract_all(buf, tmp_path)
mode = (tmp_path / "data.bin").stat().st_mode
assert not (mode & 0o111) # all exec bits cleared
# ---------------------------------------------------------------------------
# ZIP archive helper
# ---------------------------------------------------------------------------
def _make_zip(entries: list[tuple[str, str | bytes]]) -> io.BytesIO:
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
for name, content in entries:
zf.writestr(name, content)
buf.seek(0)
return buf
# ---------------------------------------------------------------------------
# _zip_extract_all
# ---------------------------------------------------------------------------
class TestZipExtractAll:
def test_basic_extraction_strips_wrapper(self, tmp_path: Path) -> None:
buf = _make_zip([("wrapper/file.txt", "hello")])
_zip_extract_all(buf, tmp_path)
assert (tmp_path / "file.txt").read_text() == "hello"
def test_flat_archive_no_wrapper(self, tmp_path: Path) -> None:
buf = _make_zip([("a.txt", "aaa"), ("b.txt", "bbb")])
_zip_extract_all(buf, tmp_path)
assert (tmp_path / "a.txt").read_text() == "aaa"
assert (tmp_path / "b.txt").read_text() == "bbb"
def test_wrapper_root_entry_skipped(self, tmp_path: Path) -> None:
"""The wrapper directory entry itself (step 3a) does not appear in dest."""
buf = _make_zip([("wrapper/", ""), ("wrapper/file.txt", "content")])
_zip_extract_all(buf, tmp_path)
assert (tmp_path / "file.txt").read_text() == "content"
assert not (tmp_path / "wrapper").exists()
def test_path_traversal_raises(self, tmp_path: Path) -> None:
# Two members with different roots so _detect_archive_root returns None
# and strip_prefix is not applied, leaving "../escape.txt" to hit the
# commonpath safety check directly.
buf = _make_zip([("safe.txt", "ok"), ("../escape.txt", "bad")])
with pytest.raises(ValueError, match="Unsafe path"):
_zip_extract_all(buf, tmp_path)
def test_multiple_files_extracted(self, tmp_path: Path) -> None:
entries = [(f"root/{c}.txt", c * 3) for c in "abc"]
buf = _make_zip(entries)
_zip_extract_all(buf, tmp_path)
for c in "abc":
assert (tmp_path / f"{c}.txt").read_text() == c * 3
# ---------------------------------------------------------------------------
# archive_extract_all dispatch
# ---------------------------------------------------------------------------
def _gzip_tar_bytes(entries: dict[str, bytes]) -> bytes:
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
for name, content in entries.items():
info = tarfile.TarInfo(name=name)
info.size = len(content)
info.mode = 0o644
tf.addfile(info, io.BytesIO(content))
return buf.getvalue()
class TestArchiveExtractAll:
def test_path_input_gzip_tar(self, tmp_path: Path) -> None:
archive = tmp_path / "test.tar.gz"
archive.write_bytes(_gzip_tar_bytes({"file.txt": b"hello"}))
dest = tmp_path / "out"
dest.mkdir()
archive_extract_all(archive, dest)
assert (dest / "file.txt").read_bytes() == b"hello"
def test_buffered_reader_input(self, tmp_path: Path) -> None:
archive = tmp_path / "test.tar.gz"
archive.write_bytes(_gzip_tar_bytes({"file.txt": b"data"}))
dest = tmp_path / "out"
dest.mkdir()
with archive.open("rb") as f: # io.BufferedReader
archive_extract_all(f, dest)
assert (dest / "file.txt").read_bytes() == b"data"
def test_rawio_input(self, tmp_path: Path) -> None:
archive = tmp_path / "test.tar.gz"
archive.write_bytes(_gzip_tar_bytes({"file.txt": b"raw"}))
dest = tmp_path / "out"
dest.mkdir()
archive_extract_all(io.FileIO(archive), dest)
assert (dest / "file.txt").read_bytes() == b"raw"
def test_zip_dispatched(self, tmp_path: Path) -> None:
archive = tmp_path / "test.zip"
archive.write_bytes(_make_zip([("file.txt", "hi")]).getvalue())
dest = tmp_path / "out"
dest.mkdir()
archive_extract_all(archive, dest)
assert (dest / "file.txt").read_text() == "hi"
def test_invalid_type_raises_type_error(self) -> None:
with pytest.raises(TypeError, match="archive must be"):
archive_extract_all(42, ".") # type: ignore[arg-type]
def test_unsupported_format_raises_value_error(self, tmp_path: Path) -> None:
bad = tmp_path / "bad.bin"
bad.write_bytes(b"\x00\x01\x02\x03\x04\x05\x06")
with pytest.raises(ValueError, match="Unsupported archive format"):
archive_extract_all(bad, tmp_path)
# ---------------------------------------------------------------------------
# download_from_mirrors
# ---------------------------------------------------------------------------
def _mock_response(content: bytes, ok: bool = True) -> MagicMock:
r = MagicMock()
r.__enter__.return_value = r
r.__exit__.return_value = False
if ok:
r.raise_for_status.return_value = None
else:
r.raise_for_status.side_effect = req.HTTPError("503")
r.headers = {"content-length": "0"} # suppress ProgressBar
r.iter_content.return_value = [content] if content else []
return r
class TestDownloadFromMirrors:
def test_success_returns_url_and_writes_content(self, tmp_path: Path) -> None:
target = tmp_path / "out.bin"
with patch(
"esphome.framework_helpers.requests.get",
return_value=_mock_response(b"filedata"),
):
url = download_from_mirrors(["https://example.com/f"], {}, target)
assert url == "https://example.com/f"
assert target.read_bytes() == b"filedata"
def test_substitutions_applied_to_url(self, tmp_path: Path) -> None:
with patch(
"esphome.framework_helpers.requests.get",
return_value=_mock_response(b"x"),
) as mock_get:
download_from_mirrors(
["https://example.com/{VERSION}.bin"],
{"VERSION": "1.2.3"},
tmp_path / "out.bin",
)
assert mock_get.call_args[0][0] == "https://example.com/1.2.3.bin"
def test_falls_back_to_second_mirror(self, tmp_path: Path) -> None:
with patch(
"esphome.framework_helpers.requests.get",
side_effect=[_mock_response(b"", ok=False), _mock_response(b"second")],
):
url = download_from_mirrors(
["https://mirror1.com/f", "https://mirror2.com/f"],
{},
tmp_path / "out.bin",
)
assert url == "https://mirror2.com/f"
assert (tmp_path / "out.bin").read_bytes() == b"second"
def test_all_mirrors_fail_reraises_last_exception(self, tmp_path: Path) -> None:
with (
patch(
"esphome.framework_helpers.requests.get",
return_value=_mock_response(b"", ok=False),
),
pytest.raises(req.HTTPError),
):
download_from_mirrors(["https://example.com/f"], {}, tmp_path / "out.bin")
def test_empty_mirrors_raises_value_error(self, tmp_path: Path) -> None:
with pytest.raises(ValueError, match="empty mirrors list"):
download_from_mirrors([], {}, tmp_path / "out.bin")
def test_invalid_target_type_raises_type_error(self) -> None:
with pytest.raises(TypeError, match="target must be"):
download_from_mirrors(["https://example.com/f"], {}, 42) # type: ignore[arg-type]
def test_file_like_target_written(self) -> None:
buf = io.BytesIO()
with patch(
"esphome.framework_helpers.requests.get",
return_value=_mock_response(b"bytes"),
):
download_from_mirrors(["https://example.com/f"], {}, buf)
buf.seek(0)
assert buf.read() == b"bytes"
def test_progress_bar_shown_when_content_length_known(self, tmp_path: Path) -> None:
r = _mock_response(b"1234567890")
r.headers = {"content-length": "10"}
with (
patch("esphome.framework_helpers.requests.get", return_value=r),
patch("esphome.framework_helpers.ProgressBar") as mock_pb,
):
download_from_mirrors(["https://example.com/f"], {}, tmp_path / "out.bin")
mock_pb.assert_called_once_with("Downloading")
mock_pb.return_value.update.assert_called()
def test_empty_chunk_not_written(self, tmp_path: Path) -> None:
"""Empty chunks yielded by iter_content are skipped without writing."""
r = MagicMock()
r.__enter__.return_value = r
r.__exit__.return_value = False
r.raise_for_status.return_value = None
r.headers = {"content-length": "0"}
r.iter_content.return_value = [b""] # one empty chunk
target = tmp_path / "out.bin"
with patch("esphome.framework_helpers.requests.get", return_value=r):
download_from_mirrors(["https://example.com/f"], {}, target)
assert target.exists()
assert target.read_bytes() == b""
# ---------------------------------------------------------------------------
# get_python_env_executable_path — Windows branch
# ---------------------------------------------------------------------------
def test_get_python_env_executable_path_nt() -> None:
"""Windows path uses Scripts/ and .exe suffix."""
from pathlib import PurePosixPath
with (
patch.object(os, "name", "nt"),
patch("esphome.framework_helpers.Path", PurePosixPath),
):
result = get_python_env_executable_path("/env", "python")
assert str(result) == "/env/Scripts/python.exe"
# ---------------------------------------------------------------------------
# _tar_extract_all — additional branch coverage
# ---------------------------------------------------------------------------
class TestTarExtractAllBranches:
@pytest.mark.skipif(
sys.version_info < (3, 12),
reason="patching os.name makes pathlib build a WindowsPath, which only "
"instantiates on POSIX in 3.12+",
)
def test_windows_drive_path_skipped(self, tmp_path: Path) -> None:
"""Windows-style drive path (C:/...) is skipped when os.name == 'nt'."""
info = tarfile.TarInfo(name="C:/secret.txt")
info.type = tarfile.REGTYPE
info.size = 0
info.mode = 0o644
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
tf.addfile(info)
buf.seek(0)
with patch.object(os, "name", "nt"):
_tar_extract_all(buf, tmp_path)
assert not list(tmp_path.rglob("*"))
def test_strip_root_exact_match_skipped(self, tmp_path: Path) -> None:
"""Member whose name equals strip_root exactly (no trailing slash) is skipped."""
# "wrapper" (file entry) + "wrapper/file.txt" causes _detect_archive_root
# to return "wrapper"; the bare "wrapper" entry matches strip_root exactly.
buf = _make_tar(
[_reg("wrapper"), _reg("wrapper/file.txt")],
{"wrapper/file.txt": b"content"},
)
_tar_extract_all(buf, tmp_path)
assert not (tmp_path / "wrapper").exists()
assert (tmp_path / "file.txt").read_bytes() == b"content"
def test_member_not_under_strip_prefix_skipped(self, tmp_path: Path) -> None:
"""Member whose name doesn't start with strip_prefix is silently skipped."""
buf = _make_tar([_reg("other/file.txt")], {"other/file.txt": b"data"})
with patch("esphome.framework_helpers._detect_archive_root", return_value="w"):
_tar_extract_all(buf, tmp_path)
assert not list(tmp_path.rglob("*"))
def test_hardlink_prefix_stripped(self, tmp_path: Path) -> None:
"""Hard-link linkname has wrapper prefix stripped along with its entry name."""
buf = _make_tar(
[_reg("wrapper/file.txt"), _hlnk("wrapper/link.txt", "wrapper/file.txt")],
{"wrapper/file.txt": b"data"},
)
_tar_extract_all(buf, tmp_path)
assert (tmp_path / "file.txt").read_bytes() == b"data"
assert (tmp_path / "link.txt").exists()
def test_hardlink_linkname_equals_strip_root_skipped(self, tmp_path: Path) -> None:
"""Hard link whose linkname equals strip_root is silently skipped."""
buf = _make_tar(
[_reg("wrapper/file.txt"), _hlnk("wrapper/link.txt", "wrapper")],
{"wrapper/file.txt": b"data"},
)
_tar_extract_all(buf, tmp_path)
assert not (tmp_path / "link.txt").exists()
def test_hardlink_linkname_outside_prefix_skipped(self, tmp_path: Path) -> None:
"""Hard link whose linkname doesn't start with strip_prefix is skipped."""
buf = _make_tar(
[_reg("wrapper/file.txt"), _hlnk("wrapper/link.txt", "other/file.txt")],
{"wrapper/file.txt": b"data"},
)
_tar_extract_all(buf, tmp_path)
assert not (tmp_path / "link.txt").exists()
def test_member_mode_none_skips_sanitization(self, tmp_path: Path) -> None:
"""Member with mode=None bypasses the sanitization block without error."""
info = _reg("file.txt")
buf = _make_tar([info], {"file.txt": b"data"})
buf.seek(0)
with tarfile.open(fileobj=buf) as tf:
members = tf.getmembers()
for m in members:
m.mode = None
buf.seek(0)
with (
patch("tarfile.TarFile.getmembers", return_value=members),
patch("tarfile.TarFile.extract"),
):
_tar_extract_all(buf, tmp_path)
def test_progress_bar_shown(self, tmp_path: Path) -> None:
"""A non-empty progress_header causes ProgressBar to be created and updated."""
buf = _make_tar([_reg("file.txt")], {"file.txt": b"x"})
with patch("esphome.framework_helpers.ProgressBar") as mock_pb:
_tar_extract_all(buf, tmp_path, progress_header="Extracting")
mock_pb.assert_called_once_with("Extracting")
mock_pb.return_value.update.assert_called()
# ---------------------------------------------------------------------------
# _zip_extract_all — additional branch coverage
# ---------------------------------------------------------------------------
class TestZipExtractAllBranches:
@pytest.mark.skipif(
sys.version_info < (3, 12),
reason="patching os.name makes pathlib build a WindowsPath, which only "
"instantiates on POSIX in 3.12+",
)
def test_windows_drive_path_skipped(self, tmp_path: Path) -> None:
"""Windows-style drive path (C:/...) is skipped when os.name == 'nt'."""
buf = _make_zip([("C:/secret.txt", "bad")])
with patch.object(os, "name", "nt"):
_zip_extract_all(buf, tmp_path)
assert not list(tmp_path.rglob("*"))
def test_member_not_under_strip_prefix_skipped(self, tmp_path: Path) -> None:
"""Member whose name doesn't start with strip_prefix is silently skipped."""
buf = _make_zip([("other/file.txt", "data")])
with patch("esphome.framework_helpers._detect_archive_root", return_value="w"):
_zip_extract_all(buf, tmp_path)
assert not list(tmp_path.rglob("*"))
def test_progress_bar_shown(self, tmp_path: Path) -> None:
"""A non-empty progress_header causes ProgressBar to be created and updated."""
buf = _make_zip([("file.txt", "hello")])
with patch("esphome.framework_helpers.ProgressBar") as mock_pb:
_zip_extract_all(buf, tmp_path, progress_header="Unzipping")
mock_pb.assert_called_once_with("Unzipping")
mock_pb.return_value.update.assert_called()
# ---------------------------------------------------------------------------
# _rename_with_retry
# ---------------------------------------------------------------------------
class TestRenameWithRetry:
def test_success_on_first_attempt(self, tmp_path: Path) -> None:
src = tmp_path / "src.txt"
src.write_text("data")
dst = tmp_path / "dst.txt"
_rename_with_retry(src, dst)
assert dst.read_text() == "data"
assert not src.exists()
def test_retries_on_permission_error_then_succeeds(self, tmp_path: Path) -> None:
src = tmp_path / "src.txt"
src.write_text("data")
dst = tmp_path / "dst.txt"
call_count = 0
original_rename = Path.rename
def flaky_rename(self, target):
nonlocal call_count
call_count += 1
if call_count == 1:
raise PermissionError("locked")
return original_rename(self, target)
with (
patch.object(Path, "rename", flaky_rename),
patch("esphome.framework_helpers.time.sleep"),
):
_rename_with_retry(src, dst, attempts=3)
assert dst.read_text() == "data"
def test_raises_after_all_attempts_fail(self, tmp_path: Path) -> None:
src = tmp_path / "src.txt"
src.write_text("data")
dst = tmp_path / "dst.txt"
with (
patch.object(Path, "rename", side_effect=PermissionError("locked")),
patch("esphome.framework_helpers.time.sleep"),
pytest.raises(PermissionError),
):
_rename_with_retry(src, dst, attempts=3)
def test_attempts_zero_is_noop(self, tmp_path: Path) -> None:
"""Zero attempts means the for-loop body never runs; src is untouched."""
src = tmp_path / "src.txt"
src.write_text("data")
dst = tmp_path / "dst.txt"
_rename_with_retry(src, dst, attempts=0)
assert src.exists()
assert not dst.exists()
# ---------------------------------------------------------------------------
# _7z_extract_all
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not _HAS_PY7ZR, reason="py7zr not installed")
class TestSevenZipExtractAll:
@staticmethod
def _make_7z(entries: dict[str, bytes]) -> io.BytesIO:
import py7zr
buf = io.BytesIO()
with py7zr.SevenZipFile(buf, "w") as sz:
for name, content in entries.items():
sz.writef(io.BytesIO(content), name)
buf.seek(0)
return buf
def test_basic_extraction_no_wrapper(self, tmp_path: Path) -> None:
buf = self._make_7z({"a.txt": b"aaa", "b.txt": b"bbb"})
out = tmp_path / "out"
out.mkdir()
_7z_extract_all(buf, out)
assert (out / "a.txt").exists()
assert (out / "b.txt").exists()
def test_strips_wrapper_directory(self, tmp_path: Path) -> None:
buf = self._make_7z({"wrapper/file.txt": b"data"})
out = tmp_path / "out"
out.mkdir()
_7z_extract_all(buf, out)
assert (out / "file.txt").exists()
assert not (out / "wrapper").exists()
def test_staging_suffix_collision(self, tmp_path: Path) -> None:
"""When .extract_tmp_0 already exists, suffix is incremented to find a free slot."""
out = tmp_path / "out"
out.mkdir()
(out / ".extract_tmp_0").mkdir()
buf = self._make_7z({"file.txt": b"hi"})
_7z_extract_all(buf, out)
assert (out / "file.txt").exists()
# .extract_tmp_1 should be cleaned up after extraction
assert not (out / ".extract_tmp_1").exists()
def test_overwrites_existing_directory(self, tmp_path: Path) -> None:
"""Pre-existing destination directory is replaced."""
out = tmp_path / "out"
out.mkdir()
existing_dir = out / "file.txt"
existing_dir.mkdir()
buf = self._make_7z({"file.txt": b"new"})
_7z_extract_all(buf, out)
assert (out / "file.txt").is_file()
def test_overwrites_existing_file(self, tmp_path: Path) -> None:
"""Pre-existing destination file is replaced."""
out = tmp_path / "out"
out.mkdir()
(out / "file.txt").write_bytes(b"old")
buf = self._make_7z({"file.txt": b"new"})
_7z_extract_all(buf, out)
assert (out / "file.txt").exists()
def test_empty_name_skipped(self, tmp_path: Path) -> None:
"""Archive entries with empty names are silently skipped."""
import py7zr
buf = self._make_7z({"file.txt": b"data"})
out = tmp_path / "out"
out.mkdir()
with patch.object(
py7zr.SevenZipFile, "getnames", return_value=["", "file.txt"]
):
_7z_extract_all(buf, out)
assert (out / "file.txt").exists()
def test_path_traversal_skipped(self, tmp_path: Path) -> None:
"""Entries whose resolved path exits extract_dir are skipped."""
import py7zr
buf = self._make_7z({"file.txt": b"safe"})
out = tmp_path / "out"
out.mkdir()
with patch.object(
py7zr.SevenZipFile, "getnames", return_value=["../escape.txt", "file.txt"]
):
_7z_extract_all(buf, out)
assert not (tmp_path / "escape.txt").exists()
assert (out / "file.txt").exists()
def test_progress_bar_shown(self, tmp_path: Path) -> None:
buf = self._make_7z({"file.txt": b"x"})
out = tmp_path / "out"
out.mkdir()
with patch("esphome.framework_helpers.ProgressBar") as mock_pb:
_7z_extract_all(buf, out, progress_header="Unpacking 7z")
mock_pb.assert_called_once_with("Unpacking 7z")
mock_pb.return_value.update.assert_called()
def test_absolute_path_in_names_skipped(self, tmp_path: Path) -> None:
"""Names that resolve as absolute are silently skipped."""
import py7zr
buf = self._make_7z({"file.txt": b"safe"})
out = tmp_path / "out"
out.mkdir()
original_is_absolute = Path.is_absolute
def patched_is_absolute(self: Path) -> bool:
if str(self).startswith("C:"):
return True
return original_is_absolute(self)
with (
patch.object(
py7zr.SevenZipFile, "getnames", return_value=["C:/evil.txt", "file.txt"]
),
patch.object(Path, "is_absolute", patched_is_absolute),
):
_7z_extract_all(buf, out)
# Avoid `out / "C:"` here: pathlib treats "C:" as a drive (always
# "exists" on Windows). Assert on the actual extracted files instead.
extracted = sorted(p.name for p in out.rglob("*") if p.is_file())
assert extracted == ["file.txt"]
def test_dispatched_via_archive_extract_all(self, tmp_path: Path) -> None:
"""archive_extract_all dispatches 7z archives to _7z_extract_all."""
buf = self._make_7z({"hello.txt": b"world"})
data = buf.read()
assert data[:6] == b"\x37\x7a\xbc\xaf\x27\x1c"
archive = tmp_path / "test.7z"
archive.write_bytes(data)
out = tmp_path / "out"
out.mkdir()
archive_extract_all(archive, out)
assert (out / "hello.txt").exists()

View File

@@ -0,0 +1,219 @@
"""Tests for esphome.components.nrf52.framework helpers."""
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from esphome.components.nrf52.framework import (
_TOOLCHAIN_VERSION,
_get_toolchain_platform_info,
check_and_install,
)
from esphome.config_validation import Version
from esphome.const import KEY_CORE, KEY_FRAMEWORK_VERSION
from esphome.core import CORE, EsphomeError
@pytest.mark.parametrize(
("system", "machine", "expected"),
[
# default — no branch hit
("Linux", "x86_64", ("linux", "x86_64", "tar.xz")),
# arm64 → aarch64 rename
("Linux", "arm64", ("linux", "aarch64", "tar.xz")),
# darwin → macos rename only
("Darwin", "x86_64", ("macos", "x86_64", "tar.xz")),
# both renames apply
("Darwin", "arm64", ("macos", "aarch64", "tar.xz")),
# windows forces x86_64 + 7z; arm64 rename is overwritten
("Windows", "arm64", ("windows", "x86_64", "7z")),
],
)
def test_get_toolchain_platform_info(
system: str, machine: str, expected: tuple[str, str, str]
) -> None:
with (
patch("platform.system", return_value=system),
patch("platform.machine", return_value=machine),
):
assert _get_toolchain_platform_info() == expected
# ---------------------------------------------------------------------------
# Helpers and fixtures for check_and_install tests
# ---------------------------------------------------------------------------
_TEST_SDK_VERSION = "2.9.0"
@pytest.fixture
def nrf52_dirs(setup_core: Path) -> SimpleNamespace:
"""Populate CORE and pre-create SDK directories so sentinel.touch() succeeds."""
CORE.data[KEY_CORE] = {KEY_FRAMEWORK_VERSION: Version.parse(_TEST_SDK_VERSION)}
tools = CORE.data_dir / "sdk-nrf"
python_env = tools / "penvs" / f"v{_TEST_SDK_VERSION}"
framework = tools / "frameworks" / f"v{_TEST_SDK_VERSION}"
toolchain_dir = tools / "toolchains" / _TOOLCHAIN_VERSION
for d in (python_env, framework, toolchain_dir):
d.mkdir(parents=True, exist_ok=True)
return SimpleNamespace(
python_env=python_env,
framework=framework,
toolchain=toolchain_dir,
)
@pytest.fixture
def mock_nrf52_ops():
"""Patch all heavy I/O operations used by check_and_install."""
with (
patch("esphome.components.nrf52.framework.rmdir") as mock_rmdir,
patch("esphome.components.nrf52.framework.create_venv") as mock_create_venv,
patch(
"esphome.components.nrf52.framework.run_command_ok", return_value=True
) as mock_run_cmd,
patch(
"esphome.components.nrf52.framework.download_from_mirrors",
return_value="https://example.com/tc.tar.xz",
) as mock_download,
patch("esphome.components.nrf52.framework.archive_extract_all") as mock_extract,
):
yield SimpleNamespace(
rmdir=mock_rmdir,
create_venv=mock_create_venv,
run_command_ok=mock_run_cmd,
download_from_mirrors=mock_download,
archive_extract_all=mock_extract,
)
# ---------------------------------------------------------------------------
# check_and_install tests
# ---------------------------------------------------------------------------
class TestCheckAndInstall:
def test_all_installed_skips_all_steps(
self,
nrf52_dirs: SimpleNamespace,
mock_nrf52_ops: SimpleNamespace,
) -> None:
"""All three sentinels present → nothing downloaded or compiled."""
(nrf52_dirs.python_env / ".ready").touch()
(nrf52_dirs.framework / ".ready").touch()
(nrf52_dirs.toolchain / ".ready").touch()
check_and_install()
mock_nrf52_ops.create_venv.assert_not_called()
mock_nrf52_ops.run_command_ok.assert_not_called()
mock_nrf52_ops.download_from_mirrors.assert_not_called()
mock_nrf52_ops.archive_extract_all.assert_not_called()
def test_fresh_install_runs_all_steps(
self,
nrf52_dirs: SimpleNamespace,
mock_nrf52_ops: SimpleNamespace,
) -> None:
"""No sentinels → venv created, west installed, SDK init+update, toolchain downloaded."""
check_and_install()
mock_nrf52_ops.create_venv.assert_called_once()
# pip install west, west init, west update
assert mock_nrf52_ops.run_command_ok.call_count == 3
mock_nrf52_ops.download_from_mirrors.assert_called_once()
mock_nrf52_ops.archive_extract_all.assert_called_once()
assert (nrf52_dirs.python_env / ".ready").exists()
assert (nrf52_dirs.framework / ".ready").exists()
assert (nrf52_dirs.toolchain / ".ready").exists()
def test_venv_exists_installs_framework_and_toolchain(
self,
nrf52_dirs: SimpleNamespace,
mock_nrf52_ops: SimpleNamespace,
) -> None:
"""Venv ready but framework missing → skip venv creation, run SDK init+update."""
(nrf52_dirs.python_env / ".ready").touch()
check_and_install()
mock_nrf52_ops.create_venv.assert_not_called()
# west init + west update only (no pip install)
assert mock_nrf52_ops.run_command_ok.call_count == 2
mock_nrf52_ops.download_from_mirrors.assert_called_once()
def test_toolchain_only_missing(
self,
nrf52_dirs: SimpleNamespace,
mock_nrf52_ops: SimpleNamespace,
) -> None:
"""Venv and framework ready → only toolchain downloaded and extracted."""
(nrf52_dirs.python_env / ".ready").touch()
(nrf52_dirs.framework / ".ready").touch()
check_and_install()
mock_nrf52_ops.create_venv.assert_not_called()
mock_nrf52_ops.run_command_ok.assert_not_called()
mock_nrf52_ops.download_from_mirrors.assert_called_once()
mock_nrf52_ops.archive_extract_all.assert_called_once()
def test_west_install_failure_raises(
self,
nrf52_dirs: SimpleNamespace,
mock_nrf52_ops: SimpleNamespace,
) -> None:
"""Failing pip install west raises EsphomeError."""
mock_nrf52_ops.run_command_ok.return_value = False
with pytest.raises(EsphomeError, match="Install west"):
check_and_install()
def test_framework_init_failure_raises(
self,
nrf52_dirs: SimpleNamespace,
mock_nrf52_ops: SimpleNamespace,
) -> None:
"""Failing west init raises EsphomeError."""
(nrf52_dirs.python_env / ".ready").touch()
mock_nrf52_ops.run_command_ok.return_value = False
with pytest.raises(EsphomeError, match="Can't initialize"):
check_and_install()
def test_framework_update_failure_raises(
self,
nrf52_dirs: SimpleNamespace,
mock_nrf52_ops: SimpleNamespace,
) -> None:
"""Failing west update raises EsphomeError."""
(nrf52_dirs.python_env / ".ready").touch()
# init succeeds, update fails
mock_nrf52_ops.run_command_ok.side_effect = [True, False]
with pytest.raises(EsphomeError, match="Can't update"):
check_and_install()
def test_toolchain_download_passes_platform_substitutions(
self,
nrf52_dirs: SimpleNamespace,
mock_nrf52_ops: SimpleNamespace,
) -> None:
"""download_from_mirrors receives VERSION + platform triple from _get_toolchain_platform_info."""
(nrf52_dirs.python_env / ".ready").touch()
(nrf52_dirs.framework / ".ready").touch()
with patch(
"esphome.components.nrf52.framework._get_toolchain_platform_info",
return_value=("linux", "x86_64", "tar.xz"),
):
check_and_install()
args, _ = mock_nrf52_ops.download_from_mirrors.call_args
substitutions = args[1]
assert substitutions["VERSION"] == _TOOLCHAIN_VERSION
assert substitutions["sysname"] == "linux"
assert substitutions["machine"] == "x86_64"
assert substitutions["extension"] == "tar.xz"

View File

@@ -442,6 +442,21 @@ def test_run_compile(setup_core: Path, mock_run_platformio_cli_run: Mock) -> Non
mock_run_platformio_cli_run.assert_called_once_with(config, True, "-j4") mock_run_platformio_cli_run.assert_called_once_with(config, True, "-j4")
def test_run_compile_without_process_limit(
setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None:
"""When no compile_process_limit is set, run_compile passes no -j flag."""
from esphome.const import CONF_ESPHOME
CORE.build_path = str(setup_core / "build" / "test")
config = {CONF_ESPHOME: {}}
mock_run_platformio_cli_run.return_value = 0
toolchain.run_compile(config, verbose=False)
mock_run_platformio_cli_run.assert_called_once_with(config, False)
def test_get_idedata_caches_result( def test_get_idedata_caches_result(
setup_core: Path, mock_run_platformio_cli_run: Mock setup_core: Path, mock_run_platformio_cli_run: Mock
) -> None: ) -> None: