# This file is licensed under the BSD 2-Clause License.
# See https://opensource.org/licenses/BSD-2-Clause for details.

import threading

from Crypto.Util.number import bytes_to_long, long_to_bytes
from Crypto.Util._raw_api import (VoidPointer, null_pointer,
                                  SmartPointer, c_size_t, c_uint8_ptr,
                                  c_ulonglong)
from Crypto.Math.Numbers import Integer
from Crypto.Random.random import getrandbits


class CurveID(object):
    P192 = 1
    P224 = 2
    P256 = 3
    P384 = 4
    P521 = 5
    ED25519 = 6
    ED448 = 7
    CURVE25519 = 8
    CURVE448 = 9


class _Curves(object):

    curves = {}
    curves_lock = threading.RLock()

    p192_names = ["p192", "NIST P-192", "P-192", "prime192v1", "secp192r1",
                  "nistp192"]
    p224_names = ["p224", "NIST P-224", "P-224", "prime224v1", "secp224r1",
                  "nistp224"]
    p256_names = ["p256", "NIST P-256", "P-256", "prime256v1", "secp256r1",
                  "nistp256"]
    p384_names = ["p384", "NIST P-384", "P-384", "prime384v1", "secp384r1",
                  "nistp384"]
    p521_names = ["p521", "NIST P-521", "P-521", "prime521v1", "secp521r1",
                  "nistp521"]
    ed25519_names = ["ed25519", "Ed25519"]
    ed448_names = ["ed448", "Ed448"]
    curve25519_names = ["curve25519", "Curve25519", "X25519"]
    curve448_names = ["curve448", "Curve448", "X448"]

    all_names = p192_names + p224_names + p256_names + p384_names + p521_names + \
        ed25519_names + ed448_names + curve25519_names + curve448_names

    def __contains__(self, item):
        return item in self.all_names

    def __dir__(self):
        return self.all_names

    def load(self, name):
        if name in self.p192_names:
            from . import _nist_ecc
            p192 = _nist_ecc.p192_curve()
            p192.id = CurveID.P192
            self.curves.update(dict.fromkeys(self.p192_names, p192))
        elif name in self.p224_names:
            from . import _nist_ecc
            p224 = _nist_ecc.p224_curve()
            p224.id = CurveID.P224
            self.curves.update(dict.fromkeys(self.p224_names, p224))
        elif name in self.p256_names:
            from . import _nist_ecc
            p256 = _nist_ecc.p256_curve()
            p256.id = CurveID.P256
            self.curves.update(dict.fromkeys(self.p256_names, p256))
        elif name in self.p384_names:
            from . import _nist_ecc
            p384 = _nist_ecc.p384_curve()
            p384.id = CurveID.P384
            self.curves.update(dict.fromkeys(self.p384_names, p384))
        elif name in self.p521_names:
            from . import _nist_ecc
            p521 = _nist_ecc.p521_curve()
            p521.id = CurveID.P521
            self.curves.update(dict.fromkeys(self.p521_names, p521))
        elif name in self.ed25519_names:
            from . import _edwards
            ed25519 = _edwards.ed25519_curve()
            ed25519.id = CurveID.ED25519
            self.curves.update(dict.fromkeys(self.ed25519_names, ed25519))
        elif name in self.ed448_names:
            from . import _edwards
            ed448 = _edwards.ed448_curve()
            ed448.id = CurveID.ED448
            self.curves.update(dict.fromkeys(self.ed448_names, ed448))
        elif name in self.curve25519_names:
            from . import _montgomery
            curve25519 = _montgomery.curve25519_curve()
            curve25519.id = CurveID.CURVE25519
            self.curves.update(dict.fromkeys(self.curve25519_names, curve25519))
        elif name in self.curve448_names:
            from . import _montgomery
            curve448 = _montgomery.curve448_curve()
            curve448.id = CurveID.CURVE448
            self.curves.update(dict.fromkeys(self.curve448_names, curve448))
        else:
            raise ValueError("Unsupported curve '%s'" % name)
        return self.curves[name]

    def __getitem__(self, name):
        with self.curves_lock:
            curve = self.curves.get(name)
            if curve is None:
                curve = self.load(name)
                if name in self.curve25519_names or name in self.curve448_names:
                    curve.G = EccXPoint(curve.Gx, name)
                else:
                    curve.G = EccPoint(curve.Gx, curve.Gy, name)
                curve.is_edwards = curve.id in (CurveID.ED25519, CurveID.ED448)
                curve.is_montgomery = curve.id in (CurveID.CURVE25519,
                                                   CurveID.CURVE448)
                curve.is_weierstrass = not (curve.is_edwards or
                                            curve.is_montgomery)
        return curve

    def items(self):
        # Load all curves
        for name in self.all_names:
            _ = self[name]
        return self.curves.items()


_curves = _Curves()


