[cli] Add --ota-platform flag to pick web_server or native API OTA (#16207)

This commit is contained in:
J. Nick Koston
2026-05-05 18:25:53 -05:00
committed by GitHub
parent be82e8faeb
commit f30ad588ea
4 changed files with 1307 additions and 14 deletions

View File

@@ -28,6 +28,7 @@ from esphome.const import (
ALLOWED_NAME_CHARS,
ARGUMENT_HELP_DEVICE,
CONF_API,
CONF_AUTH,
CONF_BAUD_RATE,
CONF_BROKER,
CONF_DEASSERT_RTS_DTR,
@@ -47,6 +48,8 @@ from esphome.const import (
CONF_PORT,
CONF_SUBSTITUTIONS,
CONF_TOPIC,
CONF_USERNAME,
CONF_WEB_SERVER,
ENV_NOGITIGNORE,
KEY_CORE,
KEY_NATIVE_IDF,
@@ -349,6 +352,17 @@ def choose_upload_log_host(
elif bootsel.permission_error:
bootsel_permission_error = True
# Annotate the OTA chooser entry only in the non-default case: when the
# config has web_server OTA but no native API OTA, the upload will fall
# through to the HTTP path and the user benefits from seeing that
# explicitly. The native-API path is the default and gets a plain label
# to avoid noise on the most common scenario. For LOGGING the OTA
# transport doesn't apply, so always leave the label plain.
if purpose == Purpose.UPLOADING and not has_native_ota() and has_web_server_ota():
ota_suffix = " via web_server"
else:
ota_suffix = ""
def add_ota_options() -> None:
"""Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled."""
if (discovered := _discover_mac_suffix_devices()) is not None:
@@ -356,11 +370,11 @@ def choose_upload_log_host(
# intentionally skip the base-name fallback since with
# name_add_mac_suffix on, the base name doesn't exist on the net.
for host in discovered:
options.append((f"Over The Air ({host})", host))
options.append((f"Over The Air{ota_suffix} ({host})", host))
elif has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
options.append((f"Over The Air{ota_suffix} ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
options.append((f"Over The Air{ota_suffix} (MQTT IP lookup)", "MQTTIP"))
if purpose == Purpose.LOGGING:
if has_mqtt_logging():
@@ -429,7 +443,19 @@ def has_api() -> bool:
def has_ota() -> bool:
"""Check if OTA upload is available (requires platform: esphome)."""
"""Check if any network OTA upload is available.
True if the config exposes either ``platform: esphome`` (native API
OTA) or ``platform: web_server`` (HTTP OTA). Both reach the device
over the same network stack, so the OTA discovery path treats them
interchangeably; ``upload_program`` picks the actual transport based
on ``--ota-platform`` and what's configured.
"""
return has_native_ota() or has_web_server_ota()
def has_native_ota() -> bool:
"""Check if native API OTA upload is available (``platform: esphome``)."""
if CONF_OTA not in CORE.config:
return False
return any(
@@ -438,6 +464,16 @@ def has_ota() -> bool:
)
def has_web_server_ota() -> bool:
"""Check if web_server OTA upload is available (``platform: web_server``)."""
if CONF_OTA not in CORE.config:
return False
return any(
ota_item.get(CONF_PLATFORM) == CONF_WEB_SERVER
for ota_item in CORE.config[CONF_OTA]
)
def has_mqtt_ip_lookup() -> bool:
"""Check if MQTT is available and IP lookup is supported."""
from esphome.components.mqtt import CONF_DISCOVER_IP
@@ -1115,25 +1151,83 @@ def upload_program(
return exit_code, host if exit_code == 0 else None
ota_conf = {}
requested_platform = getattr(args, "ota_platform", None)
chosen_platform = _choose_ota_platform(config, requested_platform)
# Resolve MQTT magic strings to actual IP addresses
network_devices = _resolve_network_devices(devices, config, args)
if chosen_platform == CONF_WEB_SERVER:
if getattr(args, "partition_table", False):
raise EsphomeError(
"--partition-table is only supported with the esphome OTA platform; "
"the web_server OTA path can only update the firmware image."
)
binary = CORE.firmware_bin
if getattr(args, "file", None) is not None:
binary = Path(args.file)
return _upload_via_web_server(config, network_devices, binary)
return _upload_via_native_api(config, network_devices, args)
def _choose_ota_platform(config: ConfigType, requested: str | None) -> str:
"""Pick the OTA platform to use, optionally honoring ``--ota-platform``.
Default behavior prefers ``esphome`` (native API) when it is configured.
The native API uses challenge-response auth with MD5/SHA256 hashing of a
server-issued nonce, so the password is never sent over the wire; the
``web_server`` path uses HTTP Basic auth which transmits credentials in
cleartext over the LAN. (The native path also supports gzip compression
on ESP8266, where flash space is tight; on ESP32/RP2040/LibreTiny the
backend reports ``supports_compression() == false`` and the firmware is
sent uncompressed regardless of which platform is used.) Falls back to
``web_server`` only when that is the only available platform.
"""
# Use a dict (insertion-ordered) instead of a list so error messages and
# membership checks see one entry per platform even if the user has
# multiple ``ota:`` items of the same platform; the web_server OTA
# platform's final-validate hook merges duplicates anyway.
available: dict[str, None] = {}
for ota_item in config.get(CONF_OTA, []):
if ota_item[CONF_PLATFORM] == CONF_ESPHOME:
platform = ota_item.get(CONF_PLATFORM)
if platform in (CONF_ESPHOME, CONF_WEB_SERVER):
available[platform] = None
if not available:
raise EsphomeError(
f"Cannot upload Over the Air as the {CONF_OTA} configuration is not "
f"present or does not include {CONF_PLATFORM}: {CONF_ESPHOME} or "
f"{CONF_PLATFORM}: {CONF_WEB_SERVER}"
)
if requested is not None:
if requested not in available:
raise EsphomeError(
f"--ota-platform {requested} was requested but the configuration "
f"only provides: {', '.join(available)}"
)
return requested
if CONF_ESPHOME in available:
return CONF_ESPHOME
return CONF_WEB_SERVER
def _upload_via_native_api(
config: ConfigType, network_devices: list[str], args: ArgsProtocol
) -> tuple[int, str | None]:
ota_conf: ConfigType = {}
for ota_item in config.get(CONF_OTA, []):
if ota_item.get(CONF_PLATFORM) == CONF_ESPHOME:
ota_conf = ota_item
break
if not ota_conf:
raise EsphomeError(
f"Cannot upload Over the Air as the {CONF_OTA} configuration is not present or does not include {CONF_PLATFORM}: {CONF_ESPHOME}"
)
from esphome import espota2
remote_port = int(ota_conf[CONF_PORT])
password = ota_conf.get(CONF_PASSWORD)
# Resolve MQTT magic strings to actual IP addresses
network_devices = _resolve_network_devices(devices, config, args)
binary = CORE.firmware_bin
ota_type = espota2.OTA_TYPE_UPDATE_APP
if getattr(args, "partition_table", False):
@@ -1157,6 +1251,28 @@ def upload_program(
return espota2.run_ota(network_devices, remote_port, password, binary, ota_type)
def _upload_via_web_server(
config: ConfigType, network_devices: list[str], binary: Path
) -> tuple[int, str | None]:
web_conf = config.get(CONF_WEB_SERVER)
if not web_conf:
raise EsphomeError(
f"Cannot upload via web_server OTA: the {CONF_WEB_SERVER} component "
f"is not configured."
)
remote_port = int(web_conf[CONF_PORT])
auth = web_conf.get(CONF_AUTH) or {}
username = auth.get(CONF_USERNAME)
password = auth.get(CONF_PASSWORD)
from esphome import web_server_ota
return web_server_ota.run_ota(
network_devices, remote_port, username, password, binary
)
# Layout of esp_partition_info_t on flash. Each entry is 32 bytes, leading with a
# 16-bit little-endian magic. ESP-IDF defines ESP_PARTITION_MAGIC = 0x50AA (stored as
# bytes 0xAA, 0x50) for partition entries and ESP_PARTITION_MAGIC_MD5 = 0xEBEB for the
@@ -1877,6 +1993,17 @@ def parse_args(argv):
"--file",
help="Manually specify the binary file to upload.",
)
parser_upload.add_argument(
"--ota-platform",
choices=[CONF_ESPHOME, CONF_WEB_SERVER],
help=(
"OTA platform to use for network uploads. Defaults to "
f"'{CONF_ESPHOME}' (native API) when configured because it uses "
"challenge-response auth so the password is never sent in "
f"cleartext on the wire. Falls back to '{CONF_WEB_SERVER}' "
"(HTTP Basic auth) when that is the only configured platform."
),
)
parser_upload.add_argument(
"--partition-table",
help="Upload as partition table (OTA).",
@@ -1951,6 +2078,17 @@ def parse_args(argv):
help="Build with native ESP-IDF instead of PlatformIO (ESP32 esp-idf framework only).",
action="store_true",
)
parser_run.add_argument(
"--ota-platform",
choices=[CONF_ESPHOME, CONF_WEB_SERVER],
help=(
"OTA platform to use for network uploads. Defaults to "
f"'{CONF_ESPHOME}' (native API) when configured because it uses "
"challenge-response auth so the password is never sent in "
f"cleartext on the wire. Falls back to '{CONF_WEB_SERVER}' "
"(HTTP Basic auth) when that is the only configured platform."
),
)
parser_clean = subparsers.add_parser(
"clean-mqtt",

202
esphome/web_server_ota.py Normal file
View File

@@ -0,0 +1,202 @@
"""HTTP-based OTA upload via the ``web_server`` component's ``/update`` endpoint.
This is the alternative to ``espota2`` (the native API OTA path). Useful when
a device only has ``platform: web_server`` configured under ``ota:``, or when
the user has lost the native OTA password but still has ``web_server`` basic
auth credentials.
"""
from __future__ import annotations
import io
import logging
from pathlib import Path
import secrets
import socket
from typing import BinaryIO
import requests
from requests.auth import HTTPBasicAuth
from esphome.core import EsphomeError
from esphome.helpers import ProgressBar, resolve_ip_address
_LOGGER = logging.getLogger(__name__)
OTA_PATH = "/update"
FORM_FIELD = "update"
# (connect_timeout, read_timeout). The device reboots after a successful
# upload so the read side must allow for a slow flash + response.
TIMEOUT = (20.0, 120.0)
class WebServerOTAError(EsphomeError):
pass
class _MultipartStreamer:
"""Stream a single-file multipart/form-data body during transmission.
``requests.post(files=...)`` materializes the entire body in memory before
sending, so a progress callback wired into the file-like fires during
encoding instead of during the network send. Pass this via ``data=``
(with ``__len__`` so urllib3 sets ``Content-Length`` instead of using
chunked transfer encoding); urllib3 then calls ``read(blocksize)``
repeatedly during the POST and the progress bar tracks bytes leaving the
host.
"""
def __init__(self, file: BinaryIO, file_size: int, filename: str) -> None:
self.boundary = f"esphomeOTA{secrets.token_hex(16)}"
prefix = (
f"--{self.boundary}\r\n"
f'Content-Disposition: form-data; name="{FORM_FIELD}"; '
f'filename="{filename}"\r\n'
f"Content-Type: application/octet-stream\r\n\r\n"
).encode()
suffix = f"\r\n--{self.boundary}--\r\n".encode()
# Walked in order; ``read()`` advances to the next source on EOF.
self._sources: list[BinaryIO] = [io.BytesIO(prefix), file, io.BytesIO(suffix)]
self._idx = 0
self._total = len(prefix) + file_size + len(suffix)
self._sent = 0
self.progress = ProgressBar()
def __len__(self) -> int:
return self._total
@property
def content_type(self) -> str:
return f"multipart/form-data; boundary={self.boundary}"
def read(self, size: int = -1) -> bytes:
remaining = self._total if size is None or size < 0 else size
out = bytearray()
while remaining > 0 and self._idx < len(self._sources):
chunk = self._sources[self._idx].read(remaining)
if not chunk:
self._idx += 1
continue
out += chunk
remaining -= len(chunk)
if out:
self._sent += len(out)
self.progress.update(self._sent / self._total)
return bytes(out)
def _try_upload(
host: str,
port: int,
username: str | None,
password: str | None,
filename: Path,
) -> tuple[int, str | None]:
from esphome.core import CORE
try:
addr_infos = resolve_ip_address(host, port, address_cache=CORE.address_cache)
except EsphomeError as err:
_LOGGER.error(
"Error resolving IP address of %s. Is it connected to WiFi?", host
)
if not CORE.dashboard:
_LOGGER.error("(If you know the IP, try --device <IP>)")
raise WebServerOTAError(err) from err
if not addr_infos:
_LOGGER.error("Could not resolve %s", host)
return 1, None
file_size = filename.stat().st_size
_LOGGER.info("Uploading %s (%s bytes) via web_server OTA", filename, file_size)
auth = HTTPBasicAuth(username, password) if username and password else None
# Iterate resolved IPs (IPv4 + IPv6 candidates) just like espota2 does.
for af, _socktype, _, _, sa in addr_infos:
ip = sa[0]
# IPv6 literals must be wrapped in brackets in URLs; link-local
# addresses need a percent-encoded zone index per RFC 6874.
if af == socket.AF_INET6:
scope = sa[3] if len(sa) >= 4 else 0
host_part = f"[{ip}%25{scope}]" if scope else f"[{ip}]"
else:
host_part = ip
url = f"http://{host_part}:{port}{OTA_PATH}"
_LOGGER.info("Connecting to %s port %s...", ip, port)
try:
with open(filename, "rb") as fh:
streamer = _MultipartStreamer(fh, file_size, filename.name)
try:
response = requests.post(
url,
data=streamer,
auth=auth,
timeout=TIMEOUT,
headers={
"Content-Type": streamer.content_type,
"Connection": "close",
},
)
finally:
streamer.progress.done()
except requests.RequestException as err:
_LOGGER.error("OTA upload to %s port %s failed: %s", ip, port, err)
continue
if response.status_code == 401:
raise WebServerOTAError(
"Authentication failed (HTTP 401). Check the 'web_server' "
"'auth' username and password."
)
if response.status_code != 200:
detail = response.text.strip() or response.reason or "no response body"
raise WebServerOTAError(
f"Unexpected HTTP {response.status_code} response from device: {detail}"
)
# The endpoint returns HTTP 200 for both success and failure; the
# body is what tells us which (see ota_web_server.cpp handleRequest).
body = response.text.strip()
if "Successful" in body:
_LOGGER.info("Device response: %s", body)
_LOGGER.info("OTA successful")
return 0, ip
raise WebServerOTAError(
f"Device reported OTA failure: {body or 'no response body'}"
)
return 1, None
def run_ota(
remote_hosts: str | list[str],
remote_port: int,
username: str | None,
password: str | None,
filename: Path,
) -> tuple[int, str | None]:
"""Upload ``filename`` to the first reachable host via ``web_server`` OTA.
Mirrors :func:`esphome.espota2.run_ota` so callers can swap between the
two paths with the same return contract: ``(0, host)`` on success or
``(1, None)`` on failure.
"""
hosts = [remote_hosts] if isinstance(remote_hosts, str) else list(remote_hosts)
for host in hosts:
try:
exit_code, used_host = _try_upload(
host, remote_port, username, password, filename
)
except WebServerOTAError as err:
_LOGGER.error("%s", err)
continue
if exit_code == 0:
return 0, used_host
# Reached only when every attempt failed; per-attempt errors were
# already logged. This summary line gives the user an unambiguous
# "stop reading, nothing worked" marker.
_LOGGER.error("OTA upload failed.")
return 1, None

View File

@@ -43,6 +43,7 @@ from esphome.__main__ import (
has_non_ip_address,
has_ota,
has_resolvable_address,
has_web_server_ota,
mqtt_get_ip,
run_esphome,
run_miniterm,
@@ -58,6 +59,7 @@ from esphome.components import esp32
from esphome.components.esp32 import KEY_ESP32, KEY_VARIANT, VARIANT_ESP32
from esphome.const import (
CONF_API,
CONF_AUTH,
CONF_BAUD_RATE,
CONF_BROKER,
CONF_DISABLED,
@@ -76,6 +78,8 @@ from esphome.const import (
CONF_SUBSTITUTIONS,
CONF_TOPIC,
CONF_USE_ADDRESS,
CONF_USERNAME,
CONF_WEB_SERVER,
CONF_WIFI,
KEY_CORE,
KEY_TARGET_PLATFORM,
@@ -213,6 +217,13 @@ def mock_run_ota() -> Generator[Mock]:
yield mock
@pytest.fixture
def mock_run_web_server_ota() -> Generator[Mock]:
"""Mock web_server_ota.run_ota for testing."""
with patch("esphome.web_server_ota.run_ota") as mock:
yield mock
@pytest.fixture
def mock_is_ip_address() -> Generator[Mock]:
"""Mock is_ip_address for testing."""
@@ -1114,6 +1125,7 @@ class MockArgs:
reset: bool = False
list_only: bool = False
output: str | None = None
ota_platform: str | None = None
partition_table: bool = False
@@ -1878,6 +1890,277 @@ def test_upload_program_ota_no_config(
upload_program(config, args, devices)
def test_has_web_server_ota_detects_platform() -> None:
"""has_web_server_ota returns True when web_server OTA platform is configured."""
setup_core(
config={
CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}],
}
)
assert has_web_server_ota() is True
assert has_ota() is True
def test_has_web_server_ota_returns_false_without_config() -> None:
"""has_web_server_ota returns False when only native OTA is configured."""
setup_core(
config={
CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}],
}
)
assert has_web_server_ota() is False
assert has_ota() is True
def test_upload_program_web_server_only_auto_dispatches(
mock_run_web_server_ota: Mock,
mock_run_ota: Mock,
mock_get_port_type: Mock,
tmp_path: Path,
) -> None:
"""When only web_server OTA is configured, upload_program picks it automatically."""
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
mock_get_port_type.return_value = "NETWORK"
mock_run_web_server_ota.return_value = (0, "192.168.1.100")
config = {
CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}],
CONF_WEB_SERVER: {
CONF_PORT: 80,
CONF_AUTH: {CONF_USERNAME: "admin", CONF_PASSWORD: "pw"},
},
}
args = MockArgs()
devices = ["192.168.1.100"]
exit_code, host = upload_program(config, args, devices)
assert exit_code == 0
assert host == "192.168.1.100"
expected_firmware = (
tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin"
)
mock_run_web_server_ota.assert_called_once_with(
["192.168.1.100"], 80, "admin", "pw", expected_firmware
)
mock_run_ota.assert_not_called()
def test_upload_program_web_server_no_auth(
mock_run_web_server_ota: Mock,
mock_get_port_type: Mock,
tmp_path: Path,
) -> None:
"""web_server OTA works without an auth block (passes None for credentials)."""
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
mock_get_port_type.return_value = "NETWORK"
mock_run_web_server_ota.return_value = (0, "192.168.1.100")
config = {
CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}],
CONF_WEB_SERVER: {CONF_PORT: 8080},
}
args = MockArgs()
devices = ["192.168.1.100"]
exit_code, host = upload_program(config, args, devices)
assert exit_code == 0
assert host == "192.168.1.100"
expected_firmware = (
tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin"
)
mock_run_web_server_ota.assert_called_once_with(
["192.168.1.100"], 8080, None, None, expected_firmware
)
def test_upload_program_both_platforms_default_prefers_native(
mock_run_ota: Mock,
mock_run_web_server_ota: Mock,
mock_get_port_type: Mock,
tmp_path: Path,
) -> None:
"""When both OTA platforms are configured, default selection is native API."""
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
mock_get_port_type.return_value = "NETWORK"
mock_run_ota.return_value = (0, "192.168.1.100")
config = {
CONF_OTA: [
{
CONF_PLATFORM: CONF_ESPHOME,
CONF_PORT: 3232,
CONF_PASSWORD: "secret",
},
{CONF_PLATFORM: CONF_WEB_SERVER},
],
CONF_WEB_SERVER: {CONF_PORT: 80},
}
args = MockArgs()
devices = ["192.168.1.100"]
exit_code, host = upload_program(config, args, devices)
assert exit_code == 0
assert host == "192.168.1.100"
mock_run_ota.assert_called_once()
mock_run_web_server_ota.assert_not_called()
def test_upload_program_ota_platform_override_to_web_server(
mock_run_ota: Mock,
mock_run_web_server_ota: Mock,
mock_get_port_type: Mock,
tmp_path: Path,
) -> None:
"""--ota-platform web_server forces web_server OTA even when native is configured."""
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
mock_get_port_type.return_value = "NETWORK"
mock_run_web_server_ota.return_value = (0, "192.168.1.100")
config = {
CONF_OTA: [
{
CONF_PLATFORM: CONF_ESPHOME,
CONF_PORT: 3232,
CONF_PASSWORD: "secret",
},
{CONF_PLATFORM: CONF_WEB_SERVER},
],
CONF_WEB_SERVER: {CONF_PORT: 80},
}
args = MockArgs(ota_platform=CONF_WEB_SERVER)
devices = ["192.168.1.100"]
exit_code, host = upload_program(config, args, devices)
assert exit_code == 0
assert host == "192.168.1.100"
mock_run_ota.assert_not_called()
mock_run_web_server_ota.assert_called_once()
def test_upload_program_ota_platform_unavailable(
mock_get_port_type: Mock,
) -> None:
"""--ota-platform must reference a platform that is actually configured."""
setup_core(platform=PLATFORM_ESP32)
mock_get_port_type.return_value = "NETWORK"
config = {
CONF_OTA: [
{
CONF_PLATFORM: CONF_ESPHOME,
CONF_PORT: 3232,
CONF_PASSWORD: "secret",
}
],
}
args = MockArgs(ota_platform=CONF_WEB_SERVER)
devices = ["192.168.1.100"]
with pytest.raises(EsphomeError, match="--ota-platform web_server"):
upload_program(config, args, devices)
def test_upload_program_web_server_missing_component(
mock_get_port_type: Mock,
tmp_path: Path,
) -> None:
"""web_server OTA without a web_server component fails with a clear error."""
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
mock_get_port_type.return_value = "NETWORK"
config = {
CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}],
# No CONF_WEB_SERVER
}
args = MockArgs()
devices = ["192.168.1.100"]
with pytest.raises(EsphomeError, match="web_server.*not configured"):
upload_program(config, args, devices)
def test_upload_program_unrelated_ota_platform_ignored(
mock_run_ota: Mock,
mock_get_port_type: Mock,
tmp_path: Path,
) -> None:
"""OTA list entries that are neither esphome nor web_server are ignored.
Covers the false branch in _choose_ota_platform's filter loop and the
no-match branch in _upload_via_native_api's lookup loop.
"""
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
mock_get_port_type.return_value = "NETWORK"
mock_run_ota.return_value = (0, "192.168.1.100")
config = {
CONF_OTA: [
{CONF_PLATFORM: "http_request"}, # unrelated platform; ignored
{
CONF_PLATFORM: CONF_ESPHOME,
CONF_PORT: 3232,
CONF_PASSWORD: "secret",
},
],
}
args = MockArgs()
devices = ["192.168.1.100"]
exit_code, host = upload_program(config, args, devices)
assert exit_code == 0
assert host == "192.168.1.100"
mock_run_ota.assert_called_once()
def test_upload_program_duplicate_platform_dedup_in_error(
mock_get_port_type: Mock,
tmp_path: Path,
) -> None:
"""Duplicate same-platform OTA entries don't repeat in --ota-platform errors."""
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
mock_get_port_type.return_value = "NETWORK"
config = {
CONF_OTA: [
{CONF_PLATFORM: CONF_ESPHOME, CONF_PORT: 3232},
{CONF_PLATFORM: CONF_ESPHOME, CONF_PORT: 3233},
],
}
args = MockArgs(ota_platform=CONF_WEB_SERVER)
devices = ["192.168.1.100"]
with pytest.raises(EsphomeError) as excinfo:
upload_program(config, args, devices)
# Error mentions esphome once in the platform list, not "esphome, esphome".
msg = str(excinfo.value)
assert "esphome, esphome" not in msg
assert msg.endswith(": esphome")
def test_upload_program_only_unrelated_ota_platforms(
mock_get_port_type: Mock,
tmp_path: Path,
) -> None:
"""Only unrelated OTA platforms configured -> raises like missing OTA."""
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path)
mock_get_port_type.return_value = "NETWORK"
config = {
CONF_OTA: [{CONF_PLATFORM: "http_request"}],
}
args = MockArgs()
devices = ["192.168.1.100"]
with pytest.raises(EsphomeError, match="Cannot upload Over the Air"):
upload_program(config, args, devices)
def test_upload_program_ota_with_mqtt_resolution(
mock_mqtt_get_ip: Mock,
mock_is_ip_address: Mock,

View File

@@ -0,0 +1,670 @@
"""Unit tests for esphome.web_server_ota module."""
from __future__ import annotations
import io
import logging
from pathlib import Path
import socket
from unittest.mock import MagicMock, patch
import pytest
import requests
from requests.auth import HTTPBasicAuth
from esphome.core import CORE, EsphomeError
from esphome.helpers import ProgressBar
from esphome.web_server_ota import (
OTA_PATH,
WebServerOTAError,
_MultipartStreamer,
run_ota,
)
@pytest.fixture
def firmware(tmp_path: Path) -> Path:
binary = tmp_path / "firmware.bin"
binary.write_bytes(b"\x00\x01\x02FIRMWARE\xff" * 64)
return binary
def _make_response(status: int, body: str) -> MagicMock:
response = MagicMock(spec=requests.Response)
response.status_code = status
response.text = body
response.reason = ""
return response
def _patch_resolve(
monkeypatch: pytest.MonkeyPatch, hosts: list[tuple[str, int]]
) -> None:
"""Replace resolve_ip_address so tests don't actually do DNS."""
addr_infos = [
(socket.AF_INET, socket.SOCK_STREAM, 0, "", (host, port))
for host, port in hosts
]
monkeypatch.setattr(
"esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos
)
# ---------------------------------------------------------------------------
# _MultipartStreamer
# ---------------------------------------------------------------------------
def test_multipart_streamer_emits_full_body() -> None:
"""Streaming the whole body in one call yields prefix + file + suffix."""
data = b"abcdef" * 100
streamer = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin")
body = streamer.read()
while True:
chunk = streamer.read()
if not chunk:
break
body += chunk
assert body.startswith(f"--{streamer.boundary}\r\n".encode())
assert b'name="update"' in body
assert b'filename="fw.bin"' in body
assert data in body
assert body.endswith(f"\r\n--{streamer.boundary}--\r\n".encode())
def test_multipart_streamer_chunked_read_matches_full_read() -> None:
"""Chunked reads (urllib3 calls read(8192) repeatedly) yield the same body."""
data = b"abcdef" * 1000 # 6000 bytes
full = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin").read()
streamed = bytearray()
s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin")
# Same boundary lengths -> identical total length.
while True:
chunk = s.read(64)
if not chunk:
break
streamed += chunk
# Boundaries are random per instance, so compare lengths and structure.
assert len(streamed) == len(full)
assert streamed.startswith(f"--{s.boundary}\r\n".encode())
assert streamed.endswith(f"\r\n--{s.boundary}--\r\n".encode())
def test_multipart_streamer_len_matches_emitted_bytes() -> None:
"""``__len__`` is what urllib3 uses to set Content-Length, so it must
equal the total bytes emitted by ``read``."""
data = b"x" * 12345
s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin")
declared = len(s)
emitted = 0
while True:
chunk = s.read(1024)
if not chunk:
break
emitted += len(chunk)
assert emitted == declared
def test_multipart_streamer_progress_ticks_during_read() -> None:
"""Each read advances the progress bar (this is the whole point of
streaming via ``data=``: progress reflects bytes leaving the host)."""
data = b"x" * 1000
s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin")
updates: list[float] = []
s.progress.update = updates.append # type: ignore[method-assign]
while True:
chunk = s.read(128)
if not chunk:
break
assert updates, "progress.update was never called"
# Strictly non-decreasing.
assert updates == sorted(updates)
# Final update reaches (within FP) 1.0 because all bytes were read.
assert updates[-1] == pytest.approx(1.0, abs=1e-9)
def test_multipart_streamer_content_type_includes_boundary() -> None:
s = _MultipartStreamer(io.BytesIO(b""), 0, "fw.bin")
assert s.content_type == f"multipart/form-data; boundary={s.boundary}"
def test_multipart_streamer_zero_size_file() -> None:
"""A zero-byte file still produces a well-formed body and progress is
skipped (avoiding a divide-by-zero on the empty file segment)."""
s = _MultipartStreamer(io.BytesIO(b""), 0, "empty.bin")
body = b""
while True:
chunk = s.read(64)
if not chunk:
break
body += chunk
assert body.startswith(f"--{s.boundary}".encode())
assert body.endswith(f"--{s.boundary}--\r\n".encode())
def test_multipart_streamer_unique_boundary_per_instance() -> None:
a = _MultipartStreamer(io.BytesIO(b""), 0, "a")
b = _MultipartStreamer(io.BytesIO(b""), 0, "a")
assert a.boundary != b.boundary
def test_multipart_streamer_zero_size_read_returns_empty() -> None:
"""``read(0)`` short-circuits without touching state."""
s = _MultipartStreamer(io.BytesIO(b"x" * 10), 10, "fw.bin")
assert s.read(0) == b""
# No bytes consumed.
assert s._sent == 0
# ---------------------------------------------------------------------------
# run_ota
# ---------------------------------------------------------------------------
def test_run_ota_success(monkeypatch: pytest.MonkeyPatch, firmware: Path) -> None:
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
) as post:
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
assert exit_code == 0
assert host == "192.168.1.50"
post.assert_called_once()
args, kwargs = post.call_args
assert args == (f"http://192.168.1.50:80{OTA_PATH}",)
assert kwargs["auth"] is None
# Streaming body, not files=, so progress fires during transmission.
assert "files" not in kwargs
assert isinstance(kwargs["data"], _MultipartStreamer)
assert kwargs["headers"]["Content-Type"] == kwargs["data"].content_type
assert kwargs["headers"]["Connection"] == "close"
def test_run_ota_logs_device_response_body(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""The device's HTTP response body is surfaced on success."""
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
caplog.set_level(logging.INFO, logger="esphome.web_server_ota")
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
):
run_ota(["192.168.1.50"], 80, None, None, firmware)
assert "Device response: Update Successful!" in caplog.text
assert "OTA successful" in caplog.text
def test_run_ota_log_says_via_web_server(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""The upload-start log line names the transport explicitly."""
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
caplog.set_level(logging.INFO, logger="esphome.web_server_ota")
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
):
run_ota(["192.168.1.50"], 80, None, None, firmware)
assert "via web_server OTA" in caplog.text
def test_run_ota_sends_basic_auth(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
) as post:
exit_code, _ = run_ota(["192.168.1.50"], 80, "admin", "secret", firmware)
assert exit_code == 0
auth = post.call_args.kwargs["auth"]
assert isinstance(auth, HTTPBasicAuth)
assert auth.username == "admin"
assert auth.password == "secret"
def test_run_ota_skips_auth_when_no_credentials(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
) as post:
run_ota(["192.168.1.50"], 80, None, None, firmware)
assert post.call_args.kwargs["auth"] is None
def test_run_ota_skips_auth_when_only_username(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""Both username and password are required to send Basic auth."""
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
) as post:
run_ota(["192.168.1.50"], 80, "admin", None, firmware)
assert post.call_args.kwargs["auth"] is None
def test_run_ota_uses_update_url(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
_patch_resolve(monkeypatch, [("192.168.1.50", 8080)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
) as post:
run_ota(["192.168.1.50"], 8080, None, None, firmware)
url = post.call_args.args[0]
assert url == f"http://192.168.1.50:8080{OTA_PATH}"
assert OTA_PATH == "/update"
def test_run_ota_failure_response(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Failed!"),
):
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
assert exit_code == 1
assert host is None
assert "OTA failure" in caplog.text
def test_run_ota_failure_response_empty_body(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, ""),
):
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
assert exit_code == 1
assert host is None
assert "no response body" in caplog.text
def test_run_ota_auth_failed(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(401, "Unauthorized"),
):
exit_code, host = run_ota(["192.168.1.50"], 80, "user", "wrong", firmware)
assert exit_code == 1
assert host is None
assert "Authentication failed" in caplog.text
def test_run_ota_unexpected_status_code(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(500, "Internal Error"),
):
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
assert exit_code == 1
assert host is None
assert "Unexpected HTTP 500" in caplog.text
def test_run_ota_unexpected_status_empty_body_falls_back(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Empty response body uses response.reason / a fallback in the error."""
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
response = _make_response(503, "")
response.reason = "Service Unavailable"
with patch(
"esphome.web_server_ota.requests.post",
return_value=response,
):
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
assert exit_code == 1
assert host is None
assert "Service Unavailable" in caplog.text
def test_run_ota_unexpected_status_no_body_no_reason(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Empty body and empty reason still produce a usable error message."""
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
response = _make_response(599, "")
response.reason = ""
with patch(
"esphome.web_server_ota.requests.post",
return_value=response,
):
run_ota(["192.168.1.50"], 80, None, None, firmware)
assert "no response body" in caplog.text
def test_run_ota_connection_error_then_success(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""First resolved address fails to connect, second succeeds."""
_patch_resolve(
monkeypatch,
[("192.168.1.10", 80), ("192.168.1.50", 80)],
)
with patch(
"esphome.web_server_ota.requests.post",
side_effect=[
requests.ConnectionError("refused"),
_make_response(200, "Update Successful!"),
],
) as post:
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
assert exit_code == 0
assert host == "192.168.1.50"
assert post.call_count == 2
def test_run_ota_request_exception_falls_through(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""A non-ConnectionError RequestException (e.g. timeout) falls through too."""
_patch_resolve(
monkeypatch,
[("192.168.1.10", 80), ("192.168.1.50", 80)],
)
with patch(
"esphome.web_server_ota.requests.post",
side_effect=[
requests.Timeout("read timeout"),
_make_response(200, "Update Successful!"),
],
):
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
assert exit_code == 0
assert host == "192.168.1.50"
def test_run_ota_all_addresses_unreachable(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""When every resolved address fails to connect, run_ota returns failure."""
_patch_resolve(
monkeypatch,
[("192.168.1.10", 80), ("192.168.1.20", 80)],
)
with patch(
"esphome.web_server_ota.requests.post",
side_effect=requests.ConnectionError("refused"),
):
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
assert exit_code == 1
assert host is None
# Per-address failure is logged for each attempt; final summary follows.
assert caplog.text.count("OTA upload to ") >= 2
assert "OTA upload failed." in caplog.text
def test_run_ota_no_resolved_addresses(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""If resolve_ip_address returns no candidates, log and return failure."""
_patch_resolve(monkeypatch, [])
exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware)
assert exit_code == 1
assert host is None
assert "Could not resolve 192.168.1.50" in caplog.text
def test_run_ota_resolution_failure(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
def _raise(*_args, **_kwargs):
raise EsphomeError("dns failed")
monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _raise)
exit_code, host = run_ota(["does.not.exist"], 80, None, None, firmware)
assert exit_code == 1
assert host is None
def test_run_ota_resolution_failure_dashboard_mode(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Dashboard mode skips the '--device <IP>' tip on resolution failure."""
def _raise(*_args, **_kwargs):
raise EsphomeError("dns failed")
monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _raise)
monkeypatch.setattr(CORE, "dashboard", True)
try:
exit_code, host = run_ota(["does.not.exist"], 80, None, None, firmware)
finally:
monkeypatch.setattr(CORE, "dashboard", False)
assert exit_code == 1
assert host is None
assert "--device <IP>" not in caplog.text
def test_run_ota_empty_hosts(firmware: Path) -> None:
exit_code, host = run_ota([], 80, None, None, firmware)
assert exit_code == 1
assert host is None
def test_run_ota_string_host_accepted(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""A bare string is accepted in addition to a list of hosts."""
_patch_resolve(monkeypatch, [("10.0.0.5", 80)])
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
):
exit_code, host = run_ota("10.0.0.5", 80, None, None, firmware)
assert exit_code == 0
assert host == "10.0.0.5"
def test_run_ota_multiple_hosts_first_fails(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""Multi-host fallthrough: first host's addresses all fail, second host wins."""
addr_lookup = {
"primary.local": [
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.10", 80)),
],
"secondary.local": [
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.50", 80)),
],
}
def _resolve(host, port, address_cache=None): # noqa: ARG001
return addr_lookup[host]
monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _resolve)
with patch(
"esphome.web_server_ota.requests.post",
side_effect=[
requests.ConnectionError("refused"),
_make_response(200, "Update Successful!"),
],
):
exit_code, host = run_ota(
["primary.local", "secondary.local"], 80, None, None, firmware
)
assert exit_code == 0
assert host == "192.168.1.50"
def test_run_ota_all_hosts_return_failure_no_exception(
monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""All hosts resolve to no addresses; run_ota cleanly returns failure."""
addr_lookup = {
"a.local": [],
"b.local": [],
}
def _resolve(host, port, address_cache=None): # noqa: ARG001
return addr_lookup[host]
monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _resolve)
exit_code, host = run_ota(["a.local", "b.local"], 80, None, None, firmware)
assert exit_code == 1
assert host is None
# Each host gets its own "Could not resolve" log line + final summary.
assert caplog.text.count("Could not resolve") == 2
assert "OTA upload failed." in caplog.text
def test_web_server_ota_error_is_esphome_error() -> None:
assert issubclass(WebServerOTAError, EsphomeError)
def test_run_ota_finalizes_progress_bar_on_success(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""progress.done() fires on the success path (finally block)."""
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
done_called: list[bool] = []
with (
patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
),
patch.object(ProgressBar, "done", lambda self: done_called.append(True)),
):
run_ota(["192.168.1.50"], 80, None, None, firmware)
assert done_called
def test_run_ota_finalizes_progress_bar_on_failure(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""progress.done() fires when the request itself raises (finally block)."""
_patch_resolve(monkeypatch, [("192.168.1.50", 80)])
done_called: list[bool] = []
with (
patch(
"esphome.web_server_ota.requests.post",
side_effect=requests.ConnectionError("boom"),
),
patch.object(ProgressBar, "done", lambda self: done_called.append(True)),
):
run_ota(["192.168.1.50"], 80, None, None, firmware)
assert done_called
def test_run_ota_ipv6_url_brackets_host(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""IPv6 candidates are bracketed in the URL so the port parses correctly."""
addr_infos = [
(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("2001:db8::1", 80, 0, 0)),
]
monkeypatch.setattr(
"esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos
)
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
) as post:
exit_code, host = run_ota(["device.local"], 80, None, None, firmware)
assert exit_code == 0
assert host == "2001:db8::1"
url = post.call_args.args[0]
assert url == f"http://[2001:db8::1]:80{OTA_PATH}"
def test_run_ota_ipv6_link_local_includes_scope_id(
monkeypatch: pytest.MonkeyPatch, firmware: Path
) -> None:
"""Link-local IPv6 candidates include the percent-encoded zone index."""
addr_infos = [
(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("fe80::1", 80, 0, 3)),
]
monkeypatch.setattr(
"esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos
)
with patch(
"esphome.web_server_ota.requests.post",
return_value=_make_response(200, "Update Successful!"),
) as post:
exit_code, _ = run_ota(["device.local"], 80, None, None, firmware)
assert exit_code == 0
url = post.call_args.args[0]
assert url == f"http://[fe80::1%253]:80{OTA_PATH}"