#!/usr/bin/env python3
# scripts/train_kw_exact_norm.py
import os, glob, wave
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier

HELLO_DIR = "data/hello"
OTHER_DIR = "data/other"
OUT_HDR   = "include/audio_model.h"

SAMPLE_RATE = 16000
CLIP_SECONDS = 1.0
CLIP_SAMPLES = int(SAMPLE_RATE * CLIP_SECONDS)

FRAME_LEN = 400    # 25 ms
HOP       = 160    # 10 ms
FFT_LEN   = 512
N_BANDS   = 10

def read_wav_any(fn):
    with wave.open(fn, "rb") as wf:
        sr  = wf.getframerate()
        ch  = wf.getnchannels()
        sw  = wf.getsampwidth()
        nfr = wf.getnframes()
        raw = wf.readframes(nfr)
    assert sr == SAMPLE_RATE, f"{fn}: expected {SAMPLE_RATE} Hz, got {sr}"

    if sw == 2:  # PCM16
        s = np.frombuffer(raw, dtype="<i2")
        if ch == 1:
            mono = s.astype(np.float32) / 32768.0
        elif ch == 2:
            s2 = s.reshape(-1,2)
            # auto-pick louder channel per file (matches firmware)
            rmsL = np.sqrt(np.mean(s2[:,0].astype(np.float32)**2))
            rmsR = np.sqrt(np.mean(s2[:,1].astype(np.float32)**2))
            mono = (s2[:,1] if rmsR > rmsL else s2[:,0]).astype(np.float32) / 32768.0
        else:
            mono = s.reshape(-1,ch).mean(axis=1).astype(np.float32)/32768.0
    elif sw == 4:  # PCM32 -> downshift
        s = np.frombuffer(raw, dtype="<i4")
        if ch == 1:
            mono = (s >> 16).astype(np.int16).astype(np.float32) / 32768.0
        elif ch == 2:
            s2 = s.reshape(-1,2)
            L = (s2[:,0] >> 16).astype(np.int16).astype(np.float32) / 32768.0
            R = (s2[:,1] >> 16).astype(np.int16).astype(np.float32) / 32768.0
            rmsL = np.sqrt(np.mean(L**2)); rmsR = np.sqrt(np.mean(R**2))
            mono = (R if rmsR > rmsL else L).astype(np.float32)
        else:
            cols = [(s.reshape(-1,ch)[:,i] >> 16).astype(np.int16).astype(np.float32)/32768.0 for i in range(ch)]
            mono = np.mean(np.stack(cols,axis=1),axis=1)
    else:
        raise AssertionError(f"{fn}: unsupported sample width {sw*8} bits")

    if len(mono) < CLIP_SAMPLES:
        mono = np.pad(mono, (0, CLIP_SAMPLES - len(mono)))
    else:
        mono = mono[:CLIP_SAMPLES]
    return mono

def stft_frames(x):
    frames = []
    hann = np.hanning(FRAME_LEN).astype(np.float32)
    for start in range(0, CLIP_SAMPLES - FRAME_LEN + 1, HOP):
        frames.append(x[start:start+FRAME_LEN] * hann)
    return np.stack(frames, axis=0)  # [T, 400]

def band_energy_feats(x):
    frames = stft_frames(x)                         # [T, 400]
    spec = np.fft.rfft(frames, n=FFT_LEN, axis=1)   # [T, 257]
    mag2 = (np.abs(spec) ** 2).astype(np.float32)
    n_bins = mag2.shape[1]
    bins_per_band = n_bins // N_BANDS
    bands = []
    for b in range(N_BANDS):
        b0 = b * bins_per_band
        b1 = (b+1)*bins_per_band if b < N_BANDS-1 else n_bins
        p = np.mean(mag2[:, b0:b1], axis=1) + 1e-10
        bands.append(10.0 * np.log10(p))
    bands = np.stack(bands, axis=1).astype(np.float32)  # [T, N_BANDS]
    return bands.flatten()  # time-major then band-major

def build_dataset():
    X, y = [], []
    for fn in glob.glob(os.path.join(HELLO_DIR, "*.wav")):
        X.append(band_energy_feats(read_wav_any(fn))); y.append(1)
    for fn in glob.glob(os.path.join(OTHER_DIR, "*.wav")):
        X.append(band_energy_feats(read_wav_any(fn))); y.append(0)
    X = np.stack(X, axis=0).astype(np.float32)
    y = np.array(y, dtype=np.int64)
    expected = ((SAMPLE_RATE - FRAME_LEN)//HOP + 1) * N_BANDS  # 98*10=980
    print("X shape:", X.shape, "expected input:", expected)
    return X, y

def train_and_export(X, y):
    # Standardize features (z-score) -> export mean & inv std
    mu = X.mean(axis=0).astype(np.float32)
    sigma = X.std(axis=0).astype(np.float32) + 1e-6
    Xz = (X - mu) / sigma

    Xtr, Xte, ytr, yte = train_test_split(Xz, y, test_size=0.2, stratify=y, random_state=0)
    clf = MLPClassifier(hidden_layer_sizes=(20,), activation='relu',
                        solver='adam', max_iter=800, random_state=0)
    clf.fit(Xtr, ytr)
    acc = clf.score(Xte, yte)
    print("Test accuracy:", acc)

    W0 = clf.coefs_[0].astype(np.float32)      # [in,20]
    b0 = clf.intercepts_[0].astype(np.float32) # [20]
    W1 = clf.coefs_[1].astype(np.float32)      # [20,2]
    b1 = clf.intercepts_[1].astype(np.float32) # [2]
    sigma_inv = (1.0 / sigma).astype(np.float32)

    # checksums for quick sanity
    print("checksums:",
          "mu", float(np.sum(mu)), "sig^-1", float(np.sum(sigma_inv)),
          "W0", float(np.sum(np.abs(W0))), "b0", float(np.sum(np.abs(b0))),
          "W1", float(np.sum(np.abs(W1))), "b1", float(np.sum(np.abs(b1))))

    os.makedirs("include", exist_ok=True)
    with open(OUT_HDR, "w") as f:
        f.write("// Auto-generated wake-word model with normalization\n#pragma once\n\n")
        f.write("#define KW_INPUT_DIM %d\n" % X.shape[1])
        f.write("#define KW_HIDDEN_DIM 20\n#define KW_OUTPUT_DIM 2\n\n")

        def dump(name, arr):
            flat = arr.flatten()
            f.write("static const float %s[%d] = {\n" % (name, flat.shape[0]))
            for i,v in enumerate(flat):
                f.write("  %.8ff,%s" % (v, "\n" if (i%8)==7 else " "))
            f.write("};\n\n")

        dump("kw_mu", mu)
        dump("kw_sigma_inv", sigma_inv)
        dump("kw_W0", W0)
        dump("kw_b0", b0)
        dump("kw_W1", W1)
        dump("kw_b1", b1)

    print("Wrote", OUT_HDR)

if __name__ == "__main__":
    X, y = build_dataset()
    train_and_export(X, y)
