import serial
import time
import os
import struct
import sys
from datetime import datetime

PORT = "/dev/cu.usbmodem2101"  # <= set to your Sparrow's port
BAUD = 115200

BASE_DIR = "data"
HELLO_DIR = os.path.join(BASE_DIR, "hello")
OTHER_DIR = os.path.join(BASE_DIR, "other")
os.makedirs(HELLO_DIR, exist_ok=True)
os.makedirs(OTHER_DIR, exist_ok=True)

def get_timestamp():
    return datetime.now().strftime("%Y%m%d_%H%M%S_%f")

def read_line(ser):
    line = b""
    while True:
        ch = ser.read(1)
        if not ch:
            continue
        if ch == b'\n':
            break
        line += ch
    return line.decode("utf-8", errors="ignore").strip()

def request_and_save(label):
    # send command
    ser.write(label.encode("utf-8") + b"\n")
    ser.flush()

    # Wait for header line: "CHUNK <label> <size>"
    header = read_line(ser)
    if not header.startswith("CHUNK"):
        print("Bad header:", header)
        return
    parts = header.split()
    if len(parts) != 3:
        print("Bad header format:", header)
        return
    lbl = parts[1]
    size = int(parts[2])

    # Read exactly `size` bytes (raw WAV buffer)
    wav_bytes = b""
    while len(wav_bytes) < size:
        chunk = ser.read(size - len(wav_bytes))
        if not chunk:
            continue
        wav_bytes += chunk

    # After binary blob, firmware prints newline then "END"
    _newline = read_line(ser)  # flush the newline after WAV
    endline = read_line(ser)
    if endline != "END":
        print("Warning: expected END, got", endline)

    # Save file
    ts = get_timestamp()
    if lbl == "hello":
        out_path = os.path.join(HELLO_DIR, f"hello_{ts}.wav")
    else:
        out_path = os.path.join(OTHER_DIR, f"other_{ts}.wav")

    with open(out_path, "wb") as f:
        f.write(wav_bytes)

    print("Saved", out_path, f"({len(wav_bytes)} bytes)")

if __name__ == "__main__":
    ser = serial.Serial(PORT, BAUD, timeout=0.1)
    time.sleep(2.0)  # let board reset
    ser.reset_input_buffer()

    print("Ready.")
    print("Type h for 'hello sparrow' samples, o for background/other, q to quit.")

    try:
        while True:
            key = input("> ").strip().lower()
            if key == "q":
                break
            elif key == "h":
                request_and_save("hello")
            elif key == "o":
                request_and_save("other")
            else:
                print("Use h / o / q")
    finally:
        ser.close()
