import { ThreeEvent, useThree } from '@react-three/fiber';
import * as Sentry from '@sentry/react';
import { isEqual } from 'lodash';
import { useEffect, useRef, useState } from 'react';
import { DoubleSide, OrthographicCamera, Texture } from 'three';
import { segmentImage } from '@vizcom/shared/data-access/graphql';
import { sleep } from '@vizcom/shared/js-utils';
import {
  ToastIndicator,
  arrayBufferSha256,
  base64ToArrayBuffer,
  getGlobalCacheKeyArrayBuffer,
  imageDataToBlob,
  setGlobalCacheKeyArrayBuffer,
  useLastValue,
  useStableCallback,
  useStateWithRef,
} from '@vizcom/shared-ui-components';

import { screenPositionToLocal } from '../../../helpers';
import { LoadingPlaceholder } from '../../../utils/LoadingPlaceholder';
import { ActiveMask, MaskDisplayMode } from '../../selection/ActiveMask';
import { useSelectionApiStore } from '../../selection/useSelectionApi';
import { MaskOperation } from '../../selection/utils';
import { useWorkbenchStudioState } from '../../studioState';
import { BrushCursorPreview } from '../BrushEngine/BrushCursorPreview';
import {
  warmUpAutoSelectWorker,
  setAutoselectWorkerEmbedding,
  initiateAutoSelectWorker,
  autoSelect,
} from './auto-selection';
import { AutoSelectInputPointType } from './auto-selection/types';

// Input images to SAM must be resized so the longest side is 1024
const SAM_LONGEST_SIDE_LENGTH = 1024;

export const AutoSelectionCacheKey = 'vizcom:autoselection';

