# Implementation of DTLS 1.2, using pyopenssl
# https://datatracker.ietf.org/doc/html/rfc6347
#
# OpenSSL's APIs for DTLS are extremely awkward and limited, which forces us to jump
# through a *lot* of hoops and implement important chunks of the protocol ourselves.
# Hopefully they fix this before implementing DTLS 1.3, because it's a very different
# protocol, and it's probably impossible to pull tricks like we do here.

from __future__ import annotations

import contextlib
import enum
import errno
import hmac
import os
import struct
import warnings
import weakref
from itertools import count
from typing import (
    TYPE_CHECKING,
    Generic,
    TypeVar,
    Union,
)
from weakref import ReferenceType, WeakValueDictionary

import attrs

import trio

from ._util import NoPublicConstructor, final

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable, Iterable, Iterator
    from types import TracebackType

    # See DTLSEndpoint.__init__ for why this is imported here
    from OpenSSL import SSL  # noqa: TC004
    from typing_extensions import Self, TypeAlias, TypeVarTuple, Unpack

    from trio._socket import AddressFormat
    from trio.socket import SocketType

    PosArgsT = TypeVarTuple("PosArgsT")

MAX_UDP_PACKET_SIZE = 65527


def packet_header_overhead(sock: SocketType) -> int:
    if sock.family == trio.socket.AF_INET:
        return 28
    else:
        return 48


def worst_case_mtu(sock: SocketType) -> int:
    if sock.family == trio.socket.AF_INET:
        return 576 - packet_header_overhead(sock)
    else:
        return 1280 - packet_header_overhead(sock)  # TODO: test this line


def best_guess_mtu(sock: SocketType) -> int:
    return 1500 - packet_header_overhead(sock)


# There are a bunch of different RFCs that define these codes, so for a
# comprehensive collection look here:
# https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml
class ContentType(enum.IntEnum):
    change_cipher_spec = 20
    alert = 21
    handshake = 22
    application_data = 23
    heartbeat = 24


class HandshakeType(enum.IntEnum):
    hello_request = 0
    client_hello = 1
    server_hello = 2
    hello_verify_request = 3
    new_session_ticket = 4
    end_of_early_data = 4
    encrypted_extensions = 8
    certificate = 11
    server_key_exchange = 12
    certificate_request = 13
    server_hello_done = 14
    certificate_verify = 15
    client_key_exchange = 16
    finished = 20
    certificate_url = 21
    certificate_status = 22
    supplemental_data = 23
    key_update = 24
    compressed_certificate = 25
    ekt_key = 26
    message_hash = 254


class ProtocolVersion:
    DTLS10 = bytes([254, 255])
    DTLS12 = bytes([254, 253])


EPOCH_MASK = 0xFFFF << (6 * 8)


# Conventions:
# - All functions that handle network data end in _untrusted.
# - All functions end in _untrusted MUST make sure that bad data from the
#   network cannot *only* cause BadPacket to be raised. No IndexError or
#   struct.error or whatever.
class BadPacket(Exception):
    pass


# This checks that the DTLS 'epoch' field is 0, which is true iff we're in the
# initial handshake. It doesn't check the ContentType, because not all
# handshake messages have ContentType==handshake -- for example,
# ChangeCipherSpec is used during the handshake but has its own ContentType.
#
# Cannot fail.
def part_of_handshake_untrusted(packet: bytes) -> bool:
    # If the packet is too short, then slicing will successfully return a
    # short string, which will necessarily fail to match.
    return packet[3:5] == b"\x00\x00"


# Cannot fail
def is_client_hello_untrusted(packet: bytes) -> bool:
    try:
        return (
            packet[0] == ContentType.handshake
            and packet[13] == HandshakeType.client_hello
        )
    except IndexError:
        # Invalid DTLS record
        return False


# DTLS records are:
# - 1 byte content type
# - 2 bytes version
# - 8 bytes epoch+seqno
#    Technically this is 2 bytes epoch then 6 bytes seqno, but we treat it as
#    a single 8-byte integer, where epoch changes are represented as jumping
#    forward by 2**(6*8).
# - 2 bytes payload length (unsigned big-endian)
# - payload
RECORD_HEADER = struct.Struct("!B2sQH")


def to_hex(data: bytes) -> str:  # pragma: no cover
    return data.hex()


@attrs.frozen
class Record:
    content_type: int
    version: bytes = attrs.field(repr=to_hex)
    epoch_seqno: int
    payload: bytes = attrs.field(repr=to_hex)


def records_untrusted(packet: bytes) -> Iterator[Record]:
    i = 0
    while i < len(packet):
        try:
            ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i)
        # Marked as no-cover because at time of writing, this code is unreachable
        # (records_untrusted only gets called on packets that are either trusted or that
        # have passed is_client_hello_untrusted, which filters out short packets)
        except struct.error as exc:  # pragma: no cover
            raise BadPacket("invalid record header") from exc
        i += RECORD_HEADER.size
        payload = packet[i : i + payload_len]
        if len(payload) != payload_len:
            raise BadPacket("short record")
        i += payload_len
        yield Record(ct, version, epoch_seqno, payload)


def encode_record(record: Record) -> bytes:
    header = RECORD_HEADER.pack(
        record.content_type,
        record.version,
        record.epoch_seqno,
        len(record.payload),
    )
    return header + record.payload


# Handshake messages are:
# - 1 byte message type
# - 3 bytes total message length
# - 2 bytes message sequence number
# - 3 bytes fragment offset
# - 3 bytes fragment length
HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s")


