#!/usr/bin/env python
'''
@file packet_module.py

@brief Packet module base class

Copyright (C) Atmosic 2025
'''
import time
import threading
from typing import Optional, Callable, TypeVar, Tuple
from logging import getLogger

from deviceio.device_io_module import DeviceIoModule
from packet.base.packet_factory_base import PacketFactoryBase
from packet.base.packet_data_base import PacketDataBase
from deviceio.device_io_read_detector_base import DeviceIoReadDetectorBase
from lib.bytes_utils import bytes_to_hex
from error.atmosic_error import PacketModuleError
from error.errorcodes import PacketModuleErrorCode

logger = getLogger(__name__)
T = TypeVar('T')

LOG_READ = "PR"
LOG_CREATE_PACKET = "PCP"
LOG_OUTPUT_PACKET = "POP"
LOG_WRITE = "PW"
LOG_POP_UNMATCHED = "PPU"
LOG_APPEND_UNMATCHED = "PAU"

class PacketModule:

    def __init__(self, device_io_module: DeviceIoModule,
                 read_detector: DeviceIoReadDetectorBase,
                 packet_factory: PacketFactoryBase):
        """Initialize PacketModule.

        Args:
            device_io_module: Device I/O module for reading/writing
            read_detector: DeviceIoReadDetectorBase to determine packet
                boundaries and drop invalid data
            packet_factory: Factory to create packets from bytes. Must be a
                PacketFactoryBase instance (cannot be None).
        """
        assert device_io_module is not None, "device_io_module cannot be None"
        assert read_detector is not None, "read_detector cannot be None"
        assert packet_factory is not None, "packet_factory cannot be None"
        self._device_io_module = device_io_module
        self._read_detector = read_detector
        self._packet_factory = packet_factory
        self._unmatched_packets: list[PacketDataBase] = []
        self._write_lock = threading.Lock()

    def _read_new(self, timeout: float = 0.0) -> PacketDataBase:
        """Read a new packet from device using slicer.

        Args:
            timeout: Maximum time to wait for a complete packet

        Returns:
            PacketDataBase: New packet if found
        """

        # Use the read detector via DeviceIoReadDetectorBase interface
        raw_data = self._device_io_module.read(
            read_detector=self._read_detector, timeout=timeout
        )
        logger.debug(f"[{LOG_READ}] {bytes_to_hex(raw_data)}")
        packet = self._packet_factory.create_packet(raw_data)
        logger.debug(f"[{LOG_CREATE_PACKET}] {packet}")
        return packet

    def read_one(self, timeout: float = 0.0) -> PacketDataBase:
        """Read one packet, checking unmatched packets first.

        Args:
            timeout: Maximum time to wait for a packet if none in unmatched_packets

        Returns:
            PacketDataBase: First available packet
        """
        # First check if we have any unmatched packets
        if self._unmatched_packets:
            packet = self._unmatched_packets.pop(0)
            logger.debug(f"[{LOG_POP_UNMATCHED}] {packet}")
            return packet

        # No unmatched packets, try to read a new one
        packet = self._read_new(timeout)
        logger.debug(f"[{LOG_OUTPUT_PACKET}] {packet}")
        return packet

    def read_until(self, check_function: Callable[[PacketDataBase], bool],
                   timeout: float = 0.0) -> PacketDataBase:
        """Read packets until check_function returns True for one.

        Args:
            check_function: Function to test each packet, returns True if packet
                matches
            timeout: Maximum time to wait for a matching packet

        Returns:
            PacketDataBase: First packet that matches check_function
        """
        start_time = time.time()

        # First check existing unmatched packets
        for i, packet in enumerate(self._unmatched_packets):
            if check_function(packet):
                self._unmatched_packets.pop(i)
                logger.debug(f"[{LOG_POP_UNMATCHED}] {packet}")
                return packet

        # No matching packet in unmatched_packets, start reading new ones
        remaining_timeout = None
        while True:
            # Check timeout
            if timeout is not None:
                elapsed = time.time() - start_time
                if elapsed >= timeout:
                    raise PacketModuleError(
                        "Timeout waiting for packet",
                        PacketModuleErrorCode.READ_UNTIL_TIMEOUT)
                remaining_timeout = timeout - elapsed

            # Try to read a new packet
            new_packet = self._read_new(remaining_timeout)

            # Check if this new packet matches
            if check_function(new_packet):
                logger.debug(f"[POP] {new_packet}")
                return new_packet

            # Doesn't match, add to unmatched_packets for later
            self._unmatched_packets.append(new_packet)
            logger.debug(f"[{LOG_APPEND_UNMATCHED}] {new_packet}")

    def _write(self, packet: PacketDataBase,
               is_clean_read_buffer: bool = True,
               wait_function: Optional[Callable[[], T]] = None
              ) -> Optional[T]:
        """Write packet to device and optionally wait for response."""
        # Handle packet serialization with error checking
        logger.debug(f"[{LOG_WRITE}] {packet}")
        prepared_bytes = packet.to_bytes()
        prepared_bytes_len = len(prepared_bytes)

        with self._write_lock:
            if is_clean_read_buffer:
                self._device_io_module.clear_read_buffer()
                self._unmatched_packets.clear()

            written_bytes_len = self._device_io_module.write(prepared_bytes)
            if written_bytes_len != prepared_bytes_len:
                raise PacketModuleError(
                    f"Partial write: {written_bytes_len}/{prepared_bytes_len}"
                    f" bytes, packet: {packet}, data:"
                    f" {bytes_to_hex(prepared_bytes)}",
                    PacketModuleErrorCode.WRITE_NOT_ENTIRELY)

            wait_result = None
            if wait_function is not None:
                wait_result = wait_function()

        return wait_result


    def write(self, packet: PacketDataBase,
              is_clean_read_buffer: bool = True) -> None:
        """Write packet to device."""
        self._write(packet, is_clean_read_buffer)

    def write_and_wait(self, packet: PacketDataBase,
                       wait_function: Callable[[], T],
                       is_clean_read_buffer: bool = True) -> T:
        """Write packet to device and optionally wait for response."""
        wait_result = self._write(packet, is_clean_read_buffer, wait_function)
        assert wait_result is not None, "Wait function must return a value"
        return wait_result
