import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.pyplot import *

from tc_python import *

"""
In this example a calibrated double ellipsoidal heat source is used to perform batch steady-state AM calculations
for all the experimental variations of power and scan speed in the single track experiments. 

A printability map is then produced using simple functions for the defects keyholing, balling and lack of fusion.
All defect functions are based on the melt pool dimensions (width, depth and length).
"""

# experimental data for SS316L from Jun Wei et al. Advances in Materials Science and Engineering Volume 2019
# https://doi.org/10.1155/2019/9451406
EXP_FILE = "316L_single_track_experiments.csv"
HEAT_SOURCE_NAME = "Double ellipsoidal - 316L - beam d 15um"
MATERIAL_NAME = "SS316L"
POWDER_THICKNESS = 10.0  # in micro-meter
HATCH_DISTANCE = 10.0  # in micro-meter. Only single track experiment but let's use a hatch spacing for the printability map
AMBIENT_TEMPERATURE = 353  # in K, as given in the paper

# the publication below reference several limits for the defects based on melt pool dimensions:
# Sheikh et al. "High-throughput Alloy and Process Design for Metal Additive Manufacturing"
# Condensed Matter Jan 13, 2023, https://arxiv.org/abs/2304.04149
LACK_OF_FUSION_LIMIT = 1.5  # eq. 15 when melt pool depth / powder thickness < 1.0
LACK_OF_FUSION_HATCH_LIMIT = 1.1  # melt pool depth at hatch distance < powder thickness
BALLING_LIMIT = 0.23  # balling: eq. 18 when w/L<0.43 or eq. 19 when L/w<0.26
KEYHOLE_LIMIT = 1.9  # keyholing: eq. 20 when melt pool width / depth < 2.5. Generally a transition at 1.0 and keyhole at 1.5)


def get_depth_at_half_hatch_distance(result: SteadyStateResult, hatch_distance: float, liquidus: float):
    """
    Iso contour of liquidus at half the hatch distance to get penetration depth
    """
    mesh = result.get_pyvista_mesh()
    origin = mesh.center
    origin[1] = 0.5 * hatch_distance * 1e-6
    slice = mesh.slice(normal="y", origin=origin)
    contour = slice.contour(isosurfaces=np.array([liquidus]))
    if contour.n_points > 0:
        depth_at_hatch_distance = contour.bounds[4]
    else:
        depth_at_hatch_distance = 0.0
    return abs(depth_at_hatch_distance)


def plot_heat_source_parameters(df):
    fig_hs, axs_hs = plt.subplots(1, 1)

    energy_density = []
    absorptivity = []
    ar = []
    af = []
    b = []
    c = []

    heat_source = HeatSource.double_ellipsoidal_from_library(HEAT_SOURCE_NAME)
    for this_power, this_scan_speed in zip(df["power (W)"], df["speed (mm/s)"]):
        energy_density.append(this_power / this_scan_speed)
        heat_source.set_power(this_power)
        heat_source.set_scanning_speed(this_scan_speed / 1e3)

        absorptivity.append(heat_source.get_absorptivity())
        ar.append(heat_source.get_ar())
        af.append(heat_source.get_af())
        b.append(heat_source.get_b())
        c.append(heat_source.get_c())

    axs_hs.plot(energy_density, absorptivity, label="abs", marker="o", lw=0)
    axs_hs.plot(energy_density, ar, label="ar", marker="o", lw=0)
    axs_hs.plot(energy_density, af, label="af", marker="o", lw=0)
    axs_hs.plot(energy_density, b, label="b", marker="o", lw=0)
    axs_hs.plot(energy_density, c, label="c", marker="o", lw=0)
    axs_hs.set_xlabel("Energy density (J/mm)")
    axs_hs.set_ylabel("Size (\u03bcm) / absorptivity")
    axs_hs.set_ylim(bottom=0.0)
    axs_hs.legend()


