from __future__ import annotations
from argparse import _SubParsersAction, ArgumentParser, SUPPRESS
from typing import Any, TypeVar, Type, Callable
from pydantic import BaseModel, Field
from pydantic_core import PydanticUndefined
from typing import get_origin, get_args, Union, Any, List
from logging import getLogger
from argparse import Namespace

try:
    from types import UnionType
except ImportError:
    UnionType = None  # Python < 3.10


logger = getLogger(__name__)
T = TypeVar('T', bound=BaseModel)

class ArgumentSubParserHelper:

    def __init__(self):
        self.sub_parser: _SubParsersAction[ArgumentParser] | None = None
        self.sub_command_name_field: str = ""
        self._task_cls_map: dict[str, Type[Any]] = {}
        self._context_cls_map: dict[str, Type[T]] = {}
        self.parsed_args: Namespace | None = None
        self.subcommand: str = ""

    def set_sub_parser(self, sub_parser: _SubParsersAction[ArgumentParser]) -> None:
        self.sub_parser = sub_parser
        self.sub_command_name_field = self.sub_parser.dest

    def _resolve_concrete_type(self, tp: Any) -> tuple[type, bool]:
        """
        Resolve the final element type and whether it's a list.

        Returns:
            (final_type, is_list)
        Examples:
            int -> (int, False)
            Optional[int] -> (int, False)
            list[int] -> (int, True)
            Optional[list[int]] -> (int, True)
            list[int] | None -> (int, True)
        """

        def _unwrap_optional(t: Any) -> Any:
            """Remove Optional / NoneType / Union[..., None] / PEP604 | None."""
            origin = get_origin(t)
            if origin in (Union, UnionType):
                args = [a for a in get_args(t) if a is not type(None)]
                if len(args) == 1:
                    return args[0]
                # fallback: multiple args, take first non-None
                return next((a for a in args if a is not type(None)), Any)
            return t

        # remove Optional wrapper
        base = _unwrap_optional(tp)
        origin = get_origin(base)
        args = get_args(base)

        # Case 1: list[T] or List[T]
        if origin in (list, List):
            inner = args[0] if args else Any
            inner = _unwrap_optional(inner)
            return inner, True

        # Case 2: not a list
        return base, False

    def _default_arg_name_formatter(name: str) -> str:
        """Format the argument name for the parser."""
        return name.replace('_', '-')
    
    def _add_one_field(
        self, field_name: str, field: Any, parser: ArgumentParser,
        arg_name_formater: Callable[[str], str] = _default_arg_name_formatter,
        is_engineer_mode: bool = False
    ) -> None:
        schema:dict[str, Any] = field.json_schema_extra or {}
        if "hide" in schema:
            return
        arg_help = SUPPRESS\
            if not is_engineer_mode and "engineer_mode" in schema\
            else field.description
        fieldtype, is_list = self._resolve_concrete_type(field.annotation)
        default = (
            field.default if field.default is not PydanticUndefined \
                else None
        )

        if fieldtype is int:
            fieldtype = lambda x: int(x, 0)

        if fieldtype is bool:
            if default is None:
                raise ValueError(
                    f"Bool field {field_name} must have default")
            if field.default:
                parser.add_argument(f'--no-{arg_name_formater(field_name)}',
                                    action='store_false', dest=field_name,
                                    help=arg_help)
            else:
                parser.add_argument(f'--{arg_name_formater(field_name)}',
                                    action='store_true', dest=field_name,
                                    help=arg_help)
        else:
            nargs = None if not is_list else '+'
            if default is not None:
                logger.debug(f"Adding arg: {arg_name_formater(field_name)}"
                            f" default={default}")
                parser.add_argument(f'--{arg_name_formater(field_name)}',
                                    type=fieldtype, help=arg_help,
                                    default=default, nargs=nargs)
            else:
                logger.debug(f"Adding arg: {arg_name_formater(field_name)}"
                            " required")
                parser.add_argument(f'--{arg_name_formater(field_name)}',
                                    type=fieldtype, help=arg_help,
                                    required=True, nargs=nargs)
        

    def add_subcommand_class(
        self, name: str, context_cls: Type[T], task_cls: Type[Any],
        /,
        arg_name_formater: Callable[[str], str] = _default_arg_name_formatter,
        is_engineer_mode: bool = False,
        **kwargs: Any) -> ArgumentParser:
        """Add a subcommand to the parser."""
        assert self.sub_parser is not None
        parser = self.sub_parser.add_parser(name, **kwargs)
        for field_name, field in context_cls.model_fields.items():
            self._add_one_field(field_name, field, parser, arg_name_formater,
                                is_engineer_mode)
        self._task_cls_map[name] = task_cls
        self._context_cls_map[name] = context_cls
        return parser

    def set_parsed_args(self, args: Namespace) -> None:
        """Set the parsed arguments to the helper."""
        self.parsed_args = args
        if self.sub_command_name_field not in args:
            raise ValueError("Subcommand name field"
                             f" {self.sub_command_name_field} not found in"
                             " parsed args")
        self.subcommand = getattr(args, self.sub_command_name_field)
        if self.subcommand not in self._context_cls_map:
            raise ValueError(f"Subcommand {self.subcommand} not found in"
                             " context map")
        logger.info(f"Subcommand: {self.subcommand}")

    def get_subcommand_context(self) -> T:
        """Get the context class from the parsed arguments."""
        if self.parsed_args is None:
            raise ValueError("Parsed args not set, call set_parsed_args first")

        if self.subcommand not in self._context_cls_map:
            raise ValueError(f"Subcommand {self.subcommand} not found in"
                             " context map")
        
        context_type = self._context_cls_map[self.subcommand]
        context = context_type(**vars(self.parsed_args))
        return context

    def get_subcommand_task_type(self) -> Type[Any]:
        """Get the task class type from the parsed arguments."""
        if self.parsed_args is None:
            raise ValueError("Parsed args not set, call set_parsed_args first")

        if self.subcommand not in self._task_cls_map:
            raise ValueError(f"Subcommand {self.subcommand} not found in"
                             " task map")

        task_type = self._task_cls_map[self.subcommand]
        return task_type

# decorate
def ArgExtra(field: Any, /, engineer_mode: bool = False, **kwargs: Any) -> Any:
    """Add extra arguments to the field."""
    if field.json_schema_extra is None:
        field.json_schema_extra = {}
    field.json_schema_extra.update(kwargs)
    field.json_schema_extra.update({"engineer_mode": engineer_mode})
    return field