import { ThreeEvent, useThree } from '@react-three/fiber';
import { useDrag } from '@use-gesture/react';
import { Group, OrthographicCamera, Vector3 } from 'three';
import { useEffect, useState } from 'react';
import { screenPositionToWorld } from '../../../helpers';
import { useStableCallback } from '@vizcom/shared-ui-components';
import { WarpPoint } from './warp';

const v3 = new Vector3();

export const initialWarpPointDirections: WarpPoint[] = [
  [-1, -1],
  [1, -1],
  [1, 1],
  [-1, 1],
];

export const useLayerWarpTransform = (
  baseWarpPoints: WarpPoint[],
  onGestureEnd: (warpPoints: WarpPoint[]) => void,
  resizerGroupRef: React.MutableRefObject<Group>,
  layerWidth: number,
  layerHeight: number
) => {
  const camera = useThree((s) => s.camera as OrthographicCamera);
  const [warpPoints, setWarpPoints] = useState<WarpPoint[]>(baseWarpPoints);

  useEffect(() => {
    setWarpPoints(baseWarpPoints);
  }, [baseWarpPoints]);

  const bindWarpHandle = useDrag<ThreeEvent<PointerEvent>>((gesture) => {
    gesture.event.stopPropagation();

    if (gesture.tap) {
      return;
    }
    if (gesture.last) {
      onGestureEnd(warpPoints);
      return;
    }

    const [pointsIndices] = gesture.args as [number[]];
    const pointerPosition = screenPositionToWorld(
      [gesture.event.nativeEvent.clientX, gesture.event.nativeEvent.clientY],
      camera
    );

    if (!gesture.memo) {
      return {
        initialPointerPosition: pointerPosition,
        initialPoints: [...warpPoints],
      };
    }

    const { initialPointerPosition, initialPoints } = gesture.memo as {
      initialPointerPosition: [number, number];
      initialPoints: WarpPoint[];
    };

    resizerGroupRef.current.getWorldScale(v3);

    const dx = (pointerPosition[0] - initialPointerPosition[0]) / v3.x;
    const dy = (pointerPosition[1] - initialPointerPosition[1]) / v3.y;

    const newPoints = [...warpPoints];
    pointsIndices.forEach((index) => {
      const point = initialPoints[index];
      newPoints[index] = [point[0] + dx, point[1] + dy];
    });
    setWarpPoints(newPoints);

    return {
      initialPointerPosition,
      initialPoints,
    };
  });

  return {
    bindWarpHandle: bindWarpHandle,
    warpPoints,
    setWarpPoints,
    resetWarp: () => {
      setWarpPoints(
        initialWarpPointDirections.map((p) => [
          (p[0] * layerWidth) / 2,
          (p[1] * layerHeight) / 2,
        ])
      );
    },
  };
};
