diff --git a/esphome/components/micro_wake_word/__init__.py b/esphome/components/micro_wake_word/__init__.py index 38926fce99..61d296cbba 100644 --- a/esphome/components/micro_wake_word/__init__.py +++ b/esphome/components/micro_wake_word/__init__.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin from esphome import automation, external_files, git from esphome.automation import register_action, register_condition import esphome.codegen as cg -from esphome.components import esp32, microphone, ota +from esphome.components import esp32, microphone, ota, psram import esphome.config_validation as cv from esphome.const import ( CONF_FILE, @@ -20,6 +20,7 @@ from esphome.const import ( CONF_RAW_DATA_ID, CONF_REF, CONF_REFRESH, + CONF_TASK_STACK_IN_PSRAM, CONF_TYPE, CONF_URL, CONF_USERNAME, @@ -358,6 +359,7 @@ CONFIG_SCHEMA = cv.All( ), cv.Optional(CONF_VAD): _maybe_empty_vad_schema, cv.Optional(CONF_STOP_AFTER_DETECTION, default=True): cv.boolean, + cv.Optional(CONF_TASK_STACK_IN_PSRAM): psram.validate_task_stack_in_psram, cv.Optional(CONF_MODEL): cv.invalid( f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter." ), @@ -451,6 +453,10 @@ async def to_code(config): cg.add_define("USE_MICRO_WAKE_WORD") ota.request_ota_state_listeners() + if config.get(CONF_TASK_STACK_IN_PSRAM): + cg.add(var.set_task_stack_in_psram(True)) + psram.request_external_task_stack() + esp32.add_idf_component(name="espressif/esp-tflite-micro", ref="1.3.3~1") # Pin esp-nn for stable future builds (esp-tflite-micro depends on esp-nn) esp32.add_idf_component(name="espressif/esp-nn", ref="1.1.2") diff --git a/esphome/components/micro_wake_word/micro_wake_word.cpp b/esphome/components/micro_wake_word/micro_wake_word.cpp index 739d64dc28..237d72229d 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.cpp +++ b/esphome/components/micro_wake_word/micro_wake_word.cpp @@ -217,10 +217,7 @@ void MicroWakeWord::inference_task(void *params) { FrontendFreeStateContents(&this_mww->frontend_state_); xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED); - while (true) { - // Continuously delay until the main loop deletes the task - delay(10); - } + vTaskSuspend(nullptr); // Suspend this task indefinitely until the loop method deletes it } std::vector MicroWakeWord::get_wake_words() { @@ -243,14 +240,14 @@ void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probabilit #endif void MicroWakeWord::suspend_task_() { - if (this->inference_task_handle_ != nullptr) { - vTaskSuspend(this->inference_task_handle_); + if (this->inference_task_.is_created()) { + vTaskSuspend(this->inference_task_.get_handle()); } } void MicroWakeWord::resume_task_() { - if (this->inference_task_handle_ != nullptr) { - vTaskResume(this->inference_task_handle_); + if (this->inference_task_.is_created()) { + vTaskResume(this->inference_task_.get_handle()); } } @@ -292,8 +289,7 @@ void MicroWakeWord::loop() { if ((event_group_bits & EventGroupBits::TASK_STOPPED)) { ESP_LOGD(TAG, "Inference task is finished, freeing task resources"); - vTaskDelete(this->inference_task_handle_); - this->inference_task_handle_ = nullptr; + this->inference_task_.deallocate(); xEventGroupClearBits(this->event_group_, ALL_BITS); xQueueReset(this->detection_queue_); this->set_state_(State::STOPPED); @@ -311,7 +307,7 @@ void MicroWakeWord::loop() { switch (this->state_) { case State::STARTING: - if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) { + if (!this->inference_task_.is_created() && !this->status_has_error()) { // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it // uses floating point operations. if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, @@ -320,10 +316,8 @@ void MicroWakeWord::loop() { return; } - xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this, - INFERENCE_TASK_PRIORITY, &this->inference_task_handle_); - - if (this->inference_task_handle_ == nullptr) { + if (!this->inference_task_.create(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, + (void *) this, INFERENCE_TASK_PRIORITY, this->task_stack_in_psram_)) { FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state this->status_momentary_error("task_start", 1000); } diff --git a/esphome/components/micro_wake_word/micro_wake_word.h b/esphome/components/micro_wake_word/micro_wake_word.h index ef440b5d37..e4c590a423 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.h +++ b/esphome/components/micro_wake_word/micro_wake_word.h @@ -11,6 +11,7 @@ #include "esphome/core/automation.h" #include "esphome/core/component.h" #include "esphome/core/defines.h" +#include "esphome/core/static_task.h" #ifdef USE_OTA_STATE_LISTENER #include "esphome/components/ota/ota_backend.h" @@ -59,6 +60,8 @@ class MicroWakeWord : public Component void set_stop_after_detection(bool stop_after_detection) { this->stop_after_detection_ = stop_after_detection; } + void set_task_stack_in_psram(bool task_stack_in_psram) { this->task_stack_in_psram_ = task_stack_in_psram; } + Trigger *get_wake_word_detected_trigger() { return &this->wake_word_detected_trigger_; } void add_wake_word_model(WakeWordModel *model); @@ -93,6 +96,8 @@ class MicroWakeWord : public Component bool stop_after_detection_; + bool task_stack_in_psram_{false}; + uint8_t features_step_size_; // Audio frontend handles generating spectrogram features @@ -105,8 +110,9 @@ class MicroWakeWord : public Component // Used to send messages about the models' states to the main loop QueueHandle_t detection_queue_; + StaticTask inference_task_; + static void inference_task(void *params); - TaskHandle_t inference_task_handle_{nullptr}; /// @brief Suspends the inference task void suspend_task_(); diff --git a/tests/components/micro_wake_word/common.yaml b/tests/components/micro_wake_word/common.yaml index c051c8dd57..cd060c176e 100644 --- a/tests/components/micro_wake_word/common.yaml +++ b/tests/components/micro_wake_word/common.yaml @@ -1,3 +1,6 @@ +psram: + mode: quad + i2s_audio: i2s_lrclk_pin: GPIO18 i2s_bclk_pin: GPIO19 @@ -12,6 +15,7 @@ microphone: micro_wake_word: microphone: echo_microphone + task_stack_in_psram: true on_wake_word_detected: - logger.log: "Wake word detected" - micro_wake_word.stop: