###############################################################################
#
# Shape - A class for to represent Excel XLSX shape objects.
#
# SPDX-License-Identifier: BSD-2-Clause
#
# Copyright (c) 2013-2025, John McNamara, jmcnamara@cpan.org
#
import copy
from warnings import warn


class Shape:
    """
    A class for to represent Excel XLSX shape objects.


    """

    ###########################################################################
    #
    # Public API.
    #
    ###########################################################################

    def __init__(self, shape_type, name, options):
        """
        Constructor.

        """
        super().__init__()
        self.name = name
        self.shape_type = shape_type
        self.connect = 0
        self.drawing = 0
        self.edit_as = ""
        self.id = 0
        self.text = ""
        self.textlink = ""
        self.stencil = 1
        self.element = -1
        self.start = None
        self.start_index = None
        self.end = None
        self.end_index = None
        self.adjustments = []
        self.start_side = ""
        self.end_side = ""
        self.flip_h = 0
        self.flip_v = 0
        self.rotation = 0
        self.text_rotation = 0
        self.textbox = False

        self.align = None
        self.fill = None
        self.font = None
        self.format = None
        self.line = None

        self._set_options(options)

    ###########################################################################
    #
    # Private API.
    #
    ###########################################################################

    def _set_options(self, options):
        self.align = self._get_align_properties(options.get("align"))
        self.fill = self._get_fill_properties(options.get("fill"))
        self.font = self._get_font_properties(options.get("font"))
        self.gradient = self._get_gradient_properties(options.get("gradient"))
        self.line = self._get_line_properties(options.get("line"))

        self.text_rotation = options.get("text_rotation", 0)

        self.textlink = options.get("textlink", "")
        if self.textlink.startswith("="):
            self.textlink = self.textlink.lstrip("=")

        if options.get("border"):
            self.line = self._get_line_properties(options["border"])

        # Gradient fill overrides solid fill.
        if self.gradient:
            self.fill = None

    ###########################################################################
    #
    # Static methods for processing chart/shape style properties.
    #
    ###########################################################################

    @staticmethod
    def _get_line_properties(line):
        # Convert user line properties to the structure required internally.

        if not line:
            return {"defined": False}

        # Copy the user defined properties since they will be modified.
        line = copy.deepcopy(line)

        dash_types = {
            "solid": "solid",
            "round_dot": "sysDot",
            "square_dot": "sysDash",
            "dash": "dash",
            "dash_dot": "dashDot",
            "long_dash": "lgDash",
            "long_dash_dot": "lgDashDot",
            "long_dash_dot_dot": "lgDashDotDot",
            "dot": "dot",
            "system_dash_dot": "sysDashDot",
            "system_dash_dot_dot": "sysDashDotDot",
        }

        # Check the dash type.
        dash_type = line.get("dash_type")

        if dash_type is not None:
            if dash_type in dash_types:
                line["dash_type"] = dash_types[dash_type]
            else:
                warn(f"Unknown dash type '{dash_type}'")
                return {}

        line["defined"] = True

        return line

    @staticmethod
    def _get_fill_properties(fill):
        # Convert user fill properties to the structure required internally.

        if not fill:
            return {"defined": False}

        # Copy the user defined properties since they will be modified.
        fill = copy.deepcopy(fill)

        fill["defined"] = True

        return fill

    @staticmethod
    def _get_pattern_properties(pattern):
        # Convert user defined pattern to the structure required internally.

        if not pattern:
            return {}

        # Copy the user defined properties since they will be modified.
        pattern = copy.deepcopy(pattern)

        if not pattern.get("pattern"):
            warn("Pattern must include 'pattern'")
            return {}

        if not pattern.get("fg_color"):
            warn("Pattern must include 'fg_color'")
            return {}

        types = {
            "percent_5": "pct5",
            "percent_10": "pct10",
            "percent_20": "pct20",
            "percent_25": "pct25",
            "percent_30": "pct30",
            "percent_40": "pct40",
            "percent_50": "pct50",
            "percent_60": "pct60",
            "percent_70": "pct70",
            "percent_75": "pct75",
            "percent_80": "pct80",
            "percent_90": "pct90",
            "light_downward_diagonal": "ltDnDiag",
            "light_upward_diagonal": "ltUpDiag",
            "dark_downward_diagonal": "dkDnDiag",
            "dark_upward_diagonal": "dkUpDiag",
            "wide_downward_diagonal": "wdDnDiag",
            "wide_upward_diagonal": "wdUpDiag",
            "light_vertical": "ltVert",
            "light_horizontal": "ltHorz",
            "narrow_vertical": "narVert",
            "narrow_horizontal": "narHorz",
            "dark_vertical": "dkVert",
            "dark_horizontal": "dkHorz",
            "dashed_downward_diagonal": "dashDnDiag",
            "dashed_upward_diagonal": "dashUpDiag",
            "dashed_horizontal": "dashHorz",
            "dashed_vertical": "dashVert",
            "small_confetti": "smConfetti",
            "large_confetti": "lgConfetti",
            "zigzag": "zigZag",
            "wave": "wave",
            "diagonal_brick": "diagBrick",
            "horizontal_brick": "horzBrick",
            "weave": "weave",
            "plaid": "plaid",
            "divot": "divot",
            "dotted_grid": "dotGrid",
            "dotted_diamond": "dotDmnd",
            "shingle": "shingle",
            "trellis": "trellis",
            "sphere": "sphere",
            "small_grid": "smGrid",
            "large_grid": "lgGrid",
            "small_check": "smCheck",
            "large_check": "lgCheck",
            "outlined_diamond": "openDmnd",
            "solid_diamond": "solidDmnd",
        }

        # Check for valid types.
        if pattern["pattern"] not in types:
            warn(f"unknown pattern type '{pattern['pattern']}'")
            return {}

        pattern["pattern"] = types[pattern["pattern"]]

        # Specify a default background color.
        pattern["bg_color"] = pattern.get("bg_color", "#FFFFFF")

        return pattern

    @staticmethod
    def _get_gradient_properties(gradient):
        # pylint: disable=too-many-return-statements
        # Convert user defined gradient to the structure required internally.

        if not gradient:
            return {}

        # Copy the user defined properties since they will be modified.
        gradient = copy.deepcopy(gradient)

        types = {
            "linear": "linear",
            "radial": "circle",
            "rectangular": "rect",
            "path": "shape",
        }

        # Check the colors array exists and is valid.
        if "colors" not in gradient or not isinstance(gradient["colors"], list):
            warn("Gradient must include colors list")
            return {}

        # Check the colors array has the required number of entries.
        if not 2 <= len(gradient["colors"]) <= 10:
            warn("Gradient colors list must at least 2 values and not more than 10")
            return {}

        if "positions" in gradient:
            # Check the positions array has the right number of entries.
            if len(gradient["positions"]) != len(gradient["colors"]):
                warn("Gradient positions not equal to number of colors")
                return {}

            # Check the positions are in the correct range.
            for pos in gradient["positions"]:
                if not 0 <= pos <= 100:
                    warn("Gradient position must be in the range 0 <= position <= 100")
                    return {}
        else:
            # Use the default gradient positions.
            if len(gradient["colors"]) == 2:
                gradient["positions"] = [0, 100]

            elif len(gradient["colors"]) == 3:
                gradient["positions"] = [0, 50, 100]

            elif len(gradient["colors"]) == 4:
                gradient["positions"] = [0, 33, 66, 100]

            else:
                warn("Must specify gradient positions")
                return {}

        angle = gradient.get("angle")
        if angle:
            if not 0 <= angle < 360:
                warn("Gradient angle must be in the range 0 <= angle < 360")
                return {}
        else:
            gradient["angle"] = 90

        # Check for valid types.
        gradient_type = gradient.get("type")

        if gradient_type is not None:
            if gradient_type in types:
                gradient["type"] = types[gradient_type]
            else:
                warn(f"Unknown gradient type '{gradient_type}")
                return {}
        else:
            gradient["type"] = "linear"

        return gradient

    @staticmethod
    def _get_font_properties(options):
        # Convert user defined font values into private dict values.
        if options is None:
            options = {}

        font = {
            "name": options.get("name"),
            "color": options.get("color"),
            "size": options.get("size", 11),
            "bold": options.get("bold"),
            "italic": options.get("italic"),
            "underline": options.get("underline"),
            "pitch_family": options.get("pitch_family"),
            "charset": options.get("charset"),
            "baseline": options.get("baseline", -1),
            "lang": options.get("lang", "en-US"),
        }

        # Convert font size units.
        if font["size"]:
            font["size"] = int(font["size"] * 100)

        return font

    @staticmethod
    def _get_font_style_attributes(font):
        # _get_font_style_attributes.
        attributes = []

        if not font:
            return attributes

        if font.get("size"):
            attributes.append(("sz", font["size"]))

        if font.get("bold") is not None:
            attributes.append(("b", 0 + font["bold"]))

        if font.get("italic") is not None:
            attributes.append(("i", 0 + font["italic"]))

        if font.get("underline") is not None:
            attributes.append(("u", "sng"))

        if font.get("baseline") != -1:
            attributes.append(("baseline", font["baseline"]))

        return attributes

    @staticmethod
    def _get_font_latin_attributes(font):
        # _get_font_latin_attributes.
        attributes = []

        if not font:
            return attributes

        if font["name"] is not None:
            attributes.append(("typeface", font["name"]))

        if font["pitch_family"] is not None:
            attributes.append(("pitchFamily", font["pitch_family"]))

        if font["charset"] is not None:
            attributes.append(("charset", font["charset"]))

        return attributes

    @staticmethod
    def _get_align_properties(align):
        # Convert user defined align to the structure required internally.
        if not align:
            return {"defined": False}

        # Copy the user defined properties since they will be modified.
        align = copy.deepcopy(align)

        if "vertical" in align:
            align_type = align["vertical"]

            align_types = {
                "top": "top",
                "middle": "middle",
                "bottom": "bottom",
            }

            if align_type in align_types:
                align["vertical"] = align_types[align_type]
            else:
                warn(f"Unknown alignment type '{align_type}'")
                return {"defined": False}

        if "horizontal" in align:
            align_type = align["horizontal"]

            align_types = {
                "left": "left",
                "center": "center",
                "right": "right",
            }

            if align_type in align_types:
                align["horizontal"] = align_types[align_type]
            else:
                warn(f"Unknown alignment type '{align_type}'")
                return {"defined": False}

        align["defined"] = True

        return align
