[api] Peel first write iteration, inline socket writes, zero-gap batch encoding (#15063)

This commit is contained in:
J. Nick Koston
2026-04-08 11:05:53 -10:00
committed by GitHub
parent 4a18ef87d7
commit 576d89a82a
10 changed files with 374 additions and 253 deletions

View File

@@ -406,7 +406,7 @@ uint16_t APIConnection::fill_and_encode_entity_info(EntityBase *entity, InfoResp
#ifdef USE_DEVICES
msg.device_id = entity->get_device_id();
#endif
return encode_to_buffer(size_fn(&msg), encode_fn, &msg, conn, remaining_size);
return encode_to_buffer_slow(size_fn(&msg), encode_fn, &msg, conn, remaining_size);
}
uint16_t APIConnection::fill_and_encode_entity_info_with_device_class(EntityBase *entity, InfoResponseProtoMessage &msg,
@@ -2005,48 +2005,12 @@ bool APIConnection::send_message_(uint32_t payload_size, uint8_t message_type, M
encode_fn(msg, buffer PROTO_ENCODE_DEBUG_INIT(&shared_buf));
return this->send_buffer(ProtoWriteBuffer{&shared_buf}, message_type);
}
// Encodes a message to the buffer and returns the total number of bytes used,
// including header and footer overhead. Returns 0 if the message doesn't fit.
uint16_t APIConnection::encode_to_buffer(uint32_t calculated_size, MessageEncodeFn encode_fn, const void *msg,
APIConnection *conn, uint32_t remaining_size) {
#ifdef HAS_PROTO_MESSAGE_DUMP
if (conn->flags_.log_only_mode) {
auto *proto_msg = static_cast<const ProtoMessage *>(msg);
DumpBuffer dump_buf;
conn->log_send_message_(proto_msg->message_name(), proto_msg->dump_to(dump_buf));
return 1;
}
#endif
// Cache frame sizes to avoid repeated virtual calls
const uint8_t header_padding = conn->helper_->frame_header_padding();
const uint8_t footer_size = conn->helper_->frame_footer_size();
// encode_to_buffer is defined inline in api_connection.h (ESPHOME_ALWAYS_INLINE)
// Calculate total size with padding for buffer allocation
size_t total_calculated_size = calculated_size + header_padding + footer_size;
// Check if it fits
if (total_calculated_size > remaining_size)
return 0; // Doesn't fit
auto &shared_buf = conn->parent_->get_shared_buffer_ref();
size_t to_add;
if (conn->flags_.batch_first_message) {
// First message - buffer already prepared by caller, just clear flag
conn->flags_.batch_first_message = false;
to_add = calculated_size;
} else {
// Batch message second or later
// Reserve for full message, resize to include footer gap + header padding + payload
to_add = total_calculated_size;
}
shared_buf.resize(shared_buf.size() + to_add);
ProtoWriteBuffer buffer{&shared_buf, shared_buf.size() - calculated_size};
encode_fn(msg, buffer PROTO_ENCODE_DEBUG_INIT(&shared_buf));
// Return total size (header + payload + footer)
return static_cast<uint16_t>(total_calculated_size);
// Noinline version for cold paths — single shared copy
uint16_t APIConnection::encode_to_buffer_slow(uint32_t calculated_size, MessageEncodeFn encode_fn, const void *msg,
APIConnection *conn, uint32_t remaining_size) {
return encode_to_buffer(calculated_size, encode_fn, msg, conn, remaining_size);
}
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) {
const bool is_log_message = (message_type == SubscribeLogsResponse::MESSAGE_TYPE);
@@ -2173,17 +2137,15 @@ void APIConnection::process_batch_multi_(APIBuffer &shared_buf, size_t num_items
"MessageInfo must remain trivially destructible with this placement-new approach");
const size_t messages_to_process = std::min(num_items, MAX_MESSAGES_PER_BATCH);
const uint8_t frame_overhead = header_padding + footer_size;
// Stack-allocated array for message info
alignas(MessageInfo) char message_info_storage[MAX_MESSAGES_PER_BATCH * sizeof(MessageInfo)];
MessageInfo *message_info = reinterpret_cast<MessageInfo *>(message_info_storage);
size_t items_processed = 0;
uint16_t remaining_size = std::numeric_limits<uint16_t>::max();
// Track where each message's header padding begins in the buffer
// For plaintext: this is where the 6-byte header padding starts
// For noise: this is where the 7-byte header padding starts
// The actual message data follows after the header padding
// Track where each message's header begins in the buffer
// First message: offset 0 (max padding, may have unused leading bytes)
// Subsequent messages: offset points to exact header start (no gaps)
uint32_t current_offset = 0;
// Process items and encode directly to buffer (up to our limit)
@@ -2199,13 +2161,14 @@ void APIConnection::process_batch_multi_(APIBuffer &shared_buf, size_t num_items
}
// Message was encoded successfully
// payload_size is header_padding + actual payload size + footer_size
uint16_t proto_payload_size = payload_size - frame_overhead;
// payload_size = header_size + proto_payload_size + footer_size
uint16_t proto_payload_size = payload_size - this->batch_header_size_ - footer_size;
// Use placement new to construct MessageInfo in pre-allocated stack array
// This avoids default-constructing all MAX_MESSAGES_PER_BATCH elements
// Explicit destruction is not needed because MessageInfo is trivially destructible,
// as ensured by the static_assert in its definition.
new (&message_info[items_processed++]) MessageInfo(item.message_type, current_offset, proto_payload_size);
new (&message_info[items_processed++])
MessageInfo(item.message_type, current_offset, proto_payload_size, this->batch_header_size_);
// After first message, set remaining size to MAX_BATCH_PACKET_SIZE to avoid fragmentation
if (items_processed == 1) {
remaining_size = MAX_BATCH_PACKET_SIZE;
@@ -2255,6 +2218,7 @@ void APIConnection::process_batch_multi_(APIBuffer &shared_buf, size_t num_items
uint16_t APIConnection::dispatch_message_(const DeferredBatch::BatchItem &item, uint32_t remaining_size,
bool batch_first) {
this->flags_.batch_first_message = batch_first;
this->batch_message_type_ = item.message_type;
#ifdef USE_EVENT
// Events need aux_data_index to look up event type from entity
if (item.message_type == EventResponse::MESSAGE_TYPE) {

View File

@@ -411,16 +411,59 @@ class APIConnection final : public APIServerConnectionBase {
// Non-template buffer management for send_message
bool send_message_(uint32_t payload_size, uint8_t message_type, MessageEncodeFn encode_fn, const void *msg);
// Non-template buffer management for batch encoding
static uint16_t encode_to_buffer(uint32_t calculated_size, MessageEncodeFn encode_fn, const void *msg,
APIConnection *conn, uint32_t remaining_size);
// Core batch encoding logic. Computes header size, checks fit, resizes buffer, encodes.
// ALWAYS_INLINE so the compiler can devirtualize encode_fn at hot call sites.
static inline uint16_t ESPHOME_ALWAYS_INLINE encode_to_buffer(uint32_t calculated_size, MessageEncodeFn encode_fn,
const void *msg, APIConnection *conn,
uint32_t remaining_size) {
#ifdef HAS_PROTO_MESSAGE_DUMP
if (conn->flags_.log_only_mode) {
auto *proto_msg = static_cast<const ProtoMessage *>(msg);
DumpBuffer dump_buf;
conn->log_send_message_(proto_msg->message_name(), proto_msg->dump_to(dump_buf));
return 1;
}
#endif
const uint8_t footer_size = conn->helper_->frame_footer_size();
// Thin template wrapper — computes size, delegates buffer work to non-template helper
// First message uses max padding (already in buffer), subsequent use exact header size
size_t to_add;
if (conn->flags_.batch_first_message) {
conn->flags_.batch_first_message = false;
conn->batch_header_size_ = conn->helper_->frame_header_padding();
to_add = calculated_size;
} else {
conn->batch_header_size_ = conn->helper_->frame_header_size(calculated_size, conn->batch_message_type_);
to_add = calculated_size + conn->batch_header_size_ + footer_size;
}
// Check if it fits (using actual header size, not max padding)
uint16_t total_calculated_size = calculated_size + conn->batch_header_size_ + footer_size;
if (total_calculated_size > remaining_size)
return 0;
auto &shared_buf = conn->parent_->get_shared_buffer_ref();
shared_buf.resize(shared_buf.size() + to_add);
ProtoWriteBuffer buffer{&shared_buf, shared_buf.size() - calculated_size};
encode_fn(msg, buffer PROTO_ENCODE_DEBUG_INIT(&shared_buf));
return total_calculated_size;
}
// Noinline version of encode_to_buffer for cold paths (entity info, zero-payload messages).
// All cold callers share this single copy instead of each getting an ALWAYS_INLINE expansion.
static uint16_t encode_to_buffer_slow(uint32_t calculated_size, MessageEncodeFn encode_fn, const void *msg,
APIConnection *conn, uint32_t remaining_size);
// Thin template wrapper — uses noinline encode_to_buffer_slow since
// encode_message_to_buffer callers are cold paths (zero-payload control messages).
// Hot paths (state/info) go through fill_and_encode_entity_state/info instead.
// batch_message_type_ is already set by dispatch_message_ before reaching here.
template<typename T> static uint16_t encode_message_to_buffer(T &msg, APIConnection *conn, uint32_t remaining_size) {
if constexpr (T::ESTIMATED_SIZE == 0) {
return encode_to_buffer(0, &encode_msg_noop, &msg, conn, remaining_size);
return encode_to_buffer_slow(0, &encode_msg_noop, &msg, conn, remaining_size);
} else {
return encode_to_buffer(msg.calculate_size(), &proto_encode_msg<T>, &msg, conn, remaining_size);
return encode_to_buffer_slow(msg.calculate_size(), &proto_encode_msg<T>, &msg, conn, remaining_size);
}
}
@@ -735,9 +778,14 @@ class APIConnection final : public APIServerConnectionBase {
// 2-byte types immediately after flags_ (no padding between them)
uint16_t client_api_version_major_{0};
uint16_t client_api_version_minor_{0};
// 1-byte type to fill padding
// 1-byte types to fill remaining space before next 4-byte boundary
ActiveIterator active_iterator_{ActiveIterator::NONE};
// Total: 2 (flags) + 2 + 2 + 1 = 7 bytes, then 1 byte padding to next 4-byte boundary
uint8_t batch_message_type_{0}; // Current message type during batch encoding
// Total: 2 (flags) + 2 + 2 + 1 + 1 = 8 bytes, aligned to 4-byte boundary
// Actual header size used by encode_to_buffer for the current message.
// Read by process_batch_multi_ to pass into MessageInfo.
uint8_t batch_header_size_{0};
uint32_t get_batch_delay_ms_() const { return this->parent_->get_batch_delay(); }
// Message will use 8 more bytes than the minimum size, and typical

View File

@@ -100,10 +100,17 @@ const LogString *api_error_to_logstr(APIError err) {
return LOG_STR("UNKNOWN");
}
#ifdef HELPER_LOG_PACKETS
void APIFrameHelper::log_packet_sending_(const void *data, uint16_t len) {
LOG_PACKET_SENDING(reinterpret_cast<const uint8_t *>(data), len);
}
#endif
APIError APIFrameHelper::drain_overflow_and_handle_errors_() {
if (this->overflow_buf_.try_drain(this->socket_.get()) == -1) {
int err = errno;
if (this->check_socket_write_err_(err) != APIError::WOULD_BLOCK) {
if (err != EWOULDBLOCK && err != EAGAIN) {
this->state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", err);
return APIError::SOCKET_WRITE_FAILED;
}
@@ -111,45 +118,58 @@ APIError APIFrameHelper::drain_overflow_and_handle_errors_() {
return APIError::OK;
}
// Write data to socket, overflow to backlog buffer if LWIP TCP send buffer is full.
// Returns OK if all data was sent or successfully queued.
// Returns SOCKET_WRITE_FAILED on hard error (sets state to FAILED).
APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, uint16_t total_write_len) {
// Single-buffer write path: wraps in iovec and delegates.
APIError APIFrameHelper::write_raw_buf_(const void *data, uint16_t len, ssize_t sent) {
struct iovec iov = {const_cast<void *>(data), len};
APIError err = this->write_raw_iov_(&iov, 1, len, sent);
#ifdef HELPER_LOG_PACKETS
for (int i = 0; i < iovcnt; i++) {
LOG_PACKET_SENDING(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len);
}
// Log after write/enqueue so re-entrant log sends can't corrupt data before it's sent
if (err == APIError::OK)
LOG_PACKET_SENDING(reinterpret_cast<const uint8_t *>(data), len);
#endif
return err;
}
uint16_t skip = 0;
// Drain any existing backlog first
if (!this->overflow_buf_.empty()) [[unlikely]] {
APIError err = this->drain_overflow_and_handle_errors_();
if (err != APIError::OK)
return err;
}
// If backlog is clear, try direct send
if (this->overflow_buf_.empty()) [[likely]] {
ssize_t sent =
(iovcnt == 1) ? this->socket_->write(iov[0].iov_base, iov[0].iov_len) : this->socket_->writev(iov, iovcnt);
if (sent == -1) [[unlikely]] {
// Handles partial writes, errors, and overflow buffering.
// Called when the inline fast path couldn't complete the write,
// or directly from cold paths (handshake, error handling).
APIError APIFrameHelper::write_raw_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len, ssize_t sent) {
if (sent <= 0) {
if (sent == WRITE_NOT_ATTEMPTED) {
// Cold path: no write attempted yet, drain overflow and try
if (!this->overflow_buf_.empty()) {
APIError err = this->drain_overflow_and_handle_errors_();
if (err != APIError::OK)
return err;
}
if (this->overflow_buf_.empty()) {
sent = this->write_iov_to_socket_(iov, iovcnt);
if (sent == static_cast<ssize_t>(total_write_len))
return APIError::OK;
// Partial write or -1: fall through to error check / enqueue below
} else {
// Overflow backlog remains after drain; skip socket write, enqueue everything
sent = 0;
}
}
// WRITE_FAILED (-1): fast path or retry write returned -1, check errno
if (sent == WRITE_FAILED) {
int err = errno;
if (this->check_socket_write_err_(err) != APIError::WOULD_BLOCK) {
if (err != EWOULDBLOCK && err != EAGAIN) {
this->state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", err);
return APIError::SOCKET_WRITE_FAILED;
}
} else if (static_cast<uint16_t>(sent) >= total_write_len) [[likely]] {
return APIError::OK;
} else {
skip = static_cast<uint16_t>(sent);
sent = 0; // Treat WOULD_BLOCK as zero bytes sent
}
}
// Full write completed (possible when called directly, not via write_raw_fast_buf_)
if (sent == static_cast<ssize_t>(total_write_len))
return APIError::OK;
// Queue unsent data into overflow buffer
if (!this->overflow_buf_.enqueue_iov(iov, iovcnt, total_write_len, skip)) {
if (!this->overflow_buf_.enqueue_iov(iov, iovcnt, total_write_len, static_cast<uint16_t>(sent))) {
HELPER_LOG("Overflow buffer full, dropping connection");
this->state_ = State::FAILED;
return APIError::SOCKET_WRITE_FAILED;

View File

@@ -49,12 +49,17 @@ struct ReadPacketBuffer {
};
// Packed message info structure to minimize memory usage
// Note: message_type is uint8_t — all current protobuf message types fit in 8 bits.
// The noise wire format encodes types as 16-bit, but the high byte is always 0.
// If message types ever exceed 255, this and encrypt_noise_message_ must be updated.
struct MessageInfo {
uint16_t offset; // Offset in buffer where message starts
uint16_t payload_size; // Size of the message payload
uint8_t message_type; // Message type (0-255)
uint8_t header_size; // Actual header size used (avoids recomputation in write path)
MessageInfo(uint8_t type, uint16_t off, uint16_t size) : offset(off), payload_size(size), message_type(type) {}
MessageInfo(uint8_t type, uint16_t off, uint16_t size, uint8_t hdr)
: offset(off), payload_size(size), message_type(type), header_size(hdr) {}
};
enum class APIError : uint16_t {
@@ -161,20 +166,33 @@ class APIFrameHelper {
this->nodelay_counter_ = 0;
}
}
APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
// Resize buffer to include footer space if needed (e.g. Noise MAC)
if (frame_footer_size_)
buffer.get_buffer()->resize(buffer.get_buffer()->size() + frame_footer_size_);
MessageInfo msg{type, 0,
static_cast<uint16_t>(buffer.get_buffer()->size() - frame_header_padding_ - frame_footer_size_)};
return write_protobuf_messages(buffer, std::span<const MessageInfo>(&msg, 1));
}
// Write multiple protobuf messages in a single operation
// messages contains (message_type, offset, length) for each message in the buffer
// The buffer contains all messages with appropriate padding before each
// Write a single protobuf message - the hot path (87-100% of all writes).
// Caller must ensure state is DATA before calling.
virtual APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) = 0;
// Write multiple protobuf messages in a single batched operation.
// Caller must ensure state is DATA and messages is not empty.
// messages contains (message_type, offset, length) for each message in the buffer.
// The buffer contains all messages with appropriate padding before each.
virtual APIError write_protobuf_messages(ProtoWriteBuffer buffer, std::span<const MessageInfo> messages) = 0;
// Get the frame header padding required by this protocol
// Get the maximum frame header padding required by this protocol (worst case)
uint8_t frame_header_padding() const { return frame_header_padding_; }
// Get the actual frame header size for a specific message.
// For noise: always returns frame_header_padding_ (fixed 7-byte header).
// For plaintext: computes actual size from varint lengths (3-6 bytes).
// Distinguishes protocols via frame_footer_size_ (noise always has a non-zero MAC
// footer, plaintext has footer=0). If a protocol with a plaintext footer is ever
// added, this should become a virtual method.
uint8_t frame_header_size(uint16_t payload_size, uint8_t message_type) const {
#if defined(USE_API_NOISE) && defined(USE_API_PLAINTEXT)
return this->frame_footer_size_
? this->frame_header_padding_
: static_cast<uint8_t>(1 + ProtoSize::varint16(payload_size) + ProtoSize::varint8(message_type));
#elif defined(USE_API_NOISE)
return this->frame_header_padding_;
#else // USE_API_PLAINTEXT only
return static_cast<uint8_t>(1 + ProtoSize::varint16(payload_size) + ProtoSize::varint8(message_type));
#endif
}
// Get the frame footer size required by this protocol
uint8_t frame_footer_size() const { return frame_footer_size_; }
// Check if socket has data ready to read
@@ -196,18 +214,41 @@ class APIFrameHelper {
// Returns OK for transient errors (WOULD_BLOCK), SOCKET_WRITE_FAILED for hard errors.
APIError drain_overflow_and_handle_errors_();
// Common implementation for writing raw data to socket
APIError write_raw_(const struct iovec *iov, int iovcnt, uint16_t total_write_len);
// Sentinel values for the sent parameter in write_raw_ methods
static constexpr ssize_t WRITE_FAILED = -1; // Fast path: write()/writev() returned -1
static constexpr ssize_t WRITE_NOT_ATTEMPTED = -2; // Cold path: no write attempted yet
// Check if a socket write errno is a hard error (not WOULD_BLOCK/EAGAIN).
// Returns WOULD_BLOCK for transient errors, SOCKET_WRITE_FAILED for hard errors.
APIError check_socket_write_err_(int err) {
if (err == EWOULDBLOCK || err == EAGAIN)
return APIError::WOULD_BLOCK;
this->state_ = State::FAILED;
return APIError::SOCKET_WRITE_FAILED;
// Dispatch to write() or writev() based on iovec count
inline ssize_t ESPHOME_ALWAYS_INLINE write_iov_to_socket_(const struct iovec *iov, int iovcnt) {
return (iovcnt == 1) ? this->socket_->write(iov[0].iov_base, iov[0].iov_len) : this->socket_->writev(iov, iovcnt);
}
// Inlined write methods — used by hot paths (write_protobuf_packet, write_protobuf_messages)
// These inline the fast path (overflow empty + full write) and tail-call the out-of-line
// slow path only on failure/partial write.
inline APIError ESPHOME_ALWAYS_INLINE write_raw_fast_buf_(const void *data, uint16_t len) {
if (this->overflow_buf_.empty()) [[likely]] {
ssize_t sent = this->socket_->write(data, len);
if (sent == static_cast<ssize_t>(len)) [[likely]] {
#ifdef HELPER_LOG_PACKETS
this->log_packet_sending_(data, len);
#endif
return APIError::OK;
}
// sent is -1 (WRITE_FAILED) or partial write count
return this->write_raw_buf_(data, len, sent);
}
return this->write_raw_buf_(data, len, WRITE_NOT_ATTEMPTED);
}
// Out-of-line write paths: handle partial writes, errors, overflow buffering
// sent: WRITE_NOT_ATTEMPTED (cold path), WRITE_FAILED (fast path write returned -1), or bytes sent (partial write)
APIError write_raw_buf_(const void *data, uint16_t len, ssize_t sent = WRITE_NOT_ATTEMPTED);
APIError write_raw_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len,
ssize_t sent = WRITE_NOT_ATTEMPTED);
#ifdef HELPER_LOG_PACKETS
void log_packet_sending_(const void *data, uint16_t len);
#endif
// Socket ownership (4 bytes on 32-bit, 8 bytes on 64-bit)
std::unique_ptr<socket::Socket> socket_;

View File

@@ -47,15 +47,8 @@ static constexpr size_t API_MAX_LOG_BYTES = 168;
format_hex_pretty_to(hex_buf_, (buffer).data(), \
(buffer).size() < API_MAX_LOG_BYTES ? (buffer).size() : API_MAX_LOG_BYTES)); \
} while (0)
#define LOG_PACKET_SENDING(data, len) \
do { \
char hex_buf_[format_hex_pretty_size(API_MAX_LOG_BYTES)]; \
ESP_LOGVV(TAG, "Sending raw: %s", \
format_hex_pretty_to(hex_buf_, data, (len) < API_MAX_LOG_BYTES ? (len) : API_MAX_LOG_BYTES)); \
} while (0)
#else
#define LOG_PACKET_RECEIVED(buffer) ((void) 0)
#define LOG_PACKET_SENDING(data, len) ((void) 0)
#endif
/// Convert a noise error code to a readable error
@@ -464,65 +457,83 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
buffer->type = type;
return APIError::OK;
}
APIError APINoiseFrameHelper::write_protobuf_messages(ProtoWriteBuffer buffer, std::span<const MessageInfo> messages) {
APIError aerr = this->check_data_state_();
// Encrypt a single noise message in place and return the encrypted frame length.
// Returns APIError::OK on success.
APIError APINoiseFrameHelper::encrypt_noise_message_(uint8_t *buf_start, uint16_t payload_size, uint8_t message_type,
uint16_t &encrypted_len_out) {
// Write noise header
buf_start[0] = 0x01; // indicator
// buf_start[1], buf_start[2] to be set after encryption
// Write message header (to be encrypted)
constexpr uint8_t msg_offset = 3;
buf_start[msg_offset] = static_cast<uint8_t>(message_type >> 8); // type high byte
buf_start[msg_offset + 1] = static_cast<uint8_t>(message_type); // type low byte
buf_start[msg_offset + 2] = static_cast<uint8_t>(payload_size >> 8); // data_len high byte
buf_start[msg_offset + 3] = static_cast<uint8_t>(payload_size); // data_len low byte
// payload data is already in the buffer starting at offset + 7
// Encrypt the message in place
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
noise_buffer_set_inout(mbuf, buf_start + msg_offset, 4 + payload_size, 4 + payload_size + this->frame_footer_size_);
int err = noise_cipherstate_encrypt(this->send_cipher_, &mbuf);
APIError aerr =
this->handle_noise_error_(err, LOG_STR("noise_cipherstate_encrypt"), APIError::CIPHERSTATE_ENCRYPT_FAILED);
if (aerr != APIError::OK)
return aerr;
if (messages.empty()) {
return APIError::OK;
}
// Fill in the encrypted size
buf_start[1] = static_cast<uint8_t>(mbuf.size >> 8);
buf_start[2] = static_cast<uint8_t>(mbuf.size);
encrypted_len_out = static_cast<uint16_t>(3 + mbuf.size); // indicator + size + encrypted data
return APIError::OK;
}
APIError APINoiseFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
#ifdef ESPHOME_DEBUG_API
assert(this->state_ == State::DATA);
#endif
// Resize buffer to include footer space for Noise MAC
if (this->frame_footer_size_)
buffer.get_buffer()->resize(buffer.get_buffer()->size() + this->frame_footer_size_);
uint16_t payload_size =
static_cast<uint16_t>(buffer.get_buffer()->size() - HEADER_PADDING - this->frame_footer_size_);
uint8_t *buf_start = buffer.get_buffer()->data();
uint16_t encrypted_len;
APIError aerr = this->encrypt_noise_message_(buf_start, payload_size, type, encrypted_len);
if (aerr != APIError::OK)
return aerr;
return this->write_raw_fast_buf_(buf_start, encrypted_len);
}
APIError APINoiseFrameHelper::write_protobuf_messages(ProtoWriteBuffer buffer, std::span<const MessageInfo> messages) {
#ifdef ESPHOME_DEBUG_API
assert(this->state_ == State::DATA);
assert(!messages.empty());
#endif
// Noise messages are already contiguous in the buffer:
// HEADER_PADDING (7) exactly matches the fixed header size, and
// footer space (16) is consumed by the encryption MAC.
uint8_t *buffer_data = buffer.get_buffer()->data();
// Stack-allocated iovec array - no heap allocation
StaticVector<struct iovec, MAX_MESSAGES_PER_BATCH> iovs;
uint8_t *write_start = buffer_data + messages[0].offset;
uint16_t total_write_len = 0;
// We need to encrypt each message in place
for (const auto &msg : messages) {
// The buffer already has padding at offset
uint8_t *buf_start = buffer_data + msg.offset;
// Write noise header
buf_start[0] = 0x01; // indicator
// buf_start[1], buf_start[2] to be set after encryption
// Write message header (to be encrypted)
constexpr uint8_t msg_offset = 3;
buf_start[msg_offset] = static_cast<uint8_t>(msg.message_type >> 8); // type high byte
buf_start[msg_offset + 1] = static_cast<uint8_t>(msg.message_type); // type low byte
buf_start[msg_offset + 2] = static_cast<uint8_t>(msg.payload_size >> 8); // data_len high byte
buf_start[msg_offset + 3] = static_cast<uint8_t>(msg.payload_size); // data_len low byte
// payload data is already in the buffer starting at offset + 7
// Make sure we have space for MAC
// The buffer should already have been sized appropriately
// Encrypt the message in place
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
noise_buffer_set_inout(mbuf, buf_start + msg_offset, 4 + msg.payload_size,
4 + msg.payload_size + frame_footer_size_);
int err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
APIError aerr =
handle_noise_error_(err, LOG_STR("noise_cipherstate_encrypt"), APIError::CIPHERSTATE_ENCRYPT_FAILED);
uint16_t encrypted_len;
APIError aerr = this->encrypt_noise_message_(buf_start, msg.payload_size, msg.message_type, encrypted_len);
if (aerr != APIError::OK)
return aerr;
// Fill in the encrypted size
buf_start[1] = static_cast<uint8_t>(mbuf.size >> 8);
buf_start[2] = static_cast<uint8_t>(mbuf.size);
// Add iovec for this encrypted message
size_t msg_len = static_cast<size_t>(3 + mbuf.size); // indicator + size + encrypted data
iovs.push_back({buf_start, msg_len});
total_write_len += msg_len;
total_write_len += encrypted_len;
}
// Send all encrypted messages in one writev call
return this->write_raw_(iovs.data(), iovs.size(), total_write_len);
return this->write_raw_fast_buf_(write_start, total_write_len);
}
APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) {
@@ -531,16 +542,16 @@ APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) {
header[1] = (uint8_t) (len >> 8);
header[2] = (uint8_t) len;
if (len == 0) {
return this->write_raw_buf_(header, 3);
}
struct iovec iov[2];
iov[0].iov_base = header;
iov[0].iov_len = 3;
if (len == 0) {
return this->write_raw_(iov, 1, 3); // Just header
}
iov[1].iov_base = const_cast<uint8_t *>(data);
iov[1].iov_len = len;
return this->write_raw_(iov, 2, 3 + len); // Header + data
return this->write_raw_iov_(iov, 2, 3 + len);
}
/** Initiate the data structures for the handshake.
@@ -606,7 +617,7 @@ APIError APINoiseFrameHelper::check_handshake_finished_() {
if (aerr != APIError::OK)
return aerr;
frame_footer_size_ = noise_cipherstate_get_mac_length(send_cipher_);
this->frame_footer_size_ = noise_cipherstate_get_mac_length(send_cipher_);
HELPER_LOG("Handshake complete!");
noise_handshakestate_free(handshake_);

View File

@@ -9,19 +9,22 @@ namespace esphome::api {
class APINoiseFrameHelper final : public APIFrameHelper {
public:
// Noise header structure:
// Pos 0: indicator (0x01)
// Pos 1-2: encrypted payload size (16-bit big-endian)
// Pos 3-6: encrypted type (16-bit) + data_len (16-bit)
// Pos 7+: actual payload data
static constexpr uint8_t HEADER_PADDING = 1 + 2 + 2 + 2; // indicator + size + type + data_len
APINoiseFrameHelper(std::unique_ptr<socket::Socket> socket, APINoiseContext &ctx)
: APIFrameHelper(std::move(socket)), ctx_(ctx) {
// Noise header structure:
// Pos 0: indicator (0x01)
// Pos 1-2: encrypted payload size (16-bit big-endian)
// Pos 3-6: encrypted type (16-bit) + data_len (16-bit)
// Pos 7+: actual payload data
frame_header_padding_ = 7;
frame_header_padding_ = HEADER_PADDING;
}
~APINoiseFrameHelper() override;
APIError init() override;
APIError loop() override;
APIError read_packet(ReadPacketBuffer *buffer) override;
APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override;
APIError write_protobuf_messages(ProtoWriteBuffer buffer, std::span<const MessageInfo> messages) override;
protected:
@@ -33,6 +36,8 @@ class APINoiseFrameHelper final : public APIFrameHelper {
APIError state_action_handshake_write_();
APIError try_read_frame_();
APIError write_frame_(const uint8_t *data, uint16_t len);
APIError encrypt_noise_message_(uint8_t *buf_start, uint16_t payload_size, uint8_t message_type,
uint16_t &encrypted_len_out);
APIError init_handshake_();
APIError check_handshake_finished_();
void send_explicit_handshake_reject_(const LogString *reason);

View File

@@ -39,15 +39,8 @@ static constexpr size_t API_MAX_LOG_BYTES = 168;
format_hex_pretty_to(hex_buf_, (buffer).data(), \
(buffer).size() < API_MAX_LOG_BYTES ? (buffer).size() : API_MAX_LOG_BYTES)); \
} while (0)
#define LOG_PACKET_SENDING(data, len) \
do { \
char hex_buf_[format_hex_pretty_size(API_MAX_LOG_BYTES)]; \
ESP_LOGVV(TAG, "Sending raw: %s", \
format_hex_pretty_to(hex_buf_, data, (len) < API_MAX_LOG_BYTES ? (len) : API_MAX_LOG_BYTES)); \
} while (0)
#else
#define LOG_PACKET_RECEIVED(buffer) ((void) 0)
#define LOG_PACKET_SENDING(data, len) ((void) 0)
#endif
/// Initialize the frame helper, returns OK if successful.
@@ -205,7 +198,6 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
// Make sure to tell the remote that we don't
// understand the indicator byte so it knows
// we do not support it.
struct iovec iov[1];
// The \x00 first byte is the marker for plaintext.
//
// The remote will know how to handle the indicator byte,
@@ -220,14 +212,12 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
"Bad indicator byte";
char msg[INDICATOR_MSG_SIZE];
memcpy_P(msg, MSG_PROGMEM, INDICATOR_MSG_SIZE);
iov[0].iov_base = (void *) msg;
this->write_raw_buf_(msg, INDICATOR_MSG_SIZE);
#else
static const char MSG[] = "\x00"
"Bad indicator byte";
iov[0].iov_base = (void *) MSG;
this->write_raw_buf_(MSG, INDICATOR_MSG_SIZE);
#endif
iov[0].iov_len = INDICATOR_MSG_SIZE;
this->write_raw_(iov, 1, INDICATOR_MSG_SIZE);
}
return aerr;
}
@@ -237,73 +227,101 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
buffer->type = this->rx_header_parsed_type_;
return APIError::OK;
}
// Encode a 16-bit varint (1-3 bytes) using pre-computed length.
ESPHOME_ALWAYS_INLINE static inline void encode_varint_16(uint16_t value, uint8_t varint_len, uint8_t *p) {
if (varint_len >= 2) {
*p++ = static_cast<uint8_t>(value | 0x80);
value >>= 7;
if (varint_len == 3) {
*p++ = static_cast<uint8_t>(value | 0x80);
value >>= 7;
}
}
*p = static_cast<uint8_t>(value);
}
// Encode an 8-bit varint (1-2 bytes) using pre-computed length.
ESPHOME_ALWAYS_INLINE static inline void encode_varint_8(uint8_t value, uint8_t varint_len, uint8_t *p) {
if (varint_len == 2) {
*p++ = static_cast<uint8_t>(value | 0x80);
*p = static_cast<uint8_t>(value >> 7);
} else {
*p = value;
}
}
// Write plaintext header into pre-allocated padding before payload.
// padding_size: bytes reserved before payload (HEADER_PADDING for first/single msg,
// actual header size for contiguous batch messages).
// Returns the total header length (indicator + varints).
ESPHOME_ALWAYS_INLINE static inline uint8_t write_plaintext_header(uint8_t *buf_start, uint16_t payload_size,
uint8_t message_type, uint8_t padding_size) {
uint8_t size_varint_len = ProtoSize::varint16(payload_size);
uint8_t type_varint_len = ProtoSize::varint8(message_type);
uint8_t total_header_len = 1 + size_varint_len + type_varint_len;
// The header is right-justified within the padding so it sits immediately before payload.
//
// Single/first message (padding_size = HEADER_PADDING = 6):
// Example (small, header=3): [0-2] unused | [3] 0x00 | [4] size | [5] type | [6...] payload
// Example (medium, header=4): [0-1] unused | [2] 0x00 | [3-4] size | [5] type | [6...] payload
// Example (large, header=6): [0] 0x00 | [1-3] size | [4-5] type | [6...] payload
//
// Batch messages 2+ (padding_size = actual header size, no unused bytes):
// Example (small, header=3): [0] 0x00 | [1] size | [2] type | [3...] payload
// Example (medium, header=4): [0] 0x00 | [1-2] size | [3] type | [4...] payload
#ifdef ESPHOME_DEBUG_API
assert(padding_size >= total_header_len);
#endif
uint32_t header_offset = padding_size - total_header_len;
// Write the plaintext header
buf_start[header_offset] = 0x00; // indicator
// Encode varints directly into buffer using pre-computed lengths
encode_varint_16(payload_size, size_varint_len, buf_start + header_offset + 1);
encode_varint_8(message_type, type_varint_len, buf_start + header_offset + 1 + size_varint_len);
return total_header_len;
}
APIError APIPlaintextFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
#ifdef ESPHOME_DEBUG_API
assert(this->state_ == State::DATA);
#endif
uint16_t payload_size = static_cast<uint16_t>(buffer.get_buffer()->size() - HEADER_PADDING);
uint8_t *buffer_data = buffer.get_buffer()->data();
uint8_t header_len = write_plaintext_header(buffer_data, payload_size, type, HEADER_PADDING);
return this->write_raw_fast_buf_(buffer_data + HEADER_PADDING - header_len,
static_cast<uint16_t>(header_len + payload_size));
}
APIError APIPlaintextFrameHelper::write_protobuf_messages(ProtoWriteBuffer buffer,
std::span<const MessageInfo> messages) {
APIError aerr = this->check_data_state_();
if (aerr != APIError::OK)
return aerr;
if (messages.empty()) {
return APIError::OK;
}
#ifdef ESPHOME_DEBUG_API
assert(this->state_ == State::DATA);
assert(!messages.empty());
#endif
uint8_t *buffer_data = buffer.get_buffer()->data();
// Stack-allocated iovec array - no heap allocation
StaticVector<struct iovec, MAX_MESSAGES_PER_BATCH> iovs;
uint16_t total_write_len = 0;
// First message has max padding (header_size = HEADER_PADDING), may have unused leading bytes.
// Subsequent messages were encoded with exact header sizes (header_size = actual header len).
// write_plaintext_header right-justifies the header within header_size bytes of padding.
const auto &first = messages[0];
uint8_t *first_start = buffer_data + first.offset;
uint8_t header_len = write_plaintext_header(first_start, first.payload_size, first.message_type, HEADER_PADDING);
uint8_t *write_start = first_start + HEADER_PADDING - header_len;
uint16_t total_len = header_len + first.payload_size;
for (const auto &msg : messages) {
// Calculate varint sizes for header layout using inline ternary to avoid varint_slow call overhead
uint8_t size_varint_len = msg.payload_size < ProtoSize::VARINT_THRESHOLD_1_BYTE
? 1
: (msg.payload_size < ProtoSize::VARINT_THRESHOLD_2_BYTE ? 2 : 3);
uint8_t type_varint_len = msg.message_type < ProtoSize::VARINT_THRESHOLD_1_BYTE ? 1 : 2;
uint8_t total_header_len = 1 + size_varint_len + type_varint_len;
// Calculate where to start writing the header
// The header starts at the latest possible position to minimize unused padding
//
// Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3
// [0-2] - Unused padding
// [3] - 0x00 indicator byte
// [4] - Payload size varint (1 byte, for sizes 0-127)
// [5] - Message type varint (1 byte, for types 0-127)
// [6...] - Actual payload data
//
// Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2
// [0-1] - Unused padding
// [2] - 0x00 indicator byte
// [3-4] - Payload size varint (2 bytes, for sizes 128-16383)
// [5] - Message type varint (1 byte, for types 0-127)
// [6...] - Actual payload data
//
// Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0
// [0] - 0x00 indicator byte
// [1-3] - Payload size varint (3 bytes, for sizes 16384-65535)
// [4-5] - Message type varint (2 bytes, for types 128-16383)
// [6...] - Actual payload data
//
// The message starts at offset + frame_header_padding_
// So we write the header starting at offset + frame_header_padding_ - total_header_len
uint8_t *buf_start = buffer_data + msg.offset;
uint32_t header_offset = frame_header_padding_ - total_header_len;
// Write the plaintext header
buf_start[header_offset] = 0x00; // indicator
// Encode varints directly into buffer
encode_varint_to_buffer(msg.payload_size, buf_start + header_offset + 1);
encode_varint_to_buffer(msg.message_type, buf_start + header_offset + 1 + size_varint_len);
// Add iovec for this message (header + payload)
size_t msg_len = static_cast<size_t>(total_header_len + msg.payload_size);
iovs.push_back({buf_start + header_offset, msg_len});
total_write_len += msg_len;
for (size_t i = 1; i < messages.size(); i++) {
const auto &msg = messages[i];
header_len = write_plaintext_header(buffer_data + msg.offset, msg.payload_size, msg.message_type, msg.header_size);
total_len += header_len + msg.payload_size;
}
// Send all messages in one writev call
return write_raw_(iovs.data(), iovs.size(), total_write_len);
return this->write_raw_fast_buf_(write_start, total_len);
}
} // namespace esphome::api

View File

@@ -7,18 +7,21 @@ namespace esphome::api {
class APIPlaintextFrameHelper final : public APIFrameHelper {
public:
// Plaintext header structure (worst case):
// Pos 0: indicator (0x00)
// Pos 1-3: payload size varint (up to 3 bytes)
// Pos 4-5: message type varint (up to 2 bytes)
// Pos 6+: actual payload data
static constexpr uint8_t HEADER_PADDING = 1 + 3 + 2; // indicator + size varint + type varint
explicit APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket) : APIFrameHelper(std::move(socket)) {
// Plaintext header structure (worst case):
// Pos 0: indicator (0x00)
// Pos 1-3: payload size varint (up to 3 bytes)
// Pos 4-5: message type varint (up to 2 bytes)
// Pos 6+: actual payload data
frame_header_padding_ = 6;
frame_header_padding_ = HEADER_PADDING;
}
~APIPlaintextFrameHelper() override = default;
APIError init() override;
APIError loop() override;
APIError read_packet(ReadPacketBuffer *buffer) override;
APIError write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) override;
APIError write_protobuf_messages(ProtoWriteBuffer buffer, std::span<const MessageInfo> messages) override;
protected:

View File

@@ -645,6 +645,17 @@ class ProtoSize {
static constexpr uint32_t VARINT_THRESHOLD_3_BYTE = 1 << 21; // 2097152
static constexpr uint32_t VARINT_THRESHOLD_4_BYTE = 1 << 28; // 268435456
// Varint encoded length for a 16-bit value (1, 2, or 3 bytes).
// Fully inline — no slow path call for values >= 128.
static constexpr inline uint8_t ESPHOME_ALWAYS_INLINE varint16(uint16_t value) {
return value < VARINT_THRESHOLD_1_BYTE ? 1 : (value < VARINT_THRESHOLD_2_BYTE ? 2 : 3);
}
// Varint encoded length for an 8-bit value (1 or 2 bytes).
static constexpr inline uint8_t ESPHOME_ALWAYS_INLINE varint8(uint8_t value) {
return value < VARINT_THRESHOLD_1_BYTE ? 1 : 2;
}
/**
* @brief Calculates the size in bytes needed to encode a uint32_t value as a varint
*

View File

@@ -75,7 +75,7 @@ static void PlaintextFrame_WriteBatch5(benchmark::State &state) {
for (auto _ : state) {
for (int i = 0; i < kInnerIterations; i++) {
buffer.clear();
MessageInfo messages[5] = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
MessageInfo messages[5] = {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}};
for (int j = 0; j < 5; j++) {
uint16_t offset = buffer.size();
@@ -89,7 +89,7 @@ static void PlaintextFrame_WriteBatch5(benchmark::State &state) {
ProtoWriteBuffer writer(&buffer, offset + padding);
msg.encode(writer);
messages[j] = MessageInfo(SensorStateResponse::MESSAGE_TYPE, offset, size);
messages[j] = MessageInfo(SensorStateResponse::MESSAGE_TYPE, offset, size, padding);
}
helper->write_protobuf_messages(ProtoWriteBuffer(&buffer, 0), std::span<const MessageInfo>(messages, 5));