import { ThreeEvent, useThree } from '@react-three/fiber';
import { useDrag } from '@use-gesture/react';
import { useState } from 'react';
import { Group, OrthographicCamera, Vector3 } from 'three';

import { screenPositionToWorld } from '../../../helpers';
import { WarpPoint } from './warp';

const v3 = new Vector3();

export const initialWarpPointDirections: WarpPoint[] = [
  [-1, -1],
  [1, -1],
  [1, 1],
  [-1, 1],
];

export const useLayerWarpTransform = (
  baseWarpPoints: WarpPoint[],
  onGestureEnd: (warpPoints: WarpPoint[]) => void,
  resizerGroupRef: React.MutableRefObject<Group>
) => {
  const camera = useThree((s) => s.camera as OrthographicCamera);
  const [warpPointsTransforms, setWarpPointsTransforms] = useState<WarpPoint[]>(
    [...baseWarpPoints]
  );

  const bindWarpHandle = useDrag<ThreeEvent<PointerEvent>>((gesture) => {
    gesture.event.stopPropagation();

    if (gesture.tap) {
      return;
    }
    if (gesture.last) {
      onGestureEnd(warpPointsTransforms);
      return;
    }

    const [pointsIndices] = gesture.args as [number[]];
    const pointerPosition = screenPositionToWorld(
      [gesture.event.nativeEvent.clientX, gesture.event.nativeEvent.clientY],
      camera
    );

    if (!gesture.memo) {
      return {
        initialPointerPosition: pointerPosition,
        initialPoints: [...warpPointsTransforms],
      };
    }

    const { initialPointerPosition, initialPoints } = gesture.memo as {
      initialPointerPosition: [number, number];
      initialPoints: WarpPoint[];
    };

    resizerGroupRef.current.getWorldScale(v3);

    const dx = (pointerPosition[0] - initialPointerPosition[0]) / v3.x;
    const dy = (pointerPosition[1] - initialPointerPosition[1]) / v3.y;

    const newTransforms = [...initialPoints];
    pointsIndices.forEach((index) => {
      newTransforms[index] = [
        newTransforms[index][0] + dx,
        newTransforms[index][1] + dy,
      ];
    });
    setWarpPointsTransforms(newTransforms);

    return {
      initialPointerPosition,
      initialPoints,
    };
  });

  return {
    bindWarpHandle,
    warpPointsTransforms,
    setWarpPointsTransforms,
    resetWarp: () => {
      setWarpPointsTransforms([...baseWarpPoints]);
    },
  };
};
