#!/usr/bin/env python3
"""Auto-correct cue start / end timecodes by scanning a reference waveform.

Primary algorithm — **hysteresis edge detection** (requires a clean voice-only
stem, e.g. HTDemucs `voice_waveform.json`):

  - Cue START = silence → voice rising edge closest to the claimed start.
  - Cue END   = voice → silence falling edge closest to the claimed end.

Thresholds are derived dynamically from the search window's peak distribution
so the detector auto-scales across quiet dialogue and shouted scenes.

Fallback — **nearest local minimum** (for intra-phrase cue splits where the
voice is continuous across the bound, so no rising/falling edge exists):

  - Smooth the peaks (box filter), enumerate local minima in the window,
    rank by depth + proximity to the claimed bound, return the best one.

The fallback is what v1 used exclusively when the waveform was mixed audio and
music made edge detection unreliable. Now that we read voice-only stems, edges
are usable and preferred.

Waveform structure: { "sampleRate": 100, "peaks": [0..1, ...] } at 10 ms steps.
"""
import argparse
import json


# --- primary (edge) params ---
# Asymmetric search: script drift is overwhelmingly one-directional — script says
# voice starts EARLIER than reality and ENDS LATER than reality. Biasing the search
# window in that direction catches bigger drifts without picking up unrelated edges.
DEFAULT_START_BACK_MS    = 1500  # for cue start: search this far backward from claim
DEFAULT_START_FORWARD_MS = 3000  # for cue start: search this far forward from claim
DEFAULT_END_BACK_MS      = 1500  # for cue end: search backward (voice rarely ends earlier than claim)
DEFAULT_END_FORWARD_MS   = 3000  # for cue end: search forward (taper + script-often-early)
DEFAULT_VOICE_MIN_MS    = 120    # min sustained voice duration to call it a voice region (used for starts)
DEFAULT_SILENCE_MIN_MS  = 120    # min sustained silence duration to call it silent
DEFAULT_VOICE_LOOKBACK_MS = 400  # how far back to look for voice when detecting voice-end taper
DEFAULT_VOICE_FACTOR    = 0.30   # voice_thresh = factor * p90 of window peaks
DEFAULT_SILENCE_FACTOR  = 0.30   # silence_thresh = factor * voice_thresh

# --- fallback (local-min) params ---
DEFAULT_SMOOTHING_MS      = 40
DEFAULT_MIN_SPACING_MS    = 40
DEFAULT_DEPTH_FACTOR      = 0.75
DEFAULT_DISTANCE_BIAS_MS  = 250


# ============================================================================
# Fallback: local-min detection (kept for continuous-voice intra-phrase splits)
# ============================================================================

def _smoothed(peaks, sr, smoothing_ms):
    sw = max(1, int(round(smoothing_ms / 1000 * sr)))
    half = sw // 2
    N = len(peaks)
    out = [0.0] * N
    window_sum = sum(peaks[0:min(N, half + 1)])
    count = min(N, half + 1)
    for i in range(N):
        out[i] = window_sum / count
        add_i = i + half + 1
        rem_i = i - half
        if add_i < N:
            window_sum += peaks[add_i]; count += 1
        if rem_i >= 0:
            window_sum -= peaks[rem_i]; count -= 1
    return out


def _local_minima(smooth, i0, i1, min_spacing):
    minima = []
    prev = None
    for k in range(max(i0 + 1, 1), min(i1, len(smooth) - 1)):
        if smooth[k] < smooth[k - 1] and smooth[k] <= smooth[k + 1]:
            if prev is not None and (k - prev) < min_spacing:
                if smooth[k] < smooth[prev]:
                    minima[-1] = k; prev = k
                continue
            minima.append(k); prev = k
    return minima


def _find_nearest_min(peaks, sr, target_s, window_s,
                      smoothing_ms, min_spacing_ms,
                      depth_factor, distance_bias_ms):
    target_i = int(round(target_s * sr))
    win = int(round(window_s * sr))
    i0 = max(0, target_i - win)
    i1 = min(len(peaks), target_i + win)
    if i1 <= i0 + 3:
        return None
    smooth = _smoothed(peaks, sr, smoothing_ms)
    window_max = max(smooth[i0:i1])
    if window_max <= 0:
        return None
    min_spacing = max(1, int(round(min_spacing_ms / 1000 * sr)))
    candidates = _local_minima(smooth, i0, i1, min_spacing)
    depth_cap = window_max * depth_factor
    candidates = [k for k in candidates if smooth[k] <= depth_cap]
    if not candidates:
        return None
    bias_samples = max(1, int(round(distance_bias_ms / 1000 * sr)))
    best = None
    best_score = float("inf")
    for k in candidates:
        dist = abs(k - target_i)
        dist_penalty = (dist / bias_samples) * window_max
        score = smooth[k] + dist_penalty
        if score < best_score:
            best_score = score; best = k
    return best


# ============================================================================
# Primary: hysteresis edge detection (silence ↔ voice transitions)
# ============================================================================

# Onomatopoeia rejection — reject candidate bursts whose duration is less than
# this fraction of the cue's expected duration. 0.4 catches typical interjections
# ("ah!", "hmm.", breaths, gasps ~150–500 ms) when the cue line is ≥1 s, while
# letting genuinely-short cues ("Sim.", "Não.") through because their expected
# duration is also small so the ratio still exceeds the floor.
DEFAULT_DURATION_FLOOR_RATIO = 0.4


def _voice_burst_duration_samples(peaks, sr, onset_i, silence_thresh, silence_min):
    """Walk forward from onset_i; return number of samples until sustained silence
    (≥85% of `silence_min` samples below silence_thresh) appears.

    Used for duration-aware candidate filtering: a 200 ms burst at the snap point
    is almost certainly an onomatopoeia / breath / interjection when the cue's
    expected line duration is 2 seconds.
    """
    N = len(peaks)
    threshold = silence_min * 0.85
    end_limit = N - silence_min
    i = onset_i
    while i < end_limit:
        silent = 0
        for j in range(i, i + silence_min):
            if peaks[j] < silence_thresh:
                silent += 1
        if silent >= threshold:
            return i - onset_i
        i += 1
    return N - onset_i


def _dynamic_thresholds(peaks, i0, i1):
    window = peaks[i0:i1]
    if not window:
        return 0.08, 0.03
    sorted_w = sorted(window)
    p90 = sorted_w[min(len(sorted_w) - 1, int(len(sorted_w) * 0.9))]
    voice_thresh = max(0.04, p90 * DEFAULT_VOICE_FACTOR)
    silence_thresh = max(0.01, voice_thresh * DEFAULT_SILENCE_FACTOR)
    return voice_thresh, silence_thresh


def _find_rising_edge(peaks, sr, target_s, back_s, forward_s,
                      voice_min_ms, silence_min_ms,
                      expected_duration_s=None,
                      duration_floor_ratio=DEFAULT_DURATION_FLOOR_RATIO):
    """Find silence → voice transition index closest to target_s within the
    asymmetric window [target - back_s, target + forward_s].

    Returns (onset_idx, low_confidence) tuple or (None, False) if no candidate.

    Two-pass detection:
      1. Anchor scan — collect ALL positions where sustained voice follows
         sustained silence (robust against mid-phrase peaks / HTDemucs bleed).
      2. Candidate selection:
         - Without expected_duration_s: pick anchor closest to target (legacy).
         - With expected_duration_s: score each anchor by (distance from target
           + onomatopoeia penalty if its voice-burst duration is much shorter
           than the cue's expected line length). Rejects grunts/breaths/ad-libs
           that would otherwise win on proximity alone. If no candidate clears
           the duration floor, the lowest-cost (closest plausible) is returned
           with low_confidence=True.
      3. Onset snap — from the chosen anchor, walk BACKWARD through every
         non-silent sample to find the true first audible sample of that voice
         burst. Removes the "lands late" bias on gradual ramps.
    """
    target_i = int(round(target_s * sr))
    i0 = max(0, target_i - int(round(back_s * sr)))
    i1 = min(len(peaks), target_i + int(round(forward_s * sr)))
    if i1 - i0 < 20:
        return None, False

    voice_thresh, silence_thresh = _dynamic_thresholds(peaks, i0, i1)
    voice_min = max(1, int(round(voice_min_ms / 1000 * sr)))
    silence_min = max(1, int(round(silence_min_ms / 1000 * sr)))

    # Pass 1 — collect every qualifying anchor in the window
    candidates = []
    lo = max(i0 + silence_min, silence_min)
    hi = min(i1 - voice_min, len(peaks) - voice_min)
    for i in range(lo, hi):
        # Voice must sustain AFTER i
        voice_count = 0
        for j in range(i, i + voice_min):
            if peaks[j] > voice_thresh:
                voice_count += 1
        if voice_count < voice_min * 0.7:
            continue
        # Silence must sustain BEFORE i
        silence_count = 0
        for j in range(i - silence_min, i):
            if peaks[j] < silence_thresh:
                silence_count += 1
        if silence_count < silence_min * 0.7:
            continue
        candidates.append(i)

    if not candidates:
        return None, False

    # Pass 2 — pick winning candidate
    low_confidence = False
    if expected_duration_s is None or expected_duration_s <= 0:
        # Legacy behavior: closest to target
        best = min(candidates, key=lambda c: abs(c - target_i))
    else:
        expected_samples = expected_duration_s * sr
        best = None
        best_score = float("inf")
        best_ratio = 0.0
        for anchor in candidates:
            burst = _voice_burst_duration_samples(
                peaks, sr, anchor, silence_thresh, silence_min
            )
            distance_s = abs(anchor - target_i) / sr
            ratio = burst / expected_samples if expected_samples > 0 else 1.0
            # Heavy penalty when burst is much shorter than expected line.
            # Penalty is "equivalent seconds" so it's directly comparable to
            # distance_s in the cost. Floor 0.4: anything below = onomatopoeia.
            if ratio < duration_floor_ratio:
                penalty_s = (duration_floor_ratio - ratio) * expected_duration_s * 3.0
            else:
                penalty_s = 0.0
            score = distance_s + penalty_s
            if score < best_score:
                best_score = score
                best = anchor
                best_ratio = ratio
        # If even the winner has a too-short burst, every candidate failed the
        # duration check — likely the real line is outside the search window.
        # Keep the snap (best-effort) but flag it for review.
        if best_ratio < duration_floor_ratio:
            low_confidence = True

    # Onset snap: walk backward through contiguous non-silent samples so gradual
    # ramps ("m…", "hhh…") snap to the true first audible sample instead of the
    # point where the level has already climbed above voice_thresh.
    onset = best
    while onset > 0 and peaks[onset - 1] > silence_thresh:
        onset -= 1
    return onset, low_confidence


