import { RefObject } from "react";
import createImageIdsAndCacheMetaData from "./createImageIdsAndCacheMetaData";
import initCornerstone from "./initCornerstone";
import {
  RenderingEngine,
  volumeLoader,
  setVolumesForViewports,
  cache,
  getRenderingEngine,
  imageLoadPoolManager,
  ImageVolume,
} from "@cornerstonejs/core";
import {
  StackScrollMouseWheelTool,
  ToolGroupManager,
  CrosshairsTool,
  Enums as csToolsEnums,
  segmentation as cstSegmentation,
  SegmentationDisplayTool,
  WindowLevelTool,
  PanTool,
  ZoomTool,
  BrushTool,
  utilities,
} from "@cornerstonejs/tools";
import {
  SegmentationIds,
  RenderingEngineIds,
  ViewerIds,
} from "../types/cornerstone";
import {
  PublicViewportInput,
  IRenderingEngine,
} from "@cornerstonejs/core/dist/esm/types";
import { setCtTransferFunctionFromImage } from "./helpers/setCtTransferFunctionForVolumeActor";
import {
  KeyboardBindings,
  MouseBindings,
} from "@cornerstonejs/tools/dist/esm/enums";
import { ToolGroupSpecificRepresentation } from "@cornerstonejs/tools/dist/esm/types";
import getIds from "./helpers/getIds";
import {
  getSegInfo,
  injectBufferIntoSegVolume,
  generateColorLUT,
  SegInfo,
} from "./generateSegmentation";
import viewerState from "stores/viewer";
import { createVOISynchronizer } from "@cornerstonejs/tools/dist/esm/synchronizers";
import dcmjs from "dcmjs";
import { getSynchronizer } from "@cornerstonejs/tools/dist/esm/store/SynchronizerManager";

export const brushInstanceNames = {
  SphereBrush: "SphereBrush",
  SphereEraser: "SphereEraser",
} as const;

const brushStrategies = {
  [brushInstanceNames.SphereBrush]: "FILL_INSIDE_SPHERE",
  [brushInstanceNames.SphereEraser]: "ERASE_INSIDE_SPHERE",
} as const;

export async function initViewer(
  id: string,
  viewerRefs: RefObject<HTMLDivElement>[],
  instanceMetaData: dcmjs.data.NaturalizedDataset
) {
  if (!viewerState.instanceMetadata) return;
  viewerState.viewerId = id;

  const {
    renderingEngineId,
    viewportInputConf,
    volumeId,
    segVolumeId,
    segmentationId,
    toolGroupId,
    viewportId1,
    viewportId2,
    viewportId3,
  } = getIds();
  const renderingEngine = await setupRenderingEngine(viewerRefs, id, {
    renderingEngineId,
    viewportInputConf,
  });

  // Get imageIds
  const [imageIds, getSEGPixelData] = await createImageIdsAndCacheMetaData({
    StudyInstanceUID: instanceMetaData.StudyInstanceUID!,
    SeriesInstanceUID: instanceMetaData.SeriesInstanceUID!,
  });
  const segInfo = getSEGPixelData?.().then((SEGArrayBuffer) =>
    //@ts-ignore
    getSegInfo(imageIds, SEGArrayBuffer)
  );

  const volume = await createAndLoadVolume(
    { volumeId, viewportId1, viewportId2, viewportId3 },
    imageIds,
    renderingEngine
  );
  // load all frames from middle outwards in both directions
  loadFramesMiddleOutwards(
    volume,
    { volumeId, viewportId1, viewportId2, viewportId3 },
    renderingEngine
  );

  await setupSegmentation(
    {
      volumeId: volume.volumeId,
      segVolumeId,
      segmentationId,
      toolGroupId,
    },
    segInfo
  );
  return renderingEngine;
}

function generateRequests(customOrderedRequests, volumeId) {
  const requestType = "prefetch";
  const priority = 0;
  return customOrderedRequests.map(
    ({ callLoadImage, imageId, imageIdIndex, options }) => {
      const additionalDetails = { volumeId: volumeId };
      return {
        callLoadImage: callLoadImage.bind(this, imageId, imageIdIndex, options),
        requestType,
        additionalDetails,
        priority,
      };
    }
  );
}
function sleep(ms) {
  return new Promise((resolve) => setTimeout(resolve, ms));
}

async function imageLoadPoolManagerRequestAsync(
  request,
  retries = 3,
  delay = 1000
) {
  for (let attempt = 1; attempt <= retries; attempt++) {
    try {
      await new Promise<void>((resolve, reject) => {
        imageLoadPoolManager.addRequest(
          request.callLoadImage,
          request.requestType,
          request.additionalDetails,
          request.priority
        );
        resolve();
      });
      return; // Success, exit the function
    } catch (error) {
      console.error(`Attempt ${attempt} failed:`, error);
      if (attempt < retries) {
        console.log(`Retrying in ${delay}ms...`);
        await sleep(delay);
      } else {
        console.error("All attempts failed.");
        throw error;
      }
    }
  }
}
async function loadFramesMiddleOutwards(
  volume,
  viewerIds: ViewerIds,
  renderingEngine: IRenderingEngine
) {
  const { volumeId, viewportId1, viewportId2, viewportId3 } = viewerIds;
  // load all frames
  const loadRequests = await volume.getImageLoadRequests();
  const totalFrames = loadRequests.length;
  const batchSize = 5;
  const middle = Math.floor(totalFrames / 2);
  const batches = [];
  // Iterate over the frames, loading frames from middle outwards in batches of ten
  for (
    let i = middle, j = middle + 1;
    i >= 0 || j < totalFrames;
    i -= batchSize, j += batchSize
  ) {
    const batchStartLeft = Math.max(i - batchSize + 1, 0);
    const batchEndLeft = i;
    const batchStartRight = j;
    const batchEndRight = Math.min(j + batchSize - 1, totalFrames - 1);
    const leftBatch = loadRequests.slice(batchStartLeft, batchEndLeft + 1);
    const rightBatch = loadRequests.slice(batchStartRight, batchEndRight + 1);

    batches.push([leftBatch, rightBatch]);

    const [requestsLeft, requestsRight] = await Promise.all([
      generateRequests(leftBatch, volumeId),
      generateRequests(rightBatch, volumeId),
    ]);

    // Load frames in parallel using Promise.all
    await Promise.all([
      ...requestsLeft.map((request) =>
        imageLoadPoolManagerRequestAsync(request)
      ),
      ...requestsRight.map((request) =>
        imageLoadPoolManagerRequestAsync(request)
      ),
    ]);

    // Render frames in both left and right batches
    await Promise.all(
      [viewportId1, viewportId2, viewportId3].map(async (viewportId) => {
        const viewport = renderingEngine.getViewport(viewportId);
        viewport.render();
      })
    );
  }
}

async function createAndLoadVolume(
  viewerIds: ViewerIds,
  imageIds: string[],
  renderingEngine: IRenderingEngine
) {
  const { volumeId, viewportId1, viewportId2, viewportId3 } = viewerIds;

  if (!cache.getVolume(volumeId)) {
    // Define a volume in memory
    await volumeLoader.createAndCacheVolume(volumeId, {
      imageIds,
    });
  }

  const volume = cache.getVolume(volumeId) as ImageVolume;

  const middleSliceIndex = Math.floor(imageIds.length / 2);
  const middleSliceImage = volume.getCornerstoneImage(
    imageIds[middleSliceIndex],
    middleSliceIndex
  );

  // Set volumes on the viewports
  await setVolumesForViewports(
    renderingEngine,
    [
      {
        volumeId,
        callback: setCtTransferFunctionFromImage(middleSliceImage),
      },
    ],
    [viewportId1, viewportId2, viewportId3]
  );

  return volume;
}

async function setupRenderingEngine(
  viewerRefs: RefObject<HTMLDivElement>[],
  id: string,
  { renderingEngineId, viewportInputConf }: RenderingEngineIds
) {
  let renderingEngine = getRenderingEngine(renderingEngineId);
  if (!renderingEngine) {
    // TODO Move in startup Application maybe ??
    await initCornerstone();
    // Instantiate a rendering engine
    renderingEngine = new RenderingEngine(renderingEngineId);
  }
  // Set the viewports of the rendering engine
  const viewportInputs: PublicViewportInput[] = viewportInputConf.map(
    (v, i) => ({
      ...v,
      element: viewerRefs[i].current!,
    })
  );
  renderingEngine.setViewports(viewportInputs);

  // Add tools to Cornerstone3D
  setupTools(id);
  return renderingEngine;
}