@attrs.frozen
class HandshakeFragment:
    msg_type: int
    msg_len: int
    msg_seq: int
    frag_offset: int
    frag_len: int
    frag: bytes = attrs.field(repr=to_hex)


def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment:
    # Raises BadPacket if decoding fails
    try:
        (
            msg_type,
            msg_len_bytes,
            msg_seq,
            frag_offset_bytes,
            frag_len_bytes,
        ) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload)
    except struct.error as exc:  # TODO: test this line
        raise BadPacket("bad handshake message header") from exc
    # 'struct' doesn't have built-in support for 24-bit integers, so we
    # have to do it by hand. These can't fail.
    msg_len = int.from_bytes(msg_len_bytes, "big")
    frag_offset = int.from_bytes(frag_offset_bytes, "big")
    frag_len = int.from_bytes(frag_len_bytes, "big")
    frag = payload[HANDSHAKE_MESSAGE_HEADER.size :]
    if len(frag) != frag_len:
        raise BadPacket("handshake fragment length doesn't match record length")
    return HandshakeFragment(
        msg_type,
        msg_len,
        msg_seq,
        frag_offset,
        frag_len,
        frag,
    )


def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes:
    hs_header = HANDSHAKE_MESSAGE_HEADER.pack(
        hsf.msg_type,
        hsf.msg_len.to_bytes(3, "big"),
        hsf.msg_seq,
        hsf.frag_offset.to_bytes(3, "big"),
        hsf.frag_len.to_bytes(3, "big"),
    )
    return hs_header + hsf.frag


def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]:
    # Raises BadPacket if parsing fails
    # Returns (record epoch_seqno, cookie from the packet, data that should be
    # hashed into cookie)
    try:
        # ClientHello has to be the first record in the packet
        record = next(records_untrusted(packet))
        # no-cover because at time of writing, this is unreachable:
        # decode_client_hello_untrusted is only called on packets that have passed
        # is_client_hello_untrusted, which confirms the content type.
        if record.content_type != ContentType.handshake:  # pragma: no cover
            raise BadPacket("not a handshake record")
        fragment = decode_handshake_fragment_untrusted(record.payload)
        if fragment.msg_type != HandshakeType.client_hello:
            raise BadPacket("not a ClientHello")
        # ClientHello can't be fragmented, because reassembly requires holding
        # per-connection state, and we refuse to allocate per-connection state
        # until after we get a valid ClientHello.
        if fragment.frag_offset != 0:
            raise BadPacket("fragmented ClientHello")
        if fragment.frag_len != fragment.msg_len:
            raise BadPacket("fragmented ClientHello")

        # As per RFC 6347:
        #
        #   When responding to a HelloVerifyRequest, the client MUST use the
        #   same parameter values (version, random, session_id, cipher_suites,
        #   compression_method) as it did in the original ClientHello.  The
        #   server SHOULD use those values to generate its cookie and verify that
        #   they are correct upon cookie receipt.
        #
        # However, the record-layer framing can and will change (e.g. the
        # second ClientHello will have a new record-layer sequence number). So
        # we need to pull out the handshake message alone, discarding the
        # record-layer stuff, and then we're going to hash all of it *except*
        # the cookie.

        body = fragment.frag
        # ClientHello is:
        #
        # - 2 bytes client_version
        # - 32 bytes random
        # - 1 byte session_id length
        # - session_id
        # - 1 byte cookie length
        # - cookie
        # - everything else
        #
        # So to find the cookie, so we need to figure out how long the
        # session_id is and skip past it.
        session_id_len = body[2 + 32]
        cookie_len_offset = 2 + 32 + 1 + session_id_len
        cookie_len = body[cookie_len_offset]

        cookie_start = cookie_len_offset + 1
        cookie_end = cookie_start + cookie_len

        before_cookie = body[:cookie_len_offset]
        cookie = body[cookie_start:cookie_end]
        after_cookie = body[cookie_end:]

        if len(cookie) != cookie_len:
            raise BadPacket("short cookie")
        return (record.epoch_seqno, cookie, before_cookie + after_cookie)

    except (struct.error, IndexError) as exc:
        raise BadPacket("bad ClientHello") from exc


@attrs.frozen
class HandshakeMessage:
    record_version: bytes = attrs.field(repr=to_hex)
    msg_type: HandshakeType
    msg_seq: int
    body: bytearray = attrs.field(repr=to_hex)


# ChangeCipherSpec is part of the handshake, but it's not a "handshake
# message" and can't be fragmented the same way. Sigh.
@attrs.frozen
class PseudoHandshakeMessage:
    record_version: bytes = attrs.field(repr=to_hex)
    content_type: int
    payload: bytes = attrs.field(repr=to_hex)


# The final record in a handshake is Finished, which is encrypted, can't be fragmented
# (at least by us), and keeps its record number (because it's in a new epoch). So we
# just pass it through unchanged. (Fortunately, the payload is only a single hash value,
# so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough
# that it never requires fragmenting to fit into a UDP packet.
@attrs.frozen
class OpaqueHandshakeMessage:
    record: Record


_AnyHandshakeMessage: TypeAlias = Union[
    HandshakeMessage,
    PseudoHandshakeMessage,
    OpaqueHandshakeMessage,
]


# This takes a raw outgoing handshake volley that openssl generated, and
# reconstructs the handshake messages inside it, so that we can repack them
# into records while retransmitting. So the data ought to be well-behaved --
# it's not coming from the network.
def decode_volley_trusted(
    volley: bytes,
) -> list[_AnyHandshakeMessage]:
    messages: list[_AnyHandshakeMessage] = []
    messages_by_seq = {}
    for record in records_untrusted(volley):
        # ChangeCipherSpec isn't a handshake message, so it can't be fragmented.
        # Handshake messages with epoch > 0 are encrypted, so we can't fragment them
        # either. Fortunately, ChangeCipherSpec has a 1 byte payload, and the only
        # encrypted handshake message is Finished, whose payload is a single hash value
        # -- so 32 bytes for SHA-256, 64 for SHA-512, etc. Neither is going to be so
        # large that it has to be fragmented to fit into a single packet.
        if record.epoch_seqno & EPOCH_MASK:
            messages.append(OpaqueHandshakeMessage(record))
        elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert):
            messages.append(
                PseudoHandshakeMessage(
                    record.version,
                    record.content_type,
                    record.payload,
                ),
            )
        else:
            assert record.content_type == ContentType.handshake
            fragment = decode_handshake_fragment_untrusted(record.payload)
            msg_type = HandshakeType(fragment.msg_type)
            if fragment.msg_seq not in messages_by_seq:
                msg = HandshakeMessage(
                    record.version,
                    msg_type,
                    fragment.msg_seq,
                    bytearray(fragment.msg_len),
                )
                messages.append(msg)
                messages_by_seq[fragment.msg_seq] = msg
            else:
                msg = messages_by_seq[fragment.msg_seq]
            assert msg.msg_type == fragment.msg_type
            assert msg.msg_seq == fragment.msg_seq
            assert len(msg.body) == fragment.msg_len

            msg.body[
                fragment.frag_offset : fragment.frag_offset + fragment.frag_len
            ] = fragment.frag

    return messages


class RecordEncoder:
    def __init__(self) -> None:
        self._record_seq = count()

    def set_first_record_number(self, n: int) -> None:
        self._record_seq = count(n)

    def encode_volley(
        self,
        messages: Iterable[_AnyHandshakeMessage],
        mtu: int,
    ) -> list[bytearray]:
        packets = []
        packet = bytearray()
        for message in messages:
            if isinstance(message, OpaqueHandshakeMessage):
                encoded = encode_record(message.record)
                if mtu - len(packet) - len(encoded) <= 0:  # TODO: test this line
                    packets.append(packet)
                    packet = bytearray()
                packet += encoded
                assert len(packet) <= mtu
            elif isinstance(message, PseudoHandshakeMessage):
                space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload)
                if space <= 0:  # TODO: test this line
                    packets.append(packet)
                    packet = bytearray()
                packet += RECORD_HEADER.pack(
                    message.content_type,
                    message.record_version,
                    next(self._record_seq),
                    len(message.payload),
                )
                packet += message.payload
                assert len(packet) <= mtu
            else:
                msg_len_bytes = len(message.body).to_bytes(3, "big")
                frag_offset = 0
                frags_encoded = 0
                # If message.body is empty, then we still want to encode it in one
                # fragment, not zero.
                while frag_offset < len(message.body) or not frags_encoded:
                    space = (
                        mtu
                        - len(packet)
                        - RECORD_HEADER.size
                        - HANDSHAKE_MESSAGE_HEADER.size
                    )
                    if space <= 0:
                        packets.append(packet)
                        packet = bytearray()
                        continue
                    frag = message.body[frag_offset : frag_offset + space]
                    frag_offset_bytes = frag_offset.to_bytes(3, "big")
                    frag_len_bytes = len(frag).to_bytes(3, "big")
                    frag_offset += len(frag)

                    packet += RECORD_HEADER.pack(
                        ContentType.handshake,
                        message.record_version,
                        next(self._record_seq),
                        HANDSHAKE_MESSAGE_HEADER.size + len(frag),
                    )

                    packet += HANDSHAKE_MESSAGE_HEADER.pack(
                        message.msg_type,
                        msg_len_bytes,
                        message.msg_seq,
                        frag_offset_bytes,
                        frag_len_bytes,
                    )

                    packet += frag

                    frags_encoded += 1
                    assert len(packet) <= mtu

        if packet:
            packets.append(packet)

        return packets


# This bit requires implementing a bona fide cryptographic protocol, so even though it's
# a simple one let's take a moment to discuss the design.
#
# Our goal is to force new incoming handshakes that claim to be coming from a
# given ip:port to prove that they can also receive packets sent to that
# ip:port. (There's nothing in UDP to stop someone from forging the return
# address, and it's often used for stuff like DoS reflection attacks, where
# an attacker tries to trick us into sending data at some innocent victim.)
# For more details, see:
#
#    https://datatracker.ietf.org/doc/html/rfc6347#section-4.2.1
#
# To do this, when we receive an initial ClientHello, we calculate a magic
# cookie, and send it back as a HelloVerifyRequest. Then the client sends us a
# second ClientHello, this time with the magic cookie included, and after we
# check that this cookie is valid we go ahead and start the handshake proper.
#
# So the magic cookie needs the following properties:
# - No-one can forge it without knowing our secret key
# - It ensures that the ip, port, and ClientHello contents from the response
#   match those in the challenge
# - It expires after a short-ish period (so that if an attacker manages to steal one, it
#   won't be useful for long)
# - It doesn't require storing any peer-specific state on our side
#
# To do that, we take the ip/port/ClientHello data and compute an HMAC of them, using a
# secret key we generate on startup. We also include:
#
# - The current time (using Trio's clock), rounded to the nearest 30 seconds
# - A random salt
#
# Then the cookie is the salt and the HMAC digest concatenated together.
#
# When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute
# the HMAC digest, for both the current time and the current time minus 30 seconds, and
# if either of them match, we consider the cookie good.
#
# Including the rounded-off time like this means that each cookie is good for at least
# 30 seconds, and possibly as much as 60 seconds.
#
# The salt is probably not necessary -- I'm pretty sure that all it does is make it hard
# for an attacker to figure out when our clock ticks over a 30 second boundary. Which is
# probably pretty harmless? But it's easier to add the salt than to convince myself that
# it's *completely* harmless, so, salt it is.

