import { useState, WheelEvent } from "react";
import { Skeleton } from "@mui/material";
import { useParams } from "react-router";
import {
  Chart as ChartJS,
  CategoryScale,
  LinearScale,
  PointElement,
  LineElement,
  Title,
  Tooltip,
  Filler,
  Legend,
  ChartOptions,
} from "chart.js";
import { Line } from "react-chartjs-2";
import { useSnapshot } from "valtio";

import {
  calculatePercentiles,
  getAgeFromMetadata,
  RegressionModel,
} from "lib/client/normative";
import useMarkers, {
  parseQuantitativeDataFromSR,
} from "shared/hooks/useMarkers";
import normativeGraphState from "stores/normativeGraph";
import useLongitudinal from "shared/hooks/useLongitudinal";
import { useTranslate } from "ra-core";

ChartJS.register(
  CategoryScale,
  LinearScale,
  PointElement,
  LineElement,
  Title,
  Tooltip,
  Filler,
  Legend
);

export const percentileColors = [
  [208, 49, 45],
  [236, 151, 135],
  [231, 209, 38],
];

export const quantificationColors = {
  currentStudy : "#005678",
  other: "black"
} as const

function NormativeGraph() {
  const { study_id } = useParams();
  const snap = useSnapshot(normativeGraphState);
  const [zoom, setZoom] = useState(10);
  const {
    regressionModels,
    isRegressionModelLoading,
    PatientID,
    populatedMarkers,
  } = useMarkers(study_id);
  const { studiesInstanceMetadata, isDone } = useLongitudinal(PatientID);
  const t = useTranslate();

  const regressionModel = regressionModels?.[snap.rid];

  if (!regressionModel || isRegressionModelLoading || !isDone)
    return <Skeleton />;

  const markersPercentile = populatedMarkers.find(
    (marker) => marker.code === snap.rid
  )?.percentiles;

  const quantification = studiesInstanceMetadata
    ?.map((instances) => {
      const studyAge = Number(getAgeFromMetadata(instances?.[0])?.toFixed(2));
      return {
        age: studyAge,
        icv: parseQuantitativeDataFromSR(instances)?.find(
          (quant) => quant[0] === snap.rid && quant[2] === "%"
        )?.[1],
        currentStudy:
          instances?.[0].StudyInstanceUID === study_id ? true : false,
      };
    })
    .sort((a, b) => a.age - b.age);

  const data = computeNormativeData(
    regressionModel,
    quantification,
    markersPercentile
  );
  const ages = quantification.map(({ age }) => age);
  const options: ChartOptions<"line"> = {
    responsive: true,
    elements: {
      point: {
        pointStyle: false,
        backgroundColor: quantificationColors.currentStudy,
      },
    },
    scales: {
      xAxis: {
        type: "linear",
        min: (Math.min(...ages) || 28) - zoom,
        max: (Math.max(...ages) || 90) + zoom,
        title: {
          display: true,
          text: t("modes.normative_graph.age"),
        },
      },
      y: {
        title: {
          display: true,
          text: t("modes.normative_graph.volume"),
        },
      },
    },
    plugins: {
      legend: {
        position: "top" as const,
      },
      title: {
        display: true,
        text: `${t("modes.normative_graph")} : ${snap.rid
          .split(":")
          .map((rid) => t(rid))}`,
      },
    },
  };

  function handleWheelEvent(e: WheelEvent<HTMLDivElement>) {
    if (e.deltaY >= 0 && zoom < 30) {
      setZoom((zoom) => zoom + 5);
    } else if (e.deltaY < 0 && zoom > 5) {
      setZoom((zoom) => zoom - 5);
    }
  }

  return (
    <div
      data-testid="viewer-mode-normative-graph"
      className="h-full"
      onWheel={handleWheelEvent}
    >
      <Line options={options} data={data} />
    </div>
  );
}

function computeNormativeData(
  regressionModel: RegressionModel,
  quantification: {
    age: number;
    icv: number;
    currentStudy: boolean;
  }[] = undefined,
  percToDisplay: number[] = [0.01, 0.05, 0.25, 0.5]
) {
  const min = 0;
  const max = 200;
  // TODO get the norm min norm max from the api
  const normMin = regressionModel?.min ?? 0;
  const normMax = regressionModel?.max ?? 90;
  const applyStyle = (ctx, value) =>
    ctx.p0.parsed.x < normMin || ctx.p0.parsed.x > normMax ? value : undefined;

  const data: {
    age: number;
    percentiles: ReturnType<typeof calculatePercentiles>;
  }[] = [];
  for (let age = min; age < max; age++) {
    data.push({ age, percentiles: calculatePercentiles(regressionModel, age) });
  }

  const ds = percToDisplay.map((perc, i) => {
    let fill: string | number = i;
    if (perc === 0.5) fill = undefined;
    else if (i === 0 && perc > 0.5) fill = "end";
    else if (i === 0 && perc < 0.5) fill = "start";

    const color =
      perc === 0.5 ? "white" : `rgb(${percentileColors[i].join(",")})`;

    return {
      label: `P${perc.toFixed(2).substring(2)}`,
      data: data.map(({ age, percentiles }) => ({
        x: age,
        y: percentiles[String(perc)],
      })),
      borderColor: color,
      backgroundColor: color,
      segment: {
        backgroundColor: (ctx) =>
          applyStyle(ctx, `rgb(${percentileColors[i]?.join(",")}, 0.1)`),
        borderDash: (ctx) => applyStyle(ctx, [3, 3]),
      },
      fill,
    };
  });

  const datasets = [
    {
      label: "Quantification",
      data: quantification.map(({ age, icv }) => ({ x: age, y: icv })),
      pointBackgroundColor: quantification.map(({ currentStudy }) => {
        return currentStudy ? quantificationColors.currentStudy : quantificationColors.other;
      }),
      pointBorderColor: quantification.map(({ currentStudy }) => {
        return currentStudy ? quantificationColors.currentStudy : quantificationColors.other;
      }),
      borderColor: quantificationColors.other,
      backgroundColor: quantificationColors.other,
      pointStyle: "circle",
      pointRadius: 6,
    },
    ...ds,
  ];

  return { datasets };
}

export default NormativeGraph;
