#! /usr/bin/env python3
"""
Code generation script for class methods
to be exported as public API
"""
from __future__ import annotations

import argparse
import ast
import os
import subprocess
import sys
from pathlib import Path
from textwrap import indent
from typing import TYPE_CHECKING

import attrs

if TYPE_CHECKING:
    from collections.abc import Iterable, Iterator

    from typing_extensions import TypeGuard

# keep these imports up to date with conditional imports in test_gen_exports
# isort: split
import astor

PREFIX = "_generated"

HEADER = """# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations

import sys

from ._ki import enable_ki_protection
from ._run import GLOBAL_RUN_CONTEXT
"""

TEMPLATE = """try:
    return{}GLOBAL_RUN_CONTEXT.{}.{}
except AttributeError:
    raise RuntimeError("must be called from async context") from None
"""


@attrs.define
class File:
    path: Path
    modname: str
    platform: str = attrs.field(default="", kw_only=True)
    imports: str = attrs.field(default="", kw_only=True)


def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
    """Check if the AST node is either a function
    or an async function
    """
    return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))


def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
    """Check if the AST node has a _public decorator"""
    if is_function(node):
        for decorator in node.decorator_list:
            if isinstance(decorator, ast.Name) and decorator.id == "_public":
                return True
    return False


def get_public_methods(
    tree: ast.AST,
) -> Iterator[ast.FunctionDef | ast.AsyncFunctionDef]:
    """Return a list of methods marked as public.
    The function walks the given tree and extracts
    all objects that are functions which are marked
    public.
    """
    for node in ast.walk(tree):
        if is_public(node):
            yield node


def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
    """Given a function definition, create a string that represents taking all
    the arguments from the function, and passing them through to another
    invocation of the same function.

    Example input: ast.parse("def f(a, *, b): ...")
    Example output: "(a, b=b)"
    """
    call_args = [arg.arg for arg in funcdef.args.args]
    if funcdef.args.vararg:
        call_args.append("*" + funcdef.args.vararg.arg)
    for arg in funcdef.args.kwonlyargs:
        call_args.append(arg.arg + "=" + arg.arg)  # noqa: PERF401  # clarity
    if funcdef.args.kwarg:
        call_args.append("**" + funcdef.args.kwarg.arg)
    return "({})".format(", ".join(call_args))


def run_black(file: File, source: str) -> tuple[bool, str]:
    """Run black on the specified file.

    Returns:
      Tuple of success and result string.
      ex.:
        (False, "Failed to run black!\nerror: cannot format ...")
        (True, "<formatted source>")

    Raises:
      ImportError: If black is not installed.
    """
    # imported to check that `subprocess` calls will succeed
    import black  # noqa: F401

    # Black has an undocumented API, but it doesn't easily allow reading configuration from
    # pyproject.toml, and simultaneously pass in / receive the code as a string.
    # https://github.com/psf/black/issues/779
    result = subprocess.run(
        # "-" as a filename = use stdin, return on stdout.
        [sys.executable, "-m", "black", "--stdin-filename", file.path, "-"],
        input=source,
        capture_output=True,
        encoding="utf8",
    )

    if result.returncode != 0:
        return False, f"Failed to run black!\n{result.stderr}"
    return True, result.stdout


def run_ruff(file: File, source: str) -> tuple[bool, str]:
    """Run ruff on the specified file.

    Returns:
      Tuple of success and result string.
      ex.:
        (False, "Failed to run ruff!\nerror: Failed to parse ...")
        (True, "<formatted source>")

    Raises:
      ImportError: If ruff is not installed.
    """
    # imported to check that `subprocess` calls will succeed
    import ruff  # noqa: F401

    result = subprocess.run(
        # "-" as a filename = use stdin, return on stdout.
        [
            sys.executable,
            "-m",
            "ruff",
            "check",
            "--fix",
            "--unsafe-fixes",
            "--stdin-filename",
            file.path,
            "-",
        ],
        input=source,
        capture_output=True,
        encoding="utf8",
    )

    if result.returncode != 0:
        return False, f"Failed to run ruff!\n{result.stderr}"
    return True, result.stdout


def run_linters(file: File, source: str) -> str:
    """Format the specified file using black and ruff.

    Returns:
      Formatted source code.

    Raises:
      ImportError: If either is not installed.
      SystemExit: If either failed.
    """

    for fn in (run_black, run_ruff):
        success, source = fn(file, source)
        if not success:
            print(source)
            sys.exit(1)

    return source


def gen_public_wrappers_source(file: File) -> str:
    """Scan the given .py file for @_public decorators, and generate wrapper
    functions.

    """
    header = [HEADER]
    header.append(file.imports)
    if file.platform:
        # Simple checks to avoid repeating imports. If this messes up, type checkers/tests will
        # just give errors.
        if "TYPE_CHECKING" not in file.imports:
            header.append("from typing import TYPE_CHECKING\n")
        if "import sys" not in file.imports:  # pragma: no cover
            header.append("import sys\n")
        header.append(
            f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n',
        )

    generated = ["".join(header)]

    source = astor.code_to_ast.parse_file(file.path)
    method_names = []
    for method in get_public_methods(source):
        # Remove self from arguments
        assert method.args.args[0].arg == "self"
        del method.args.args[0]
        method_names.append(method.name)

        for dec in method.decorator_list:  # pragma: no cover
            if isinstance(dec, ast.Name) and dec.id == "contextmanager":
                is_cm = True
                break
        else:
            is_cm = False

        # Remove decorators
        method.decorator_list = [ast.Name("enable_ki_protection")]

        # Create pass through arguments
        new_args = create_passthrough_args(method)

        # Remove method body without the docstring
        if ast.get_docstring(method) is None:
            del method.body[:]
        else:
            # The first entry is always the docstring
            del method.body[1:]

        # Create the function definition including the body
        func = astor.to_source(method, indent_with=" " * 4)

        if is_cm:  # pragma: no cover
            func = func.replace("->Iterator", "->AbstractContextManager")

        # Create export function body
        template = TEMPLATE.format(
            " await " if isinstance(method, ast.AsyncFunctionDef) else " ",
            file.modname,
            method.name + new_args,
        )

        # Assemble function definition arguments and body
        snippet = func + indent(template, " " * 4)

        # Append the snippet to the corresponding module
        generated.append(snippet)

    method_names.sort()
    # Insert after the header, before function definitions
    generated.insert(1, f"__all__ = {method_names!r}")
    return "\n\n".join(generated)


def matches_disk_files(new_files: dict[str, str]) -> bool:
    for new_path, new_source in new_files.items():
        if not os.path.exists(new_path):
            return False
        old_source = Path(new_path).read_text(encoding="utf-8")
        if old_source != new_source:
            return False
    return True


def process(files: Iterable[File], *, do_test: bool) -> None:
    new_files = {}
    for file in files:
        print("Scanning:", file.path)
        new_source = gen_public_wrappers_source(file)
        new_source = run_linters(file, new_source)
        dirname, basename = os.path.split(file.path)
        new_path = os.path.join(dirname, PREFIX + basename)
        new_files[new_path] = new_source
    matches_disk = matches_disk_files(new_files)
    if do_test:
        if not matches_disk:
            print("Generated sources are outdated. Please regenerate.")
            sys.exit(1)
        else:
            print("Generated sources are up to date.")
    else:
        for new_path, new_source in new_files.items():
            with open(new_path, "w", encoding="utf-8", newline="\n") as fp:
                fp.write(new_source)
        print("Regenerated sources successfully.")
        if not matches_disk:  # TODO: test this branch
            # With pre-commit integration, show that we edited files.
            sys.exit(1)


# This is in fact run in CI, but only in the formatting check job, which
# doesn't collect coverage.
def main() -> None:  # pragma: no cover
    parser = argparse.ArgumentParser(
        description="Generate python code for public api wrappers",
    )
    parser.add_argument(
        "--test",
        "-t",
        action="store_true",
        help="test if code is still up to date",
    )
    parsed_args = parser.parse_args()

    source_root = Path.cwd()
    # Double-check we found the right directory
    assert (source_root / "LICENSE").exists()
    core = source_root / "src/trio/_core"
    to_wrap = [
        File(core / "_run.py", "runner", imports=IMPORTS_RUN),
        File(
            core / "_instrumentation.py",
            "runner.instruments",
            imports=IMPORTS_INSTRUMENT,
        ),
        File(
            core / "_io_windows.py",
            "runner.io_manager",
            platform="win32",
            imports=IMPORTS_WINDOWS,
        ),
        File(
            core / "_io_epoll.py",
            "runner.io_manager",
            platform="linux",
            imports=IMPORTS_EPOLL,
        ),
        File(
            core / "_io_kqueue.py",
            "runner.io_manager",
            platform="darwin",
            imports=IMPORTS_KQUEUE,
        ),
    ]

    process(to_wrap, do_test=parsed_args.test)


IMPORTS_RUN = """\
from collections.abc import Awaitable, Callable
from typing import Any, TYPE_CHECKING

from outcome import Outcome
import contextvars

from ._run import _NO_SEND, RunStatistics, Task
from ._entry_queue import TrioToken
from .._abc import Clock

if TYPE_CHECKING:
    from typing_extensions import Unpack
    from ._run import PosArgT
"""
IMPORTS_INSTRUMENT = """\
from ._instrumentation import Instrument
"""

IMPORTS_EPOLL = """\
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from .._file_io import _HasFileNo
"""

IMPORTS_KQUEUE = """\
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import select
    from collections.abc import Callable
    from contextlib import AbstractContextManager

    from .. import _core
    from .._file_io import _HasFileNo
    from ._traps import Abort, RaiseCancelT
"""

IMPORTS_WINDOWS = """\
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from contextlib import AbstractContextManager

    from typing_extensions import Buffer

    from .._file_io import _HasFileNo
    from ._unbounded_queue import UnboundedQueue
    from ._windows_cffi import Handle, CData
"""


if __name__ == "__main__":  # pragma: no cover
    main()
