from __future__ import annotations
import os
import re
import shlex
import tempfile
import subprocess
import threading
import time
from pathlib import Path
from typing import Any, Dict, Optional, Callable, List, TextIO
from enum import IntEnum
from contextlib import contextmanager

from logging import getLogger

logger = getLogger(__name__)

class OpenocdWrapper:
    """
    A thin wrapper to run OpenOCD with Tcl procs, offering three run modes:
      - run_wait(): wait until OpenOCD fully exits.
      - run_until(): wait until a ready pattern appears, then return a handle
        that requires handle.stop() (only this mode requires stop()).
      - run_then_kill(): wait until ready pattern appears, kill OpenOCD,
        and return the result.

    Global defaults can be set on the wrapper; per-call overrides can be given
    via .options(**opts) or directly on run_* methods.
    """

    def __init__(self,
                 jlink_sn: str = "",
                 openocd_exe: os.PathLike | str = r".\openocd\openocd.exe",
                 cfg_path: os.PathLike | str = r".\openocd\jlink_io.cfg",
                 search_path: os.PathLike | str = r".\openocd",
                 tcl_script: os.PathLike | str = r".\openocd\scripts.tcl",
                 ready_re: re.Pattern = re.compile(
                     r"waiting\s+for\s+'?([^'\s]+)'?|waiting\s+to\s+be\s+killed",
                     re.IGNORECASE,
                 ),
                 **defaults) -> None:
        self.jlink_sn = jlink_sn
        self.openocd_exe = str(openocd_exe)
        self.cfg_path = str(cfg_path)
        self.search_path = str(search_path)
        self.tcl_script = str(tcl_script)
        self.ready_re = ready_re

        # Map friendly option names to env keys (milliseconds for timeout).
        self._ENV_MAP: Dict[str, tuple[str, Optional[Callable[[Any], str]]]] = {
            "timeout_ms": ("WAIT_TIMEOUT_MS",
                           lambda v: str(int(v))),
            "signal":     ("SIGNAL_FILE",
                           lambda v: str(v)),
            "speed_khz":  ("SPEED_KHZ",
                           lambda v: str(int(v))),
            # Add more aliases as needed...
        }

        # Base env derived from defaults
        self._base_env: Dict[str, str] = self._merge_env({}, **defaults)

        self.ready_text: str = "waiting for"

    # ---------------------- public command constructor ----------------------

    def cmd(self, name: str, *args: Any) -> _PendingCommand:
        """Create a pending command for a Tcl proc."""
        return _PendingCommand(self, name, list(args))

    # ---------------------- convenience / context overrides -----------------

    @contextmanager
    def use(self, **overrides):
        """
        Temporarily override base options (converted to env via alias map).
        """
        old = self._base_env
        try:
            self._base_env = self._merge_env(self._base_env, **overrides)
            yield self
        finally:
            self._base_env = old

    # ---------------------- internal merge / spawn helpers ------------------

    def _merge_env(self, base: Dict[str, str], **opts) -> Dict[str, str]:
        merged = dict(base)
        for k, v in opts.items():
            if k in self._ENV_MAP:
                env_key, transform = self._ENV_MAP[k]
                if v is None:
                    merged.pop(env_key, None)
                else:
                    merged[env_key] = transform(v) if transform else str(v)
            else:
                # Fallback: pass-through as UPPER env key
                env_key = k.upper()
                if v is None:
                    merged.pop(env_key, None)
                else:
                    merged[env_key] = str(v)
        return merged

    def _spawn(self, cmd_name: str, args: List[Any],
               env: Dict[str, str], spawn_type: SpawnType,
               flag_path: Optional[Path] = None) -> tuple[
                   subprocess.Popen,
                   Path, TextIO, threading.Event, '_LogReader'
               ]:
        proc_env = os.environ.copy()
        proc_env.update(env)

        quoted_args = " ".join(shlex.quote(str(a)) for a in args)
        tcl_call = f"{cmd_name} {quoted_args};"
        if spawn_type == SpawnType.UNTIL:
            tcl_call += f"wait_file_remove \"{str(flag_path)}\";"
        elif spawn_type == SpawnType.WAIT_KILL:
            tcl_call += f"wait_be_killed;"
        # For SpawnType.WAIT, just execute the command and exit
        tcl_call += f"exit;"

        logger.debug(f"OpenOCD command: {cmd_name} {quoted_args}")
        logger.debug(f"TCL call: {tcl_call}")

        proc_args = [
            self.openocd_exe,
            "-f", self.cfg_path,
            "-s", self.search_path,
            "-f", self.tcl_script,
            "-c", tcl_call,
        ]

        log_dir = Path.cwd()/"logs"/"openocd"
        log_dir.mkdir(parents=True, exist_ok=True)
        # Pipe stdout for the reader to detect readiness.
        log_fd, log_path_str = tempfile.mkstemp(prefix="openocd_",
                                                suffix=".log",
                                                dir=log_dir)
        log_path = Path(log_path_str)
        log_fp: TextIO = open(log_fd, "w", buffering=1, encoding="utf-8",
                              errors="replace", closefd=False)

        logger.info(f"Starting OpenOCD process")
        logger.debug(f"OpenOCD log file: {log_path}")
        logger.debug(f"OpenOCD args: {' '.join(proc_args)}")

        proc = subprocess.Popen(
            proc_args,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            env=proc_env,
            text=True,
            bufsize=1,
        )

        ready_evt = threading.Event()
        reader = _LogReader(proc=proc, log_fp=log_fp,
                            ready_re=self.ready_re, ready_evt=ready_evt)
        reader.start()

        return proc, log_path, log_fp, ready_evt, reader

    # ---------------------- the three run backends --------------------------

    def _run_wait(self, cmd: str, args: List[Any],
                  env: Dict[str, str]) -> Dict[str, Any]:
        logger.info(f"Running OpenOCD command: {cmd}")
        proc, log_path, log_fp, ready_evt, reader = \
            self._spawn(cmd, args, env, SpawnType.WAIT)

        logger.debug("Waiting for OpenOCD process to complete")
        rc = proc.wait()
        logger.info(f"OpenOCD process completed with return code: {rc}")
        try:
            log_fp.flush()
        except Exception:
            pass
        output = ""
        try:
            output = log_path.read_text(errors="replace")
        except Exception:
            pass
        try:
            log_fp.close()
        except Exception:
            pass
        try:
            log_path.unlink(missing_ok=True)
        except Exception:
            pass
        return {"returncode": rc, "output": output}

    def _run_until(self, cmd: str, args: List[Any], env: Dict[str, str],
                   on_aborted: Optional[Callable[[Dict[str, Any]], None]]
                   = None) -> _OcdRun:
        
        flag_dir = Path.cwd()
        with tempfile.NamedTemporaryFile(prefix="ocd_", suffix=".flag",
                                         delete=False, dir=flag_dir) as tf:
            flag_path = Path(tf.name)

        proc, log_path, log_fp, ready_evt, reader = \
            self._spawn(cmd, args, env, SpawnType.UNTIL, flag_path)

        logger.info(f"Waiting for OpenOCD to be ready")
        # Wait until the ready marker appears
        ready_evt.wait()
        logger.debug("Ready wording detected")

        return _OcdRun(proc=proc, flag_path=flag_path, log_path=log_path,
                       log_fp=log_fp, reader=reader, on_aborted=on_aborted)

    def _run_then_kill(self, cmd: str, args: List[Any],
                       env: Dict[str, str]) -> Dict[str, Any]:
        logger.info(f"Running OpenOCD command (then kill): {cmd}")
        proc, log_path, log_fp, ready_evt, reader = \
            self._spawn(cmd, args, env, SpawnType.WAIT_KILL)

        logger.debug("Waiting for ready signal before killing")
        ready_evt.wait()
        logger.debug("Ready signal received, killing process")
        try:
            proc.kill()
        except Exception:
            pass
        rc = proc.wait()
        logger.info(f"OpenOCD process killed, return code: {rc}")
        try:
            log_fp.flush()
        except Exception:
            pass
        output = ""
        try:
            output = log_path.read_text(errors="replace")
        except Exception:
            pass
        try:
            log_fp.close()
        except Exception:
            pass
        try:
            log_path.unlink(missing_ok=True)
        except Exception:
            pass
        return {"returncode": rc, "output": output}


