import struct
from enum import IntEnum

from types import ModuleType
from typing import Optional

from .KDF import _HKDF_extract, _HKDF_expand
from .DH import key_agreement, import_x25519_public_key, import_x448_public_key
from Crypto.Util.strxor import strxor
from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey
from Crypto.Hash import SHA256, SHA384, SHA512
from Crypto.Cipher import AES, ChaCha20_Poly1305


class MODE(IntEnum):
    """HPKE modes"""
    BASE = 0x00
    PSK = 0x01
    AUTH = 0x02
    AUTH_PSK = 0x03


class AEAD(IntEnum):
    """Authenticated Encryption with Associated Data (AEAD) Functions"""
    AES128_GCM = 0x0001
    AES256_GCM = 0x0002
    CHACHA20_POLY1305 = 0x0003


class DeserializeError(ValueError):
    pass

class MessageLimitReachedError(ValueError):
    pass

# CURVE to (KEM ID, KDF ID, HASH)
_Curve_Config = {
  "NIST P-256": (0x0010, 0x0001, SHA256),
  "NIST P-384": (0x0011, 0x0002, SHA384),
  "NIST P-521": (0x0012, 0x0003, SHA512),
  "Curve25519": (0x0020, 0x0001, SHA256),
  "Curve448":   (0x0021, 0x0003, SHA512),
}


def _labeled_extract(salt: bytes,
                     label: bytes,
                     ikm: bytes,
                     suite_id: bytes,
                     hashmod: ModuleType):
    labeled_ikm = b"HPKE-v1" + suite_id + label + ikm
    return _HKDF_extract(salt, labeled_ikm, hashmod)


def _labeled_expand(prk: bytes,
                    label: bytes,
                    info: bytes,
                    L: int,
                    suite_id: bytes,
                    hashmod: ModuleType):
    labeled_info = struct.pack('>H', L) + b"HPKE-v1" + suite_id + \
                   label + info
    return _HKDF_expand(prk, labeled_info, L, hashmod)


def _extract_and_expand(dh: bytes,
                        kem_context: bytes,
                        suite_id: bytes,
                        hashmod: ModuleType):
    Nsecret = hashmod.digest_size

    eae_prk = _labeled_extract(b"",
                               b"eae_prk",
                               dh,
                               suite_id,
                               hashmod)

    shared_secret = _labeled_expand(eae_prk,
                                    b"shared_secret",
                                    kem_context,
                                    Nsecret,
                                    suite_id,
                                    hashmod)
    return shared_secret


