mirror of
https://github.com/esphome/esphome.git
synced 2026-06-24 14:37:04 +00:00
[mapping] Implement default value (#15861)
This commit is contained in:
@@ -1,18 +1,27 @@
|
||||
from collections.abc import Callable
|
||||
import difflib
|
||||
|
||||
import esphome.codegen as cg
|
||||
from esphome.components.const import KEY_METADATA
|
||||
import esphome.config_validation as cv
|
||||
from esphome.const import CONF_FROM, CONF_ID, CONF_TO
|
||||
from esphome.core import CORE
|
||||
from esphome.cpp_generator import MockObj, VariableDeclarationExpression, add_global
|
||||
from esphome.core import CORE, ID
|
||||
from esphome.cpp_generator import (
|
||||
MockObj,
|
||||
MockObjClass,
|
||||
VariableDeclarationExpression,
|
||||
add_global,
|
||||
)
|
||||
from esphome.loader import get_component
|
||||
|
||||
CODEOWNERS = ["@clydebarrow"]
|
||||
MULTI_CONF = True
|
||||
DOMAIN = "mapping"
|
||||
|
||||
mapping_ns = cg.esphome_ns.namespace("mapping")
|
||||
mapping_class = mapping_ns.class_("Mapping")
|
||||
|
||||
CONF_DEFAULT_VALUE = "default_value"
|
||||
CONF_ENTRIES = "entries"
|
||||
CONF_CLASS = "class"
|
||||
|
||||
@@ -22,11 +31,18 @@ class IndexType:
|
||||
Represents a type of index in a map.
|
||||
"""
|
||||
|
||||
def __init__(self, validator, data_type, conversion):
|
||||
def __init__(
|
||||
self, validator: Callable, data_type: MockObj, conversion: Callable = None
|
||||
) -> None:
|
||||
self.validator = validator
|
||||
self.data_type = data_type
|
||||
self.conversion = conversion
|
||||
|
||||
async def convert_value(self, value):
|
||||
if self.conversion:
|
||||
return self.conversion(value)
|
||||
return await cg.get_variable(value)
|
||||
|
||||
|
||||
INDEX_TYPES = {
|
||||
"int": IndexType(cv.int_, cg.int_, int),
|
||||
@@ -38,6 +54,12 @@ INDEX_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
class MappingMetaData:
|
||||
def __init__(self, from_: IndexType, to_: IndexType) -> None:
|
||||
self.from_ = from_
|
||||
self.to_ = to_
|
||||
|
||||
|
||||
def to_schema(value):
|
||||
"""
|
||||
Generate a schema for the 'to' field of a map. This can be either one of the index types or a class name.
|
||||
@@ -60,7 +82,7 @@ BASE_SCHEMA = cv.Schema(
|
||||
)
|
||||
|
||||
|
||||
def get_object_type(to_):
|
||||
def get_object_type(to_) -> MockObjClass | None:
|
||||
"""
|
||||
Get the object type from a string. Possible formats:
|
||||
xxx The name of a component which defines INSTANCE_TYPE
|
||||
@@ -81,25 +103,60 @@ def get_object_type(to_):
|
||||
return None
|
||||
|
||||
|
||||
def get_all_mapping_metadata() -> dict[str, MappingMetaData]:
|
||||
"""Get all mapping metadata."""
|
||||
return CORE.data.setdefault(DOMAIN, {}).setdefault(KEY_METADATA, {})
|
||||
|
||||
|
||||
def get_mapping_metadata(mapping_id: str) -> MappingMetaData:
|
||||
"""Get mapping metadata by ID for use by other components."""
|
||||
return get_all_mapping_metadata()[mapping_id]
|
||||
|
||||
|
||||
def add_metadata(
|
||||
mapping_id: ID,
|
||||
from_: IndexType,
|
||||
to_: IndexType,
|
||||
) -> None:
|
||||
get_all_mapping_metadata()[mapping_id.id] = MappingMetaData(from_, to_)
|
||||
|
||||
|
||||
def map_schema(config):
|
||||
config = BASE_SCHEMA(config)
|
||||
if CONF_ENTRIES not in config or not isinstance(config[CONF_ENTRIES], dict):
|
||||
raise cv.Invalid("an entries list is required for a map")
|
||||
raise cv.Invalid("an entries dictionary is required for a mapping")
|
||||
entries = config[CONF_ENTRIES]
|
||||
if len(entries) == 0:
|
||||
raise cv.Invalid("Map must have at least one entry")
|
||||
raise cv.Invalid("A mapping must have at least one entry")
|
||||
to_ = config[CONF_TO]
|
||||
if to_ in INDEX_TYPES:
|
||||
value_type = INDEX_TYPES[to_].validator
|
||||
value_type = INDEX_TYPES[to_]
|
||||
else:
|
||||
value_type = get_object_type(to_)
|
||||
if value_type is None:
|
||||
object_type = get_object_type(to_)
|
||||
if object_type is None:
|
||||
matches = difflib.get_close_matches(to_, CORE.id_classes)
|
||||
raise cv.Invalid(
|
||||
f"No known mappable class name matches '{to_}'; did you mean one of {', '.join(matches)}?"
|
||||
)
|
||||
value_type = cv.use_id(value_type)
|
||||
config[CONF_ENTRIES] = {k: value_type(v) for k, v in entries.items()}
|
||||
validator = cv.use_id(object_type)
|
||||
value_type = IndexType(validator, object_type)
|
||||
config[CONF_ENTRIES] = {k: value_type.validator(v) for k, v in entries.items()}
|
||||
if (default_value := config.get(CONF_DEFAULT_VALUE)) is not None:
|
||||
config[CONF_DEFAULT_VALUE] = value_type.validator(default_value)
|
||||
unexpected_keys = config.keys() - {
|
||||
CONF_ENTRIES,
|
||||
CONF_TO,
|
||||
CONF_FROM,
|
||||
CONF_ID,
|
||||
CONF_DEFAULT_VALUE,
|
||||
}
|
||||
if unexpected_keys:
|
||||
errors = [
|
||||
cv.Invalid(f"Unexpected key '{k}'", path=[k]) for k in unexpected_keys
|
||||
]
|
||||
raise cv.MultipleInvalid(errors)
|
||||
|
||||
add_metadata(config[CONF_ID], INDEX_TYPES[config[CONF_FROM]], value_type)
|
||||
return config
|
||||
|
||||
|
||||
@@ -107,29 +164,19 @@ CONFIG_SCHEMA = map_schema
|
||||
|
||||
|
||||
async def to_code(config):
|
||||
entries = config[CONF_ENTRIES]
|
||||
from_ = config[CONF_FROM]
|
||||
to_ = config[CONF_TO]
|
||||
index_conversion = INDEX_TYPES[from_].conversion
|
||||
index_type = INDEX_TYPES[from_].data_type
|
||||
if to_ in INDEX_TYPES:
|
||||
value_conversion = INDEX_TYPES[to_].conversion
|
||||
value_type = INDEX_TYPES[to_].data_type
|
||||
entries = {
|
||||
index_conversion(key): value_conversion(value)
|
||||
for key, value in entries.items()
|
||||
}
|
||||
else:
|
||||
entries = {
|
||||
index_conversion(key): await cg.get_variable(value)
|
||||
for key, value in entries.items()
|
||||
}
|
||||
value_type = get_object_type(to_)
|
||||
if list(entries.values())[0].op != ".":
|
||||
value_type = value_type.operator("ptr")
|
||||
varid = config[CONF_ID]
|
||||
metadata = get_mapping_metadata(varid.id)
|
||||
entries = {
|
||||
metadata.from_.conversion(key): await metadata.to_.convert_value(value)
|
||||
for key, value in config[CONF_ENTRIES].items()
|
||||
}
|
||||
value_type = metadata.to_.data_type
|
||||
# entries guaranteed to be non-empty here.
|
||||
value_0 = list(entries.values())[0]
|
||||
if isinstance(value_0, MockObj) and value_0.op != ".":
|
||||
value_type = value_type.operator("ptr")
|
||||
varid.type = mapping_class.template(
|
||||
index_type,
|
||||
metadata.from_.data_type,
|
||||
value_type,
|
||||
)
|
||||
var = MockObj(varid, ".")
|
||||
@@ -139,4 +186,6 @@ async def to_code(config):
|
||||
|
||||
for key, value in entries.items():
|
||||
cg.add(var.set(key, value))
|
||||
if (default_value := config.get(CONF_DEFAULT_VALUE)) is not None:
|
||||
cg.add(var.set_default_value(await metadata.to_.convert_value(default_value)))
|
||||
return var
|
||||
|
||||
@@ -40,6 +40,9 @@ template<typename K, typename V> class Mapping {
|
||||
if (it != this->map_.end()) {
|
||||
return V{it->second};
|
||||
}
|
||||
if (this->default_value_.has_value()) {
|
||||
return this->default_value_.value();
|
||||
}
|
||||
if constexpr (std::is_pointer_v<K>) {
|
||||
esph_log_e(TAG, "Key '%p' not found in mapping", key);
|
||||
} else if constexpr (std::is_same_v<K, std::string>) {
|
||||
@@ -69,11 +72,17 @@ template<typename K, typename V> class Mapping {
|
||||
if (it != this->map_.end()) {
|
||||
return it->second.c_str(); // safe since value remains in map
|
||||
}
|
||||
if (this->default_value_.has_value()) {
|
||||
return this->default_value_.value();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
void set_default_value(const V &default_value) { this->default_value_ = default_value; }
|
||||
|
||||
protected:
|
||||
std::map<key_t, value_t, std::less<key_t>, RAMAllocator<std::pair<key_t, value_t>>> map_;
|
||||
std::optional<V> default_value_{};
|
||||
};
|
||||
|
||||
} // namespace esphome::mapping
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import Any
|
||||
from esphome.const import SOURCE_FILE_EXTENSIONS
|
||||
from esphome.core import CORE
|
||||
import esphome.core.config
|
||||
from esphome.cpp_generator import MockObjClass
|
||||
from esphome.types import ConfigType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -93,7 +94,7 @@ class ComponentManifest:
|
||||
return getattr(self.module, "CODEOWNERS", [])
|
||||
|
||||
@property
|
||||
def instance_type(self) -> list[str]:
|
||||
def instance_type(self) -> MockObjClass | None:
|
||||
return getattr(self.module, "INSTANCE_TYPE", None)
|
||||
|
||||
@property
|
||||
|
||||
@@ -21,6 +21,7 @@ mapping:
|
||||
entries:
|
||||
clear-night: image_1
|
||||
sunny: image_2
|
||||
default_value: image_1
|
||||
- id: weather_map_2
|
||||
from: string
|
||||
to: image
|
||||
@@ -35,6 +36,7 @@ mapping:
|
||||
2: "two"
|
||||
3: "three"
|
||||
77: "seventy-seven"
|
||||
default_value: unknown
|
||||
- id: string_map
|
||||
from: string
|
||||
to: int
|
||||
|
||||
@@ -4,7 +4,7 @@ packages:
|
||||
|
||||
display:
|
||||
spi_id: spi_bus
|
||||
platform: ili9xxx
|
||||
platform: mipi_spi
|
||||
id: main_lcd
|
||||
model: ili9342
|
||||
cs_pin: 12
|
||||
|
||||
@@ -4,7 +4,7 @@ packages:
|
||||
|
||||
display:
|
||||
spi_id: spi_bus
|
||||
platform: ili9xxx
|
||||
platform: mipi_spi
|
||||
id: main_lcd
|
||||
model: ili9342
|
||||
cs_pin: 5
|
||||
|
||||
@@ -4,7 +4,7 @@ packages:
|
||||
|
||||
display:
|
||||
spi_id: spi_bus
|
||||
platform: ili9xxx
|
||||
platform: mipi_spi
|
||||
id: main_lcd
|
||||
model: ili9342
|
||||
data_rate: 31.25MHz
|
||||
|
||||
Reference in New Issue
Block a user