main.cpp
// main.cpp -- streaming wake-word validator with low-latency inference
#include <Arduino.h>
#include <ESP_I2S.h>
#include <Adafruit_NeoPixel.h>
#include <arduinoFFT.h>
#include <math.h>
#include "audio_model.h"
 
#define AUTO_PICK_CHANNEL   1
#define FORCE_RIGHT_CHANNEL 1  // used if auto-pick disabled
 
// Audio capture configuration ------------------------------------------------
constexpr int SAMPLE_RATE_HZ = 16000;
constexpr int CLIP_SECONDS   = 1;
constexpr int CLIP_SAMPLES   = SAMPLE_RATE_HZ * CLIP_SECONDS;
 
// Feature extraction parameters (must match training script)
constexpr int FRAME_LEN = 400;   // 25 ms
constexpr int HOP       = 160;   // 10 ms
constexpr int FFT_LEN   = 512;
constexpr int N_BANDS   = 10;
constexpr int N_FRAMES  = 1 + (CLIP_SAMPLES - FRAME_LEN) / HOP;
static_assert(N_FRAMES * N_BANDS == KW_INPUT_DIM, "Feature dimension mismatch");
constexpr int N_BINS = (FFT_LEN / 2) + 1;
constexpr int BINS_PER_BAND = N_BINS / N_BANDS;
 
// Stream chunk configuration --------------------------------------------------
constexpr int   CHUNK_FRAMES         = HOP;      // process every 10 ms hop
constexpr size_t BYTES_PER_FRAME     = 8;        // 32-bit stereo (4B * 2)
constexpr size_t RAW_CHUNK_BYTES     = CHUNK_FRAMES * BYTES_PER_FRAME;
 
// Inference smoothing and gating
constexpr float RMS_GATE     = 0.0030f;   // silence gate (tune to mic noise floor)
constexpr float HELLO_THRESH = 0.65f;     // smoother decision threshold
constexpr int   SMOOTH_N     = 3;
constexpr int   CHUNKS_PER_INFERENCE = 4; // run NN every ~40 ms
constexpr uint32_t DETECT_COOLDOWN_MS = 1500;
 
constexpr int NEOPIXEL_PIN   = 3;
constexpr int NEOPIXEL_COUNT = 1;
 
// -----------------------------------------------------------------------------
 
struct FeatureSummary {
  float min;
  float max;
  float mean;
};
 
struct ModelResult {
  float prob;
  float logit_off;
  float logit_on;
};
 
static I2SClass           g_i2s;
static arduinoFFT         g_fft;
static Adafruit_NeoPixel  g_pixel(NEOPIXEL_COUNT, NEOPIXEL_PIN, NEO_GRB + NEO_KHZ800);
 
// Ring buffer to hold the most recent CLIP_SAMPLES samples
static int16_t g_ring[CLIP_SAMPLES];
static size_t  g_ringWrite = 0;
static size_t  g_ringCount = 0;
 
// Working buffers
static int16_t g_pcm[CLIP_SAMPLES];
static float   g_feat[KW_INPUT_DIM];
static int16_t g_chunkMono[CHUNK_FRAMES];
static int16_t g_chunkL[CHUNK_FRAMES];
static int16_t g_chunkR[CHUNK_FRAMES];
static uint8_t g_rawChunk[RAW_CHUNK_BYTES];
static int16_t g_frameBuffer[FRAME_LEN];
static float   g_frameFeatures[N_FRAMES][N_BANDS];
static size_t  g_frameStart = 0;
static size_t  g_frameCount = 0;
 
static float   g_hann[FRAME_LEN];
static double  g_fftReal[FFT_LEN];
static double  g_fftImag[FFT_LEN];
 
static bool g_channelChosen = false;
static bool g_useRight      = true;
static bool g_dumpedRaw     = false;
 
// -----------------------------------------------------------------------------
static void initHann() {
  for (int i = 0; i < FRAME_LEN; ++i) {
    g_hann[i] = 0.5f - 0.5f * cosf(2.0f * PI * i / (FRAME_LEN - 1));
  }
}
 
static void ledOff() {
  g_pixel.setPixelColor(0, g_pixel.Color(0, 0, 0));
  g_pixel.show();
}
 
static void ledBlink() {
  g_pixel.setPixelColor(0, g_pixel.Color(0, 0, 255));
  g_pixel.show();
  delay(100);
  ledOff();
}
 
static float computeRms(const int16_t* data, size_t count) {
  double acc = 0.0;
  for (size_t i = 0; i < count; ++i) {
    double v = static_cast<double>(data[i]);
    acc += v * v;
  }
  acc = (count > 0) ? acc / static_cast<double>(count) : 0.0;
  return static_cast<float>(sqrt(acc)) / 32768.0f;
}
 
static void chooseChannel(double energyL, double energyR) {
  constexpr double EPS = 1e-9;
  bool pickRight = (energyR + EPS >= energyL);
  if (!g_channelChosen || pickRight != g_useRight) {
    Serial.print("# Auto channel -> ");
    Serial.println(pickRight ? "RIGHT" : "LEFT");
  }
  g_useRight = pickRight;
  g_channelChosen = true;
}
 
static bool readChunkMono(int16_t* dst, int frames) {
  const size_t wantBytes = static_cast<size_t>(frames) * BYTES_PER_FRAME;
  size_t got = g_i2s.readBytes(reinterpret_cast<char*>(g_rawChunk), wantBytes);
  if (got != wantBytes) {
    Serial.printf("! readBytes short: got %u need %u\n", static_cast<unsigned>(got), static_cast<unsigned>(wantBytes));
    return false;
  }
 
  const int32_t* src = reinterpret_cast<const int32_t*>(g_rawChunk);
  double energyL = 0.0;
  double energyR = 0.0;
  for (int i = 0; i < frames; ++i) {
    int32_t Lraw = src[2 * i + 0];
    int32_t Rraw = src[2 * i + 1];
    int16_t L = static_cast<int16_t>(Lraw >> 16);
    int16_t R = static_cast<int16_t>(Rraw >> 16);
    g_chunkL[i] = L;
    g_chunkR[i] = R;
    energyL += static_cast<double>(L) * L;
    energyR += static_cast<double>(R) * R;
  }
 
#if AUTO_PICK_CHANNEL
  chooseChannel(energyL, energyR);
#endif
 
  if (!g_dumpedRaw) {
    g_dumpedRaw = true;
    Serial.println("# raw32 dump (first 16 frames):");
    for (int i = 0; i < min(frames, 16); ++i) {
      int32_t Lraw = src[2 * i + 0];
      int32_t Rraw = src[2 * i + 1];
      Serial.printf("  [%02d] L=0x%08lx R=0x%08lx\n",
                    i,
                    static_cast<unsigned long>(Lraw),
                    static_cast<unsigned long>(Rraw));
    }
    Serial.printf("# channel energy L=%.1f R=%.1f\n", energyL, energyR);
  }
 
  const int16_t* chosen = nullptr;
#if AUTO_PICK_CHANNEL
  chosen = g_useRight ? g_chunkR : g_chunkL;
#else
  chosen = FORCE_RIGHT_CHANNEL ? g_chunkR : g_chunkL;
#endif
 
  for (int i = 0; i < frames; ++i) {
    dst[i] = chosen[i];
  }
  return true;
}
 
static void ringPush(const int16_t* samples, int frames) {
  for (int i = 0; i < frames; ++i) {
    g_ring[g_ringWrite] = samples[i];
    g_ringWrite = (g_ringWrite + 1) % CLIP_SAMPLES;
    if (g_ringCount < static_cast<size_t>(CLIP_SAMPLES)) {
      g_ringCount++;
    }
  }
}
 
static bool ringCopyToClip(int16_t* dst) {
  if (g_ringCount < static_cast<size_t>(CLIP_SAMPLES)) {
    return false;
  }
  size_t start = g_ringWrite;
  for (int i = 0; i < CLIP_SAMPLES; ++i) {
    dst[i] = g_ring[(start + i) % CLIP_SAMPLES];
  }
  return true;
}
 