function setupTools(id: string) {
  const {
    renderingEngineId,
    toolGroupId,
    voiSynchronizerId,
    viewportId1,
    viewportId2,
    viewportId3,
    getReferenceLineColor,
    getReferenceLineControllable,
    getReferenceLineDraggableRotatable,
    getReferenceLineSlabThicknessControlsOn,
  } = getIds(id);

  // Define tool groups to add the segmentation display tool to
  const toolGroup =
    ToolGroupManager.getToolGroup(toolGroupId) ||
    ToolGroupManager.createToolGroup(toolGroupId);

  // For the crosshairs to operate, the viewports must currently be
  // added ahead of setting the tool active. This will be improved in the future.
  toolGroup.addViewport(viewportId1, renderingEngineId);
  toolGroup.addViewport(viewportId2, renderingEngineId);
  toolGroup.addViewport(viewportId3, renderingEngineId);

  // Add  Tools
  if (!toolGroup._toolInstances[StackScrollMouseWheelTool.toolName]) {
    toolGroup.addTool(StackScrollMouseWheelTool.toolName);
  }

  if (!toolGroup._toolInstances[WindowLevelTool.toolName]) {
    toolGroup.addTool(WindowLevelTool.toolName);
  }

  if (!toolGroup._toolInstances[PanTool.toolName]) {
    toolGroup.addTool(PanTool.toolName);
  }

  if (!toolGroup._toolInstances[ZoomTool.toolName]) {
    toolGroup.addTool(ZoomTool.toolName);
  }

  // Add Tool instance for Brush & Eraser
  if (!toolGroup._toolInstances[BrushTool.toolName]) {
    toolGroup.addTool(BrushTool.toolName);
    toolGroup.addToolInstance(
      brushInstanceNames.SphereBrush,
      BrushTool.toolName,
      {
        activeStrategy: brushStrategies[brushInstanceNames.SphereBrush],
      }
    );
    toolGroup.addToolInstance(
      brushInstanceNames.SphereEraser,
      BrushTool.toolName,
      {
        activeStrategy: brushStrategies[brushInstanceNames.SphereEraser],
      }
    );
  }

  const voiSynchronizer =
    getSynchronizer(voiSynchronizerId) ||
    createVOISynchronizer(voiSynchronizerId);
  voiSynchronizer.add({ renderingEngineId, viewportId: viewportId1 });
  voiSynchronizer.add({ renderingEngineId, viewportId: viewportId2 });
  voiSynchronizer.add({ renderingEngineId, viewportId: viewportId3 });

  // Add Crosshairs tool and configure it to link the three viewports
  // These viewports could use different tool groups. See the PET-CT example
  // for a more complicated used case.
  if (!toolGroup._toolInstances[CrosshairsTool.toolName]) {
    toolGroup.addTool(CrosshairsTool.toolName, {
      getReferenceLineColor,
      getReferenceLineControllable,
      getReferenceLineDraggableRotatable,
      getReferenceLineSlabThicknessControlsOn,
    });
  }

  if (!toolGroup._toolInstances[SegmentationDisplayTool.toolName]) {
    toolGroup.addTool(SegmentationDisplayTool.toolName);
  }

  // Activate Tools
  toolGroup.setToolActive(CrosshairsTool.toolName, {
    bindings: [{ mouseButton: MouseBindings.Primary }],
  });

  toolGroup.setToolActive(WindowLevelTool.toolName, {
    bindings: [{ mouseButton: MouseBindings.Secondary }],
  });

  toolGroup.setToolActive(PanTool.toolName, {
    bindings: [{ mouseButton: MouseBindings.Auxiliary }],
  });

  toolGroup.setToolActive(ZoomTool.toolName, {
    bindings: [
      {
        mouseButton: MouseBindings.Auxiliary,
        modifierKey: KeyboardBindings.Ctrl,
      },
    ],
  });

  // As the Stack Scroll mouse wheel is a tool using the `mouseWheelCallback`
  // hook instead of mouse buttons, it does not need to assign any mouse button.
  toolGroup.setToolActive(StackScrollMouseWheelTool.toolName);

  // Set default config for Segmentation
  const config = cstSegmentation.config.getGlobalConfig();
  config.renderInactiveSegmentations = true;
  config.representations.LABELMAP.renderOutline = false;
  config.representations.LABELMAP.fillAlpha = 1;
  config.representations.LABELMAP.fillAlphaInactive = 1;
  cstSegmentation.config.setGlobalConfig(config);

  // Set default brush size
  utilities.segmentation.setBrushSizeForToolGroup(toolGroupId, 5);
}