COOKIE_REFRESH_INTERVAL = 30  # seconds
KEY_BYTES = 32
COOKIE_HASH = "sha256"
SALT_BYTES = 8
# 32 bytes was the maximum cookie length in DTLS 1.0. DTLS 1.2 raised it to 255. I doubt
# there are any DTLS 1.0 implementations still in the wild, but really 32 bytes is
# plenty, and it also gets rid of a confusing warning in Wireshark output.
#
# We truncate the cookie to 32 bytes, of which 8 bytes is salt, so that leaves 24 bytes
# of truncated HMAC = 192 bit security, which is still massive overkill. (TCP uses 32
# *bits* for this.) HMAC truncation is explicitly noted as safe in RFC 2104:
#   https://datatracker.ietf.org/doc/html/rfc2104#section-5
COOKIE_LENGTH = 32


def _current_cookie_tick() -> int:
    return int(trio.current_time() / COOKIE_REFRESH_INTERVAL)


# Simple deterministic and invertible serializer -- i.e., a useful tool for converting
# structured data into something we can cryptographically sign.
def _signable(*fields: bytes) -> bytes:
    out: list[bytes] = []
    for field in fields:
        out.extend((struct.pack("!Q", len(field)), field))
    return b"".join(out)


def _make_cookie(
    key: bytes,
    salt: bytes,
    tick: int,
    address: AddressFormat,
    client_hello_bits: bytes,
) -> bytes:
    assert len(salt) == SALT_BYTES
    assert len(key) == KEY_BYTES

    signable_data = _signable(
        salt,
        struct.pack("!Q", tick),
        # address is a mix of strings and ints, and variable length, so pack
        # it into a single nested field
        _signable(*(str(part).encode() for part in address)),
        client_hello_bits,
    )

    return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH]


def valid_cookie(
    key: bytes,
    cookie: bytes,
    address: AddressFormat,
    client_hello_bits: bytes,
) -> bool:
    if len(cookie) > SALT_BYTES:
        salt = cookie[:SALT_BYTES]

        tick = _current_cookie_tick()

        cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
        old_cookie = _make_cookie(
            key,
            salt,
            max(tick - 1, 0),
            address,
            client_hello_bits,
        )

        # I doubt using a short-circuiting 'or' here would leak any meaningful
        # information, but why risk it when '|' is just as easy.
        return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest(
            cookie,
            old_cookie,
        )
    else:
        return False


def challenge_for(
    key: bytes,
    address: AddressFormat,
    epoch_seqno: int,
    client_hello_bits: bytes,
) -> bytes:
    salt = os.urandom(SALT_BYTES)
    tick = _current_cookie_tick()
    cookie = _make_cookie(key, salt, tick, address, client_hello_bits)

    # HelloVerifyRequest body is:
    # - 2 bytes version
    # - length-prefixed cookie
    #
    # The DTLS 1.2 spec says that for this message specifically we should use
    # the DTLS 1.0 version.
    #
    # (It also says the opposite of that, but that part is a mistake:
    #    https://www.rfc-editor.org/errata/eid4103
    # ).
    #
    # And I guess we use this for both the message-level and record-level
    # ProtocolVersions, since we haven't negotiated anything else yet?
    body = ProtocolVersion.DTLS10 + bytes([len(cookie)]) + cookie

    # RFC says have to copy the client's record number
    # Errata says it should be handshake message number
    # Openssl copies back record sequence number, and always sets message seq
    # number 0. So I guess we'll follow openssl.
    hs = HandshakeFragment(
        msg_type=HandshakeType.hello_verify_request,
        msg_len=len(body),
        msg_seq=0,
        frag_offset=0,
        frag_len=len(body),
        frag=body,
    )
    payload = encode_handshake_fragment(hs)

    packet = encode_record(
        Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload),
    )
    return packet


_T = TypeVar("_T")


class _Queue(Generic[_T]):
    def __init__(self, incoming_packets_buffer: int | float) -> None:  # noqa: PYI041
        self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer)


def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
    chunks = []
    while True:
        try:
            chunk = read_fn(2**14)  # max TLS record size
        except SSL.WantReadError:
            break
        chunks.append(chunk)
    return b"".join(chunks)


