mirror of
https://github.com/esphome/esphome.git
synced 2026-06-24 14:19:03 +00:00
[cli] Add --ota-platform flag to pick web_server or native API OTA (#16207)
This commit is contained in:
@@ -28,6 +28,7 @@ from esphome.const import (
|
||||
ALLOWED_NAME_CHARS,
|
||||
ARGUMENT_HELP_DEVICE,
|
||||
CONF_API,
|
||||
CONF_AUTH,
|
||||
CONF_BAUD_RATE,
|
||||
CONF_BROKER,
|
||||
CONF_DEASSERT_RTS_DTR,
|
||||
@@ -47,6 +48,8 @@ from esphome.const import (
|
||||
CONF_PORT,
|
||||
CONF_SUBSTITUTIONS,
|
||||
CONF_TOPIC,
|
||||
CONF_USERNAME,
|
||||
CONF_WEB_SERVER,
|
||||
ENV_NOGITIGNORE,
|
||||
KEY_CORE,
|
||||
KEY_NATIVE_IDF,
|
||||
@@ -349,6 +352,17 @@ def choose_upload_log_host(
|
||||
elif bootsel.permission_error:
|
||||
bootsel_permission_error = True
|
||||
|
||||
# Annotate the OTA chooser entry only in the non-default case: when the
|
||||
# config has web_server OTA but no native API OTA, the upload will fall
|
||||
# through to the HTTP path and the user benefits from seeing that
|
||||
# explicitly. The native-API path is the default and gets a plain label
|
||||
# to avoid noise on the most common scenario. For LOGGING the OTA
|
||||
# transport doesn't apply, so always leave the label plain.
|
||||
if purpose == Purpose.UPLOADING and not has_native_ota() and has_web_server_ota():
|
||||
ota_suffix = " via web_server"
|
||||
else:
|
||||
ota_suffix = ""
|
||||
|
||||
def add_ota_options() -> None:
|
||||
"""Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled."""
|
||||
if (discovered := _discover_mac_suffix_devices()) is not None:
|
||||
@@ -356,11 +370,11 @@ def choose_upload_log_host(
|
||||
# intentionally skip the base-name fallback since with
|
||||
# name_add_mac_suffix on, the base name doesn't exist on the net.
|
||||
for host in discovered:
|
||||
options.append((f"Over The Air ({host})", host))
|
||||
options.append((f"Over The Air{ota_suffix} ({host})", host))
|
||||
elif has_resolvable_address():
|
||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||
options.append((f"Over The Air{ota_suffix} ({CORE.address})", CORE.address))
|
||||
if has_mqtt_ip_lookup():
|
||||
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
||||
options.append((f"Over The Air{ota_suffix} (MQTT IP lookup)", "MQTTIP"))
|
||||
|
||||
if purpose == Purpose.LOGGING:
|
||||
if has_mqtt_logging():
|
||||
@@ -429,7 +443,19 @@ def has_api() -> bool:
|
||||
|
||||
|
||||
def has_ota() -> bool:
|
||||
"""Check if OTA upload is available (requires platform: esphome)."""
|
||||
"""Check if any network OTA upload is available.
|
||||
|
||||
True if the config exposes either ``platform: esphome`` (native API
|
||||
OTA) or ``platform: web_server`` (HTTP OTA). Both reach the device
|
||||
over the same network stack, so the OTA discovery path treats them
|
||||
interchangeably; ``upload_program`` picks the actual transport based
|
||||
on ``--ota-platform`` and what's configured.
|
||||
"""
|
||||
return has_native_ota() or has_web_server_ota()
|
||||
|
||||
|
||||
def has_native_ota() -> bool:
|
||||
"""Check if native API OTA upload is available (``platform: esphome``)."""
|
||||
if CONF_OTA not in CORE.config:
|
||||
return False
|
||||
return any(
|
||||
@@ -438,6 +464,16 @@ def has_ota() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def has_web_server_ota() -> bool:
|
||||
"""Check if web_server OTA upload is available (``platform: web_server``)."""
|
||||
if CONF_OTA not in CORE.config:
|
||||
return False
|
||||
return any(
|
||||
ota_item.get(CONF_PLATFORM) == CONF_WEB_SERVER
|
||||
for ota_item in CORE.config[CONF_OTA]
|
||||
)
|
||||
|
||||
|
||||
def has_mqtt_ip_lookup() -> bool:
|
||||
"""Check if MQTT is available and IP lookup is supported."""
|
||||
from esphome.components.mqtt import CONF_DISCOVER_IP
|
||||
@@ -1115,25 +1151,83 @@ def upload_program(
|
||||
|
||||
return exit_code, host if exit_code == 0 else None
|
||||
|
||||
ota_conf = {}
|
||||
requested_platform = getattr(args, "ota_platform", None)
|
||||
chosen_platform = _choose_ota_platform(config, requested_platform)
|
||||
|
||||
# Resolve MQTT magic strings to actual IP addresses
|
||||
network_devices = _resolve_network_devices(devices, config, args)
|
||||
|
||||
if chosen_platform == CONF_WEB_SERVER:
|
||||
if getattr(args, "partition_table", False):
|
||||
raise EsphomeError(
|
||||
"--partition-table is only supported with the esphome OTA platform; "
|
||||
"the web_server OTA path can only update the firmware image."
|
||||
)
|
||||
binary = CORE.firmware_bin
|
||||
if getattr(args, "file", None) is not None:
|
||||
binary = Path(args.file)
|
||||
return _upload_via_web_server(config, network_devices, binary)
|
||||
|
||||
return _upload_via_native_api(config, network_devices, args)
|
||||
|
||||
|
||||
def _choose_ota_platform(config: ConfigType, requested: str | None) -> str:
|
||||
"""Pick the OTA platform to use, optionally honoring ``--ota-platform``.
|
||||
|
||||
Default behavior prefers ``esphome`` (native API) when it is configured.
|
||||
The native API uses challenge-response auth with MD5/SHA256 hashing of a
|
||||
server-issued nonce, so the password is never sent over the wire; the
|
||||
``web_server`` path uses HTTP Basic auth which transmits credentials in
|
||||
cleartext over the LAN. (The native path also supports gzip compression
|
||||
on ESP8266, where flash space is tight; on ESP32/RP2040/LibreTiny the
|
||||
backend reports ``supports_compression() == false`` and the firmware is
|
||||
sent uncompressed regardless of which platform is used.) Falls back to
|
||||
``web_server`` only when that is the only available platform.
|
||||
"""
|
||||
# Use a dict (insertion-ordered) instead of a list so error messages and
|
||||
# membership checks see one entry per platform even if the user has
|
||||
# multiple ``ota:`` items of the same platform; the web_server OTA
|
||||
# platform's final-validate hook merges duplicates anyway.
|
||||
available: dict[str, None] = {}
|
||||
for ota_item in config.get(CONF_OTA, []):
|
||||
if ota_item[CONF_PLATFORM] == CONF_ESPHOME:
|
||||
platform = ota_item.get(CONF_PLATFORM)
|
||||
if platform in (CONF_ESPHOME, CONF_WEB_SERVER):
|
||||
available[platform] = None
|
||||
|
||||
if not available:
|
||||
raise EsphomeError(
|
||||
f"Cannot upload Over the Air as the {CONF_OTA} configuration is not "
|
||||
f"present or does not include {CONF_PLATFORM}: {CONF_ESPHOME} or "
|
||||
f"{CONF_PLATFORM}: {CONF_WEB_SERVER}"
|
||||
)
|
||||
|
||||
if requested is not None:
|
||||
if requested not in available:
|
||||
raise EsphomeError(
|
||||
f"--ota-platform {requested} was requested but the configuration "
|
||||
f"only provides: {', '.join(available)}"
|
||||
)
|
||||
return requested
|
||||
|
||||
if CONF_ESPHOME in available:
|
||||
return CONF_ESPHOME
|
||||
return CONF_WEB_SERVER
|
||||
|
||||
|
||||
def _upload_via_native_api(
|
||||
config: ConfigType, network_devices: list[str], args: ArgsProtocol
|
||||
) -> tuple[int, str | None]:
|
||||
ota_conf: ConfigType = {}
|
||||
for ota_item in config.get(CONF_OTA, []):
|
||||
if ota_item.get(CONF_PLATFORM) == CONF_ESPHOME:
|
||||
ota_conf = ota_item
|
||||
break
|
||||
|
||||
if not ota_conf:
|
||||
raise EsphomeError(
|
||||
f"Cannot upload Over the Air as the {CONF_OTA} configuration is not present or does not include {CONF_PLATFORM}: {CONF_ESPHOME}"
|
||||
)
|
||||
|
||||
from esphome import espota2
|
||||
|
||||
remote_port = int(ota_conf[CONF_PORT])
|
||||
password = ota_conf.get(CONF_PASSWORD)
|
||||
|
||||
# Resolve MQTT magic strings to actual IP addresses
|
||||
network_devices = _resolve_network_devices(devices, config, args)
|
||||
|
||||
binary = CORE.firmware_bin
|
||||
ota_type = espota2.OTA_TYPE_UPDATE_APP
|
||||
if getattr(args, "partition_table", False):
|
||||
@@ -1157,6 +1251,28 @@ def upload_program(
|
||||
return espota2.run_ota(network_devices, remote_port, password, binary, ota_type)
|
||||
|
||||
|
||||
def _upload_via_web_server(
|
||||
config: ConfigType, network_devices: list[str], binary: Path
|
||||
) -> tuple[int, str | None]:
|
||||
web_conf = config.get(CONF_WEB_SERVER)
|
||||
if not web_conf:
|
||||
raise EsphomeError(
|
||||
f"Cannot upload via web_server OTA: the {CONF_WEB_SERVER} component "
|
||||
f"is not configured."
|
||||
)
|
||||
|
||||
remote_port = int(web_conf[CONF_PORT])
|
||||
auth = web_conf.get(CONF_AUTH) or {}
|
||||
username = auth.get(CONF_USERNAME)
|
||||
password = auth.get(CONF_PASSWORD)
|
||||
|
||||
from esphome import web_server_ota
|
||||
|
||||
return web_server_ota.run_ota(
|
||||
network_devices, remote_port, username, password, binary
|
||||
)
|
||||
|
||||
|
||||
# Layout of esp_partition_info_t on flash. Each entry is 32 bytes, leading with a
|
||||
# 16-bit little-endian magic. ESP-IDF defines ESP_PARTITION_MAGIC = 0x50AA (stored as
|
||||
# bytes 0xAA, 0x50) for partition entries and ESP_PARTITION_MAGIC_MD5 = 0xEBEB for the
|
||||
@@ -1877,6 +1993,17 @@ def parse_args(argv):
|
||||
"--file",
|
||||
help="Manually specify the binary file to upload.",
|
||||
)
|
||||
parser_upload.add_argument(
|
||||
"--ota-platform",
|
||||
choices=[CONF_ESPHOME, CONF_WEB_SERVER],
|
||||
help=(
|
||||
"OTA platform to use for network uploads. Defaults to "
|
||||
f"'{CONF_ESPHOME}' (native API) when configured because it uses "
|
||||
"challenge-response auth so the password is never sent in "
|
||||
f"cleartext on the wire. Falls back to '{CONF_WEB_SERVER}' "
|
||||
"(HTTP Basic auth) when that is the only configured platform."
|
||||
),
|
||||
)
|
||||
parser_upload.add_argument(
|
||||
"--partition-table",
|
||||
help="Upload as partition table (OTA).",
|
||||
@@ -1951,6 +2078,17 @@ def parse_args(argv):
|
||||
help="Build with native ESP-IDF instead of PlatformIO (ESP32 esp-idf framework only).",
|
||||
action="store_true",
|
||||
)
|
||||
parser_run.add_argument(
|
||||
"--ota-platform",
|
||||
choices=[CONF_ESPHOME, CONF_WEB_SERVER],
|
||||
help=(
|
||||
"OTA platform to use for network uploads. Defaults to "
|
||||
f"'{CONF_ESPHOME}' (native API) when configured because it uses "
|
||||
"challenge-response auth so the password is never sent in "
|
||||
f"cleartext on the wire. Falls back to '{CONF_WEB_SERVER}' "
|
||||
"(HTTP Basic auth) when that is the only configured platform."
|
||||
),
|
||||
)
|
||||
|
||||
parser_clean = subparsers.add_parser(
|
||||
"clean-mqtt",
|
||||
|
||||
202
esphome/web_server_ota.py
Normal file
202
esphome/web_server_ota.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""HTTP-based OTA upload via the ``web_server`` component's ``/update`` endpoint.
|
||||
|
||||
This is the alternative to ``espota2`` (the native API OTA path). Useful when
|
||||
a device only has ``platform: web_server`` configured under ``ota:``, or when
|
||||
the user has lost the native OTA password but still has ``web_server`` basic
|
||||
auth credentials.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import socket
|
||||
from typing import BinaryIO
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from esphome.core import EsphomeError
|
||||
from esphome.helpers import ProgressBar, resolve_ip_address
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
OTA_PATH = "/update"
|
||||
FORM_FIELD = "update"
|
||||
# (connect_timeout, read_timeout). The device reboots after a successful
|
||||
# upload so the read side must allow for a slow flash + response.
|
||||
TIMEOUT = (20.0, 120.0)
|
||||
|
||||
|
||||
class WebServerOTAError(EsphomeError):
|
||||
pass
|
||||
|
||||
|
||||
class _MultipartStreamer:
|
||||
"""Stream a single-file multipart/form-data body during transmission.
|
||||
|
||||
``requests.post(files=...)`` materializes the entire body in memory before
|
||||
sending, so a progress callback wired into the file-like fires during
|
||||
encoding instead of during the network send. Pass this via ``data=``
|
||||
(with ``__len__`` so urllib3 sets ``Content-Length`` instead of using
|
||||
chunked transfer encoding); urllib3 then calls ``read(blocksize)``
|
||||
repeatedly during the POST and the progress bar tracks bytes leaving the
|
||||
host.
|
||||
"""
|
||||
|
||||
def __init__(self, file: BinaryIO, file_size: int, filename: str) -> None:
|
||||
self.boundary = f"esphomeOTA{secrets.token_hex(16)}"
|
||||
prefix = (
|
||||
f"--{self.boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="{FORM_FIELD}"; '
|
||||
f'filename="{filename}"\r\n'
|
||||
f"Content-Type: application/octet-stream\r\n\r\n"
|
||||
).encode()
|
||||
suffix = f"\r\n--{self.boundary}--\r\n".encode()
|
||||
# Walked in order; ``read()`` advances to the next source on EOF.
|
||||
self._sources: list[BinaryIO] = [io.BytesIO(prefix), file, io.BytesIO(suffix)]
|
||||
self._idx = 0
|
||||
self._total = len(prefix) + file_size + len(suffix)
|
||||
self._sent = 0
|
||||
self.progress = ProgressBar()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._total
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return f"multipart/form-data; boundary={self.boundary}"
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
remaining = self._total if size is None or size < 0 else size
|
||||
out = bytearray()
|
||||
while remaining > 0 and self._idx < len(self._sources):
|
||||
chunk = self._sources[self._idx].read(remaining)
|
||||
if not chunk:
|
||||
self._idx += 1
|
||||
continue
|
||||
out += chunk
|
||||
remaining -= len(chunk)
|
||||
if out:
|
||||
self._sent += len(out)
|
||||
self.progress.update(self._sent / self._total)
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _try_upload(
|
||||
host: str,
|
||||
port: int,
|
||||
username: str | None,
|
||||
password: str | None,
|
||||
filename: Path,
|
||||
) -> tuple[int, str | None]:
|
||||
from esphome.core import CORE
|
||||
|
||||
try:
|
||||
addr_infos = resolve_ip_address(host, port, address_cache=CORE.address_cache)
|
||||
except EsphomeError as err:
|
||||
_LOGGER.error(
|
||||
"Error resolving IP address of %s. Is it connected to WiFi?", host
|
||||
)
|
||||
if not CORE.dashboard:
|
||||
_LOGGER.error("(If you know the IP, try --device <IP>)")
|
||||
raise WebServerOTAError(err) from err
|
||||
|
||||
if not addr_infos:
|
||||
_LOGGER.error("Could not resolve %s", host)
|
||||
return 1, None
|
||||
|
||||
file_size = filename.stat().st_size
|
||||
_LOGGER.info("Uploading %s (%s bytes) via web_server OTA", filename, file_size)
|
||||
auth = HTTPBasicAuth(username, password) if username and password else None
|
||||
|
||||
# Iterate resolved IPs (IPv4 + IPv6 candidates) just like espota2 does.
|
||||
for af, _socktype, _, _, sa in addr_infos:
|
||||
ip = sa[0]
|
||||
# IPv6 literals must be wrapped in brackets in URLs; link-local
|
||||
# addresses need a percent-encoded zone index per RFC 6874.
|
||||
if af == socket.AF_INET6:
|
||||
scope = sa[3] if len(sa) >= 4 else 0
|
||||
host_part = f"[{ip}%25{scope}]" if scope else f"[{ip}]"
|
||||
else:
|
||||
host_part = ip
|
||||
url = f"http://{host_part}:{port}{OTA_PATH}"
|
||||
_LOGGER.info("Connecting to %s port %s...", ip, port)
|
||||
|
||||
try:
|
||||
with open(filename, "rb") as fh:
|
||||
streamer = _MultipartStreamer(fh, file_size, filename.name)
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
data=streamer,
|
||||
auth=auth,
|
||||
timeout=TIMEOUT,
|
||||
headers={
|
||||
"Content-Type": streamer.content_type,
|
||||
"Connection": "close",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
streamer.progress.done()
|
||||
except requests.RequestException as err:
|
||||
_LOGGER.error("OTA upload to %s port %s failed: %s", ip, port, err)
|
||||
continue
|
||||
|
||||
if response.status_code == 401:
|
||||
raise WebServerOTAError(
|
||||
"Authentication failed (HTTP 401). Check the 'web_server' "
|
||||
"'auth' username and password."
|
||||
)
|
||||
if response.status_code != 200:
|
||||
detail = response.text.strip() or response.reason or "no response body"
|
||||
raise WebServerOTAError(
|
||||
f"Unexpected HTTP {response.status_code} response from device: {detail}"
|
||||
)
|
||||
|
||||
# The endpoint returns HTTP 200 for both success and failure; the
|
||||
# body is what tells us which (see ota_web_server.cpp handleRequest).
|
||||
body = response.text.strip()
|
||||
if "Successful" in body:
|
||||
_LOGGER.info("Device response: %s", body)
|
||||
_LOGGER.info("OTA successful")
|
||||
return 0, ip
|
||||
|
||||
raise WebServerOTAError(
|
||||
f"Device reported OTA failure: {body or 'no response body'}"
|
||||
)
|
||||
|
||||
return 1, None
|
||||
|
||||
|
||||
def run_ota(
|
||||
remote_hosts: str | list[str],
|
||||
remote_port: int,
|
||||
username: str | None,
|
||||
password: str | None,
|
||||
filename: Path,
|
||||
) -> tuple[int, str | None]:
|
||||
"""Upload ``filename`` to the first reachable host via ``web_server`` OTA.
|
||||
|
||||
Mirrors :func:`esphome.espota2.run_ota` so callers can swap between the
|
||||
two paths with the same return contract: ``(0, host)`` on success or
|
||||
``(1, None)`` on failure.
|
||||
"""
|
||||
hosts = [remote_hosts] if isinstance(remote_hosts, str) else list(remote_hosts)
|
||||
for host in hosts:
|
||||
try:
|
||||
exit_code, used_host = _try_upload(
|
||||
host, remote_port, username, password, filename
|
||||
)
|
||||
except WebServerOTAError as err:
|
||||
_LOGGER.error("%s", err)
|
||||
continue
|
||||
if exit_code == 0:
|
||||
return 0, used_host
|
||||
# Reached only when every attempt failed; per-attempt errors were
|
||||
# already logged. This summary line gives the user an unambiguous
|
||||
# "stop reading, nothing worked" marker.
|
||||
_LOGGER.error("OTA upload failed.")
|
||||
return 1, None
|
||||
@@ -43,6 +43,7 @@ from esphome.__main__ import (
|
||||
has_non_ip_address,
|
||||
has_ota,
|
||||
has_resolvable_address,
|
||||
has_web_server_ota,
|
||||
mqtt_get_ip,
|
||||
run_esphome,
|
||||
run_miniterm,
|
||||
@@ -58,6 +59,7 @@ from esphome.components import esp32
|
||||
from esphome.components.esp32 import KEY_ESP32, KEY_VARIANT, VARIANT_ESP32
|
||||
from esphome.const import (
|
||||
CONF_API,
|
||||
CONF_AUTH,
|
||||
CONF_BAUD_RATE,
|
||||
CONF_BROKER,
|
||||
CONF_DISABLED,
|
||||
@@ -76,6 +78,8 @@ from esphome.const import (
|
||||
CONF_SUBSTITUTIONS,
|
||||
CONF_TOPIC,
|
||||
CONF_USE_ADDRESS,
|
||||
CONF_USERNAME,
|
||||
CONF_WEB_SERVER,
|
||||
CONF_WIFI,
|
||||
KEY_CORE,
|
||||
KEY_TARGET_PLATFORM,
|
||||
@@ -213,6 +217,13 @@ def mock_run_ota() -> Generator[Mock]:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_run_web_server_ota() -> Generator[Mock]:
|
||||
"""Mock web_server_ota.run_ota for testing."""
|
||||
with patch("esphome.web_server_ota.run_ota") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_is_ip_address() -> Generator[Mock]:
|
||||
"""Mock is_ip_address for testing."""
|
||||
@@ -1114,6 +1125,7 @@ class MockArgs:
|
||||
reset: bool = False
|
||||
list_only: bool = False
|
||||
output: str | None = None
|
||||
ota_platform: str | None = None
|
||||
partition_table: bool = False
|
||||
|
||||
|
||||
@@ -1878,6 +1890,277 @@ def test_upload_program_ota_no_config(
|
||||
upload_program(config, args, devices)
|
||||
|
||||
|
||||
def test_has_web_server_ota_detects_platform() -> None:
|
||||
"""has_web_server_ota returns True when web_server OTA platform is configured."""
|
||||
setup_core(
|
||||
config={
|
||||
CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}],
|
||||
}
|
||||
)
|
||||
assert has_web_server_ota() is True
|
||||
assert has_ota() is True
|
||||
|
||||
|
||||
def test_has_web_server_ota_returns_false_without_config() -> None:
|
||||
"""has_web_server_ota returns False when only native OTA is configured."""
|
||||
setup_core(
|
||||
config={
|
||||
CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}],
|
||||
}
|
||||
)
|
||||
assert has_web_server_ota() is False
|
||||
assert has_ota() is True
|
||||
|
||||
|
||||
def test_upload_program_web_server_only_auto_dispatches(
|
||||
mock_run_web_server_ota: Mock,
|
||||
mock_run_ota: Mock,
|
||||
mock_get_port_type: Mock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""When only web_server OTA is configured, upload_program picks it automatically."""
|
||||
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
mock_run_web_server_ota.return_value = (0, "192.168.1.100")
|
||||
|
||||
config = {
|
||||
CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}],
|
||||
CONF_WEB_SERVER: {
|
||||
CONF_PORT: 80,
|
||||
CONF_AUTH: {CONF_USERNAME: "admin", CONF_PASSWORD: "pw"},
|
||||
},
|
||||
}
|
||||
args = MockArgs()
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
exit_code, host = upload_program(config, args, devices)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.100"
|
||||
expected_firmware = (
|
||||
tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin"
|
||||
)
|
||||
mock_run_web_server_ota.assert_called_once_with(
|
||||
["192.168.1.100"], 80, "admin", "pw", expected_firmware
|
||||
)
|
||||
mock_run_ota.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_program_web_server_no_auth(
|
||||
mock_run_web_server_ota: Mock,
|
||||
mock_get_port_type: Mock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""web_server OTA works without an auth block (passes None for credentials)."""
|
||||
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
mock_run_web_server_ota.return_value = (0, "192.168.1.100")
|
||||
|
||||
config = {
|
||||
CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}],
|
||||
CONF_WEB_SERVER: {CONF_PORT: 8080},
|
||||
}
|
||||
args = MockArgs()
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
exit_code, host = upload_program(config, args, devices)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.100"
|
||||
expected_firmware = (
|
||||
tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin"
|
||||
)
|
||||
mock_run_web_server_ota.assert_called_once_with(
|
||||
["192.168.1.100"], 8080, None, None, expected_firmware
|
||||
)
|
||||
|
||||
|
||||
def test_upload_program_both_platforms_default_prefers_native(
|
||||
mock_run_ota: Mock,
|
||||
mock_run_web_server_ota: Mock,
|
||||
mock_get_port_type: Mock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""When both OTA platforms are configured, default selection is native API."""
|
||||
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
mock_run_ota.return_value = (0, "192.168.1.100")
|
||||
|
||||
config = {
|
||||
CONF_OTA: [
|
||||
{
|
||||
CONF_PLATFORM: CONF_ESPHOME,
|
||||
CONF_PORT: 3232,
|
||||
CONF_PASSWORD: "secret",
|
||||
},
|
||||
{CONF_PLATFORM: CONF_WEB_SERVER},
|
||||
],
|
||||
CONF_WEB_SERVER: {CONF_PORT: 80},
|
||||
}
|
||||
args = MockArgs()
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
exit_code, host = upload_program(config, args, devices)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.100"
|
||||
mock_run_ota.assert_called_once()
|
||||
mock_run_web_server_ota.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_program_ota_platform_override_to_web_server(
|
||||
mock_run_ota: Mock,
|
||||
mock_run_web_server_ota: Mock,
|
||||
mock_get_port_type: Mock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""--ota-platform web_server forces web_server OTA even when native is configured."""
|
||||
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
mock_run_web_server_ota.return_value = (0, "192.168.1.100")
|
||||
|
||||
config = {
|
||||
CONF_OTA: [
|
||||
{
|
||||
CONF_PLATFORM: CONF_ESPHOME,
|
||||
CONF_PORT: 3232,
|
||||
CONF_PASSWORD: "secret",
|
||||
},
|
||||
{CONF_PLATFORM: CONF_WEB_SERVER},
|
||||
],
|
||||
CONF_WEB_SERVER: {CONF_PORT: 80},
|
||||
}
|
||||
args = MockArgs(ota_platform=CONF_WEB_SERVER)
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
exit_code, host = upload_program(config, args, devices)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.100"
|
||||
mock_run_ota.assert_not_called()
|
||||
mock_run_web_server_ota.assert_called_once()
|
||||
|
||||
|
||||
def test_upload_program_ota_platform_unavailable(
|
||||
mock_get_port_type: Mock,
|
||||
) -> None:
|
||||
"""--ota-platform must reference a platform that is actually configured."""
|
||||
setup_core(platform=PLATFORM_ESP32)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
|
||||
config = {
|
||||
CONF_OTA: [
|
||||
{
|
||||
CONF_PLATFORM: CONF_ESPHOME,
|
||||
CONF_PORT: 3232,
|
||||
CONF_PASSWORD: "secret",
|
||||
}
|
||||
],
|
||||
}
|
||||
args = MockArgs(ota_platform=CONF_WEB_SERVER)
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
with pytest.raises(EsphomeError, match="--ota-platform web_server"):
|
||||
upload_program(config, args, devices)
|
||||
|
||||
|
||||
def test_upload_program_web_server_missing_component(
|
||||
mock_get_port_type: Mock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""web_server OTA without a web_server component fails with a clear error."""
|
||||
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
|
||||
config = {
|
||||
CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}],
|
||||
# No CONF_WEB_SERVER
|
||||
}
|
||||
args = MockArgs()
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
with pytest.raises(EsphomeError, match="web_server.*not configured"):
|
||||
upload_program(config, args, devices)
|
||||
|
||||
|
||||
def test_upload_program_unrelated_ota_platform_ignored(
|
||||
mock_run_ota: Mock,
|
||||
mock_get_port_type: Mock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""OTA list entries that are neither esphome nor web_server are ignored.
|
||||
|
||||
Covers the false branch in _choose_ota_platform's filter loop and the
|
||||
no-match branch in _upload_via_native_api's lookup loop.
|
||||
"""
|
||||
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
mock_run_ota.return_value = (0, "192.168.1.100")
|
||||
|
||||
config = {
|
||||
CONF_OTA: [
|
||||
{CONF_PLATFORM: "http_request"}, # unrelated platform; ignored
|
||||
{
|
||||
CONF_PLATFORM: CONF_ESPHOME,
|
||||
CONF_PORT: 3232,
|
||||
CONF_PASSWORD: "secret",
|
||||
},
|
||||
],
|
||||
}
|
||||
args = MockArgs()
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
exit_code, host = upload_program(config, args, devices)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.100"
|
||||
mock_run_ota.assert_called_once()
|
||||
|
||||
|
||||
def test_upload_program_duplicate_platform_dedup_in_error(
|
||||
mock_get_port_type: Mock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Duplicate same-platform OTA entries don't repeat in --ota-platform errors."""
|
||||
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
|
||||
config = {
|
||||
CONF_OTA: [
|
||||
{CONF_PLATFORM: CONF_ESPHOME, CONF_PORT: 3232},
|
||||
{CONF_PLATFORM: CONF_ESPHOME, CONF_PORT: 3233},
|
||||
],
|
||||
}
|
||||
args = MockArgs(ota_platform=CONF_WEB_SERVER)
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
with pytest.raises(EsphomeError) as excinfo:
|
||||
upload_program(config, args, devices)
|
||||
|
||||
# Error mentions esphome once in the platform list, not "esphome, esphome".
|
||||
msg = str(excinfo.value)
|
||||
assert "esphome, esphome" not in msg
|
||||
assert msg.endswith(": esphome")
|
||||
|
||||
|
||||
def test_upload_program_only_unrelated_ota_platforms(
|
||||
mock_get_port_type: Mock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Only unrelated OTA platforms configured -> raises like missing OTA."""
|
||||
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
|
||||
mock_get_port_type.return_value = "NETWORK"
|
||||
|
||||
config = {
|
||||
CONF_OTA: [{CONF_PLATFORM: "http_request"}],
|
||||
}
|
||||
args = MockArgs()
|
||||
devices = ["192.168.1.100"]
|
||||
|
||||
with pytest.raises(EsphomeError, match="Cannot upload Over the Air"):
|
||||
upload_program(config, args, devices)
|
||||
|
||||
|
||||
def test_upload_program_ota_with_mqtt_resolution(
|
||||
mock_mqtt_get_ip: Mock,
|
||||
mock_is_ip_address: Mock,
|
||||
|
||||
670
tests/unit_tests/test_web_server_ota.py
Normal file
670
tests/unit_tests/test_web_server_ota.py
Normal file
@@ -0,0 +1,670 @@
|
||||
"""Unit tests for esphome.web_server_ota module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import socket
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from esphome.core import CORE, EsphomeError
|
||||
from esphome.helpers import ProgressBar
|
||||
from esphome.web_server_ota import (
|
||||
OTA_PATH,
|
||||
WebServerOTAError,
|
||||
_MultipartStreamer,
|
||||
run_ota,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def firmware(tmp_path: Path) -> Path:
|
||||
binary = tmp_path / "firmware.bin"
|
||||
binary.write_bytes(b"\x00\x01\x02FIRMWARE\xff" * 64)
|
||||
return binary
|
||||
|
||||
|
||||
def _make_response(status: int, body: str) -> MagicMock:
|
||||
response = MagicMock(spec=requests.Response)
|
||||
response.status_code = status
|
||||
response.text = body
|
||||
response.reason = ""
|
||||
return response
|
||||
|
||||
|
||||
def _patch_resolve(
|
||||
monkeypatch: pytest.MonkeyPatch, hosts: list[tuple[str, int]]
|
||||
) -> None:
|
||||
"""Replace resolve_ip_address so tests don't actually do DNS."""
|
||||
addr_infos = [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", (host, port))
|
||||
for host, port in hosts
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
"esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _MultipartStreamer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_multipart_streamer_emits_full_body() -> None:
|
||||
"""Streaming the whole body in one call yields prefix + file + suffix."""
|
||||
data = b"abcdef" * 100
|
||||
streamer = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin")
|
||||
|
||||
body = streamer.read()
|
||||
while True:
|
||||
chunk = streamer.read()
|
||||
if not chunk:
|
||||
break
|
||||
body += chunk
|
||||
|
||||
assert body.startswith(f"--{streamer.boundary}\r\n".encode())
|
||||
assert b'name="update"' in body
|
||||
assert b'filename="fw.bin"' in body
|
||||
assert data in body
|
||||
assert body.endswith(f"\r\n--{streamer.boundary}--\r\n".encode())
|
||||
|
||||
|
||||
def test_multipart_streamer_chunked_read_matches_full_read() -> None:
|
||||
"""Chunked reads (urllib3 calls read(8192) repeatedly) yield the same body."""
|
||||
data = b"abcdef" * 1000 # 6000 bytes
|
||||
full = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin").read()
|
||||
|
||||
streamed = bytearray()
|
||||
s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin")
|
||||
# Same boundary lengths -> identical total length.
|
||||
while True:
|
||||
chunk = s.read(64)
|
||||
if not chunk:
|
||||
break
|
||||
streamed += chunk
|
||||
# Boundaries are random per instance, so compare lengths and structure.
|
||||
assert len(streamed) == len(full)
|
||||
assert streamed.startswith(f"--{s.boundary}\r\n".encode())
|
||||
assert streamed.endswith(f"\r\n--{s.boundary}--\r\n".encode())
|
||||
|
||||
|
||||
def test_multipart_streamer_len_matches_emitted_bytes() -> None:
|
||||
"""``__len__`` is what urllib3 uses to set Content-Length, so it must
|
||||
equal the total bytes emitted by ``read``."""
|
||||
data = b"x" * 12345
|
||||
s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin")
|
||||
declared = len(s)
|
||||
|
||||
emitted = 0
|
||||
while True:
|
||||
chunk = s.read(1024)
|
||||
if not chunk:
|
||||
break
|
||||
emitted += len(chunk)
|
||||
|
||||
assert emitted == declared
|
||||
|
||||
|
||||
def test_multipart_streamer_progress_ticks_during_read() -> None:
|
||||
"""Each read advances the progress bar (this is the whole point of
|
||||
streaming via ``data=``: progress reflects bytes leaving the host)."""
|
||||
data = b"x" * 1000
|
||||
s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin")
|
||||
|
||||
updates: list[float] = []
|
||||
s.progress.update = updates.append # type: ignore[method-assign]
|
||||
|
||||
while True:
|
||||
chunk = s.read(128)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
assert updates, "progress.update was never called"
|
||||
# Strictly non-decreasing.
|
||||
assert updates == sorted(updates)
|
||||
# Final update reaches (within FP) 1.0 because all bytes were read.
|
||||
assert updates[-1] == pytest.approx(1.0, abs=1e-9)
|
||||
|
||||
|
||||
def test_multipart_streamer_content_type_includes_boundary() -> None:
|
||||
s = _MultipartStreamer(io.BytesIO(b""), 0, "fw.bin")
|
||||
assert s.content_type == f"multipart/form-data; boundary={s.boundary}"
|
||||
|
||||
|
||||
def test_multipart_streamer_zero_size_file() -> None:
|
||||
"""A zero-byte file still produces a well-formed body and progress is
|
||||
skipped (avoiding a divide-by-zero on the empty file segment)."""
|
||||
s = _MultipartStreamer(io.BytesIO(b""), 0, "empty.bin")
|
||||
body = b""
|
||||
while True:
|
||||
chunk = s.read(64)
|
||||
if not chunk:
|
||||
break
|
||||
body += chunk
|
||||
assert body.startswith(f"--{s.boundary}".encode())
|
||||
assert body.endswith(f"--{s.boundary}--\r\n".encode())
|
||||
|
||||
|
||||
def test_multipart_streamer_unique_boundary_per_instance() -> None:
|
||||
a = _MultipartStreamer(io.BytesIO(b""), 0, "a")
|
||||
b = _MultipartStreamer(io.BytesIO(b""), 0, "a")
|
||||
assert a.boundary != b.boundary
|
||||
|
||||
|
||||
def test_multipart_streamer_zero_size_read_returns_empty() -> None:
|
||||
"""``read(0)`` short-circuits without touching state."""
|
||||
s = _MultipartStreamer(io.BytesIO(b"x" * 10), 10, "fw.bin")
|
||||
assert s.read(0) == b""
|
||||
# No bytes consumed.
|
||||
assert s._sent == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_ota
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_ota_success(monkeypatch: pytest.MonkeyPatch, firmware: Path) -> None:
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
) as post:
|
||||
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.50"
|
||||
post.assert_called_once()
|
||||
args, kwargs = post.call_args
|
||||
assert args == (f"http://192.168.1.50:80{OTA_PATH}",)
|
||||
assert kwargs["auth"] is None
|
||||
# Streaming body, not files=, so progress fires during transmission.
|
||||
assert "files" not in kwargs
|
||||
assert isinstance(kwargs["data"], _MultipartStreamer)
|
||||
assert kwargs["headers"]["Content-Type"] == kwargs["data"].content_type
|
||||
assert kwargs["headers"]["Connection"] == "close"
|
||||
|
||||
|
||||
def test_run_ota_logs_device_response_body(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""The device's HTTP response body is surfaced on success."""
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
caplog.set_level(logging.INFO, logger="esphome.web_server_ota")
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
):
|
||||
run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert "Device response: Update Successful!" in caplog.text
|
||||
assert "OTA successful" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_log_says_via_web_server(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""The upload-start log line names the transport explicitly."""
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
caplog.set_level(logging.INFO, logger="esphome.web_server_ota")
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
):
|
||||
run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert "via web_server OTA" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_sends_basic_auth(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
) as post:
|
||||
exit_code, _ = run_ota(["192.168.1.50"], 80, "admin", "secret", firmware)
|
||||
|
||||
assert exit_code == 0
|
||||
auth = post.call_args.kwargs["auth"]
|
||||
assert isinstance(auth, HTTPBasicAuth)
|
||||
assert auth.username == "admin"
|
||||
assert auth.password == "secret"
|
||||
|
||||
|
||||
def test_run_ota_skips_auth_when_no_credentials(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
) as post:
|
||||
run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert post.call_args.kwargs["auth"] is None
|
||||
|
||||
|
||||
def test_run_ota_skips_auth_when_only_username(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""Both username and password are required to send Basic auth."""
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
) as post:
|
||||
run_ota(["192.168.1.50"], 80, "admin", None, firmware)
|
||||
|
||||
assert post.call_args.kwargs["auth"] is None
|
||||
|
||||
|
||||
def test_run_ota_uses_update_url(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 8080)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
) as post:
|
||||
run_ota(["192.168.1.50"], 8080, None, None, firmware)
|
||||
|
||||
url = post.call_args.args[0]
|
||||
assert url == f"http://192.168.1.50:8080{OTA_PATH}"
|
||||
assert OTA_PATH == "/update"
|
||||
|
||||
|
||||
def test_run_ota_failure_response(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Failed!"),
|
||||
):
|
||||
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
assert "OTA failure" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_failure_response_empty_body(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, ""),
|
||||
):
|
||||
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
assert "no response body" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_auth_failed(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(401, "Unauthorized"),
|
||||
):
|
||||
exit_code, host = run_ota(["192.168.1.50"], 80, "user", "wrong", firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
assert "Authentication failed" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_unexpected_status_code(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(500, "Internal Error"),
|
||||
):
|
||||
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
assert "Unexpected HTTP 500" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_unexpected_status_empty_body_falls_back(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Empty response body uses response.reason / a fallback in the error."""
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
response = _make_response(503, "")
|
||||
response.reason = "Service Unavailable"
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=response,
|
||||
):
|
||||
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
assert "Service Unavailable" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_unexpected_status_no_body_no_reason(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Empty body and empty reason still produce a usable error message."""
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
response = _make_response(599, "")
|
||||
response.reason = ""
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=response,
|
||||
):
|
||||
run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert "no response body" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_connection_error_then_success(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""First resolved address fails to connect, second succeeds."""
|
||||
_patch_resolve(
|
||||
monkeypatch,
|
||||
[("192.168.1.10", 80), ("192.168.1.50", 80)],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
side_effect=[
|
||||
requests.ConnectionError("refused"),
|
||||
_make_response(200, "Update Successful!"),
|
||||
],
|
||||
) as post:
|
||||
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.50"
|
||||
assert post.call_count == 2
|
||||
|
||||
|
||||
def test_run_ota_request_exception_falls_through(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""A non-ConnectionError RequestException (e.g. timeout) falls through too."""
|
||||
_patch_resolve(
|
||||
monkeypatch,
|
||||
[("192.168.1.10", 80), ("192.168.1.50", 80)],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
side_effect=[
|
||||
requests.Timeout("read timeout"),
|
||||
_make_response(200, "Update Successful!"),
|
||||
],
|
||||
):
|
||||
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.50"
|
||||
|
||||
|
||||
def test_run_ota_all_addresses_unreachable(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""When every resolved address fails to connect, run_ota returns failure."""
|
||||
_patch_resolve(
|
||||
monkeypatch,
|
||||
[("192.168.1.10", 80), ("192.168.1.20", 80)],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
side_effect=requests.ConnectionError("refused"),
|
||||
):
|
||||
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
# Per-address failure is logged for each attempt; final summary follows.
|
||||
assert caplog.text.count("OTA upload to ") >= 2
|
||||
assert "OTA upload failed." in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_no_resolved_addresses(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""If resolve_ip_address returns no candidates, log and return failure."""
|
||||
_patch_resolve(monkeypatch, [])
|
||||
|
||||
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
assert "Could not resolve 192.168.1.50" in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_resolution_failure(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise EsphomeError("dns failed")
|
||||
|
||||
monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _raise)
|
||||
|
||||
exit_code, host = run_ota(["does.not.exist"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
|
||||
|
||||
def test_run_ota_resolution_failure_dashboard_mode(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Dashboard mode skips the '--device <IP>' tip on resolution failure."""
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise EsphomeError("dns failed")
|
||||
|
||||
monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _raise)
|
||||
monkeypatch.setattr(CORE, "dashboard", True)
|
||||
try:
|
||||
exit_code, host = run_ota(["does.not.exist"], 80, None, None, firmware)
|
||||
finally:
|
||||
monkeypatch.setattr(CORE, "dashboard", False)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
assert "--device <IP>" not in caplog.text
|
||||
|
||||
|
||||
def test_run_ota_empty_hosts(firmware: Path) -> None:
|
||||
exit_code, host = run_ota([], 80, None, None, firmware)
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
|
||||
|
||||
def test_run_ota_string_host_accepted(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""A bare string is accepted in addition to a list of hosts."""
|
||||
_patch_resolve(monkeypatch, [("10.0.0.5", 80)])
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
):
|
||||
exit_code, host = run_ota("10.0.0.5", 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "10.0.0.5"
|
||||
|
||||
|
||||
def test_run_ota_multiple_hosts_first_fails(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""Multi-host fallthrough: first host's addresses all fail, second host wins."""
|
||||
addr_lookup = {
|
||||
"primary.local": [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.10", 80)),
|
||||
],
|
||||
"secondary.local": [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.50", 80)),
|
||||
],
|
||||
}
|
||||
|
||||
def _resolve(host, port, address_cache=None): # noqa: ARG001
|
||||
return addr_lookup[host]
|
||||
|
||||
monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _resolve)
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
side_effect=[
|
||||
requests.ConnectionError("refused"),
|
||||
_make_response(200, "Update Successful!"),
|
||||
],
|
||||
):
|
||||
exit_code, host = run_ota(
|
||||
["primary.local", "secondary.local"], 80, None, None, firmware
|
||||
)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "192.168.1.50"
|
||||
|
||||
|
||||
def test_run_ota_all_hosts_return_failure_no_exception(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""All hosts resolve to no addresses; run_ota cleanly returns failure."""
|
||||
addr_lookup = {
|
||||
"a.local": [],
|
||||
"b.local": [],
|
||||
}
|
||||
|
||||
def _resolve(host, port, address_cache=None): # noqa: ARG001
|
||||
return addr_lookup[host]
|
||||
|
||||
monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _resolve)
|
||||
|
||||
exit_code, host = run_ota(["a.local", "b.local"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 1
|
||||
assert host is None
|
||||
# Each host gets its own "Could not resolve" log line + final summary.
|
||||
assert caplog.text.count("Could not resolve") == 2
|
||||
assert "OTA upload failed." in caplog.text
|
||||
|
||||
|
||||
def test_web_server_ota_error_is_esphome_error() -> None:
|
||||
assert issubclass(WebServerOTAError, EsphomeError)
|
||||
|
||||
|
||||
def test_run_ota_finalizes_progress_bar_on_success(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""progress.done() fires on the success path (finally block)."""
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
done_called: list[bool] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
),
|
||||
patch.object(ProgressBar, "done", lambda self: done_called.append(True)),
|
||||
):
|
||||
run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert done_called
|
||||
|
||||
|
||||
def test_run_ota_finalizes_progress_bar_on_failure(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""progress.done() fires when the request itself raises (finally block)."""
|
||||
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
|
||||
|
||||
done_called: list[bool] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
side_effect=requests.ConnectionError("boom"),
|
||||
),
|
||||
patch.object(ProgressBar, "done", lambda self: done_called.append(True)),
|
||||
):
|
||||
run_ota(["192.168.1.50"], 80, None, None, firmware)
|
||||
|
||||
assert done_called
|
||||
|
||||
|
||||
def test_run_ota_ipv6_url_brackets_host(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""IPv6 candidates are bracketed in the URL so the port parses correctly."""
|
||||
addr_infos = [
|
||||
(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("2001:db8::1", 80, 0, 0)),
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
"esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos
|
||||
)
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
) as post:
|
||||
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 0
|
||||
assert host == "2001:db8::1"
|
||||
url = post.call_args.args[0]
|
||||
assert url == f"http://[2001:db8::1]:80{OTA_PATH}"
|
||||
|
||||
|
||||
def test_run_ota_ipv6_link_local_includes_scope_id(
|
||||
monkeypatch: pytest.MonkeyPatch, firmware: Path
|
||||
) -> None:
|
||||
"""Link-local IPv6 candidates include the percent-encoded zone index."""
|
||||
addr_infos = [
|
||||
(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("fe80::1", 80, 0, 3)),
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
"esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos
|
||||
)
|
||||
|
||||
with patch(
|
||||
"esphome.web_server_ota.requests.post",
|
||||
return_value=_make_response(200, "Update Successful!"),
|
||||
) as post:
|
||||
exit_code, _ = run_ota(["device.local"], 80, None, None, firmware)
|
||||
|
||||
assert exit_code == 0
|
||||
url = post.call_args.args[0]
|
||||
assert url == f"http://[fe80::1%253]:80{OTA_PATH}"
|
||||
Reference in New Issue
Block a user