from __future__ import annotations  # isort: split

import __future__  # Regular import, not special!

import enum
import functools
import importlib
import inspect
import json
import socket as stdlib_socket
import sys
import types
from pathlib import Path, PurePath
from types import ModuleType
from typing import TYPE_CHECKING, Protocol

import attrs
import pytest

import trio
import trio.testing
from trio._tests.pytest_plugin import RUN_SLOW, skip_if_optional_else_raise

from .. import _core, _util
from .._core._tests.tutil import slow

if TYPE_CHECKING:
    from collections.abc import Iterable, Iterator

mypy_cache_updated = False


try:  # If installed, check both versions of this class.
    from typing_extensions import Protocol as Protocol_ext
except ImportError:  # pragma: no cover
    Protocol_ext = Protocol  # type: ignore[assignment]


def _ensure_mypy_cache_updated() -> None:
    # This pollutes the `empty` dir. Should this be changed?
    try:
        from mypy.api import run
    except ImportError as error:
        skip_if_optional_else_raise(error)

    global mypy_cache_updated
    if not mypy_cache_updated:
        # mypy cache was *probably* already updated by the other tests,
        # but `pytest -k ...` might run just this test on its own
        result = run(
            [
                "--config-file=",
                "--cache-dir=./.mypy_cache",
                "--no-error-summary",
                "-c",
                "import trio",
            ],
        )
        assert not result[1]  # stderr
        assert not result[0]  # stdout
        mypy_cache_updated = True


def test_core_is_properly_reexported() -> None:
    # Each export from _core should be re-exported by exactly one of these
    # three modules:
    sources = [trio, trio.lowlevel, trio.testing]
    for symbol in dir(_core):
        if symbol.startswith("_"):
            continue
        found = 0
        for source in sources:
            if symbol in dir(source) and getattr(source, symbol) is getattr(
                _core,
                symbol,
            ):
                found += 1
        print(symbol, found)
        assert found == 1


def class_is_final(cls: type) -> bool:
    """Check if a class cannot be subclassed."""
    try:
        # new_class() handles metaclasses properly, type(...) does not.
        types.new_class("SubclassTester", (cls,))
    except TypeError:
        return True
    else:
        return False


def iter_modules(
    module: types.ModuleType,
    only_public: bool,
) -> Iterator[types.ModuleType]:
    yield module
    for name, class_ in module.__dict__.items():
        if name.startswith("_") and only_public:
            continue
        if not isinstance(class_, ModuleType):
            continue
        if not class_.__name__.startswith(module.__name__):  # pragma: no cover
            continue
        if class_ is module:  # pragma: no cover
            continue
        yield from iter_modules(class_, only_public)


PUBLIC_MODULES = list(iter_modules(trio, only_public=True))
ALL_MODULES = list(iter_modules(trio, only_public=False))
PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]