def plot_exp_and_calculated_width_depth(df):
    fig, axs = plt.subplots(2, 1)
    axs[0].errorbar(df['energy density (J/mm)'], df['meltpool depth (um)'], yerr=df['depth error (um)'], marker="o",
                    linestyle="None", capsize=7.0, label='Experiment', zorder=0)
    axs[0].plot(df["energy density (J/mm)"], df["calc depth (um)"], label="Calculation", marker="<", lw=0)
    axs[0].set_xlabel("Energy density (J/mm)")
    axs[0].set_ylabel("Depth (\u03bcm)")
    axs[0].set_ylim(bottom=0.0)
    axs[0].legend()

    if "meltpool width (um)" in df:
        axs[1].errorbar(df['energy density (J/mm)'], df['meltpool width (um)'], yerr=df['width error (um)'], marker="o",
                        linestyle="None", capsize=7.0, label='Experiment', zorder=0)
    axs[1].plot(df["energy density (J/mm)"], df["calc width (um)"], label="Calculation", marker="<", lw=0)
    axs[1].set_xlabel("Energy density (J/mm)")
    axs[1].set_ylabel("Width (\u03bcm)")
    axs[1].set_ylim(bottom=0.0)
    axs[1].legend()


def plot_calculated_melt_pool_dimensions(df):
    fig, axs2 = plt.subplots(3, 1)

    cntr1 = axs2[0].tricontourf(df["speed (mm/s)"], df["power (W)"], df["calc depth (um)"], levels=14, cmap="RdBu_r")
    fig.colorbar(cntr1, ax=axs2[0])
    axs2[0].plot(df["speed (mm/s)"], df["power (W)"], "ko", ms=3)
    axs2[0].set_title("Calculated melt pool depth (%d points)" % len(df.index))
    axs2[0].set_xlabel("Scan speed (mm/)")
    axs2[0].set_ylabel("Power (W)")

    cntr2 = axs2[1].tricontourf(df["speed (mm/s)"], df["power (W)"], df["calc width (um)"], levels=14, cmap="RdBu_r")
    fig.colorbar(cntr2, ax=axs2[1])
    axs2[1].plot(df["speed (mm/s)"], df["power (W)"], "ko", ms=3)
    axs2[1].set_title("Calculated melt pool width (%d points)" % len(df.index))
    axs2[1].set_xlabel("Scan speed (mm/)")
    axs2[1].set_ylabel("Power (W)")

    cntr2 = axs2[2].tricontourf(df["speed (mm/s)"], df["power (W)"], df["calc length (um)"], levels=14, cmap="RdBu_r")
    fig.colorbar(cntr2, ax=axs2[2])
    axs2[2].plot(df["speed (mm/s)"], df["power (W)"], "ko", ms=3)
    axs2[2].set_title("Calculated melt pool length (%d points)" % len(df.index))
    axs2[2].set_xlabel("Scan speed (mm/)")
    axs2[2].set_ylabel("Power (W)")
    plt.subplots_adjust(hspace=0.7)


