#!/usr/bin/env python3
"""
Standalone EEG analysis script — extracted from Django management command.
Reads an EDF file, trialData.json, and recognitionPresses.json, then
generates the same analysis graphs that the Django version produced.

Usage:
    python3 run_analysis.py

All input files are expected in the same directory as this script.
Output PNG files are written to an 'output/' subdirectory.
"""

import os
import sys
import json
import pprint
import traceback

import numpy as np
import scipy.signal
import matplotlib
matplotlib.use("Agg")  # headless backend — no display needed
import matplotlib.pyplot as plt
import pyedflib
from pyedflib import highlevel
from datetime import datetime

# ---------------------------------------------------------------------------
# Configuration — values extracted from the Django Trial #144 admin page
# ---------------------------------------------------------------------------
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

EDF_FILE = os.path.join(SCRIPT_DIR, "AH_20260228_102752.edf")
TRIAL_DATA_FILE = os.path.join(SCRIPT_DIR, "trialData.json")
RECOGNITION_PRESSES_FILE = os.path.join(SCRIPT_DIR, "recognitionPresses.json")

# From the Django model / admin screenshot
TRIAL_PK = 144
TRIAL_TIMESTAMP_STR = "02/28/2026 10:28:16.854"
EDF_START_TIME_STR = "02/28/2026 10:28:05.1"
JAVASCRIPT_VERSION = "1.2"

OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def peri_stimulus_average(signals, signal_labels, fs, notch_b, notch_a,
                          signal_name, stimuli_list, before_offset, after_offset):
    """
    Compute the peri-stimulus average for *signal_name* around each time in
    *stimuli_list*.  before_offset and after_offset are in ms.
    Returns (average_waveform, notch-filtered_signal).
    """
    before_samples = int(-before_offset / (1000 / fs))
    after_samples = int(after_offset / (1000 / fs))
    average = np.zeros(after_samples - before_samples)

    sig = signals[signal_labels.index(signal_name)]
    sig_notch = scipy.signal.filtfilt(notch_b, notch_a, sig)

    count = 0
    for value in stimuli_list:
        count += 1
        for i in range(before_samples, after_samples):
            idx = int(value / (1000 / fs)) + i
            if 0 <= idx < len(sig_notch):
                average[i] = average[i] * (count - 1) / count + sig_notch[idx] / count

    return average, sig_notch


