import * as PlotL from "@observablehq/plot";
import * as d3 from "d3";
import React, { useEffect, useRef } from "react";
import { formatIsoDateTime } from "../../utils";
import { colorScheme } from "../../utils/colorScheme";
import { xAxisTypes } from "../../utils/inspection/dashboard";
import './plotStyles.css';

export enum ChartSeriesOptions {
    MIN = "min",
    MAX = "max",
    MEAN = "mean",
    MEDIAN = "median"
}

export interface PlotData {
    scanId: number;
    sn: string;
    metric: string;
    series: ChartSeriesOptions;
    sliceId: number | null;
    value: number;
    batch: string;
    date: Date;
};

const plotWidthPx = 900
const plotHeightPx = 300

// const getMetricUnit = (metric: string) => {
//     metricDisplayUnit(metric) ?? ""
// };

const tooltipText = (d: PlotData) => `${d.series}: ${d.value.toPrecision(4)}\n${d.sn}\n${d.batch}\n${formatIsoDateTime(new Date(d.date))}`;

const createPlot = (
    xDomain: string[],
    yDomain: [number, number],
    facetData: PlotData[],
    category: string
) => {
    const minMaxRules = Array.from(d3.group(facetData, d => d.sn), ([x, values]) => {
        const filteredMinValues = values.filter(d => d.series === "min" && d.value !== null);
        const filteredMaxValues = values.filter(d => d.series === "max" && d.value !== null);
        const yMin = d3.min(filteredMinValues, d => d.value as number);
        const yMax = d3.max(filteredMaxValues, d => d.value as number);
        return yMin !== undefined && yMax !== undefined ? { x: String(x), yMin, yMax } : null;
    }).filter(Boolean) as { x: string; yMin: number; yMax: number }[];

    const getTickFormat = (data: PlotData[], targetTickCount: number) => {
        // TODO: a little janky when there's mixed result numbers
        const uniqueXValues = Array.from(new Set(data.map(d => d.sn)));
        const totalXValues = uniqueXValues.length;
        if (totalXValues <= targetTickCount) return (d: string) => d;
        const step = Math.ceil(totalXValues / targetTickCount);
        return (_: string, i: number) => (i % step === 0 ? _ : ""); // Show every `step` label
    };

    return PlotL.plot({
        width: plotWidthPx,
        height: plotHeightPx,
        className: "plot-chart-class",
        marginLeft: 100,
        marginBottom: 100,
        x: {
            domain: xDomain,
            type: "band",
            label: null,
            labelArrow: false,
            tickFormat: getTickFormat(facetData, 30),
            tickRotate: -30,
            tickSize: 0,
            grid: false
        },
        y: {
            domain: yDomain,
            label: `${category}`,
            labelAnchor: "center",
            labelOffset: 50,
            nice: true,
            labelArrow: false
        },
        grid: true,
        color: {
            domain: ["min", "max", "mean", "median"],
            range: [colorScheme.min, colorScheme.max, colorScheme.average, colorScheme.average]
        },
        marks: [
            PlotL.ruleX(minMaxRules, { x: "x", y1: "yMin", y2: "yMax", stroke: "white", strokeWidth: 0.5 }),
            PlotL.dot(facetData.filter(d => d.value !== null), {
                x: d => String(d.sn),
                y: "value",
                fill: "series",
            }),
            PlotL.tip(
                facetData.filter(d => d.value !== null),
                PlotL.pointer({
                    x: d => String(d.sn),
                    y: "value",
                    title: d => tooltipText(d),
                    fill: "black"
                })
            )
        ],
        figure: false
    });
};

const createComparisonPlot = (
    xMetric: string,
    yMetric: string,
    data: PlotData[],
    categoryX: string,
    categoryY: string
) => {
    const xValues = data.filter(d => d.metric === xMetric && d.value !== null);
    const yValues = data.filter(d => d.metric === yMetric && d.value !== null);
    const mergedData = xValues.map(x => {
        const yMatch = yValues.find(y => y.sn === x.sn);
        return yMatch ? { x: x.value, y: yMatch.value, series: x.series } : null;
    }).filter(Boolean) as { x: number; y: number; series: string }[];

    return PlotL.plot({
        width: 350,
        height: 350,
        marginLeft: 60,
        marginBottom: 40,
        className: "plot-chart-class",

        x: { label: categoryX },
        y: { label: categoryY },
        color: {
            domain: ["min", "max", "mean", "median"],
            range: [colorScheme.min, colorScheme.max, colorScheme.average, colorScheme.average]
        },
        marks: [
            PlotL.dot(mergedData.filter(d => d.series === "min"), { x: "x", y: "y", fill: colorScheme.min }),
            PlotL.dot(mergedData.filter(d => d.series === "max"), { x: "x", y: "y", fill: colorScheme.max }),
            PlotL.dot(mergedData.filter(d => d.series === "mean"), { x: "x", y: "y", fill: colorScheme.average }),
            PlotL.dot(mergedData.filter(d => d.series === "median"), { x: "x", y: "y", fill: colorScheme.average })
        ],
        grid: true
    });
};


const generatePlots = (data: PlotData[], compareMetrics: boolean) => {
    const groupedData = d3.group(data, d => d.metric);
    const plotElements: HTMLDivElement[] = [];

    if (compareMetrics) {
        const metrics = Array.from(groupedData.keys());
        for (let i = 0; i < metrics.length; i++) {
            for (let j = i + 1; j < metrics.length; j++) {
                const plot = createComparisonPlot(metrics[i], metrics[j], data, metrics[i], metrics[j]);
                const plotDiv = document.createElement("div");
                plotDiv.style.margin = "10px";
                plotDiv.style.display = "inline-block";
                plotDiv.appendChild(plot);
                plotElements.push(plotDiv);
            }
        }
    } else {

        const xDomain = Array.from(new Set(data.map(d => String(d.sn))));
        groupedData.forEach((facetData, category) => {
            const yDomain = d3.extent(facetData, d => d.value) as [number, number];
            const plot = createPlot(xDomain, yDomain, facetData, category);
            const plotDiv = document.createElement("div");
            plotDiv.style.marginBottom = "20px";
            plotDiv.appendChild(plot);
            plotElements.push(plotDiv);
        });
    }
    return plotElements;
};

export interface FacetPlotProps {
    data: PlotData[];
    xAxisType: xAxisTypes;
}

const FacetedPlot: React.FC<FacetPlotProps> = ({ data, xAxisType }) => {
    const containerRef = useRef<HTMLDivElement>(null);

    useEffect(() => {
        if (!containerRef.current) return;
        const plotElements = generatePlots(data, xAxisType === xAxisTypes.CROSS_COMPARE);
        containerRef.current.innerHTML = "";
        plotElements.forEach(element => containerRef.current?.appendChild(element));
    }, [data]);

    return <div ref={containerRef} />;
};

export default FacetedPlot;