async def handle_client_hello_untrusted(
    endpoint: DTLSEndpoint,
    address: AddressFormat,
    packet: bytes,
) -> None:
    # it's trivial to write a simple function that directly calls this to
    # get code coverage, but it should maybe:
    # 1. be removed
    # 2. be asserted
    # 3. Write a complicated test case where this happens "organically"
    if endpoint._listening_context is None:  # pragma: no cover
        return

    try:
        epoch_seqno, cookie, bits = decode_client_hello_untrusted(packet)
    except BadPacket:
        return

    if endpoint._listening_key is None:
        endpoint._listening_key = os.urandom(KEY_BYTES)

    if not valid_cookie(endpoint._listening_key, cookie, address, bits):
        challenge_packet = challenge_for(
            endpoint._listening_key,
            address,
            epoch_seqno,
            bits,
        )
        try:
            async with endpoint._send_lock:
                await endpoint.socket.sendto(challenge_packet, address)
        except (OSError, trio.ClosedResourceError):
            pass
    else:
        # We got a real, valid ClientHello!
        stream = DTLSChannel._create(endpoint, address, endpoint._listening_context)
        # Our HelloRetryRequest had some sequence number. We need our future sequence
        # numbers to be larger than it, so our peer knows that our future records aren't
        # stale/duplicates. But, we don't know what this sequence number was. What we do
        # know is:
        # - the HelloRetryRequest seqno was copied it from the initial ClientHello
        # - the new ClientHello has a higher seqno than the initial ClientHello
        # So, if we copy the new ClientHello's seqno into our first real handshake
        # record and increment from there, that should work.
        stream._record_encoder.set_first_record_number(epoch_seqno)
        # Process the ClientHello
        try:
            stream._ssl.bio_write(packet)
            stream._ssl.DTLSv1_listen()
        except SSL.Error:  # pragma: no cover
            # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello
            # after all.
            return

        # Check if we have an existing association
        old_stream = endpoint._streams.get(address)
        if old_stream is not None:
            if old_stream._client_hello == (cookie, bits):
                # ...This was just a duplicate of the last ClientHello, so never mind.
                return
            else:
                # Ok, this *really is* a new handshake; the old stream should go away.
                old_stream._set_replaced()
        stream._client_hello = (cookie, bits)
        endpoint._streams[address] = stream
        endpoint._incoming_connections_q.s.send_nowait(stream)


async def dtls_receive_loop(
    endpoint_ref: ReferenceType[DTLSEndpoint],
    sock: SocketType,
) -> None:
    try:
        while True:
            try:
                packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE)
            except OSError as exc:
                if exc.errno == errno.ECONNRESET:
                    # Windows only: "On a UDP-datagram socket [ECONNRESET]
                    # indicates a previous send operation resulted in an ICMP Port
                    # Unreachable message" -- https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom
                    #
                    # This is totally useless -- there's nothing we can do with this
                    # information. So we just ignore it and retry the recv.
                    continue
                else:
                    raise
            endpoint = endpoint_ref()
            try:
                if endpoint is None:
                    return
                if is_client_hello_untrusted(packet):
                    await handle_client_hello_untrusted(endpoint, address, packet)
                elif address in endpoint._streams:
                    stream = endpoint._streams[address]
                    if stream._did_handshake and part_of_handshake_untrusted(packet):
                        # The peer just sent us more handshake messages, that aren't a
                        # ClientHello, and we thought the handshake was done. Some of
                        # the packets that we sent to finish the handshake must have
                        # gotten lost. So re-send them. We do this directly here instead
                        # of just putting it into the queue and letting the receiver do
                        # it, because there's no guarantee that anyone is reading from
                        # the queue, because we think the handshake is done!
                        await stream._resend_final_volley()
                    else:
                        try:
                            # mypy for some reason cannot determine type of _q
                            stream._q.s.send_nowait(packet)  # type:ignore[has-type]
                        except trio.WouldBlock:
                            stream._packets_dropped_in_trio += 1
                else:
                    # Drop packet
                    pass
            finally:
                del endpoint
    except trio.ClosedResourceError:
        # socket was closed
        return
    except OSError as exc:
        if exc.errno in (errno.EBADF, errno.ENOTSOCK):
            # socket was closed
            return
        else:  # pragma: no cover
            # ??? shouldn't happen
            raise


@attrs.frozen
class DTLSChannelStatistics:
    """Currently this has only one attribute:

    - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
      incoming packets from this peer that Trio successfully received from the
      network, but then got dropped because the internal channel buffer was full. If
      this is non-zero, then you might want to call ``receive`` more often, or use a
      larger ``incoming_packets_buffer``, or just not worry about it because your
      UDP-based protocol should be able to handle the occasional lost packet, right?

    """

    incoming_packets_dropped_in_trio: int


