def are_coplanar(self, normal1, normal2, point1, point2, tolerance=1e-6):
    # Check if normals are parallel
    if np.abs(np.dot(normal1, normal2)) < 1 - tolerance:
        return False

    # Check if points lie on the same plane
    diff = point2 - point1
    return np.abs(np.dot(diff, normal1)) < tolerance


def merge_coplanar_triangles(self, polydata):
    # Compute normals
    normalGenerator = vtk.vtkPolyDataNormals()
    normalGenerator.SetInputData(polydata)
    normalGenerator.ComputePointNormalsOff()
    normalGenerator.ComputeCellNormalsOn()
    normalGenerator.Update()

    mesh = normalGenerator.GetOutput()
    n_cells = mesh.GetNumberOfCells()

    # Create a map to store merged triangles
    merged = {}

    for i in range(n_cells):
        if i in merged:
            continue

        cell = mesh.GetCell(i)
        normal = np.array(mesh.GetCellData().GetNormals().GetTuple(i))
        point = np.array(cell.GetPoints().GetPoint(0))

        merged[i] = [i]

        for j in range(i + 1, n_cells):
            if j in merged:
                continue

            cell_j = mesh.GetCell(j)
            normal_j = np.array(mesh.GetCellData().GetNormals().GetTuple(j))
            point_j = np.array(cell_j.GetPoints().GetPoint(0))

            if self.are_coplanar(normal, normal_j, point, point_j):
                merged[i].append(j)

    # Create new polygons
    new_polygons = vtk.vtkCellArray()
    for group in merged.values():
        if len(group) > 1:
            polygon = vtk.vtkPolygon()
            points = set()
            for idx in group:
                cell = mesh.GetCell(idx)
                for j in range(3):
                    point_id = cell.GetPointId(j)
                    points.add(point_id)
            polygon.GetPointIds().SetNumberOfIds(len(points))
            for j, point_id in enumerate(points):
                polygon.GetPointIds().SetId(j, point_id)
            new_polygons.InsertNextCell(polygon)
        else:
            new_polygons.InsertNextCell(mesh.GetCell(group[0]))

    # Create new polydata
    new_polydata = vtk.vtkPolyData()
    new_polydata.SetPoints(mesh.GetPoints())
    new_polydata.SetPolys(new_polygons)

    return new_polydata


def create_cube_mesh(self):
    # cube_source = vtk.vtkSuperquadricSource()

    reader = vtk.vtkSTLReader()
    reader.SetFileName("case.stl")  # Replace with your mesh file path
    reader.Update()

    featureEdges = vtk.vtkFeatureEdges()
    featureEdges.SetInputConnection(reader.GetOutputPort())
    featureEdges.BoundaryEdgesOn()
    featureEdges.FeatureEdgesOn()
    featureEdges.ManifoldEdgesOff()
    featureEdges.NonManifoldEdgesOff()
    featureEdges.Update()

    # print(cube_source)
    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputConnection(reader.GetOutputPort())
    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    self.renderer.AddActor(actor)

    mapper_edge = vtk.vtkPolyDataMapper()
    mapper_edge.SetInputConnection(featureEdges.GetOutputPort())
    actor = vtk.vtkActor()
    actor.SetMapper(mapper_edge)
    self.renderer.AddActor(actor)


def simplify_mesh(self, input_mesh, target_reduction):
    # Create the quadric decimation filter
    decimate = vtk.vtkDecimatePro()
    decimate.SetInputData(input_mesh)

    # Set the reduction factor (0 to 1, where 1 means maximum reduction)
    decimate.SetTargetReduction(target_reduction)

    # Optional: Preserve topology (if needed)
    decimate.PreserveTopologyOn()

    # Perform the decimation
    decimate.Update()

    return decimate.GetOutput()


