from __future__ import annotations
import logging
import threading
import time
from queue import Queue, Full, Empty

try:
    import serial  # pyserial
    from serial import SerialException
except Exception as exc:
    raise RuntimeError("pyserial is required: pip install pyserial") from exc


class SerialHandler(logging.Handler):
    """
    Asynchronous logging handler that writes to a serial port using pyserial.
    """

    def __init__(
        self,
        port: str,
        baudrate: int = 115200,
        encoding: str = "utf-8",
        newline: str = "\r\n",
        write_timeout: float | None = 1.0,
        queue_maxsize: int = 1000,
        drop_policy: str = "drop_new",  # "drop_new" or "drop_old"
        reconnect: bool = True,
        reconnect_backoff_base: float = 0.5,
        reconnect_backoff_max: float = 5.0,
        daemon_thread: bool = True,
    ) -> None:
        super().__init__()
        self.port = port
        self.baudrate = baudrate
        self.encoding = encoding
        self.newline = newline
        self.write_timeout = write_timeout
        self.queue: Queue[str] = Queue(maxsize=queue_maxsize)
        self.drop_policy = drop_policy
        self.reconnect = reconnect
        self.reconnect_backoff_base = reconnect_backoff_base
        self.reconnect_backoff_max = reconnect_backoff_max
        self._stop_evt = threading.Event()
        self._thread = threading.Thread(
            target=self._worker, name="SerialLogWriter", daemon=daemon_thread
        )
        self._ser = None  # type: serial.Serial | None
        self._dropped_count = 0
        self._thread.start()

    @property
    def dropped_count(self) -> int:
        return self._dropped_count

    def emit(self, record: logging.LogRecord) -> None:
        try:
            msg = self.format(record)
            if self.newline and not msg.endswith(self.newline):
                msg = f"{msg}{self.newline}"
            try:
                self.queue.put_nowait(msg)
            except Full:
                if self.drop_policy == "drop_old":
                    try:
                        _ = self.queue.get_nowait()
                    except Empty:
                        pass
                    try:
                        self.queue.put_nowait(msg)
                    except Full:
                        self._dropped_count += 1
                else:
                    self._dropped_count += 1
        except Exception:
            self.handleError(record)

    def close(self) -> None:
        try:
            self._stop_evt.set()
            self._thread.join(timeout=5.0)
        finally:
            try:
                if self._ser:
                    self._ser.close()
            finally:
                self._ser = None
                super().close()

    # -------------------- internals --------------------

    def _open_port(self) -> None:
        if self._ser and self._ser.is_open:
            return
        self._ser = serial.Serial(
            port=self.port,
            baudrate=self.baudrate,
            timeout=0,
            write_timeout=self.write_timeout,
        )

    def _safe_write(self, data: bytes) -> None:
        if not self._ser or not self._ser.is_open:
            self._open_port()
        assert self._ser is not None
        self._ser.write(data)
        # Flush is usually not necessary; keep short to avoid stalls.
        self._ser.flush()

    def _worker(self) -> None:
        backoff = self.reconnect_backoff_base
        while not self._stop_evt.is_set():
            try:
                msg = self.queue.get(timeout=0.2)
            except Empty:
                continue

            try:
                self._safe_write(msg.encode(self.encoding, errors="replace"))
                backoff = self.reconnect_backoff_base
            except (SerialException, OSError):
                if not self.reconnect:
                    # Drop message if no reconnect is allowed.
                    continue
                # Requeue the message at the head if possible.
                try:
                    self.queue.put_nowait(msg)
                except Full:
                    self._dropped_count += 1
                # Attempt reconnect with backoff.
                self._close_quiet()
                time.sleep(backoff)
                backoff = min(backoff * 2.0, self.reconnect_backoff_max)
            except Exception:
                # Unexpected error: drop this message but keep running.
                self._dropped_count += 1

        # Drain remaining messages on shutdown (best effort).
        drain_until = time.time() + 1.0
        while time.time() < drain_until:
            try:
                msg = self.queue.get_nowait()
            except Empty:
                break
            try:
                self._safe_write(msg.encode(self.encoding, errors="replace"))
            except Exception:
                break
        self._close_quiet()

    def _close_quiet(self) -> None:
        try:
            if self._ser:
                self._ser.close()
        except Exception:
            pass
        finally:
            self._ser = None