@final
class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
    """A DTLS connection.

    This class has no public constructor – you get instances by calling
    `DTLSEndpoint.serve` or `~DTLSEndpoint.connect`.

    .. attribute:: endpoint

       The `DTLSEndpoint` that this connection is using.

    .. attribute:: peer_address

       The IP/port of the remote peer that this connection is associated with.

    """

    def __init__(
        self,
        endpoint: DTLSEndpoint,
        peer_address: AddressFormat,
        ctx: SSL.Context,
    ) -> None:
        self.endpoint = endpoint
        self.peer_address = peer_address
        self._packets_dropped_in_trio = 0
        self._client_hello = None
        self._did_handshake = False
        # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to
        # stop openssl from trying to query the memory BIO's MTU and then breaking, and
        # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to
        # support and isn't useful anyway -- especially for DTLS where it's equivalent
        # to just performing a new handshake.
        ctx.set_options(
            SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION,  # type: ignore[attr-defined]
        )
        self._ssl = SSL.Connection(ctx)
        self._handshake_mtu = 0
        # This calls self._ssl.set_ciphertext_mtu, which is important, because if you
        # don't call it then openssl doesn't work.
        self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket))
        self._replaced = False
        self._closed = False
        self._q = _Queue[bytes](endpoint.incoming_packets_buffer)
        self._handshake_lock = trio.Lock()
        self._record_encoder: RecordEncoder = RecordEncoder()

        self._final_volley: list[_AnyHandshakeMessage] = []

    def _set_replaced(self) -> None:
        self._replaced = True
        # Any packets we already received could maybe possibly still be processed, but
        # there are no more coming. So we close this on the sender side.
        self._q.s.close()

    def _check_replaced(self) -> None:
        if self._replaced:
            raise trio.BrokenResourceError(
                "peer tore down this connection to start a new one",
            )

    # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU
    # estimate

    # XX should we send close-notify when closing? It seems particularly pointless for
    # DTLS where packets are all independent and can be lost anyway. We do at least need
    # to handle receiving it properly though, which might be easier if we send it...

    def close(self) -> None:
        """Close this connection.

        `DTLSChannel`\\s don't actually own any OS-level resources – the
        socket is owned by the `DTLSEndpoint`, not the individual connections. So
        you don't really *have* to call this. But it will interrupt any other tasks
        calling `receive` with a `ClosedResourceError`, and cause future attempts to use
        this connection to fail.

        You can also use this object as a synchronous or asynchronous context manager.

        """
        if self._closed:
            return
        self._closed = True
        if self.endpoint._streams.get(self.peer_address) is self:
            del self.endpoint._streams[self.peer_address]
        # Will wake any tasks waiting on self._q.get with a
        # ClosedResourceError
        self._q.r.close()

    def __enter__(self) -> Self:
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        return self.close()

    async def aclose(self) -> None:
        """Close this connection, but asynchronously.

        This is included to satisfy the `trio.abc.Channel` contract. It's
        identical to `close`, but async.

        """
        self.close()
        await trio.lowlevel.checkpoint()

    async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None:
        packets = self._record_encoder.encode_volley(
            volley_messages,
            self._handshake_mtu,
        )
        for packet in packets:
            async with self.endpoint._send_lock:
                await self.endpoint.socket.sendto(packet, self.peer_address)

    async def _resend_final_volley(self) -> None:
        await self._send_volley(self._final_volley)

    async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None:
        """Perform the handshake.

        Calling this is optional – if you don't, then it will be automatically called
        the first time you call `send` or `receive`. But calling it explicitly can be
        useful in case you want to control the retransmit timeout, use a cancel scope to
        place an overall timeout on the handshake, or catch errors from the handshake
        specifically.

        It's safe to call this multiple times, or call it simultaneously from multiple
        tasks – the first call will perform the handshake, and the rest will be no-ops.

        Args:

          initial_retransmit_timeout (float): Since UDP is an unreliable protocol, it's
            possible that some of the packets we send during the handshake will get
            lost. To handle this, DTLS uses a timer to automatically retransmit
            handshake packets that don't receive a response. This lets you set the
            timeout we use to detect packet loss. Ideally, it should be set to ~1.5
            times the round-trip time to your peer, but 1 second is a reasonable
            default. There's `some useful guidance here
            <https://tlswg.org/dtls13-spec/draft-ietf-tls-dtls13.html#name-timer-values>`__.

            This is the *initial* timeout, because if packets keep being lost then Trio
            will automatically back off to longer values, to avoid overloading the
            network.

        """
        async with self._handshake_lock:
            if self._did_handshake:
                return

            timeout = initial_retransmit_timeout
            volley_messages: list[_AnyHandshakeMessage] = []
            volley_failed_sends = 0

            def read_volley() -> list[_AnyHandshakeMessage]:
                volley_bytes = _read_loop(self._ssl.bio_read)
                new_volley_messages = decode_volley_trusted(volley_bytes)
                if (
                    new_volley_messages
                    and volley_messages
                    and isinstance(new_volley_messages[0], HandshakeMessage)
                    and isinstance(volley_messages[0], HandshakeMessage)
                    and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq
                ):
                    # openssl decided to retransmit; discard because we handle
                    # retransmits ourselves
                    return []
                else:
                    return new_volley_messages

            # If we're a client, we send the initial volley. If we're a server, then
            # the initial ClientHello has already been inserted into self._ssl's
            # read BIO. So either way, we start by generating a new volley.
            with contextlib.suppress(SSL.WantReadError):
                self._ssl.do_handshake()
            volley_messages = read_volley()
            # If we don't have messages to send in our initial volley, then something
            # has gone very wrong. (I'm not sure this can actually happen without an
            # error from OpenSSL, but we check just in case.)
            if not volley_messages:  # pragma: no cover
                raise SSL.Error("something wrong with peer's ClientHello")

            while True:
                # -- at this point, we need to either send or re-send a volley --
                assert volley_messages
                self._check_replaced()
                await self._send_volley(volley_messages)
                # -- then this is where we wait for a reply --
                self.endpoint._ensure_receive_loop()
                with trio.move_on_after(timeout) as cscope:
                    async for packet in self._q.r:
                        self._ssl.bio_write(packet)
                        try:
                            self._ssl.do_handshake()
                        # We ignore generic SSL.Error here, because you can get those
                        # from random invalid packets
                        except (SSL.WantReadError, SSL.Error):
                            pass
                        else:
                            # No exception -> the handshake is done, and we can
                            # switch into data transfer mode.
                            self._did_handshake = True
                            # Might be empty, but that's ok -- we'll just send no
                            # packets.
                            self._final_volley = read_volley()
                            await self._send_volley(self._final_volley)
                            return
                        maybe_volley = read_volley()
                        if maybe_volley:
                            if (
                                isinstance(maybe_volley[0], PseudoHandshakeMessage)
                                and maybe_volley[0].content_type == ContentType.alert
                            ):  # TODO: test this line
                                # we're sending an alert (e.g. due to a corrupted
                                # packet). We want to send it once, but don't save it to
                                # retransmit -- keep the last volley as the current
                                # volley.
                                await self._send_volley(maybe_volley)
                            else:
                                # We managed to get all of the peer's volley and
                                # generate a new one ourselves! break out of the 'for'
                                # loop and restart the timer.
                                volley_messages = maybe_volley
                                # "Implementations SHOULD retain the current timer value
                                # until a transmission without loss occurs, at which
                                # time the value may be reset to the initial value."
                                if volley_failed_sends == 0:
                                    timeout = initial_retransmit_timeout
                                volley_failed_sends = 0
                                break
                    else:
                        assert self._replaced
                        self._check_replaced()
                if cscope.cancelled_caught:
                    # Timeout expired. Double timeout for backoff, with a limit of 60
                    # seconds (this matches what openssl does, and also the
                    # recommendation in draft-ietf-tls-dtls13).
                    timeout = min(2 * timeout, 60.0)
                    volley_failed_sends += 1
                    if volley_failed_sends == 2:
                        # We tried sending this twice and they both failed. Maybe our
                        # PMTU estimate is wrong? Let's try dropping it to the minimum
                        # and hope that helps.
                        self._handshake_mtu = min(
                            self._handshake_mtu,
                            worst_case_mtu(self.endpoint.socket),
                        )

    async def send(self, data: bytes) -> None:
        """Send a packet of data, securely."""

        if self._closed:
            raise trio.ClosedResourceError
        if not data:
            raise ValueError("openssl doesn't support sending empty DTLS packets")
        if not self._did_handshake:
            await self.do_handshake()
        self._check_replaced()
        self._ssl.write(data)
        async with self.endpoint._send_lock:
            await self.endpoint.socket.sendto(
                _read_loop(self._ssl.bio_read),
                self.peer_address,
            )

    async def receive(self) -> bytes:
        """Fetch the next packet of data from this connection's peer, waiting if
        necessary.

        This is safe to call from multiple tasks simultaneously, in case you have some
        reason to do that. And more importantly, it's cancellation-safe, meaning that
        cancelling a call to `receive` will never cause a packet to be lost or corrupt
        the underlying connection.

        """
        if not self._did_handshake:
            await self.do_handshake()
        # If the packet isn't really valid, then openssl can decode it to the empty
        # string (e.g. b/c it's a late-arriving handshake packet, or a duplicate copy of
        # a data packet). Skip over these instead of returning them.
        while True:
            try:
                packet = await self._q.r.receive()
            except trio.EndOfChannel:
                assert self._replaced
                self._check_replaced()
            self._ssl.bio_write(packet)
            cleartext = _read_loop(self._ssl.read)
            if cleartext:
                return cleartext

    def set_ciphertext_mtu(self, new_mtu: int) -> None:
        """Tells Trio the `largest amount of data that can be sent in a single packet to
        this peer <https://en.wikipedia.org/wiki/Maximum_transmission_unit>`__.

        Trio doesn't actually enforce this limit – if you pass a huge packet to `send`,
        then we'll dutifully encrypt it and attempt to send it. But calling this method
        does have two useful effects:

        - If called before the handshake is performed, then Trio will automatically
          fragment handshake messages to fit within the given MTU. It also might
          fragment them even smaller, if it detects signs of packet loss, so setting
          this should never be necessary to make a successful connection. But, the
          packet loss detection only happens after multiple timeouts have expired, so if
          you have reason to believe that a smaller MTU is required, then you can set
          this to skip those timeouts and establish the connection more quickly.

        - It changes the value returned from `get_cleartext_mtu`. So if you have some
          kind of estimate of the network-level MTU, then you can use this to figure out
          how much overhead DTLS will need for hashes/padding/etc., and how much space
          you have left for your application data.

        The MTU here is measuring the largest UDP *payload* you think can be sent, the
        amount of encrypted data that can be handed to the operating system in a single
        call to `send`. It should *not* include IP/UDP headers. Note that OS estimates
        of the MTU often are link-layer MTUs, so you have to subtract off 28 bytes on
        IPv4 and 48 bytes on IPv6 to get the ciphertext MTU.

        By default, Trio assumes an MTU of 1472 bytes on IPv4, and 1452 bytes on IPv6,
        which correspond to the common Ethernet MTU of 1500 bytes after accounting for
        IP/UDP overhead.

        """
        self._handshake_mtu = new_mtu
        self._ssl.set_ciphertext_mtu(new_mtu)

    def get_cleartext_mtu(self) -> int:
        """Returns the largest number of bytes that you can pass in a single call to
        `send` while still fitting within the network-level MTU.

        See `set_ciphertext_mtu` for more details.

        """
        if not self._did_handshake:
            raise trio.NeedHandshakeError
        return self._ssl.get_cleartext_mtu()  # type: ignore[no-any-return]

    def statistics(self) -> DTLSChannelStatistics:
        """Returns a `DTLSChannelStatistics` object with statistics about this connection."""
        return DTLSChannelStatistics(self._packets_dropped_in_trio)


