// main.cpp -- streaming wake-word validator with low-latency inference #include <Arduino.h> #include <ESP_I2S.h> #include <Adafruit_NeoPixel.h> #include <arduinoFFT.h> #include <math.h> #include "audio_model.h" #define AUTO_PICK_CHANNEL 1 #define FORCE_RIGHT_CHANNEL 1 // used if auto-pick disabled // Audio capture configuration ------------------------------------------------ constexpr int SAMPLE_RATE_HZ = 16000; constexpr int CLIP_SECONDS = 1; constexpr int CLIP_SAMPLES = SAMPLE_RATE_HZ * CLIP_SECONDS; // Feature extraction parameters (must match training script) constexpr int FRAME_LEN = 400; // 25 ms constexpr int HOP = 160; // 10 ms constexpr int FFT_LEN = 512; constexpr int N_BANDS = 10; constexpr int N_FRAMES = 1 + (CLIP_SAMPLES - FRAME_LEN) / HOP; static_assert(N_FRAMES * N_BANDS == KW_INPUT_DIM, "Feature dimension mismatch"); constexpr int N_BINS = (FFT_LEN / 2) + 1; constexpr int BINS_PER_BAND = N_BINS / N_BANDS; // Stream chunk configuration -------------------------------------------------- constexpr int CHUNK_FRAMES = HOP; // process every 10 ms hop constexpr size_t BYTES_PER_FRAME = 8; // 32-bit stereo (4B * 2) constexpr size_t RAW_CHUNK_BYTES = CHUNK_FRAMES * BYTES_PER_FRAME; // Inference smoothing and gating constexpr float RMS_GATE = 0.0030f; // silence gate (tune to mic noise floor) constexpr float HELLO_THRESH = 0.65f; // smoother decision threshold constexpr int SMOOTH_N = 3; constexpr int CHUNKS_PER_INFERENCE = 4; // run NN every ~40 ms constexpr uint32_t DETECT_COOLDOWN_MS = 1500; constexpr int NEOPIXEL_PIN = 3; constexpr int NEOPIXEL_COUNT = 1; // ----------------------------------------------------------------------------- struct FeatureSummary { float min; float max; float mean; }; struct ModelResult { float prob; float logit_off; float logit_on; }; static I2SClass g_i2s; static arduinoFFT g_fft; static Adafruit_NeoPixel g_pixel(NEOPIXEL_COUNT, NEOPIXEL_PIN, NEO_GRB + NEO_KHZ800); // Ring buffer to hold the most recent CLIP_SAMPLES samples static int16_t g_ring[CLIP_SAMPLES]; static size_t g_ringWrite = 0; static size_t g_ringCount = 0; // Working buffers static int16_t g_pcm[CLIP_SAMPLES]; static float g_feat[KW_INPUT_DIM]; static int16_t g_chunkMono[CHUNK_FRAMES]; static int16_t g_chunkL[CHUNK_FRAMES]; static int16_t g_chunkR[CHUNK_FRAMES]; static uint8_t g_rawChunk[RAW_CHUNK_BYTES]; static int16_t g_frameBuffer[FRAME_LEN]; static float g_frameFeatures[N_FRAMES][N_BANDS]; static size_t g_frameStart = 0; static size_t g_frameCount = 0; static float g_hann[FRAME_LEN]; static double g_fftReal[FFT_LEN]; static double g_fftImag[FFT_LEN]; static bool g_channelChosen = false; static bool g_useRight = true; static bool g_dumpedRaw = false; // ----------------------------------------------------------------------------- static void initHann() { for (int i = 0; i < FRAME_LEN; ++i) { g_hann[i] = 0.5f - 0.5f * cosf(2.0f * PI * i / (FRAME_LEN - 1)); } } static void ledOff() { g_pixel.setPixelColor(0, g_pixel.Color(0, 0, 0)); g_pixel.show(); } static void ledBlink() { g_pixel.setPixelColor(0, g_pixel.Color(0, 0, 255)); g_pixel.show(); delay(100); ledOff(); } static float computeRms(const int16_t* data, size_t count) { double acc = 0.0; for (size_t i = 0; i < count; ++i) { double v = static_cast<double>(data[i]); acc += v * v; } acc = (count > 0) ? acc / static_cast<double>(count) : 0.0; return static_cast<float>(sqrt(acc)) / 32768.0f; } static void chooseChannel(double energyL, double energyR) { constexpr double EPS = 1e-9; bool pickRight = (energyR + EPS >= energyL); if (!g_channelChosen || pickRight != g_useRight) { Serial.print("# Auto channel -> "); Serial.println(pickRight ? "RIGHT" : "LEFT"); } g_useRight = pickRight; g_channelChosen = true; } static bool readChunkMono(int16_t* dst, int frames) { const size_t wantBytes = static_cast<size_t>(frames) * BYTES_PER_FRAME; size_t got = g_i2s.readBytes(reinterpret_cast<char*>(g_rawChunk), wantBytes); if (got != wantBytes) { Serial.printf("! readBytes short: got %u need %u\n", static_cast<unsigned>(got), static_cast<unsigned>(wantBytes)); return false; } const int32_t* src = reinterpret_cast<const int32_t*>(g_rawChunk); double energyL = 0.0; double energyR = 0.0; for (int i = 0; i < frames; ++i) { int32_t Lraw = src[2 * i + 0]; int32_t Rraw = src[2 * i + 1]; int16_t L = static_cast<int16_t>(Lraw >> 16); int16_t R = static_cast<int16_t>(Rraw >> 16); g_chunkL[i] = L; g_chunkR[i] = R; energyL += static_cast<double>(L) * L; energyR += static_cast<double>(R) * R; } #if AUTO_PICK_CHANNEL chooseChannel(energyL, energyR); #endif if (!g_dumpedRaw) { g_dumpedRaw = true; Serial.println("# raw32 dump (first 16 frames):"); for (int i = 0; i < min(frames, 16); ++i) { int32_t Lraw = src[2 * i + 0]; int32_t Rraw = src[2 * i + 1]; Serial.printf(" [%02d] L=0x%08lx R=0x%08lx\n", i, static_cast<unsigned long>(Lraw), static_cast<unsigned long>(Rraw)); } Serial.printf("# channel energy L=%.1f R=%.1f\n", energyL, energyR); } const int16_t* chosen = nullptr; #if AUTO_PICK_CHANNEL chosen = g_useRight ? g_chunkR : g_chunkL; #else chosen = FORCE_RIGHT_CHANNEL ? g_chunkR : g_chunkL; #endif for (int i = 0; i < frames; ++i) { dst[i] = chosen[i]; } return true; } static void ringPush(const int16_t* samples, int frames) { for (int i = 0; i < frames; ++i) { g_ring[g_ringWrite] = samples[i]; g_ringWrite = (g_ringWrite + 1) % CLIP_SAMPLES; if (g_ringCount < static_cast<size_t>(CLIP_SAMPLES)) { g_ringCount++; } } } static bool ringCopyToClip(int16_t* dst) { if (g_ringCount < static_cast<size_t>(CLIP_SAMPLES)) { return false; } size_t start = g_ringWrite; for (int i = 0; i < CLIP_SAMPLES; ++i) { dst[i] = g_ring[(start + i) % CLIP_SAMPLES]; } return true; } static void computeBandsForFrame(const int16_t* frame, float* outBands) { for (int i = 0; i < FRAME_LEN; ++i) { float sample = static_cast<float>(frame[i]) / 32768.0f; g_fftReal[i] = static_cast<double>(sample * g_hann[i]); g_fftImag[i] = 0.0; } for (int i = FRAME_LEN; i < FFT_LEN; ++i) { g_fftReal[i] = 0.0; g_fftImag[i] = 0.0; } g_fft.Windowing(g_fftReal, FFT_LEN, FFT_WIN_TYP_RECTANGLE, FFT_FORWARD); g_fft.Compute(g_fftReal, g_fftImag, FFT_LEN, FFT_FORWARD); g_fft.ComplexToMagnitude(g_fftReal, g_fftImag, FFT_LEN); constexpr float EPS = 1e-10f; for (int band = 0; band < N_BANDS; ++band) { int b0 = band * BINS_PER_BAND; int b1 = (band == N_BANDS - 1) ? N_BINS : (band + 1) * BINS_PER_BAND; double acc = 0.0; int count = 0; for (int k = b0; k < b1; ++k) { double mag = g_fftReal[k]; acc += mag * mag; ++count; } float meanPower = (count > 0) ? static_cast<float>(acc / static_cast<double>(count)) : 0.0f; outBands[band] = 10.0f * log10f(meanPower + EPS); } } static bool updateFrameFeatures() { if (g_ringCount < static_cast<size_t>(FRAME_LEN)) { return false; } size_t idx = (g_ringWrite + CLIP_SAMPLES - FRAME_LEN) % CLIP_SAMPLES; for (int i = 0; i < FRAME_LEN; ++i) { g_frameBuffer[i] = g_ring[idx]; idx = (idx + 1) % CLIP_SAMPLES; } float bands[N_BANDS]; computeBandsForFrame(g_frameBuffer, bands); size_t slot; if (g_frameCount < static_cast<size_t>(N_FRAMES)) { slot = (g_frameStart + g_frameCount) % N_FRAMES; g_frameCount++; } else { slot = g_frameStart; g_frameStart = (g_frameStart + 1) % N_FRAMES; } for (int b = 0; b < N_BANDS; ++b) { g_frameFeatures[slot][b] = bands[b]; } return g_frameCount == static_cast<size_t>(N_FRAMES); } static FeatureSummary buildFeatureVector() { FeatureSummary summary; summary.min = 1e9f; summary.max = -1e9f; summary.mean = 0.0f; if (g_frameCount < static_cast<size_t>(N_FRAMES)) { summary.min = summary.max = summary.mean = 0.0f; return summary; } size_t idx = g_frameStart; int featIndex = 0; for (int f = 0; f < N_FRAMES; ++f) { const float* bands = g_frameFeatures[idx]; for (int b = 0; b < N_BANDS; ++b) { float val = bands[b]; g_feat[featIndex++] = val; summary.min = min(summary.min, val); summary.max = max(summary.max, val); summary.mean += val; } idx = (idx + 1) % N_FRAMES; } summary.mean /= static_cast<float>(N_FRAMES * N_BANDS); return summary; } static ModelResult runKeywordModel(float* features /* mutated in place */) { float hidden[KW_HIDDEN_DIM]; float logits[KW_OUTPUT_DIM]; for (int i = 0; i < KW_INPUT_DIM; ++i) { features[i] = (features[i] - kw_mu[i]) * kw_sigma_inv[i]; } for (int j = 0; j < KW_HIDDEN_DIM; ++j) { float sum = kw_b0[j]; for (int i = 0; i < KW_INPUT_DIM; ++i) { sum += kw_W0[j + i * KW_HIDDEN_DIM] * features[i]; } hidden[j] = (sum > 0.0f) ? sum : 0.0f; } for (int j = 0; j < KW_OUTPUT_DIM; ++j) { float sum = kw_b1[j]; for (int i = 0; i < KW_HIDDEN_DIM; ++i) { sum += kw_W1[j + i * KW_OUTPUT_DIM] * hidden[i]; } logits[j] = sum; } float m = max(logits[0], logits[1]); float e0 = expf(logits[0] - m); float e1 = expf(logits[1] - m); ModelResult res; res.logit_off = logits[0]; res.logit_on = logits[1]; res.prob = e1 / (e0 + e1); return res; } // ----------------------------------------------------------------------------- static void setupI2S() { g_i2s.setPins(/*bclk*/18, /*ws*/20, /*dout*/-1, /*din*/19, /*mclk*/-1); bool ok = g_i2s.begin(I2S_MODE_STD, SAMPLE_RATE_HZ, I2S_DATA_BIT_WIDTH_32BIT, I2S_SLOT_MODE_STEREO); if (!ok) Serial.println("ERR: I2S.begin failed"); ok = g_i2s.configureRX(SAMPLE_RATE_HZ, I2S_DATA_BIT_WIDTH_32BIT, I2S_SLOT_MODE_STEREO, I2S_RX_TRANSFORM_NONE); if (!ok) Serial.println("ERR: I2S.configureRX failed"); } void setup() { Serial.begin(115200); delay(2000); g_pixel.begin(); ledOff(); initHann(); setupI2S(); Serial.printf("# Wake word validator (streaming). chunk=%d frames, clip=%d samples\n", CHUNK_FRAMES, CLIP_SAMPLES); } void loop() { if (!readChunkMono(g_chunkMono, CHUNK_FRAMES)) { delay(5); return; } ringPush(g_chunkMono, CHUNK_FRAMES); bool framesReady = updateFrameFeatures(); if (g_ringCount < static_cast<size_t>(CLIP_SAMPLES) || !framesReady) { return; } static int chunkAccumulator = 0; if (++chunkAccumulator < CHUNKS_PER_INFERENCE) { return; } chunkAccumulator = 0; if (!ringCopyToClip(g_pcm)) { return; } static bool dumpedPcm = false; if (!dumpedPcm) { dumpedPcm = true; Serial.print("# pcm[0..15]:"); for (int i = 0; i < 16; ++i) { Serial.printf(" %d", g_pcm[i]); } Serial.println(); } float rms = computeRms(g_pcm, CLIP_SAMPLES); Serial.printf("rms=%.5f\n", rms); if (rms < RMS_GATE) { Serial.println("probHello=0.0000 (gated by RMS)"); return; } FeatureSummary summary = buildFeatureVector(); Serial.printf("feat[min=%.2f max=%.2f mean=%.2f]\n", summary.min, summary.max, summary.mean); static bool printedFeatRaw = false; if (!printedFeatRaw) { printedFeatRaw = true; Serial.print("# featRaw[0..7]:"); for (int i = 0; i < 8; ++i) { Serial.printf(" %.2f", g_feat[i]); } Serial.println(); } ModelResult res = runKeywordModel(g_feat); static bool printedFeatNorm = false; if (!printedFeatNorm) { printedFeatNorm = true; Serial.print("# featNorm[0..7]:"); for (int i = 0; i < 8; ++i) { Serial.printf(" %.2f", g_feat[i]); } Serial.println(); } static float smooth[SMOOTH_N] = {0}; static int smoothIdx = 0; smooth[smoothIdx] = res.prob; smoothIdx = (smoothIdx + 1) % SMOOTH_N; float avg = 0.0f; int valid = 0; for (int i = 0; i < SMOOTH_N; ++i) { avg += smooth[i]; if (smooth[i] > 0.0f) ++valid; } if (valid > 0) { avg /= static_cast<float>(valid); } else { avg = res.prob; } Serial.printf("logits[off=%.3f on=%.3f] prob=%.4f avg=%.4f\n", res.logit_off, res.logit_on, res.prob, avg); static uint32_t lastDetectMs = 0; uint32_t nowMs = millis(); if (avg > HELLO_THRESH && (nowMs - lastDetectMs) > DETECT_COOLDOWN_MS) { ledBlink(); lastDetectMs = nowMs; } }