mirror of
https://github.com/esphome/esphome.git
synced 2026-06-24 12:17:23 +00:00
[api] Split ProtoVarInt::parse into 32-bit and 64-bit phases (#14039)
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
// See script/api_protobuf/api_protobuf.py
|
// See script/api_protobuf/api_protobuf.py
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "esphome/core/defines.h"
|
|
||||||
#include "esphome/core/string_ref.h"
|
#include "esphome/core/string_ref.h"
|
||||||
|
|
||||||
#include "proto.h"
|
#include "proto.h"
|
||||||
|
|||||||
12
esphome/components/api/api_pb2_defines.h
Normal file
12
esphome/components/api/api_pb2_defines.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// This file was automatically generated with a tool.
|
||||||
|
// See script/api_protobuf/api_protobuf.py
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "esphome/core/defines.h"
|
||||||
|
#ifdef USE_BLUETOOTH_PROXY
|
||||||
|
#ifndef USE_API_VARINT64
|
||||||
|
#define USE_API_VARINT64
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace esphome::api {} // namespace esphome::api
|
||||||
@@ -7,6 +7,23 @@ namespace esphome::api {
|
|||||||
|
|
||||||
static const char *const TAG = "api.proto";
|
static const char *const TAG = "api.proto";
|
||||||
|
|
||||||
|
#ifdef USE_API_VARINT64
|
||||||
|
optional<ProtoVarInt> ProtoVarInt::parse_wide(const uint8_t *buffer, uint32_t len, uint32_t *consumed,
|
||||||
|
uint32_t result32) {
|
||||||
|
uint64_t result64 = result32;
|
||||||
|
uint32_t limit = std::min(len, uint32_t(10));
|
||||||
|
for (uint32_t i = 4; i < limit; i++) {
|
||||||
|
uint8_t val = buffer[i];
|
||||||
|
result64 |= uint64_t(val & 0x7F) << (i * 7);
|
||||||
|
if ((val & 0x80) == 0) {
|
||||||
|
*consumed = i + 1;
|
||||||
|
return ProtoVarInt(result64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) {
|
uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) {
|
||||||
uint32_t count = 0;
|
uint32_t count = 0;
|
||||||
const uint8_t *ptr = buffer;
|
const uint8_t *ptr = buffer;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "api_pb2_defines.h"
|
||||||
#include "esphome/core/component.h"
|
#include "esphome/core/component.h"
|
||||||
#include "esphome/core/helpers.h"
|
#include "esphome/core/helpers.h"
|
||||||
#include "esphome/core/log.h"
|
#include "esphome/core/log.h"
|
||||||
@@ -110,59 +111,78 @@ class ProtoVarInt {
|
|||||||
#endif
|
#endif
|
||||||
if (len == 0)
|
if (len == 0)
|
||||||
return {};
|
return {};
|
||||||
|
// Fast path: single-byte varints (0-127) are the most common case
|
||||||
// Most common case: single-byte varint (values 0-127)
|
// (booleans, small enums, field tags). Avoid loop overhead entirely.
|
||||||
if ((buffer[0] & 0x80) == 0) {
|
if ((buffer[0] & 0x80) == 0) {
|
||||||
*consumed = 1;
|
*consumed = 1;
|
||||||
return ProtoVarInt(buffer[0]);
|
return ProtoVarInt(buffer[0]);
|
||||||
}
|
}
|
||||||
|
// 32-bit phase: process remaining bytes with native 32-bit shifts.
|
||||||
// General case for multi-byte varints
|
// Without USE_API_VARINT64: cover bytes 1-4 (shifts 7, 14, 21, 28) — the uint32_t
|
||||||
// Since we know buffer[0]'s high bit is set, initialize with its value
|
// shift at byte 4 (shift by 28) may lose bits 32-34, but those are always zero for valid uint32 values.
|
||||||
uint64_t result = buffer[0] & 0x7F;
|
// With USE_API_VARINT64: cover bytes 1-3 (shifts 7, 14, 21) so parse_wide handles
|
||||||
uint8_t bitpos = 7;
|
// byte 4+ with full 64-bit arithmetic (avoids truncating values > UINT32_MAX).
|
||||||
|
uint32_t result32 = buffer[0] & 0x7F;
|
||||||
// A 64-bit varint is at most 10 bytes (ceil(64/7)). Reject overlong encodings
|
#ifdef USE_API_VARINT64
|
||||||
// to avoid undefined behavior from shifting uint64_t by >= 64 bits.
|
uint32_t limit = std::min(len, uint32_t(4));
|
||||||
uint32_t max_len = std::min(len, uint32_t(10));
|
#else
|
||||||
|
uint32_t limit = std::min(len, uint32_t(5));
|
||||||
// Start from the second byte since we've already processed the first
|
#endif
|
||||||
for (uint32_t i = 1; i < max_len; i++) {
|
for (uint32_t i = 1; i < limit; i++) {
|
||||||
uint8_t val = buffer[i];
|
uint8_t val = buffer[i];
|
||||||
result |= uint64_t(val & 0x7F) << uint64_t(bitpos);
|
result32 |= uint32_t(val & 0x7F) << (i * 7);
|
||||||
bitpos += 7;
|
|
||||||
if ((val & 0x80) == 0) {
|
if ((val & 0x80) == 0) {
|
||||||
*consumed = i + 1;
|
*consumed = i + 1;
|
||||||
return ProtoVarInt(result);
|
return ProtoVarInt(result32);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 64-bit phase for remaining bytes (BLE addresses etc.)
|
||||||
return {}; // Incomplete or invalid varint
|
#ifdef USE_API_VARINT64
|
||||||
|
return parse_wide(buffer, len, consumed, result32);
|
||||||
|
#else
|
||||||
|
return {};
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_API_VARINT64
|
||||||
|
protected:
|
||||||
|
/// Continue parsing varint bytes 4-9 with 64-bit arithmetic.
|
||||||
|
/// Separated to keep 64-bit shift code (__ashldi3 on 32-bit platforms) out of the common path.
|
||||||
|
static optional<ProtoVarInt> parse_wide(const uint8_t *buffer, uint32_t len, uint32_t *consumed, uint32_t result32)
|
||||||
|
__attribute__((noinline));
|
||||||
|
|
||||||
|
public:
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr uint16_t as_uint16() const { return this->value_; }
|
constexpr uint16_t as_uint16() const { return this->value_; }
|
||||||
constexpr uint32_t as_uint32() const { return this->value_; }
|
constexpr uint32_t as_uint32() const { return this->value_; }
|
||||||
constexpr uint64_t as_uint64() const { return this->value_; }
|
|
||||||
constexpr bool as_bool() const { return this->value_; }
|
constexpr bool as_bool() const { return this->value_; }
|
||||||
constexpr int32_t as_int32() const {
|
constexpr int32_t as_int32() const {
|
||||||
// Not ZigZag encoded
|
// Not ZigZag encoded
|
||||||
return static_cast<int32_t>(this->as_int64());
|
return static_cast<int32_t>(this->value_);
|
||||||
}
|
|
||||||
constexpr int64_t as_int64() const {
|
|
||||||
// Not ZigZag encoded
|
|
||||||
return static_cast<int64_t>(this->value_);
|
|
||||||
}
|
}
|
||||||
constexpr int32_t as_sint32() const {
|
constexpr int32_t as_sint32() const {
|
||||||
// with ZigZag encoding
|
// with ZigZag encoding
|
||||||
return decode_zigzag32(static_cast<uint32_t>(this->value_));
|
return decode_zigzag32(static_cast<uint32_t>(this->value_));
|
||||||
}
|
}
|
||||||
|
#ifdef USE_API_VARINT64
|
||||||
|
constexpr uint64_t as_uint64() const { return this->value_; }
|
||||||
|
constexpr int64_t as_int64() const {
|
||||||
|
// Not ZigZag encoded
|
||||||
|
return static_cast<int64_t>(this->value_);
|
||||||
|
}
|
||||||
constexpr int64_t as_sint64() const {
|
constexpr int64_t as_sint64() const {
|
||||||
// with ZigZag encoding
|
// with ZigZag encoding
|
||||||
return decode_zigzag64(this->value_);
|
return decode_zigzag64(this->value_);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
#ifdef USE_API_VARINT64
|
||||||
uint64_t value_;
|
uint64_t value_;
|
||||||
|
#else
|
||||||
|
uint32_t value_;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
// Forward declarations for decode_to_message, encode_message and encode_packed_sint32
|
// Forward declarations for decode_to_message, encode_message and encode_packed_sint32
|
||||||
|
|||||||
@@ -144,6 +144,7 @@
|
|||||||
#define USE_API_HOMEASSISTANT_SERVICES
|
#define USE_API_HOMEASSISTANT_SERVICES
|
||||||
#define USE_API_HOMEASSISTANT_STATES
|
#define USE_API_HOMEASSISTANT_STATES
|
||||||
#define USE_API_NOISE
|
#define USE_API_NOISE
|
||||||
|
#define USE_API_VARINT64
|
||||||
#define USE_API_PLAINTEXT
|
#define USE_API_PLAINTEXT
|
||||||
#define USE_API_USER_DEFINED_ACTIONS
|
#define USE_API_USER_DEFINED_ACTIONS
|
||||||
#define USE_API_CUSTOM_SERVICES
|
#define USE_API_CUSTOM_SERVICES
|
||||||
|
|||||||
@@ -1913,6 +1913,37 @@ def build_type_usage_map(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_varint64_ifdef(
|
||||||
|
file_desc: descriptor.FileDescriptorProto,
|
||||||
|
message_ifdef_map: dict[str, str | None],
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
"""Check if 64-bit varint fields exist and get their common ifdef guard.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(has_varint64, ifdef_guard) - has_varint64 is True if any fields exist,
|
||||||
|
ifdef_guard is the common guard or None if unconditional.
|
||||||
|
"""
|
||||||
|
varint64_types = {
|
||||||
|
FieldDescriptorProto.TYPE_INT64,
|
||||||
|
FieldDescriptorProto.TYPE_UINT64,
|
||||||
|
FieldDescriptorProto.TYPE_SINT64,
|
||||||
|
}
|
||||||
|
ifdefs: set[str | None] = {
|
||||||
|
message_ifdef_map.get(msg.name)
|
||||||
|
for msg in file_desc.message_type
|
||||||
|
if not msg.options.deprecated
|
||||||
|
for field in msg.field
|
||||||
|
if not field.options.deprecated and field.type in varint64_types
|
||||||
|
}
|
||||||
|
if not ifdefs:
|
||||||
|
return False, None
|
||||||
|
if None in ifdefs:
|
||||||
|
# At least one 64-bit varint field is unconditional, so the guard must be unconditional.
|
||||||
|
return True, None
|
||||||
|
ifdefs.discard(None)
|
||||||
|
return True, ifdefs.pop() if len(ifdefs) == 1 else None
|
||||||
|
|
||||||
|
|
||||||
def build_enum_type(desc, enum_ifdef_map) -> tuple[str, str, str]:
|
def build_enum_type(desc, enum_ifdef_map) -> tuple[str, str, str]:
|
||||||
"""Builds the enum type.
|
"""Builds the enum type.
|
||||||
|
|
||||||
@@ -2567,11 +2598,38 @@ def main() -> None:
|
|||||||
|
|
||||||
file = d.file[0]
|
file = d.file[0]
|
||||||
|
|
||||||
|
# Build dynamic ifdef mappings early so we can emit USE_API_VARINT64 before includes
|
||||||
|
enum_ifdef_map, message_ifdef_map, message_source_map, used_messages = (
|
||||||
|
build_type_usage_map(file)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the ifdef guard for 64-bit varint fields (int64/uint64/sint64).
|
||||||
|
# Generated into api_pb2_defines.h so proto.h can include it, ensuring
|
||||||
|
# consistent ProtoVarInt layout across all translation units.
|
||||||
|
has_varint64, varint64_guard = get_varint64_ifdef(file, message_ifdef_map)
|
||||||
|
|
||||||
|
# Generate api_pb2_defines.h — included by proto.h to ensure all translation
|
||||||
|
# units see USE_API_VARINT64 consistently (avoids ODR violations in ProtoVarInt).
|
||||||
|
defines_content = FILE_HEADER
|
||||||
|
defines_content += "#pragma once\n\n"
|
||||||
|
defines_content += '#include "esphome/core/defines.h"\n'
|
||||||
|
if has_varint64:
|
||||||
|
lines = [
|
||||||
|
"#ifndef USE_API_VARINT64",
|
||||||
|
"#define USE_API_VARINT64",
|
||||||
|
"#endif",
|
||||||
|
]
|
||||||
|
defines_content += "\n".join(wrap_with_ifdef(lines, varint64_guard))
|
||||||
|
defines_content += "\n"
|
||||||
|
defines_content += "\nnamespace esphome::api {} // namespace esphome::api\n"
|
||||||
|
|
||||||
|
with open(root / "api_pb2_defines.h", "w", encoding="utf-8") as f:
|
||||||
|
f.write(defines_content)
|
||||||
|
|
||||||
content = FILE_HEADER
|
content = FILE_HEADER
|
||||||
content += """\
|
content += """\
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "esphome/core/defines.h"
|
|
||||||
#include "esphome/core/string_ref.h"
|
#include "esphome/core/string_ref.h"
|
||||||
|
|
||||||
#include "proto.h"
|
#include "proto.h"
|
||||||
@@ -2702,11 +2760,6 @@ static void dump_bytes_field(DumpBuffer &out, const char *field_name, const uint
|
|||||||
|
|
||||||
content += "namespace enums {\n\n"
|
content += "namespace enums {\n\n"
|
||||||
|
|
||||||
# Build dynamic ifdef mappings for both enums and messages
|
|
||||||
enum_ifdef_map, message_ifdef_map, message_source_map, used_messages = (
|
|
||||||
build_type_usage_map(file)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simple grouping of enums by ifdef
|
# Simple grouping of enums by ifdef
|
||||||
current_ifdef = None
|
current_ifdef = None
|
||||||
|
|
||||||
|
|||||||
47
tests/integration/fixtures/varint_five_byte_device_id.yaml
Normal file
47
tests/integration/fixtures/varint_five_byte_device_id.yaml
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
esphome:
|
||||||
|
name: varint-5byte-test
|
||||||
|
# Define areas and devices - device_ids will be FNV hashes > 2^28,
|
||||||
|
# requiring 5-byte varint encoding that exercises the 32-bit parse boundary.
|
||||||
|
areas:
|
||||||
|
- id: test_area
|
||||||
|
name: Test Area
|
||||||
|
devices:
|
||||||
|
- id: sub_device_one
|
||||||
|
name: Sub Device One
|
||||||
|
area_id: test_area
|
||||||
|
- id: sub_device_two
|
||||||
|
name: Sub Device Two
|
||||||
|
area_id: test_area
|
||||||
|
|
||||||
|
host:
|
||||||
|
api:
|
||||||
|
logger:
|
||||||
|
|
||||||
|
# Switches on sub-devices so we can send commands with large device_id varints
|
||||||
|
switch:
|
||||||
|
- platform: template
|
||||||
|
name: Device Switch
|
||||||
|
device_id: sub_device_one
|
||||||
|
id: device_switch_one
|
||||||
|
optimistic: true
|
||||||
|
turn_on_action:
|
||||||
|
- logger.log: "Switch one on"
|
||||||
|
turn_off_action:
|
||||||
|
- logger.log: "Switch one off"
|
||||||
|
|
||||||
|
- platform: template
|
||||||
|
name: Device Switch
|
||||||
|
device_id: sub_device_two
|
||||||
|
id: device_switch_two
|
||||||
|
optimistic: true
|
||||||
|
turn_on_action:
|
||||||
|
- logger.log: "Switch two on"
|
||||||
|
turn_off_action:
|
||||||
|
- logger.log: "Switch two off"
|
||||||
|
|
||||||
|
sensor:
|
||||||
|
- platform: template
|
||||||
|
name: Device Sensor
|
||||||
|
device_id: sub_device_one
|
||||||
|
lambda: return 42.0;
|
||||||
|
update_interval: 0.1s
|
||||||
120
tests/integration/test_varint_five_byte_device_id.py
Normal file
120
tests/integration/test_varint_five_byte_device_id.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Integration test for 5-byte varint parsing of device_id fields.
|
||||||
|
|
||||||
|
Device IDs are FNV hashes (uint32) that frequently exceed 2^28 (268435456),
|
||||||
|
requiring 5 varint bytes. This test verifies that:
|
||||||
|
1. The firmware correctly decodes 5-byte varint device_id in incoming commands
|
||||||
|
2. The firmware correctly encodes large device_id values in state responses
|
||||||
|
3. Switch commands with large device_id reach the correct entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from aioesphomeapi import EntityState, SwitchInfo, SwitchState
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .types import APIClientConnectedFactory, RunCompiledFunction
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_varint_five_byte_device_id(
|
||||||
|
yaml_config: str,
|
||||||
|
run_compiled: RunCompiledFunction,
|
||||||
|
api_client_connected: APIClientConnectedFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test that device_id values requiring 5-byte varints parse correctly."""
|
||||||
|
async with run_compiled(yaml_config), api_client_connected() as client:
|
||||||
|
device_info = await client.device_info()
|
||||||
|
devices = device_info.devices
|
||||||
|
assert len(devices) >= 2, f"Expected at least 2 devices, got {len(devices)}"
|
||||||
|
|
||||||
|
# Verify at least one device_id exceeds the 4-byte varint boundary (2^28)
|
||||||
|
large_ids = [d for d in devices if d.device_id >= (1 << 28)]
|
||||||
|
assert len(large_ids) > 0, (
|
||||||
|
"Expected at least one device_id >= 2^28 to exercise 5-byte varint path. "
|
||||||
|
f"Got device_ids: {[d.device_id for d in devices]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get entities
|
||||||
|
all_entities, _ = await client.list_entities_services()
|
||||||
|
switch_entities = [e for e in all_entities if isinstance(e, SwitchInfo)]
|
||||||
|
|
||||||
|
# Find switches named "Device Switch" — one per sub-device
|
||||||
|
device_switches = [e for e in switch_entities if e.name == "Device Switch"]
|
||||||
|
assert len(device_switches) == 2, (
|
||||||
|
f"Expected 2 'Device Switch' entities, got {len(device_switches)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify switches have different device_ids matching the sub-devices
|
||||||
|
switch_device_ids = {s.device_id for s in device_switches}
|
||||||
|
assert len(switch_device_ids) == 2, "Switches should have different device_ids"
|
||||||
|
|
||||||
|
# Subscribe to states and wait for initial states
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
states: dict[tuple[int, int], EntityState] = {}
|
||||||
|
switch_futures: dict[tuple[int, int], asyncio.Future[EntityState]] = {}
|
||||||
|
initial_done: asyncio.Future[bool] = loop.create_future()
|
||||||
|
|
||||||
|
def on_state(state: EntityState) -> None:
|
||||||
|
key = (state.device_id, state.key)
|
||||||
|
states[key] = state
|
||||||
|
|
||||||
|
if len(states) >= 3 and not initial_done.done():
|
||||||
|
initial_done.set_result(True)
|
||||||
|
|
||||||
|
if initial_done.done() and key in switch_futures:
|
||||||
|
fut = switch_futures[key]
|
||||||
|
if not fut.done() and isinstance(state, SwitchState):
|
||||||
|
fut.set_result(state)
|
||||||
|
|
||||||
|
client.subscribe_states(on_state)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(initial_done, timeout=10.0)
|
||||||
|
except TimeoutError:
|
||||||
|
pytest.fail(
|
||||||
|
f"Timed out waiting for initial states. Got {len(states)} states"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify state responses contain correct large device_id values
|
||||||
|
for device in devices:
|
||||||
|
device_states = [
|
||||||
|
s for (did, _), s in states.items() if did == device.device_id
|
||||||
|
]
|
||||||
|
assert len(device_states) > 0, (
|
||||||
|
f"No states received for device '{device.name}' "
|
||||||
|
f"(device_id={device.device_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test switch commands with large device_id varints —
|
||||||
|
# this is the critical path: the client encodes device_id as a varint
|
||||||
|
# in the SwitchCommandRequest, and the firmware must decode it correctly.
|
||||||
|
for switch in device_switches:
|
||||||
|
state_key = (switch.device_id, switch.key)
|
||||||
|
|
||||||
|
# Turn on
|
||||||
|
switch_futures[state_key] = loop.create_future()
|
||||||
|
client.switch_command(switch.key, True, device_id=switch.device_id)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(switch_futures[state_key], timeout=2.0)
|
||||||
|
except TimeoutError:
|
||||||
|
pytest.fail(
|
||||||
|
f"Timed out waiting for switch ON state "
|
||||||
|
f"(device_id={switch.device_id}, key={switch.key}). "
|
||||||
|
f"This likely means the firmware failed to decode the "
|
||||||
|
f"5-byte varint device_id in SwitchCommandRequest."
|
||||||
|
)
|
||||||
|
assert states[state_key].state is True
|
||||||
|
|
||||||
|
# Turn off
|
||||||
|
switch_futures[state_key] = loop.create_future()
|
||||||
|
client.switch_command(switch.key, False, device_id=switch.device_id)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(switch_futures[state_key], timeout=2.0)
|
||||||
|
except TimeoutError:
|
||||||
|
pytest.fail(
|
||||||
|
f"Timed out waiting for switch OFF state "
|
||||||
|
f"(device_id={switch.device_id}, key={switch.key})"
|
||||||
|
)
|
||||||
|
assert states[state_key].state is False
|
||||||
Reference in New Issue
Block a user