def plot_defects(df):
    fig, axs2 = plt.subplots(2, 2)

    cntr2 = axs2[0, 0].tricontourf(df["speed (mm/s)"], df["power (W)"], df["keyholing"], levels=14, cmap="RdBu_r")
    fig.colorbar(cntr2, ax=axs2[0, 0])
    axs2[0, 0].set_title("Keyholing (W/D)")
    axs2[0, 0].set_xlabel("Scan speed (mm/)")
    axs2[0, 0].set_ylabel("Power (W)")

    cntr2 = axs2[0, 1].tricontourf(df["speed (mm/s)"], df["power (W)"], df["balling"], levels=14, cmap="RdBu_r")
    fig.colorbar(cntr2, ax=axs2[0, 1])
    axs2[0, 1].set_title("Balling (W/L)")
    axs2[0, 1].set_xlabel("Scan speed (mm/)")
    axs2[0, 1].set_ylabel("Power (W)")

    cntr1 = axs2[1, 0].tricontourf(df["speed (mm/s)"], df["power (W)"], df["lack of fusion"], levels=14, cmap="RdBu_r")
    fig.colorbar(cntr1, ax=axs2[1, 0])
    axs2[1, 0].set_title("Lack of fusion (D/t)")
    axs2[1, 0].set_xlabel("Scan speed (mm/)")
    axs2[1, 0].set_ylabel("Power (W)")

    cntr2 = axs2[1, 1].tricontourf(df["speed (mm/s)"], df["power (W)"], df["lack of fusion - hatch"], levels=14,
                                   cmap="RdBu_r")
    fig.colorbar(cntr2, ax=axs2[1, 1])
    axs2[1, 1].set_title("lack of fusion -  (D_hatch/t)")
    axs2[1, 1].set_xlabel("Scan speed (mm/)")
    axs2[1, 1].set_ylabel("Power (W)")

    plt.subplots_adjust(hspace=0.5, wspace=0.4)


def plot_printability_map(df, df_experiments_only):
    fig, ax = plt.subplots(1, 1)
    ax.title.set_text("Layer thickness:{}\u03bcm, hatch distance:{}\u03bcm".format(POWDER_THICKNESS,
                                                                                   HATCH_DISTANCE))

    alpha = 0.5
    levels_keyholing = [0.0, KEYHOLE_LIMIT]
    levels_balling = [0.0, BALLING_LIMIT]
    levels_lack_of_fusion = [0.0, LACK_OF_FUSION_LIMIT]
    levels_lack_of_fusion_hatch = [0.0, LACK_OF_FUSION_HATCH_LIMIT]

    ax.tricontourf(df["speed (mm/s)"], df["power (W)"], df["balling"], levels=levels_balling, colors=("magenta",),
                   alpha=alpha)
    ax.tricontourf(df["speed (mm/s)"], df["power (W)"], df["lack of fusion"], levels=levels_lack_of_fusion,
                   colors=("blue",), alpha=alpha)
    ax.tricontourf(df["speed (mm/s)"], df["power (W)"], df["keyholing"], levels=levels_keyholing, colors=("teal",),
                   alpha=alpha)
    ax.tricontourf(df["speed (mm/s)"], df["power (W)"], df["lack of fusion - hatch"],
                   levels=levels_lack_of_fusion_hatch, colors=("deepskyblue",), alpha=alpha)

    c = ax.tricontour(df["speed (mm/s)"], df["power (W)"], df["keyholing"],
                      levels=np.linspace(.0, KEYHOLE_LIMIT, 8))
    ax.clabel(c, inline=True, fontsize=10)

    for i, row in df_experiments_only.iterrows():
        ax.text(x=row["speed (mm/s)"], y=row["power (W)"], s=row["text"], color=row["color"], ha="center", va="center")

    ax.set_xlabel("Scan speed (mm/s)")
    ax.set_ylabel("Power (W)")
    lack_of_fusion_hatch_patch = mpatches.Patch(color="deepskyblue",
                                                label="Lack of fusion - hatch (D_hatch/t < {})".format(
                                                    LACK_OF_FUSION_HATCH_LIMIT), alpha=alpha)
    lack_of_fusion_patch = mpatches.Patch(color="blue", label="Lack of fusion (D/t < {})".format(LACK_OF_FUSION_LIMIT),
                                          alpha=alpha)
    balling_patch = mpatches.Patch(color="magenta", label="Balling (W/L < {})".format(BALLING_LIMIT), alpha=alpha)
    keyholing_patch = mpatches.Patch(color="teal", label="Keyholing (W/D < {})".format(KEYHOLE_LIMIT), alpha=alpha)
    ax.legend(handles=[keyholing_patch, balling_patch, lack_of_fusion_patch, lack_of_fusion_hatch_patch])


