"""
TaskProgrammingBinFile - Program binary file via uart loader protocol

This is a refactored version of TaskProgrammingBinFile that uses reusable
operation classes (XmodemOperation and UartLoaderOperation) to reduce
code duplication and improve maintainability.

This task supports two modes for binary programming:
1. RAM mode: Reset to bboot mode -> xmodem upload -> uart loader programming
2. RRAM mode: Reset to normal mode -> direct uart loader programming

The task provides comprehensive firmware programming with optional verification:
- XOR verification for RAM and RRAM blocks
- SHA256 verification for programmed firmware
- Flexible reset support (none, jlinkob, rts)
"""

from __future__ import annotations
from pathlib import Path
from datetime import datetime
import time
from logging import getLogger
from pydantic import Field

from Task.TaskContextBase.task_program_context_base import (
    TaskProgramContextBase
)
from Task.TaskContextBase.task_reset_support_context_base import (
    TaskResetSupportContextBase
)
from csv_report.csv_report_helper import CsvReportHelper
from csv_report_struct.program_full_spend_time import ProgramFullSpendTimeReport
from Task.TaskOperation.reset_operation import ResetOperation
from Task.TaskOperation.xmodem_operation import XmodemOperation
from Task.TaskOperation.uart_loader_operation import UartLoaderOperation
from Task.TaskOperation.reset_operation import UartLoaderBootSource
from atm_file_parser.atm_file_parser import AtmFileParser

logger = getLogger(__name__)

class TaskProgramAtmFileContext(TaskResetSupportContextBase,
                                TaskProgramContextBase):

    atm_path: str = Field(
        description="Path to .atm file to program via fast-load"
    )
    erase_timeout: float = Field(
        default=5.0,
        description="Timeout for erase operation"
    )

    def to_program_binary_context(self, binary_title: str, binary_data: bytes,
                                  target_address: int, is_storage_erased: bool
        ) -> UartLoaderOperation.ProgramBinaryContext:
        return UartLoaderOperation.ProgramBinaryContext(
            binary_title=binary_title,
            binary_data=binary_data,
            is_storage_erased=is_storage_erased,
            **self.to_program_binary_context_base(target_address).model_dump()
        )