class HPKE_Cipher:

    def __init__(self,
                 receiver_key: EccKey,
                 enc: Optional[bytes],
                 sender_key: Optional[EccKey],
                 psk_pair: tuple[bytes, bytes],
                 info: bytes,
                 aead_id: AEAD,
                 mode: MODE):

        self.enc: bytes = b'' if enc is None else enc
        """The encapsulated session key."""

        self._verify_psk_inputs(mode, psk_pair)

        self._curve = receiver_key.curve
        self._aead_id = aead_id
        self._mode = mode

        try:
            self._kem_id, \
             self._kdf_id, \
             self._hashmod = _Curve_Config[self._curve]
        except KeyError as ke:
            raise ValueError("Curve {} is not supported by HPKE".format(self._curve)) from ke

        self._Nk = 16 if self._aead_id == AEAD.AES128_GCM else 32
        self._Nn = 12
        self._Nt = 16
        self._Nh = self._hashmod.digest_size

        self._encrypt = not receiver_key.has_private()

        if self._encrypt:
            # SetupBaseS (encryption)
            if enc is not None:
                raise ValueError("Parameter 'enc' cannot be an input  when sealing")
            shared_secret, self.enc = self._encap(receiver_key,
                                                  self._kem_id,
                                                  self._hashmod,
                                                  sender_key)
        else:
            # SetupBaseR (decryption)
            if enc is None:
                raise ValueError("Parameter 'enc' required when unsealing")
            shared_secret = self._decap(enc,
                                        receiver_key,
                                        self._kem_id,
                                        self._hashmod,
                                        sender_key)

        self._sequence = 0
        self._max_sequence = (1 << (8 * self._Nn)) - 1

        self._key, \
            self._base_nonce, \
            self._export_secret = self._key_schedule(shared_secret,
                                                     info,
                                                     *psk_pair)

    @staticmethod
    def _encap(receiver_key: EccKey,
               kem_id: int,
               hashmod: ModuleType,
               sender_key: Optional[EccKey] = None,
               eph_key: Optional[EccKey] = None):

        assert (sender_key is None) or sender_key.has_private()
        assert (eph_key is None) or eph_key.has_private()

        if eph_key is None:
            eph_key = ECC.generate(curve=receiver_key.curve)
        enc = eph_key.public_key().export_key(format='raw')

        pkRm = receiver_key.public_key().export_key(format='raw')
        kem_context = enc + pkRm
        extra_param = {}
        if sender_key:
            kem_context += sender_key.public_key().export_key(format='raw')
            extra_param = {'static_priv': sender_key}

        suite_id = b"KEM" + struct.pack('>H', kem_id)

        def kdf(dh,
                kem_context=kem_context,
                suite_id=suite_id,
                hashmod=hashmod):
            return _extract_and_expand(dh, kem_context, suite_id, hashmod)

        shared_secret = key_agreement(eph_priv=eph_key,
                                      static_pub=receiver_key,
                                      kdf=kdf,
                                      **extra_param)
        return shared_secret, enc

    @staticmethod
    def _decap(enc: bytes,
               receiver_key: EccKey,
               kem_id: int,
               hashmod: ModuleType,
               sender_key: Optional[EccKey] = None):

        assert receiver_key.has_private()

        try:
            if receiver_key.curve == 'Curve25519':
                pkE = import_x25519_public_key(enc)
            elif receiver_key.curve == 'Curve448':
                pkE = import_x448_public_key(enc)
            else:
                pkE = ECC.import_key(enc, curve_name=receiver_key.curve)
        except ValueError as ve:
            raise DeserializeError("'enc' is not a valid encapsulated HPKE key") from ve

        pkRm = receiver_key.public_key().export_key(format='raw')
        kem_context = enc + pkRm
        extra_param = {}
        if sender_key:
            kem_context += sender_key.public_key().export_key(format='raw')
            extra_param = {'static_pub': sender_key}

        suite_id = b"KEM" + struct.pack('>H', kem_id)

        def kdf(dh,
                kem_context=kem_context,
                suite_id=suite_id,
                hashmod=hashmod):
            return _extract_and_expand(dh, kem_context, suite_id, hashmod)

        shared_secret = key_agreement(eph_pub=pkE,
                                      static_priv=receiver_key,
                                      kdf=kdf,
                                      **extra_param)
        return shared_secret

    @staticmethod
    def _verify_psk_inputs(mode: MODE, psk_pair: tuple[bytes, bytes]):
        psk_id, psk = psk_pair

        if (psk == b'') ^ (psk_id == b''):
            raise ValueError("Inconsistent PSK inputs")

        if (psk == b''):
            if mode in (MODE.PSK, MODE.AUTH_PSK):
                raise ValueError(f"PSK is required with mode {mode.name}")
        else:
            if len(psk) < 32:
                raise ValueError("PSK must be at least 32 byte long")
            if mode in (MODE.BASE, MODE.AUTH):
                raise ValueError("PSK is not compatible with this mode")

    def _key_schedule(self,
                      shared_secret: bytes,
                      info: bytes,
                      psk_id: bytes,
                      psk: bytes):

        suite_id = b"HPKE" + struct.pack('>HHH',
                                         self._kem_id,
                                         self._kdf_id,
                                         self._aead_id)

        psk_id_hash = _labeled_extract(b'',
                                       b'psk_id_hash',
                                       psk_id,
                                       suite_id,
                                       self._hashmod)

        info_hash = _labeled_extract(b'',
                                     b'info_hash',
                                     info,
                                     suite_id,
                                     self._hashmod)

        key_schedule_context = self._mode.to_bytes(1, 'big') + psk_id_hash + info_hash

        secret = _labeled_extract(shared_secret,
                                  b'secret',
                                  psk,
                                  suite_id,
                                  self._hashmod)

        key = _labeled_expand(secret,
                              b'key',
                              key_schedule_context,
                              self._Nk,
                              suite_id,
                              self._hashmod)

        base_nonce = _labeled_expand(secret,
                                     b'base_nonce',
                                     key_schedule_context,
                                     self._Nn,
                                     suite_id,
                                     self._hashmod)

        exporter_secret = _labeled_expand(secret,
                                          b'exp',
                                          key_schedule_context,
                                          self._Nh,
                                          suite_id,
                                          self._hashmod)

        return key, base_nonce, exporter_secret

    def _new_cipher(self):
        nonce = strxor(self._base_nonce, self._sequence.to_bytes(self._Nn, 'big'))
        if self._aead_id in (AEAD.AES128_GCM, AEAD.AES256_GCM):
            cipher = AES.new(self._key, AES.MODE_GCM, nonce=nonce, mac_len=self._Nt)
        elif self._aead_id == AEAD.CHACHA20_POLY1305:
            cipher = ChaCha20_Poly1305.new(key=self._key, nonce=nonce)
        else:
            raise ValueError(f"Unknown AEAD cipher ID {self._aead_id:#x}")
        if self._sequence >= self._max_sequence:
            raise MessageLimitReachedError()
        self._sequence += 1
        return cipher

    def seal(self, plaintext: bytes, auth_data: Optional[bytes] = None):
        """Encrypt and authenticate a message.

        This method can be invoked multiple times
        to seal an ordered sequence of messages.

        Arguments:
          plaintext: bytes
            The message to seal.
          auth_data: bytes
            Optional. Additional Authenticated data (AAD) that is not encrypted
            but that will be also covered by the authentication tag.

        Returns:
           The ciphertext concatenated with the authentication tag.
        """

        if not self._encrypt:
            raise ValueError("This cipher can only be used to seal")
        cipher = self._new_cipher()
        if auth_data:
            cipher.update(auth_data)
        ct, tag = cipher.encrypt_and_digest(plaintext)
        return ct + tag

    def unseal(self, ciphertext: bytes, auth_data: Optional[bytes] = None):
        """Decrypt a message and validate its authenticity.

        This method can be invoked multiple times
        to unseal an ordered sequence of messages.

        Arguments:
          cipertext: bytes
            The message to unseal.
          auth_data: bytes
            Optional. Additional Authenticated data (AAD) that
            was also covered by the authentication tag.

        Returns:
           The original plaintext.

        Raises: ValueError
           If the ciphertext (in combination with the AAD) is not valid.

           But if it is the first time you call ``unseal()`` this
           exception may also mean that any of the parameters or keys
           used to establish the session is wrong or that one is missing.
        """

        if self._encrypt:
            raise ValueError("This cipher can only be used to unseal")
        if len(ciphertext) < self._Nt:
            raise ValueError("Ciphertext is too small")
        cipher = self._new_cipher()
        if auth_data:
            cipher.update(auth_data)

        try:
            pt = cipher.decrypt_and_verify(ciphertext[:-self._Nt],
                                           ciphertext[-self._Nt:])
        except ValueError:
            if self._sequence == 1:
                raise ValueError("Incorrect HPKE keys/parameters or invalid message (wrong MAC tag)")
            raise ValueError("Invalid message (wrong MAC tag)")
        return pt


