import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

import pylbm
import numpy as np


class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        FancyArrowPatch.draw(self, renderer)


def Velocities_1D(n):
    dim = 1
    fig = plt.figure(dim, figsize=(8, 4), facecolor="white")
    fig.clf()
    xmin, xmax, ymin, ymax = 1000, -1000, -1, 1
    e = 0.2
    for k in range((2 * n + 1) ** dim):
        v = pylbm.stencil.Velocity(dim=dim, num=k)
        x = v.vx
        xmin = min(xmin, x)
        xmax = max(xmax, x)
        couleur_texte = 0.0
        couleur_trait = 0.5
        plt.text(
            x,
            0.0,
            str(v.num),
            color=[couleur_texte] * 3,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=15,
        )

    plt.plot([xmin, xmax], [0, 0], ":", color=[couleur_trait] * 3)
    plt.text(
        0.0,
        ymax + 2 * e,
        "Velocities numbering {0:1d}D".format(dim),
        fontsize=20,
        verticalalignment="center",
        horizontalalignment="center",
        color="b",
    )
    plt.arrow(
        xmin - e, ymin - e, 1, 0, head_width=0.05 * dim, head_length=0.1, fc="b", ec="b"
    )
    plt.text(
        xmin - e + 0.5,
        ymin - 1.5 * e,
        "x",
        color="b",
        verticalalignment="center",
        horizontalalignment="center",
    )
    plt.axis("off")
    plt.xlim(xmin - 2 * e, xmax + 2 * e)
    plt.ylim(ymin - 2 * e, ymax + 2 * e)
    plt.draw()


def Velocities_2D(n):
    dim = 2
    fig = plt.figure(dim, figsize=(8, 8), facecolor="white")
    fig.clf()
    xmin, xmax, ymin, ymax = 1000, -1000, 1000, -1000
    e = 0.5
    for k in range((2 * n + 1) ** dim):
        v = pylbm.stencil.Velocity(dim=dim, num=k)
        x = v.vx
        y = v.vy
        xmin = min(xmin, x)
        xmax = max(xmax, x)
        ymin = min(ymin, y)
        ymax = max(ymax, y)
        couleur_texte = 0.0
        couleur_trait = 0.5
        plt.text(
            x,
            y,
            str(v.num),
            color=[couleur_texte] * 3,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=15,
        )
    for x in range(xmin, xmax + 1):
        plt.plot([x, x], [ymin, ymax], ":", color=[couleur_trait] * 3)
    for y in range(ymin, ymax + 1):
        plt.plot([xmin, xmax], [y, y], ":", color=[couleur_trait] * 3)
    plt.text(
        0.0,
        ymax + 2 * e,
        "Velocities numbering {0:1d}D".format(dim),
        fontsize=20,
        verticalalignment="center",
        horizontalalignment="center",
        color="b",
    )
    plt.arrow(
        xmin - e, ymin - e, 1, 0, head_width=0.05 * dim, head_length=0.1, fc="b", ec="b"
    )
    plt.arrow(
        xmin - e, ymin - e, 0, 1, head_width=0.05 * dim, head_length=0.1, fc="b", ec="b"
    )
    plt.text(
        xmin - e + 0.5,
        ymin - 1.5 * e,
        "x",
        color="b",
        verticalalignment="center",
        horizontalalignment="center",
    )
    plt.text(
        xmin - 1.5 * e,
        ymin - e + 0.5,
        "y",
        color="b",
        verticalalignment="center",
        horizontalalignment="center",
    )
    plt.axis("off")
    plt.xlim(xmin - 2 * e, xmax + 2 * e)
    plt.ylim(ymin - 2 * e, ymax + 2 * e)
    plt.draw()


