#!/usr/bin/env python
"""
UART Loader Agent 

This module provides high-level business logic for uart loader protocol.
The agent handles intelligent decision making, data chunking, and complete workflows.
"""
from __future__ import annotations
import time
from logging import getLogger

from gateway.hci.hci_gateway import HciGateway
from lib.checksum import xor_bytes
from packet.hci.hci_packet_implement import HciEvent, HciBootStatusEvent
from .uart_loader_agent_types import DeviceInfo, ProgramResult, DiagnoTputResult
from error.atmosic_error import UartLoaderAgentError, PacketModuleError
from error.errorcodes import UartLoaderAgentErrorCode, PacketModuleErrorCode
from .program_binaries_recipe import (
    ProgramBinariesProtocol, ProgramBinariesRecipe, ProgramBinariesRecipeConfig
)

logger = getLogger(__name__)
# Configuration constants
DIAGNO_TPUT_HEADER_SIZE = 7
MAX_DIAGNO_TPUT_DATA_LENGTH = 255 - 1 # plen(max: 255) - 1 (xor)
MAX_DIAGNO_TPUT_PAKCET_SIZE = DIAGNO_TPUT_HEADER_SIZE + MAX_DIAGNO_TPUT_DATA_LENGTH


class UartLoaderAgent(ProgramBinariesProtocol):
    """UART Loader Agent

    This class provides high-level business logic for uart loader protocol.
    The agent handles intelligent decision making, data chunking, and complete workflows.
    """
    
    def __init__(self, gateway: HciGateway):
        """Initialize agent with HCI gateway."""
        self.gateway = gateway

    def wait_boot_event(self, timeout_s: float = 5.0) -> None:
        """Wait for HciBootStatusEvent after device reset.

        Args:
            timeout: Timeout in seconds to wait for boot event
        """
        logger.debug(f"[Agent] Waiting for boot event"
                     f" (timeout: {timeout_s}s)...")

        def check_boot_event(evt: HciEvent) -> bool:
            logger.debug(f"[Agent] Checking event: {evt}")
            if isinstance(evt, HciBootStatusEvent.EventVS):
                logger.debug(f"[Agent] Boot event received:"
                             f" app_ver={evt.app_ver}")
                return True
            return False
        try:
            self.gateway.hci.read_until(check_boot_event, timeout_s=timeout_s)
        except PacketModuleError as e:
            if e.error_code == PacketModuleErrorCode.READ_UNTIL_TIMEOUT:
                raise UartLoaderAgentError(
                    f"Boot event timeout after {timeout_s}s.",
                    UartLoaderAgentErrorCode.WAIT_BOOT_EVENT_TIMEOUT) from e
            raise
        logger.debug("[Agent] Boot event successfully received")
        
    def erase(self, start_addr: int, length: int, timeout_s: float = 5.0
              ) -> None:
        """Erase data from target address

        Args:
            start_addr: Target storage address
            length: Size of data to erase\
        """
        self.gateway.prog_blk_clean_request(start_addr, length,
                                            timeout_s=timeout_s)
        logger.debug("[Agent] Block cleaned:"
                     f" 0x{start_addr:08X}+0x{length:08X}")
        
    
    def program_binaries(self, data: bytes, target_storage_addr: int,
                         check_ram_xor: bool = True,
                         check_rram_xor: bool = True,
                         ram_buffer_start: int = 0,
                         ram_buffer_size: int = 0,
                         count_group_packets: int = 1,
                         group_index_check_rsp: int = 0,
                         skip_send_ff_packets: bool = True,
                         is_storage_erased: bool = False) -> ProgramResult:
        """Complete firmware programming workflow with timeout protection."""

        logger.debug(f"[Agent] Programming binaries: target_storage_addr="
                     f"0x{target_storage_addr:08X}, data_length={len(data)},"
                     f" check_ram_xor={check_ram_xor}, check_rram_xor="
                     f"{check_rram_xor}, ram_buffer_start={ram_buffer_start},"
                     f" ram_buffer_size={ram_buffer_size},"
                     f" count_group_packets={count_group_packets},"
                     f" group_index_check_rsp={group_index_check_rsp},"
                     f" skip_send_ff_packets={skip_send_ff_packets}")
        recipe = ProgramBinariesRecipe(ProgramBinariesRecipeConfig(
            data=data,
            target_storage_addr=target_storage_addr,
            check_ram_xor=check_ram_xor,
            check_rram_xor=check_rram_xor,
            ram_buffer_start=ram_buffer_start,
            ram_buffer_size=ram_buffer_size,
            count_group_packets=count_group_packets,
            group_index_check_rsp=group_index_check_rsp,
            skip_send_ff_packets=skip_send_ff_packets,
            is_storage_erased=is_storage_erased
        ))
        return recipe.execute(self, self.gateway)
    
    def get_device_info(self) -> DeviceInfo:
        """Get device info and return structured device information."""
        evt = self.gateway.prog_info_request()

        assert evt.ram_buffer_start is not None 
        assert evt.ram_buffer_size is not None
        assert evt.app_ver is not None 
        assert evt.protocol_ver is not None
        assert evt.num_flash is not None 
        assert evt.flash_ids is not None

        device_info = DeviceInfo(
            app_version=evt.app_ver,
            protocol_version=evt.protocol_ver,
            ram_buffer_start=evt.ram_buffer_start,
            ram_buffer_size=evt.ram_buffer_size,
            num_flash=evt.num_flash,
            flash_ids=evt.flash_ids
        )

        logger.debug(f"[Agent] Device info: {device_info}")
        return device_info
    
    def set_baudrate(self, baudrate: int = 115200, delay_us: int = 1000
                     ) -> None:
        """Set optimal baudrate for uart loader communication.

        Args:
            baudrate: Target baudrate
            delay_us: Delay in microseconds after baudrate change
        """
        self.gateway.prog_baudrate_set_request(baudrate, delay_us)
        logger.debug(f"[Agent] Baudrate successfully set to {baudrate} with"
                     f" {delay_us}us delay")

    def cal_sha256_from_device(self, target_storage_addr: int, data_length: int
                               ) -> bytes:
        """Perform SHA256 verification for entire storage range.

        Args:
            target_storage_addr: Starting address in storage to verify
            data_length: Number of bytes to verify

        Returns:
            SHA256 hash bytes (first 16 bytes)
        """
        logger.debug("=" * 80)
        logger.debug(f"[Agent] Verifying storage SHA256 integrity at"
                      f" 0x{target_storage_addr:08X} ({data_length} bytes)...")

        evt = self.gateway.prog_sha256_check_request(target_storage_addr,
                                                     data_length)
        
        assert evt.sha256_first_16_bytes is not None
        logger.debug(f"[Agent] Storage SHA256 check:"
                     f" hash={evt.sha256_first_16_bytes.hex()}")
        logger.debug("=" * 80)
        return evt.sha256_first_16_bytes

    def dump_data(self, absolute_addr: int, data_length: int,
                  timeout_s: float = 15.0) -> bytes:
        """Dump data from target storage address

        Args:
            target_storage_addr: Target storage address
            data_length: Number of bytes to dump

        Returns:
            Dumped data
        """
        self.gateway.prog_dump_send_command(absolute_addr, data_length)

        output = bytearray()
        output_size = 0
        start_time = time.time()
        while output_size < data_length:
            if time.time() - start_time > timeout_s:
                raise UartLoaderAgentError(
                    "Dump operation failed: timeout",
                    UartLoaderAgentErrorCode.DUMP_TIMEOUT
                )
            dump_evt = self.gateway.prog_dump_wait_event()
            assert dump_evt is not None
            xor = xor_bytes(bytearray(dump_evt.payload[:-1]))
            if xor != dump_evt.xor:
                raise UartLoaderAgentError(
                    "Dump operation failed: XOR mismatch",
                    UartLoaderAgentErrorCode.DUMP_XOR_MISMATCH
                )

            assert dump_evt.data is not None
            output += dump_evt.data
            output_size += len(dump_evt.data)

        return bytes(output)

    def diagno_tput(self, data_length: int, count_group_packets: int = 1,
                    timeout_s: float = 15.0) -> None:
        """Diagno throughput from target storage address

        Args:
            target_storage_addr: Target storage address
            data_length: Number of bytes to dump

        Returns:
            Dumped data
        """
        logger.debug(f"[Agent] Diagno throughput: data_length={data_length}")

        loop_index = -1
        while data_length > 0:
            loop_index += 1
            packet_size = min(data_length, MAX_DIAGNO_TPUT_DATA_LENGTH)
            data_len = packet_size - DIAGNO_TPUT_HEADER_SIZE
            data = bytes(bytearray(b"\xA5" * data_len))
            mode = 0x02 if (loop_index + 1) % count_group_packets == 0 or\
                data_length < MAX_DIAGNO_TPUT_DATA_LENGTH else 0x01
            self.gateway.diagno_tput_send_command(mode, data)
            if mode == 0x02:
                self.gateway.diagno_tput_wait_event(timeout_s=timeout_s)
            data_length -= packet_size

    def diagno_latency(self, test_data: int, start_address: int,
                       data_length: int) -> DiagnoTputResult:
        """Diagno latency from target storage address

        Args:
            target_storage_addr: Target storage address
            data_length: Number of bytes to dump

        Returns:
            Dumped data
        """
        logger.debug(f"[Agent] Diagno latency: test_data={test_data},"
                     f" start_address=0x{start_address:08X},"
                     f" data_length={data_length}")

        evt = self.gateway.diagno_latency_request(test_data, start_address,
                                                  data_length)
        
        logger.debug(f"[Agent] Diagno latency: write_time_ram={evt.write_time_ram},"
                     f" erase_time_storage={evt.erase_time_storage},"
                     f" write_time_storage={evt.write_time_storage}")

        assert evt.write_time_ram is not None
        assert evt.erase_time_storage is not None
        assert evt.write_time_storage is not None

        return DiagnoTputResult(
            write_time_ram=evt.write_time_ram,
            erase_time_storage=evt.erase_time_storage,
            write_time_storage=evt.write_time_storage
        )