# It doesn't make sense for downstream redistributors to run this test, since
# they might be using a newer version of Python with additional symbols which
# won't be reflected in trio.socket, and this shouldn't cause downstream test
# runs to start failing.
@pytest.mark.redistributors_should_skip
# Static analysis tools often have trouble with alpha releases, where Python's
# internals are in flux, grammar may not have settled down, etc.
@pytest.mark.skipif(
    sys.version_info.releaselevel == "alpha",
    reason="skip static introspection tools on Python dev/alpha releases",
)
@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
@pytest.mark.parametrize("tool", ["pylint", "jedi", "mypy", "pyright_verifytypes"])
@pytest.mark.filterwarnings(
    # https://github.com/pypa/setuptools/issues/3274
    "ignore:module 'sre_constants' is deprecated:DeprecationWarning",
)
def test_static_tool_sees_all_symbols(tool: str, modname: str, tmp_path: Path) -> None:
    module = importlib.import_module(modname)

    def no_underscores(symbols: Iterable[str]) -> set[str]:
        return {symbol for symbol in symbols if not symbol.startswith("_")}

    runtime_names = no_underscores(dir(module))

    # ignore deprecated module `tests` being invisible
    if modname == "trio":
        runtime_names.discard("tests")

    # Ignore any __future__ feature objects, if imported under that name.
    for name in __future__.all_feature_names:
        if getattr(module, name, None) is getattr(__future__, name):
            runtime_names.remove(name)

    if tool == "pylint":
        try:
            from pylint.lint import PyLinter
        except ImportError as error:
            skip_if_optional_else_raise(error)

        linter = PyLinter()
        assert module.__file__ is not None
        ast = linter.get_ast(module.__file__, modname)
        static_names = no_underscores(ast)  # type: ignore[arg-type]
    elif tool == "jedi":
        if sys.implementation.name != "cpython":
            pytest.skip("jedi does not support pypy")

        try:
            import jedi
        except ImportError as error:
            skip_if_optional_else_raise(error)

        # Simulate typing "import trio; trio.<TAB>"
        script = jedi.Script(f"import {modname}; {modname}.")
        completions = script.complete()
        static_names = no_underscores(c.name for c in completions)
    elif tool == "mypy":
        if not RUN_SLOW:  # pragma: no cover
            pytest.skip("use --run-slow to check against mypy")

        cache = Path.cwd() / ".mypy_cache"

        _ensure_mypy_cache_updated()

        trio_cache = next(cache.glob("*/trio"))
        _, modname = (modname + ".").split(".", 1)
        modname = modname[:-1]
        mod_cache = trio_cache / modname if modname else trio_cache
        if mod_cache.is_dir():  # pragma: no coverage
            mod_cache = mod_cache / "__init__.data.json"
        else:
            mod_cache = trio_cache / (modname + ".data.json")

        assert mod_cache.exists()
        assert mod_cache.is_file()
        with mod_cache.open() as cache_file:
            cache_json = json.loads(cache_file.read())
            static_names = no_underscores(
                key
                for key, value in cache_json["names"].items()
                if not key.startswith(".") and value["kind"] == "Gdef"
            )
    elif tool == "pyright_verifytypes":
        if not RUN_SLOW:  # pragma: no cover
            pytest.skip("use --run-slow to check against pyright")

        try:
            import pyright  # noqa: F401
        except ImportError as error:
            skip_if_optional_else_raise(error)
        import subprocess

        res = subprocess.run(
            ["pyright", f"--verifytypes={modname}", "--outputjson"],
            capture_output=True,
        )
        current_result = json.loads(res.stdout)

        static_names = {
            x["name"][len(modname) + 1 :]
            for x in current_result["typeCompleteness"]["symbols"]
            if x["name"].startswith(modname)
        }
    else:  # pragma: no cover
        raise AssertionError()

    # It's expected that the static set will contain more names than the
    # runtime set:
    # - static tools are sometimes sloppy and include deleted names
    # - some symbols are platform-specific at runtime, but always show up in
    #   static analysis (e.g. in trio.socket or trio.lowlevel)
    # So we check that the runtime names are a subset of the static names.
    missing_names = runtime_names - static_names

    # ignore warnings about deprecated module tests
    missing_names -= {"tests"}

    if missing_names:  # pragma: no cover
        print(f"{tool} can't see the following names in {modname}:")
        print()
        for name in sorted(missing_names):
            print(f"    {name}")
        raise AssertionError()


