from __future__ import annotations
from logging import getLogger
from typing import List
from pydantic import Field

from Task.TaskContextBase.task_reset_support_context_base import (
    TaskResetSupportContextBase
)
from Task.TaskContextBase.task_uart_loader_context_base import (
    TaskUartLoaderContextBase
)

from Task.TaskOperation.reset_operation import (
    ResetOperation, UartLoaderBootSource
)
from Task.TaskOperation.xmodem_operation import (
    XmodemOperation
)
from Task.TaskOperation.uart_loader_operation import (
    UartLoaderOperation
)
from pydantic_argparse.argparse_helper_funcs import ArgExtra

logger = getLogger(__name__)

class TaskDiagnoLatencyContext(TaskResetSupportContextBase,
                               TaskUartLoaderContextBase):
    rram_address: int = Field(
        default=0x15000,
        description="Start address for RRAM"
    )
    flash_address: int = Field(
        default=0x215000,
        description="Start address for flash"
    )

class TaskDiagnoLatency:
    """Task class for programming binary files via uart loader protocol (Operation-based)"""
    TEST_DATA_LIST = [0x55, 0xAA]
    TEST_LEN_LIST = [0x1000, 0x2000, 0x4000]
    def __init__(self, ctx: TaskDiagnoLatencyContext):
        """Initialize TaskDumpBinFile"""
        self.ctx = ctx

        self.reset_operation = ResetOperation(
            ctx.to_reset_operation_init_context()
        )

        self.xmodem_operation = XmodemOperation(
            ctx=ctx.to_xmodem_init_context(),
        )

        self.uart_loader_operation = UartLoaderOperation(
            ctx=ctx.to_uart_loader_init_context(),
        )

    def run(self):
        """Main execution function for TaskDumpBinFile task"""
        logger.info("=== TaskDumpBinFile Task Started"
                    f" (Mode: {self.ctx.uart_loader_mode.value}) ===")
        try:
            self._run_body()
            logger.info("=== TaskDumpBinFile Task Completed ===")
        finally:
            # Clean up
            try:
                self.uart_loader_operation.close_uart_loader_serial()
            except:
                pass

    def _run_body(self):
        """Core task execution logic without error handling"""
        # Execute mode-specific initialization phase
        self.reset_operation.execute_reset_phase()

        if self.ctx.uart_loader_mode == UartLoaderBootSource.RAM:
            self.xmodem_operation.execute_ram_mode_phase()

        # Common phases for both modes
        self.uart_loader_operation.init_uart_loader_serial()
        self.uart_loader_operation.wait_for_boot_event()

        if self.ctx.rram_address != 0:
            for data_len in self.TEST_LEN_LIST:
                self._run_diagno_latency(
                    "RRAM", self.TEST_DATA_LIST, self.ctx.rram_address,
                    data_len)
        if self.ctx.flash_address != 0:
            for data_len in self.TEST_LEN_LIST:
                self._run_diagno_latency(
                    "FLASH", self.TEST_DATA_LIST, self.ctx.flash_address,
                    data_len)

    def _run_diagno_latency(self, result_title:str, test_data_list: List[int],
                            start_address: int, data_len: int) -> None:
        """Run diagno latency"""
        assert self.uart_loader_operation.agent is not None
        total_write_time_ram = 0
        total_erase_time_storage = 0
        total_write_time_storage = 0
        for d in test_data_list:
            diagno_latency_result = \
                self.uart_loader_operation.agent.diagno_latency(
                    d, start_address, data_len)
            logger.debug(f"Diagno latency result ({d:02X}, {data_len:#04X}):")
            logger.debug("  write_time_ram:"
                         f" {diagno_latency_result.write_time_ram} ms")
            logger.debug("  erase_time_storage:"
                         f" {diagno_latency_result.erase_time_storage} ms")
            logger.debug("  write_time_storage:"
                         f" {diagno_latency_result.write_time_storage} ms")
            total_write_time_ram += diagno_latency_result.write_time_ram
            total_erase_time_storage += diagno_latency_result.erase_time_storage
            total_write_time_storage += diagno_latency_result.write_time_storage
        
        testing_times = len(test_data_list)
        t_ram = total_write_time_ram / testing_times
        spd_ram = f"({data_len / t_ram:.3f} kB/s)" if t_ram > 0 else ""
        t_erase = total_erase_time_storage / testing_times
        spd_erase = f"({data_len / t_erase:.3f} kB/s)" if t_erase > 0 else ""
        t_write = total_write_time_storage / testing_times
        spd_write = f"({data_len / t_write:.3f} kB/s)" if t_write > 0 else ""
        

        logger.info(f"Diagno latency AVG ({result_title}, {data_len:#04X}):")
        logger.info(f"  write_time_ram: {t_ram:.1f} ms{spd_ram}")
        logger.info(f"  erase_time_storage: {t_erase:.1f} ms{spd_erase}")
        logger.info(f"  write_time_storage: {t_write:.1f} ms{spd_write}")