import numpy as np
from skimage import measure
import multiprocessing
from functools import partial
from multiprocessing.pool import ThreadPool
import itertools
import time


def _cartesian_product(*arrays):
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[..., i] = a
    return arr.reshape(-1, la)


class VESTA:
    def __init__(self, sdf, bounds=None, resolution=64, threshold=0.0, workers=None):
        self.sdf = sdf
        self.bounds = bounds
        self.resolution = resolution
        self.threshold = threshold
        self.workers = workers or multiprocessing.cpu_count()

    def _estimate_bounds(self):
        s = 16
        x0 = y0 = z0 = -1e9
        x1 = y1 = z1 = 1e9
        prev = None
        for i in range(32):
            X = np.linspace(x0, x1, s)
            Y = np.linspace(y0, y1, s)
            Z = np.linspace(z0, z1, s)
            d = np.array([X[1] - X[0], Y[1] - Y[0], Z[1] - Z[0]])
            threshold = np.linalg.norm(d) / 2
            if threshold == prev:
                break
            prev = threshold
            P = _cartesian_product(X, Y, Z)
            volume = self.sdf(P).reshape((len(X), len(Y), len(Z)))
            where = np.argwhere(np.abs(volume) <= threshold)
            if where.size == 0:
                continue
            x1, y1, z1 = (x0, y0, z0) + where.max(axis=0) * d + d / 2
            x0, y0, z0 = (x0, y0, z0) + where.min(axis=0) * d - d / 2
        if prev is None:
            raise ValueError("Failed to estimate bounds. No points found within any threshold.")
        return ((x0, y0, z0), (x1, y1, z1))

    def _vesta_worker(self, chunk):
        x0, x1, y0, y1, z0, z1 = chunk
        X = np.linspace(x0, x1, self.resolution)
        Y = np.linspace(y0, y1, self.resolution)
        Z = np.linspace(z0, z1, self.resolution)
        P = _cartesian_product(X, Y, Z)
        V = self.sdf(P).reshape((self.resolution, self.resolution, self.resolution))

        try:
            verts, faces, _, _ = measure.marching_cubes(V, self.threshold)
        except RuntimeError:
            # Return empty arrays if marching_cubes fails
            return np.array([]), np.array([])

        # Scale and translate vertices to match the chunk's bounds
        verts = verts / (self.resolution - 1)
        verts[:, 0] = verts[:, 0] * (x1 - x0) + x0
        verts[:, 1] = verts[:, 1] * (y1 - y0) + y0
        verts[:, 2] = verts[:, 2] * (z1 - z0) + z0

        return verts, faces

    def _merge_meshes(self, results):
        all_verts = []
        all_faces = []
        offset = 0
        for verts, faces in results:
            if len(verts) > 0 and len(faces) > 0:
                all_verts.append(verts)
                all_faces.append(faces + offset)
                offset += len(verts)
        if not all_verts or not all_faces:
            return np.array([]), np.array([])
        return np.vstack(all_verts), np.vstack(all_faces)

    def generate_mesh(self):
        if self.bounds is None:
            self.bounds = self._estimate_bounds()

        (x0, y0, z0), (x1, y1, z1) = self.bounds
        chunks = [
            (x0, x1, y0, y1, z0, z1)
        ]

        with ThreadPool(self.workers) as pool:
            results = pool.map(self._vesta_worker, chunks)

        verts, faces = self._merge_meshes(results)
        return verts, faces


def generate_mesh_from_sdf(sdf, bounds=None, resolution=64, threshold=0.0, workers=None):
    vesta = VESTA(sdf, bounds, resolution, threshold, workers)
    return vesta.generate_mesh()


# Helper function to save the mesh as an STL file
def save_mesh_as_stl(vertices, faces, filename):
    from stl import mesh

    # Create the mesh
    cube = mesh.Mesh(np.zeros(faces.shape[0], dtype=mesh.Mesh.dtype))
    for i, f in enumerate(faces):
        for j in range(3):
            cube.vectors[i][j] = vertices[f[j], :]

    # Write the mesh to file
    cube.save(filename)