async function setupSegmentation(
  { volumeId, segVolumeId, segmentationId, toolGroupId }: SegmentationIds,
  segInfoPromise: Promise<SegInfo>
) {
  const loadedSegReps =
    cstSegmentation.state.getSegmentationRepresentations(toolGroupId);

  const segReps = loadedSegReps
    ?.filter((rep: ToolGroupSpecificRepresentation) =>
      rep.segmentationId.includes(segmentationId)
    )
    ?.map((rep) => rep.segmentationRepresentationUID);
  const isSegAlreadyLoaded = Boolean(segReps?.length);

  if (isSegAlreadyLoaded || !segInfoPromise) return;
  let labelMapBuffers;
  let segInfo;
  try {
    segInfo = await segInfoPromise;

    labelMapBuffers = segInfo.labelmapBufferArray;
  } catch (error) {
    console.error("Error resolving segInfoPromise:", error);
  }

  // Create Seg volumes for each seg buffer extracted from dicom seg
  await Promise.all(
    labelMapBuffers.map(async (b, i) => {
      const cachedSegVolume = cache.getVolume(`${segVolumeId}_${i}`);
      if (cachedSegVolume) return;
      const segVolume = await volumeLoader.createAndCacheDerivedVolume(
        volumeId,
        {
          volumeId: `${segVolumeId}_${i}`,
          targetBuffer: {
            sharedArrayBuffer: true,
            type: "Uint8Array",
          },
        }
      );
      injectBufferIntoSegVolume(segVolume as ImageVolume, labelMapBuffers[i]);
    })
  );

  // Create a segmentation for each seg volume
  const segmentations = labelMapBuffers
    .filter(
      (b, i) => !cstSegmentation.state.getSegmentation(`${segmentationId}_${i}`)
    )
    .map((b, i) => ({
      segmentationId: `${segmentationId}_${i}`,
      representation: {
        type: csToolsEnums.SegmentationRepresentations.Labelmap,
        data: {
          volumeId: `${segVolumeId}_${i}`,
        },
      },
    }));
  if (segmentations.length) {
    cstSegmentation.addSegmentations(segmentations);
  }
  // Create a seg Representation for each segmentation
  const segmentationRepresentationUIDs =
    await cstSegmentation.addSegmentationRepresentations(
      toolGroupId,
      labelMapBuffers.map((b, i) => ({
        segmentationId: `${segmentationId}_${i}`,
        type: csToolsEnums.SegmentationRepresentations.Labelmap,
      }))
    );

  generateColorLUT(
    toolGroupId,
    segInfo.segMetadata.data,
    segmentationRepresentationUIDs
  );
  segmentationRepresentationUIDs.forEach((segRepUID) =>
    cstSegmentation.config.color.setColorLUT(toolGroupId, segRepUID, 0)
  );

  // Update viewer state
  if (viewerState.instanceMetadata) {
    viewerState.segmentsPerSeries = viewerState.segmentsPerSeries.map(
      (segment) => {
        if (
          segment.seriesUID !== viewerState.instanceMetadata.SeriesInstanceUID
        )
          return segment;
        const segmentIndex = segment.SegmentNumber;
        const bufferIndex = segInfo.segmentsOnFrameArray.findIndex((arr) =>
          arr.flat().includes(segmentIndex)
        );
        const representationUID =
          bufferIndex < 0
            ? segmentationRepresentationUIDs[0]
            : segmentationRepresentationUIDs[bufferIndex];
        return { ...segment, representationUID };
      }
    );
  }
}
