"""
HciCommandOperationBase
"""

from __future__ import annotations
import time
from typing import Optional
from pydantic import BaseModel, Field
from logging import getLogger
from hashlib import sha256
from datetime import datetime

from deviceio.serial_io_module import SerialIoModule
from packet.hci_module import HciModule
from gateway.hci.hci_gateway import HciGateway
from agent.uart_loader_agent import UartLoaderAgent
from pydantic_argparse.argparse_helper_funcs import ArgExtra
from error.atmosic_error import UartLoaderTaskError
from error.errorcodes import UartLoaderTaskErrorCode
from csv_report_struct.uart_loader_spend_time import UartLoaderSpendTimeReport
from csv_report_struct.program_full_spend_time import ProgramFullSpendTimeReport
from csv_report.csv_report_helper import CsvReportHelper
from agent.uart_loader_agent_types import ProgramResult

logger = getLogger(__name__)

class UartLoaderOperation:
    """Reusable uart loader operations"""
    class InitContext(BaseModel):
        uart_loader_serial: str = ArgExtra(Field(
            description="Serial port for uart loader communication"
        ), engineer_mode=True)
        uart_loader_init_baudrate: int = ArgExtra(Field(
            default=115200,
            description="Initial baudrate for uart loader communication"
        ), engineer_mode=True)

    def __init__(self, ctx: InitContext,
                 start_timestamp: datetime | None = None,
                 full_report: ProgramFullSpendTimeReport | None = None):
        """Initialize UartLoaderOperation"""
        self.ctx = ctx

        # Runtime state
        self.uart_loader_serial_module: Optional[SerialIoModule] = None
        self.hci_module: Optional[HciModule] = None
        self.gateway: Optional[HciGateway] = None
        self.agent: Optional[UartLoaderAgent] = None
        self.ram_buffer_start: int = 0
        self.ram_buffer_size: int = 0
        
        self.start_timestamp = start_timestamp
        self.full_report = full_report

        logger.info(f"UartLoaderOperation initialized:")
        logger.info(f"  Uart loader serial: {ctx.uart_loader_serial}")

    def init_uart_loader_serial(self):
        """Initialize uart loader serial port and create HCI components"""
        logger.info("=== Initialize uart loader serial port ===")

        self.uart_loader_serial_module = SerialIoModule(
            port=self.ctx.uart_loader_serial,
            baudrate=self.ctx.uart_loader_init_baudrate,
            rtscts=False,
            stream_type="ascii",
            timeout=3
        )

        logger.info(f"Opening uart loader serial port:"
                      f" {self.ctx.uart_loader_serial} at"
                      f" {self.ctx.uart_loader_init_baudrate} baudrate")
        self.uart_loader_serial_module.open()
        
        # Create HCI module and agent for boot event handling
        self.hci_module = HciModule(self.uart_loader_serial_module)
        self.gateway = HciGateway(self.hci_module)
        self.agent = UartLoaderAgent(self.gateway)

        logger.info(f"Uart loader serial port {self.ctx.uart_loader_serial}"
                      " opened successfully")
    
    class WaitBootEventContext(BaseModel):
        timeout: float = Field(
            default=3.0,
            description="Timeout for waiting for boot event in seconds"
        )
    def wait_for_boot_event(self,
                            ctx: WaitBootEventContext = WaitBootEventContext()):
        """Wait for boot event
        
        Args:
            timeout: Timeout in seconds
        """
        if self.agent is None:
            raise ValueError("Agent not initialized - call"
                             " init_uart_loader_serial first")

        logger.info("Waiting for boot event from device...")

        # This will raise BootEventTimeoutError or HciCommandError if it fails
        self.agent.wait_boot_event(timeout_s=ctx.timeout)
        logger.info("Boot event received successfully!")

    class ChangeBaudrateContext(BaseModel):
        baudrate: int = Field(
            description="Target baudrate"
        )
        delay_us: int = Field(
            default=1000,
            description="Delay in microseconds after baudrate change"
        )
    def change_baudrate(self, ctx: ChangeBaudrateContext):
        """Execute baudrate change phase
        
        Args:
            baudrate: Target baudrate
            delay_us: Delay in microseconds after baudrate change
        """
        if ctx.baudrate == self.ctx.uart_loader_init_baudrate:
            logger.info("=== Baudrate change phase ===")
            logger.info("Programming baudrate is the same as init baudrate, "
                        "skip change baudrate")
            return
        assert self.uart_loader_serial_module is not None
        assert self.agent is not None
            
        logger.info("=== Baudrate change phase ===")
        logger.info(f"Changing baudrate to {ctx.baudrate}...")
        self.agent.set_baudrate(baudrate=ctx.baudrate, delay_us=ctx.delay_us)
        time.sleep(0.01)
        self.uart_loader_serial_module.reopen_with_baudrate(ctx.baudrate)

    def get_prog_info(self) -> bool:
        """Send prog info request and store device information

        Returns:
            bool: True if successful, False otherwise
        """
        assert self.agent is not None

        logger.info("=== Send prog info request ===")

        info = self.agent.get_device_info()

        self.ram_buffer_start = info.ram_buffer_start
        self.ram_buffer_size = info.ram_buffer_size

        logger.info("Prog info response received:")
        logger.info(f" App version: 0x{info.app_version or 0:04X}")
        logger.info("  Protocol version:"
                    f" 0x{info.protocol_version or 0:04X}")
        logger.info("  RAM buffer start:"
                    f" 0x{info.ram_buffer_start or 0:08X}")
        logger.info("  RAM buffer size:"
                    f" 0x{info.ram_buffer_size or 0:08X}")

        return True

    def close_uart_loader_serial(self):
        """Close uart loader serial port"""
        if self.uart_loader_serial_module and \
            self.uart_loader_serial_module.is_open:
            self.uart_loader_serial_module.close()
            logger.info("Uart loader serial port closed")
            self.uart_loader_serial_module = None

    class DumpBinaryContext(BaseModel):
        address: int = Field(
            description="Target storage address for dumping"
        )
        size: int = Field(
            description="Size of data"
        )
        check_dump_sha256: bool = Field(
            description="Check SHA256 after dumping"
        )
        dump_timeout_s: int = Field(
            description="Timeout for dump operation in seconds"
        )
    def dump_binary(self, ctx: DumpBinaryContext) -> bytes:
        assert self.agent is not None
        logger.info(f"=== Dumping data ===")
        logger.info(f"  target_address: 0x{ctx.address:08X}")
        logger.info(f"  size: 0x{ctx.size:08X}")
        logger.info(f"  check_sha256: {ctx.check_dump_sha256}")
        logger.info(f"  timeout: {ctx.dump_timeout_s} s")
        
        dump_stime = time.time()
        binary_data = self.agent.dump_data(ctx.address, ctx.size,
                                           timeout_s=ctx.dump_timeout_s)
        dump_time_ms = time.time() - dump_stime
        logger.info(f"Dump time: {dump_time_ms:.3f}"
                    f"({ctx.size / dump_time_ms / 1000:.3f} kB/s) ms")
        if ctx.check_dump_sha256:
            self._check_sha256(ctx.address, ctx.size, binary_data)
        logger.info(f"=== Dump completed ===")
        return binary_data
    
    def _check_sha256(self, target_address: int, size: int, binary_data: bytes
                     ) -> None:
        assert self.agent is not None
        sha256_device = self.agent.cal_sha256_from_device(target_address, size)
        sha256_dumped = sha256(binary_data).hexdigest()[:32]
        logger.info(f"SHA256 from device: {sha256_device.hex()}")
        logger.info(f"SHA256 from dumped: {sha256_dumped}")
        if sha256_device.hex() != sha256_dumped:
            raise UartLoaderTaskError(
                f"SHA256 mismatch: {sha256_device.hex()} != {sha256_dumped}",
                UartLoaderTaskErrorCode.DUMP_CHECK_SHA256_MISMATCH
            )
    
    
    class EraseDataContext(BaseModel):
        target_address: int = Field(
            description="Target address for erase"
        )
        erase_size: int = Field(
            description="Size of data to erase"
        )
        partition_name: str = Field(
            default="",
            description="Partition name for erase"
        )
        erase_timeout: float = Field(
            default=5.0,
            description="Timeout for erase operation"
        )
    def erase_data(self, ctx: EraseDataContext) -> None:
        """Erase data from target address"""

        if ctx.erase_size == 0:
            logger.debug("Erase size is 0, skip erase")
            return

        assert self.agent is not None
        if ctx.partition_name == "":
            logger.info(f"=== Erasing data from 0x{ctx.target_address:08X}"
                          f" with size 0x{ctx.erase_size:08X} ===")
        else:
            logger.info(
                f"=== Erasing {ctx.partition_name} from"
                f" 0x{ctx.target_address:08X}"
                f" with size 0x{ctx.erase_size:08X} ===")
        
        erase_stime = time.time()
        self.agent.erase(ctx.target_address, ctx.erase_size,
                         timeout_s=ctx.erase_timeout)
        erase_time_ms = time.time() - erase_stime
        logger.info(f"Erase time: {erase_time_ms * 1000:.0f} ms")
        logger.info(f"=== Erase completed ===")

    class ProgramBinaryContextBase(BaseModel):
        target_address: int = Field(
            description="Target storage address for programming"
        )
        check_ram_xor: bool = Field(
            default=False,
            description="Enable RAM block XOR verification"
        )
        check_rram_xor: bool = Field(
            default=False,
            description="Enable RRAM block XOR verification"
        )
        check_program_sha256: bool = Field(
            default=True,
            description="Enable SHA256 verification"
        )
        count_group_packets: int = Field(
            default=16,
            description="Count of group packets"
        )
        group_index_check_rsp: int = Field(
            default=15,
            description="Group index for response checking"
        )
        skip_send_ff_packets: bool = Field(
            default=True,
            description="Skip send all 0xFF packets"
        )

    class ProgramBinFileContext(ProgramBinaryContextBase):
        bin_file_path: str = Field(
            description="Path to binary file to program via uart loader",
        )
        erase_size: int = Field(
            default=0,
            description="Size of data to erase before programming"
        )
        erase_timeout: float = Field(
            default=5.0,
            description="Timeout for erase operation"
        )

    def program_binary_file(self, ctx: ProgramBinFileContext) -> None:
        """Program a binary file and return timing report"""

        logger.info(f"=== Programming binary file: {ctx.bin_file_path} ===")
        if not ctx.bin_file_path:
            raise UartLoaderTaskError(
                "No bin_file_path specified",
                UartLoaderTaskErrorCode.PROG_FILE_NOT_EXIST)
        # Read binary file
        logger.info(f"Reading firmware binary: {ctx.bin_file_path}")
        with open(ctx.bin_file_path, 'rb') as f:
            binary_data = f.read()

        binary_size = len(binary_data)

        if ctx.erase_size != 0 and ctx.erase_size < binary_size:
            raise UartLoaderTaskError(
                f"Erase size {ctx.erase_size} is smaller than binary size"
                f" {binary_size}",
                UartLoaderTaskErrorCode.PROG_ERASE_SIZE_SMALLER_THAN_BINARY_SIZE
            )

        if ctx.erase_size != 0:
            self.erase_data(self.EraseDataContext(
                target_address=ctx.target_address,
                erase_size=ctx.erase_size,
                erase_timeout=ctx.erase_timeout
            ))

        self.program_binary(self.ProgramBinaryContext(
            binary_title=ctx.bin_file_path,
            binary_data=binary_data,
            is_storage_erased=ctx.erase_size >= binary_size,
            **ctx.model_dump(),
        ))

    class ProgramBinaryContext(ProgramBinaryContextBase):
        binary_title: str = Field(
            description="Title of binary data"
        )
        binary_data: bytes = Field(
            description="Binary data to program"
        )
        is_storage_erased: bool = Field(
            default=False,
            description="Is storage erased before programming"
        )

    def program_binary(self, ctx: ProgramBinaryContext) -> None:
        """Program a binary data and return timing report"""

        file_size = len(ctx.binary_data)
        # Programming phase
        logger.info(f"=== Programming {ctx.binary_title} start ===")
        prog_stime = time.time()
        result = self._run_programming_phase(ctx)
        prog_time_ms = int((time.time() - prog_stime) * 1000)

        skipstr = ""
        if result.skiped_ff_size != 0:
            skipstr = ","
            if result.skiped_ff_chunks != 0:
                skipstr += f" skip {result.skiped_ff_chunks} chunks,"
            if result.skiped_ff_packets != 0:
                skipstr += f" skip {result.skiped_ff_packets} packets,"
            skipstr += f" total skip {result.skiped_ff_size} bytes"

        logger.info(f"Programming {ctx.binary_title} time: {prog_time_ms} ms"
                    f"({file_size / prog_time_ms :.2f} kB/s{skipstr})")

        logger.info(f"=== Programming {ctx.binary_title} successfully ===")

        # Verification phase
        logger.info(f"=== Verifying {ctx.binary_title} SHA256 start ===")
        verify_stime = time.time()
        self._run_verification_phase(ctx)
        verify_time_ms = int((time.time() - verify_stime) * 1000)
        logger.info(f"Verification time: {verify_time_ms} ms")
        logger.info(f"=== Verifying {ctx.binary_title} SHA256 successfully ===")

        # Create and populate report
        report = UartLoaderSpendTimeReport()
        report.cmd_start_timestamp = self.start_timestamp
        report.serial = self.ctx.uart_loader_serial
        report.file = ctx.binary_title
        report.size = file_size
        report.programming_time_ms = prog_time_ms
        report.verification_time_ms = verify_time_ms
        CsvReportHelper.append(report)

        assert self.full_report is not None
        if self.full_report.uart_loader_bin_size is not None:
            self.full_report.uart_loader_bin_size += file_size
        if self.full_report.full_programming_time_ms is not None:
            self.full_report.full_programming_time_ms += prog_time_ms
        if self.full_report.full_verification_time_ms is not None:
            self.full_report.full_verification_time_ms += verify_time_ms

        logger.info(f"=== Binary file programming completed: {ctx.binary_title}"
                    f"({file_size} bytes) ===")

    def _run_programming_phase(self, ctx: ProgramBinaryContext) -> ProgramResult:
        """Execute firmware programming phase"""
        assert self.agent is not None

        file_size = len(ctx.binary_data)
        logger.info(f"Firmware binary size: {file_size} bytes"
                    f" (0x{file_size:08X})")

        # Create HCI module and programming agent

        # Program firmware using agent (without final SHA256 verification)
        logger.info("Starting firmware programming...")
        logger.info(f"Target storage address: 0x{ctx.target_address:08X}")

        return self.agent.program_binaries(
            ctx.binary_data,
            ram_buffer_start=self.ram_buffer_start,
            ram_buffer_size=self.ram_buffer_size,
            target_storage_addr=ctx.target_address,
            check_ram_xor=ctx.check_ram_xor,
            check_rram_xor=ctx.check_rram_xor,
            count_group_packets=ctx.count_group_packets,
            group_index_check_rsp=ctx.group_index_check_rsp,
            skip_send_ff_packets=ctx.skip_send_ff_packets,
            is_storage_erased=ctx.is_storage_erased
        )

    def _run_verification_phase(self, ctx: ProgramBinaryContext) -> None:
        """Execute firmware SHA256 verification phase"""

        if not ctx.check_program_sha256:
            logger.info("SHA256 verification disabled, skipping verification")
            return

        assert self.agent is not None

        # Perform SHA256 verification
        logger.info("Starting SHA256 verification...")
        sha256_hash = self.agent.cal_sha256_from_device(ctx.target_address,
                                                        len(ctx.binary_data))

        local_hash = sha256(ctx.binary_data).hexdigest()[:32]
        device_hash = sha256_hash.hex()

        logger.info(f"SHA256 from device: {device_hash}")
        logger.info(f"SHA256 from local:  {local_hash}")

        if device_hash != local_hash:
            logger.error("SHA256 verification FAILED!")
            raise UartLoaderTaskError(
                "SHA256 verification failed - hash mismatch",
                UartLoaderTaskErrorCode.PROG_CHECK_SHA256_MISMATCH
            )
        logger.info("SHA256 verification PASSED!")