#!/usr/bin/env python3
"""Detect voice onset in a VO take and align it to the cue start.

Strategy (phase 1 — anchored inside the green focal area):
  1. Detect voice onset in the take file. Cascade silencedetect at a few
     thresholds, then fall back to an RMS envelope on decoded PCM when
     silencedetect doesn't find a plausible leading silence.
  2. Target position inside the file = preroll_ms + anchor_offset_ms, so
     voice lands slightly INSIDE cue.start on the master timeline (user
     asked for "somewhat inside" instead of exactly at the edge).
  3. Before shifting, sanity-check that the detected onset is inside the
     plausible green-zone window (preroll_ms ± shadow_lead on the early side,
     preroll_ms + cue_duration + shadow_trail on the late side). If it's
     outside that window, detection lied — skip the shift entirely and
     leave the take as recorded. Never clobber the take with a nonsense
     pad/trim that pushes voice outside the green zone.
  4. If detection is plausible, trim (actor came in late) or pad (actor came
     in early) so voice lands at target_in_file_ms.

Output is a re-encoded webm (libopus 128k) so edit points are sample-accurate.
Requires ffmpeg on PATH. Uses numpy for the RMS fallback.
"""
import argparse
import json
import os
import re
import shutil
import subprocess
import sys


# Cascaded silence thresholds: start at -40dB (quiet self-noise), then relax
# toward louder thresholds when the quieter ones don't find a leading silence.
_SILENCE_THRESHOLDS_DB = (-40.0, -33.0, -28.0)
_SILENCE_MIN_DURATION_S = 0.05

# Minimum voice-burst duration to accept as real voice onset. Rejects sub-
# millisecond click transients (mouse clicks at record-start, keyboard taps,
# physical artifacts captured by the mic) that would otherwise hijack the
# silencedetect cascade and lock in a false positive. Calibration:
#   * Click artifacts: 0.1–1 ms physical phenomena, gap to next silence_start
#     is on the order of 0.2 ms in observed material.
#   * Shortest legitimate voice burst: ~50 ms (short stressed Portuguese
#     vowel, e.g. the "O" in "O que foi?" is ~54 ms).
# 30 ms sits comfortably between the two: well above any click, well below
# any real syllable. See B2 fix (build-notes/sessions/2026-04-30.md).
_MIN_VOICE_BURST_S = 0.030


def _silencedetect_onset_ms(input_path: str, threshold_db: float, min_silence_s: float,
                            min_voice_burst_s: float = _MIN_VOICE_BURST_S) -> int:
    """Run silencedetect once and return voice-onset ms if the file has
    leading silence starting within 10ms of t=0 AND the first non-silent
    burst sustains for at least min_voice_burst_s.

    Sub-burst transients (record-start clicks, mouse clicks captured a few
    ms above threshold) are skipped — the algorithm walks past to the first
    silence_end whose subsequent voice burst is genuinely sustained. Without
    this gate, a sub-millisecond click at e.g. 80 ms gets returned as voice
    onset and the cascade locks in a false positive (B2 from 2026-04-30).

    Returns -1 if no silence_end has a sustained voice burst — caller
    relaxes the threshold (next cascade step) or falls back to the RMS
    envelope detector."""
    cmd = [
        "ffmpeg", "-hide_banner", "-nostats",
        "-i", input_path,
        "-af", f"silencedetect=noise={threshold_db}dB:duration={min_silence_s}",
        "-f", "null", "-",
    ]
    out = subprocess.run(cmd, capture_output=True, text=True)
    stderr = out.stderr
    silence_starts = [float(m.group(1)) for m in re.finditer(r"silence_start:\s*([\d.]+)", stderr)]
    silence_ends   = [float(m.group(1)) for m in re.finditer(r"silence_end:\s*([\d.]+)",   stderr)]
    if not (silence_starts and silence_starts[0] < 0.01 and silence_ends):
        return -1
    # Walk silence_ends paired with the next silence_start. The voice burst
    # following silence_ends[i] runs from silence_ends[i] to silence_starts[i+1]
    # (or EOF if no next silence_start). Skip bursts shorter than
    # min_voice_burst_s — they're clicks/transients, not voice. Return the
    # first silence_end whose subsequent burst is genuinely sustained.
    INF = float("inf")
    for i, end in enumerate(silence_ends):
        next_start = silence_starts[i + 1] if i + 1 < len(silence_starts) else INF
        if next_start - end >= min_voice_burst_s:
            return int(round(end * 1000))
    return -1