# this could be sped up by only invoking mypy once per module, or even once for all
# modules, instead of once per class.
@slow
# see comment on test_static_tool_sees_all_symbols
@pytest.mark.redistributors_should_skip
# Static analysis tools often have trouble with alpha releases, where Python's
# internals are in flux, grammar may not have settled down, etc.
@pytest.mark.skipif(
    sys.version_info.releaselevel == "alpha",
    reason="skip static introspection tools on Python dev/alpha releases",
)
@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES)
@pytest.mark.parametrize("tool", ["jedi", "mypy"])
def test_static_tool_sees_class_members(
    tool: str,
    module_name: str,
    tmp_path: Path,
) -> None:
    module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)]

    # ignore hidden, but not dunder, symbols
    def no_hidden(symbols: Iterable[str]) -> set[str]:
        return {
            symbol
            for symbol in symbols
            if (not symbol.startswith("_")) or symbol.startswith("__")
        }

    if tool == "jedi" and sys.implementation.name != "cpython":
        pytest.skip("jedi does not support pypy")

    if tool == "mypy":
        cache = Path.cwd() / ".mypy_cache"

        _ensure_mypy_cache_updated()

        trio_cache = next(cache.glob("*/trio"))
        modname = module_name
        _, modname = (modname + ".").split(".", 1)
        modname = modname[:-1]
        mod_cache = trio_cache / modname if modname else trio_cache
        if mod_cache.is_dir():
            mod_cache = mod_cache / "__init__.data.json"
        else:
            mod_cache = trio_cache / (modname + ".data.json")

        assert mod_cache.exists()
        assert mod_cache.is_file()
        with mod_cache.open() as cache_file:
            cache_json = json.loads(cache_file.read())

        # skip a bunch of file-system activity (probably can un-memoize?)
        @functools.lru_cache
        def lookup_symbol(symbol: str) -> dict[str, str]:
            topname, *modname, name = symbol.split(".")
            version = next(cache.glob("3.*/"))
            mod_cache = version / topname
            if not mod_cache.is_dir():
                mod_cache = version / (topname + ".data.json")

            if modname:
                for piece in modname[:-1]:
                    mod_cache /= piece
                next_cache = mod_cache / modname[-1]
                if next_cache.is_dir():  # pragma: no coverage
                    mod_cache = next_cache / "__init__.data.json"
                else:
                    mod_cache = mod_cache / (modname[-1] + ".data.json")
            elif mod_cache.is_dir():
                mod_cache /= "__init__.data.json"
            with mod_cache.open() as f:
                return json.loads(f.read())["names"][name]  # type: ignore[no-any-return]

    errors: dict[str, object] = {}
    for class_name, class_ in module.__dict__.items():
        if not isinstance(class_, type):
            continue
        if module_name == "trio.socket" and class_name in dir(stdlib_socket):
            continue

        # Ignore classes that don't use attrs, they only define their members once
        # __init__ is called (and reason they don't use attrs is because they're going
        # to be reimplemented in pytest).
        # Not 100% that's the case, and it works locally, so whatever /shrug
        if class_ is trio.testing.RaisesGroup or class_ is trio.testing.Matcher:
            continue

        # dir() and inspect.getmembers doesn't display properties from the metaclass
        # also ignore some dunder methods that tend to differ but are of no consequence
        ignore_names = set(dir(type(class_))) | {
            "__annotations__",
            "__attrs_attrs__",
            "__attrs_own_setattr__",
            "__callable_proto_members_only__",
            "__class_getitem__",
            "__final__",
            "__getstate__",
            "__match_args__",
            "__order__",
            "__orig_bases__",
            "__parameters__",
            "__protocol_attrs__",
            "__setstate__",
            "__slots__",
            "__weakref__",
            # ignore errors about dunders inherited from stdlib that tools might
            # not see
            "__copy__",
            "__deepcopy__",
        }

        if type(class_) is type:
            # C extension classes don't have these dunders, but Python classes do
            ignore_names.add("__firstlineno__")
            ignore_names.add("__static_attributes__")

        # pypy seems to have some additional dunders that differ
        if sys.implementation.name == "pypy":
            ignore_names |= {
                "__basicsize__",
                "__dictoffset__",
                "__itemsize__",
                "__sizeof__",
                "__weakrefoffset__",
                "__unicode__",
            }

        # inspect.getmembers sees `name` and `value` in Enums, otherwise
        # it behaves the same way as `dir`
        # runtime_names = no_underscores(dir(class_))
        runtime_names = (
            no_hidden(x[0] for x in inspect.getmembers(class_)) - ignore_names
        )

        if tool == "jedi":
            try:
                import jedi
            except ImportError as error:
                skip_if_optional_else_raise(error)

            script = jedi.Script(
                f"from {module_name} import {class_name}; {class_name}.",
            )
            completions = script.complete()
            static_names = no_hidden(c.name for c in completions) - ignore_names

        elif tool == "mypy":
            # load the cached type information
            cached_type_info = cache_json["names"][class_name]
            assert (
                "node" not in cached_type_info
            ), "previously this was an 'if' but it seems it's no longer possible for this cache to contain 'node', if this assert raises for you please let us know!"
            cached_type_info = lookup_symbol(cached_type_info["cross_ref"])

            assert "node" in cached_type_info
            node = cached_type_info["node"]
            static_names = no_hidden(
                k for k in node.get("names", ()) if not k.startswith(".")
            )
            for symbol in node["mro"][1:]:
                node = lookup_symbol(symbol)["node"]
                static_names |= no_hidden(
                    k for k in node.get("names", ()) if not k.startswith(".")
                )
            static_names -= ignore_names

        else:  # pragma: no cover
            raise AssertionError("unknown tool")

        missing = runtime_names - static_names
        extra = static_names - runtime_names

        # using .remove() instead of .delete() to get an error in case they start not
        # being missing

        if (
            tool == "jedi"
            and BaseException in class_.__mro__
            and sys.version_info >= (3, 11)
        ):
            missing.remove("add_note")

        if (
            tool == "mypy"
            and BaseException in class_.__mro__
            and sys.version_info >= (3, 11)
        ):
            extra.remove("__notes__")

        if tool == "mypy" and attrs.has(class_):
            # e.g. __trio__core__run_CancelScope_AttrsAttributes__
            before = len(extra)
            extra = {e for e in extra if not e.endswith("AttrsAttributes__")}
            assert len(extra) == before - 1

        # mypy does not see these attributes in Enum subclasses
        if (
            tool == "mypy"
            and enum.Enum in class_.__mro__
            and sys.version_info >= (3, 12)
        ):
            # Another attribute, in 3.12+ only.
            extra.remove("__signature__")

        # TODO: this *should* be visible via `dir`!!
        if tool == "mypy" and class_ == trio.Nursery:
            extra.remove("cancel_scope")

        # These are (mostly? solely?) *runtime* attributes, often set in
        # __init__, which doesn't show up with dir() or inspect.getmembers,
        # but we get them in the way we query mypy & jedi
        EXTRAS = {
            trio.DTLSChannel: {"peer_address", "endpoint"},
            trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"},
            trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"},
            trio.SSLListener: {"transport_listener"},
            trio.SSLStream: {"transport_stream"},
            trio.SocketListener: {"socket"},
            trio.SocketStream: {"socket"},
            trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"},
            trio.testing.MemorySendStream: {
                "close_hook",
                "send_all_hook",
                "wait_send_all_might_not_block_hook",
            },
        }
        if tool == "mypy" and class_ in EXTRAS:
            before = len(extra)
            extra -= EXTRAS[class_]
            assert len(extra) == before - len(EXTRAS[class_])

        # TODO: why is this? Is it a problem?
        # see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916
        if class_ == trio.StapledStream:
            extra.remove("receive_stream")
            extra.remove("send_stream")

        # I have not researched why these are missing, should maybe create an issue
        # upstream with jedi
        if tool == "jedi" and sys.version_info >= (3, 11):
            if class_ in (
                trio.DTLSChannel,
                trio.MemoryReceiveChannel,
                trio.MemorySendChannel,
                trio.SSLListener,
                trio.SocketListener,
            ):
                missing.remove("__aenter__")
                missing.remove("__aexit__")
            if class_ in (trio.DTLSChannel, trio.MemoryReceiveChannel):
                missing.remove("__aiter__")
                missing.remove("__anext__")

        if class_ in (trio.Path, trio.WindowsPath, trio.PosixPath):
            # These are from inherited subclasses.
            missing -= PurePath.__dict__.keys()
            # These are unix-only.
            if tool == "mypy" and sys.platform == "win32":
                missing -= {"owner", "is_mount", "group"}
            if tool == "jedi" and sys.platform == "win32":
                extra -= {"owner", "is_mount", "group"}

        # not sure why jedi in particular ignores this (static?) method in 3.13
        # (especially given the method is from 3.12....)
        if (
            tool == "jedi"
            and sys.version_info >= (3, 13)
            and class_ in (trio.Path, trio.WindowsPath, trio.PosixPath)
        ):
            missing.remove("with_segments")

        if sys.version_info >= (3, 13) and attrs.has(class_):
            missing.remove("__replace__")

        if missing or extra:  # pragma: no cover
            errors[f"{module_name}.{class_name}"] = {
                "missing": missing,
                "extra": extra,
            }

    # `assert not errors` will not print the full content of errors, even with
    # `--verbose`, so we manually print it
    if errors:  # pragma: no cover
        from pprint import pprint

        print(f"\n{tool} can't see the following symbols in {module_name}:")
        pprint(errors)
    assert not errors


def test_nopublic_is_final() -> None:
    """Check all NoPublicConstructor classes are also @final."""
    assert class_is_final(_util.NoPublicConstructor)  # This is itself final.

    for module in ALL_MODULES:
        for class_ in module.__dict__.values():
            if isinstance(class_, _util.NoPublicConstructor):
                assert class_is_final(class_)


def test_classes_are_final() -> None:
    # Sanity checks.
    assert not class_is_final(object)
    assert class_is_final(bool)

    for module in PUBLIC_MODULES:
        for name, class_ in module.__dict__.items():
            if not isinstance(class_, type):
                continue
            # Deprecated classes are exported with a leading underscore
            if name.startswith("_"):  # pragma: no cover
                continue

            # Abstract classes can be subclassed, because that's the whole
            # point of ABCs
            if inspect.isabstract(class_):
                continue
            # Same with protocols, but only direct children.
            if Protocol in class_.__bases__ or Protocol_ext in class_.__bases__:
                continue
            # Exceptions are allowed to be subclassed, because exception
            # subclassing isn't used to inherit behavior.
            if issubclass(class_, BaseException):
                continue
            # These are classes that are conceptually abstract, but
            # inspect.isabstract returns False for boring reasons.
            if class_ is trio.abc.Instrument or class_ is trio.socket.SocketType:
                continue
            # ... insert other special cases here ...

            # The `Path` class needs to support inheritance to allow `WindowsPath` and `PosixPath`.
            if class_ is trio.Path:
                continue
            # don't care about the *Statistics classes
            if name.endswith("Statistics"):
                continue

            assert class_is_final(class_)


# Plugin might not be running, especially if running from an installed version.
@pytest.mark.skipif(
    not hasattr(attrs.field, "trio_modded"),
    reason="Pytest plugin not installed.",
)
def test_pyright_recognizes_init_attributes() -> None:
    """Check whether we provide `alias` for all underscore prefixed attributes.

    Attrs always sets the `alias` attribute on fields, so a pytest plugin is used
    to monkeypatch `field()` to record whether an alias was defined in the metadata.
    See `_trio_check_attrs_aliases`.
    """
    for module in PUBLIC_MODULES:
        for class_ in module.__dict__.values():
            if not attrs.has(class_):
                continue
            if isinstance(class_, _util.NoPublicConstructor):
                continue

            attributes = [
                attr
                for attr in attrs.fields(class_)
                if attr.init
                if attr.alias
                not in (
                    attr.name,
                    # trio_original_args may not be present in autoattribs
                    attr.metadata.get("trio_original_args", {}).get("alias"),
                )
            ]

            assert attributes == [], class_
