from __future__ import annotations

import contextvars
import queue as stdlib_queue
import re
import sys
import threading
import time
import weakref
from functools import partial
from typing import (
    TYPE_CHECKING,
    NoReturn,
    TypeVar,
    Union,
)

import pytest
import sniffio

from .. import (
    CancelScope,
    CapacityLimiter,
    Event,
    _core,
    fail_after,
    move_on_after,
    sleep,
    sleep_forever,
)
from .._core._tests.test_ki import ki_self
from .._core._tests.tutil import slow
from .._threads import (
    active_thread_count,
    current_default_thread_limiter,
    from_thread_check_cancelled,
    from_thread_run,
    from_thread_run_sync,
    to_thread_run_sync,
    wait_all_threads_completed,
)
from ..testing import wait_all_tasks_blocked

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator, Awaitable, Callable

    from outcome import Outcome

    from ..lowlevel import Task

RecordType = list[tuple[str, Union[threading.Thread, type[BaseException]]]]
T = TypeVar("T")


async def test_do_in_trio_thread() -> None:
    trio_thread = threading.current_thread()

    async def check_case(  # type: ignore[explicit-any]
        do_in_trio_thread: Callable[..., threading.Thread],
        fn: Callable[..., T | Awaitable[T]],
        expected: tuple[str, T],
        trio_token: _core.TrioToken | None = None,
    ) -> None:
        record: RecordType = []

        def threadfn() -> None:
            try:
                record.append(("start", threading.current_thread()))
                x = do_in_trio_thread(fn, record, trio_token=trio_token)
                record.append(("got", x))
            except BaseException as exc:
                print(exc)
                record.append(("error", type(exc)))

        child_thread = threading.Thread(target=threadfn, daemon=True)
        child_thread.start()
        while child_thread.is_alive():
            print("yawn")
            await sleep(0.01)
        assert record == [("start", child_thread), ("f", trio_thread), expected]

    token = _core.current_trio_token()

    def f1(record: RecordType) -> int:
        assert not _core.currently_ki_protected()
        record.append(("f", threading.current_thread()))
        return 2

    await check_case(from_thread_run_sync, f1, ("got", 2), trio_token=token)

    def f2(record: RecordType) -> NoReturn:
        assert not _core.currently_ki_protected()
        record.append(("f", threading.current_thread()))
        raise ValueError

    await check_case(from_thread_run_sync, f2, ("error", ValueError), trio_token=token)

    async def f3(record: RecordType) -> int:
        assert not _core.currently_ki_protected()
        await _core.checkpoint()
        record.append(("f", threading.current_thread()))
        return 3

    await check_case(from_thread_run, f3, ("got", 3), trio_token=token)

    async def f4(record: RecordType) -> NoReturn:
        assert not _core.currently_ki_protected()
        await _core.checkpoint()
        record.append(("f", threading.current_thread()))
        raise KeyError

    await check_case(from_thread_run, f4, ("error", KeyError), trio_token=token)


async def test_do_in_trio_thread_from_trio_thread() -> None:
    with pytest.raises(RuntimeError):
        from_thread_run_sync(lambda: None)  # pragma: no branch

    async def foo() -> None:  # pragma: no cover
        pass

    with pytest.raises(RuntimeError):
        from_thread_run(foo)


def test_run_in_trio_thread_ki() -> None:
    # if we get a control-C during a run_in_trio_thread, then it propagates
    # back to the caller (slick!)
    record = set()

    async def check_run_in_trio_thread() -> None:
        token = _core.current_trio_token()

        def trio_thread_fn() -> None:
            print("in Trio thread")
            assert not _core.currently_ki_protected()
            print("ki_self")
            try:
                ki_self()
            finally:
                import sys

                print("finally", sys.exc_info())

        async def trio_thread_afn() -> None:
            trio_thread_fn()

        def external_thread_fn() -> None:
            try:
                print("running")
                from_thread_run_sync(trio_thread_fn, trio_token=token)
            except KeyboardInterrupt:
                print("ok1")
                record.add("ok1")
            try:
                from_thread_run(trio_thread_afn, trio_token=token)
            except KeyboardInterrupt:
                print("ok2")
                record.add("ok2")

        thread = threading.Thread(target=external_thread_fn)
        thread.start()
        print("waiting")
        while thread.is_alive():  # noqa: ASYNC110
            await sleep(0.01)  # Fine to poll in tests.
        print("waited, joining")
        thread.join()
        print("done")

    _core.run(check_run_in_trio_thread)
    assert record == {"ok1", "ok2"}