def combine_coplanar_faces(self, input_polydata, tolerance=0.001):
    # Clean the polydata to merge duplicate points
    clean = vtk.vtkCleanPolyData()
    clean.SetInputData(input_polydata)
    clean.SetTolerance(tolerance)
    clean.Update()

    # Generate normals and merge coplanar polygons
    normals = vtk.vtkPolyDataNormals()
    normals.SetInputConnection(clean.GetOutputPort())
    normals.SplittingOff()  # Disable splitting of sharp edges
    normals.ConsistencyOn()  # Ensure consistent polygon ordering
    normals.AutoOrientNormalsOn()  # Automatically orient normals
    normals.ComputePointNormalsOff()  # We only need face normals
    normals.ComputeCellNormalsOn()  # Compute cell normals
    normals.Update()

    return normals.GetOutput()


def poisson_reconstruction(self, points):
    # Create a polydata object from points
    point_polydata = vtk.vtkPolyData()
    point_polydata.SetPoints(points)

    # Create a surface reconstruction filter
    surf = vtk.vtkSurfaceReconstructionFilter()
    surf.SetInputData(point_polydata)
    surf.Update()

    # Create a contour filter to extract the surface
    cf = vtk.vtkContourFilter()
    cf.SetInputConnection(surf.GetOutputPort())
    cf.SetValue(0, 0.0)
    cf.Update()

    # Reverse normals
    reverse = vtk.vtkReverseSense()
    reverse.SetInputConnection(cf.GetOutputPort())
    reverse.ReverseCellsOn()
    reverse.ReverseNormalsOn()
    reverse.Update()

    return reverse.GetOutput()


def create_simplified_outline(self, polydata):
    featureEdges = vtk.vtkFeatureEdges()
    featureEdges.SetInputData(polydata)
    featureEdges.BoundaryEdgesOn()
    featureEdges.FeatureEdgesOn()
    featureEdges.ManifoldEdgesOff()
    featureEdges.NonManifoldEdgesOff()
    featureEdges.Update()

    """# 3. Clean the edges to merge duplicate points
    cleaner = vtk.vtkCleanPolyData()
    cleaner.SetInputConnection(feature_edges.GetOutputPort())
    cleaner.Update()

    # 4. Optional: Smooth the outline
    smooth = vtk.vtkSmoothPolyDataFilter()
    smooth.SetInputConnection(cleaner.GetOutputPort())
    smooth.SetNumberOfIterations(15)
    smooth.SetRelaxationFactor(0.1)
    smooth.FeatureEdgeSmoothingOff()
    smooth.BoundarySmoothingOn()
    smooth.Update()"""

    return featureEdges