export const AutoSelection = ({
  drawingSize,
  getCompositedImage,
}: {
  drawingSize: [number, number];
  getCompositedImage: (() => ImageData) | undefined;
}) => {
  const selectionApiStore = useSelectionApiStore();
  const [error, setError] = useState<null | string>(null);
  const [initialized, setInitialized] = useState<boolean>();
  const lastAutoselectionHoverPosition = useRef<[number, number] | undefined>(
    undefined
  );

  const lastGetCompositedImage = useLastValue(getCompositedImage);

  useEffect(() => {
    (async () => {
      if (!lastGetCompositedImage.current) {
        return;
      }
      warmUpAutoSelectWorker(); // do this out-of-band while waiting for segmentation request to return
      // need to add a setTimeout of 0 here to allow the portal for the previously active layer to be mounted properly before getting the composited image
      // if we remove this one, the composited image won't show the previously active layer because it will have been unmounted, but the portal with
      // the layer won't be mounted yet. A workaround was tried to fix this with a tunnel (based out of rat-tunnel) but it causes other race-conditions
      // a real fix would be to refactor the react-three-fiber portal implementation to use `useLayoutEffect` instead of `useEffect` to mount the portal
      // before this call
      await sleep(0);
      const compositedImage = lastGetCompositedImage.current();
      const hash = await arrayBufferSha256(compositedImage.data);
      let embedding: ArrayBuffer | null = null;
      if (hash) {
        embedding = await getGlobalCacheKeyArrayBuffer(
          AutoSelectionCacheKey,
          hash
        );
      }
      if (!embedding) {
        const blob = await imageDataToBlob(compositedImage);
        const jobOutput = await segmentImage(blob);
        embedding = await base64ToArrayBuffer(jobOutput.embedding);
        if (hash) {
          setGlobalCacheKeyArrayBuffer(AutoSelectionCacheKey, hash, embedding);
        }
      }

      await setAutoselectWorkerEmbedding(embedding, {
        height: drawingSize[1],
        width: drawingSize[0],
        samScale: SAM_LONGEST_SIDE_LENGTH / Math.max(...drawingSize),
      });

      setInitialized(true);
    })().catch((e) => {
      console.error('Error while generating auto selection embedding', e);
      setError('Failed to process image, please try again in a few minutes');
      Sentry.captureException(e);
    });
    initiateAutoSelectWorker();
  }, [lastGetCompositedImage]);

  const camera = useThree((state) => state.camera as OrthographicCamera);

  const loadingRef = useRef(false);
  const [lastHoverMask, setLastHoverMask, lastHoverMaskRef] = useStateWithRef<
    | undefined
    | {
        maskTexture: Texture;
        imageData: ImageData;
        position: [number, number];
      }
  >(undefined);

  const refreshHoverMask = useStableCallback(async () => {
    if (!initialized || loadingRef.current) {
      return;
    }
    if (
      isEqual(
        lastHoverMaskRef.current?.position,
        lastAutoselectionHoverPosition.current
      )
    ) {
      return;
    }
    if (!lastAutoselectionHoverPosition.current) {
      setLastHoverMask(undefined);
      return;
    }
    const autoselectionInput = lastAutoselectionHoverPosition.current;
    loadingRef.current = true;
    const rawMask = await autoSelect([
      {
        type: AutoSelectInputPointType.INCLUDE,
        x: autoselectionInput[0],
        y: autoselectionInput[1],
      },
    ]);
    const mask = new ImageData(rawMask, drawingSize[0], drawingSize[1]);
    const maskTexture = new Texture(mask);
    maskTexture.needsUpdate = true;
    loadingRef.current = false;
    if (!lastAutoselectionHoverPosition.current) {
      // we received a pointerLeave event since we started loading the mask, in this case, don't set the mask
      return;
    }
    setLastHoverMask({
      maskTexture,
      position: autoselectionInput,
      imageData: mask,
    });
    refreshHoverMask(); // try refreshing the mask again in case the points have changed while we were waiting for the mask response from the worker
  });

  const computingMaskOnClick = useRef(false);
  const onClick = async (event: ThreeEvent<MouseEvent>) => {
    if (!initialized || computingMaskOnClick.current) {
      return;
    }
    computingMaskOnClick.current = true;
    const localPosition = screenPositionToLocal(
      [event.x, event.y],
      camera,
      event.eventObject
    );
    const layerPosition = [
      (localPosition[0] + 0.5) * drawingSize[0],
      (-localPosition[1] + 0.5) * drawingSize[1],
    ];

    try {
      let maskImageData: ImageData;
      if (
        lastHoverMaskRef.current &&
        Math.hypot(
          lastHoverMaskRef.current.position[0] - layerPosition[0],
          lastHoverMaskRef.current.position[1] - layerPosition[1]
        ) < 5
      ) {
        // we already computed the mask for this point, just apply it instead of recomputing it again
        maskImageData = lastHoverMaskRef.current.imageData;
      } else {
        const rawMask = await autoSelect([
          {
            type: AutoSelectInputPointType.INCLUDE,
            x: layerPosition[0],
            y: layerPosition[1],
          },
        ]);
        maskImageData = new ImageData(rawMask, drawingSize[0], drawingSize[1]);
      }

      const maskImage = await createImageBitmap(maskImageData);

      selectionApiStore.getState().editSelectionCanvas((ctx) => {
        const operation =
          useWorkbenchStudioState.getState().selectionSettings.operation;
        if (operation === MaskOperation.Replace) {
          ctx.fillStyle = '#000000';
          ctx.fillRect(0, 0, ctx.canvas.width, ctx.canvas.height);
        }
        ctx.globalCompositeOperation = 'source-over';
        if (operation === MaskOperation.Remove) {
          ctx.globalCompositeOperation = 'destination-out';
        }
        ctx.drawImage(maskImage, 0, 0);
      });
      maskImage.close();
    } finally {
      computingMaskOnClick.current = false;
    }
  };

  const onPointerMove = (event: ThreeEvent<MouseEvent>) => {
    const [x, y] = screenPositionToLocal(
      [event.x, event.y],
      camera,
      event.eventObject
    );
    lastAutoselectionHoverPosition.current = [
      (x + 0.5) * drawingSize[0],
      (-y + 0.5) * drawingSize[1],
    ];
    refreshHoverMask();
  };
  const onPointerLeave = () => {
    lastAutoselectionHoverPosition.current = undefined;
    setLastHoverMask(undefined);
  };

  if (!initialized && !error) {
    return (
      <LoadingPlaceholder
        backdropOpacity={0.3}
        backdropColor={'black'}
        scale={[drawingSize[0], drawingSize[1]]}
      />
    );
  }

  return (
    <>
      <BrushCursorPreview
        drawingSize={drawingSize}
        toolSize={0}
        toolAspect={1}
        toolAngle={0}
        color={'#000000'}
      />
      <mesh
        scale={[drawingSize[0], drawingSize[1], 1]}
        onPointerUp={onClick}
        onPointerMove={onPointerMove}
        onPointerLeave={onPointerLeave}
      >
        <planeGeometry args={[1, 1, 1, 1]} />
        <meshBasicMaterial transparent opacity={0} side={DoubleSide} />
      </mesh>

      {lastHoverMask && (
        <ActiveMask
          drawingSize={drawingSize}
          maskTexture={lastHoverMask.maskTexture}
          mode={MaskDisplayMode.FILL}
        />
      )}

      {error && <ToastIndicator variant="warning" text={error} />}
    </>
  );
};