def test_await_in_trio_thread_while_main_exits() -> None:
    record = []
    ev = Event()

    async def trio_fn() -> None:
        record.append("sleeping")
        ev.set()
        await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)

    def thread_fn(token: _core.TrioToken) -> None:
        try:
            from_thread_run(trio_fn, trio_token=token)
        except _core.Cancelled:
            record.append("cancelled")

    async def main() -> threading.Thread:
        token = _core.current_trio_token()
        thread = threading.Thread(target=thread_fn, args=(token,))
        thread.start()
        await ev.wait()
        assert record == ["sleeping"]
        return thread

    thread = _core.run(main)
    thread.join()
    assert record == ["sleeping", "cancelled"]


async def test_named_thread() -> None:
    ending = " from trio._tests.test_threads.test_named_thread"

    def inner(name: str = "inner" + ending) -> threading.Thread:
        assert threading.current_thread().name == name
        return threading.current_thread()

    def f(name: str) -> Callable[[], threading.Thread]:
        return partial(inner, name)

    # test defaults
    await to_thread_run_sync(inner)
    await to_thread_run_sync(inner, thread_name=None)

    # functools.partial doesn't have __name__, so defaults to None
    await to_thread_run_sync(f("None" + ending))

    # test that you can set a custom name, and that it's reset afterwards
    async def test_thread_name(name: str) -> None:
        thread = await to_thread_run_sync(f(name), thread_name=name)
        assert re.match(r"Trio thread [0-9]*", thread.name)

    await test_thread_name("")
    await test_thread_name("fobiedoo")
    await test_thread_name("name_longer_than_15_characters")

    await test_thread_name("💙")


def _get_thread_name(ident: int | None = None) -> str | None:
    import ctypes
    import ctypes.util

    libpthread_path = ctypes.util.find_library("pthread")
    if not libpthread_path:
        # musl includes pthread functions directly in libc.so
        # (but note that find_library("c") does not work on musl,
        #  see: https://github.com/python/cpython/issues/65821)
        # so try that library instead
        # if it doesn't exist, CDLL() will fail below
        libpthread_path = "libc.so"
    try:
        libpthread = ctypes.CDLL(libpthread_path)
    except Exception:
        print(f"no pthread on {sys.platform}")
        return None

    pthread_getname_np = getattr(libpthread, "pthread_getname_np", None)

    # this should never fail on any platforms afaik
    assert pthread_getname_np

    # thankfully getname signature doesn't differ between platforms
    pthread_getname_np.argtypes = [
        ctypes.c_void_p,
        ctypes.c_char_p,
        ctypes.c_size_t,
    ]
    pthread_getname_np.restype = ctypes.c_int

    name_buffer = ctypes.create_string_buffer(b"", size=16)
    if ident is None:
        ident = threading.get_ident()
    assert pthread_getname_np(ident, name_buffer, 16) == 0
    try:
        return name_buffer.value.decode()
    except UnicodeDecodeError as e:  # pragma: no cover
        # used for debugging when testing via CI
        pytest.fail(f"value: {name_buffer.value!r}, exception: {e}")


# test os thread naming
# this depends on pthread being available, which is the case on 99.9% of linux machines
# and most mac machines. So unless the platform is linux it will just skip
# in case it fails to fetch the os thread name.
async def test_named_thread_os() -> None:
    def inner(name: str) -> threading.Thread:
        os_thread_name = _get_thread_name()
        if os_thread_name is None and sys.platform != "linux":
            pytest.skip(f"no pthread OS support on {sys.platform}")
        else:
            assert os_thread_name == name[:15]

        return threading.current_thread()

    def f(name: str) -> Callable[[], threading.Thread]:
        return partial(inner, name)

    # test defaults
    default = "None from trio._tests.test_threads.test_named_thread"
    await to_thread_run_sync(f(default))
    await to_thread_run_sync(f(default), thread_name=None)

    # test that you can set a custom name, and that it's reset afterwards
    async def test_thread_name(name: str, expected: str | None = None) -> None:
        if expected is None:
            expected = name
        thread = await to_thread_run_sync(f(expected), thread_name=name)

        os_thread_name = _get_thread_name(thread.ident)
        assert os_thread_name is not None, "should skip earlier if this is the case"
        assert re.match(r"Trio thread [0-9]*", os_thread_name)

    await test_thread_name("")
    await test_thread_name("fobiedoo")
    await test_thread_name("name_longer_than_15_characters")

    await test_thread_name("💙", expected="?")


def test_has_pthread_setname_np() -> None:
    from trio._core._thread_cache import get_os_thread_name_func

    k = get_os_thread_name_func()
    if k is None:
        assert sys.platform != "linux"
        pytest.skip(f"no pthread_setname_np on {sys.platform}")


async def test_run_in_worker_thread() -> None:
    trio_thread = threading.current_thread()

    def f(x: T) -> tuple[T, threading.Thread]:
        return (x, threading.current_thread())

    x, child_thread = await to_thread_run_sync(f, 1)
    assert x == 1
    assert child_thread != trio_thread

    def g() -> NoReturn:
        raise ValueError(threading.current_thread())

    with pytest.raises(
        ValueError,
        match=r"^<Thread\(Trio thread \d+, started daemon \d+\)>$",
    ) as excinfo:
        await to_thread_run_sync(g)
    print(excinfo.value.args)
    assert excinfo.value.args[0] != trio_thread


async def test_run_in_worker_thread_cancellation() -> None:
    register: list[str | None] = [None]

    def f(q: stdlib_queue.Queue[None]) -> None:
        # Make the thread block for a controlled amount of time
        register[0] = "blocking"
        q.get()
        register[0] = "finished"

    async def child(q: stdlib_queue.Queue[None], abandon_on_cancel: bool) -> None:
        record.append("start")
        try:
            return await to_thread_run_sync(f, q, abandon_on_cancel=abandon_on_cancel)
        finally:
            record.append("exit")

    record: list[str] = []
    q: stdlib_queue.Queue[None] = stdlib_queue.Queue()
    async with _core.open_nursery() as nursery:
        nursery.start_soon(child, q, True)
        # Give it a chance to get started. (This is important because
        # to_thread_run_sync does a checkpoint_if_cancelled before
        # blocking on the thread, and we don't want to trigger this.)
        await wait_all_tasks_blocked()
        assert record == ["start"]
        # Then cancel it.
        nursery.cancel_scope.cancel()
    # The task exited, but the thread didn't:
    assert register[0] != "finished"
    # Put the thread out of its misery:
    q.put(None)
    while register[0] != "finished":
        time.sleep(0.01)  # noqa: ASYNC251  # Need to wait for OS thread

    # This one can't be cancelled
    record = []
    register[0] = None
    async with _core.open_nursery() as nursery:
        nursery.start_soon(child, q, False)
        await wait_all_tasks_blocked()
        nursery.cancel_scope.cancel()
        with _core.CancelScope(shield=True):
            for _ in range(10):
                await _core.checkpoint()
        # It's still running
        assert record == ["start"]
        q.put(None)
        # Now it exits

    # But if we cancel *before* it enters, the entry is itself a cancellation
    # point
    with _core.CancelScope() as scope:
        scope.cancel()
        await child(q, False)
    assert scope.cancelled_caught


# Make sure that if trio.run exits, and then the thread finishes, then that's
# handled gracefully. (Requires that the thread result machinery be prepared
# for call_soon to raise RunFinishedError.)
def test_run_in_worker_thread_abandoned(
    capfd: pytest.CaptureFixture[str],
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)

    q1: stdlib_queue.Queue[None] = stdlib_queue.Queue()
    q2: stdlib_queue.Queue[threading.Thread] = stdlib_queue.Queue()

    def thread_fn() -> None:
        q1.get()
        q2.put(threading.current_thread())

    async def main() -> None:
        async def child() -> None:
            await to_thread_run_sync(thread_fn, abandon_on_cancel=True)

        async with _core.open_nursery() as nursery:
            nursery.start_soon(child)
            await wait_all_tasks_blocked()
            nursery.cancel_scope.cancel()

    _core.run(main)

    q1.put(None)
    # This makes sure:
    # - the thread actually ran
    # - that thread has finished before we check for its output
    thread = q2.get()
    while thread.is_alive():
        time.sleep(0.01)  # pragma: no cover

    # Make sure we don't have a "Exception in thread ..." dump to the console:
    out, err = capfd.readouterr()
    assert "Exception in thread" not in out
    assert "Exception in thread" not in err


@pytest.mark.parametrize("MAX", [3, 5, 10])
@pytest.mark.parametrize("cancel", [False, True])
@pytest.mark.parametrize("use_default_limiter", [False, True])
async def test_run_in_worker_thread_limiter(
    MAX: int,
    cancel: bool,
    use_default_limiter: bool,
) -> None:
    # This test is a bit tricky. The goal is to make sure that if we set
    # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever
    # running at a time, even if there are more concurrent calls to
    # to_thread_run_sync, and even if some of those are cancelled. And
    # also to make sure that the default limiter actually limits.
    COUNT = 2 * MAX
    gate = threading.Event()
    lock = threading.Lock()
    if use_default_limiter:
        c = current_default_thread_limiter()
        orig_total_tokens = c.total_tokens
        c.total_tokens = MAX
        limiter_arg = None
    else:
        c = CapacityLimiter(MAX)
        orig_total_tokens = MAX
        limiter_arg = c
    try:
        # We used to use regular variables and 'nonlocal' here, but it turns
        # out that it's not safe to assign to closed-over variables that are
        # visible in multiple threads, at least as of CPython 3.10 and PyPy
        # 7.3:
        #
        #   https://bugs.python.org/issue30744
        #   https://bitbucket.org/pypy/pypy/issues/2591/
        #
        # Mutating them in-place is OK though (as long as you use proper
        # locking etc.).
        class state:
            ran: int
            high_water: int
            running: int
            parked: int

        state.ran = 0
        state.high_water = 0
        state.running = 0
        state.parked = 0

        token = _core.current_trio_token()

        def thread_fn(cancel_scope: CancelScope) -> None:
            print("thread_fn start")
            from_thread_run_sync(cancel_scope.cancel, trio_token=token)
            with lock:
                state.ran += 1
                state.running += 1
                state.high_water = max(state.high_water, state.running)
                # The Trio thread below watches this value and uses it as a
                # signal that all the stats calculations have finished.
                state.parked += 1
            gate.wait()
            with lock:
                state.parked -= 1
                state.running -= 1
            print("thread_fn exiting")

        async def run_thread(event: Event) -> None:
            with _core.CancelScope() as cancel_scope:
                await to_thread_run_sync(
                    thread_fn,
                    cancel_scope,
                    abandon_on_cancel=cancel,
                    limiter=limiter_arg,
                )
            print("run_thread finished, cancelled:", cancel_scope.cancelled_caught)
            event.set()

        async with _core.open_nursery() as nursery:
            print("spawning")
            events = []
            for _ in range(COUNT):
                events.append(Event())
                nursery.start_soon(run_thread, events[-1])
                await wait_all_tasks_blocked()
            # In the cancel case, we in particular want to make sure that the
            # cancelled tasks don't release the semaphore. So let's wait until
            # at least one of them has exited, and that everything has had a
            # chance to settle down from this, before we check that everyone
            # who's supposed to be waiting is waiting:
            if cancel:
                print("waiting for first cancellation to clear")
                await events[0].wait()
                await wait_all_tasks_blocked()
            # Then wait until the first MAX threads are parked in gate.wait(),
            # and the next MAX threads are parked on the semaphore, to make
            # sure no-one is sneaking past, and to make sure the high_water
            # check below won't fail due to scheduling issues. (It could still
            # fail if too many threads are let through here.)
            while (  # noqa: ASYNC110
                state.parked != MAX or c.statistics().tasks_waiting != MAX
            ):
                await sleep(0.01)  # pragma: no cover
            # Then release the threads
            gate.set()

        assert state.high_water == MAX

        if cancel:
            # Some threads might still be running; need to wait to them to
            # finish before checking that all threads ran. We can do this
            # using the CapacityLimiter.
            while c.borrowed_tokens > 0:  # noqa: ASYNC110
                await sleep(0.01)  # pragma: no cover

        assert state.ran == COUNT
        assert state.running == 0
    finally:
        c.total_tokens = orig_total_tokens


async def test_run_in_worker_thread_custom_limiter() -> None:
    # Basically just checking that we only call acquire_on_behalf_of and
    # release_on_behalf_of, since that's part of our documented API.
    record = []

    class CustomLimiter:
        async def acquire_on_behalf_of(self, borrower: Task) -> None:
            record.append("acquire")
            self._borrower = borrower

        def release_on_behalf_of(self, borrower: Task) -> None:
            record.append("release")
            assert borrower == self._borrower

    # TODO: should CapacityLimiter have an abc or protocol so users can modify it?
    # because currently it's `final` so writing code like this is not allowed.
    await to_thread_run_sync(lambda: None, limiter=CustomLimiter())  # type: ignore[arg-type]
    assert record == ["acquire", "release"]


async def test_run_in_worker_thread_limiter_error() -> None:
    record = []

    class BadCapacityLimiter:
        async def acquire_on_behalf_of(self, borrower: Task) -> None:
            record.append("acquire")

        def release_on_behalf_of(self, borrower: Task) -> NoReturn:
            record.append("release")
            raise ValueError("release on behalf")

    bs = BadCapacityLimiter()

    with pytest.raises(ValueError, match=r"^release on behalf$") as excinfo:
        await to_thread_run_sync(lambda: None, limiter=bs)  # type: ignore[arg-type]
    assert excinfo.value.__context__ is None
    assert record == ["acquire", "release"]
    record = []

    # If the original function raised an error, then the semaphore error
    # chains with it
    d: dict[str, object] = {}
    with pytest.raises(ValueError, match=r"^release on behalf$") as excinfo:
        await to_thread_run_sync(lambda: d["x"], limiter=bs)  # type: ignore[arg-type]
    assert isinstance(excinfo.value.__context__, KeyError)
    assert record == ["acquire", "release"]


async def test_run_in_worker_thread_fail_to_spawn(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    # Test the unlikely but possible case where trying to spawn a thread fails
    def bad_start(self: object, *args: object) -> NoReturn:
        raise RuntimeError("the engines canna take it captain")

    monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start)

    limiter = current_default_thread_limiter()
    assert limiter.borrowed_tokens == 0

    # We get an appropriate error, and the limiter is cleanly released
    with pytest.raises(RuntimeError) as excinfo:
        await to_thread_run_sync(lambda: None)  # pragma: no cover
    assert "engines" in str(excinfo.value)

    assert limiter.borrowed_tokens == 0


async def test_trio_to_thread_run_sync_token() -> None:
    # Test that to_thread_run_sync automatically injects the current trio token
    # into a spawned thread
    def thread_fn() -> _core.TrioToken:
        callee_token = from_thread_run_sync(_core.current_trio_token)
        return callee_token

    caller_token = _core.current_trio_token()
    callee_token = await to_thread_run_sync(thread_fn)
    assert callee_token == caller_token


async def test_trio_to_thread_run_sync_expected_error() -> None:
    # Test correct error when passed async function
    async def async_fn() -> None:  # pragma: no cover
        pass

    with pytest.raises(TypeError, match="expected a sync function"):
        await to_thread_run_sync(async_fn)  # type: ignore[unused-coroutine]


trio_test_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
    "trio_test_contextvar",
)


async def test_trio_to_thread_run_sync_contextvars() -> None:
    trio_thread = threading.current_thread()
    trio_test_contextvar.set("main")

    def f() -> tuple[str, threading.Thread]:
        value = trio_test_contextvar.get()
        with pytest.raises(sniffio.AsyncLibraryNotFoundError):
            sniffio.current_async_library()
        return (value, threading.current_thread())

    value, child_thread = await to_thread_run_sync(f)
    assert value == "main"
    assert child_thread != trio_thread

    def g() -> tuple[str, str, threading.Thread]:
        parent_value = trio_test_contextvar.get()
        trio_test_contextvar.set("worker")
        inner_value = trio_test_contextvar.get()
        with pytest.raises(sniffio.AsyncLibraryNotFoundError):
            sniffio.current_async_library()
        return (
            parent_value,
            inner_value,
            threading.current_thread(),
        )

    parent_value, inner_value, child_thread = await to_thread_run_sync(g)
    current_value = trio_test_contextvar.get()
    assert parent_value == "main"
    assert inner_value == "worker"
    assert current_value == "main", (
        "The contextvar value set on the worker would not propagate back to the main"
        " thread"
    )
    assert sniffio.current_async_library() == "trio"


async def test_trio_from_thread_run_sync() -> None:
    # Test that to_thread_run_sync correctly "hands off" the trio token to
    # trio.from_thread.run_sync()
    def thread_fn_1() -> float:
        trio_time = from_thread_run_sync(_core.current_time)
        return trio_time

    trio_time = await to_thread_run_sync(thread_fn_1)
    assert isinstance(trio_time, float)

    # Test correct error when passed async function
    async def async_fn() -> None:  # pragma: no cover
        pass

    def thread_fn_2() -> None:
        from_thread_run_sync(async_fn)  # type: ignore[unused-coroutine]

    with pytest.raises(TypeError, match="expected a synchronous function"):
        await to_thread_run_sync(thread_fn_2)


async def test_trio_from_thread_run() -> None:
    # Test that to_thread_run_sync correctly "hands off" the trio token to
    # trio.from_thread.run()
    record = []

    async def back_in_trio_fn() -> None:
        _core.current_time()  # implicitly checks that we're in trio
        record.append("back in trio")

    def thread_fn() -> None:
        record.append("in thread")
        from_thread_run(back_in_trio_fn)

    await to_thread_run_sync(thread_fn)
    assert record == ["in thread", "back in trio"]

    # Test correct error when passed sync function
    def sync_fn() -> None:  # pragma: no cover
        pass

    with pytest.raises(TypeError, match="appears to be synchronous"):
        await to_thread_run_sync(from_thread_run, sync_fn)  # type: ignore[arg-type]


async def test_trio_from_thread_token() -> None:
    # Test that to_thread_run_sync and spawned trio.from_thread.run_sync()
    # share the same Trio token
    def thread_fn() -> _core.TrioToken:
        callee_token = from_thread_run_sync(_core.current_trio_token)
        return callee_token

    caller_token = _core.current_trio_token()
    callee_token = await to_thread_run_sync(thread_fn)
    assert callee_token == caller_token


async def test_trio_from_thread_token_kwarg() -> None:
    # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can
    # use an explicitly defined token
    def thread_fn(token: _core.TrioToken) -> _core.TrioToken:
        callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token)
        return callee_token

    caller_token = _core.current_trio_token()
    callee_token = await to_thread_run_sync(thread_fn, caller_token)
    assert callee_token == caller_token


async def test_from_thread_no_token() -> None:
    # Test that a "raw call" to trio.from_thread.run() fails because no token
    # has been provided

    with pytest.raises(RuntimeError):
        from_thread_run_sync(_core.current_time)


async def test_trio_from_thread_run_sync_contextvars() -> None:
    trio_test_contextvar.set("main")

    def thread_fn() -> tuple[str, str, str, str, str]:
        thread_parent_value = trio_test_contextvar.get()
        trio_test_contextvar.set("worker")
        thread_current_value = trio_test_contextvar.get()
        with pytest.raises(sniffio.AsyncLibraryNotFoundError):
            sniffio.current_async_library()

        def back_in_main() -> tuple[str, str]:
            back_parent_value = trio_test_contextvar.get()
            trio_test_contextvar.set("back_in_main")
            back_current_value = trio_test_contextvar.get()
            assert sniffio.current_async_library() == "trio"
            return back_parent_value, back_current_value

        back_parent_value, back_current_value = from_thread_run_sync(back_in_main)
        thread_after_value = trio_test_contextvar.get()
        with pytest.raises(sniffio.AsyncLibraryNotFoundError):
            sniffio.current_async_library()
        return (
            thread_parent_value,
            thread_current_value,
            thread_after_value,
            back_parent_value,
            back_current_value,
        )

    (
        thread_parent_value,
        thread_current_value,
        thread_after_value,
        back_parent_value,
        back_current_value,
    ) = await to_thread_run_sync(thread_fn)
    current_value = trio_test_contextvar.get()
    assert current_value == thread_parent_value == "main"
    assert thread_current_value == back_parent_value == thread_after_value == "worker"
    assert sniffio.current_async_library() == "trio"
    assert back_current_value == "back_in_main"


async def test_trio_from_thread_run_contextvars() -> None:
    trio_test_contextvar.set("main")

    def thread_fn() -> tuple[str, str, str, str, str]:
        thread_parent_value = trio_test_contextvar.get()
        trio_test_contextvar.set("worker")
        thread_current_value = trio_test_contextvar.get()
        with pytest.raises(sniffio.AsyncLibraryNotFoundError):
            sniffio.current_async_library()

        async def async_back_in_main() -> tuple[str, str]:
            back_parent_value = trio_test_contextvar.get()
            trio_test_contextvar.set("back_in_main")
            back_current_value = trio_test_contextvar.get()
            assert sniffio.current_async_library() == "trio"
            return back_parent_value, back_current_value

        back_parent_value, back_current_value = from_thread_run(async_back_in_main)
        thread_after_value = trio_test_contextvar.get()
        with pytest.raises(sniffio.AsyncLibraryNotFoundError):
            sniffio.current_async_library()
        return (
            thread_parent_value,
            thread_current_value,
            thread_after_value,
            back_parent_value,
            back_current_value,
        )

    (
        thread_parent_value,
        thread_current_value,
        thread_after_value,
        back_parent_value,
        back_current_value,
    ) = await to_thread_run_sync(thread_fn)
    current_value = trio_test_contextvar.get()
    assert current_value == thread_parent_value == "main"
    assert thread_current_value == back_parent_value == thread_after_value == "worker"
    assert back_current_value == "back_in_main"
    assert sniffio.current_async_library() == "trio"


def test_run_fn_as_system_task_caught_badly_typed_token() -> None:
    with pytest.raises(RuntimeError):
        from_thread_run_sync(
            _core.current_time,
            trio_token="Not TrioTokentype",  # type: ignore[arg-type]
        )


async def test_from_thread_inside_trio_thread() -> None:
    def not_called() -> None:  # pragma: no cover
        raise AssertionError()

    trio_token = _core.current_trio_token()
    with pytest.raises(RuntimeError):
        from_thread_run_sync(not_called, trio_token=trio_token)


def test_from_thread_run_during_shutdown() -> None:
    save = []
    record = []

    async def agen(token: _core.TrioToken | None) -> AsyncGenerator[None, None]:
        try:
            yield
        finally:
            with _core.CancelScope(shield=True):
                try:
                    await to_thread_run_sync(
                        partial(from_thread_run, sleep, 0, trio_token=token),
                    )
                except _core.RunFinishedError:
                    record.append("finished")
                else:
                    record.append("clean")

    async def main(use_system_task: bool) -> None:
        save.append(agen(_core.current_trio_token() if use_system_task else None))
        await save[-1].asend(None)

    _core.run(main, True)  # System nursery will be closed and raise RunFinishedError
    _core.run(main, False)  # host task will be rescheduled as normal
    assert record == ["finished", "clean"]


async def test_trio_token_weak_referenceable() -> None:
    token = _core.current_trio_token()
    assert isinstance(token, _core.TrioToken)
    weak_reference = weakref.ref(token)
    assert token is weak_reference()


async def test_unsafe_abandon_on_cancel_kwarg() -> None:
    # This is a stand in for a numpy ndarray or other objects
    # that (maybe surprisingly) lack a notion of truthiness
    class BadBool:
        def __bool__(self) -> bool:
            raise NotImplementedError

    with pytest.raises(NotImplementedError):
        await to_thread_run_sync(int, abandon_on_cancel=BadBool())  # type: ignore[arg-type]


async def test_from_thread_reuses_task() -> None:
    task = _core.current_task()

    async def async_current_task() -> _core.Task:
        return _core.current_task()

    assert task is await to_thread_run_sync(from_thread_run_sync, _core.current_task)
    assert task is await to_thread_run_sync(from_thread_run, async_current_task)


async def test_recursive_to_thread() -> None:
    tid = None

    def get_tid_then_reenter() -> int:
        nonlocal tid
        tid = threading.get_ident()
        return from_thread_run(to_thread_run_sync, threading.get_ident)

    assert tid != await to_thread_run_sync(get_tid_then_reenter)


async def test_from_thread_host_cancelled() -> None:
    queue: stdlib_queue.Queue[bool] = stdlib_queue.Queue()

    def sync_check() -> None:
        from_thread_run_sync(cancel_scope.cancel)
        try:
            from_thread_run_sync(bool)
        except _core.Cancelled:  # pragma: no cover
            queue.put(True)  # sync functions don't raise Cancelled
        else:
            queue.put(False)

    with _core.CancelScope() as cancel_scope:
        await to_thread_run_sync(sync_check)

    assert not cancel_scope.cancelled_caught
    assert not queue.get_nowait()

    with _core.CancelScope() as cancel_scope:
        await to_thread_run_sync(sync_check, abandon_on_cancel=True)

    assert cancel_scope.cancelled_caught
    assert not await to_thread_run_sync(partial(queue.get, timeout=1))

    async def no_checkpoint() -> bool:
        return True

    def async_check() -> None:
        from_thread_run_sync(cancel_scope.cancel)
        try:
            assert from_thread_run(no_checkpoint)
        except _core.Cancelled:  # pragma: no cover
            queue.put(True)  # async functions raise Cancelled at checkpoints
        else:
            queue.put(False)

    with _core.CancelScope() as cancel_scope:
        await to_thread_run_sync(async_check)

    assert not cancel_scope.cancelled_caught
    assert not queue.get_nowait()

    with _core.CancelScope() as cancel_scope:
        await to_thread_run_sync(async_check, abandon_on_cancel=True)

    assert cancel_scope.cancelled_caught
    assert not await to_thread_run_sync(partial(queue.get, timeout=1))

    async def async_time_bomb() -> None:
        cancel_scope.cancel()
        with fail_after(10):
            await sleep_forever()

    with _core.CancelScope() as cancel_scope:
        await to_thread_run_sync(from_thread_run, async_time_bomb)

    assert cancel_scope.cancelled_caught


async def test_from_thread_check_cancelled() -> None:
    q: stdlib_queue.Queue[str] = stdlib_queue.Queue()

    async def child(abandon_on_cancel: bool, scope: CancelScope) -> None:
        with scope:
            record.append("start")
            try:
                return await to_thread_run_sync(f, abandon_on_cancel=abandon_on_cancel)
            except _core.Cancelled:
                record.append("cancel")
                raise
            finally:
                record.append("exit")

    def f() -> None:
        try:
            from_thread_check_cancelled()
        except _core.Cancelled:  # pragma: no cover, test failure path
            q.put("Cancelled")
        else:
            q.put("Not Cancelled")
        ev.wait()
        return from_thread_check_cancelled()

    # Base case: nothing cancelled so we shouldn't see cancels anywhere
    record: list[str] = []
    ev = threading.Event()
    async with _core.open_nursery() as nursery:
        nursery.start_soon(child, False, _core.CancelScope())
        await wait_all_tasks_blocked()
        assert record[0] == "start"
        assert q.get(timeout=1) == "Not Cancelled"
        ev.set()
    # implicit assertion, Cancelled not raised via nursery
    assert record[1] == "exit"

    # abandon_on_cancel=False case: a cancel will pop out but be handled by
    # the appropriate cancel scope
    record = []
    ev = threading.Event()
    scope = _core.CancelScope()  # Nursery cancel scope gives false positives
    async with _core.open_nursery() as nursery:
        nursery.start_soon(child, False, scope)
        await wait_all_tasks_blocked()
        assert record[0] == "start"
        assert q.get(timeout=1) == "Not Cancelled"
        scope.cancel()
        ev.set()
    assert scope.cancelled_caught
    assert "cancel" in record
    assert record[-1] == "exit"

    # abandon_on_cancel=True case: slightly different thread behavior needed
    # check thread is cancelled "soon" after abandonment
    def f() -> None:  # type: ignore[no-redef] # noqa: F811
        ev.wait()
        try:
            from_thread_check_cancelled()
        except _core.Cancelled:
            q.put("Cancelled")
        else:  # pragma: no cover, test failure path
            q.put("Not Cancelled")

    record = []
    ev = threading.Event()
    scope = _core.CancelScope()
    async with _core.open_nursery() as nursery:
        nursery.start_soon(child, True, scope)
        await wait_all_tasks_blocked()
        assert record[0] == "start"
        scope.cancel()
        ev.set()
    assert scope.cancelled_caught
    assert "cancel" in record
    assert record[-1] == "exit"
    assert q.get(timeout=1) == "Cancelled"


def test_from_thread_check_cancelled_raises_in_foreign_threads() -> None:
    with pytest.raises(RuntimeError):
        from_thread_check_cancelled()
    q: stdlib_queue.Queue[Outcome[object]] = stdlib_queue.Queue()
    _core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_))
    with pytest.raises(RuntimeError):
        q.get(timeout=1).unwrap()


@slow
async def test_reentry_doesnt_deadlock() -> None:
    # Regression test for issue noticed in GH-2827
    # The failure mode is to hang the whole test suite, unfortunately.
    # XXX consider running this in a subprocess with a timeout, if it comes up again!

    async def child() -> None:
        while True:
            await to_thread_run_sync(from_thread_run, sleep, 0, abandon_on_cancel=False)

    with move_on_after(2):
        async with _core.open_nursery() as nursery:
            for _ in range(4):
                nursery.start_soon(child)


async def test_wait_all_threads_completed() -> None:
    no_threads_left = False
    e1 = Event()
    e2 = Event()

    e1_exited = Event()
    e2_exited = Event()

    async def wait_event(e: Event, e_exit: Event) -> None:
        def thread() -> None:
            from_thread_run(e.wait)

        await to_thread_run_sync(thread)
        e_exit.set()

    async def wait_no_threads_left() -> None:
        nonlocal no_threads_left
        await wait_all_threads_completed()
        no_threads_left = True

    async with _core.open_nursery() as nursery:
        nursery.start_soon(wait_event, e1, e1_exited)
        nursery.start_soon(wait_event, e2, e2_exited)
        await wait_all_tasks_blocked()
        nursery.start_soon(wait_no_threads_left)
        await wait_all_tasks_blocked()
        assert not no_threads_left
        assert active_thread_count() == 2

        e1.set()
        await e1_exited.wait()
        await wait_all_tasks_blocked()
        assert not no_threads_left
        assert active_thread_count() == 1

        e2.set()
        await e2_exited.wait()
        await wait_all_tasks_blocked()
        assert no_threads_left
        assert active_thread_count() == 0


async def test_wait_all_threads_completed_no_threads() -> None:
    await wait_all_threads_completed()
    assert active_thread_count() == 0
