from dataclasses import dataclass
from typing import List
from pydantic import BaseModel, Field
from pathlib import Path
import re

import atm_file_parser.atm_isp_pb2 as archive_pb2


class RramCommand(BaseModel):
    image: bytes = Field(description="Image data") 
    address: int = Field(description="Absolute address")
    partition_name: str = Field(description="Partition name")

class FlashCommand(BaseModel):
    image: bytes = Field(description="Image data")
    address: int = Field(description="Absolute address")
    partition_size: int = Field(description="Partition size")
    partition_name: str = Field(description="Partition name")

class EraseFlashCommand(BaseModel):
    region_size: int = Field(description="Region size")
    address: int = Field(description="Relative address")
    base_address: int = Field(description="Base address for storage")

class EraseRramCommand(BaseModel):
    region_start: int = Field(description="Relative address")
    region_size: int = Field(description="Region size")
    base_address: int = Field(description="Base address for storage")

EXTRACTORS = {
  "loadFlash": lambda cu: (
    cu.loadFlash.commonLoadFlash.commonLoad.image,
    f"addr_{cu.loadFlash.address:08X}_{cu.loadFlash.extrainfo}"
  ),
  "loadRram": lambda cu: (
    cu.loadRram.commonLoadRram.commonLoad.image,
    f"addr_{cu.loadRram.address:08X}_{cu.loadRram.extrainfo}"
  ),
  "loadFlashNvds": lambda cu: (
    cu.loadFlashNvds.commonLoadFlash.commonLoad.image,
    f"{cu.loadFlashNvds.extrainfo}"
  ),
  "loadOtpNvds": lambda cu: (
    cu.loadOtpNvds.commonLoad.image,
    "otp_nvds"
  ),
  "cmdExtend": lambda cu: (
    cu.cmdExtend.commonLoad.image,
    f"type_{cu.cmdExtend.type}_{cu.cmdExtend.extrainfo}"
  ),
  "nvdsReadModWrite": lambda cu: (
    cu.nvdsReadModWrite.nvdsContent,
    "nvds"
  ),
}

class AtmFileParser:
    def __init__(self, path: str):
        self.path = path
        self.archive = archive_pb2.Archive()
        with open(path, "rb") as f:
            self.archive.ParseFromString(f.read())
        self.meta = self.archive.meta

    def get_rram_commands(self) -> List[RramCommand]:
        cmds: List[RramCommand] = []
        for idx, cmd_union in enumerate(self.archive.script):
            cmd_name = cmd_union.WhichOneof("cmdUnion")
            if cmd_name == "loadRram":
                cmd = cmd_union.loadRram
                common = cmd.commonLoadRram
                info = RramCommand(
                    image=common.commonLoad.image,
                    address=common.region_start,
                    partition_name=cmd.extrainfo,
                )
                cmds.append(info)
        return cmds

    def get_erase_flash_commands(self) -> List[EraseFlashCommand]:
        external_flash_offset =\
            0x200000 if self.meta.platform.family == "atm34" else 0
        cmds: List[EraseFlashCommand] = []
        for idx, cmd_union in enumerate(self.archive.script):
            cmd_name = cmd_union.WhichOneof("cmdUnion")
            if cmd_name == "eraseFlash":
                cmd = cmd_union.eraseFlash
                common = cmd.commonEraseFlash
                info = EraseFlashCommand(
                    region_size=common.region_size,
                    address=external_flash_offset + common.address,
                    base_address=cmd.base_address,
                )
                cmds.append(info)
        return cmds

    def get_flash_commands(self) -> List[FlashCommand]:
        external_flash_offset =\
            0x200000 if self.meta.platform.family == "atm34" else 0
        cmds: List[FlashCommand] = []
        for idx, cmd_union in enumerate(self.archive.script):
            cmd_name = cmd_union.WhichOneof("cmdUnion")
            if cmd_name == "loadFlash":
                cmd = cmd_union.loadFlash
                common = cmd.commonLoadFlash
                info = FlashCommand(
                    image=common.commonLoad.image,
                    address=external_flash_offset + common.region_start,
                    partition_size=common.region_size,
                    partition_name=cmd.extrainfo,
                )
                cmds.append(info)
            elif cmd_name == "loadFlashNvds":
                cmd = cmd_union.loadFlashNvds
                common = cmd.commonLoadFlash
                info = FlashCommand(
                    image=common.commonLoad.image,
                    address=external_flash_offset + common.region_start,
                    partition_size=common.region_size,
                    partition_name=cmd.extrainfo,
                )
                cmds.append(info)
        return cmds

    def get_erase_rram_commands(self) -> List[EraseRramCommand]:
        cmds: List[EraseRramCommand] = []
        for idx, cmd_union in enumerate(self.archive.script):
            cmd_name = cmd_union.WhichOneof("cmdUnion")
            if cmd_name == "eraseRram":
                cmd = cmd_union.eraseRram
                common = cmd.commonEraseRram
                info = EraseRramCommand(
                    region_start=common.region_start,
                    region_size=common.region_size,
                    base_address=cmd.base_address,
                )
                cmds.append(info)
        return cmds

    def get_all_info(self) -> List[str]:
        info: List[str] = []

        # Add meta information
        meta = self.meta
        info.append("=== Archive Meta Information ===")
        info.append(f"Signature: {meta.signature}")
        info.append(f"Platform: {meta.platform.family}/{meta.platform.name}/{meta.platform.board} (rev: {meta.platform.revision})")
        info.append(f"MPR Start: 0x{meta.mpr_start:08X}")
        info.append(f"MPR Size: 0x{meta.mpr_size:08X}")
        info.append(f"MPR Lock Size: 0x{meta.mpr_lock_size:08X}")
        info.append(f"OTA: {meta.ota}")
        info.append(f"SDK Version: {meta.sdk_ver}")
        info.append(f"Secure Debug: {meta.sec_dbg}")
        info.append(f"Secure Boot: {meta.sec_boot}")
        info.append("")

        # Add script commands information
        info.append("=== Script Commands ===")
        for idx, cmd_union in enumerate(self.archive.script):
            cmd_name = cmd_union.WhichOneof("cmdUnion")
            info.append(f"Command {idx}: {cmd_name}")

            if cmd_name == "loadFlash":
                cmd = cmd_union.loadFlash
                common = cmd.commonLoadFlash
                info.append(f"  - Flash Load Command")
                info.append(f"  - Address: 0x{cmd.address:08X}")
                info.append(f"  - Region Start: 0x{common.region_start:08X}")
                info.append(f"  - Region Size: 0x{common.region_size:08X}")
                info.append(f"  - Image Size: {len(common.commonLoad.image)} bytes")
                info.append(f"  - Partition: {cmd.extrainfo}")

            elif cmd_name == "loadRram":
                cmd = cmd_union.loadRram
                common = cmd.commonLoadRram
                info.append(f"  - RRAM Load Command")
                info.append(f"  - Address: 0x{cmd.address:08X}")
                info.append(f"  - Region Start: 0x{common.region_start:08X}")
                info.append(f"  - Region Size: 0x{common.region_size:08X}")
                info.append(f"  - Image Size: {len(common.commonLoad.image)} bytes")
                info.append(f"  - Partition: {cmd.extrainfo}")

            elif cmd_name == "loadFlashNvds":
                cmd = cmd_union.loadFlashNvds
                common = cmd.commonLoadFlash
                info.append(f"  - Flash NVDS Load Command")
                info.append(f"  - Region Start: 0x{common.region_start:08X}")
                info.append(f"  - Region Size: 0x{common.region_size:08X}")
                info.append(f"  - Image Size: {len(common.commonLoad.image)} bytes")
                info.append(f"  - Extra Info: {cmd.extrainfo}")

            elif cmd_name == "loadOtpNvds":
                cmd = cmd_union.loadOtpNvds
                info.append(f"  - OTP NVDS Load Command")
                info.append(f"  - Image Size: {len(cmd.commonLoad.image)} bytes")

            elif cmd_name == "eraseRram":
                cmd = cmd_union.eraseRram
                common = cmd.commonEraseRram
                info.append(f"  - RRAM Erase Command")
                info.append(f"  - Base Address: 0x{cmd.base_address:08X}")
                info.append(f"  - Region Start: 0x{common.region_start:08X}")
                info.append(f"  - Region Size: 0x{common.region_size:08X}")

            elif cmd_name == "eraseFlash":
                cmd = cmd_union.eraseFlash
                common = cmd.commonEraseFlash
                info.append(f"  - Flash Erase Command")
                info.append(f"  - Base Address: 0x{cmd.base_address:08X}")
                info.append(f"  - Address: 0x{common.address:08X}")
                info.append(f"  - Region Size: 0x{common.region_size:08X}")

            elif cmd_name == "nvdsReadModWrite":
                cmd = cmd_union.nvdsReadModWrite
                info.append(f"  - NVDS Read/Modify/Write Command")
                info.append(f"  - Invert: {cmd.invert}")
                info.append(f"  - NVDS Content Size: {len(cmd.nvdsContent)} bytes")

            elif cmd_name == "cmdExtend":
                cmd = cmd_union.cmdExtend
                info.append(f"  - Extended Command")
                info.append(f"  - Type: {cmd.type}")
                info.append(f"  - Image Size: {len(cmd.commonLoad.image)} bytes")
                info.append(f"  - Extra Info: {cmd.extrainfo}")

            info.append("")  # Empty line between commands

        return info

    def export_binaries(self, out_dir: str) -> None:
        Path(out_dir).mkdir(parents=True, exist_ok=True)
        for i, cu in enumerate(self.archive.script):
            name = cu.WhichOneof("cmdUnion")
            if name not in EXTRACTORS: continue
            data, suf = EXTRACTORS[name](cu)
            if not data: continue
            safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", str(suf)).strip("_")
            fn = Path(out_dir) / f"{i:03d}_{name}_{safe}.bin"
            with open(fn, "wb") as f: f.write(bytes(data))

    def append_rram(self, bin_path: str, address: int, extra_info: str) -> None:
        data = Path(bin_path).read_bytes()
        cmd = archive_pb2.Archive.Command()
        cmd.loadRram.commonLoadRram.commonLoad.image = data
        cmd.loadRram.commonLoadRram.region_start = address
        cmd.loadRram.commonLoadRram.region_size = len(data)
        cmd.loadRram.address = address
        cmd.loadRram.extrainfo = extra_info
        self.archive.script.append(cmd)

    def append_flash(self, bin_path: str, address: int, region_size: int,
                     extra_info: str) -> None:
        data = Path(bin_path).read_bytes()
        data_len = len(data)
        if region_size < data_len:
            raise ValueError(f"Region size {region_size}(0x{region_size:X}) is"
                             " smaller than data size"
                             f" {data_len}(0x{data_len:X})")
        if region_size % 0x1000 != 0:
            raise ValueError(f"Region size {region_size}(0x{region_size:X})"
                             " is not aligned to 0x1000")
        cmd = archive_pb2.Archive.Command()
        cmd.loadFlash.commonLoadFlash.commonLoad.image = data
        cmd.loadFlash.commonLoadFlash.region_start = address
        cmd.loadFlash.commonLoadFlash.region_size = region_size
        cmd.loadFlash.address = address
        cmd.loadFlash.extrainfo = extra_info
        self.archive.script.append(cmd)

    def append_erase_flash(self, address: int, region_size: int) -> None:
        if region_size % 0x1000 != 0:
            raise ValueError(f"Region size {region_size}(0x{region_size:X})"
                             " is not aligned to 0x1000")
        cmd = archive_pb2.Archive.Command()
        cmd.eraseFlash.commonEraseFlash.address = address
        cmd.eraseFlash.commonEraseFlash.region_size = region_size
        cmd.eraseFlash.base_address = 0x200000
        self.archive.script.append(cmd)

    def append_erase_rram(self, address: int, region_size: int) -> None:
        if region_size % 0x1000 != 0:
            raise ValueError(f"Region size {region_size}(0x{region_size:X})"
                             " is not aligned to 0x1000")
        cmd = archive_pb2.Archive.Command()
        cmd.eraseRram.commonEraseRram.region_start = address
        cmd.eraseRram.commonEraseRram.region_size = region_size
        cmd.eraseRram.base_address = 0
        self.archive.script.append(cmd)

    def save(self, path: str) -> None:
        with open(path, "wb") as f:
            f.write(self.archive.SerializeToString())
