from __future__ import annotations
from logging import getLogger
from pydantic import Field
from pathlib import Path

from Task.TaskContextBase.task_uart_loader_context_base import (
    TaskUartLoaderContextBase, TaskFuncChangeBaudrateContextBase
)
from Task.TaskContextBase.task_reset_support_context_base import (
    TaskResetSupportContextBase
)

from Task.TaskOperation.reset_operation import (
    ResetOperation, UartLoaderBootSource
)
from Task.TaskOperation.xmodem_operation import XmodemOperation
from Task.TaskOperation.uart_loader_operation import UartLoaderOperation
from pydantic_argparse.argparse_helper_funcs import ArgExtra

logger = getLogger(__name__)

class TaskDumpBinFileContext(TaskResetSupportContextBase,
                             TaskUartLoaderContextBase,
                             TaskFuncChangeBaudrateContextBase):
    output: str = Field(
        description="Path to binary file to dump via uart loader",
    )
    address: int = Field(
        description="Target storage address for dumping"
    )
    merge_exist_file: bool = Field(
        default=False,
        description="Merge existing file"
    )
    output_offset: int = Field(
        default=0,
        description="Offset address of output file"
    )
    size: int = Field(
        description="Size of data"
    )
    skip_check_sha256: bool = Field(
        default=False,
        description="Check SHA256 after dumping"
    )
    dump_timeout_s: int = Field(
        default=300,
        description="Timeout for dump operation in seconds"
    )

    def to_dump_binary_context(self) -> UartLoaderOperation.DumpBinaryContext:
        return UartLoaderOperation.DumpBinaryContext(
            address=self.address,
            size=self.size,
            check_dump_sha256= not self.skip_check_sha256,
            dump_timeout_s=self.dump_timeout_s
        )

class TaskDumpBinFile:
    """Task class for dumping binary files via uart loader protocol (Operation-based)"""
    
    def __init__(self, ctx: TaskDumpBinFileContext):
        """Initialize TaskDumpBinFile"""
        self.ctx = ctx

        self.reset_operation = ResetOperation(
            ctx.to_reset_operation_init_context()
        )

        self.xmodem_operation = XmodemOperation(
            ctx.to_xmodem_init_context(),
        )

        self.uart_loader_operation = UartLoaderOperation(
            ctx=ctx.to_uart_loader_init_context(),
        )

    def run(self):
        """Main execution function for TaskDumpBinFile task"""
        logger.info("=== TaskDumpBinFile Task Started"
                    f" (Mode: {self.ctx.uart_loader_mode.value}) ===")
        try:
            self._run_body()
            logger.info("=== TaskDumpBinFile Task Completed ===")
        finally:
            # Clean up
            try:
                self.reset_operation.module_stop()
            except:
                pass
            try:
                self.uart_loader_operation.close_uart_loader_serial()
            except:
                pass

    def _run_body(self):
        """Core task execution logic without error handling"""
        # Execute mode-specific initialization phase
        self.reset_operation.execute_reset_phase()

        if self.ctx.uart_loader_mode == UartLoaderBootSource.RAM:
            self.xmodem_operation.execute_ram_mode_phase()

        # Common phases for both modes
        self.uart_loader_operation.init_uart_loader_serial()
        self.uart_loader_operation.wait_for_boot_event()
        self.uart_loader_operation.get_prog_info()
        self.uart_loader_operation.change_baudrate(
            self.ctx.to_change_baudrate_context())
        # Dump the binary
        b = self.uart_loader_operation.dump_binary(
            self.ctx.to_dump_binary_context())
        
        output = Path(self.ctx.output)
        if not self.ctx.merge_exist_file:
           output.unlink(missing_ok=True)

        open_mode = "w+b" if not output.exists() else "r+b"

        with open(output, open_mode) as f:
            if self.ctx.output_offset != 0:
                f.seek(self.ctx.output_offset)
            f.write(b)