class EccPoint(object):
    """A class to model a point on an Elliptic Curve.

    The class supports operators for:

    * Adding two points: ``R = S + T``
    * In-place addition: ``S += T``
    * Negating a point: ``R = -T``
    * Comparing two points: ``if S == T: ...`` or ``if S != T: ...``
    * Multiplying a point by a scalar: ``R = S*k``
    * In-place multiplication by a scalar: ``T *= k``

    :ivar curve: The **canonical** name of the curve as defined in the `ECC table`_.
    :vartype curve: string

    :ivar x: The affine X-coordinate of the ECC point
    :vartype x: integer

    :ivar y: The affine Y-coordinate of the ECC point
    :vartype y: integer

    :ivar xy: The tuple with affine X- and Y- coordinates
    """

    def __init__(self, x, y, curve="p256"):

        try:
            self._curve = _curves[curve]
        except KeyError:
            raise ValueError("Unknown curve name %s" % str(curve))
        self.curve = self._curve.canonical

        if self._curve.id == CurveID.CURVE25519:
            raise ValueError("EccPoint cannot be created for Curve25519")

        modulus_bytes = self.size_in_bytes()

        xb = long_to_bytes(x, modulus_bytes)
        yb = long_to_bytes(y, modulus_bytes)
        if len(xb) != modulus_bytes or len(yb) != modulus_bytes:
            raise ValueError("Incorrect coordinate length")

        new_point = self._curve.rawlib.new_point
        free_func = self._curve.rawlib.free_point

        self._point = VoidPointer()
        try:
            context = self._curve.context.get()
        except AttributeError:
            context = null_pointer
        result = new_point(self._point.address_of(),
                           c_uint8_ptr(xb),
                           c_uint8_ptr(yb),
                           c_size_t(modulus_bytes),
                           context)

        if result:
            if result == 15:
                raise ValueError("The EC point does not belong to the curve")
            raise ValueError("Error %d while instantiating an EC point" % result)

        # Ensure that object disposal of this Python object will (eventually)
        # free the memory allocated by the raw library for the EC point
        self._point = SmartPointer(self._point.get(), free_func)

    def set(self, point):
        clone = self._curve.rawlib.clone
        free_func = self._curve.rawlib.free_point

        self._point = VoidPointer()
        result = clone(self._point.address_of(),
                       point._point.get())

        if result:
            raise ValueError("Error %d while cloning an EC point" % result)

        self._point = SmartPointer(self._point.get(), free_func)
        return self

    def __eq__(self, point):
        if not isinstance(point, EccPoint):
            return False

        cmp_func = self._curve.rawlib.cmp
        return 0 == cmp_func(self._point.get(), point._point.get())

    # Only needed for Python 2
    def __ne__(self, point):
        return not self == point

    def __neg__(self):
        neg_func = self._curve.rawlib.neg
        np = self.copy()
        result = neg_func(np._point.get())
        if result:
            raise ValueError("Error %d while inverting an EC point" % result)
        return np

    def copy(self):
        """Return a copy of this point."""
        x, y = self.xy
        np = EccPoint(x, y, self.curve)
        return np

    def is_point_at_infinity(self):
        """``True`` if this is the *point-at-infinity*."""

        if self._curve.is_edwards:
            return self.x == 0
        else:
            return self.xy == (0, 0)

    def point_at_infinity(self):
        """Return the *point-at-infinity* for the curve."""

        if self._curve.is_edwards:
            return EccPoint(0, 1, self.curve)
        else:
            return EccPoint(0, 0, self.curve)

    @property
    def x(self):
        return self.xy[0]

    @property
    def y(self):
        return self.xy[1]

    @property
    def xy(self):
        modulus_bytes = self.size_in_bytes()
        xb = bytearray(modulus_bytes)
        yb = bytearray(modulus_bytes)
        get_xy = self._curve.rawlib.get_xy
        result = get_xy(c_uint8_ptr(xb),
                        c_uint8_ptr(yb),
                        c_size_t(modulus_bytes),
                        self._point.get())
        if result:
            raise ValueError("Error %d while encoding an EC point" % result)

        return (Integer(bytes_to_long(xb)), Integer(bytes_to_long(yb)))

    def size_in_bytes(self):
        """Size of each coordinate, in bytes."""
        return (self.size_in_bits() + 7) // 8

    def size_in_bits(self):
        """Size of each coordinate, in bits."""
        return self._curve.modulus_bits

    def double(self):
        """Double this point (in-place operation).

        Returns:
            This same object (to enable chaining).
        """

        double_func = self._curve.rawlib.double
        result = double_func(self._point.get())
        if result:
            raise ValueError("Error %d while doubling an EC point" % result)
        return self

    def __iadd__(self, point):
        """Add a second point to this one"""

        add_func = self._curve.rawlib.add
        result = add_func(self._point.get(), point._point.get())
        if result:
            if result == 16:
                raise ValueError("EC points are not on the same curve")
            raise ValueError("Error %d while adding two EC points" % result)
        return self

    def __add__(self, point):
        """Return a new point, the addition of this one and another"""

        np = self.copy()
        np += point
        return np

    def __imul__(self, scalar):
        """Multiply this point by a scalar"""

        scalar_func = self._curve.rawlib.scalar
        if scalar < 0:
            raise ValueError("Scalar multiplication is only defined for non-negative integers")
        sb = long_to_bytes(scalar)
        result = scalar_func(self._point.get(),
                             c_uint8_ptr(sb),
                             c_size_t(len(sb)),
                             c_ulonglong(getrandbits(64)))
        if result:
            raise ValueError("Error %d during scalar multiplication" % result)
        return self

    def __mul__(self, scalar):
        """Return a new point, the scalar product of this one"""

        np = self.copy()
        np *= scalar
        return np

    def __rmul__(self, left_hand):
        return self.__mul__(left_hand)