# ----------------------------- helpers below -------------------------------

class SpawnType(IntEnum):
    WAIT = 1
    UNTIL = 2
    WAIT_KILL = 3

class _LogReader(threading.Thread):
    """Continuously copies stdout to a log file and sets ready_evt on match."""

    def __init__(self, proc: subprocess.Popen, log_fp: TextIO,
                 ready_re: re.Pattern, ready_evt: threading.Event):
        super().__init__(daemon=True)
        self._proc = proc
        self._log_fp: TextIO = log_fp
        self._ready_re = ready_re
        self._ready_evt = ready_evt

    def run(self):
        try:
            assert self._proc.stdout is not None
            for line in self._proc.stdout:
                try:
                    self._log_fp.write(line)
                    logger.debug(f"[OCD] {line.rstrip()}")
                except Exception:
                    pass
                if not self._ready_evt.is_set():
                    if self._ready_re.search(line):
                        self._ready_evt.set()
        except Exception:
            pass
        finally:
            try:
                self._log_fp.flush()
            except Exception:
                pass


class _OcdRun:
    """
    Handle returned by run_until(). Only this mode requires stop().
    Provides:
      - stop(timeout: Optional[float] = None) -> result dict
      - kill() -> result dict
      - is_running() -> bool
      - context manager support
    """

    def __init__(self, proc: subprocess.Popen, flag_path: Path,
                 log_path: Path, log_fp: TextIO, reader: '_LogReader',
                 on_aborted: Optional[Callable[[Dict[str, Any]], None]]):
        self._proc = proc
        self._flag_path = flag_path
        self._log_path = log_path
        self._log_fp: TextIO = log_fp
        self._reader: '_LogReader' = reader
        self._on_aborted = on_aborted
        self._lock = threading.RLock()
        self._stopped = False
        self._result: Optional[Dict[str, Any]] = None

        # Background watcher: if process exits by itself, collect and callback.
        t = threading.Thread(target=self._watch, daemon=True)
        t.start()

    def __enter__(self):
        return self

    def __exit__(self, exc_t, exc, tb):
        self.stop()

    def is_running(self) -> bool:
        with self._lock:
            return self._proc.poll() is None and not self._stopped

    def _watch(self):
        rc = self._proc.wait()
        with self._lock:
            if self._stopped:
                return
            self._stopped = True
            self._result = self._collect(rc)
            cb = self._on_aborted
        if cb:
            try:
                cb(dict(self._result))
            except Exception:
                pass

    def _safe_read_log(self) -> str:
        try:
            return self._log_path.read_text(errors="replace")
        except Exception:
            return ""

    def _close_log_fp(self):
        try:
            if self._log_fp and not self._log_fp.closed:
                self._log_fp.flush()
                self._log_fp.close()
        except Exception:
            pass

    def _collect(self, returncode: int) -> Dict[str, Any]:
        self._close_log_fp()
        output = self._safe_read_log()
        try:
            self._log_path.unlink(missing_ok=True)
        except Exception:
            pass
        try:
            self._flag_path.unlink(missing_ok=True)
        except Exception:
            pass
        return {"returncode": returncode, "output": output}

    def stop(self, timeout: Optional[float] = None) -> Dict[str, Any]:
        """
        Delete the signal file to let Tcl exit proceed, wait for exit,
        and return the result. timeout in seconds; None means wait forever.
        """
        with self._lock:
            if self._stopped:
                return dict(self._result or
                            {"returncode": self._proc.returncode,
                             "output": self._safe_read_log()})
            try:
                self._flag_path.unlink(missing_ok=True)
            except Exception:
                pass
        rc = self._wait_with_timeout(timeout)
        with self._lock:
            self._stopped = True
            self._result = self._collect(rc)
            return dict(self._result)

    def kill(self) -> Dict[str, Any]:
        """Force kill and return the result."""
        with self._lock:
            if self._stopped:
                return dict(self._result or
                            {"returncode": self._proc.returncode,
                             "output": self._safe_read_log()})
            try:
                self._proc.kill()
            except Exception:
                pass
        rc = self._proc.wait()
        with self._lock:
            self._stopped = True
            self._result = self._collect(rc)
            return dict(self._result)

    def _wait_with_timeout(self, timeout: Optional[float]) -> int:
        if timeout is None:
            return self._proc.wait()
        end = time.time() + timeout
        while time.time() < end:
            rc = self._proc.poll()
            if rc is not None:
                return rc
            time.sleep(0.05)
        try:
            self._proc.terminate()
        except Exception:
            pass
        end2 = time.time() + 1.0
        while time.time() < end2:
            rc = self._proc.poll()
            if rc is not None:
                return rc
            time.sleep(0.05)
        try:
            self._proc.kill()
        except Exception:
            pass
        return self._proc.wait()