@final
class DTLSEndpoint:
    """A DTLS endpoint.

    A single UDP socket can handle arbitrarily many DTLS connections simultaneously,
    acting as a client or server as needed. A `DTLSEndpoint` object holds a UDP socket
    and manages these connections, which are represented as `DTLSChannel` objects.

    Args:
      socket: (trio.socket.SocketType): A ``SOCK_DGRAM`` socket. If you want to accept
        incoming connections in server mode, then you should probably bind the socket to
        some known port.
      incoming_packets_buffer (int): Each `DTLSChannel` using this socket has its own
        buffer that holds incoming packets until you call `~DTLSChannel.receive` to read
        them. This lets you adjust the size of this buffer. `~DTLSChannel.statistics`
        lets you check if the buffer has overflowed.

    .. attribute:: socket
                   incoming_packets_buffer

       Both constructor arguments are also exposed as attributes, in case you need to
       access them later.

    """

    def __init__(
        self,
        socket: SocketType,
        *,
        incoming_packets_buffer: int = 10,
    ) -> None:
        # We do this lazily on first construction, so only people who actually use DTLS
        # have to install PyOpenSSL.
        global SSL
        from OpenSSL import SSL

        # for __del__, in case the next line raises
        self._initialized: bool = False
        if socket.type != trio.socket.SOCK_DGRAM:
            raise ValueError("DTLS requires a SOCK_DGRAM socket")
        self._initialized = True
        self.socket: SocketType = socket

        self.incoming_packets_buffer = incoming_packets_buffer
        self._token = trio.lowlevel.current_trio_token()
        # We don't need to track handshaking vs non-handshake connections
        # separately. We only keep one connection per remote address; as soon
        # as a peer provides a valid cookie, we can immediately tear down the
        # old connection.
        # {remote address: DTLSChannel}
        self._streams: WeakValueDictionary[AddressFormat, DTLSChannel] = (
            WeakValueDictionary()
        )
        self._listening_context: SSL.Context | None = None
        self._listening_key: bytes | None = None
        self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
        self._send_lock = trio.Lock()
        self._closed = False
        self._receive_loop_spawned = False

    def _ensure_receive_loop(self) -> None:
        # We have to spawn this lazily, because on Windows it will immediately error out
        # if the socket isn't already bound -- which for clients might not happen until
        # after we send our first packet.
        if not self._receive_loop_spawned:
            trio.lowlevel.spawn_system_task(
                dtls_receive_loop,
                weakref.ref(self),
                self.socket,
            )
            self._receive_loop_spawned = True

    def __del__(self) -> None:
        # Do nothing if this object was never fully constructed
        if not self._initialized:
            return
        # Close the socket in Trio context (if our Trio context still exists), so that
        # the background task gets notified about the closure and can exit.
        if not self._closed:
            with contextlib.suppress(RuntimeError):
                self._token.run_sync_soon(self.close)
            # Do this last, because it might raise an exception
            warnings.warn(
                f"unclosed DTLS endpoint {self!r}",
                ResourceWarning,
                source=self,
                stacklevel=1,
            )

    def close(self) -> None:
        """Close this socket, and all associated DTLS connections.

        This object can also be used as a context manager.

        """
        self._closed = True
        self.socket.close()
        for stream in list(self._streams.values()):
            stream.close()
        self._incoming_connections_q.s.close()

    def __enter__(self) -> Self:
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        return self.close()

    def _check_closed(self) -> None:
        if self._closed:
            raise trio.ClosedResourceError

    async def serve(
        self,
        ssl_context: SSL.Context,
        async_fn: Callable[[DTLSChannel, Unpack[PosArgsT]], Awaitable[object]],
        *args: Unpack[PosArgsT],
        task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED,
    ) -> None:
        """Listen for incoming connections, and spawn a handler for each using an
        internal nursery.

        Similar to `~trio.serve_tcp`, this function never returns until cancelled, or
        the `DTLSEndpoint` is closed and all handlers have exited.

        Usage commonly looks like::

            async def handler(dtls_channel):
                ...

            async with trio.open_nursery() as nursery:
                await nursery.start(dtls_endpoint.serve, ssl_context, handler)
                # ... do other things here ...

        The ``dtls_channel`` passed into the handler function has already performed the
        "cookie exchange" part of the DTLS handshake, so the peer address is
        trustworthy. But the actual cryptographic handshake doesn't happen until you
        start using it, giving you a chance for any last minute configuration, and the
        option to catch and handle handshake errors.

        Args:
          ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for
            incoming connections.
          async_fn: The handler function that will be invoked for each incoming
            connection.
          *args: Additional arguments to pass to the handler function.

        """
        self._check_closed()
        if self._listening_context is not None:
            raise trio.BusyResourceError("another task is already listening")
        try:
            self.socket.getsockname()
        except OSError:  # TODO: test this line
            raise RuntimeError(
                "DTLS socket must be bound before it can serve",
            ) from None
        self._ensure_receive_loop()
        # We do cookie verification ourselves, so tell OpenSSL not to worry about it.
        # (See also _inject_client_hello_untrusted.)
        ssl_context.set_cookie_verify_callback(lambda *_: True)
        try:
            self._listening_context = ssl_context
            task_status.started()

            async def handler_wrapper(stream: DTLSChannel) -> None:
                with stream:
                    await async_fn(stream, *args)

            async with trio.open_nursery() as nursery:
                async for stream in self._incoming_connections_q.r:  # pragma: no branch
                    nursery.start_soon(handler_wrapper, stream)
        finally:
            self._listening_context = None

    def connect(
        self,
        address: tuple[str, int],
        ssl_context: SSL.Context,
    ) -> DTLSChannel:
        """Initiate an outgoing DTLS connection.

        Notice that this is a synchronous method. That's because it doesn't actually
        initiate any I/O – it just sets up a `DTLSChannel` object. The actual handshake
        doesn't occur until you start using the `DTLSChannel`. This gives you a chance
        to do further configuration first, like setting MTU etc.

        Args:
          address: The address to connect to. Usually a (host, port) tuple, like
            ``("127.0.0.1", 12345)``.
          ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for
            this connection.

        Returns:
          DTLSChannel

        """
        # it would be nice if we could detect when 'address' is our own endpoint (a
        # loopback connection), because that can't work
        # but I don't see how to do it reliably
        self._check_closed()
        channel = DTLSChannel._create(self, address, ssl_context)
        channel._ssl.set_connect_state()
        old_channel = self._streams.get(address)
        if old_channel is not None:
            old_channel._set_replaced()
        self._streams[address] = channel
        return channel