def _rms_envelope_onset_ms(input_path: str, max_duration_s: float = 6.0, window_ms: int = 20) -> int:
    """Fallback onset detector — decode the first `max_duration_s` of audio to
    mono PCM and find the first sustained rise above the leading noise floor.

    Returns onset in ms, or -1 if the envelope is flat / decode failed."""
    try:
        import numpy as np
    except ImportError:
        return -1
    sr = 16000
    cmd = [
        "ffmpeg", "-hide_banner", "-loglevel", "error",
        "-i", input_path,
        "-t", f"{max_duration_s:.2f}",
        "-ac", "1", "-ar", str(sr),
        "-f", "s16le", "-",
    ]
    try:
        out = subprocess.run(cmd, capture_output=True, check=True)
    except subprocess.CalledProcessError:
        return -1
    pcm = np.frombuffer(out.stdout, dtype=np.int16).astype(np.float32) / 32768.0
    if pcm.size == 0:
        return -1
    win = max(1, int(sr * window_ms / 1000))
    n = pcm.size // win
    if n < 4:
        return -1
    rms = np.sqrt((pcm[:n * win].reshape(n, win) ** 2).mean(axis=1))
    # Noise floor: median of the first ~200ms (assumed silent/near-silent)
    head_n = max(1, 200 // window_ms)
    floor = float(np.median(rms[:head_n])) + 1e-6
    # Threshold: 4× noise floor or absolute 0.02, whichever is higher.
    # 4× ≈ +12 dB, a robust "voice started" marker.
    thresh = max(floor * 4.0, 0.02)
    above = rms > thresh
    if not above.any():
        return -1
    # First index where it's above threshold for at least 2 consecutive windows
    # (rejects single-window transients like clicks).
    for i in range(len(above) - 1):
        if above[i] and above[i + 1]:
            return i * window_ms
    return int(above.argmax()) * window_ms


def detect_onset_ms(input_path: str, threshold_db: float = -40.0, min_silence_s: float = 0.1) -> int:
    """Return ms where voice starts within the file. Cascades silence thresholds,
    then falls back to RMS envelope. Always returns ≥ 0 (0 = file starts hot)."""
    # 1. Try silencedetect cascade
    for thr in _SILENCE_THRESHOLDS_DB:
        onset = _silencedetect_onset_ms(input_path, thr, _SILENCE_MIN_DURATION_S)
        if onset > 0:
            return onset
    # 2. Fall back to RMS envelope (more robust to self-noise / breath)
    rms_onset = _rms_envelope_onset_ms(input_path)
    if rms_onset >= 0:
        return rms_onset
    # 3. Nothing found — treat file as starting hot
    return 0


def _ffmpeg(args, description=""):
    try:
        subprocess.run(args, check=True, capture_output=True, text=True)
    except subprocess.CalledProcessError as e:
        raise RuntimeError(f"ffmpeg failed ({description}): {e.stderr.strip()[:400]}") from e


def trim_from_start(input_path: str, output_path: str, trim_ms: int):
    _ffmpeg([
        "ffmpeg", "-hide_banner", "-loglevel", "error", "-y",
        "-ss", f"{trim_ms/1000:.3f}",
        "-i", input_path,
        "-c:a", "libopus", "-b:a", "128k",
        output_path,
    ], description="trim")


def pad_silence_at_start(input_path: str, output_path: str, pad_ms: int):
    # adelay with the same value for both channel-slots handles mono and stereo.
    _ffmpeg([
        "ffmpeg", "-hide_banner", "-loglevel", "error", "-y",
        "-i", input_path,
        "-af", f"adelay={pad_ms}|{pad_ms}",
        "-c:a", "libopus", "-b:a", "128k",
        output_path,
    ], description="pad")


def shift_take(input_path: str, output_path: str, shift_ms: int, tolerance_ms: int = 5) -> dict:
    """Shift the audio timing by shift_ms (applied ABSOLUTELY to input_path).

    shift_ms > 0 → trim shift_ms from the start (voice plays earlier on timeline)
    shift_ms < 0 → pad |shift_ms| of silence at the start (voice plays later)
    |shift_ms| ≤ tolerance_ms → copy unchanged
    """
    if abs(shift_ms) <= tolerance_ms:
        shutil.copyfile(input_path, output_path)
        action = "copy"
    elif shift_ms > 0:
        trim_from_start(input_path, output_path, shift_ms)
        action = "trim"
    else:
        pad_silence_at_start(input_path, output_path, -shift_ms)
        action = "pad"

    return {
        "input": os.path.abspath(input_path),
        "output": os.path.abspath(output_path),
        "shift_ms": shift_ms,
        "action": action,
        "tolerance_ms": tolerance_ms,
    }


def align_take_to_cue(
    input_path: str,
    output_path: str,
    preroll_ms: int,
    *,
    cue_duration_ms: int = 0,
    anchor_offset_ms: int = 80,
    shadow_lead_ms: int = 300,
    shadow_trail_ms: int = 500,
    threshold_db: float = -40.0,
    min_silence_s: float = 0.1,
    tolerance_ms: int = 5,
) -> dict:
    """Detect voice onset and shift so voice lands at preroll_ms + anchor_offset_ms.

    cue_duration_ms: width of the green focal area. When > 0, the detected onset
        is sanity-checked to fall within the plausible green-zone window. If it
        falls outside, the shift is skipped entirely (take is left as recorded).
    anchor_offset_ms: how far past cue.start to target voice onset. Small positive
        value puts voice "somewhat inside" the cue rather than exactly at the edge.
    shadow_lead_ms / shadow_trail_ms: tolerance on either side of the cue window
        for the sanity check.
    """
    onset_ms = detect_onset_ms(input_path, threshold_db, min_silence_s)

    # Plausibility window for the detected onset INSIDE the take file:
    #   earliest plausible = preroll - shadow_lead        (actor came in early)
    #   latest plausible   = preroll + cue_dur + shadow_trail (actor came in late / long)
    # If cue_duration_ms is 0 (caller didn't tell us), widen generously.
    min_expected = max(0, preroll_ms - shadow_lead_ms)
    max_expected = preroll_ms + (cue_duration_ms if cue_duration_ms > 0 else 4000) + shadow_trail_ms

    if onset_ms < min_expected or onset_ms > max_expected:
        # Detection is outside the green focal area — don't trust it. Leaving
        # the take as recorded is better than shoving voice to a nonsense spot.
        shutil.copyfile(input_path, output_path)
        return {
            "input": os.path.abspath(input_path),
            "output": os.path.abspath(output_path),
            "shift_ms": 0,
            "action": "skip",
            "tolerance_ms": tolerance_ms,
            "detected_onset_ms": onset_ms,
            "preroll_ms": preroll_ms,
            "anchor_offset_ms": anchor_offset_ms,
            "cue_duration_ms": cue_duration_ms,
            "plausible_range_ms": [min_expected, max_expected],
            "reason": "onset_out_of_green_zone",
            "threshold_db": threshold_db,
            "min_silence_s": min_silence_s,
        }

    target_in_file_ms = preroll_ms + anchor_offset_ms
    shift_ms = onset_ms - target_in_file_ms
    info = shift_take(input_path, output_path, shift_ms, tolerance_ms)
    info["detected_onset_ms"] = onset_ms
    info["preroll_ms"] = preroll_ms
    info["anchor_offset_ms"] = anchor_offset_ms
    info["cue_duration_ms"] = cue_duration_ms
    info["threshold_db"] = threshold_db
    info["min_silence_s"] = min_silence_s
    return info


def _cli():
    ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
    ap.add_argument("input", help="Input audio file (webm/wav/etc.)")
    ap.add_argument("output", help="Output webm file")
    ap.add_argument("--preroll", type=int, required=True,
                    help="Expected voice-onset position in ms (the cue's preroll_ms)")
    ap.add_argument("--cue-duration", type=int, default=0,
                    help="Cue duration ms (for green-zone sanity check); 0 = unknown, use wide window")
    ap.add_argument("--anchor-offset", type=int, default=80,
                    help="Target voice this many ms past cue.start (default 80)")
    ap.add_argument("--shadow-lead", type=int, default=300, help="Early-side shadow tolerance ms")
    ap.add_argument("--shadow-trail", type=int, default=500, help="Late-side shadow tolerance ms")
    ap.add_argument("--threshold", type=float, default=-40.0, help="Silence threshold dB (default -40)")
    ap.add_argument("--min-silence", type=float, default=0.1, help="Min silence duration s (default 0.1)")
    ap.add_argument("--tolerance", type=int, default=5, help="No-op if |shift| ≤ this many ms (default 5)")
    args = ap.parse_args()
    result = align_take_to_cue(
        args.input, args.output, args.preroll,
        cue_duration_ms=args.cue_duration,
        anchor_offset_ms=args.anchor_offset,
        shadow_lead_ms=args.shadow_lead,
        shadow_trail_ms=args.shadow_trail,
        threshold_db=args.threshold,
        min_silence_s=args.min_silence,
        tolerance_ms=args.tolerance,
    )
    print(json.dumps(result, indent=2))


if __name__ == "__main__":
    _cli()
