import sys
import numpy as np
from PySide6.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget
from PySide6.QtOpenGLWidgets import QOpenGLWidget
from PySide6.QtCore import Qt, QPoint
from OpenGL.GL import *
from OpenGL.GLU import *

##testing

def create_cube(scale=1):
    vertices = np.array([
        [0, 0, 0],
        [2, 0, 0],
        [2, 2, 0],
        [0, 2, 0],
        [0, 0, 2],
        [2, 0, 2],
        [2, 2, 2],
        [0, 2, 2]
    ]) * scale

    faces = np.array([
        [0, 1, 2],
        [2, 3, 0],
        [4, 5, 6],
        [6, 7, 4],
        [0, 1, 5],
        [5, 4, 0],
        [2, 3, 7],
        [7, 6, 2],
        [0, 3, 7],
        [7, 4, 0],
        [1, 2, 6],
        [6, 5, 1]
    ])

    return vertices, faces


class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("OpenGL Cube Viewer")
        self.setGeometry(100, 100, 800, 600)

        self.opengl_widget = OpenGLWidget()

        central_widget = QWidget()
        layout = QVBoxLayout()
        layout.addWidget(self.opengl_widget)
        central_widget.setLayout(layout)
        self.setCentralWidget(central_widget)

        # Load cube data
        vertices, faces = create_cube()
        self.opengl_widget.load_interactor_mesh((vertices, faces))