class _PendingCommand:
    """
    A pending Tcl command with per-call options. Build env by merging
    wrapper base env and local overrides.
    """

    def __init__(self, ctl: OpenocdWrapper, name: str, args: List[Any]) -> None:
        self._ctl = ctl
        self._name = name
        self._args = args
        self._local_opts: Dict[str, Any] = {}

    def options(self, **opts) -> _PendingCommand:
        """Set per-call options like timeout_ms=10000, speed_khz=4000."""
        self._local_opts.update(opts)
        return self

    # Mode 1: wait until OpenOCD fully exits
    def run_wait(self, **opts) -> Dict[str, Any]:
        env = self._ctl._merge_env(self._ctl._base_env,
                                   **self._local_opts, **opts)
        return self._ctl._run_wait(self._name, self._args, env)

    # Mode 2: wait until ready, return a handle; requires handle.stop()
    def run_until(self, on_aborted=None, **opts) -> _OcdRun:
        env = self._ctl._merge_env(self._ctl._base_env,
                                   **self._local_opts, **opts)
        return self._ctl._run_until(self._name, self._args, env,
                                    on_aborted=on_aborted)

    # Mode 3: wait until ready, then kill and return result
    def run_then_kill(self, **opts) -> Dict[str, Any]:
        env = self._ctl._merge_env(self._ctl._base_env,
                                   **self._local_opts, **opts)
        return self._ctl._run_then_kill(self._name, self._args, env)