#!/usr/bin/env python
'''
@file packet_data_base.py

@brief Packet data base class

Copyright (C) Atmosic 2025
'''
from __future__ import annotations
from lib.checksum import xor_bytes
from packet.hci.hci_property import HciProperty
from packet.hci.hci_packet_base import HciPacketBase


class HciPacket(HciPacketBase):
    pktype = HciProperty[int](start=0, fmt="B", expected=0x04)

class HciCommand(HciPacket):
    opcode = HciProperty[int](start=1, fmt="<H")  # Little-endian unsigned short
    plen = HciProperty[int](start=3, fmt="B")

    def update_before_to_bytes(self) -> None:
        self.pktype = 0x01
        self.plen = 0 # To expand the payload buffer
        self.plen = len(self.payload) - 4

class HciEvent(HciPacket):
    evtcode = HciProperty[int](start=1, fmt="B", expected=0x0E)
    plen = HciProperty[int](start=2, fmt="B")

class HciEventCommandComplete(HciEvent):
    num_hci_command_pkts = HciProperty[int](start=3, fmt="B")
    opcode = HciProperty[int](start=4, fmt="<H")  # Little-endian unsigned short
    status = HciProperty[int](start=6, fmt="B")

class HciEventVendorSpecific(HciEvent):
    sub_event = HciProperty[int](start=3, fmt="B")

class HciBootStatusEvent:
    class EventVS(HciEventVendorSpecific):
        app_ver = HciProperty[bytes](start=4, fmt="s", expected=b"Atmosic")
 
class HciProgInfoRsp:
    class Command(HciCommand):
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF870
            # Set valid payload size for command (header only)
            self.adjust_payload_size(4)  # 3 bytes header + 1 byte plen
    class EventCC(HciEventCommandComplete):
        app_ver = HciProperty[int](start=7, fmt="<H")
        protocol_ver = HciProperty[int](start=9, fmt="<H")
        ram_buffer_start = HciProperty[int](start=11, fmt="<I")
        ram_buffer_size = HciProperty[int](start=15, fmt="<I")
        num_flash = HciProperty[int](start=19, fmt="B")
        flash_ids = HciProperty[bytes](start=20, fmt="s")

class HciProgBlkClean:
    class Command(HciCommand):
        start_addr = HciProperty[int](start=4, fmt="<I")
        length = HciProperty[int](start=8, fmt="<I")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF871
            # Set valid payload size for command (header + parameters)
            self.adjust_payload_size(12)  # 3 bytes header + 8 bytes parameters + 1 byte plen
    class EventCC(HciEventCommandComplete):
        pass

class HciProgBaudrateSet:
    class Command(HciCommand):
        baudrate = HciProperty[int](start=4, fmt="<I")
        delay_time_us = HciProperty[int](start=8, fmt="<H")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF872
            # Set valid payload size for command (header + parameters)
            self.adjust_payload_size(10)  # 3 bytes header + 6 bytes parameters + 1 byte plen
    class EventCC(HciEventCommandComplete):
        pass

def on_prog_write_data_cmd_data_change(instance: HciPacketBase):
    assert isinstance(instance, HciProgWriteData.Command)
    """Auto-update data_length and xor_checksum when data changes"""
    if instance.data is None:
        return
    # Auto-update data_length
    instance.data_length = len(instance.data)
    # Auto-update payload size to include all fields
    # Calculate required size: header + data + xor_checksum
    required_size = 11 + len(instance.data) + 1
    instance.adjust_payload_size(required_size)

def pos_prog_write_data_cmd_xor(instance: HciPacketBase) -> int:
    assert isinstance(instance, HciProgWriteData.Command)
    if instance.data_length is not None:
        return 11 + instance.data_length
    else:
        return 11

class HciProgWriteData:
    class Command(HciCommand):
        mode = HciProperty[int](start=4, fmt="B")
        address = HciProperty[int](start=5, fmt="<I")
        seq_no = HciProperty[int](start=9, fmt="B")
        data_length = HciProperty[int](start=10, fmt="B")
        data = HciProperty[bytes](start=11, fmt="s",
                                  on_change=on_prog_write_data_cmd_data_change)
        xor_checksum = HciProperty[int](start=pos_prog_write_data_cmd_xor,
                                        fmt="B")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF873
            
            if len(self.payload) > 0:
                self.xor_checksum = xor_bytes(self.payload[:-1])
    class EventCC(HciEventCommandComplete):
        seq_no = HciProperty[int](start=7, fmt="B")