class OpenGLWidget(QOpenGLWidget):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.vertices = None
        self.faces = None
        self.selected_face = -1
        self.scale_factor = 1
        self.mesh_loaded = None
        self.interactor_loaded = None
        self.centroid = None
        self.stl_file = "out.stl"  # Replace with your STL file path
        self.lastPos = QPoint()
        self.startPos = None
        self.endPos = None
        self.xRot = 180
        self.yRot = 0
        self.zoom = -2
        self.sketch = []
        self.gl_width = self.width()
        self.gl_height = self.height()

    def map_value_to_range(self, value, value_min=0, value_max=1920, range_min=-1, range_max=1):
        value = max(value_min, min(value_max, value))
        mapped_value = ((value - value_min) / (value_max - value_min)) * (range_max - range_min) + range_min

        return mapped_value

    def load_stl(self, filename: str) -> object:
        try:
            stl_mesh = mesh.Mesh.from_file(filename)

            # Extract vertices
            vertices = np.concatenate([stl_mesh.v0, stl_mesh.v1, stl_mesh.v2])

            # Calculate bounding box
            min_x, min_y, min_z = vertices.min(axis=0)
            max_x, max_y, max_z = vertices.max(axis=0)

            # Calculate centroid
            centroid_x = (min_x + max_x) / 2.0
            centroid_y = (min_y + max_y) / 2.0
            centroid_z = (min_z + max_z) / 2.0

            self.mesh_loaded = stl_mesh.vectors
            self.centroid = (centroid_x, centroid_y, centroid_z)

        except FileNotFoundError:
            print(f"Error: File {filename} not found.")
        except Exception as e:
            print(f"Error loading {filename}: {e}")

        return None, (0, 0, 0)

    def load_interactor_mesh(self, simp_mesh):
        self.interactor_loaded = simp_mesh
        # Calculate centroid based on the average position of vertices
        centroid = np.mean(simp_mesh[0], axis=0)

        self.centroid = tuple(centroid)
        print(f"Centroid: {self.centroid}")

        self.update()

    def load_mesh_direct(self, mesh):
        try:
            stl_mesh = mesh

            # Extract vertices
            vertices = np.array(stl_mesh)

            # Calculate centroid based on the average position of vertices
            centroid = np.mean(vertices, axis=0)

            self.mesh_loaded = vertices
            self.centroid = tuple(centroid)
            print(f"Centroid: {self.centroid}")
            self.update()
        except Exception as e:
            print(e)

    def clear_mesh(self):
        self.mesh_loaded = None

    def initializeGL(self):
        glClearColor(0, 0, 0, 1)
        glEnable(GL_DEPTH_TEST)

    def resizeGL(self, width, height):
        glViewport(0, 0, width, height)
        glMatrixMode(GL_PROJECTION)
        glLoadIdentity()

        aspect = width / float(height)

        self.gl_width = self.width()
        self.gl_height = self.height()

        gluPerspective(45.0, aspect, 0.01, 1000.0)
        glMatrixMode(GL_MODELVIEW)

    def unproject(self, x, y, z, modelview, projection, viewport):
        mvp = np.dot(projection, modelview)
        mvp_inv = np.linalg.inv(mvp)

        ndc = np.array([(x - viewport[0]) / viewport[2] * 2 - 1,
                        (y - viewport[1]) / viewport[3] * 2 - 1,
                        2 * z - 1,
                        1])

        world = np.dot(mvp_inv, ndc)
        print("world undproj", world)
        return world[:3] / world[3]

    def draw_ray(self, ray_start, ray_end):
        glColor3f(1.0, 0.0, 0.0)  # Set the color of the ray (red)
        glBegin(GL_LINES)
        glVertex3f(*ray_start)
        glVertex3f(*ray_end)
        glEnd()

    def mousePressEvent(self, event):
        if event.buttons() & Qt.RightButton:
            self.select_face(event)

    def select_face(self, event):
        x = event.position().x()
        y = event.position().y()

        modelview = glGetDoublev(GL_MODELVIEW_MATRIX)
        projection = glGetDoublev(GL_PROJECTION_MATRIX)
        viewport = glGetIntegerv(GL_VIEWPORT)

        # Unproject near and far points in world space
        ray_start = gluUnProject(x, y, 0.0, modelview, projection, viewport)
        ray_end = gluUnProject(x, y, 1.0, modelview, projection, viewport)

        ray_start = np.array(ray_start)
        ray_end = np.array(ray_end)
        ray_direction = ray_end - ray_start
        ray_direction /= np.linalg.norm(ray_direction)

        print(f"Ray start: {ray_start}")
        print(f"Ray end: {ray_end}")
        print(f"Ray direction: {ray_direction}")

        self.selected_face = self.check_intersection(ray_start, ray_end)
        print(f"Selected face: {self.selected_face}")

        self.update()

    def ray_box_intersection(self, ray_origin, ray_direction, box_min, box_max):
        inv_direction = 1 / (ray_direction + 1e-7)  # Add small value to avoid division by zero
        t1 = (box_min - ray_origin) * inv_direction
        t2 = (box_max - ray_origin) * inv_direction

        t_min = np.max(np.minimum(t1, t2))
        t_max = np.min(np.maximum(t1, t2))

        print(f"min: {t_min}, max: {t_max}" )

        return t_max >= t_min and t_max > 0

    def check_intersection(self, ray_start, ray_end):
        # Get the current modelview matrix
        modelview = glGetDoublev(GL_MODELVIEW_MATRIX)

        # Transform vertices to camera space
        vertices_cam = [np.dot(modelview, np.append(v, 1))[:3] for v in self.interactor_loaded[0]]

        ray_direction = ray_end - ray_start
        ray_direction /= np.linalg.norm(ray_direction)

        print(f"Checking intersection with {len(self.interactor_loaded[1])} faces")
        for face_idx, face in enumerate(self.interactor_loaded[1]):
            v0, v1, v2 = [vertices_cam[i] for i in face]
            intersection = self.moller_trumbore(ray_start, ray_direction, v0, v1, v2)
            if intersection is not None:
                print(f"Intersection found with face {face_idx}")
                return face_idx

        print("No intersection found")
        return None

    def moller_trumbore(self, ray_origin, ray_direction, v0, v1, v2):
        epsilon = 1e-6
        # Find vectors for two edges sharing v0
        edge1 = v1 - v0
        edge2 = v2 - v0
        pvec = np.cross(ray_direction, edge2)

        det = np.dot(edge1, pvec)
        print(det)

        """if det < epsilon:
            return None"""

        inv_det = 1.0 / det
        tvec = ray_origin - v0
        u = np.dot(tvec, pvec) * inv_det

        print("u", u )

        if u < 0.0 or u > 1.0:
            return None

        qvec = np.cross(tvec, edge1)

        # Calculate v parameter and test bounds
        v = np.dot(ray_direction, qvec) * inv_det
        print("v", v)

        if v < 0.0 or u + v > 1.0:
            return None

        # Calculate t, ray intersects triangle
        t = np.dot(edge2, qvec) * inv_det
        print("t",t)

        if t > epsilon:
            return ray_origin + t * ray_direction

        return None

    def ray_triangle_intersection(self, ray_origin, ray_direction, v0, v1, v2):
        epsilon = 1e-5
        edge1 = v1 - v0
        edge2 = v2 - v0
        h = np.cross(ray_direction, edge2)
        a = np.dot(edge1, h)

        print(f"Triangle vertices: {v0}, {v1}, {v2}")
        print(f"a: {a}")

        if abs(a) < epsilon:
            print("Ray is parallel to the triangle")
            return None  # Ray is parallel to the triangle

        f = 1.0 / a
        s = ray_origin - v0
        u = f * np.dot(s, h)

        print(f"u: {u}")

        if u < 0.0 or u > 1.0:
            print("u is out of range")
            return None

        q = np.cross(s, edge1)
        v = f * np.dot(ray_direction, q)

        print(f"v: {v}")

        if v < 0.0 or u + v > 1.0:
            print("v is out of range")
            return None

        t = f * np.dot(edge2, q)

        print(f"t: {t}")

        if t > epsilon:
            intersection_point = ray_origin + t * ray_direction
            print(f"Intersection point: {intersection_point}")
            return intersection_point

        print("t is too small")
        return None
    def paintGL(self):
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        glMatrixMode(GL_MODELVIEW)
        glLoadIdentity()

        # Apply camera transformation
        glTranslatef(0, 0, self.zoom)
        glRotatef(self.xRot, 1.0, 0.0, 0.0)
        glRotatef(self.yRot, 0.0, 1.0, 0.0)

        """# Apply model transformation
        glTranslatef(self.tx, self.ty, self.tz)
        glScalef(self.scale, self.scale, self.scale)
        glRotatef(self.model_xRot, 1.0, 0.0, 0.0)
        glRotatef(self.model_yRot, 0.0, 1.0, 0.0)
        glRotatef(self.model_zRot, 0.0, 0.0, 1.0)"""

        glColor3f(0.9, 0.8, 0.8)
        self.draw_area()

        if self.mesh_loaded is not None:
            # Adjust the camera for the STL mesh
            if self.centroid:
                glPushMatrix()  # Save current transformation matrix
                glScalef(self.scale_factor, self.scale_factor, self.scale_factor)  # Apply scaling

                cx, cy, cz = self.centroid
                gluLookAt(cx, cy, cz + 100, cx, cy, cz, 0, 1, 0)

                self.draw_mesh_direct(self.mesh_loaded)
                glPopMatrix()  # Restore transformation matrix

        if self.interactor_loaded is not None:
            # Draw interactor mesh
            glPushMatrix()  # Save current transformation matrix
            glScalef(self.scale_factor, self.scale_factor, self.scale_factor)  # Apply scaling

            self.draw_interactor(self.interactor_loaded)
            glPopMatrix()  # Restore transformation matrix

        if self.selected_face is not None:
            glColor3f(0.0, 1.0, 0.0)  # Red color for selected face
            glBegin(GL_TRIANGLES)
            for vertex_idx in self.interactor_loaded[1][self.selected_face]:
                glVertex3fv(self.interactor_loaded[0][vertex_idx])
            glEnd()

            # Flush the OpenGL pipeline and swap buffers


        if hasattr(self, 'ray_start') and hasattr(self, 'ray_end'):
            self.draw_ray(self.ray_start, self.ray_end)

        glFlush()

    def draw_stl(self, vertices):
        glEnable(GL_LIGHTING)
        glEnable(GL_LIGHT0)
        glEnable(GL_DEPTH_TEST)
        glEnable(GL_COLOR_MATERIAL)
        glColorMaterial(GL_FRONT_AND_BACK, GL_AMBIENT_AND_DIFFUSE)

        glLightfv(GL_LIGHT0, GL_POSITION, (0, 1, 1, 0))
        glLightfv(GL_LIGHT0, GL_DIFFUSE, (0.6, 0.6, 0.6, 1.0))

        glBegin(GL_TRIANGLES)
        for triangle in vertices:
            for vertex in triangle:
                glVertex3fv(vertex)
        glEnd()
        self.update()

    def draw_interactor(self, simp_mesh: tuple):
        vertices, faces = simp_mesh

        glEnable(GL_LIGHTING)
        glEnable(GL_LIGHT0)
        glEnable(GL_DEPTH_TEST)
        glEnable(GL_COLOR_MATERIAL)
        glColorMaterial(GL_FRONT_AND_BACK, GL_AMBIENT_AND_DIFFUSE)

        glLightfv(GL_LIGHT0, GL_POSITION, (0, 0.6, 0.6, 0))
        glLightfv(GL_LIGHT0, GL_DIFFUSE, (0.4, 0.4, 0.4, 0.6))

        # Draw the faces
        glDisable(GL_LIGHTING)
        glColor3f(0.2, 0.0, 0.0)  # Set face color to red (or any color you prefer)

        glBegin(GL_TRIANGLES)
        for face in faces:
            for vertex_index in face:
                glVertex3fv(vertices[vertex_index])
        glEnd()

        # Draw the lines (edges of the triangles)
        glColor3f(0.0, 1.0, 0.0)  # Set line color to green (or any color you prefer)

        glBegin(GL_LINES)
        for face in faces:
            for i in range(len(face)):
                glVertex3fv(vertices[face[i]])
                glVertex3fv(vertices[face[(i + 1) % len(face)]])
        glEnd()

        glEnable(GL_LIGHTING)  # Re-enable lighting if further drawing requires it

    def draw_mesh_direct(self, points):
        glEnable(GL_LIGHTING)
        glEnable(GL_LIGHT0)
        glEnable(GL_DEPTH_TEST)
        glEnable(GL_COLOR_MATERIAL)
        glColorMaterial(GL_FRONT_AND_BACK, GL_AMBIENT_AND_DIFFUSE)

        glLightfv(GL_LIGHT0, GL_POSITION, (0, 0.6, 0.6, 0))
        glLightfv(GL_LIGHT0, GL_DIFFUSE, (0.4, 0.4, 0.4, 0.6))

        glDisable(GL_LIGHTING)
        glBegin(GL_TRIANGLES)
        for vertex in points:
            glVertex3fv(vertex)
        glEnd()

        # Draw the lines (edges of the triangles)
        #glDisable(GL_LIGHTING)  # Disable lighting to avoid affecting the line color
        glColor3f(0.0, 0.0, 0.0)  # Set line color to black (or any color you prefer)

        glBegin(GL_LINES)
        for i in range(0, len(points), 3):
            glVertex3fv(points[i])
            glVertex3fv(points[i + 1])

            glVertex3fv(points[i + 1])
            glVertex3fv(points[i + 2])

            glVertex3fv(points[i + 2])
            glVertex3fv(points[i])
        glEnd()

        glEnable(GL_LIGHTING)  # Re-enable lighting if further drawing requires it

    def draw_area(self):
        glColor3f(0.5, 0.5, 0.5)  # Gray color

        glBegin(GL_LINES)
        for x in range(0, self.width(), 1):
            x_ndc = self.map_value_to_range(x, 0, value_max=self.width(), range_min=-self.gl_width, range_max=self.gl_width)
            glVertex2f(x_ndc, -self.gl_height)  # Start from y = -1
            glVertex2f(x_ndc, self.gl_height)   # End at y = 1

        for y in range(0, self.height(), 1):
            y_ndc = self.map_value_to_range(y, 0, value_max=self.height(), range_min=-self.gl_height, range_max=self.gl_height)
            glVertex2f(-self.gl_width, y_ndc)  # Start from x = -1
            glVertex2f(self.gl_width, y_ndc)   # End at x = 1
        glEnd()

    def mouseMoveEvent(self, event):
        dx = event.x() - self.lastPos.x()
        dy = event.y() - self.lastPos.y()

        if event.buttons() & Qt.MouseButton.LeftButton :
            self.xRot += 0.5 * dy
            self.yRot += 0.5 * dx
            self.lastPos = event.pos()
        self.update()

    def wheelEvent(self, event):
        delta = event.angleDelta().y()
        self.zoom += delta / 200
        self.update()

    def aspect_ratio(self):
        return self.width() / self.height() * (1.0 / abs(self.zoom))

if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = MainWindow()
    window.show()
    sys.exit(app.exec())