class EccXPoint(object):
    """A class to model a point on an Elliptic Curve,
    where only the X-coordinate is exposed.

    The class supports operators for:

    * Multiplying a point by a scalar: ``R = S*k``
    * In-place multiplication by a scalar: ``T *= k``

    :ivar curve: The **canonical** name of the curve as defined in the `ECC table`_.
    :vartype curve: string

    :ivar x: The affine X-coordinate of the ECC point
    :vartype x: integer
    """

    def __init__(self, x, curve):
        # Once encoded, x must not exceed the length of the modulus,
        # but its value may match or exceed the modulus itself
        # (i.e., non-canonical value)

        try:
            self._curve = _curves[curve]
        except KeyError:
            raise ValueError("Unknown curve name %s" % str(curve))
        self.curve = self._curve.canonical

        if self._curve.id not in (CurveID.CURVE25519, CurveID.CURVE448):
            raise ValueError("EccXPoint can only be created for Curve25519/Curve448")

        new_point = self._curve.rawlib.new_point
        free_func = self._curve.rawlib.free_point

        self._point = VoidPointer()
        try:
            context = self._curve.context.get()
        except AttributeError:
            context = null_pointer

        modulus_bytes = self.size_in_bytes()

        if x is None:
            xb = null_pointer
        else:
            xb = c_uint8_ptr(long_to_bytes(x, modulus_bytes))
            if len(xb) != modulus_bytes:
                raise ValueError("Incorrect coordinate length")

        self._point = VoidPointer()
        result = new_point(self._point.address_of(),
                           xb,
                           c_size_t(modulus_bytes),
                           context)

        if result == 15:
            raise ValueError("The EC point does not belong to the curve")
        if result:
            raise ValueError("Error %d while instantiating an EC point" % result)

        # Ensure that object disposal of this Python object will (eventually)
        # free the memory allocated by the raw library for the EC point
        self._point = SmartPointer(self._point.get(), free_func)

    def set(self, point):
        clone = self._curve.rawlib.clone
        free_func = self._curve.rawlib.free_point

        self._point = VoidPointer()
        result = clone(self._point.address_of(),
                       point._point.get())
        if result:
            raise ValueError("Error %d while cloning an EC point" % result)

        self._point = SmartPointer(self._point.get(), free_func)
        return self

    def __eq__(self, point):
        if not isinstance(point, EccXPoint):
            return False

        cmp_func = self._curve.rawlib.cmp
        p1 = self._point.get()
        p2 = point._point.get()
        res = cmp_func(p1, p2)
        return 0 == res

    def copy(self):
        """Return a copy of this point."""

        try:
            x = self.x
        except ValueError:
            return self.point_at_infinity()
        return EccXPoint(x, self.curve)

    def is_point_at_infinity(self):
        """``True`` if this is the *point-at-infinity*."""

        try:
            _ = self.x
        except ValueError:
            return True
        return False

    def point_at_infinity(self):
        """Return the *point-at-infinity* for the curve."""

        return EccXPoint(None, self.curve)

    @property
    def x(self):
        modulus_bytes = self.size_in_bytes()
        xb = bytearray(modulus_bytes)
        get_x = self._curve.rawlib.get_x
        result = get_x(c_uint8_ptr(xb),
                       c_size_t(modulus_bytes),
                       self._point.get())
        if result == 19:    # ERR_ECC_PAI
            raise ValueError("No X coordinate for the point at infinity")
        if result:
            raise ValueError("Error %d while getting X of an EC point" % result)
        return Integer(bytes_to_long(xb))

    def size_in_bytes(self):
        """Size of each coordinate, in bytes."""
        return (self.size_in_bits() + 7) // 8

    def size_in_bits(self):
        """Size of each coordinate, in bits."""
        return self._curve.modulus_bits

    def __imul__(self, scalar):
        """Multiply this point by a scalar"""

        scalar_func = self._curve.rawlib.scalar
        if scalar < 0:
            raise ValueError("Scalar multiplication is only defined for non-negative integers")
        sb = long_to_bytes(scalar)
        result = scalar_func(self._point.get(),
                             c_uint8_ptr(sb),
                             c_size_t(len(sb)),
                             c_ulonglong(getrandbits(64)))
        if result:
            raise ValueError("Error %d during scalar multiplication" % result)
        return self

    def __mul__(self, scalar):
        """Return a new point, the scalar product of this one"""

        np = self.copy()
        np *= scalar
        return np

    def __rmul__(self, left_hand):
        return self.__mul__(left_hand)
