import math
import re
from copy import copy

from PySide6.QtWidgets import QApplication, QWidget, QMessageBox, QInputDialog
from PySide6.QtGui import QPainter, QPen, QColor
from PySide6.QtCore import Qt, QPoint, QPointF, Signal
from python_solvespace import SolverSystem, ResultFlag

class DrawingTools():
    pass
class Costrains():
    pass

class SketchWidget(QWidget):
    constrain_done = Signal()

    def __init__(self):
        super().__init__()
        self.line_draw_buffer = [None, None]
        self.drag_buffer = [None, None]
        self.main_buffer = [None, None]

        self.hovered_point = None
        self.selected_line = None

        self.snapping_range = 20  # Range in pixels for snapping
        self.zoom = 1

        self.setMouseTracking(True)
        self.mouse_mode = False
        self.wp = None
        self.solv = SolverSystem()

        self.slv_points_main = []
        self.slv_lines_main = []

    def reset_buffers(self):
        self.line_draw_buffer = [None, None]
        self.drag_buffer = [None, None]
        self.main_buffer = [None, None]

    def set_points(self, points: list):
        self.points = points
        #self.update()

    def create_workplane(self):
        self.wp = self.solv.create_2d_base()

    def get_handle_nr(self, input_str: str) -> int:
        # Define the regex pattern to extract the handle number
        pattern = r"handle=(\d+)"

        # Use re.search to find the handle number in the string
        match = re.search(pattern, input_str)

        if match:
            handle_number = int(match.group(1))
            print(f"Handle number: {handle_number}")
            return int(handle_number)

        else:
            print("Handle number not found.")
            return 0

    def get_keys(self, d: dict, target: QPoint) -> list:
        result = []
        path = []
        print(d)
        print(target)
        for k, v in d.items():
            path.append(k)
            if isinstance(v, dict):
                self.get_keys(v, target)
            if v == target:
                result.append(copy(path))
            path.pop()

        return result

    def get_handle_from_ui_point(self, ui_point: QPoint):
        """Input QPoint and you shall reveive a slvs entity handle!"""
        for point in self.slv_points_main:
            if ui_point == point['ui_point']:
                slv_handle = point['solv_handle']

                return slv_handle

    def get_line_handle_from_ui_point(self, ui_point: QPoint):
        """Input Qpoint that is on a line and you shall receive the handle of the line!"""
        for target_line_con in self.slv_lines_main:
            if self.is_point_on_line(ui_point, target_line_con['ui_points'][0], target_line_con['ui_points'][1]):
                slv_handle = target_line_con['solv_handle']

                return slv_handle

    def get_point_line_handles_from_ui_point(self, ui_point: QPoint) -> tuple:
        """Input Qpoint that is on a line and you shall receive the handles of the points of the line!"""
        for target_line_con in self.slv_lines_main:
            if self.is_point_on_line(ui_point, target_line_con['ui_points'][0], target_line_con['ui_points'][1]):
                lines_to_cons = target_line_con['solv_entity_points']

                return lines_to_cons

    def distance(self, p1, p2):
        return math.sqrt((p1.x() - p2.x())**2 + (p1.y() - p2.y())**2)

    def calculate_midpoint(self, point1, point2):
        mx = (point1.x() + point2.x()) // 2
        my = (point1.y() + point2.y()) // 2
        return QPoint(mx, my)

    def is_point_on_line(self, p, p1, p2, tolerance=5):
        # Calculate the lengths of the sides of the triangle
        a = self.distance(p, p1)
        b = self.distance(p, p2)
        c = self.distance(p1, p2)

        # Calculate the semi-perimeter
        s = (a + b + c) / 2

        # Calculate the area using Heron's formula
        area = math.sqrt(s * (s - a) * (s - b) * (s - c))

        # Calculate the height (perpendicular distance from the point to the line)
        if c > 0:
            height = (2 * area) / c
            # Check if the height is within the tolerance distance to the line
            if height > tolerance:
                return False

            # Check if the projection of the point onto the line is within the line segment
            dot_product = ((p.x() - p1.x()) * (p2.x() - p1.x()) + (p.y() - p1.y()) * (p2.y() - p1.y())) / (c ** 2)

            return 0 <= dot_product <= 1
        else:
            return None

    def viewport_to_local_coord(self, qt_pos : QPoint) -> QPoint:
        self.to_quadrant_coords(qt_pos)
        return QPoint(self.to_quadrant_coords(qt_pos))

    def check_all_points(self,) -> list:
        old_points_ui = []
        new_points_ui = []

        for old_point_ui in self.slv_points_main:
            old_points_ui.append(old_point_ui['ui_point'])

        for i in range(self.solv.entity_len()):
            # Iterate though full length because mixed list from SS
            entity = self.solv.entity(i)
            if entity.is_point_2d() and self.solv.params(entity.params):
                x_tbu, y_tbu = self.solv.params(entity.params)
                point_solved = QPoint(x_tbu, y_tbu)
                new_points_ui.append(point_solved)

        # Now we have old_points_ui and new_points_ui, let's compare them
        differences = []

        if len(old_points_ui) != len(new_points_ui):
            print(f"Length mismatch {len(old_points_ui)} - {len(new_points_ui)}")

        for index, (old_point, new_point) in enumerate(zip(old_points_ui, new_points_ui)):
            if old_point != new_point:
                differences.append((index, old_point, new_point))

        return differences

    def update_ui_points(self, point_list: list):
        # Print initial state of slv_points_main
        # print("Initial slv_points_main:", self.slv_points_main)
        print("Change list:", point_list)

        if len(point_list) > 0:
            for tbu_points_idx in point_list:
                # Each tbu_points_idx is a tuple: (index, old_point, new_point)
                index, old_point, new_point = tbu_points_idx

                # Update the point in slv_points_main
                self.slv_points_main[index]['ui_point'] = new_point
            # Print updated state
            # print("Updated slv_points_main:", self.slv_points_main)

    def check_all_lines_and_update(self,changed_points: list):
        for tbu_points_idx in changed_points:
            index, old_point, new_point = tbu_points_idx
            for line_needs_update in self.slv_lines_main:
                if old_point == line_needs_update['ui_points'][0]:
                    line_needs_update['ui_points'][0] = new_point
                elif old_point == line_needs_update['ui_points'][1]:
                    line_needs_update['ui_points'][1] = new_point

    def mouseReleaseEvent(self, event):
        local_event_pos = self.viewport_to_local_coord(event.pos())

        if event.button() == Qt.LeftButton and not self.mouse_mode:
            self.drag_buffer[1] = local_event_pos

            print("Le main buffer", self.drag_buffer)

            if len(self.main_buffer) == 2:
                entry = self.drag_buffer[0]
                new_params = self.drag_buffer[1].x(), self.drag_buffer[1].y()
                self.solv.set_params(entry.params, new_params)

                self.solv.solve()

                points_need_update = self.check_all_points()
                self.update_ui_points(points_need_update)
                self.check_all_lines_and_update(points_need_update)

                self.update()
                self.drag_buffer = [None, None]

    def mousePressEvent(self, event):
        local_event_pos = self.viewport_to_local_coord(event.pos())

        relation_point = {
            'handle_nr': None,
            'solv_handle': None,
            'ui_point': None,
            'part_of_entity': None
        }

        relation_line = {
            'handle_nr': None,
            'solv_handle': None,
            'solv_entity_points': None,
            'ui_points': None
        }

        if event.button() == Qt.LeftButton and not self.mouse_mode:
            self.drag_buffer[0] = self.get_handle_from_ui_point(self.hovered_point)

        if event.button() == Qt.RightButton and self.mouse_mode:
            self.reset_buffers()

        if event.button() == Qt.LeftButton and self.mouse_mode == "line":
            clicked_pos = local_event_pos

            if not self.line_draw_buffer[0]:
                self.line_draw_buffer[0] = clicked_pos
                u = clicked_pos.x()
                v = clicked_pos.y()

                point = self.solv.add_point_2d(u, v, self.wp)

                relation_point = {}  # Reinitialize the dictionary
                handle_nr = self.get_handle_nr(str(point))
                relation_point['handle_nr'] = handle_nr
                relation_point['solv_handle'] = point
                relation_point['ui_point'] = clicked_pos

                self.slv_points_main.append(relation_point)

                print("points", self.slv_points_main)
                print("lines", self.slv_lines_main)

            elif self.line_draw_buffer[0]:
                self.line_draw_buffer[1] = clicked_pos
                u = clicked_pos.x()
                v = clicked_pos.y()

                point2 = self.solv.add_point_2d(u, v, self.wp)

                relation_point = {}  # Reinitialize the dictionary
                handle_nr = self.get_handle_nr(str(point2))
                relation_point['handle_nr'] = handle_nr
                relation_point['solv_handle'] = point2
                relation_point['ui_point'] = clicked_pos

                self.slv_points_main.append(relation_point)

                print("points", self.slv_points_main)
                print("lines", self.slv_lines_main)

            print("Buffer state", self.line_draw_buffer)
            if self.line_draw_buffer[0] and self.line_draw_buffer[1]:
                point_slv1 = self.get_handle_from_ui_point(self.line_draw_buffer[0])
                point_slv2 = self.get_handle_from_ui_point(self.line_draw_buffer[1])
                print(point_slv1)
                print(point_slv2)

                line = self.solv.add_line_2d(point_slv1, point_slv2, self.wp)

                relation_line = {}  # Reinitialize the dictionary
                handle_nr_line = self.get_handle_nr(str(line))
                relation_line['handle_nr'] = handle_nr_line
                relation_line['solv_handle'] = line
                relation_line['solv_entity_points'] = (point_slv1, point_slv2)
                relation_line['ui_points'] = [self.line_draw_buffer[0], self.line_draw_buffer[1]]

                # Track relationship of point in line
                relation_point['part_of_entity'] = handle_nr_line

                self.slv_lines_main.append(relation_line)

                # Reset the buffer for the next line segment
                self.line_draw_buffer[0] = self.line_draw_buffer[1]
                self.line_draw_buffer[1] = None

            # Track Relationship
            # Points

        if event.button() == Qt.LeftButton and self.mouse_mode == "pt_pt":
            if self.hovered_point and not self.main_buffer[0]:
                self.main_buffer[0] = self.get_handle_from_ui_point(self.hovered_point)

            elif self.main_buffer[0]:
                self.main_buffer[1] = self.get_handle_from_ui_point(self.hovered_point)

            if self.main_buffer[0] and self.main_buffer[1]:
                print("buf", self.main_buffer)

                self.solv.coincident(self.main_buffer[0], self.main_buffer[1], self.wp)

                if self.solv.solve() == ResultFlag.OKAY:
                    print("Fuck yeah")

                elif self.solv.solve() == ResultFlag.DIDNT_CONVERGE:
                    print("Solve_failed - Converge")

                elif self.solv.solve() == ResultFlag.TOO_MANY_UNKNOWNS:
                    print("Solve_failed - Unknowns")

                elif self.solv.solve() == ResultFlag.INCONSISTENT:
                    print("Solve_failed - Incons")
                self.constrain_done.emit()
                self.main_buffer = [None, None]

        if event.button() == Qt.LeftButton and self.mouse_mode == "pt_line":
            print("ptline")
            line_selected = None

            if self.hovered_point and not self.main_buffer[1]:
                self.main_buffer[0] = self.get_handle_from_ui_point(self.hovered_point)

            elif self.main_buffer[0]:
                self.main_buffer[1] = self.get_line_handle_from_ui_point(local_event_pos)

                # Contrain point to line
                if self.main_buffer[1]:
                    self.solv.coincident(self.main_buffer[0], self.main_buffer[1], self.wp)

                    if self.solv.solve() == ResultFlag.OKAY:
                        print("Fuck yeah")
                        self.constrain_done.emit()

                    elif self.solv.solve() == ResultFlag.DIDNT_CONVERGE:
                        print("Solve_failed - Converge")

                    elif self.solv.solve() == ResultFlag.TOO_MANY_UNKNOWNS:
                        print("Solve_failed - Unknowns")

                    elif self.solv.solve() == ResultFlag.INCONSISTENT:
                        print("Solve_failed - Incons")

                    self.constrain_done.emit()
                    # Clear saved_points after solve attempt
                    self.main_buffer = [None, None]

        if event.button() == Qt.LeftButton and self.mouse_mode == "pb_con_mid":
            print("ptline")
            line_selected = None

            if self.hovered_point and not self.main_buffer[1]:
                self.main_buffer[0] = self.get_handle_from_ui_point(self.hovered_point)

            elif self.main_buffer[0]:
                self.main_buffer[1] = self.get_line_handle_from_ui_point(local_event_pos)

                # Contrain point to line
                if self.main_buffer[1]:
                    self.solv.midpoint(self.main_buffer[0], self.main_buffer[1], self.wp)

                    if self.solv.solve() == ResultFlag.OKAY:
                        print("Fuck yeah")

                    elif self.solv.solve() == ResultFlag.DIDNT_CONVERGE:
                        print("Solve_failed - Converge")

                    elif self.solv.solve() == ResultFlag.TOO_MANY_UNKNOWNS:
                        print("Solve_failed - Unknowns")

                    elif self.solv.solve() == ResultFlag.INCONSISTENT:
                        print("Solve_failed - Incons")
                    self.constrain_done.emit()

                    self.main_buffer = [None, None]

        if event.button() == Qt.LeftButton and self.mouse_mode == "horiz":

            line_selected = self.get_line_handle_from_ui_point(local_event_pos)

            if line_selected:
                self.solv.horizontal(line_selected, self.wp)

            if self.solv.solve() == ResultFlag.OKAY:
                print("Fuck yeah")

            elif self.solv.solve() == ResultFlag.DIDNT_CONVERGE:
                print("Solve_failed - Converge")

            elif self.solv.solve() == ResultFlag.TOO_MANY_UNKNOWNS:
                print("Solve_failed - Unknowns")

            elif self.solv.solve() == ResultFlag.INCONSISTENT:
                print("Solve_failed - Incons")

        if event.button() == Qt.LeftButton and self.mouse_mode == "vert":
            line_selected = self.get_line_handle_from_ui_point(local_event_pos)

            if line_selected:
                self.solv.vertical(line_selected, self.wp)

                if self.solv.solve() == ResultFlag.OKAY:
                    print("Fuck yeah")

                elif self.solv.solve() == ResultFlag.DIDNT_CONVERGE:
                    print("Solve_failed - Converge")

                elif self.solv.solve() == ResultFlag.TOO_MANY_UNKNOWNS:
                    print("Solve_failed - Unknowns")

                elif self.solv.solve() == ResultFlag.INCONSISTENT:
                    print("Solve_failed - Incons")

        if event.button() == Qt.LeftButton and self.mouse_mode == "distance":
            # Depending on selected elemnts either point line or line distance
            #print("distance")
            e1 = None
            e2 = None

            if self.hovered_point:
                print("buf point")
                # Get the point as UI point as buffer
                self.main_buffer[0] = self.hovered_point

            elif self.selected_line:
                # Get the point as UI point as buffer
                self.main_buffer[1] = local_event_pos

            if self.main_buffer[0] and self.main_buffer[1]:
                # Define point line combination
                e1 = self.get_handle_from_ui_point(self.main_buffer[0])
                e2 = self.get_line_handle_from_ui_point(self.main_buffer[1])

            elif not self.main_buffer[0]:
                # Define only line selection
                e1, e2 = self.get_point_line_handles_from_ui_point(local_event_pos)

            if e1 and e2:
                # Ask fo the dimension and solve if both elements are present
                length, ok = QInputDialog.getDouble(self, 'Distance', 'Enter a mm value:', value=100, decimals=2)
                self.solv.distance(e1, e2, length, self.wp)

                if self.solv.solve() == ResultFlag.OKAY:
                    print("Fuck yeah")

                elif self.solv.solve() == ResultFlag.DIDNT_CONVERGE:
                    print("Solve_failed - Converge")

                elif self.solv.solve() == ResultFlag.TOO_MANY_UNKNOWNS:
                    print("Solve_failed - Unknowns")

                elif self.solv.solve() == ResultFlag.INCONSISTENT:
                    print("Solve_failed - Incons")

                self.constrain_done.emit()
                self.main_buffer = [None, None]

        # Update the main point list with the new elements and draw them
        points_need_update = self.check_all_points()
        self.update_ui_points(points_need_update)
        self.check_all_lines_and_update(points_need_update)

        self.update()

    def mouseMoveEvent(self, event):
        local_event_pos = self.viewport_to_local_coord(event.pos())

        closest_point = None
        min_distance = float('inf')
        threshold = 10  # Distance threshold for highlighting

        for point in self.slv_points_main:
            distance = (local_event_pos - point['ui_point']).manhattanLength()
            if distance < threshold and distance < min_distance:
                closest_point = point['ui_point']
                min_distance = distance

        if closest_point != self.hovered_point:
            self.hovered_point = closest_point
            print(self.hovered_point)

        for dic in self.slv_lines_main:
            p1 = dic['ui_points'][0]
            p2 = dic['ui_points'][1]

            if self.is_point_on_line(local_event_pos, p1, p2):
                self.selected_line = p1, p2
                break
            else:
                self.selected_line = None

        self.update()

    def mouseDoubleClickEvent(self, event):
        pass

    def drawBackgroundGrid(self, painter):
        """Draw a background grid."""
        grid_spacing = 50
        pen = QPen(QColor(200, 200, 200), 1, Qt.SolidLine)
        painter.setPen(pen)

        # Draw vertical grid lines
        for x in range(-self.width() // 2, self.width() // 2, grid_spacing):
            painter.drawLine(x, -self.height() // 2, x, self.height() // 2)

        # Draw horizontal grid lines
        for y in range(-self.height() // 2, self.height() // 2, grid_spacing):
            painter.drawLine(-self.width() // 2, y, self.width() // 2, y)

    def drawAxes(self, painter):
        painter.setRenderHint(QPainter.Antialiasing)

        # Set up pen for dashed lines
        pen = QPen(Qt.gray, 1, Qt.DashLine)
        painter.setPen(pen)

        middle_x = self.width() // 2
        middle_y = self.height() // 2

        # Draw X axis as dashed line
        painter.drawLine(0, middle_y, self.width(), middle_y)

        # Draw Y axis as dashed line
        painter.drawLine(middle_x, 0, middle_x, self.height())

        # Draw tick marks
        tick_length = int(10  * self.zoom)
        tick_spacing = int(50 * self.zoom)

        pen = QPen(Qt.gray, 1, Qt.SolidLine)
        painter.setPen(pen)

        # Draw tick marks on the X axis to the right and left from the middle point
        for x in range(0, self.width() // 2, tick_spacing):
            painter.drawLine(middle_x + x, middle_y - tick_length // 2, middle_x + x, middle_y + tick_length // 2)
            painter.drawLine(middle_x - x, middle_y - tick_length // 2, middle_x - x, middle_y + tick_length // 2)

        # Draw tick marks on the Y axis upwards and downwards from the middle point
        for y in range(0, self.height() // 2, tick_spacing):
            painter.drawLine(middle_x - tick_length // 2, middle_y + y, middle_x + tick_length // 2, middle_y + y)
            painter.drawLine(middle_x - tick_length // 2, middle_y - y, middle_x + tick_length // 2, middle_y - y)

        # Draw the origin point in red
        painter.setPen(QPen(Qt.red, 4))
        painter.drawPoint(middle_x, middle_y)

    def to_quadrant_coords(self, point):
        """Translate linear coordinates to quadrant coordinates."""
        center_x = self.width() // 2
        center_y = self.height() // 2
        quadrant_x = point.x() - center_x
        quadrant_y = point.y() - center_y
        return QPoint(quadrant_x, quadrant_y) / self.zoom

    def paintEvent(self, event):
        painter = QPainter(self)

        self.drawAxes(painter)

        # Translate the origin to the center of the widget
        center = QPoint(self.width() // 2, self.height() // 2)
        painter.translate(center)

        # Apply the zoom factor
        painter.scale(self.zoom, self.zoom)

        pen = QPen(Qt.gray)
        pen.setWidth(2 / self.zoom)
        painter.setPen(pen)

        # Draw points
        for point in self.slv_points_main:
            painter.drawEllipse(point['ui_point'], 3 / self.zoom, 3 / self.zoom)

        for dic in self.slv_lines_main:
            p1 = dic['ui_points'][0]
            p2 = dic['ui_points'][1]
            painter.drawLine(p1, p2)

            dis = self.distance(p1, p2)
            mid = self.calculate_midpoint(p1, p2)
            painter.drawText(mid, str(round(dis, 2)))

        pen = QPen(Qt.green)
        pen.setWidth(2)
        painter.setPen(pen)

        if self.solv.entity_len():
            for i in range(self.solv.entity_len()):
                # 3 Entitys in the beginning of the workplane normal and point
                entity = self.solv.entity(i)
                if entity.is_point_2d() and self.solv.params(entity.params):
                    x, y = self.solv.params(entity.params)
                    point = QPoint(x, y)
                    painter.drawEllipse(point, 6 / self.zoom, 6 / self.zoom)

        #Highlight point hovered
        if self.hovered_point:
            highlight_pen = QPen(QColor(255, 0, 0))
            highlight_pen.setWidth(2)
            painter.setPen(highlight_pen)
            painter.drawEllipse(self.hovered_point, 5 / self.zoom, 5 / self.zoom)

        # Highlight line hovered
        if self.selected_line and not self.hovered_point:
            p1, p2 = self.selected_line
            painter.setPen(QPen(Qt.red, 2))
            painter.drawLine(p1, p2)

            # self.drawBackgroundGrid(painter)
        painter.end()

    def wheelEvent(self, event):
        delta = event.angleDelta().y()
        self.zoom += (delta / 200) * 0.1
        self.update()
    
    def aspect_ratio(self):
        return self.width() / self.height() * (1.0 / abs(self.zoom))

    def clear_sketch(self):
        self.slv_points_main = []
        self.slv_lines_main = []
        self.reset_buffers()
        self.solv = SolverSystem()


# Example usage
if __name__ == "__main__":
    import sys

    app = QApplication(sys.argv)
    window = SketchWidget()
    window.setWindowTitle("Snap Line Widget")
    window.resize(800, 600)
    window.show()
    sys.exit(app.exec())