# ---------------------------------------------------------------------------
# Main analysis
# ---------------------------------------------------------------------------
def main():
    print("=" * 60)
    print("Standalone EEG Analysis — Trial #%d" % TRIAL_PK)
    print("=" * 60)

    # --- Load input files ---------------------------------------------------
    with open(TRIAL_DATA_FILE) as f:
        trial_data = json.load(f)

    with open(RECOGNITION_PRESSES_FILE) as f:
        recognition_presses = json.load(f)

    print("Loaded %d recognition presses" % len(recognition_presses))

    # --- Read EDF -----------------------------------------------------------
    signals, signal_headers, header = highlevel.read_edf(EDF_FILE)
    f_edf = pyedflib.EdfReader(EDF_FILE)
    signal_labels = f_edf.getSignalLabels()
    f_edf.close()

    fs = signal_headers[0]["sample_frequency"]
    print("Sample frequency: %f Hz" % fs)
    print("Signal labels: %s" % signal_labels)
    print("EDF header startdate: %s" % header["startdate"])

    # --- Compute time offset ------------------------------------------------
    # The EEG signal array starts at the EDF header's startdate (sample 0).
    # JS times (trialData values and recognition presses) are relative to
    # the trial timestamp.  We need the offset from EDF start to trial start
    # to convert JS-relative times into EEG sample indices.
    eeg_start_time = header["startdate"]
    trial_timestamp = datetime.strptime(TRIAL_TIMESTAMP_STR, "%m/%d/%Y %H:%M:%S.%f")

    offset_ms = (trial_timestamp - eeg_start_time).total_seconds() * 1000
    print("Offset from EDF start to trial start: %.1f ms (%.3f s)" %
          (offset_ms, offset_ms / 1000))

    trim_start = -500 + offset_ms
    trim_end = 5000 + offset_ms

    # Shift recognition presses from JS-relative to EEG-absolute time
    recognition_presses = [rp + offset_ms for rp in recognition_presses]

    # --- Notch filter (60 Hz) -----------------------------------------------
    f0 = 60.0
    Q = 30.0
    notch_b, notch_a = scipy.signal.iirnotch(f0, Q, fs)

    # --- Optical trigger channels -------------------------------------------
    optical1 = "CH21"
    optical2 = "Cz"
    ekg1 = signals[signal_labels.index(optical1)]
    ekg2 = signals[signal_labels.index(optical2)]
    diff = ekg1 - ekg2

    # --- Build image dictionaries -------------------------------------------
    clear_images = {}
    scrambled_images = {}
    filtered_images = {}
    prompt_delay = 1000

    for key, value in trial_data["data"].items():
        if not key.isdigit():
            continue
        values = value.split(",")
        t = int(round(float(values[0]))) + offset_ms
        if values[2] == "false":
            clear_images[t] = values[1]
            filtered = False
            for val in recognition_presses:
                try:
                    if val > t and val < t + prompt_delay and not filtered:
                        filtered_images[val] = values[1]
                        filtered = True
                except Exception as e:
                    print(e)
        else:
            scrambled_images[t] = values[1]

    # --- Associated keypresses (minimum reaction time 200 ms) ---------------
    minimum_reaction_time = 200
    associated_keypresses = {}
    for key in clear_images.keys():
        filtered = False
        for val in recognition_presses:
            if val > key + minimum_reaction_time and val < key + prompt_delay:
                if not filtered:
                    associated_keypresses[key] = val
                    filtered = True

    print("Clear images: %d" % len(clear_images))
    print("Scrambled images: %d" % len(scrambled_images))
    print("Filtered images (with keypress in window): %d" % len(filtered_images))
    print("Associated keypresses: %d" % len(associated_keypresses))

    # --- Peri-stimulus averages ---------------------------------------------
    before_offset = 500  # ms
    after_offset = 500   # ms

    channels = ["O1", "O2", "T5", "T6"]

    # Clear-image PSA
    psa = {}
    sig_notch = {}
    for ch in channels:
        psa[ch], sig_notch[ch] = peri_stimulus_average(
            signals, signal_labels, fs, notch_b, notch_a,
            ch, clear_images.keys(), before_offset, after_offset)

    # Filtered PSA (only stimuli followed by a keypress within prompt_delay)
    fpsa = {}
    for ch in channels:
        fpsa[ch], _ = peri_stimulus_average(
            signals, signal_labels, fs, notch_b, notch_a,
            ch, filtered_images.keys(), before_offset, after_offset)

    # Keypress-locked PSA
    kppsa = {}
    for ch in channels:
        kppsa[ch], _ = peri_stimulus_average(
            signals, signal_labels, fs, notch_b, notch_a,
            ch, associated_keypresses.keys(), 0, prompt_delay)

    # Non-stimulus (scrambled) PSA
    pnsa = {}
    for ch in channels:
        pnsa[ch], _ = peri_stimulus_average(
            signals, signal_labels, fs, notch_b, notch_a,
            ch, scrambled_images.keys(), before_offset, after_offset)

    # --- Time axes ----------------------------------------------------------
    time = np.arange(len(diff)) * 1000 / fs
    peri_stimulus_time = np.arange(-before_offset, after_offset, 1000 / fs)
    kp_stimulus_time = np.arange(0, prompt_delay, 1000 / fs)

    # ======================================================================
    # PLOT 1 — Full signal alignment with images & keypresses
    # ======================================================================
    bigplot, axes = plt.subplots(2, 1, figsize=(20, 10))
    bigplot.suptitle(
        "Aligning O1 and Cz with images and key presses in trial #%s" % TRIAL_PK,
        fontsize=16)

    O1_notch = sig_notch["O1"]
    axes[0].plot(time, O1_notch, color="lightGreen", label="O1")
    axes[0].plot(time, diff, color="purple", label="CH21 - Cz")
    axes[0].legend()
    axes[0].set_xlim(trim_start, trim_end)

    axes[1].plot(time, O1_notch, color="lightGreen", label="O1")
    axes[1].vlines(x=recognition_presses,
                   ymin=min(O1_notch), ymax=max(O1_notch),
                   color="purple", linestyle="--", linewidth=1,
                   label="recognition presses")
    axes[1].vlines(x=list(clear_images.keys()),
                   ymin=min(O1_notch) / 2, ymax=max(O1_notch) / 2,
                   color="green", linestyle="--", linewidth=2,
                   label="clear images")
    axes[1].vlines(x=list(scrambled_images.keys()),
                   ymin=min(O1_notch) / 2, ymax=max(O1_notch) / 2,
                   color="red", linestyle="--", linewidth=2,
                   label="scrambled images")
    axes[1].legend()
    axes[1].set_xlim(trim_start, trim_end)

    filepath = os.path.join(OUTPUT_DIR, "trial%d-signalAlignment.png" % TRIAL_PK)
    bigplot.savefig(filepath)
    print("Saved: %s" % filepath)
    plt.close(bigplot)

    # ======================================================================
    # PLOT 2 — Peri-stimulus average (all clear images)
    # ======================================================================
    plt.figure(figsize=(20, 3))
    plt.title("Trial #%d, peri-stimulus averages over time, averaged across "
              "%d stimuli and %d non-stimuli" %
              (TRIAL_PK, len(clear_images), len(scrambled_images)))
    for ch in channels:
        plt.plot(peri_stimulus_time, psa[ch], label="%s peri-stimulus average" % ch)
    plt.xlabel("ms before/after stimulus presentation")
    plt.legend(loc="upper right")
    filepath = os.path.join(OUTPUT_DIR, "trial%d-periStimulusAverage.png" % TRIAL_PK)
    plt.savefig(filepath)
    print("Saved: %s" % filepath)
    plt.close()

    # ======================================================================
    # PLOT 3 — Filtered peri-stimulus average (only stimuli with keypress)
    # ======================================================================
    plt.figure(figsize=(20, 3))
    plt.title("Trial #%d, filtered peri-stimulus averages (stimuli with keypress)" % TRIAL_PK)
    for ch in channels:
        plt.plot(peri_stimulus_time, fpsa[ch],
                 label="%s filtered peri-stimulus average" % ch)
    plt.xlabel("ms before/after stimulus presentation")
    plt.legend(loc="upper right")
    filepath = os.path.join(OUTPUT_DIR, "trial%d-filteredPeriStimulusAverage.png" % TRIAL_PK)
    plt.savefig(filepath)
    print("Saved: %s" % filepath)
    plt.close()

    # ======================================================================
    # PLOT 4 — Keypress-locked peri-stimulus average
    # ======================================================================
    plt.figure(figsize=(20, 3))
    plt.title("Trial #%d, filtered peri-stimulus (only with key press) averages, "
              "averaged across %d prompt (within %d ms) presses" %
              (TRIAL_PK, len(filtered_images), prompt_delay))
    for ch in channels:
        plt.plot(kp_stimulus_time, kppsa[ch],
                 label="%s, averaged around stimulus w/associated keypress" % ch)
    plt.xlabel("ms before/after key press")
    plt.legend(loc="upper right")
    filepath = os.path.join(OUTPUT_DIR,
                            "trial%d-filteredPeriPromptStimulusAverage.png" % TRIAL_PK)
    plt.savefig(filepath)
    print("Saved: %s" % filepath)
    plt.close()

    # ======================================================================
    # PLOT 5 — Peri-non-stimulus average (scrambled images)
    # ======================================================================
    plt.figure(figsize=(20, 3))
    plt.title("Trial #%d, peri-non-stimulus average over time" % TRIAL_PK)
    for ch in channels:
        plt.plot(peri_stimulus_time, pnsa[ch],
                 label="%s peri-non-stimulus average" % ch)
    plt.xlabel("ms before/after stimulus presentation")
    plt.legend()
    filepath = os.path.join(OUTPUT_DIR, "trial%d-periNonStimulusAverage.png" % TRIAL_PK)
    plt.savefig(filepath)
    print("Saved: %s" % filepath)
    plt.close()

    # ======================================================================
    # HTML report
    # ======================================================================
    image_files = [
        ("Signal Alignment", "trial%d-signalAlignment.png" % TRIAL_PK),
        ("Peri-Stimulus Average (all clear images)", "trial%d-periStimulusAverage.png" % TRIAL_PK),
        ("Filtered Peri-Stimulus Average (stimuli with keypress)", "trial%d-filteredPeriStimulusAverage.png" % TRIAL_PK),
        ("Keypress-Locked Peri-Stimulus Average", "trial%d-filteredPeriPromptStimulusAverage.png" % TRIAL_PK),
        ("Peri-Non-Stimulus Average (scrambled images)", "trial%d-periNonStimulusAverage.png" % TRIAL_PK),
    ]

    html_parts = [
        "<!DOCTYPE html>",
        "<html><head>",
        "<meta charset='utf-8'>",
        "<title>Trial #%d Analysis</title>" % TRIAL_PK,
        "<style>",
        "  body { font-family: sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }",
        "  h1 { border-bottom: 1px solid #ccc; padding-bottom: 8px; }",
        "  h2 { margin-top: 32px; }",
        "  img { max-width: 100%; border: 1px solid #ddd; }",
        "  .meta { color: #666; margin-bottom: 24px; }",
        "</style>",
        "</head><body>",
        "<h1>Trial #%d &mdash; EEG Analysis</h1>" % TRIAL_PK,
        "<p class='meta'>Subject: AH1 &bull; %s &bull; %d clear images, %d scrambled, %d keypresses</p>"
        % (TRIAL_TIMESTAMP_STR, len(clear_images), len(scrambled_images), len(associated_keypresses)),
    ]

    for title, filename in image_files:
        html_parts.append("<h2>%s</h2>" % title)
        html_parts.append("<img src='%s'>" % filename)

    html_parts.append("</body></html>")

    html_path = os.path.join(OUTPUT_DIR, "trial%d.html" % TRIAL_PK)
    with open(html_path, "w") as f:
        f.write("\n".join(html_parts))
    print("Saved: %s" % html_path)

    # ======================================================================
    # Index page — links to all processed trials
    # ======================================================================
    import glob
    import re

    trial_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "trial*.html")),
                         key=lambda p: int(re.search(r"trial(\d+)", p).group(1)))

    index_parts = [
        "<!DOCTYPE html>",
        "<html><head>",
        "<meta charset='utf-8'>",
        "<title>EEG Analysis Index</title>",
        "<style>",
        "  body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }",
        "  h1 { border-bottom: 1px solid #ccc; padding-bottom: 8px; }",
        "  ul { list-style: none; padding: 0; }",
        "  li { padding: 8px 0; border-bottom: 1px solid #eee; }",
        "  a { text-decoration: none; color: #0066cc; font-size: 1.1em; }",
        "  a:hover { text-decoration: underline; }",
        "</style>",
        "</head><body>",
        "<h1>EEG Analysis &mdash; Processed Trials</h1>",
        "<ul>",
    ]

    for tf in trial_files:
        basename = os.path.basename(tf)
        trial_num = re.search(r"trial(\d+)", basename).group(1)
        index_parts.append("<li><a href='%s'>Trial #%s</a></li>" % (basename, trial_num))

    index_parts.append("</ul>")
    index_parts.append("</body></html>")

    index_path = os.path.join(OUTPUT_DIR, "index.html")
    with open(index_path, "w") as f:
        f.write("\n".join(index_parts))
    print("Saved: %s" % index_path)

    print("\n" + "=" * 60)
    print("Analysis complete. All output saved to: %s" % OUTPUT_DIR)
    print("=" * 60)


if __name__ == "__main__":
    main()
