from __future__ import annotations

from typing import NoReturn

import attrs
import pytest

from .._highlevel_generic import StapledStream
from ..abc import ReceiveStream, SendStream


@attrs.define(slots=False)
class RecordSendStream(SendStream):
    record: list[str | tuple[str, object]] = attrs.Factory(list)

    async def send_all(self, data: object) -> None:
        self.record.append(("send_all", data))

    async def wait_send_all_might_not_block(self) -> None:
        self.record.append("wait_send_all_might_not_block")

    async def aclose(self) -> None:
        self.record.append("aclose")


@attrs.define(slots=False)
class RecordReceiveStream(ReceiveStream):
    record: list[str | tuple[str, int | None]] = attrs.Factory(list)

    async def receive_some(self, max_bytes: int | None = None) -> bytes:
        self.record.append(("receive_some", max_bytes))
        return b""

    async def aclose(self) -> None:
        self.record.append("aclose")


async def test_StapledStream() -> None:
    send_stream = RecordSendStream()
    receive_stream = RecordReceiveStream()
    stapled = StapledStream(send_stream, receive_stream)

    assert stapled.send_stream is send_stream
    assert stapled.receive_stream is receive_stream

    await stapled.send_all(b"foo")
    await stapled.wait_send_all_might_not_block()
    assert send_stream.record == [
        ("send_all", b"foo"),
        "wait_send_all_might_not_block",
    ]
    send_stream.record.clear()

    await stapled.send_eof()
    assert send_stream.record == ["aclose"]
    send_stream.record.clear()

    async def fake_send_eof() -> None:
        send_stream.record.append("send_eof")

    send_stream.send_eof = fake_send_eof  # type: ignore[attr-defined]
    await stapled.send_eof()
    assert send_stream.record == ["send_eof"]

    send_stream.record.clear()
    assert receive_stream.record == []

    await stapled.receive_some(1234)
    assert receive_stream.record == [("receive_some", 1234)]
    assert send_stream.record == []
    receive_stream.record.clear()

    await stapled.aclose()
    assert receive_stream.record == ["aclose"]
    assert send_stream.record == ["aclose"]


async def test_StapledStream_with_erroring_close() -> None:
    # Make sure that if one of the aclose methods errors out, then the other
    # one still gets called.
    class BrokenSendStream(RecordSendStream):
        async def aclose(self) -> NoReturn:
            await super().aclose()
            raise ValueError("send error")

    class BrokenReceiveStream(RecordReceiveStream):
        async def aclose(self) -> NoReturn:
            await super().aclose()
            raise ValueError("recv error")

    stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())

    with pytest.raises(ValueError, match=r"^(send|recv) error$") as excinfo:
        await stapled.aclose()
    assert isinstance(excinfo.value.__context__, ValueError)

    assert stapled.send_stream.record == ["aclose"]
    assert stapled.receive_stream.record == ["aclose"]