def _find_voice_end(peaks, sr, target_s, back_s, forward_s,
                    silence_min_ms, voice_lookback_ms):
    """Find where voice gives way to sustained silence, closest to target_s.

    Mirror of `_find_rising_edge` for the trailing boundary:

      1. Anchor scan — find a position where sustained silence follows recent
         voice activity. Voice endings TAPER (loud → quiet → silent over ~500
         ms), so there's no sharp edge. Anchor detects: (a) sustained silence
         of `silence_min_ms` starting at index i, and (b) voice activity
         somewhere in the preceding `voice_lookback_ms` window. Max-based
         lookback tolerates taper — the phrase can be quiet at i-1 as long as
         it was loud earlier in the lookback.

      2. Offset snap — from the anchor, refine so the returned index is
         exactly "one past the last audible sample". The 85%-silent rule in
         the anchor tolerates a few stray tail peaks inside [i, i+silence_min],
         and the anchor may land a few samples after the true tail because
         the silence needed to sustain. Snap:
           (a) walk FORWARD while peaks[i] > silence_thresh — captures any
               straggler tail peaks the 85% rule allowed past the anchor;
           (b) walk BACKWARD while peaks[i-1] ≤ silence_thresh — pulls back
               past any silent samples inside the anchor region so the result
               lands immediately after the last audible peak.
         The back-walk is bounded to the lookback window so it can't cross a
         real silence gap into a previous voice burst.
    """
    target_i = int(round(target_s * sr))
    i0 = max(0, target_i - int(round(back_s * sr)))
    i1 = min(len(peaks), target_i + int(round(forward_s * sr)))
    if i1 - i0 < 20:
        return None

    voice_thresh, silence_thresh = _dynamic_thresholds(peaks, i0, i1)
    silence_min = max(1, int(round(silence_min_ms / 1000 * sr)))
    voice_lookback = max(1, int(round(voice_lookback_ms / 1000 * sr)))

    best = None
    best_dist = float("inf")
    lo = max(i0, voice_lookback)
    hi = min(i1 - silence_min, len(peaks) - silence_min)
    for i in range(lo, hi):
        # Sustained silence from i onward (stricter — 85% must actually be silent).
        silence_count = 0
        for j in range(i, i + silence_min):
            if peaks[j] < silence_thresh:
                silence_count += 1
        if silence_count < silence_min * 0.85:
            continue
        # Voice activity somewhere in the preceding lookback window.
        # Max-based check tolerates taper: if the phrase was loud N ms ago and
        # has quieted down, that still counts as "voice region ended here".
        prev_max = 0.0
        for j in range(max(0, i - voice_lookback), i):
            if peaks[j] > prev_max:
                prev_max = peaks[j]
        if prev_max < voice_thresh:
            continue
        dist = abs(i - target_i)
        if dist < best_dist:
            best_dist = dist; best = i

    if best is None:
        return None

    # Offset snap — refine anchor to "one past the last audible sample".
    N = len(peaks)
    offset = best
    # (a) Extend forward through any lingering audible samples (tail peaks).
    while offset < N and peaks[offset] > silence_thresh:
        offset += 1
    # (b) Pull back past silent samples so result lands immediately after the
    # last audible peak. Bounded to the lookback so we can't cross a true gap
    # into a previous phrase.
    back_limit = max(0, best - voice_lookback)
    while offset > back_limit and peaks[offset - 1] <= silence_thresh:
        offset -= 1
    return offset


