import asyncio
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

from pyscript import display
from js import document


GRID_VOTERS = 100
DOMAIN = [-1.0, 1.0, -1.0, 1.0]
xmin, xmax, ymin, ymax = DOMAIN


def get_float(id_):
    return float(document.getElementById(id_).value)


def get_int(id_):
    return int(float(document.getElementById(id_).value))


def get_str(id_):
    return str(document.getElementById(id_).value)


def set_status(msg):
    document.getElementById("status").innerText = msg


def set_results(text):
    document.getElementById("results").innerText = text


def run_model():
    tau = get_float("tau")
    m = get_float("m")
    alpha = get_float("alpha")

    N_PARTIES = get_int("n_parties")
    N_STEPS = get_int("n_steps")

    dynamics = get_str("dynamics")
    distribution = get_str("distribution")

    T_mc = get_float("t_mc")
    a_xy = get_float("a_xy")
    alpha_xy = get_float("alpha_xy")

    if distribution == "two_peak":
        sigma2 = (1 - m**2) / 9.0
        if sigma2 <= 0:
            raise ValueError("For two_peak distribution, need |m| < 1.")

        mean1 = np.array([m, m])
        mean2 = np.array([-m, -m])
        cov = np.array([[sigma2, 0.0],
                        [0.0,    sigma2]])

        rv1 = multivariate_normal(mean=mean1, cov=cov)
        rv2 = multivariate_normal(mean=mean2, cov=cov)

        def W(x, y):
            pos = np.dstack((x, y))
            return 0.5 * rv1.pdf(pos) + 0.5 * rv2.pdf(pos)

    elif distribution == "covariant":
        def W(x, y):
            q = x**2 + 2 * y**2
            base = (1.0 / (np.sqrt(2.0) * np.pi)) * np.exp(-q)
            modifier = 1.0 + a_xy * x * y * np.exp(-alpha_xy * q)
            return base * modifier

    elif distribution == "flat":
        def W(x, y):
            return np.ones_like(x, dtype=float)

    else:
        raise ValueError("distribution must be 'two_peak', 'covariant' or 'flat'")

    x_edges = np.linspace(xmin, xmax, GRID_VOTERS + 1)
    y_edges = np.linspace(ymin, ymax, GRID_VOTERS + 1)

    xv = 0.5 * (x_edges[:-1] + x_edges[1:])
    yv = 0.5 * (y_edges[:-1] + y_edges[1:])

    XV, YV = np.meshgrid(xv, yv)
    W_grid = W(XV, YV)

    if np.any(W_grid < 0):
        W_grid = np.clip(W_grid, 0.0, None)

    if W_grid.sum() <= 0:
        raise ValueError("Voter density has non-positive total mass on the grid.")

    W_grid /= W_grid.sum()

    voter_points = np.stack([XV.ravel(), YV.ravel()], axis=1)
    voter_weights = W_grid.ravel()

    rng = np.random.default_rng()

    initial_indices = rng.choice(len(voter_points), size=N_PARTIES, replace=False)
    party_pos = voter_points[initial_indices].copy()

    traj = [party_pos.copy()]

    def compute_vote_shares(party_pos):
        diff = voter_points[:, None, :] - party_pos[None, :, :]
        d2 = np.sum(diff**2, axis=2)

        owners = np.argmin(d2, axis=1)

        chosen_distances = np.sqrt(d2[np.arange(len(voter_points)), owners])
        turnout_factor = (1 + chosen_distances) ** (-2 * tau)

        raw_votes = np.zeros(N_PARTIES)
        for i in range(N_PARTIES):
            mask = (owners == i)
            raw_votes[i] = np.sum(voter_weights[mask] * turnout_factor[mask])

        total_votes = raw_votes.sum()
        shares = raw_votes / total_votes if total_votes > 0 else raw_votes
        return shares, owners

    def compute_centroids(party_pos, owners):
        centroids = np.zeros_like(party_pos)
        for i in range(N_PARTIES):
            mask = (owners == i)
            if np.any(mask):
                w = voter_weights[mask]
                pts = voter_points[mask]
                centroids[i] = np.average(pts, axis=0, weights=w)
            else:
                centroids[i] = party_pos[i]
        return centroids

    def utility(i, pos_i, shares, centroids):
        return alpha * shares[i] - (1 - alpha) * np.sum((pos_i - centroids[i])**2)

    NEIGH = np.array([
        [ 1,  0], [-1,  0], [0,  1], [0, -1],
        [ 1,  1], [ 1, -1], [-1,  1], [-1, -1]
    ], dtype=float)

    STEP = 0.05
    NEIGH = NEIGH * STEP

    stop_step = None

    for step in range(N_STEPS):
        shares, owners = compute_vote_shares(party_pos)
        centroids = compute_centroids(party_pos, owners)

        any_move = False

        for i in rng.permutation(N_PARTIES):
            current_util = utility(i, party_pos[i], shares, centroids)

            if dynamics == "greedy":
                best_util = current_util
                best_pos = party_pos[i].copy()

                for d in NEIGH:
                    cand = party_pos[i] + d

                    if not (xmin <= cand[0] <= xmax and ymin <= cand[1] <= ymax):
                        continue

                    temp_pos = party_pos.copy()
                    temp_pos[i] = cand

                    shares_temp, owners_temp = compute_vote_shares(temp_pos)
                    centroids_temp = compute_centroids(temp_pos, owners_temp)

                    util_temp = utility(i, cand, shares_temp, centroids_temp)

                    if util_temp > best_util:
                        best_util = util_temp
                        best_pos = cand

                if not np.allclose(best_pos, party_pos[i]):
                    any_move = True
                party_pos[i] = best_pos

            elif dynamics == "metropolis":
                candidates = []
                for d in NEIGH:
                    cand = party_pos[i] + d
                    if xmin <= cand[0] <= xmax and ymin <= cand[1] <= ymax:
                        candidates.append(cand)

                if len(candidates) == 0:
                    continue

                cand = candidates[rng.integers(len(candidates))]

                temp_pos = party_pos.copy()
                temp_pos[i] = cand

                shares_temp, owners_temp = compute_vote_shares(temp_pos)
                centroids_temp = compute_centroids(temp_pos, owners_temp)

                util_temp = utility(i, cand, shares_temp, centroids_temp)
                delta = util_temp - current_util

                if delta >= 0:
                    accept = True
                else:
                    accept = rng.random() < np.exp(delta / T_mc)

                if accept:
                    party_pos[i] = cand
                    any_move = True

            else:
                raise ValueError("dynamics must be 'greedy' or 'metropolis'")

        traj.append(party_pos.copy())

        if dynamics == "greedy" and not any_move:
            stop_step = step
            break

    final_shares, final_owners = compute_vote_shares(party_pos)
    traj = np.stack(traj, axis=0)

    polarisation = np.mean(np.sqrt(np.sum(party_pos**2, axis=1)))

    W_x = W_grid.sum(axis=0)
    W_y = W_grid.sum(axis=1)

    owners_grid = final_owners.reshape(GRID_VOTERS, GRID_VOTERS)
    X_edges, Y_edges = np.meshgrid(x_edges, y_edges)

    fig = plt.figure(figsize=(7.0, 6.0))
    ax_main = fig.add_axes([0.10, 0.12, 0.62, 0.62])

    ax_main.pcolormesh(
        X_edges,
        Y_edges,
        owners_grid,
        cmap="tab20",
        alpha=0.30,
        shading="flat",
        vmin=0,
        vmax=max(N_PARTIES - 1, 1)
    )

    if not np.allclose(W_grid, W_grid[0, 0]):
        ax_main.contour(
            XV, YV, W_grid,
            levels=12,
            colors="k",
            linewidths=0.6,
            alpha=0.5
        )

    colors = plt.cm.tab10(np.linspace(0, 1, N_PARTIES))

    for i in range(N_PARTIES):
        xi = traj[:, i, 0]
        yi = traj[:, i, 1]

        ax_main.plot(
            xi, yi,
            "-o",
            color=colors[i],
            markersize=1.5,
            linewidth=1.0
        )

        ax_main.scatter(
            xi[-1], yi[-1],
            s=80,
            color=colors[i],
            edgecolor="k",
            linewidth=0.8,
            zorder=5
        )

    ax_main.set_xlim(xmin, xmax)
    ax_main.set_ylim(ymin, ymax)
    ax_main.set_aspect("equal", adjustable="box")
    ax_main.set_xticks([])
    ax_main.set_yticks([])

    fig.canvas.draw()
    pos = ax_main.get_position()

    gap = 0.02
    top_h = 0.10
    right_w = 0.10

    ax_top = fig.add_axes([pos.x0, pos.y1 + gap, pos.width, top_h])
    ax_right = fig.add_axes([pos.x1 + gap, pos.y0, right_w, pos.height])

    ax_top.plot(xv, W_x, color="black", linewidth=1.2)
    for i in range(N_PARTIES):
        ax_top.axvline(
            party_pos[i, 0],
            color=colors[i],
            linestyle="--",
            alpha=0.9,
            linewidth=0.9
        )

    ax_top.set_xlim(xmin, xmax)
    ax_top.set_xticks([])
    ax_top.set_yticks([])
    ax_top.spines["top"].set_visible(False)
    ax_top.spines["right"].set_visible(False)

    ax_right.plot(W_y, yv, color="black", linewidth=1.2)
    for i in range(N_PARTIES):
        ax_right.axhline(
            party_pos[i, 1],
            color=colors[i],
            linestyle="--",
            alpha=0.9,
            linewidth=0.9
        )

    ax_right.set_ylim(ymin, ymax)
    ax_right.set_xticks([])
    ax_right.set_yticks([])
    ax_right.spines["top"].set_visible(False)
    ax_right.spines["right"].set_visible(False)

    return fig, final_shares, party_pos, polarisation, stop_step
async def on_run(event=None):
    set_status("Running...")
    set_results("Running simulation...")
    await asyncio.sleep(0.05)

    try:
        document.getElementById("plot").innerHTML = ""

        fig, final_shares, party_pos, polarisation, stop_step = run_model()

        display(fig, target="plot")

        lines = []
        lines.append("Final vote shares:")
        for i, s in enumerate(final_shares):
            lines.append(f"  Party {i}: {s:.4f}")

        lines.append("")
        lines.append("Final party positions:")
        for i in range(len(party_pos)):
            lines.append(
                f"  Party {i}: x = {party_pos[i,0]: .4f}, y = {party_pos[i,1]: .4f}"
            )

        lines.append("")
        lines.append(f"Mean separation from centre: {polarisation:.4f}")

        if stop_step is not None:
            lines.append("")
            lines.append(f"Greedy dynamics stopped: no party moved at step {stop_step}")

        set_results("\n".join(lines))
        set_status("Done.")

    except Exception as e:
        set_status("Error.")
        set_results(f"Error: {e}")
