from __future__ import annotations as _annotations

import os
from collections.abc import Mapping
from typing import (
    TYPE_CHECKING,
    Any,
)

from pydantic._internal._utils import deep_update, is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic.fields import FieldInfo
from typing_extensions import get_args, get_origin
from typing_inspection.introspection import is_union_origin

from ...utils import _lenient_issubclass
from ..base import PydanticBaseEnvSettingsSource
from ..types import EnvNoneType
from ..utils import (
    _annotation_enum_name_to_val,
    _get_model_fields,
    _union_is_complex,
    parse_env_vars,
)

if TYPE_CHECKING:
    from pydantic_settings.main import BaseSettings


class EnvSettingsSource(PydanticBaseEnvSettingsSource):
    """
    Source class for loading settings values from environment variables.
    """

    def __init__(
        self,
        settings_cls: type[BaseSettings],
        case_sensitive: bool | None = None,
        env_prefix: str | None = None,
        env_nested_delimiter: str | None = None,
        env_nested_max_split: int | None = None,
        env_ignore_empty: bool | None = None,
        env_parse_none_str: str | None = None,
        env_parse_enums: bool | None = None,
    ) -> None:
        super().__init__(
            settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
        )
        self.env_nested_delimiter = (
            env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
        )
        self.env_nested_max_split = (
            env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split')
        )
        self.maxsplit = (self.env_nested_max_split or 0) - 1
        self.env_prefix_len = len(self.env_prefix)

        self.env_vars = self._load_env_vars()

    def _load_env_vars(self) -> Mapping[str, str | None]:
        return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str)

    def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
        """
        Gets the value for field from environment variables and a flag to determine whether value is complex.

        Args:
            field: The field.
            field_name: The field name.

        Returns:
            A tuple that contains the value (`None` if not found), key, and
                a flag to determine whether value is complex.
        """

        env_val: str | None = None
        for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
            env_val = self.env_vars.get(env_name)
            if env_val is not None:
                break

        return env_val, field_key, value_is_complex

    def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
        """
        Prepare value for the field.

        * Extract value for nested field.
        * Deserialize value to python object for complex field.

        Args:
            field: The field.
            field_name: The field name.

        Returns:
            A tuple contains prepared value for the field.

        Raises:
            ValuesError: When There is an error in deserializing value for complex field.
        """
        is_complex, allow_parse_failure = self._field_is_complex(field)
        if self.env_parse_enums:
            enum_val = _annotation_enum_name_to_val(field.annotation, value)
            value = value if enum_val is None else enum_val

        if is_complex or value_is_complex:
            if isinstance(value, EnvNoneType):
                return value
            elif value is None:
                # field is complex but no value found so far, try explode_env_vars
                env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
                if env_val_built:
                    return env_val_built
            else:
                # field is complex and there's a value, decode that as JSON, then add explode_env_vars
                try:
                    value = self.decode_complex_value(field_name, field, value)
                except ValueError as e:
                    if not allow_parse_failure:
                        raise e

                if isinstance(value, dict):
                    return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
                else:
                    return value
        elif value is not None:
            # simplest case, field is not complex, we only need to add the value if it was found
            return value

    def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
        """
        Find out if a field is complex, and if so whether JSON errors should be ignored
        """
        if self.field_is_complex(field):
            allow_parse_failure = False
        elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
            allow_parse_failure = True
        else:
            return False, False

        return True, allow_parse_failure

    # Default value of `case_sensitive` is `None`, because we don't want to break existing behavior.
    # We have to change the method to a non-static method and use
    # `self.case_sensitive` instead in V3.
    def next_field(
        self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None
    ) -> FieldInfo | None:
        """
        Find the field in a sub model by key(env name)

        By having the following models:

            ```py
            class SubSubModel(BaseSettings):
                dvals: Dict

            class SubModel(BaseSettings):
                vals: list[str]
                sub_sub_model: SubSubModel

            class Cfg(BaseSettings):
                sub_model: SubModel
            ```

        Then:
            next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
            next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class

        Args:
            field: The field.
            key: The key (env name).
            case_sensitive: Whether to search for key case sensitively.

        Returns:
            Field if it finds the next field otherwise `None`.
        """
        if not field:
            return None

        annotation = field.annotation if isinstance(field, FieldInfo) else field
        for type_ in get_args(annotation):
            type_has_key = self.next_field(type_, key, case_sensitive)
            if type_has_key:
                return type_has_key
        if is_model_class(annotation) or is_pydantic_dataclass(annotation):  # type: ignore[arg-type]
            fields = _get_model_fields(annotation)
            # `case_sensitive is None` is here to be compatible with the old behavior.
            # Has to be removed in V3.
            for field_name, f in fields.items():
                for _, env_name, _ in self._extract_field_info(f, field_name):
                    if case_sensitive is None or case_sensitive:
                        if field_name == key or env_name == key:
                            return f
                    elif field_name.lower() == key.lower() or env_name.lower() == key.lower():
                        return f
        return None

    def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
        """
        Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.

        This is applied to a single field, hence filtering by env_var prefix.

        Args:
            field_name: The field name.
            field: The field.
            env_vars: Environment variables.

        Returns:
            A dictionary contains extracted values from nested env values.
        """
        if not self.env_nested_delimiter:
            return {}

        ann = field.annotation
        is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)

        prefixes = [
            f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
        ]
        result: dict[str, Any] = {}
        for env_name, env_val in env_vars.items():
            try:
                prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix))
            except StopIteration:
                continue
            # we remove the prefix before splitting in case the prefix has characters in common with the delimiter
            env_name_without_prefix = env_name[len(prefix) :]
            *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit)
            env_var = result
            target_field: FieldInfo | None = field
            for key in keys:
                target_field = self.next_field(target_field, key, self.case_sensitive)
                if isinstance(env_var, dict):
                    env_var = env_var.setdefault(key, {})

            # get proper field with last_key
            target_field = self.next_field(target_field, last_key, self.case_sensitive)

            # check if env_val maps to a complex field and if so, parse the env_val
            if (target_field or is_dict) and env_val:
                if target_field:
                    is_complex, allow_json_failure = self._field_is_complex(target_field)
                    if self.env_parse_enums:
                        enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val)
                        env_val = env_val if enum_val is None else enum_val
                else:
                    # nested field type is dict
                    is_complex, allow_json_failure = True, True
                if is_complex:
                    try:
                        env_val = self.decode_complex_value(last_key, target_field, env_val)  # type: ignore
                    except ValueError as e:
                        if not allow_json_failure:
                            raise e
            if isinstance(env_var, dict):
                if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
                    env_var[last_key] = env_val

        return result

    def __repr__(self) -> str:
        return (
            f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
            f'env_prefix_len={self.env_prefix_len!r})'
        )


__all__ = ['EnvSettingsSource']
