import { shaderMaterial } from '@react-three/drei';
import { extend } from '@react-three/fiber';
import { Matrix3, Texture, DoubleSide } from 'three';

import { cross2d, straightAlphaBlending } from '../../../../lib/glsl';
import { VizcomRenderingOrderEntry } from '../../../utils/threeRenderingOrder';

export function MaskedTransformMesh({
  drawingSize,
  selectionTexture,
  layerImage,
  inverseTransform,
}: {
  drawingSize: [number, number];
  selectionTexture: Texture;
  layerImage: Texture | undefined;
  inverseTransform: Matrix3;
}) {
  return (
    <mesh
      userData={{
        vizcomRenderingOrder: [
          {
            zIndex: 1,
          } satisfies VizcomRenderingOrderEntry,
        ],
      }}
    >
      <planeGeometry args={[drawingSize[0], drawingSize[1]]} />

      <maskedTransformMaterial
        u_map={layerImage!}
        u_mask={selectionTexture}
        u_transform={inverseTransform}
        u_drawingSize={drawingSize}
        transparent
        depthTest={false}
        side={DoubleSide}
      />
    </mesh>
  );
}

export const MaskedTransformMaterialImpl = shaderMaterial(
  {
    u_map: null,
    u_mask: null,
    u_transform: new Matrix3(),
    u_drawingSize: null,
  },
  `
  varying vec2 v_originalPos;
  varying vec2 v_uv;

  void main() {
    v_originalPos = position.xy;
    v_uv = uv;
    gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
  }`,
  `varying vec2 v_originalPos;
  varying vec2 v_uv;

  uniform sampler2D u_map;
  uniform sampler2D u_mask;
  uniform mat3 u_transform;
  uniform float u_time;
  uniform vec2 u_drawingSize;

  ${cross2d}
  ${straightAlphaBlending}

  void main() {
    vec4 layer = texture2D(u_map, v_uv);
    float mask = texture2D(u_mask, v_uv).r;
    vec4 maskedBackground = vec4(layer.rgb, layer.a * (1.0 - mask));

    gl_FragColor = maskedBackground;
    vec2 transformedPosition = (u_transform * vec3(v_originalPos, 1.0)).xy;

    vec2 transformedUV = (vec3(transformedPosition / u_drawingSize + 0.5, 1.0)).xy;
    vec4 transformedLayer = texture2D(u_map, transformedUV);
    float transformedMask = 0.0;
    if(transformedUV.x >= 0.0 && transformedUV.x <= 1.0 &&
       transformedUV.y >= 0.0 && transformedUV.y <= 1.0
    ){
        transformedMask = texture2D(u_mask, transformedUV).r;
    }
    vec4 maskedTransformedLayer = vec4(transformedLayer.rgb, transformedLayer.a * transformedMask);

    gl_FragColor = straightAlphaBlending(maskedBackground, maskedTransformedLayer);
  }`
);

type MaskedTransformMaterialType = JSX.IntrinsicElements['shaderMaterial'] & {
  u_map: Texture;
  u_mask: Texture;
  u_transform: Matrix3;
  u_drawingSize: [number, number];
};

declare global {
  // eslint-disable-next-line @typescript-eslint/no-namespace
  namespace JSX {
    interface IntrinsicElements {
      maskedTransformMaterial: MaskedTransformMaterialType;
    }
  }
}

extend({ MaskedTransformMaterial: MaskedTransformMaterialImpl });