with TCPython(logging_policy=LoggingPolicy.SCREEN) as start:
    start.set_cache_folder(os.path.basename(__file__) + "cache")

    df = pd.read_csv(EXP_FILE, skipinitialspace=True)
    print(df.to_string())

    extend_dataset_all_combinations = True
    # extend the dataset for calculation of the printability map by adding all combinations of the laser power
    # and scan speed
    if extend_dataset_all_combinations:
        power_list = df["power (W)"]
        u_list = df["speed (mm/s)"]
        XX, YY = np.meshgrid(power_list, u_list, indexing="ij")
        NT = np.prod(XX.shape)
        data = {"power (W)": np.reshape(XX, NT),
                "speed (mm/s)": np.reshape(YY, NT)}
        df2 = pd.DataFrame(data=data)
        df = pd.concat([df, df2])
        df = df.drop_duplicates(subset=["power (W)", "speed (mm/s)"], keep="first", ignore_index=True)
        print(df.to_string())
    else:
        df = df

    mp = MaterialProperties.from_library(MATERIAL_NAME)

    am_calculator = (start.with_additive_manufacturing()
                     .with_steady_state_calculation()
                     .with_numerical_options(NumericalOptions().set_number_of_cores(4))
                     .disable_fluid_flow_marangoni()
                     .with_material_properties(mp)
                     .with_mesh(Mesh().coarse()))
    am_calculator.set_ambient_temperature(AMBIENT_TEMPERATURE)
    am_calculator.set_base_plate_temperature(AMBIENT_TEMPERATURE)

    heat_source = HeatSource.double_ellipsoidal_from_library(HEAT_SOURCE_NAME)
    am_calculator.with_heat_source(heat_source)

    # add the energy density and dataframe columns for calculated melt pool dimensions
    df["energy density (J/mm)"] = df["power (W)"] / df["speed (mm/s)"]
    df["calc width (um)"] = np.nan
    df["calc depth (um)"] = np.nan
    df["calc length (um)"] = np.nan
    df["hatch distance depth (um)"] = np.nan

    for i, row in df.iterrows():
        power = row["power (W)"]
        scan_speed = row["speed (mm/s)"]
        print("Row:{}/{}, Power:{}W, scan speed:{}mm/s".format(i, len(df.index), power, scan_speed))

        try:
            heat_source.set_power(power)
            heat_source.set_scanning_speed(scan_speed / 1e3)
            result = am_calculator.calculate()

            length = result.get_meltpool_length()
            width = result.get_meltpool_width()
            depth = result.get_meltpool_depth()

            if length != float("inf") and width != float("inf") and depth != float("inf"):
                df.loc[i, ["hatch distance depth (um)"]] = (get_depth_at_half_hatch_distance(result, HATCH_DISTANCE,
                                                                                             mp.get_liquidus_temperature())
                                                            * 1e6)
                df.loc[i, ["calc width (um)"]] = width * 1e6
                df.loc[i, ["calc depth (um)"]] = depth * 1e6
                df.loc[i, ["calc length (um)"]] = length * 1e6
        except Exception as e:
            print(e)

    # drop the rows with failed calculations, if any
    df.dropna(subset=["calc width (um)", "calc depth (um)", "calc length (um)"], inplace=True)

    # filter the rows that include experimental data
    df_experiments_only = df.dropna(subset=["meltpool depth (um)", "meltpool width (um)"], inplace=False)

    df["lack of fusion - hatch"] = df["hatch distance depth (um)"] / POWDER_THICKNESS
    df["lack of fusion"] = df["calc depth (um)"] / POWDER_THICKNESS
    df["balling"] = df["calc width (um)"] / df["calc length (um)"]
    df["keyholing"] = df["calc width (um)"] / df["calc depth (um)"]

    plot_exp_and_calculated_width_depth(df_experiments_only)
    plot_heat_source_parameters(df)
    plot_calculated_melt_pool_dimensions(df)
    plot_defects(df)
    plot_printability_map(df, df_experiments_only)

    print(df.to_string())

    plt.show()