def Velocities_3D(n):
    dim = 3
    couleur_tour = "k"
    fig = plt.figure(dim, figsize=(8, 8), facecolor="white")
    fig.clf()
    ax = fig.add_subplot(111, projection="3d")
    xmin, xmax, ymin, ymax, zmin, zmax = 1000, -1000, 1000, -1000, 1000, -1000
    e = 0.5
    for k in range((2 * n + 1) ** dim):
        v = pylbm.stencil.Velocity(dim=dim, num=k)
        x = v.vx
        y = v.vy
        z = v.vz
        xmin = min(xmin, x)
        xmax = max(xmax, x)
        ymin = min(ymin, y)
        ymax = max(ymax, y)
        zmin = min(zmin, z)
        zmax = max(zmax, z)
        couleur_texte = [0.5 + 0.5 * x, 0.0, 0.5 - 0.5 * x]
        couleur_trait = 0.5
        ax.text(
            x,
            y,
            z,
            str(v.num),
            color=couleur_texte,
            horizontalalignment="center",
            verticalalignment="center",
            fontsize=15,
        )
    for x in range(xmin, xmax + 1):
        for y in range(ymin, ymax + 1):
            ax.plot([x, x], [y, y], [zmin, zmax], ":", color=[couleur_trait] * 3)
    for x in range(xmin, xmax + 1):
        for z in range(zmin, zmax + 1):
            ax.plot([x, x], [ymin, ymax], [z, z], ":", color=[couleur_trait] * 3)
    for z in range(zmin, zmax + 1):
        for y in range(ymin, ymax + 1):
            ax.plot([xmin, xmax], [y, y], [z, z], ":", color=[couleur_trait] * 3)

    XS, YS = np.meshgrid([-1, 1], [-1, 1])
    ZS = np.zeros(XS.shape)
    for x in [-1, 0, 1]:
        ax.plot_surface(
            ZS + x,
            XS,
            YS,
            rstride=1,
            cstride=1,
            color=[0.5 + 0.5 * x, 0.0, 0.5 - 0.5 * x],
            shade=False,
            alpha=0.2,
            antialiased=False,
            linewidth=0,
        )
    ax.text(
        0.0,
        0.0,
        zmax + 2 * e,
        "Velocities numbering {0:1d}D".format(dim),
        fontsize=20,
        verticalalignment="center",
        horizontalalignment="center",
        color=couleur_tour,
    )
    vx = Arrow3D(
        [xmax + e, xmax + e + 1],
        [ymin - e, ymin - e],
        [zmin - e, zmin - e],
        mutation_scale=20,
        lw=1,
        arrowstyle="-|>",
        color=couleur_tour,
    )
    ax.add_artist(vx)
    vy = Arrow3D(
        [xmax + e, xmax + e],
        [ymin - e, ymin - e + 1],
        [zmin - e, zmin - e],
        mutation_scale=20,
        lw=1,
        arrowstyle="-|>",
        color=couleur_tour,
    )
    ax.add_artist(vy)
    vz = Arrow3D(
        [xmax + e, xmax + e],
        [ymin - e, ymin - e],
        [zmin - e, zmin - e + 1],
        mutation_scale=20,
        lw=1,
        arrowstyle="-|>",
        color=couleur_tour,
    )
    ax.add_artist(vz)
    ax.text(
        xmax + e + 0.8,
        ymin - 0.8 * e,
        zmin - e,
        "x",
        color=couleur_tour,
        verticalalignment="center",
        horizontalalignment="center",
    )
    ax.text(
        xmax + e,
        ymin - e + 0.8,
        zmin - 1.2 * e,
        "y",
        color=couleur_tour,
        verticalalignment="center",
        horizontalalignment="center",
    )
    ax.text(
        xmax + e,
        ymin - 1.2 * e,
        zmin - e + 0.8,
        "z",
        color=couleur_tour,
        verticalalignment="center",
        horizontalalignment="center",
    )
    ax.set_xlim(xmin - e, xmax + e)
    ax.set_ylim(ymin - e, ymax + e)
    ax.set_zlim(zmin - e, zmax + e)
    ax.azim = 34
    ax.elev = 20
    plt.axis("off")
    plt.draw()


def Velocities(dim, n):
    if dim == 1:
        Velocities_1D(n)
    elif dim == 2:
        Velocities_2D(n)
    elif dim == 3:
        Velocities_3D(n)
    else:
        print("error of dimension")
    plt.show()


Velocities(1, 3)
Velocities(2, 2)
Velocities(3, 1)
