#!/usr/bin/env python
'''
@file serial_io_entry.py

@brief SerialIoEntry: thin wrapper around pyserial implementing DeviceIoEntryBase

Copyright (C) Atmosic 2025
'''
from __future__ import annotations

import serial
import threading
from logging import getLogger

from deviceio.device_io_entry_base import DeviceIoEntryBase
from lib.bytes_utils import bytes_to_hex
from error.atmosic_error import SerialIoEntryError
from error.errorcodes import SerialIoEntryErrorCode

logger = getLogger(__name__)
uart_logger = getLogger("uart")
LOG_WRITE = "SW"
LOG_READ = "SR"


class SerialIoEntry(DeviceIoEntryBase):
    def __init__(self, port, baudrate, rtscts=False, stream_type: str = "ascii",
                 timeout: float | None = None):
        self._com_read_lock = threading.Lock()
        self._com_write_lock = threading.Lock()
        self._port = port
        self._baudrate = baudrate
        self._rtscts = rtscts
        self._timeout = timeout
        self._ser: serial.Serial | None = None
        self._stream_type = stream_type

    def open(self) -> None :
        with self._com_read_lock, self._com_write_lock:
            logger.debug(f"[{LOG_READ}] _com_read_lock in open")
            if self._ser is not None:
                if self._ser.is_open:
                    self._ser.close()
                self._ser = None
            try:
                self._ser = serial.Serial(
                    self._port,
                    self._baudrate,
                    timeout=self._timeout if self._timeout is not None else 0,
                    write_timeout=1,
                    rtscts=self._rtscts,
                )
                self._ser.reset_input_buffer()
                logger.debug(f"[{LOG_READ}] _com_read_lock opened")
            except BaseException as exp:
                raise SerialIoEntryError(
                    f"Failed to open serial port: {self._port}, {exp}",
                    SerialIoEntryErrorCode.PORT_OPEN_FAILED) from exp
    
    def clear_write_buffer(self) -> None:
        self.flush_output()
    
    def clear_read_buffer(self) -> None:
        self.flush_input()

    def read(self, size: int | None = None, timeout: float | None = None
            ) -> bytes | None:
        with self._com_read_lock:
            try:
                if self._ser is None or not self._ser.is_open:
                    return None

                # Temporarily adjust timeout if provided
                original_timeout = None
                if timeout is not None:
                    original_timeout = self._ser.timeout
                    self._ser.timeout = timeout
                try:
                    if size is None:
                        # Read all available
                        if self._ser.in_waiting <= 0:
                            return None
                        data = self._ser.read(self._ser.in_waiting)
                    else:
                        data = self._ser.read(size)
                finally:
                    if original_timeout is not None:
                        self._ser.timeout = original_timeout
                if data:
                    logger.debug(f"[{LOG_READ}] {bytes_to_hex(data)}, "
                                   f"len: {len(data)}")
                    uart_logger.info("[R] " + bytes_to_hex(data))
                else:
                    logger.debug(f"[{LOG_READ}] None")
                return data if data else None
            except BaseException as exp:
                raise SerialIoEntryError(
                    f"Failed to open serial port: {self._port}, {exp}",
                    SerialIoEntryErrorCode.PORT_OPEN_FAILED) from exp

    def write(self, data: str | bytes) -> int:
        bytes_data: bytes
        if isinstance(data, str):
            if self._stream_type == "ascii":
                ori = data
                bytes_data = data.encode("utf-8")
            elif self._stream_type == "hex":
                ori = data
                bytes_data = bytes.fromhex(data)
            logger.debug(f"[{LOG_WRITE}] {bytes_to_hex(bytes_data)}, "
                           f"len: {len(bytes_data)} ({ori})")
        else:
            bytes_data = data
            logger.debug(f"[{LOG_WRITE}] {bytes_to_hex(bytes_data)}, "
                           f"len: {len(bytes_data)}")
            uart_logger.info("[W] " + bytes_to_hex(bytes_data))
        with self._com_write_lock:
            if self._ser is None or not self._ser.is_open:
                return 0
            try:
                ret = self._ser.write(bytes_data)
            except BaseException as exp:
                raise SerialIoEntryError(
                    f"Failed to write to serial port: {self._port}, {exp}",
                    SerialIoEntryErrorCode.WRITE_FAILED) from exp
            return ret if ret is not None else 0

    def close(self):
        with self._com_read_lock, self._com_write_lock:
            logger.debug(f"[{LOG_READ}] _com_read_lock in close")
            if self._ser is not None:
                try:
                    self._ser.close()
                finally:
                    self._ser = None

    def is_open(self) -> bool:
        return self._ser.is_open if self._ser is not None else False

    def reopen_with_baudrate(self, new_baudrate: int) -> None:
        self.close()
        self._baudrate = new_baudrate
        self.open()

    def flush_input(self):
        with self._com_read_lock:
            logger.debug(f"[{LOG_READ}] _com_read_lock in flush_input")
            if self._ser is not None and self._ser.is_open:
                self._ser.reset_input_buffer()

    def flush_output(self):
        with self._com_write_lock:
            logger.debug(f"[{LOG_WRITE}] _com_write_lock in flush_output")
            if self._ser is not None and self._ser.is_open:
                self._ser.reset_output_buffer()

    def can_read(self) -> bool:
        """Check if device is ready for reading."""
        return self.is_open()

    def can_write(self) -> bool:
        """Check if device is ready for writing."""
        return self.is_open()