class HciProgApplyBlk:
    class Command(HciCommand):
        source_address = HciProperty[int](start=4, fmt="<I")
        target_address = HciProperty[int](start=8, fmt="<I")
        length = HciProperty[int](start=12, fmt="<I")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF874
            # Set valid payload size for command (header + parameters)
            self.adjust_payload_size(16)  # 3 bytes header + 12 bytes parameters + 1 byte plen
    class EventCC(HciEventCommandComplete):
        pass

class HciProgXorCheck:
    class Command(HciCommand):
        address = HciProperty[int](start=4, fmt="<I")
        length = HciProperty[int](start=8, fmt="<I")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF875
            # Set valid payload size for command (header + parameters)
            self.adjust_payload_size(12)  # 3 bytes header + 8 bytes parameters + 1 byte plen
    class EventCC(HciEventCommandComplete):
        xor = HciProperty[int](start=7, fmt="B")

class HciProgSha256Check:
    class Command(HciCommand):
        address = HciProperty[int](start=4, fmt="<I")
        length = HciProperty[int](start=8, fmt="<I")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF876
            # Set valid payload size for command (header + parameters)
            self.adjust_payload_size(12)  # 3 bytes header + 8 bytes parameters + 1 byte plen
    class EventCC(HciEventCommandComplete):
        sha256_first_16_bytes = HciProperty[bytes](start=7, fmt="16s")

def end_pos_prog_dump_data(instance: HciPacketBase) -> int:
    return pos_prog_dump_evt_xor(instance)

def pos_prog_dump_evt_xor(instance: HciPacketBase) -> int:
    assert isinstance(instance, HciProgDump.EventCC)
    if instance.data_length is not None:
        return 12 + instance.data_length
    else:
        return 12

class HciProgDump:
    class Command(HciCommand):
        start_address = HciProperty[int](start=4, fmt="<I")
        length = HciProperty[int](start=8, fmt="<I")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF877
            # Set valid payload size for command (header + parameters)
            self.adjust_payload_size(12)  # 3 bytes header + 8 bytes parameters + 1 byte plen
    class EventCC(HciEventCommandComplete):
        start_address = HciProperty[int](start=7, fmt="<I")
        data_length = HciProperty[int](start=11, fmt="B")
        data = HciProperty[bytes](start=12, end=end_pos_prog_dump_data, fmt="s")
        xor = HciProperty[int](start=pos_prog_dump_evt_xor, fmt="B")


def on_diagno_tput_cmd_data_change(instance: HciPacketBase):
    assert isinstance(instance, HciDiagnoTput.Command)
    """Auto-update data_length and xor_checksum when data changes"""
    if instance.data is None:
        return
    # Auto-update data_length
    instance.data_len = len(instance.data)
    # Auto-update payload size to include all fields
    # Calculate required size: header + data + xor_checksum
    required_size = 6 + len(instance.data) + 1

    instance.adjust_payload_size(required_size)

def pos_diagno_tput_cmd_xor(instance: HciPacketBase) -> int:
    assert isinstance(instance, HciDiagnoTput.Command)
    if instance.data_len is not None:
        return 6 + instance.data_len
    else:
        return 6

class HciDiagnoTput:
    class Command(HciCommand):
        mode = HciProperty[int](start=4, fmt="B")
        data_len = HciProperty[int](start=5, fmt="B")
        data = HciProperty[bytes](start=6, fmt="s",
                                  on_change=on_diagno_tput_cmd_data_change)
        xor = HciProperty[int](start=pos_diagno_tput_cmd_xor, fmt="B")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF880
            
            if len(self.payload) > 0:
                self.xor = xor_bytes(self.payload[:-1])
    class EventCC(HciEventCommandComplete):
        pass

class HciDiagnoLatency:
    class Command(HciCommand):
        test_data = HciProperty[int](start=4, fmt="B")
        start_address = HciProperty[int](start=5, fmt="<I")
        data_length = HciProperty[int](start=9, fmt="<I")
        def update_before_to_bytes(self) -> None:
            super().update_before_to_bytes()
            self.opcode = 0xF881
    class EventCC(HciEventCommandComplete):
        write_time_ram = HciProperty[int](start=7, fmt="<I")
        erase_time_storage = HciProperty[int](start=11, fmt="<I")
        write_time_storage = HciProperty[int](start=15, fmt="<I")