static void computeBandsForFrame(const int16_t* frame, float* outBands) {
  for (int i = 0; i < FRAME_LEN; ++i) {
    float sample = static_cast<float>(frame[i]) / 32768.0f;
    g_fftReal[i] = static_cast<double>(sample * g_hann[i]);
    g_fftImag[i] = 0.0;
  }
  for (int i = FRAME_LEN; i < FFT_LEN; ++i) {
    g_fftReal[i] = 0.0;
    g_fftImag[i] = 0.0;
  }
 
  g_fft.Windowing(g_fftReal, FFT_LEN, FFT_WIN_TYP_RECTANGLE, FFT_FORWARD);
  g_fft.Compute(g_fftReal, g_fftImag, FFT_LEN, FFT_FORWARD);
  g_fft.ComplexToMagnitude(g_fftReal, g_fftImag, FFT_LEN);
 
  constexpr float EPS = 1e-10f;
  for (int band = 0; band < N_BANDS; ++band) {
    int b0 = band * BINS_PER_BAND;
    int b1 = (band == N_BANDS - 1) ? N_BINS : (band + 1) * BINS_PER_BAND;
    double acc = 0.0;
    int count = 0;
    for (int k = b0; k < b1; ++k) {
      double mag = g_fftReal[k];
      acc += mag * mag;
      ++count;
    }
    float meanPower = (count > 0) ? static_cast<float>(acc / static_cast<double>(count)) : 0.0f;
    outBands[band] = 10.0f * log10f(meanPower + EPS);
  }
}
 
static bool updateFrameFeatures() {
  if (g_ringCount < static_cast<size_t>(FRAME_LEN)) {
    return false;
  }
 
  size_t idx = (g_ringWrite + CLIP_SAMPLES - FRAME_LEN) % CLIP_SAMPLES;
  for (int i = 0; i < FRAME_LEN; ++i) {
    g_frameBuffer[i] = g_ring[idx];
    idx = (idx + 1) % CLIP_SAMPLES;
  }
 
  float bands[N_BANDS];
  computeBandsForFrame(g_frameBuffer, bands);
 
  size_t slot;
  if (g_frameCount < static_cast<size_t>(N_FRAMES)) {
    slot = (g_frameStart + g_frameCount) % N_FRAMES;
    g_frameCount++;
  } else {
    slot = g_frameStart;
    g_frameStart = (g_frameStart + 1) % N_FRAMES;
  }
 
  for (int b = 0; b < N_BANDS; ++b) {
    g_frameFeatures[slot][b] = bands[b];
  }
  return g_frameCount == static_cast<size_t>(N_FRAMES);
}
 
static FeatureSummary buildFeatureVector() {
  FeatureSummary summary;
  summary.min = 1e9f;
  summary.max = -1e9f;
  summary.mean = 0.0f;
 
  if (g_frameCount < static_cast<size_t>(N_FRAMES)) {
    summary.min = summary.max = summary.mean = 0.0f;
    return summary;
  }
 
  size_t idx = g_frameStart;
  int featIndex = 0;
  for (int f = 0; f < N_FRAMES; ++f) {
    const float* bands = g_frameFeatures[idx];
    for (int b = 0; b < N_BANDS; ++b) {
      float val = bands[b];
      g_feat[featIndex++] = val;
      summary.min = min(summary.min, val);
      summary.max = max(summary.max, val);
      summary.mean += val;
    }
    idx = (idx + 1) % N_FRAMES;
  }
  summary.mean /= static_cast<float>(N_FRAMES * N_BANDS);
  return summary;
}
 
static ModelResult runKeywordModel(float* features /* mutated in place */) {
  float hidden[KW_HIDDEN_DIM];
  float logits[KW_OUTPUT_DIM];
 
  for (int i = 0; i < KW_INPUT_DIM; ++i) {
    features[i] = (features[i] - kw_mu[i]) * kw_sigma_inv[i];
  }
 
  for (int j = 0; j < KW_HIDDEN_DIM; ++j) {
    float sum = kw_b0[j];
    for (int i = 0; i < KW_INPUT_DIM; ++i) {
      sum += kw_W0[j + i * KW_HIDDEN_DIM] * features[i];
    }
    hidden[j] = (sum > 0.0f) ? sum : 0.0f;
  }
 
  for (int j = 0; j < KW_OUTPUT_DIM; ++j) {
    float sum = kw_b1[j];
    for (int i = 0; i < KW_HIDDEN_DIM; ++i) {
      sum += kw_W1[j + i * KW_OUTPUT_DIM] * hidden[i];
    }
    logits[j] = sum;
  }
 
  float m = max(logits[0], logits[1]);
  float e0 = expf(logits[0] - m);
  float e1 = expf(logits[1] - m);
 
  ModelResult res;
  res.logit_off = logits[0];
  res.logit_on  = logits[1];
  res.prob = e1 / (e0 + e1);
  return res;
}
 