class TaskProgramAtmFile:
    """Task class for programming .atm files via uart loader protocol (Operation-based)"""
    
    def __init__(self, ctx: TaskProgramAtmFileContext):
        """Initialize TaskProgramminAtmFile"""
        self.ctx = ctx
        self.start_timestamp = datetime.now()
        
        # Initialize full report
        self.full_report = ProgramFullSpendTimeReport()
        self.full_report.cmd_start_timestamp = self.start_timestamp
        self.full_report.serial = self.ctx.serial
        self.full_report.xmodem_ram_file = self.ctx.ram_image
        self.full_report.xmodem_ram_size =\
            Path(self.ctx.ram_image).stat().st_size
        self.full_report.uart_loader_file = self.ctx.atm_path
        self.full_report.uart_loader_bin_size =\
            Path(self.ctx.atm_path).stat().st_size
        self.full_report.full_programming_time_ms = 0
        self.full_report.full_verification_time_ms = 0

        self.reset_operation = ResetOperation(
            ctx.to_reset_operation_init_context()
        )

        self.xmodem_operation = XmodemOperation(
            ctx.to_xmodem_init_context(),
            start_timestamp=self.start_timestamp,
            full_report=self.full_report
        )

        self.uart_loader_operation = UartLoaderOperation(
            ctx.to_uart_loader_init_context(),
            start_timestamp=self.start_timestamp,
            full_report=self.full_report
        )
        
        logger.info(f"TaskProgramminAtmFile initialized:")
        logger.info(f"  Program binary: {self.ctx.atm_path}")

    def run(self):
        """Main execution function for TaskProgramminAtmFile task"""
        logger.info(f"=== TaskProgramminAtmFile Task Started"
                    f" (Mode: {self.ctx.uart_loader_mode.value}) ===")
        try:
            self._run_body()
            logger.info("=== TaskProgramminAtmFile Task Completed ===")
        finally:
            # Clean up
            try:
                self.reset_operation.module_stop()
            except:
                pass
            try:
                self.uart_loader_operation.close_uart_loader_serial()
            except:
                pass

    def _run_body(self):
        """Core task execution logic without error handling"""

        logger.info(f"Create ATM parser for file: {self.ctx.atm_path}")
        atm_file_parser = AtmFileParser(self.ctx.atm_path)

        erase_flashs = atm_file_parser.get_erase_flash_commands()
        erase_rrams = atm_file_parser.get_erase_rram_commands()
        load_flashs = atm_file_parser.get_flash_commands()
        load_rrams = atm_file_parser.get_rram_commands()
        if len(load_flashs) == 0 and len(load_rrams) == 0:
            logger.warning("No commands found in ATM file")
            return

        # Execute mode-specific initialization phase
        self.reset_operation.execute_reset_phase()

        if self.ctx.uart_loader_mode == UartLoaderBootSource.RAM:
            self.xmodem_operation.execute_ram_mode_phase()

        # Common phases for both modes
        self.uart_loader_operation.init_uart_loader_serial()

        wait_boot_event_stime = time.time()
        self.uart_loader_operation.wait_for_boot_event()
        wait_boot_event_spend_time =\
            int((time.time() - wait_boot_event_stime) * 1000)

        init_stime = time.time()
        self.uart_loader_operation.change_baudrate(
            self.ctx.to_change_baudrate_context())
        self.uart_loader_operation.get_prog_info()
        init_spend_time = int((time.time() - init_stime) * 1000)

        if len(erase_flashs) > 0:
            logger.info(f"Would erase {len(erase_flashs)} flash regions from"
                          " ATM file")
        if len(erase_rrams) > 0:
            logger.info(f"Would erase {len(erase_rrams)} rram regions from"
                          " ATM file")
        if len(load_rrams) > 0:
            logger.info(f"Would program {len(load_rrams)} rram blocks from ATM"
                          " file")
        if len(load_flashs) > 0:
            logger.info(f"Would program {len(load_flashs)} flash blocks from"
                          " ATM file")

        if len(erase_rrams) > 0:
            for rram in erase_rrams:
                self.uart_loader_operation.erase_data(
                    UartLoaderOperation.EraseDataContext(
                        target_address=rram.region_start,
                        erase_size=rram.region_size,
                        partition_name="",
                        erase_timeout=self.ctx.erase_timeout
                    )
                )

        if len(erase_flashs) > 0:
            for flash in erase_flashs:
                self.uart_loader_operation.erase_data(
                    UartLoaderOperation.EraseDataContext(
                        target_address=flash.address,
                        erase_size=flash.region_size,
                        partition_name="",
                        erase_timeout=self.ctx.erase_timeout
                    )
                )

        if len(load_flashs) > 0:
            for flash in load_flashs:
                self.uart_loader_operation.erase_data(
                    UartLoaderOperation.EraseDataContext(
                        target_address=flash.address,
                        erase_size=flash.partition_size,
                        partition_name=flash.partition_name,
                        erase_timeout=self.ctx.erase_timeout
                    )
                )

        if len(load_rrams) > 0:
            for rram in load_rrams:
                self.uart_loader_operation.program_binary(
                    self.ctx.to_program_binary_context(
                        rram.partition_name, rram.image, rram.address,
                        is_storage_erased=False
                    )
                )
        if len(load_flashs) > 0:
            for flash in load_flashs:
                self.uart_loader_operation.program_binary(  
                    self.ctx.to_program_binary_context(
                        flash.partition_name, flash.image, flash.address,
                        is_storage_erased=True
                    )
                )

        self.full_report.uart_loader_init_ms = init_spend_time
        self.full_report.wait_uart_loader_boot_event_time_ms =\
            wait_boot_event_spend_time
        CsvReportHelper.append(self.full_report)
