import { useEffect, useMemo } from 'react';
import { useThree } from '@react-three/fiber';
import {
  BackSide,
  Color,
  ColorRepresentation,
  DoubleSide,
  GLSL3,
  Matrix3,
  Mesh,
  NearestFilter,
  OrthographicCamera,
  PlaneGeometry,
  RawShaderMaterial,
  Vector2,
  WebGLRenderer,
  WebGLRenderTarget,
} from 'three';
import { Symmetry } from '../../studioState';
import {
  cross2d,
  rotate,
  sdBox,
  sdCircle,
  sdEllipse,
  sdSegment,
} from '../../../../lib/glsl';
import { getSymmetryMatrix } from '../../lib/symmetry';

export enum ShapeType {
  ellipse = 0,
  rectangle = 1,
  line = 2,
}

class ShapeRendererMaterial extends RawShaderMaterial {
  private static readonly vertexShader = /* glsl */ `
    precision mediump float;
    precision mediump int;

    #define PI 3.141592653589793

    uniform vec2 start;
    uniform vec2 end;
    uniform mat3 symmetryMatrix;
    uniform float rotation;
    uniform float size;
    uniform int snap;
    uniform int type;

    uniform mat4 projectionMatrix;

    in vec3 position;
    flat out vec2 vStart;
    flat out vec2 vEnd;
    out vec2 vUv;
    out vec2 vPosition;
    flat out vec2 vRadius;

    ${rotate}

    void main()	{

      if (type == ${ShapeType.line}) {
        vec2 s = start;
        vec2 e = end;
        if (snap == 1) {
          const float increment = (PI * 2.0) / 8.0;
          float angle = atan(e.y - s.y, e.x - s.x) - rotation;
          float snappedAngle = round(angle / increment) * increment + rotation;
          e = s + vec2(cos(snappedAngle), sin(snappedAngle)) * distance(s, e);
        }
        vec2 origin = (s + e) * 0.5;
        vec2 radius = abs(s - e) * 0.5 + size * 0.5 + 1.0;
        vStart = s - origin;
        vEnd = e - origin;
        vUv = position.xy * radius;
        vPosition = (symmetryMatrix * vec3(vUv + origin, 1.0)).xy;
        gl_Position = projectionMatrix * vec4(vPosition, 0.0, 1.0);
        return;
      }

      vec2 origin = (start + end) * 0.5;
      vec2 s = rotate(start - origin, -rotation);
      vec2 e = rotate(end - origin, -rotation);
      vec2 radius = abs(s - e) * 0.5;
      if (snap == 1) {
        vec2 r = vec2(min(radius.x, radius.y));
        origin += rotate((radius - r) * sign(s - e), rotation);
        radius = r;
      }
      vUv = position.xy * radius;
      vRadius = radius;


      vPosition = (symmetryMatrix * vec3(rotate(vUv, rotation) + origin, 1.0)).xy;
      gl_Position = projectionMatrix * vec4(vPosition, 0.0, 1.0);
    }
  `;

  private static readonly fragmentShader = /* glsl */ `
    precision mediump float;
    precision mediump int;

    uniform vec3 color;
    uniform float size;
    uniform int type;
    uniform int isRing;
    uniform mat3 symmetryMatrix;
    uniform bool allowCrossingSymmetryAxis;
    uniform vec2 symmetryOrigin;
    uniform vec2 symmetryDirection;

    flat in vec2 vStart;
    flat in vec2 vEnd;
    in vec2 vUv;
    in vec2 vPosition;
    flat in vec2 vRadius;
    out vec4 outputColor;

    ${sdBox}
    ${sdCircle}
    ${sdEllipse}
    ${sdSegment}
    ${cross2d}

    void main()	{
      //is the fragment left or right of the symmetry axis
      float crossingState = cross2d(vPosition - symmetryOrigin, symmetryDirection);
      if(!allowCrossingSymmetryAxis && crossingState <= 0.0){
        discard;
      }
      float dist;
      if (type == ${ShapeType.line}) {
        dist = sdSegment(vUv, vStart, vEnd) - size * 0.5;
      } else {
        if (type == ${ShapeType.rectangle}) {
          dist = sdBox(vUv, vRadius);
        } else if (abs(vRadius.x - vRadius.y) < 0.01) {
          dist = sdCircle(vUv, vRadius.x);
        } else {
          dist = sdEllipse(vUv, vRadius);
        }
        if (isRing == 1) {
          dist = abs(dist + size * 0.5 + 1.0) - size * 0.5;
        }
      }
      float alpha = 1.0 - smoothstep(-1.0, 1.0, dist);
      outputColor = clamp(vec4(color, alpha), 0.0, 1.0);
    }
  `;

  constructor() {
    super({
      glslVersion: GLSL3,
      side: DoubleSide,
      transparent: true,
      uniforms: {
        start: { value: new Vector2() },
        end: { value: new Vector2() },
        symmetryMatrix: { value: new Matrix3() },
        color: { value: new Color() },
        isRing: { value: 0 },
        rotation: { value: 0 },
        size: { value: 0 },
        snap: { value: 0 },
        type: { value: 0 },
        allowCrossingSymmetryAxis: { value: false },
        symmetryOrigin: { value: new Vector2(0, 0) },
        symmetryDirection: { value: new Vector2(0, 1) },
      },
      vertexShader: ShapeRendererMaterial.vertexShader,
      fragmentShader: ShapeRendererMaterial.fragmentShader,
    });
  }
}

class ShapeRendererMesh extends Mesh {
  declare material: ShapeRendererMaterial;
  constructor() {
    super(new PlaneGeometry(2, 2, 1, 1), new ShapeRendererMaterial());
  }
}

class ShapeRenderer {
  private static readonly auxColor = new Color();
  private static readonly clearColor = new Color(0);

  private readonly gl: WebGLRenderer;
  private readonly camera: OrthographicCamera;
  private readonly mesh: ShapeRendererMesh;
  private readonly target: WebGLRenderTarget;

  constructor(gl: WebGLRenderer, size: [number, number]) {
    this.gl = gl;
    this.camera = new OrthographicCamera(0, size[0], 0, size[1], -1, 1);
    this.mesh = new ShapeRendererMesh();
    this.target = new WebGLRenderTarget(size[0], size[1], {
      minFilter: NearestFilter,
      magFilter: NearestFilter,
    });
  }

  dispose() {
    const { mesh, target } = this;
    mesh.geometry.dispose();
    mesh.material.dispose();
    target.dispose();
  }

  getTexture() {
    return this.target.texture;
  }

  resize(size: [number, number]) {
    const { camera, target } = this;
    camera.right = size[0];
    camera.bottom = size[1];
    camera.updateProjectionMatrix();
    target.setSize(size[0], size[1]);
  }

  render(
    start: [number, number],
    end: [number, number],
    color: ColorRepresentation,
    isRing: boolean,
    rotation: number,
    size: number,
    snap: boolean,
    symmetry: Symmetry,
    type: ShapeType
  ) {
    const { gl, camera, mesh, target } = this;
    const {
      material: { uniforms },
    } = mesh;
    uniforms.start.value.fromArray(start);
    uniforms.end.value.fromArray(end);
    uniforms.symmetryMatrix.value.identity();
    uniforms.color.value.set(color);
    uniforms.isRing.value = isRing ? 1 : 0;
    uniforms.rotation.value = rotation;
    uniforms.size.value = size;
    uniforms.snap.value = snap ? 1 : 0;
    uniforms.type.value = type;

    const symmetryOrigin = new Vector2(
      symmetry.origin[0] * target.width,
      symmetry.origin[1] * target.height
    );

    const symmetryDirection = new Vector2(
      Math.cos(Math.PI / 2 - symmetry.rotation),
      Math.sin(Math.PI / 2 - symmetry.rotation)
    );

    if (
      uniforms.start.value
        .clone()
        .sub(symmetryOrigin)
        .cross(symmetryDirection) < 0
    ) {
      symmetryDirection.negate();
    }
    const negativeSymmetryDirection = symmetryDirection.clone().negate();

    uniforms.allowCrossingSymmetryAxis.value =
      symmetry.allowCrossingAxis || !symmetry.enabled; //if symmetry is disabled we don't want to limit the stroke
    uniforms.symmetryOrigin.value = symmetryOrigin;
    uniforms.symmetryDirection.value = symmetryDirection;

    const { auxColor, clearColor } = ShapeRenderer;
    const currentAutoClear = gl.autoClear;
    const currentClearAlpha = gl.getClearAlpha();
    const currentClearColor = gl.getClearColor(auxColor);
    const currentRenderTarget = gl.getRenderTarget();
    gl.autoClear = true;
    gl.setRenderTarget(target);
    gl.setClearColor(clearColor);
    gl.setClearAlpha(0);
    gl.render(mesh, camera);

    if (symmetry.enabled) {
      uniforms.symmetryMatrix.value.copy(
        getSymmetryMatrix(symmetry, [target.width, target.height])
      );
      uniforms.symmetryDirection.value = negativeSymmetryDirection;

      gl.autoClear = false;
      gl.render(mesh, camera);
    }

    gl.setRenderTarget(currentRenderTarget);
    gl.setClearColor(currentClearColor);
    gl.setClearAlpha(currentClearAlpha);
    gl.autoClear = currentAutoClear;
  }

  clear() {
    const { gl, target } = this;
    const { auxColor, clearColor } = ShapeRenderer;
    const currentClearAlpha = gl.getClearAlpha();
    const currentClearColor = gl.getClearColor(auxColor);
    const currentRenderTarget = gl.getRenderTarget();
    gl.setRenderTarget(target);
    gl.setClearColor(clearColor);
    gl.setClearAlpha(0);
    gl.clear();
    gl.setRenderTarget(currentRenderTarget);
    gl.setClearColor(currentClearColor);
    gl.setClearAlpha(currentClearAlpha);
  }
}

export const useShapeRenderer = (size: [number, number]) => {
  const gl = useThree((s) => s.gl);
  const shapeRenderer = useMemo(() => new ShapeRenderer(gl, size), []);
  useEffect(() => shapeRenderer.resize(size), [size[0], size[1]]);
  useEffect(() => () => shapeRenderer.dispose(), []);

  return shapeRenderer;
};
