[core] Allow finding all devices as target that match mac suffix (#13135)

This commit is contained in:
Paulus Schoutsen
2026-04-23 09:43:32 -04:00
committed by GitHub
parent 70ae614abd
commit 9b45b046a8
8 changed files with 912 additions and 70 deletions

View File

@@ -39,6 +39,7 @@ from esphome.const import (
CONF_MDNS,
CONF_MQTT,
CONF_NAME,
CONF_NAME_ADD_MAC_SUFFIX,
CONF_OTA,
CONF_PASSWORD,
CONF_PLATFORM,
@@ -71,6 +72,7 @@ from esphome.util import (
run_external_process,
safe_print,
)
from esphome.zeroconf import discover_mdns_devices
_LOGGER = logging.getLogger(__name__)
@@ -204,6 +206,64 @@ def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]:
return [address]
def _populate_mdns_cache(hosts_to_addresses: dict[str, list[str]]) -> None:
"""Store discovered ``host -> [ips]`` entries in ``CORE.address_cache``.
Ensures ``CORE.address_cache`` exists, then records each mDNS hostname so
the downstream resolution path (``resolve_ip_address``) can skip opening a
second Zeroconf client.
"""
from esphome.address_cache import AddressCache
if CORE.address_cache is None:
CORE.address_cache = AddressCache()
for host, addresses in hosts_to_addresses.items():
if addresses:
_LOGGER.debug("Caching mDNS result %s -> %s", host, addresses)
CORE.address_cache.add_mdns_addresses(host, addresses)
def _discover_mac_suffix_devices() -> list[str] | None:
"""Discover ``<name>-<mac>.local`` devices and cache their IPs.
Returns:
- ``None`` when discovery isn't applicable (``name_add_mac_suffix`` off,
mDNS disabled, or ``CORE.address`` is already an IP). Callers should
then fall back to whatever default OTA address they normally use.
- ``[]`` when discovery ran but found nothing. Callers should NOT fall
back to the base name: with ``name_add_mac_suffix`` enabled, the base
name by definition doesn't exist on the network.
- A non-empty sorted list of ``.local`` hostnames on success.
Populates ``CORE.address_cache`` so downstream resolution (``espota2`` or
``aioesphomeapi`` via :func:`_resolve_network_devices`) reuses the IPs we
already have without opening a second Zeroconf client.
"""
if not (has_name_add_mac_suffix() and has_mdns() and has_non_ip_address()):
return None
_LOGGER.info("Discovering devices...")
if not (discovered := discover_mdns_devices(CORE.name)):
_LOGGER.warning(
"No devices matching '%s-<mac>.local' were discovered.", CORE.name
)
return []
_populate_mdns_cache(discovered)
return list(discovered)
def _ota_hostnames_for_default(purpose: Purpose) -> list[str]:
"""Return OTA hostname(s) for the ``--device OTA`` / default-resolve path.
When ``name_add_mac_suffix`` is enabled, returns discovered
``<name>-<mac>.local`` hostnames (possibly empty — in which case the
caller should not fall back to the base name). Otherwise falls back to
the cache-resolved ``CORE.address``.
"""
if (discovered := _discover_mac_suffix_devices()) is not None:
return discovered
return _resolve_with_cache(CORE.address, purpose)
def choose_upload_log_host(
default: list[str] | str | None,
check_default: str | None,
@@ -242,14 +302,14 @@ def choose_upload_log_host(
resolved.append("MQTT")
if has_api() and has_non_ip_address() and has_resolvable_address():
resolved.extend(_resolve_with_cache(CORE.address, purpose))
resolved.extend(_ota_hostnames_for_default(purpose))
elif purpose == Purpose.UPLOADING:
if has_ota() and has_mqtt_ip_lookup():
resolved.append("MQTTIP")
if has_ota() and has_non_ip_address() and has_resolvable_address():
resolved.extend(_resolve_with_cache(CORE.address, purpose))
resolved.extend(_ota_hostnames_for_default(purpose))
else:
resolved.append(device)
if not resolved:
@@ -281,22 +341,29 @@ def choose_upload_log_host(
elif bootsel.permission_error:
bootsel_permission_error = True
def add_ota_options() -> None:
"""Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled."""
if (discovered := _discover_mac_suffix_devices()) is not None:
# Discovery was applicable. Use whatever we found — on empty,
# intentionally skip the base-name fallback since with
# name_add_mac_suffix on, the base name doesn't exist on the net.
for host in discovered:
options.append((f"Over The Air ({host})", host))
elif has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
if purpose == Purpose.LOGGING:
if has_mqtt_logging():
mqtt_config = CORE.config[CONF_MQTT]
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
if has_api():
if has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
add_ota_options()
elif purpose == Purpose.UPLOADING and has_ota():
if has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
add_ota_options()
# Show helpful BOOTSEL instructions for RP2040 when no BOOTSEL device is found
if (
@@ -407,7 +474,17 @@ def has_resolvable_address() -> bool:
return not CORE.address.endswith(".local")
def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str):
def has_name_add_mac_suffix() -> bool:
"""Check if name_add_mac_suffix is enabled in the config."""
if CORE.config is None:
return False
esphome_config = CORE.config.get(CONF_ESPHOME, {})
return esphome_config.get(CONF_NAME_ADD_MAC_SUFFIX, False)
def mqtt_get_ip(
config: ConfigType, username: str, password: str, client_id: str
) -> list[str]:
from esphome import mqtt
return mqtt.get_esphome_device_ip(config, username, password, client_id)
@@ -420,6 +497,9 @@ def _resolve_network_devices(
This function filters the devices list to:
- Replace MQTT/MQTTIP magic strings with actual IP addresses via MQTT lookup
- Expand hostnames that are already in ``CORE.address_cache`` to their
cached IPs so downstream code (e.g. aioesphomeapi) doesn't open a second
Zeroconf client to resolve them
- Deduplicate addresses while preserving order
- Only resolve MQTT once even if multiple MQTT strings are present
- If MQTT resolution fails, log a warning and continue with other devices
@@ -444,13 +524,29 @@ def _resolve_network_devices(
mqtt_ips = mqtt_get_ip(
config, args.username, args.password, args.client_id
)
network_devices.extend(mqtt_ips)
# pylint can't infer mqtt_get_ip's return through its
# lazy ``from esphome import mqtt`` import, so it flags
# the genexpr below.
network_devices.extend(
addr
for addr in mqtt_ips # pylint: disable=not-an-iterable
if addr not in network_devices
)
except EsphomeError as err:
_LOGGER.warning(
"MQTT IP discovery failed (%s), will try other devices if available",
err,
)
mqtt_resolved = True
continue
# If the hostname is already in the address cache (e.g. populated by
# mDNS discovery), substitute the cached IPs so aioesphomeapi doesn't
# open its own Zeroconf to re-resolve it.
if CORE.address_cache and (cached := CORE.address_cache.get_addresses(device)):
network_devices.extend(
addr for addr in cached if addr not in network_devices
)
elif device not in network_devices:
# Regular network address or IP - add if not already present
network_devices.append(device)

View File

@@ -101,6 +101,17 @@ class AddressCache:
"""Check if any cache entries exist."""
return bool(self.mdns_cache or self.dns_cache)
def add_mdns_addresses(self, hostname: str, addresses: list[str]) -> None:
"""Store resolved mDNS addresses for ``hostname`` in the cache.
Callers that discover ``.local`` hosts (e.g. via mDNS browse) can use
this to avoid a second resolution round-trip during the upload path.
No-op when ``addresses`` is empty.
"""
if not addresses:
return
self.mdns_cache[normalize_hostname(hostname)] = addresses
@classmethod
def from_cli_args(
cls, mdns_args: Iterable[str], dns_args: Iterable[str]

56
esphome/async_thread.py Normal file
View File

@@ -0,0 +1,56 @@
"""Helpers for running an async coroutine from sync code via a daemon thread.
``asyncio.run(coro())`` in the main thread blocks until the loop's cleanup
cycle finishes, which can add hundreds of milliseconds before the caller
receives the result. Running the loop in a daemon thread lets the caller
observe the result as soon as the coroutine completes while cleanup finishes
in the background.
"""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
import threading
from typing import Generic, TypeVar
_T = TypeVar("_T")
class AsyncThreadRunner(threading.Thread, Generic[_T]):
"""Run an async coroutine in a daemon thread and expose its result.
The runner catches all exceptions from the coroutine and stores them in
``exception`` so ``event`` is always set — this prevents callers waiting
on ``event`` from hanging forever when the coroutine crashes.
Typical usage::
runner = AsyncThreadRunner(lambda: my_coro(arg))
runner.start()
if not runner.event.wait(timeout=5.0):
... # timed out
if runner.exception is not None:
raise runner.exception
result = runner.result
"""
def __init__(self, coro_factory: Callable[[], Awaitable[_T]]) -> None:
super().__init__(daemon=True)
self._coro_factory = coro_factory
self.result: _T | None = None
self.exception: BaseException | None = None
self.event = threading.Event()
async def _runner(self) -> None:
try:
self.result = await self._coro_factory()
except Exception as exc: # pylint: disable=broad-except
# Capture all exceptions so ``event`` is always set — otherwise a
# crash would hang the waiter forever.
self.exception = exc
finally:
self.event.set()
def run(self) -> None:
asyncio.run(self._runner())

View File

@@ -2,66 +2,52 @@
from __future__ import annotations
import asyncio
import threading
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
import aioesphomeapi.host_resolver as hr
from esphome.async_thread import AsyncThreadRunner
from esphome.core import EsphomeError
RESOLVE_TIMEOUT = 10.0 # seconds
class AsyncResolver(threading.Thread):
class AsyncResolver:
"""Resolver using aioesphomeapi that runs in a thread for faster results.
This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution,
including proper .local domain fallback. Running in a thread allows us to get
the result immediately without waiting for asyncio.run() to complete its
cleanup cycle, which can take significant time.
This resolver uses aioesphomeapi's async_resolve_host to handle DNS
resolution, including proper .local domain fallback. Running in a thread
(via :class:`AsyncThreadRunner`) allows us to get the result immediately
without waiting for ``asyncio.run()`` to complete its cleanup cycle, which
can take significant time.
"""
def __init__(self, hosts: list[str], port: int) -> None:
"""Initialize the resolver."""
super().__init__(daemon=True)
self.hosts = hosts
self.port = port
self.result: list[hr.AddrInfo] | None = None
self.exception: Exception | None = None
self.event = threading.Event()
async def _resolve(self) -> None:
async def _resolve(self) -> list[hr.AddrInfo]:
"""Resolve hostnames to IP addresses."""
try:
self.result = await hr.async_resolve_host(
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
)
except Exception as e: # pylint: disable=broad-except
# We need to catch all exceptions to ensure the event is set
# Otherwise the thread could hang forever
self.exception = e
finally:
self.event.set()
def run(self) -> None:
"""Run the DNS resolution."""
asyncio.run(self._resolve())
return await hr.async_resolve_host(
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
)
def resolve(self) -> list[hr.AddrInfo]:
"""Start the thread and wait for the result."""
self.start()
runner: AsyncThreadRunner[list[hr.AddrInfo]] = AsyncThreadRunner(self._resolve)
runner.start()
if not self.event.wait(
if not runner.event.wait(
timeout=RESOLVE_TIMEOUT + 1.0
): # Give it 1 second more than the resolver timeout
raise EsphomeError("Timeout resolving IP address")
if exc := self.exception:
if exc := runner.exception:
if isinstance(exc, ResolveTimeoutAPIError):
raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc
if isinstance(exc, ResolveAPIError):
raise EsphomeError(f"Error resolving IP address: {exc}") from exc
raise exc
return self.result
assert runner.result is not None # guaranteed when event set and no exception
return runner.result

View File

@@ -14,8 +14,13 @@ from zeroconf import (
)
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
from esphome.async_thread import AsyncThreadRunner
from esphome.storage_json import StorageJSON, ext_storage_path
# Length of the MAC suffix appended when name_add_mac_suffix is enabled.
MAC_SUFFIX_LEN = 6
_HEX_CHARS = frozenset("0123456789abcdef")
_LOGGER = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 10.0
@@ -188,15 +193,177 @@ class EsphomeZeroconf(Zeroconf):
return None
async def async_resolve_hosts(
zeroconf: Zeroconf, hosts: list[str], timeout: float = DEFAULT_TIMEOUT
) -> dict[str, list[str]]:
"""Resolve ``hosts`` to IPs using a shared ``Zeroconf`` instance.
Tries the cache synchronously first (so hosts already primed by a recent
browse return immediately with no network round-trip), then issues
``async_request`` for the remaining misses in parallel via
``asyncio.gather``. Returns a dict mapping each host to its list of
addresses (empty list when unresolved). Only ``<short>.local`` form is
queried, matching the name scheme the resolvers below expect.
"""
resolvers: dict[str, AddressResolver] = {}
pending: list[str] = []
for host in hosts:
resolver = AddressResolver(f"{host.partition('.')[0]}.local.")
resolvers[host] = resolver
if not resolver.load_from_cache(zeroconf):
pending.append(host)
if pending and timeout:
results = await asyncio.gather(
*(
resolvers[host].async_request(zeroconf, timeout * 1000)
for host in pending
),
return_exceptions=True,
)
for host, result in zip(pending, results):
if isinstance(result, BaseException):
_LOGGER.debug("Failed to resolve %s: %s", host, result)
return {
host: resolver.parsed_scoped_addresses(IPVersion.All)
for host, resolver in resolvers.items()
}
class AsyncEsphomeZeroconf(AsyncZeroconf):
async def async_resolve_host(
self, host: str, timeout: float = DEFAULT_TIMEOUT
) -> list[str] | None:
"""Resolve a host name to an IP address."""
info = AddressResolver(f"{host.partition('.')[0]}.local.")
if (
info.load_from_cache(self.zeroconf)
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
return addresses
return None
addresses = (await async_resolve_hosts(self.zeroconf, [host], timeout))[host]
return addresses or None
def _is_mac_suffix_match(device_name: str, prefix: str) -> bool:
"""Return True if ``device_name`` is ``prefix`` followed by a 6-char hex MAC."""
if not device_name.startswith(prefix):
return False
suffix = device_name[len(prefix) :]
return len(suffix) == MAC_SUFFIX_LEN and all(c in _HEX_CHARS for c in suffix)
async def async_discover_mdns_devices(
base_name: str, timeout: float = 5.0
) -> dict[str, list[str]]:
"""Discover ESPHome devices via mDNS that match the base name + MAC suffix.
When ``name_add_mac_suffix`` is enabled, devices advertise as
``<base_name>-<6-hex-mac>.local``. This function uses a single
``AsyncEsphomeZeroconf`` lifecycle to both browse for matching services and
resolve their IP addresses, so callers get resolved addresses without
opening a second Zeroconf client.
Args:
base_name: The base device name (without MAC suffix).
timeout: How long to wait for mDNS responses (default 5 seconds).
Returns:
Mapping of ``<device>.local`` hostnames to their resolved IP addresses
(may be empty for a device if resolution failed within the timeout).
"""
prefix = f"{base_name}-"
# Preserves insertion order for stable output and deduplicates
discovered: dict[str, list[str]] = {}
def on_service_state_change(
zeroconf: Zeroconf,
service_type: str,
name: str,
state_change: ServiceStateChange,
) -> None:
if state_change not in (ServiceStateChange.Added, ServiceStateChange.Updated):
return
device_name = name.partition(".")[0]
if not _is_mac_suffix_match(device_name, prefix):
_LOGGER.debug(
"Ignoring %s (%s): does not match '%s<6-hex>'",
device_name,
state_change.name,
prefix,
)
return
host = f"{device_name}.local"
if host in discovered:
return
discovered[host] = []
_LOGGER.debug("Discovered %s (%s)", host, state_change.name)
_LOGGER.debug(
"Starting mDNS discovery for '%s<mac>.local' (timeout=%.1fs)",
prefix,
timeout,
)
try:
aiozc = AsyncEsphomeZeroconf()
except Exception as err: # pylint: disable=broad-except
# Zeroconf init can raise OSError, NonUniqueNameException, etc.
# Any failure here just means we can't discover — log and move on.
_LOGGER.warning("mDNS discovery failed to initialize: %s", err)
return {}
try:
browser = AsyncServiceBrowser(
aiozc.zeroconf,
ESPHOME_SERVICE_TYPE,
handlers=[on_service_state_change],
)
try:
await asyncio.sleep(timeout)
finally:
await browser.async_cancel()
_LOGGER.debug(
"Browse finished: %d device(s) matched '%s<mac>'",
len(discovered),
prefix,
)
# Resolve each discovered hostname on the SAME Zeroconf instance so
# we don't spin up a second client. ``async_resolve_hosts`` tries the
# cache synchronously (the browse usually primes it) before issuing
# any ``async_request`` in parallel for misses.
resolved = await async_resolve_hosts(aiozc.zeroconf, list(discovered))
for host, addresses in resolved.items():
if addresses:
discovered[host] = addresses
_LOGGER.debug("Resolved %s -> %s", host, addresses)
else:
_LOGGER.debug("No addresses returned for %s", host)
finally:
await aiozc.async_close()
return dict(sorted(discovered.items()))
def _await_discovery(
runner: AsyncThreadRunner[dict[str, list[str]]], timeout: float
) -> dict[str, list[str]]:
"""Wait for ``runner`` to finish and return its discovery result.
Split out of :func:`discover_mdns_devices` so the timeout branch is
testable without patching ``asyncio`` or ``threading`` internals — a test
passes a stub whose ``event.wait`` returns ``False``.
"""
# Give the discovery an extra second over the browse timeout for the
# resolution + cleanup pass.
if not runner.event.wait(timeout=timeout + 2.0):
_LOGGER.warning("mDNS discovery timed out after %.1fs", timeout)
return {}
if runner.exception is not None:
_LOGGER.warning("mDNS discovery failed: %s", runner.exception)
return {}
return runner.result or {}
def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> dict[str, list[str]]:
"""Synchronous wrapper around :func:`async_discover_mdns_devices`."""
runner = AsyncThreadRunner(
lambda: async_discover_mdns_devices(base_name, timeout=timeout)
)
runner.start()
return _await_discovery(runner, timeout)

View File

@@ -121,6 +121,26 @@ def test_get_addresses_auto_detection() -> None:
assert cache.get_addresses("unknown.com") is None
def test_add_mdns_addresses_stores_and_normalizes() -> None:
"""add_mdns_addresses inserts entries under the normalized hostname."""
cache = AddressCache()
cache.add_mdns_addresses("Device.Local.", ["192.168.1.10", "192.168.1.11"])
assert cache.mdns_cache == {
normalize_hostname("Device.Local."): ["192.168.1.10", "192.168.1.11"]
}
# Overwrites on subsequent calls for the same host
cache.add_mdns_addresses("device.local", ["10.0.0.1"])
assert cache.mdns_cache[normalize_hostname("device.local")] == ["10.0.0.1"]
def test_add_mdns_addresses_empty_is_noop() -> None:
"""Passing an empty address list must not create an entry."""
cache = AddressCache()
cache.add_mdns_addresses("device.local", [])
assert cache.mdns_cache == {}
def test_has_cache() -> None:
"""Test checking if cache has entries."""
# Empty cache

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Callable, Generator
from dataclasses import dataclass
import json
import logging
@@ -12,16 +12,18 @@ import re
import sys
import time
from typing import Any
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from pytest import CaptureFixture
from zeroconf import ServiceStateChange
from esphome import platformio_api
from esphome.__main__ import (
Purpose,
_get_configured_xtal_freq,
_make_crystal_freq_callback,
_resolve_network_devices,
choose_upload_log_host,
command_analyze_memory,
command_bundle,
@@ -36,6 +38,7 @@ from esphome.__main__ import (
has_mqtt,
has_mqtt_ip_lookup,
has_mqtt_logging,
has_name_add_mac_suffix,
has_non_ip_address,
has_ota,
has_resolvable_address,
@@ -48,6 +51,7 @@ from esphome.__main__ import (
upload_using_picotool,
upload_using_platformio,
)
from esphome.address_cache import AddressCache
from esphome.bundle import BUNDLE_EXTENSION, BundleFile, BundleResult
from esphome.components.esp32 import KEY_ESP32, KEY_VARIANT, VARIANT_ESP32
from esphome.const import (
@@ -62,6 +66,7 @@ from esphome.const import (
CONF_MDNS,
CONF_MQTT,
CONF_NAME,
CONF_NAME_ADD_MAC_SUFFIX,
CONF_OTA,
CONF_PASSWORD,
CONF_PLATFORM,
@@ -79,6 +84,7 @@ from esphome.const import (
)
from esphome.core import CORE, EsphomeError
from esphome.util import BootselResult
from esphome.zeroconf import _await_discovery, discover_mdns_devices
def strip_ansi_codes(text: str) -> str:
@@ -2218,6 +2224,509 @@ def test_has_resolvable_address() -> None:
assert has_resolvable_address() is False
def test_has_name_add_mac_suffix() -> None:
"""Test has_name_add_mac_suffix function."""
# Test with name_add_mac_suffix enabled
setup_core(config={CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: True}})
assert has_name_add_mac_suffix() is True
# Test with name_add_mac_suffix disabled
setup_core(config={CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: False}})
assert has_name_add_mac_suffix() is False
# Test with name_add_mac_suffix not set (defaults to False)
setup_core(config={CONF_ESPHOME: {}})
assert has_name_add_mac_suffix() is False
# Test with no esphome config
setup_core(config={})
assert has_name_add_mac_suffix() is False
# Test with no config at all
CORE.config = None
assert has_name_add_mac_suffix() is False
@pytest.fixture
def mock_mdns_discovery() -> Generator[MagicMock]:
"""Fixture to mock the async mDNS discovery infrastructure.
Patches ``AsyncEsphomeZeroconf``, ``AsyncServiceBrowser`` and
``AddressResolver`` in ``esphome.zeroconf`` and exposes hooks for tests to
stage browser events and control resolution results. The default
``AddressResolver`` stub simulates a cache hit returning no addresses, so
matched hosts appear in the discovery output with empty address lists
unless the test overrides ``_resolver_setup``.
"""
with (
patch("esphome.zeroconf.AsyncEsphomeZeroconf") as mock_aiozc_class,
patch("esphome.zeroconf.AsyncServiceBrowser") as mock_browser_class,
patch("esphome.zeroconf.AddressResolver") as mock_resolver_class,
):
mock_aiozc = MagicMock()
mock_aiozc.zeroconf = MagicMock()
mock_aiozc.async_close = AsyncMock(return_value=None)
mock_aiozc_class.return_value = mock_aiozc
mock_browser = MagicMock()
mock_browser.async_cancel = AsyncMock(return_value=None)
# Default: each host gets a fresh resolver that hits the cache and
# returns no addresses. Tests can override via ``_resolver_setup``.
def default_resolver_factory(name: str) -> MagicMock:
resolver = MagicMock()
resolver._name = name
resolver.load_from_cache.return_value = True
resolver.async_request = AsyncMock(return_value=True)
resolver.parsed_scoped_addresses.return_value = []
return resolver
mock_resolver_class.side_effect = default_resolver_factory
# Store references for test access
mock_aiozc._mock_browser_class = mock_browser_class
mock_aiozc._mock_browser = mock_browser
mock_aiozc._mock_class = mock_aiozc_class
mock_aiozc._mock_resolver_class = mock_resolver_class
yield mock_aiozc
@pytest.mark.parametrize(
("discovered_services", "base_name", "expected_hosts"),
[
# Matching devices; different-prefix device is filtered out
(
[
("mydevice-abc123._esphomelib._tcp.local.", ServiceStateChange.Added),
("mydevice-def456._esphomelib._tcp.local.", ServiceStateChange.Added),
(
"otherdevice-abcdef._esphomelib._tcp.local.",
ServiceStateChange.Added,
),
],
"mydevice",
["mydevice-abc123.local", "mydevice-def456.local"],
),
# No matches at all
(
[
(
"otherdevice-abcdef._esphomelib._tcp.local.",
ServiceStateChange.Added,
),
],
"mydevice",
[],
),
# Deduplication (same device Added then Updated)
(
[
("mydevice-abc123._esphomelib._tcp.local.", ServiceStateChange.Added),
("mydevice-abc123._esphomelib._tcp.local.", ServiceStateChange.Updated),
],
"mydevice",
["mydevice-abc123.local"],
),
# Suffix must be exactly 6 hex chars: wrong length and non-hex are rejected
(
[
# too short
("mydevice-abcd._esphomelib._tcp.local.", ServiceStateChange.Added),
# too long
(
"mydevice-abcdef1._esphomelib._tcp.local.",
ServiceStateChange.Added,
),
# non-hex
("mydevice-xyz123._esphomelib._tcp.local.", ServiceStateChange.Added),
# valid
("mydevice-012345._esphomelib._tcp.local.", ServiceStateChange.Added),
],
"mydevice",
["mydevice-012345.local"],
),
# Prefix-collision: base "foo" must not match "foo-bar-abc123"
(
[
("foo-abcdef._esphomelib._tcp.local.", ServiceStateChange.Added),
("foo-bar-abcdef._esphomelib._tcp.local.", ServiceStateChange.Added),
],
"foo",
["foo-abcdef.local"],
),
],
ids=[
"matching_with_filter",
"no_matches",
"deduplication",
"hex_suffix_filter",
"prefix_collision",
],
)
def test_discover_mdns_devices(
mock_mdns_discovery: MagicMock,
discovered_services: list[tuple[str, ServiceStateChange]],
base_name: str,
expected_hosts: list[str],
) -> None:
"""Test discover_mdns_devices filtering and deduplication."""
mock_browser = mock_mdns_discovery._mock_browser
def capture_callback(
zc: MagicMock,
service_type: str,
handlers: list[Callable[..., None]],
) -> MagicMock:
callback = handlers[0]
for service_name, state_change in discovered_services:
callback(
mock_mdns_discovery.zeroconf, service_type, service_name, state_change
)
return mock_browser
mock_mdns_discovery._mock_browser_class.side_effect = capture_callback
# Each discovered host gets a resolver that returns a unique IP string
# derived from its server name so we can assert per-host.
def resolver_factory(name: str) -> MagicMock:
resolver = MagicMock()
resolver._name = name
resolver.load_from_cache.return_value = True
resolver.async_request = AsyncMock(return_value=True)
resolver.parsed_scoped_addresses.return_value = [f"10.0.0.1#{name}"]
return resolver
mock_mdns_discovery._mock_resolver_class.side_effect = resolver_factory
result = discover_mdns_devices(base_name, timeout=0)
assert sorted(result) == expected_hosts
# Resolved addresses should be stored for matched hosts. AddressResolver
# receives the fully-qualified name (``<device>.local.``).
for host in expected_hosts:
short = host.partition(".")[0]
assert result[host] == [f"10.0.0.1#{short}.local."]
mock_browser.async_cancel.assert_awaited_once()
mock_mdns_discovery.async_close.assert_awaited_once()
def test_discover_mdns_devices_init_failure(caplog: pytest.LogCaptureFixture) -> None:
"""If AsyncEsphomeZeroconf fails to init, return empty dict and log warning."""
with (
patch(
"esphome.zeroconf.AsyncEsphomeZeroconf",
side_effect=OSError("no network"),
),
caplog.at_level(logging.WARNING, logger="esphome.zeroconf"),
):
result = discover_mdns_devices("mydevice", timeout=0)
assert result == {}
assert "mDNS discovery failed to initialize" in caplog.text
def test_discover_mdns_devices_resolution_failure(
mock_mdns_discovery: MagicMock,
) -> None:
"""If resolution raises, the host is still listed with an empty address list."""
mock_browser = mock_mdns_discovery._mock_browser
def capture_callback(
zc: MagicMock,
service_type: str,
handlers: list[Callable[..., None]],
) -> MagicMock:
handlers[0](
mock_mdns_discovery.zeroconf,
service_type,
"mydevice-abc123._esphomelib._tcp.local.",
ServiceStateChange.Added,
)
return mock_browser
mock_mdns_discovery._mock_browser_class.side_effect = capture_callback
# Resolver misses the cache, then async_request raises.
def failing_resolver_factory(name: str) -> MagicMock:
resolver = MagicMock()
resolver.load_from_cache.return_value = False
resolver.async_request = AsyncMock(side_effect=OSError("boom"))
resolver.parsed_scoped_addresses.return_value = []
return resolver
mock_mdns_discovery._mock_resolver_class.side_effect = failing_resolver_factory
result = discover_mdns_devices("mydevice", timeout=0)
assert result == {"mydevice-abc123.local": []}
def test_discover_mdns_devices_ignores_removed_state(
mock_mdns_discovery: MagicMock,
) -> None:
"""``Removed`` state changes are ignored and do not appear in the result."""
mock_browser = mock_mdns_discovery._mock_browser
def capture_callback(
zc: MagicMock,
service_type: str,
handlers: list[Callable[..., None]],
) -> MagicMock:
handlers[0](
mock_mdns_discovery.zeroconf,
service_type,
"mydevice-abc123._esphomelib._tcp.local.",
ServiceStateChange.Removed,
)
return mock_browser
mock_mdns_discovery._mock_browser_class.side_effect = capture_callback
result = discover_mdns_devices("mydevice", timeout=0)
assert result == {}
# No AddressResolver should have been constructed since no host matched.
mock_mdns_discovery._mock_resolver_class.assert_not_called()
def test_discover_mdns_devices_empty_resolution(
mock_mdns_discovery: MagicMock,
) -> None:
"""Host is listed with empty addresses when resolver returns no addresses."""
mock_browser = mock_mdns_discovery._mock_browser
def capture_callback(
zc: MagicMock,
service_type: str,
handlers: list[Callable[..., None]],
) -> MagicMock:
handlers[0](
mock_mdns_discovery.zeroconf,
service_type,
"mydevice-abc123._esphomelib._tcp.local.",
ServiceStateChange.Added,
)
return mock_browser
mock_mdns_discovery._mock_browser_class.side_effect = capture_callback
# Default fixture resolver is a cache-hit with no addresses — simulates
# the "browse found it but no A/AAAA records are available" case.
result = discover_mdns_devices("mydevice", timeout=0)
assert result == {"mydevice-abc123.local": []}
def test_resolve_network_devices_expands_cached_mdns_hosts(tmp_path: Path) -> None:
"""Hostnames in ``CORE.address_cache`` are expanded to their cached IPs."""
setup_core(tmp_path=tmp_path)
CORE.address_cache = AddressCache(
mdns_cache={
"device-abc123.local": ["10.0.0.1", "10.0.0.2"],
}
)
result = _resolve_network_devices(
["device-abc123.local", "192.168.1.50", "device-abc123.local"],
CORE.config,
MockArgs(),
)
# Cached hostname is replaced with its IPs (deduplicated across repeats)
# and the literal IP is preserved after.
assert result == ["10.0.0.1", "10.0.0.2", "192.168.1.50"]
def test_resolve_network_devices_keeps_uncached_hosts(tmp_path: Path) -> None:
"""Hostnames not in the cache pass through unchanged."""
setup_core(tmp_path=tmp_path)
CORE.address_cache = AddressCache()
result = _resolve_network_devices(
["unknown.local", "192.168.1.50"],
CORE.config,
MockArgs(),
)
assert result == ["unknown.local", "192.168.1.50"]
def test_await_discovery_timeout_returns_empty(
caplog: pytest.LogCaptureFixture,
) -> None:
"""If the discovery runner never sets its event, return {} and warn."""
stub = MagicMock()
stub.event.wait.return_value = False
stub.exception = None
stub.result = {"should_not_be_read": ["1.2.3.4"]}
with caplog.at_level(logging.WARNING, logger="esphome.zeroconf"):
result = _await_discovery(stub, timeout=0.01)
assert result == {}
assert "mDNS discovery timed out after 0.0s" in caplog.text
stub.event.wait.assert_called_once_with(timeout=pytest.approx(2.01))
def test_await_discovery_propagates_exception_as_empty(
caplog: pytest.LogCaptureFixture,
) -> None:
"""If the coroutine raised, log and return {} rather than re-raise."""
stub = MagicMock()
stub.event.wait.return_value = True
stub.exception = RuntimeError("boom")
stub.result = None
with caplog.at_level(logging.WARNING, logger="esphome.zeroconf"):
result = _await_discovery(stub, timeout=5.0)
assert result == {}
assert "mDNS discovery failed: boom" in caplog.text
@pytest.mark.usefixtures("mock_no_serial_ports")
def test_choose_upload_log_host_discovers_mac_suffix_devices(tmp_path: Path) -> None:
"""Interactive mode discovers MAC-suffixed devices and populates the cache."""
setup_core(
config={
CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: True},
CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}],
},
address="mydevice.local",
tmp_path=tmp_path,
name="mydevice",
)
CORE.address_cache = None
discovered = {
"mydevice-abc123.local": ["10.0.0.1"],
"mydevice-def456.local": ["10.0.0.2"],
}
with (
patch(
"esphome.__main__.discover_mdns_devices", return_value=discovered
) as mock_discover,
patch(
"esphome.__main__.choose_prompt", return_value="mydevice-abc123.local"
) as mock_prompt,
):
result = choose_upload_log_host(
default=None,
check_default=None,
purpose=Purpose.UPLOADING,
)
assert result == ["mydevice-abc123.local"]
mock_discover.assert_called_once_with("mydevice")
mock_prompt.assert_called_once_with(
[
("Over The Air (mydevice-abc123.local)", "mydevice-abc123.local"),
("Over The Air (mydevice-def456.local)", "mydevice-def456.local"),
],
purpose=Purpose.UPLOADING,
)
# Resolved IPs should be cached so downstream resolution skips a second
# Zeroconf lookup.
assert CORE.address_cache is not None
assert CORE.address_cache.get_mdns_addresses("mydevice-abc123.local") == [
"10.0.0.1"
]
assert CORE.address_cache.get_mdns_addresses("mydevice-def456.local") == [
"10.0.0.2"
]
@pytest.mark.usefixtures("mock_no_serial_ports")
def test_choose_upload_log_host_mac_suffix_no_devices_found(
tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""When discovery finds nothing, no OTA option is offered and a warning logs."""
setup_core(
config={
CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: True},
CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}],
},
address="mydevice.local",
tmp_path=tmp_path,
name="mydevice",
)
with (
patch("esphome.__main__.discover_mdns_devices", return_value={}),
caplog.at_level(logging.WARNING, logger="esphome.__main__"),
pytest.raises(EsphomeError),
):
choose_upload_log_host(
default=None,
check_default=None,
purpose=Purpose.UPLOADING,
)
assert "No devices matching 'mydevice-<mac>.local'" in caplog.text
def test_choose_upload_log_host_default_ota_discovers_mac_suffix(
tmp_path: Path,
) -> None:
"""``--device OTA`` also runs mDNS discovery when name_add_mac_suffix is on."""
setup_core(
config={
CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: True},
CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}],
},
address="mydevice.local",
tmp_path=tmp_path,
name="mydevice",
)
CORE.address_cache = None
discovered = {
"mydevice-abc123.local": ["10.0.0.1"],
"mydevice-def456.local": ["10.0.0.2"],
}
with patch(
"esphome.__main__.discover_mdns_devices", return_value=discovered
) as mock_discover:
result = choose_upload_log_host(
default="OTA",
check_default=None,
purpose=Purpose.UPLOADING,
)
# Both discovered hostnames are returned so aioesphomeapi / espota2 can
# try each in turn with the cached IPs.
assert result == ["mydevice-abc123.local", "mydevice-def456.local"]
mock_discover.assert_called_once_with("mydevice")
assert CORE.address_cache is not None
assert CORE.address_cache.get_mdns_addresses("mydevice-abc123.local") == [
"10.0.0.1"
]
def test_choose_upload_log_host_default_ota_no_suffix_discovery(
tmp_path: Path,
) -> None:
"""``--device OTA`` without name_add_mac_suffix uses CORE.address as-is."""
setup_core(
config={CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}]},
address="192.168.1.100",
tmp_path=tmp_path,
name="mydevice",
)
with patch("esphome.__main__.discover_mdns_devices") as mock_discover:
result = choose_upload_log_host(
default="OTA",
check_default=None,
purpose=Purpose.UPLOADING,
)
assert result == ["192.168.1.100"]
# Discovery must NOT run when name_add_mac_suffix is disabled.
mock_discover.assert_not_called()
def test_command_wizard(tmp_path: Path) -> None:
"""Test command_wizard function."""
config_file = tmp_path / "test.yaml"

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import re
import socket
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr
@@ -115,24 +115,21 @@ def test_async_resolver_generic_exception() -> None:
def test_async_resolver_thread_timeout() -> None:
"""Test timeout when thread doesn't complete in time."""
# Mock the start method to prevent actual thread execution
with (
patch.object(AsyncResolver, "start"),
patch("esphome.resolver.hr.async_resolve_host"),
):
resolver = AsyncResolver(["test.local"], 6053)
# Override event.wait to simulate timeout (return False = timeout occurred)
with (
patch.object(resolver.event, "wait", return_value=False),
pytest.raises(
EsphomeError, match=re.escape("Timeout resolving IP address")
),
):
resolver.resolve()
"""Test timeout when the runner thread doesn't complete in time."""
# Patch AsyncThreadRunner inside esphome.resolver so we never actually
# start a thread and can control the wait return value directly.
fake_runner = MagicMock()
fake_runner.start = MagicMock()
fake_runner.event.wait.return_value = False # simulate timeout
# Verify thread start was called
resolver.start.assert_called_once()
with (
patch("esphome.resolver.AsyncThreadRunner", return_value=fake_runner),
patch("esphome.resolver.hr.async_resolve_host"),
pytest.raises(EsphomeError, match=re.escape("Timeout resolving IP address")),
):
AsyncResolver(["test.local"], 6053).resolve()
fake_runner.start.assert_called_once()
def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None: