from __future__ import annotations

import re
import weakref
from typing import TYPE_CHECKING, Callable, Union

import pytest

from trio.testing import Matcher, RaisesGroup

from .. import _core
from .._core._parking_lot import GLOBAL_PARKING_LOT_BREAKER
from .._sync import *
from .._timeouts import sleep_forever
from ..testing import assert_checkpoints, wait_all_tasks_blocked

if TYPE_CHECKING:
    from typing_extensions import TypeAlias


async def test_Event() -> None:
    e = Event()
    assert not e.is_set()
    assert e.statistics().tasks_waiting == 0

    e.set()
    assert e.is_set()
    with assert_checkpoints():
        await e.wait()

    e = Event()

    record = []

    async def child() -> None:
        record.append("sleeping")
        await e.wait()
        record.append("woken")

    async with _core.open_nursery() as nursery:
        nursery.start_soon(child)
        nursery.start_soon(child)
        await wait_all_tasks_blocked()
        assert record == ["sleeping", "sleeping"]
        assert e.statistics().tasks_waiting == 2
        e.set()
        await wait_all_tasks_blocked()
        assert record == ["sleeping", "sleeping", "woken", "woken"]


async def test_CapacityLimiter() -> None:
    with pytest.raises(TypeError):
        CapacityLimiter(1.0)
    with pytest.raises(ValueError, match=r"^total_tokens must be >= 1$"):
        CapacityLimiter(-1)
    c = CapacityLimiter(2)
    repr(c)  # smoke test
    assert c.total_tokens == 2
    assert c.borrowed_tokens == 0
    assert c.available_tokens == 2
    with pytest.raises(RuntimeError):
        c.release()
    assert c.borrowed_tokens == 0
    c.acquire_nowait()
    assert c.borrowed_tokens == 1
    assert c.available_tokens == 1

    stats = c.statistics()
    assert stats.borrowed_tokens == 1
    assert stats.total_tokens == 2
    assert stats.borrowers == [_core.current_task()]
    assert stats.tasks_waiting == 0

    # Can't re-acquire when we already have it
    with pytest.raises(RuntimeError):
        c.acquire_nowait()
    assert c.borrowed_tokens == 1
    with pytest.raises(RuntimeError):
        await c.acquire()
    assert c.borrowed_tokens == 1

    # We can acquire on behalf of someone else though
    with assert_checkpoints():
        await c.acquire_on_behalf_of("someone")

    # But then we've run out of capacity
    assert c.borrowed_tokens == 2
    with pytest.raises(_core.WouldBlock):
        c.acquire_on_behalf_of_nowait("third party")

    assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}

    # Until we release one
    c.release_on_behalf_of(_core.current_task())
    assert c.statistics().borrowers == ["someone"]

    c.release_on_behalf_of("someone")
    assert c.borrowed_tokens == 0
    with assert_checkpoints():
        async with c:
            assert c.borrowed_tokens == 1

    async with _core.open_nursery() as nursery:
        await c.acquire_on_behalf_of("value 1")
        await c.acquire_on_behalf_of("value 2")
        nursery.start_soon(c.acquire_on_behalf_of, "value 3")
        await wait_all_tasks_blocked()
        assert c.borrowed_tokens == 2
        assert c.statistics().tasks_waiting == 1
        c.release_on_behalf_of("value 2")
        # Fairness:
        assert c.borrowed_tokens == 2
        with pytest.raises(_core.WouldBlock):
            c.acquire_nowait()

    c.release_on_behalf_of("value 3")
    c.release_on_behalf_of("value 1")


async def test_CapacityLimiter_inf() -> None:
    from math import inf

    c = CapacityLimiter(inf)
    repr(c)  # smoke test
    assert c.total_tokens == inf
    assert c.borrowed_tokens == 0
    assert c.available_tokens == inf
    with pytest.raises(RuntimeError):
        c.release()
    assert c.borrowed_tokens == 0
    c.acquire_nowait()
    assert c.borrowed_tokens == 1
    assert c.available_tokens == inf


async def test_CapacityLimiter_change_total_tokens() -> None:
    c = CapacityLimiter(2)

    with pytest.raises(TypeError):
        c.total_tokens = 1.0

    with pytest.raises(ValueError, match=r"^total_tokens must be >= 1$"):
        c.total_tokens = 0

    with pytest.raises(ValueError, match=r"^total_tokens must be >= 1$"):
        c.total_tokens = -10

    assert c.total_tokens == 2

    async with _core.open_nursery() as nursery:
        for i in range(5):
            nursery.start_soon(c.acquire_on_behalf_of, i)
            await wait_all_tasks_blocked()
        assert set(c.statistics().borrowers) == {0, 1}
        assert c.statistics().tasks_waiting == 3
        c.total_tokens += 2
        assert set(c.statistics().borrowers) == {0, 1, 2, 3}
        assert c.statistics().tasks_waiting == 1
        c.total_tokens -= 3
        assert c.borrowed_tokens == 4
        assert c.total_tokens == 1
        c.release_on_behalf_of(0)
        c.release_on_behalf_of(1)
        c.release_on_behalf_of(2)
        assert set(c.statistics().borrowers) == {3}
        assert c.statistics().tasks_waiting == 1
        c.release_on_behalf_of(3)
        assert set(c.statistics().borrowers) == {4}
        assert c.statistics().tasks_waiting == 0


# regression test for issue #548
async def test_CapacityLimiter_memleak_548() -> None:
    limiter = CapacityLimiter(total_tokens=1)
    await limiter.acquire()

    async with _core.open_nursery() as n:
        n.start_soon(limiter.acquire)
        await wait_all_tasks_blocked()  # give it a chance to run the task
        n.cancel_scope.cancel()

    # if this is 1, the acquire call (despite being killed) is still there in the task, and will
    # leak memory all the while the limiter is active
    assert len(limiter._pending_borrowers) == 0


async def test_Semaphore() -> None:
    with pytest.raises(TypeError):
        Semaphore(1.0)  # type: ignore[arg-type]
    with pytest.raises(ValueError, match=r"^initial value must be >= 0$"):
        Semaphore(-1)
    s = Semaphore(1)
    repr(s)  # smoke test
    assert s.value == 1
    assert s.max_value is None
    s.release()
    assert s.value == 2
    assert s.statistics().tasks_waiting == 0
    s.acquire_nowait()
    assert s.value == 1
    with assert_checkpoints():
        await s.acquire()
    assert s.value == 0
    with pytest.raises(_core.WouldBlock):
        s.acquire_nowait()

    s.release()
    assert s.value == 1
    with assert_checkpoints():
        async with s:
            assert s.value == 0
    assert s.value == 1
    s.acquire_nowait()

    record = []

    async def do_acquire(s: Semaphore) -> None:
        record.append("started")
        await s.acquire()
        record.append("finished")

    async with _core.open_nursery() as nursery:
        nursery.start_soon(do_acquire, s)
        await wait_all_tasks_blocked()
        assert record == ["started"]
        assert s.value == 0
        s.release()
        # Fairness:
        assert s.value == 0
        with pytest.raises(_core.WouldBlock):
            s.acquire_nowait()
    assert record == ["started", "finished"]


def test_Semaphore_bounded() -> None:
    with pytest.raises(TypeError):
        Semaphore(1, max_value=1.0)  # type: ignore[arg-type]
    with pytest.raises(ValueError, match=r"^max_values must be >= initial_value$"):
        Semaphore(2, max_value=1)
    bs = Semaphore(1, max_value=1)
    assert bs.max_value == 1
    repr(bs)  # smoke test
    with pytest.raises(ValueError, match=r"^semaphore released too many times$"):
        bs.release()
    assert bs.value == 1
    bs.acquire_nowait()
    assert bs.value == 0
    bs.release()
    assert bs.value == 1


@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
async def test_Lock_and_StrictFIFOLock(
    lockcls: type[Lock | StrictFIFOLock],
) -> None:
    l = lockcls()  # noqa
    assert not l.locked()

    # make sure locks can be weakref'ed (gh-331)
    r = weakref.ref(l)
    assert r() is l

    repr(l)  # smoke test
    # make sure repr uses the right name for subclasses
    assert lockcls.__name__ in repr(l)
    with assert_checkpoints():
        async with l:
            assert l.locked()
            repr(l)  # smoke test (repr branches on locked/unlocked)
    assert not l.locked()
    l.acquire_nowait()
    assert l.locked()
    l.release()
    assert not l.locked()
    with assert_checkpoints():
        await l.acquire()
    assert l.locked()
    l.release()
    assert not l.locked()

    l.acquire_nowait()
    with pytest.raises(RuntimeError):
        # Error out if we already own the lock
        l.acquire_nowait()
    l.release()
    with pytest.raises(RuntimeError):
        # Error out if we don't own the lock
        l.release()

    holder_task = None

    async def holder() -> None:
        nonlocal holder_task
        holder_task = _core.current_task()
        async with l:
            await sleep_forever()

    async with _core.open_nursery() as nursery:
        assert not l.locked()
        nursery.start_soon(holder)
        await wait_all_tasks_blocked()
        assert l.locked()
        # WouldBlock if someone else holds the lock
        with pytest.raises(_core.WouldBlock):
            l.acquire_nowait()
        # Can't release a lock someone else holds
        with pytest.raises(RuntimeError):
            l.release()

        statistics = l.statistics()
        print(statistics)
        assert statistics.locked
        assert statistics.owner is holder_task
        assert statistics.tasks_waiting == 0

        nursery.start_soon(holder)
        await wait_all_tasks_blocked()
        statistics = l.statistics()
        print(statistics)
        assert statistics.tasks_waiting == 1

        nursery.cancel_scope.cancel()

    statistics = l.statistics()
    assert not statistics.locked
    assert statistics.owner is None
    assert statistics.tasks_waiting == 0


async def test_Condition() -> None:
    with pytest.raises(TypeError):
        Condition(Semaphore(1))  # type: ignore[arg-type]
    with pytest.raises(TypeError):
        Condition(StrictFIFOLock)  # type: ignore[arg-type]
    l = Lock()  # noqa
    c = Condition(l)
    assert not l.locked()
    assert not c.locked()
    with assert_checkpoints():
        await c.acquire()
    assert l.locked()
    assert c.locked()

    c = Condition()
    assert not c.locked()
    c.acquire_nowait()
    assert c.locked()
    with pytest.raises(RuntimeError):
        c.acquire_nowait()
    c.release()

    with pytest.raises(RuntimeError):
        # Can't wait without holding the lock
        await c.wait()
    with pytest.raises(RuntimeError):
        # Can't notify without holding the lock
        c.notify()
    with pytest.raises(RuntimeError):
        # Can't notify without holding the lock
        c.notify_all()

    finished_waiters = set()

    async def waiter(i: int) -> None:
        async with c:
            await c.wait()
        finished_waiters.add(i)

    async with _core.open_nursery() as nursery:
        for i in range(3):
            nursery.start_soon(waiter, i)
            await wait_all_tasks_blocked()
        async with c:
            c.notify()
        assert c.locked()
        await wait_all_tasks_blocked()
        assert finished_waiters == {0}
        async with c:
            c.notify_all()
        await wait_all_tasks_blocked()
        assert finished_waiters == {0, 1, 2}

    finished_waiters = set()
    async with _core.open_nursery() as nursery:
        for i in range(3):
            nursery.start_soon(waiter, i)
            await wait_all_tasks_blocked()
        async with c:
            c.notify(2)
            statistics = c.statistics()
            print(statistics)
            assert statistics.tasks_waiting == 1
            assert statistics.lock_statistics.tasks_waiting == 2
        # exiting the context manager hands off the lock to the first task
        assert c.statistics().lock_statistics.tasks_waiting == 1

        await wait_all_tasks_blocked()
        assert finished_waiters == {0, 1}

        async with c:
            c.notify_all()

    # After being cancelled still hold the lock (!)
    # (Note that c.__aexit__ checks that we hold the lock as well)
    with _core.CancelScope() as scope:
        async with c:
            scope.cancel()
            try:
                await c.wait()
            finally:
                assert c.locked()


from .._channel import open_memory_channel
from .._sync import AsyncContextManagerMixin

# Three ways of implementing a Lock in terms of a channel. Used to let us put
# the channel through the generic lock tests.


class ChannelLock1(AsyncContextManagerMixin):
    def __init__(self, capacity: int) -> None:
        self.s, self.r = open_memory_channel[None](capacity)
        for _ in range(capacity - 1):
            self.s.send_nowait(None)

    def acquire_nowait(self) -> None:
        self.s.send_nowait(None)

    async def acquire(self) -> None:
        await self.s.send(None)

    def release(self) -> None:
        self.r.receive_nowait()


class ChannelLock2(AsyncContextManagerMixin):
    def __init__(self) -> None:
        self.s, self.r = open_memory_channel[None](10)
        self.s.send_nowait(None)

    def acquire_nowait(self) -> None:
        self.r.receive_nowait()

    async def acquire(self) -> None:
        await self.r.receive()

    def release(self) -> None:
        self.s.send_nowait(None)


class ChannelLock3(AsyncContextManagerMixin):
    def __init__(self) -> None:
        self.s, self.r = open_memory_channel[None](0)
        # self.acquired is true when one task acquires the lock and
        # only becomes false when it's released and no tasks are
        # waiting to acquire.
        self.acquired = False

    def acquire_nowait(self) -> None:
        assert not self.acquired
        self.acquired = True

    async def acquire(self) -> None:
        if self.acquired:
            await self.s.send(None)
        else:
            self.acquired = True
            await _core.checkpoint()

    def release(self) -> None:
        try:
            self.r.receive_nowait()
        except _core.WouldBlock:
            assert self.acquired
            self.acquired = False


lock_factories = [
    lambda: CapacityLimiter(1),
    lambda: Semaphore(1),
    Lock,
    StrictFIFOLock,
    lambda: ChannelLock1(10),
    lambda: ChannelLock1(1),
    ChannelLock2,
    ChannelLock3,
]
lock_factory_names = [
    "CapacityLimiter(1)",
    "Semaphore(1)",
    "Lock",
    "StrictFIFOLock",
    "ChannelLock1(10)",
    "ChannelLock1(1)",
    "ChannelLock2",
    "ChannelLock3",
]

generic_lock_test = pytest.mark.parametrize(
    "lock_factory",
    lock_factories,
    ids=lock_factory_names,
)

LockLike: TypeAlias = Union[
    CapacityLimiter,
    Semaphore,
    Lock,
    StrictFIFOLock,
    ChannelLock1,
    ChannelLock2,
    ChannelLock3,
]
LockFactory: TypeAlias = Callable[[], LockLike]


# Spawn a bunch of workers that take a lock and then yield; make sure that
# only one worker is ever in the critical section at a time.
@generic_lock_test
async def test_generic_lock_exclusion(lock_factory: LockFactory) -> None:
    LOOPS = 10
    WORKERS = 5
    in_critical_section = False
    acquires = 0

    async def worker(lock_like: LockLike) -> None:
        nonlocal in_critical_section, acquires
        for _ in range(LOOPS):
            async with lock_like:
                acquires += 1
                assert not in_critical_section
                in_critical_section = True
                await _core.checkpoint()
                await _core.checkpoint()
                assert in_critical_section
                in_critical_section = False

    async with _core.open_nursery() as nursery:
        lock_like = lock_factory()
        for _ in range(WORKERS):
            nursery.start_soon(worker, lock_like)
    assert not in_critical_section
    assert acquires == LOOPS * WORKERS


# Several workers queue on the same lock; make sure they each get it, in
# order.
@generic_lock_test
async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None:
    initial_order = []
    record = []
    LOOPS = 5

    async def loopy(name: int, lock_like: LockLike) -> None:
        # Record the order each task was initially scheduled in
        initial_order.append(name)
        for _ in range(LOOPS):
            async with lock_like:
                record.append(name)

    lock_like = lock_factory()
    async with _core.open_nursery() as nursery:
        nursery.start_soon(loopy, 1, lock_like)
        nursery.start_soon(loopy, 2, lock_like)
        nursery.start_soon(loopy, 3, lock_like)
    # The first three could be in any order due to scheduling randomness,
    # but after that they should repeat in the same order
    for i in range(LOOPS):
        assert record[3 * i : 3 * (i + 1)] == initial_order


@generic_lock_test
async def test_generic_lock_acquire_nowait_blocks_acquire(
    lock_factory: LockFactory,
) -> None:
    lock_like = lock_factory()

    record = []

    async def lock_taker() -> None:
        record.append("started")
        async with lock_like:
            pass
        record.append("finished")

    async with _core.open_nursery() as nursery:
        lock_like.acquire_nowait()
        nursery.start_soon(lock_taker)
        await wait_all_tasks_blocked()
        assert record == ["started"]
        lock_like.release()


async def test_lock_acquire_unowned_lock() -> None:
    """Test that trying to acquire a lock whose owner has exited raises an error.
    see https://github.com/python-trio/trio/issues/3035
    """
    assert not GLOBAL_PARKING_LOT_BREAKER
    lock = trio.Lock()
    async with trio.open_nursery() as nursery:
        nursery.start_soon(lock.acquire)
    owner_str = re.escape(str(lock._lot.broken_by[0]))
    with pytest.raises(
        trio.BrokenResourceError,
        match=f"^Owner of this lock exited without releasing: {owner_str}$",
    ):
        await lock.acquire()
    assert not GLOBAL_PARKING_LOT_BREAKER


async def test_lock_multiple_acquire() -> None:
    """Test for error if awaiting on a lock whose owner exits without releasing.
    see https://github.com/python-trio/trio/issues/3035"""
    assert not GLOBAL_PARKING_LOT_BREAKER
    lock = trio.Lock()
    with RaisesGroup(
        Matcher(
            trio.BrokenResourceError,
            match="^Owner of this lock exited without releasing: ",
        ),
    ):
        async with trio.open_nursery() as nursery:
            nursery.start_soon(lock.acquire)
            nursery.start_soon(lock.acquire)
    assert not GLOBAL_PARKING_LOT_BREAKER


async def test_lock_handover() -> None:
    assert not GLOBAL_PARKING_LOT_BREAKER
    child_task: Task | None = None
    lock = trio.Lock()

    # this task acquires the lock
    lock.acquire_nowait()
    assert {
        _core.current_task(): [
            lock._lot,
        ],
    } == GLOBAL_PARKING_LOT_BREAKER

    async with trio.open_nursery() as nursery:
        nursery.start_soon(lock.acquire)
        await wait_all_tasks_blocked()

        # hand over the lock to the child task
        lock.release()

        # check values, and get the identifier out of the dict for later check
        assert len(GLOBAL_PARKING_LOT_BREAKER) == 1
        child_task = next(iter(GLOBAL_PARKING_LOT_BREAKER))
        assert GLOBAL_PARKING_LOT_BREAKER[child_task] == [lock._lot]

    assert lock._lot.broken_by == [child_task]
    assert not GLOBAL_PARKING_LOT_BREAKER