# ============================================================================
# Main entry point
# ============================================================================

def autocorrect_cue(waveform: dict, original_start_s: float, original_end_s: float,
                    start_back_ms: int = DEFAULT_START_BACK_MS,
                    start_forward_ms: int = DEFAULT_START_FORWARD_MS,
                    end_back_ms: int = DEFAULT_END_BACK_MS,
                    end_forward_ms: int = DEFAULT_END_FORWARD_MS,
                    voice_min_ms: int = DEFAULT_VOICE_MIN_MS,
                    silence_min_ms: int = DEFAULT_SILENCE_MIN_MS,
                    voice_lookback_ms: int = DEFAULT_VOICE_LOOKBACK_MS,
                    smoothing_ms: int = DEFAULT_SMOOTHING_MS,
                    min_spacing_ms: int = DEFAULT_MIN_SPACING_MS,
                    depth_factor: float = DEFAULT_DEPTH_FACTOR,
                    distance_bias_ms: int = DEFAULT_DISTANCE_BIAS_MS,
                    gate_threshold: float = 0.0,
                    expected_duration_s: float = None,
                    duration_floor_ratio: float = DEFAULT_DURATION_FLOOR_RATIO) -> dict:
    sr = int(waveform.get("sampleRate") or 100)
    peaks = waveform["peaks"]

    # Per-cue gate plugin: pre-zero peaks below threshold so HTDemucs bleed /
    # low-level noise can't fool the edge detectors. The gated peaks are only
    # used for this one cue's detection — the waveform file on disk is untouched.
    if gate_threshold and gate_threshold > 0:
        peaks = [p if p >= gate_threshold else 0.0 for p in peaks]

    start_back_s    = start_back_ms    / 1000.0
    start_forward_s = start_forward_ms / 1000.0
    end_back_s      = end_back_ms      / 1000.0
    end_forward_s   = end_forward_ms   / 1000.0
    # Local-min fallback uses a symmetric window = average of the asymmetric halves
    fallback_window_s = (start_back_s + start_forward_s) / 2.0

    # Primary: rising/falling edges with asymmetric windows.
    start_idx, start_low_conf = _find_rising_edge(
        peaks, sr, original_start_s,
        start_back_s, start_forward_s,
        voice_min_ms, silence_min_ms,
        expected_duration_s=expected_duration_s,
        duration_floor_ratio=duration_floor_ratio,
    )
    start_method = "rising_edge" if start_idx is not None else None
    if start_idx is None:
        start_idx = _find_nearest_min(peaks, sr, original_start_s, fallback_window_s,
                                      smoothing_ms, min_spacing_ms,
                                      depth_factor, distance_bias_ms)
        start_method = "local_min" if start_idx is not None else "none"

    end_idx = _find_voice_end(peaks, sr, original_end_s,
                              end_back_s, end_forward_s,
                              silence_min_ms, voice_lookback_ms)
    end_method = "voice_end" if end_idx is not None else None
    if end_idx is None:
        end_idx = _find_nearest_min(peaks, sr, original_end_s, fallback_window_s,
                                    smoothing_ms, min_spacing_ms,
                                    depth_factor, distance_bias_ms)
        end_method = "local_min" if end_idx is not None else "none"

    corrected_start = (start_idx / sr) if start_idx is not None else None
    corrected_end = (end_idx / sr) if end_idx is not None else None

    final_start = corrected_start if corrected_start is not None else original_start_s
    final_end   = corrected_end   if corrected_end   is not None else original_end_s

    # Safety: handle inversion without losing a valid correction on the other side.
    # HTDemucs bleed can create spurious "voice" peaks far from the real cue, so
    # occasionally one detector returns a match wildly outside the cue region.
    # If that corrupts the cue (end <= start), revert ONLY the correction whose
    # delta is larger — keeping the likely-correct one.
    if final_end <= final_start:
        start_delta = abs(final_start - original_start_s)
        end_delta   = abs(final_end   - original_end_s)
        if start_delta > end_delta:
            final_start = original_start_s
            start_method = f"{start_method}_reverted" if start_method else "none"
        else:
            final_end = original_end_s
            end_method = f"{end_method}_reverted" if end_method else "none"
        # If still inverted after single revert, revert the other too
        if final_end <= final_start:
            final_start = original_start_s
            final_end = original_end_s

    return {
        "original_start": original_start_s,
        "original_end":   original_end_s,
        "corrected_start": final_start,
        "corrected_end":   final_end,
        "delta_start_ms": int(round((final_start - original_start_s) * 1000)),
        "delta_end_ms":   int(round((final_end   - original_end_s)   * 1000)),
        "onset_found":    corrected_start is not None,
        "offset_found":   corrected_end   is not None,
        "low_confidence": bool(start_low_conf),
        "method": {"start": start_method, "end": end_method},
        "params": {
            "start_back_ms": start_back_ms,
            "start_forward_ms": start_forward_ms,
            "end_back_ms": end_back_ms,
            "end_forward_ms": end_forward_ms,
            "voice_min_ms": voice_min_ms,
            "silence_min_ms": silence_min_ms,
            "voice_lookback_ms": voice_lookback_ms,
            "smoothing_ms": smoothing_ms,
            "min_spacing_ms": min_spacing_ms,
            "depth_factor": depth_factor,
            "distance_bias_ms": distance_bias_ms,
            "gate_threshold": gate_threshold,
            "expected_duration_s": expected_duration_s,
            "duration_floor_ratio": duration_floor_ratio,
        },
    }