def new(*, receiver_key: EccKey,
        aead_id: AEAD,
        enc: Optional[bytes] = None,
        sender_key: Optional[EccKey] = None,
        psk: Optional[tuple[bytes, bytes]] = None,
        info: Optional[bytes] = None) -> HPKE_Cipher:
    """Create an HPKE context which can be used:

    - by the sender to seal (encrypt) a message or
    - by the receiver to unseal (decrypt) it.

    As a minimum, the two parties agree on the receiver's asymmetric key
    (of which the sender will only know the public half).

    Additionally, for authentication purposes, they may also agree on:

    * the sender's asymmetric key (of which the receiver will only know the public half)

    * a shared secret (e.g., a symmetric key derived from a password)

    Args:
      receiver_key:
        The ECC key of the receiver.
        It must be on one of the following curves: ``NIST P-256``,
        ``NIST P-384``, ``NIST P-521``, ``X25519`` or ``X448``.

        If this is a **public** key, the HPKE context can only be used to
        **seal** (**encrypt**).

        If this is a **private** key, the HPKE context can only be used to
        **unseal** (**decrypt**).

      aead_id:
        The HPKE identifier of the symmetric cipher.
        The possible values are:

        * ``HPKE.AEAD.AES128_GCM``
        * ``HPKE.AEAD.AES256_GCM``
        * ``HPKE.AEAD.CHACHA20_POLY1305``

      enc:
        The encapsulated session key (i.e., the KEM shared secret).

        The receiver must always specify this parameter.

        The sender must always omit this parameter.

      sender_key:
        The ECC key of the sender.
        It must be on the same curve as the ``receiver_key``.
        If the ``receiver_key`` is a public key, ``sender_key`` must be a
        private key, and vice versa.

      psk:
        A Pre-Shared Key (PSK) as a 2-tuple of non-empty
        byte strings: the identifier and the actual secret value.
        Sender and receiver must use the same PSK (or none).

        The secret value must be at least 32 bytes long,
        but it  must not be a low-entropy password
        (use a KDF like PBKDF2 or scrypt to derive a secret
        from a password).

      info:
        A non-secret parameter that contributes
        to the generation of all session keys.
        Sender and receive must use the same **info** parameter (or none).

    Returns:
        An object that can be used for
        sealing (if ``receiver_key`` is a public key) or
        unsealing (if ``receiver_key`` is a private key).
        In the latter case,
        correctness of all the keys and parameters will only
        be assessed with the first call to ``unseal()``.

    .. _HPKE: https://datatracker.ietf.org/doc/rfc9180/
    """

    if aead_id not in AEAD:
        raise ValueError(f"Unknown AEAD cipher ID {aead_id:#x}")

    curve = receiver_key.curve
    if curve not in ('NIST P-256', 'NIST P-384', 'NIST P-521',
                     'Curve25519', 'Curve448'):
        raise ValueError(f"Unsupported curve {curve}")

    if sender_key:
        count_private_keys = int(receiver_key.has_private()) + \
                             int(sender_key.has_private())
        if count_private_keys != 1:
            raise ValueError("Exactly 1 private key required")
        if sender_key.curve != curve:
            raise ValueError("Sender key uses {} but recipient key {}".
                             format(sender_key.curve, curve))
        mode = MODE.AUTH if psk is None else MODE.AUTH_PSK
    else:
        mode = MODE.BASE if psk is None else MODE.PSK

    if psk is None:
        psk = b'', b''

    if info is None:
        info = b''

    return HPKE_Cipher(receiver_key,
                       enc,
                       sender_key,
                       psk,
                       info,
                       aead_id,
                       mode)
