from __future__ import annotations
from dataclasses import dataclass
from logging import getLogger
from typing import Protocol
import time

from lib.checksum import xor_bytes
from error.errorcodes import UartLoaderAgentErrorCode
from error.atmosic_error import UartLoaderAgentError
from .uart_loader_agent_types import DeviceInfo, ProgramResult
from gateway.hci.hci_gateway import HciGateway

logger = getLogger(__name__)

# plen max 255, minus headers
MAX_WRITE_DATA_PER_PACKET = 255 - 8

class ProgramBinariesProtocol(Protocol):

    def get_device_info(self) -> DeviceInfo:
        ...
    def erase(self, start_addr: int, length: int) -> None:
        ...

@dataclass
class ProgramBinariesRecipeConfig:
    """Configuration for ProgramBinariesHelper."""
    data: bytes
    target_storage_addr: int
    check_ram_xor: bool = True
    check_rram_xor: bool = True
    ram_buffer_start: int = 0
    ram_buffer_size: int = 0
    count_group_packets: int = 1
    group_index_check_rsp: int = 0
    skip_send_ff_packets: bool = True
    is_storage_erased: bool = False

class ProgramBinariesRecipe:
    """Helper class for programming binaries using UartLoaderAgent."""
    def __init__(self, config: ProgramBinariesRecipeConfig):
        self.config = config
        self.result = ProgramResult()

    def execute(self, agent: ProgramBinariesProtocol, gateway: HciGateway
               ) -> ProgramResult:
        """Complete firmware programming workflow with timeout protection."""
        logger.debug("=" * 80)
        logger.debug(f"[Agent] FIRMWARE PROGRAMMING STARTED")
        logger.debug(f"[Agent] Binary size: {len(self.config.data)} bytes"
                       f" (0x{len(self.config.data):08X})")
        target_storage_end_addr =\
            self.config.target_storage_addr + len(self.config.data) - 1
        logger.debug("[Agent] Target storage address:"
                       f" 0x{self.config.target_storage_addr:08X} -"
                       f" {target_storage_end_addr:08X}")
        ram_xor_veri_str =\
            "Enabled" if self.config.check_ram_xor else "Disabled"
        logger.debug(f"[Agent] RAM XOR verification: {ram_xor_veri_str}")
        rram_xor_veri_str =\
            "Enabled" if self.config.check_rram_xor else "Disabled"
        logger.debug(f"[Agent] RRAM XOR verification: {rram_xor_veri_str}")
        logger.debug("=" * 80)

        # Get device info and validate capacity
        if self.config.ram_buffer_start == 0 or\
           self.config.ram_buffer_size == 0:
            logger.debug("[Agent] ram_buffer_start and ram_buffer_size not"
                          " provided, Getting device info...")
            device_info = agent.get_device_info()
            self.config.ram_buffer_start = device_info.ram_buffer_start
            self.config.ram_buffer_size = device_info.ram_buffer_size
            ram_buffer_end =\
                self.config.ram_buffer_start + self.config.ram_buffer_size - 1
            logger.debug("[Agent] Device RAM buffer:"
                          f" 0x{self.config.ram_buffer_start:08X} -"
                          f" 0x{ram_buffer_end:08X}"
                          f" ({self.config.ram_buffer_size} bytes)")

        data_len = len(self.config.data)
        # Program data in chunks using RAM buffer
        total_chunks = (data_len + self.config.ram_buffer_size - 1)\
            // self.config.ram_buffer_size
        logger.debug(f"[Agent] Programming {len(self.config.data)} bytes in"
                     f" {total_chunks} chunks (max"
                     f" {self.config.ram_buffer_size} bytes per chunk)...")
        logger.debug("")  # Empty line for better readability

        data_offset = 0

        logger.debug(
            "[program_firmware] index_check_rsp="
            f"{self.config.group_index_check_rsp},"
            f" count_continue_packets={self.config.count_group_packets}")
        
        current_storage_addr = self.config.target_storage_addr
        while data_offset < data_len:
            # Calculate chunk size (limited by RAM buffer size)
            remaining_data = data_len - data_offset
            chunk_size = min(remaining_data, self.config.ram_buffer_size)
            chunk_data = self.config.data[data_offset:data_offset + chunk_size]
            chunk_number = data_offset//self.config.ram_buffer_size + 1
            total_chunks = (data_len + self.config.ram_buffer_size - 1)\
                // self.config.ram_buffer_size
            
            logger.debug(f"[Agent] === Chunk {chunk_number}/{total_chunks} ===")
            logger.debug(f"[Agent] Source: Binary offset 0x{data_offset:08X}"
                         f" - 0x{data_offset + chunk_size - 1:08X}"
                         f" ({chunk_size} bytes)")
            
            if self.config.skip_send_ff_packets\
                and self.config.is_storage_erased:
                if self._is_all_data_ff(chunk_data):
                    self.result.skiped_ff_chunks += 1
                    self.result.skiped_ff_size += chunk_size
                    data_offset += chunk_size
                    current_storage_addr += chunk_size
                    logger.info(
                        f"[Agent] Chunk {chunk_number}/{total_chunks} skiped"
                        f" cause all 0xFF Progress: {data_offset}/{data_len}"
                        f" bytes ({data_offset * 100 // data_len}%)")
                    continue

            self._program_chunk(agent, gateway, chunk_data, chunk_size,
                                current_storage_addr)

            # Move to next chunk
            data_offset += chunk_size
            current_storage_addr += chunk_size

            logger.info(
                f"[Agent] Chunk {chunk_number}/{total_chunks} completed!"
                f" Progress: {data_offset}/{data_len} bytes"
                f" ({data_offset * 100 // data_len}%)")

        logger.debug("=" * 80)
        logger.debug(f"[Agent] FIRMWARE PROGRAMMING COMPLETED SUCCESSFULLY!")
        logger.debug(f"[Agent] Total programmed: {data_len} bytes"
                     f" (0x{data_len:08X})")
        logger.debug(f"[Agent] Storage range:"
                     f" 0x{self.config.target_storage_addr:08X} -"
                     f" 0x{target_storage_end_addr:08X}")
        logger.debug(f"[Agent] Total chunks processed: {total_chunks}")
        logger.debug(f"[Agent] Skiped all 0xFF chunks:"
                     f" {self.result.skiped_ff_chunks}")
        logger.debug("[Agent] Skiped all 0xFF packets:"
                     f" {self.result.skiped_ff_packets}")
        logger.debug(f"[Agent] Skiped all 0xFF size:"
                     f" {self.result.skiped_ff_size}")
        logger.debug("=" * 80)
        return self.result
        
    def _program_chunk(self, agent: ProgramBinariesProtocol,
                       gateway: HciGateway, chunk_data: bytes,
                       chunk_size: int, target_storage_addr: int) -> None:

        logger.debug("[Agent] Target: Storage address"
                     f" 0x{target_storage_addr:08X} -"
                     f" 0x{target_storage_addr + chunk_size - 1:08X}"
                     f" ({chunk_size} bytes)")

        # Program chunk data to RAM
        logger.debug(f"[Agent] Programming {chunk_size} bytes to RAM buffer at"
                     f" 0x{self.config.ram_buffer_start:08X}...")
        self._program_data_to_ram(agent, gateway, self.config.ram_buffer_start,
                                  chunk_data)

        # Verify RAM data integrity (optional)
        if self.config.check_ram_xor:
            logger.debug(
                "[Agent] Verifying RAM data integrity at"
                f" 0x{self.config.ram_buffer_start:08X}"
                f" ({chunk_size} bytes)...")
            self._verify_ram_integrity(
                gateway, self.config.ram_buffer_start, chunk_size, chunk_data)

        # Apply data from RAM to storage
        logger.debug(
            f"[Agent] Applying {chunk_size} bytes from RAM"
            f" 0x{self.config.ram_buffer_start:08X} to storage"
            f" 0x{target_storage_addr:08X}...")
        self._apply_ram_to_storage(
            gateway, self.config.ram_buffer_start, target_storage_addr,
            chunk_size)

        # Verify storage XOR integrity for this chunk (optional)
        if self.config.check_rram_xor:
            logger.debug(
                "[Agent] Verifying storage XOR integrity at"
                f" 0x{target_storage_addr:08X} ({chunk_size} bytes)...")
            self._verify_storage_xor_integrity(
                gateway, target_storage_addr, chunk_size, chunk_data)


    def _is_all_data_ff(self, data: bytes) -> bool:
        return all(b == 0xFF for b in data)

    def _is_block_has_ff_packet(self, data: bytes) -> bool:
        total_bytes = len(data)
        packets_needed = (total_bytes + MAX_WRITE_DATA_PER_PACKET - 1) \
            // MAX_WRITE_DATA_PER_PACKET

        for packet_idx in range(packets_needed):
            # Check timeout before each packet
            offset = packet_idx * MAX_WRITE_DATA_PER_PACKET
            remaining = total_bytes - offset
            packet_size = min(MAX_WRITE_DATA_PER_PACKET, remaining)
            packet_data = data[offset:offset + packet_size]

            if self._is_all_data_ff(packet_data):
                return True
        return False


    def _program_data_to_ram(self, agent: ProgramBinariesProtocol,
                             gateway: HciGateway,
                             start_address: int, data: bytes) -> None:
        """Program data to RAM using intelligent chunking strategy.

        Args:
            start_address: RAM start address
            data: Data to program
        """

        logger.debug(
            f"[_program_data_to_ram]"
            f" group_index_check_rsp={self.config.group_index_check_rsp},"
            f" count_group_packets={self.config.count_group_packets}")
        if self.config.group_index_check_rsp >= self.config.count_group_packets:
            raise UartLoaderAgentError(
                f"Invalid group_index_check_rsp: "
                f"{self.config.group_index_check_rsp} >="
                f" {self.config.count_group_packets}",
                UartLoaderAgentErrorCode.PROG_WRITE_DATA_ARG_INVALID
            )

        total_bytes = len(data)

        if self.config.skip_send_ff_packets:
            stime = time.time()
            is_has_ff_packet = self._is_block_has_ff_packet(data)
            logger.debug("[Agent] Check has 0xFF packet time:"
                        f" {(time.time() - stime) * 1000:.3f} ms")
            if is_has_ff_packet:
                stime = time.time()
                agent.erase(self.config.ram_buffer_start,
                            self.config.ram_buffer_size)
                logger.debug("[Agent] Has 0xFF packet, erasing RAM buffer."
                             f" spent time: {(time.time() - stime) * 1000:.3f}"
                             " ms")

        packets_needed = (total_bytes + MAX_WRITE_DATA_PER_PACKET - 1) \
            // MAX_WRITE_DATA_PER_PACKET

        logger.debug(f"[Agent] Programming {total_bytes} bytes in"
                      f" {packets_needed} packets"
                      f" (max {MAX_WRITE_DATA_PER_PACKET} bytes per packet)")

        last_checking_seq_no = 0
        seq_no = -1
        for packet_idx in range(packets_needed):
            # Check timeout before each packet
            offset = packet_idx * MAX_WRITE_DATA_PER_PACKET
            remaining = total_bytes - offset
            packet_size = min(MAX_WRITE_DATA_PER_PACKET, remaining)
            packet_address = start_address + offset
            packet_data = data[offset:offset + packet_size]

            
            if self.config.skip_send_ff_packets:
                if self._is_all_data_ff(packet_data):
                    logger.debug(
                        f"[Agent] Packet {packet_idx}/{packets_needed-1}:"
                        f" address=0x{packet_address:08X}, size={packet_size}"
                        f" - All 0xFF, skiped")
                    self.result.skiped_ff_packets += 1
                    self.result.skiped_ff_size += packet_size
                    continue

            seq_no += 1
            is_last_packet = packet_idx == packets_needed - 1
            logger.debug(
                f"[Agent] Packet {packet_idx}/{packets_needed-1}:"
                f" address=0x{packet_address:08X}, size={packet_size}")

            is_need_rsp = seq_no % self.config.count_group_packets\
                == self.config.group_index_check_rsp

            # Writing
            gateway.prog_write_data_send_command(
                packet_address, seq_no, packet_data,
                is_need_rsp or is_last_packet)
            if is_need_rsp and not is_last_packet:
                logger.debug(f"[Agent] Add next checking packet {seq_no}")
                last_checking_seq_no = seq_no

            # Checking
            is_time_to_check = (seq_no % self.config.count_group_packets\
                == self.config.group_index_check_rsp) or is_last_packet
            if last_checking_seq_no != 0 and is_time_to_check:
                logger.debug(
                    f"[Agent] Start waiting packet {last_checking_seq_no}")

                evt = gateway.prog_write_data_wait_event()
                if evt.seq_no is None:
                    raise UartLoaderAgentError(
                        "Data packet failed: seq_no is None",
                        UartLoaderAgentErrorCode.PROG_WRITE_UNEXCPECTED_SEQ_ID
                    )
                if evt.seq_no != last_checking_seq_no:
                    raise UartLoaderAgentError(
                        f"Data packet {last_checking_seq_no}"
                        f" failed with wrong seq_no: {evt.seq_no}",
                        UartLoaderAgentErrorCode.PROG_WRITE_UNEXCPECTED_SEQ_ID
                    )
                last_checking_seq_no = 0

            if is_last_packet:
                logger.debug(
                    f"[Agent] Start waiting last packet: {seq_no}")

                evt = gateway.prog_write_data_wait_event()
                if evt.seq_no is None:
                    raise UartLoaderAgentError(
                        "Data packet failed: seq_no is None",
                        UartLoaderAgentErrorCode.PROG_WRITE_UNEXCPECTED_SEQ_ID
                    )
                if evt.seq_no != seq_no:
                    raise UartLoaderAgentError(
                        f"Data packet {seq_no}"
                        f" failed with wrong seq_no: {evt.seq_no}",
                        UartLoaderAgentErrorCode.PROG_WRITE_UNEXCPECTED_SEQ_ID
                    )

        logger.debug(f"[Agent] Successfully programmed {packets_needed}"
                     " packets to RAM")

    
    
    def _verify_ram_integrity(self, gateway: HciGateway,
                              address: int, length: int,
                              expected_data: bytes) -> None:
        """Verify RAM data integrity using XOR checksum comparison.

        Args:
            address: RAM address to verify
            length: Length of data to verify
            expected_data: Expected data for comparison
        """
        # Get XOR checksum from device
        evt = gateway.prog_xor_check_request(address, length)
        device_xor = evt.xor

        # Calculate expected XOR checksum from original data
        expected_xor = xor_bytes(bytearray(
            expected_data + evt.payload[:-1]))

        logger.debug(f"[Agent] RAM XOR check: device=0x{device_xor:02X},"
                        f" expected=0x{expected_xor:02X}")

        # Compare XOR values
        if device_xor != expected_xor:
            raise UartLoaderAgentError(
                f"RAM integrity verification failed: XOR mismatch"
                f" (device=0x{device_xor:02X},"
                f" expected=0x{expected_xor:02X})",
                UartLoaderAgentErrorCode.XOR_RAM_MISMATCH
            )
        logger.debug("[Agent] RAM integrity verification passed")
    
    def _apply_ram_to_storage(self, gateway: HciGateway,
                              source_addr: int, target_addr: int,
                              length: int) -> None:
        """Apply block from RAM to storage.

        Args:
            source_addr: Source RAM address
            target_addr: Target storage address
            length: Length of data to apply
        """
        gateway.prog_apply_blk_request(source_addr, target_addr,
                                             length)
        logger.debug(f"[Agent] Block applied: 0x{source_addr:08X} ->"
                     f" 0x{target_addr:08X} ({length} bytes)")

    def _verify_storage_xor_integrity(self, gateway: HciGateway,
                                      address: int, length: int,
                                      expected_data: bytes) -> None:
        """Verify storage data integrity using XOR check only.

        Args:
            address: Storage address to verify
            length: Length of data to verify
            expected_data: Expected data for comparison
        """
        # XOR verification
        evt = gateway.prog_xor_check_request(address, length)
        device_xor = evt.xor

        # Calculate expected XOR checksum from original data
        expected_xor = xor_bytes(bytearray(
            expected_data + evt.payload[:-1]))

        logger.debug(f"[Agent] Storage XOR check: device=0x{device_xor:02X},"
                     f" expected=0x{expected_xor:02X}")

        # Compare XOR values
        if device_xor != expected_xor:
            raise UartLoaderAgentError(
                f"Storage integrity verification failed: XOR mismatch"
                f" (device=0x{device_xor:02X},"
                f" expected=0x{expected_xor:02X})",
                UartLoaderAgentErrorCode.XOR_STORAGE_MISMATCH
            )
        logger.debug("[Agent] Storage integrity verification passed")