from __future__ import annotations
from logging import getLogger
import time
from pydantic import Field

from Task.TaskContextBase.task_reset_support_context_base import (
    TaskResetSupportContextBase
)
from Task.TaskContextBase.task_uart_loader_context_base import (
    TaskUartLoaderContextBase, TaskFuncChangeBaudrateContextBase
)
from Task.TaskOperation.reset_operation import (
    ResetOperation, UartLoaderBootSource
)
from Task.TaskOperation.xmodem_operation import (
    XmodemOperation
)
from Task.TaskOperation.uart_loader_operation import (
    UartLoaderOperation
)
from pydantic_argparse.argparse_helper_funcs import ArgExtra

logger = getLogger(__name__)

class TaskDiagnoTputContext(TaskResetSupportContextBase,
                            TaskUartLoaderContextBase,
                            TaskFuncChangeBaudrateContextBase):
    data_len: int = Field(
        description="Data length"
    )
    count_group_packets: int = Field(
        default=16,
        description="Count of group packets"
    )
    timeout: int = Field(
        default=15,
        description="Timeout for diagno throughput operation in seconds"
    )


class TaskDiagnoTput:
    """Task class for test diagno throughput via uart loader protocol (Operation-based)"""
    
    def __init__(self, ctx: TaskDiagnoTputContext):
        """Initialize TaskDumpBinFile"""
        self.ctx = ctx

        self.reset_operation = ResetOperation(
            ctx.to_reset_operation_init_context()
        )

        self.xmodem_operation = XmodemOperation(
            ctx.to_xmodem_init_context(),
        )

        self.uart_load_operation = UartLoaderOperation(
            ctx.to_uart_loader_init_context(),
        )

    def run(self):
        """Main execution function for TaskDumpBinFile task"""
        logger.info("=== TaskDumpBinFile Task Started"
                    f" (Mode: {self.ctx.uart_loader_mode.value}) ===")
        try:
            self._run_body()
            logger.info("=== TaskDumpBinFile Task Completed ===")
        finally:
            # Clean up
            try:
                self.uart_load_operation.close_uart_loader_serial()
            except:
                pass

    def _run_body(self):
        """Core task execution logic without error handling"""
        # Execute mode-specific initialization phase
        self.reset_operation.execute_reset_phase()

        if self.ctx.uart_loader_mode == UartLoaderBootSource.RAM:
            self.xmodem_operation.execute_ram_mode_phase()

        # Common phases for both modes
        self.uart_load_operation.init_uart_loader_serial()
        self.uart_load_operation.wait_for_boot_event()
        self.uart_load_operation.change_baudrate(
            self.ctx.to_change_baudrate_context())
        
        assert self.uart_load_operation.agent is not None
        stime = time.time()
        self.uart_load_operation.agent.diagno_tput(
            self.ctx.data_len, self.ctx.count_group_packets, self.ctx.timeout)
        spend_time = (time.time() - stime) * 1000
        logger.info(
            f"Diagno throughput time: {spend_time:.3f} ms"
            f" ({self.ctx.data_len / spend_time:.3f} kB/s)")