def _cli():
    ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
    ap.add_argument("waveform")
    ap.add_argument("--start", type=float, required=True)
    ap.add_argument("--end",   type=float, required=True)
    ap.add_argument("--start-back-ms",    type=int,   default=DEFAULT_START_BACK_MS)
    ap.add_argument("--start-forward-ms", type=int,   default=DEFAULT_START_FORWARD_MS)
    ap.add_argument("--end-back-ms",      type=int,   default=DEFAULT_END_BACK_MS)
    ap.add_argument("--end-forward-ms",   type=int,   default=DEFAULT_END_FORWARD_MS)
    ap.add_argument("--voice-min-ms",     type=int,   default=DEFAULT_VOICE_MIN_MS)
    ap.add_argument("--silence-min-ms",   type=int,   default=DEFAULT_SILENCE_MIN_MS)
    ap.add_argument("--voice-lookback-ms",type=int,   default=DEFAULT_VOICE_LOOKBACK_MS)
    ap.add_argument("--smoothing-ms",     type=int,   default=DEFAULT_SMOOTHING_MS)
    ap.add_argument("--min-spacing-ms",   type=int,   default=DEFAULT_MIN_SPACING_MS)
    ap.add_argument("--depth-factor",     type=float, default=DEFAULT_DEPTH_FACTOR)
    ap.add_argument("--distance-bias-ms", type=int,   default=DEFAULT_DISTANCE_BIAS_MS)
    ap.add_argument("--gate-threshold",   type=float, default=0.0,
                    help="pre-zero peaks below this value (0..1); 0 = disabled")
    ap.add_argument("--expected-duration-s", type=float, default=None,
                    help="cue's expected line duration in seconds; enables duration-aware "
                         "candidate scoring that rejects onomatopoeia / breaths")
    ap.add_argument("--duration-floor-ratio", type=float, default=DEFAULT_DURATION_FLOOR_RATIO,
                    help="reject candidate bursts whose duration is less than this fraction "
                         "of expected_duration_s (default 0.4)")
    args = ap.parse_args()
    with open(args.waveform) as f:
        wf = json.load(f)
    result = autocorrect_cue(
        wf, args.start, args.end,
        start_back_ms=args.start_back_ms,
        start_forward_ms=args.start_forward_ms,
        end_back_ms=args.end_back_ms,
        end_forward_ms=args.end_forward_ms,
        voice_min_ms=args.voice_min_ms,
        silence_min_ms=args.silence_min_ms,
        voice_lookback_ms=args.voice_lookback_ms,
        smoothing_ms=args.smoothing_ms,
        min_spacing_ms=args.min_spacing_ms,
        depth_factor=args.depth_factor,
        distance_bias_ms=args.distance_bias_ms,
        gate_threshold=args.gate_threshold,
        expected_duration_s=args.expected_duration_s,
        duration_floor_ratio=args.duration_floor_ratio,
    )
    print(json.dumps(result, indent=2))


if __name__ == "__main__":
    _cli()