// -----------------------------------------------------------------------------
static void setupI2S() {
  g_i2s.setPins(/*bclk*/18, /*ws*/20, /*dout*/-1, /*din*/19, /*mclk*/-1);
  bool ok = g_i2s.begin(I2S_MODE_STD, SAMPLE_RATE_HZ, I2S_DATA_BIT_WIDTH_32BIT, I2S_SLOT_MODE_STEREO);
  if (!ok) Serial.println("ERR: I2S.begin failed");
 
  ok = g_i2s.configureRX(SAMPLE_RATE_HZ,
                         I2S_DATA_BIT_WIDTH_32BIT,
                         I2S_SLOT_MODE_STEREO,
                         I2S_RX_TRANSFORM_NONE);
  if (!ok) Serial.println("ERR: I2S.configureRX failed");
}
 
void setup() {
  Serial.begin(115200);
  delay(2000);
 
  g_pixel.begin();
  ledOff();
  initHann();
  setupI2S();
 
  Serial.printf("# Wake word validator (streaming). chunk=%d frames, clip=%d samples\n",
                CHUNK_FRAMES, CLIP_SAMPLES);
}
 
void loop() {
  if (!readChunkMono(g_chunkMono, CHUNK_FRAMES)) {
    delay(5);
    return;
  }
 
  ringPush(g_chunkMono, CHUNK_FRAMES);
  bool framesReady = updateFrameFeatures();
 
  if (g_ringCount < static_cast<size_t>(CLIP_SAMPLES) || !framesReady) {
    return;
  }
 
  static int chunkAccumulator = 0;
  if (++chunkAccumulator < CHUNKS_PER_INFERENCE) {
    return;
  }
  chunkAccumulator = 0;
 
  if (!ringCopyToClip(g_pcm)) {
    return;
  }
 
  static bool dumpedPcm = false;
  if (!dumpedPcm) {
    dumpedPcm = true;
    Serial.print("# pcm[0..15]:");
    for (int i = 0; i < 16; ++i) {
      Serial.printf(" %d", g_pcm[i]);
    }
    Serial.println();
  }
 
  float rms = computeRms(g_pcm, CLIP_SAMPLES);
  Serial.printf("rms=%.5f\n", rms);
  if (rms < RMS_GATE) {
    Serial.println("probHello=0.0000 (gated by RMS)");
    return;
  }
 
  FeatureSummary summary = buildFeatureVector();
  Serial.printf("feat[min=%.2f max=%.2f mean=%.2f]\n", summary.min, summary.max, summary.mean);
 
  static bool printedFeatRaw = false;
  if (!printedFeatRaw) {
    printedFeatRaw = true;
    Serial.print("# featRaw[0..7]:");
    for (int i = 0; i < 8; ++i) {
      Serial.printf(" %.2f", g_feat[i]);
    }
    Serial.println();
  }
 
  ModelResult res = runKeywordModel(g_feat);
 
  static bool printedFeatNorm = false;
  if (!printedFeatNorm) {
    printedFeatNorm = true;
    Serial.print("# featNorm[0..7]:");
    for (int i = 0; i < 8; ++i) {
      Serial.printf(" %.2f", g_feat[i]);
    }
    Serial.println();
  }
 
  static float smooth[SMOOTH_N] = {0};
  static int smoothIdx = 0;
  smooth[smoothIdx] = res.prob;
  smoothIdx = (smoothIdx + 1) % SMOOTH_N;
 
  float avg = 0.0f;
  int valid = 0;
  for (int i = 0; i < SMOOTH_N; ++i) {
    avg += smooth[i];
    if (smooth[i] > 0.0f) ++valid;
  }
  if (valid > 0) {
    avg /= static_cast<float>(valid);
  } else {
    avg = res.prob;
  }
 
  Serial.printf("logits[off=%.3f on=%.3f] prob=%.4f avg=%.4f\n",
                res.logit_off, res.logit_on, res.prob, avg);
 
  static uint32_t lastDetectMs = 0;
  uint32_t nowMs = millis();
  if (avg > HELLO_THRESH && (nowMs - lastDetectMs) > DETECT_COOLDOWN_MS) {
    ledBlink();
    lastDetectMs = nowMs;
  }
}
iothings/laboratoare/2025_code/lab6_7.txt · Last modified: 2025/11/01 21:14 by dan.tudose
CC Attribution-Share Alike 3.0 Unported
www.chimeric.de Valid CSS Driven by DokuWiki do yourself a favour and use a real browser - get firefox!! Recent changes RSS feed Valid XHTML 1.0