def render_from_points_direct_with_faces(self, vertices, faces):
    points = vtk.vtkPoints()
    for i in range(vertices.shape[0]):
        points.InsertNextPoint(vertices[i])

    # Create a vtkCellArray to store the triangles
    triangles = vtk.vtkCellArray()
    for i in range(faces.shape[0]):
        triangle = vtk.vtkTriangle()
        triangle.GetPointIds().SetId(0, faces[i, 0])
        triangle.GetPointIds().SetId(1, faces[i, 1])
        triangle.GetPointIds().SetId(2, faces[i, 2])
        triangles.InsertNextCell(triangle)

    """vtk_points = vtk.vtkPoints()
    for point in points:
        vtk_points.InsertNextPoint(point)

    # Create a vtkCellArray to store the triangles
    triangles = vtk.vtkCellArray()

    # Assuming points are organized as triplets forming triangles
    for i in range(0, len(points), 3):
        triangle = vtk.vtkTriangle()
        triangle.GetPointIds().SetId(0, i)
        triangle.GetPointIds().SetId(1, i + 1)
        triangle.GetPointIds().SetId(2, i + 2)
        triangles.InsertNextCell(triangle)"""

    # Create a polydata object
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(points)
    polydata.SetPolys(triangles)

    # Calculate normals
    normalGenerator = vtk.vtkPolyDataNormals()
    normalGenerator.SetInputData(polydata)
    normalGenerator.ComputePointNormalsOn()
    normalGenerator.ComputeCellNormalsOn()
    normalGenerator.Update()

    self.cell_normals = vtk_to_numpy(normalGenerator.GetOutput().GetCellData().GetNormals())

    # merged_polydata = self.merge_coplanar_triangles(polydata)

    # Create a mapper and actor
    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputData(polydata)

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetColor(1, 1, 1)  # Set color (white in this case)
    actor.GetProperty().EdgeVisibilityOn()  # Show edges
    actor.GetProperty().SetLineWidth(2)  # Set line width

    feature_edges = self.create_simplified_outline(polydata)

    # Create a mapper for the feature edges
    edge_mapper = vtk.vtkPolyDataMapper()
    # Already wiht output
    edge_mapper.SetInputConnection(feature_edges.GetOutputPort())

    # Create an actor for the feature edges
    edge_actor = vtk.vtkActor()
    edge_actor.SetMapper(edge_mapper)

    # Set the properties of the edge actor
    edge_actor.GetProperty().SetColor(1, 0, 0)  # Set color (red in this case)
    edge_actor.GetProperty().SetLineWidth(2)  # Set line width

    # Optionally, if you want to keep the original mesh visible:
    # (assuming you have the original mesh mapper and actor set up)
    self.renderer.AddActor(actor)  # Add the original mesh actor
    # Add the edge actor to the renderer
    self.renderer.AddActor(edge_actor)

    # Force an update of the pipeline
    mapper.Update()
    self.vtk_widget.GetRenderWindow().Render()

    """# Print statistics
    print(f"Original points: {len(points)}")
    print(f"Number of triangles: {triangles.GetNumberOfCells()}")
    print(f"Final number of points: {normals.GetOutput().GetNumberOfPoints()}")
    print(f"Final number of cells: {normals.GetOutput().GetNumberOfCells()}")"""


def render_from_points_direct(self, points):
    ### Rendermethod for SDF mesh (output)
    # Create a vtkPoints object and store the points in it
    vtk_points = vtk.vtkPoints()
    for point in points:
        vtk_points.InsertNextPoint(point)

    # Create a polydata object
    point_polydata = vtk.vtkPolyData()
    point_polydata.SetPoints(vtk_points)

    # Surface reconstruction
    surf = vtk.vtkSurfaceReconstructionFilter()
    surf.SetInputData(point_polydata)
    surf.Update()

    # Create a contour filter to extract the surface
    cf = vtk.vtkContourFilter()
    cf.SetInputConnection(surf.GetOutputPort())
    cf.SetValue(0, 0.0)
    cf.Update()

    # Reverse the normals
    reverse = vtk.vtkReverseSense()
    reverse.SetInputConnection(cf.GetOutputPort())
    reverse.ReverseCellsOn()
    reverse.ReverseNormalsOn()
    reverse.Update()

    # Get the reconstructed mesh
    reconstructed_mesh = reverse.GetOutput()

    """# Simplify the mesh
    target_reduction = 1  # Adjust this value as needed
    simplified_mesh = self.simplify_mesh(reconstructed_mesh, target_reduction)

    combinded_faces = self.combine_coplanar_faces(simplified_mesh, 0.001)"""

    # Create a mapper and actor for the simplified mesh
    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputData(reconstructed_mesh)

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetColor(1, 1, 1)  # Set color (white in this case)
    actor.GetProperty().EdgeVisibilityOn()  # Show edges
    actor.GetProperty().SetLineWidth(2)  # Set line width

    # Add the actor to the renderer
    self.renderer.AddActor(actor)

    # Force an update of the pipeline
    # mapper.Update()
    self.vtk_widget.GetRenderWindow().Render()

    # Print statistics
    print(f"Original points: {len(points)}")
    print(
        f"Reconstructed mesh: {reconstructed_mesh.GetNumberOfPoints()} points, {reconstructed_mesh.GetNumberOfCells()} cells")
    """print(
        f"Simplified mesh: {simplified_mesh.GetNumberOfPoints()} points, {simplified_mesh.GetNumberOfCells()} cells")"""