import { BufferAttribute, Matrix3, Vector2, Texture } from 'three';
import { useRef, useEffect } from 'react';

import { extend } from '@react-three/fiber';
import { shaderMaterial } from '@react-three/drei';
import { cross2d, straightAlphaBlending } from '../../../../lib/glsl';

export type WarpPoint = [number, number];

export function WarpGeometry({ points }: { points: WarpPoint[] }) {
  const posRef = useRef<BufferAttribute>(null!);
  useEffect(() => {
    posRef.current.array = new Float32Array(
      points.map((p) => [p[0], p[1], 0]).flat()
    );
    posRef.current.needsUpdate = true;
  }, [points]);
  return (
    <bufferGeometry>
      <bufferAttribute
        ref={posRef}
        attach="attributes-position"
        array={new Float32Array(points.map((p) => [p[0], p[1], 0]).flat())}
        count={4}
        itemSize={3}
      />
      <bufferAttribute
        attach="index"
        array={new Uint8Array([0, 1, 3, 1, 2, 3])}
        count={6}
        itemSize={1}
      />
    </bufferGeometry>
  );
}

//From https://iquilezles.org/articles/ibilinear/
const invBilinear = /*glsl */ `
  vec2 invBilinear(vec2 p, vec2 a, vec2 b, vec2 c, vec2 d)
  {
      vec2 res = vec2(-1.0);

      vec2 e = b-a;
      vec2 f = d-a;
      vec2 g = a-b+c-d;
      vec2 h = p-a;

      float k2 = cross2d( g, f );
      float k1 = cross2d( e, f ) + cross2d( h, g );
      float k0 = cross2d( h, e );

      // if edges are parallel, this is a linear equation
      if( abs(k2)<0.001 ||  length(g) < 0.01)
      {
          res = vec2( (h.x*k1+f.x*k0)/(e.x*k1-g.x*k0), -k0/k1 );
      }
      // otherwise, it's a quadratic
      else
      {
          float w = k1*k1 - 4.0*k0*k2;
          if( w<0.0 ) return vec2(-1.0);
          w = sqrt( w );

          float ik2 = 0.5/k2;
          float v = (-k1 - w)*ik2;
          float u = (h.x - f.x*v)/(e.x + g.x*v);

          if( u<0.0 || u>1.0 || v<0.0 || v>1.0 )
          {
            v = (-k1 + w)*ik2;
            u = (h.x - f.x*v)/(e.x + g.x*v);
          }
          res = vec2( u, v );
      }

      return res;
  }
`;

export const WarpMaterialImpl = shaderMaterial(
  {
    u_map: null,
    u_points: [],
    u_textureMatrix: new Matrix3(),
  },
  `
  varying vec2 v_originalPos;

  void main() {
    v_originalPos = position.xy;
    gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
  }`,
  `varying vec2 v_originalPos;

  uniform vec2 u_points[4];
  uniform sampler2D u_map;
  uniform mat3 u_textureMatrix;

  ${cross2d}
  ${invBilinear}

  void main() {
    vec2 mappedUV = invBilinear(v_originalPos, u_points[0], u_points[1], u_points[2], u_points[3]);
    gl_FragColor = texture2D(u_map, (u_textureMatrix * vec3(mappedUV, 1.0)).xy);
  }`
);

extend({ WarpMaterial: WarpMaterialImpl });

type WarpMaterialType = JSX.IntrinsicElements['shaderMaterial'] & {
  u_points: Vector2[];
  u_map: Texture;
  u_textureMatrix: Matrix3;
};

declare global {
  // eslint-disable-next-line @typescript-eslint/no-namespace
  namespace JSX {
    interface IntrinsicElements {
      warpMaterial: WarpMaterialType;
    }
  }
}

export const MaskedWarpMaterialImpl = shaderMaterial(
  {
    u_map: null,
    u_mask: null,
    u_points: [],
    u_translation: null,
    u_textureMatrix: new Matrix3(),
  },
  `
  varying vec2 v_originalPos;
  varying vec2 v_uv;

  void main() {
    v_originalPos = position.xy;
    v_uv = uv;
    gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
  }`,
  `varying vec2 v_originalPos;
  varying vec2 v_uv;

  uniform vec2 u_points[4];
  uniform sampler2D u_map;
  uniform sampler2D u_mask;
  uniform mat3 u_textureMatrix;
  uniform vec2 u_translation;
  uniform float u_time;
  uniform vec2 u_drawingSize;

  ${cross2d}
  ${invBilinear}
  ${straightAlphaBlending}

  void main() {
    vec4 layer = texture2D(u_map, v_uv);
    float mask = texture2D(u_mask, v_uv).r;
    vec4 maskedBackground = vec4(layer.rgb, layer.a * (1.0 - mask));

    vec2 mappedUV = invBilinear(v_originalPos - u_translation, u_points[0], u_points[1], u_points[2], u_points[3]);
    vec2 transformedUV = (u_textureMatrix * vec3(mappedUV, 1.0)).xy;
    vec4 transformedLayer = texture2D(u_map, transformedUV);
    float transformedMask = 0.0;
    if(transformedUV.x >= 0.0 && transformedUV.x <= 1.0 &&
       transformedUV.y >= 0.0 && transformedUV.y <= 1.0
    ){
        transformedMask = texture2D(u_mask, transformedUV).r;
    }
    vec4 maskedTransformedLayer = vec4(transformedLayer.rgb, transformedLayer.a * transformedMask);

    gl_FragColor = straightAlphaBlending(maskedBackground, maskedTransformedLayer);
  }`
);

extend({ MaskedWarpMaterial: MaskedWarpMaterialImpl });

type MaskedWarpMaterialType = JSX.IntrinsicElements['shaderMaterial'] & {
  u_points: Vector2[];
  u_textureMatrix: Matrix3;
  u_map: Texture;
  u_mask: Texture;
  u_translation: [number, number];
};

declare global {
  // eslint-disable-next-line @typescript-eslint/no-namespace
  namespace JSX {
    interface IntrinsicElements {
      maskedWarpMaterial: MaskedWarpMaterialType;
    }
  }
}
