import sys

import numpy as np
import pyvista as pv
from pyvista.plotting.opts import ElementType
from pyvistaqt import QtInteractor
from PySide6.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget


class PyVistaWidget(QWidget):
    def __init__(self, parent=None):
        super().__init__(parent)

        # Create the PyVista plotter
        self.plotter = QtInteractor(self)
        self.plotter.background_color = "darkgray"

        # Create a layout and add the PyVista widget
        layout = QVBoxLayout()
        layout.addWidget(self.plotter.interactor)
        self.setLayout(layout)

        # Set up the picker
        #self.plotter.enable_cell_picking(callback=self.on_cell_pick, show=True)
        self.plotter.enable_element_picking(callback=self.on_cell_pick, show=True, mode="face", left_clicking=True)

    def on_cell_pick(self, element):
        if element is not None:
            mesh = self.plotter.mesh  # Get the current mesh
            print(mesh)
            print(element)

            """# Get the face data
            face = mesh.extract_cells(element)

            # Compute face normal
            face.compute_normals(cell_normals=True, inplace=True)
            normal = face.cell_data['Normals'][0]

            # Get the points of the face
            points = face.points

            print(f"Picked face ID: {face_id}")
            print(f"Face normal: {normal}")
            print("Face points:")
            for point in points:
                print(point)"""
        else:
            print("No face was picked or the picked element is not a face.")
    def create_simplified_outline(self, mesh, camera):
        # Project 3D to 2D
        points_2d = self.plotter.map_to_2d(mesh.points)

        # Detect silhouette edges (simplified approach)
        edges = mesh.extract_feature_edges(feature_angle=90, boundary_edges=False, non_manifold_edges=False)

        # Project edges to 2D
        edge_points_2d = self.plotter.map_to_2d(edges.points)

        # Create 2D outline
        self.plotter.add_lines(edge_points_2d, color='black', width=2)
        self.plotter.render()

    def mesh_from_points(self, points):
        # Convert points to numpy array if not already
        points = np.array(points)

        # Create faces array
        num_triangles = len(points) // 3
        faces = np.arange(len(points)).reshape(num_triangles, 3)
        faces = np.column_stack((np.full(num_triangles, 3), faces))  # Add 3 as first column

        # Create PyVista PolyData
        mesh = pv.PolyData(points, faces)

        # Optional: Merge duplicate points
        mesh = mesh.clean()

        # Optional: Compute normals
        mesh = mesh.compute_normals(point_normals=False, cell_normals=True, consistent_normals=True)
        edges = mesh.extract_feature_edges(30, non_manifold_edges=False)

        # Clear any existing meshes
        self.plotter.clear()

        # Add the mesh to the plotter
        self.plotter.add_mesh(mesh, pickable=True, color='white', show_edges=True, line_width=2, pbr=True, metallic=0.8, roughness=0.1, diffuse=1)
        self.plotter.add_mesh(edges, color="red", line_width=10)

        # Reset the camera to fit the new mesh
        self.plotter.reset_camera()

        # Update the render window
        self.plotter.update()

        # Print statistics
        print(f"Original points: {len(points)}")
        print(f"Number of triangles: {num_triangles}")
        print(f"Final number of points: {mesh.n_points}")
        print(f"Final number of cells: {mesh.n_cells}")


class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("PyVista in PySide6")
        self.setGeometry(100, 100, 800, 600)