[psram] Make schema extractable with per-variant options (#16949)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Jesse Hills
2026-06-15 19:55:21 +12:00
parent 94b248527d
commit 32ab3abd7c
8 changed files with 158 additions and 15 deletions

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable, Iterable
import contextlib
from dataclasses import dataclass
import itertools
@@ -6,6 +7,7 @@ import os
from pathlib import Path
import re
import subprocess
from typing import Any
from esphome import yaml_util
import esphome.codegen as cg
@@ -52,6 +54,7 @@ from esphome.coroutine import CoroPriority, coroutine_with_priority
from esphome.espidf.component import generate_idf_components
import esphome.final_validate as fv
from esphome.helpers import copy_file_if_changed, rmtree, write_file_if_changed
from esphome.schema_extractors import SCHEMA_EXTRACT, schema_extractor
from esphome.types import ConfigType
from esphome.writer import clean_build, clean_cmake_cache
@@ -496,6 +499,32 @@ def get_esp32_variant(core_obj=None):
return (core_obj or CORE).data[KEY_ESP32][KEY_VARIANT]
def variant_filtered_enum(
by_variant: dict[str, Iterable[Any]], **kwargs: Any
) -> Callable[[Any], Any]:
"""Build a ``one_of`` validator whose valid set depends on the active variant.
``by_variant`` maps each ESP32 variant constant to the iterable of values that
are valid on that variant. At validation time the value is checked against the
set allowed for the current target variant. For schema extraction the inverted
``{value: [variants, ...]}`` map is returned instead, so the language-schema
dump can tag every option with the variants that accept it and frontends can
filter to the user's selected variant.
"""
by_value: dict[str, list[str]] = {}
for variant, values in by_variant.items():
for value in values:
by_value.setdefault(str(value), []).append(variant)
@schema_extractor("variant_enum")
def validator(value: Any) -> Any:
if value is SCHEMA_EXTRACT:
return by_value
return cv.one_of(*by_variant.get(get_esp32_variant(), ()), **kwargs)(value)
return validator
def get_board(core_obj=None):
return (core_obj or CORE).data[KEY_ESP32][KEY_BOARD]

View File

@@ -16,6 +16,7 @@ from esphome.components.esp32 import (
add_idf_sdkconfig_option,
get_esp32_variant,
idf_version,
variant_filtered_enum,
)
import esphome.config_validation as cv
from esphome.const import (
@@ -29,6 +30,7 @@ from esphome.const import (
)
from esphome.core import CORE
import esphome.final_validate as fv
from esphome.types import ConfigType
CODEOWNERS = ["@esphome/core"]
DOMAIN = "psram"
@@ -70,6 +72,11 @@ SPIRAM_SPEEDS = {
VARIANT_ESP32P4: (20, 100, 200),
}
SPIRAM_SPEEDS_MHZ = {
variant: tuple(f"{speed}MHZ" for speed in speeds)
for variant, speeds in SPIRAM_SPEEDS.items()
}
def supported() -> bool:
if not CORE.is_esp32:
@@ -145,15 +152,23 @@ def validate_psram_mode(config):
return config
def get_config_schema(config):
def _set_variant_defaults(config: ConfigType) -> ConfigType:
"""Resolve variant-dependent defaults before the static schema validates.
The set of valid ``mode``/``speed`` values is variant-specific (enforced by
``variant_filtered_enum`` in the schema below); this only supplies the default
when the user omits the option. ``mode`` has no single default on chips that
support more than one mode, so selection is required there.
"""
variant = get_esp32_variant()
speeds = [f"{s}MHZ" for s in SPIRAM_SPEEDS.get(variant, [])]
if not speeds:
modes = SPIRAM_MODES.get(variant)
speeds = SPIRAM_SPEEDS.get(variant)
if not modes or not speeds:
raise cv.Invalid("PSRAM is not supported on this chip")
modes = SPIRAM_MODES[variant]
if CONF_MODE not in config and len(modes) != 1:
raise (
cv.Invalid(
config = config.copy()
if CONF_MODE not in config:
if len(modes) != 1:
raise cv.Invalid(
textwrap.dedent(
f"""
{variant} requires PSRAM mode selection; one of {", ".join(modes)}
@@ -161,20 +176,27 @@ def get_config_schema(config):
"""
)
)
)
return cv.Schema(
config[CONF_MODE] = modes[0]
if CONF_SPEED not in config:
config[CONF_SPEED] = f"{speeds[0]}MHZ"
return config
CONFIG_SCHEMA = cv.All(
_set_variant_defaults,
cv.Schema(
{
cv.GenerateID(): cv.declare_id(PsramComponent),
cv.Optional(CONF_MODE, default=modes[0]): cv.one_of(*modes, lower=True),
cv.Optional(CONF_MODE): variant_filtered_enum(SPIRAM_MODES, lower=True),
cv.Optional(CONF_ENABLE_ECC, default=False): cv.boolean,
cv.Optional(CONF_SPEED, default=speeds[0]): cv.one_of(*speeds, upper=True),
cv.Optional(CONF_SPEED): variant_filtered_enum(
SPIRAM_SPEEDS_MHZ, upper=True
),
cv.Optional(CONF_DISABLED, default=False): cv.boolean,
cv.Optional(CONF_IGNORE_NOT_FOUND, default=True): cv.boolean,
}
)(config)
CONFIG_SCHEMA = get_config_schema
),
)
def _store_psram_guaranteed(config):

View File

@@ -951,6 +951,15 @@ def convert(schema, config_var, path):
elif schema_type == "enum":
config_var[S_TYPE] = "enum"
config_var["values"] = dict.fromkeys(list(data.keys()))
elif schema_type == "variant_enum":
# Per-variant enum (e.g. psram mode/speed): each value carries the
# list of variants that accept it so clients can filter to the
# user's selected variant. Additive to the plain enum format —
# consumers that ignore the metadata still see every option.
config_var[S_TYPE] = "enum"
config_var["values"] = {
value: {"variants": variants} for value, variants in data.items()
}
elif schema_type == "maybe":
# maybe_simple_value: either a scalar shorthand (mapped to the key in
# data[1]) or the full wrapped schema. The wrapped schema is usually a

View File

@@ -97,6 +97,54 @@ def test_psram_configuration_valid_supported_variants(
FINAL_VALIDATE_SCHEMA(config)
def test_psram_applies_single_mode_default(
set_core_config: SetCoreConfigCallable,
) -> None:
"""On a single-mode variant the omitted mode/speed fall back to defaults."""
set_core_config(
PlatformFramework.ESP32_IDF,
platform_data={KEY_VARIANT: VARIANT_ESP32},
full_config={CONF_ESPHOME: {}},
)
from esphome.components.psram import CONFIG_SCHEMA
config = CONFIG_SCHEMA({})
assert config["mode"] == "quad"
assert config["speed"] == "40MHZ"
assert config["disabled"] is False
assert config["ignore_not_found"] is True
def test_psram_requires_mode_on_multi_mode_variant(
set_core_config: SetCoreConfigCallable,
) -> None:
"""A variant with multiple modes requires an explicit mode selection."""
set_core_config(
PlatformFramework.ESP32_IDF,
platform_data={KEY_VARIANT: VARIANT_ESP32S3},
full_config={CONF_ESPHOME: {}},
)
from esphome.components.psram import CONFIG_SCHEMA
with pytest.raises(cv.Invalid, match=r"requires PSRAM mode selection"):
CONFIG_SCHEMA({})
def test_psram_rejects_mode_invalid_for_variant(
set_core_config: SetCoreConfigCallable,
) -> None:
"""A mode not supported by the active variant is rejected by the schema."""
set_core_config(
PlatformFramework.ESP32_IDF,
platform_data={KEY_VARIANT: VARIANT_ESP32},
full_config={CONF_ESPHOME: {}},
)
from esphome.components.psram import CONFIG_SCHEMA
with pytest.raises(cv.Invalid, match=r"Unknown value 'octal'"):
CONFIG_SCHEMA({"mode": "octal"})
def _setup_psram_final_validation_test(
esp32_config: dict,
set_core_config: SetCoreConfigCallable,

View File

@@ -0,0 +1,5 @@
# Config-only: the ESP32-S3 supports both quad and octal. The compile test uses
# octal; this exercises the other branch of the per-variant mode enum (quad) and
# lets speed fall back to its 40MHz default.
psram:
mode: quad

View File

@@ -0,0 +1,4 @@
# Config-only: with no options the single-mode ESP32 resolves mode -> quad and
# speed -> 40MHz from the per-variant defaults. Compiling adds no signal here,
# so this only runs through `esphome config`.
psram:

View File

@@ -0,0 +1,4 @@
# Config-only: the ESP32-P4 has a distinct value set (hex mode, 20/100/200MHz).
# With no options it resolves mode -> hex and speed -> 20MHz, exercising the
# P4-specific default branch of the per-variant enums.
psram:

View File

@@ -139,6 +139,28 @@ def test_convert_walks_callable_schema_extractor() -> None:
assert "foo" in config_var["schema"]["config_vars"]
def test_convert_emits_variant_enum() -> None:
"""A per-variant enum is dumped with each value tagged by its variants."""
from esphome.components.esp32 import (
VARIANT_ESP32,
VARIANT_ESP32S3,
variant_filtered_enum,
)
validator = variant_filtered_enum(
{VARIANT_ESP32: ("quad",), VARIANT_ESP32S3: ("quad", "octal")},
lower=True,
)
config_var: dict = {}
_bls.convert(validator, config_var, "/test")
assert config_var["type"] == "enum"
assert config_var["values"] == {
"quad": {"variants": [VARIANT_ESP32, VARIANT_ESP32S3]},
"octal": {"variants": [VARIANT_ESP32S3]},
}
def test_convert_keys_emits_heuristic_sensitive_marker() -> None:
converted: dict = {}
_bls.convert_keys(converted, {cv.Optional("password"): cv.string}, "/root")