import numpy as np
from PySide6.QtOpenGLWidgets import QOpenGLWidget
from PySide6.QtCore import Qt, QPoint
from OpenGL.GL import *
from OpenGL.GLU import *
from stl import mesh

class OpenGLWidget(QOpenGLWidget):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.scale_factor = 0.001
        self.mesh_loaded = None
        self.centroid = None
        self.stl_file = "out.stl"  # Replace with your STL file path
        self.lastPos = QPoint()
        self.startPos = None
        self.endPos = None
        self.xRot = 180
        self.yRot = 0
        self.zoom = -2
        self.sketch = []
        self.gl_width = self.width() / 100
        self.gl_height = self.height() / 100

    def map_value_to_range(self, value, value_min=0, value_max=1920, range_min=-1, range_max=1):
        value = max(value_min, min(value_max, value))
        mapped_value = ((value - value_min) / (value_max - value_min)) * (range_max - range_min) + range_min
        
        return mapped_value

    def load_stl(self, filename: str) -> object:
        try:
            stl_mesh = mesh.Mesh.from_file(filename)

            # Extract vertices
            vertices = np.concatenate([stl_mesh.v0, stl_mesh.v1, stl_mesh.v2])

            # Calculate bounding box
            min_x, min_y, min_z = vertices.min(axis=0)
            max_x, max_y, max_z = vertices.max(axis=0)

            # Calculate centroid
            centroid_x = (min_x + max_x) / 2.0
            centroid_y = (min_y + max_y) / 2.0
            centroid_z = (min_z + max_z) / 2.0

            self.mesh_loaded = stl_mesh.vectors
            self.centroid = (centroid_x, centroid_y, centroid_z)
        
        except FileNotFoundError:
            print(f"Error: File {filename} not found.")
        except Exception as e:
            print(f"Error loading {filename}: {e}")
            
        return None, (0, 0, 0)


    def load_mesh_direct(self, mesh):
        try:
            stl_mesh = mesh

            # Extract vertices
            vertices = np.array(stl_mesh)

            # Calculate centroid based on the average position of vertices
            centroid = np.mean(vertices, axis=0)

            self.mesh_loaded = vertices
            self.centroid = tuple(centroid)
            print(f"Centroid: {self.centroid}")
            self.update()
        except Exception as e:
            print(e)

    def clear_mesh(self):
        self.mesh_loaded = None

    def initializeGL(self):
        glClearColor(0, 0, 0, 1)
        glEnable(GL_DEPTH_TEST)

    def resizeGL(self, width, height):
        glViewport(0, 0, width, height)
        glMatrixMode(GL_PROJECTION)
        glLoadIdentity()
        aspect = width / float(height)

        self.gl_width = self.width() / 1000
        self.gl_height = self.height() / 1000

        gluPerspective(45.0, aspect, 0.01, 10.0)
        glMatrixMode(GL_MODELVIEW)


    def paintGL(self):
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        glLoadIdentity()

        glTranslatef(0, 0, self.zoom)
        glRotatef(self.xRot, 1.0, 0.0, 0.0)
        glRotatef(self.yRot, 0.0, 1.0, 0.0)

        glColor3f(0.9, 0.8, 0.8)

        self.draw_area()

        if self.mesh_loaded is not None:
            # Adjust the camera
            if self.centroid:
                glScalef(self.scale_factor, self.scale_factor, self.scale_factor)  # Apply scaling

                cx, cy, cz = self.centroid
                gluLookAt(cx, cy, cz + 100, cx, cy, cz, 0, 1, 0)

            self.draw_mesh_direct(self.mesh_loaded)
        else:
            glClearColor(0.0, 0.0, 0.0, 1.0)  # Set the clear color (black with full opacity)
            glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)  # Clear the color and depth buffers


    def draw_stl(self, vertices):
        glEnable(GL_LIGHTING)
        glEnable(GL_LIGHT0)
        glEnable(GL_DEPTH_TEST)
        glEnable(GL_COLOR_MATERIAL)
        glColorMaterial(GL_FRONT_AND_BACK, GL_AMBIENT_AND_DIFFUSE)

        glLightfv(GL_LIGHT0, GL_POSITION, (0, 1, 1, 0))
        glLightfv(GL_LIGHT0, GL_DIFFUSE, (0.6, 0.6, 0.6, 1.0))

        glBegin(GL_TRIANGLES)
        for triangle in vertices:
            for vertex in triangle:
                glVertex3fv(vertex)
        glEnd()
        self.update()

    def draw_mesh_direct(self, points):
        glEnable(GL_LIGHTING)
        glEnable(GL_LIGHT0)
        glEnable(GL_DEPTH_TEST)
        glEnable(GL_COLOR_MATERIAL)
        glColorMaterial(GL_FRONT_AND_BACK, GL_AMBIENT_AND_DIFFUSE)

        glLightfv(GL_LIGHT0, GL_POSITION, (0, 0.6, 0.6, 0))
        glLightfv(GL_LIGHT0, GL_DIFFUSE, (0.4, 0.4, 0.4, 0.6))

        glDisable(GL_LIGHTING)
        glBegin(GL_TRIANGLES)
        for vertex in points:
            glVertex3fv(vertex)
        glEnd()

        # Draw the lines (edges of the triangles)
        #glDisable(GL_LIGHTING)  # Disable lighting to avoid affecting the line color
        glColor3f(0.0, 0.0, 0.0)  # Set line color to black (or any color you prefer)

        glBegin(GL_LINES)
        for i in range(0, len(points), 3):
            glVertex3fv(points[i])
            glVertex3fv(points[i + 1])

            glVertex3fv(points[i + 1])
            glVertex3fv(points[i + 2])

            glVertex3fv(points[i + 2])
            glVertex3fv(points[i])
        glEnd()

        glEnable(GL_LIGHTING)  # Re-enable lighting if further drawing requires it


    def draw_area(self):
        glColor3f(0.5, 0.5, 0.5)  # Gray color

        glBegin(GL_LINES)
        for x in range(0, self.width(), 20):
            x_ndc = self.map_value_to_range(x, 0, value_max=self.width(), range_min=-self.gl_width, range_max=self.gl_width)
            glVertex2f(x_ndc, -self.gl_height)  # Start from y = -1
            glVertex2f(x_ndc, self.gl_height)   # End at y = 1

        for y in range(0, self.height(), 20):
            y_ndc = self.map_value_to_range(y, 0, value_max=self.height(), range_min=-self.gl_height, range_max=self.gl_height)
            glVertex2f(-self.gl_width, y_ndc)  # Start from x = -1
            glVertex2f(self.gl_width, y_ndc)   # End at x = 1
        glEnd()


    def mouseMoveEvent(self, event):
        dx = event.x() - self.lastPos.x()
        dy = event.y() - self.lastPos.y()

        if event.buttons() & Qt.MouseButton.LeftButton :
            self.xRot += 0.5 * dy
            self.yRot += 0.5 * dx
            self.lastPos = event.pos()
        self.update()

    def wheelEvent(self, event):
        delta = event.angleDelta().y()
        self.zoom += delta / 200
        self.update()

    def aspect_ratio(self):
        return self.width() / self.height() * (1.0 / abs(self.zoom))