import { useEffect, useMemo } from 'react';
import { useThree } from '@react-three/fiber';
import {
  BackSide,
  Color,
  ColorRepresentation,
  GLSL3,
  InstancedBufferGeometry,
  Mesh,
  NearestFilter,
  OrthographicCamera,
  PlaneGeometry,
  RawShaderMaterial,
  WebGLRenderer,
  WebGLRenderTarget,
  Vector2,
} from 'three';
import { cross2d, rotate, sdCircle, sdEllipse } from '../../../../lib/glsl';

export enum BrushHardness {
  hard = 0,
  soft = 1,
}

class StrokeRendererGeometry extends InstancedBufferGeometry {
  constructor() {
    super();
    const plane = new PlaneGeometry(2, 2, 1, 1);
    this.setAttribute('position', plane.getAttribute('position'));
    this.setIndex(plane.getIndex());
  }
}

class StrokeRendererMaterial extends RawShaderMaterial {
  private static readonly vertexShader = /* glsl */ `
    precision mediump float;

    struct CubicPoly {
      float c0;
      float c1;
      float c2;
      float c3;
    };

    struct Pressure {
      float start;
      float end;
    };

    uniform int count;
    uniform Pressure pressure;
    uniform CubicPoly px;
    uniform CubicPoly py;
    uniform float rotation;
    uniform float size;

    uniform mat4 projectionMatrix;

    in vec3 position;
    out float vRadius;
    out vec2 vUV;
    out vec2 vPosition;

    ${rotate}

    float getCubicPolyAt(CubicPoly p, float t) {
      float t2 = t * t;
      float t3 = t2 * t;
      return p.c0 + p.c1 * t + p.c2 * t2 + p.c3 * t3;
    }

    void main()	{
      float t = float(gl_InstanceID) / float(count);
      vec2 point = vec2(getCubicPolyAt(px, t), getCubicPolyAt(py, t));
      float pressureAtPoint = mix(pressure.start, pressure.end, t);
      vRadius = max((size * 0.5) * pressureAtPoint, 0.5);
      vUV = position.xy * (vRadius + 2.0);
      vPosition = rotate(vUV, -rotation) + point;
      gl_Position = projectionMatrix * vec4(vPosition, 0.0, 1.0);
    }
  `;

  private static readonly fragmentShader = /* glsl */ `
    precision mediump float;

    uniform float aspect;
    uniform int hardness;
    uniform vec3 color;
    uniform bool allowCrossingSymmetryAxis;
    uniform vec2 symmetryOrigin;
    uniform vec2 symmetryDirection;

    in float vRadius;
    in vec2 vUV;
    in vec2 vPosition;
    out vec4 outputColor;

    ${sdCircle}
    ${sdEllipse}
    ${cross2d}

    void main()	{
    //is the fragment left or right of the symmetry axis
    if(!allowCrossingSymmetryAxis && cross2d(vPosition - symmetryOrigin, symmetryDirection) < 0.0){
        discard;
      }

      float dist;
      if (aspect == 1.0) {
        dist = sdCircle(vUV, vRadius);
      } else {
        dist = sdEllipse(vUV, vec2(1.0, aspect) * vRadius);
      }
      float edge = min(vRadius, 2.0);
      float alpha = 1.0 - smoothstep(hardness == ${BrushHardness.soft} ? -(vRadius * aspect) : -edge, edge, dist);
      outputColor = clamp(vec4(color, alpha), 0.0, 1.0);
    }
  `;

  constructor() {
    super({
      glslVersion: GLSL3,
      side: BackSide,
      transparent: true,
      uniforms: {
        aspect: { value: 0 },
        color: { value: new Color() },
        count: { value: 0 },
        hardness: { value: 0 },
        pressure: { value: { start: 0, end: 0 } },
        px: { value: { c0: 0, c1: 0, c2: 0, c3: 0 } },
        py: { value: { c0: 0, c1: 0, c2: 0, c3: 0 } },
        rotation: { value: 0 },
        size: { value: 0 },
        allowCrossingSymmetryAxis: { value: false },
        symmetryOrigin: { value: new Vector2(0, 0) },
        symmetryDirection: { value: new Vector2(0, 1) },
      },
      vertexShader: StrokeRendererMaterial.vertexShader,
      fragmentShader: StrokeRendererMaterial.fragmentShader,
    });
  }
}

class StrokeRendererMesh extends Mesh {
  declare geometry: StrokeRendererGeometry;
  declare material: StrokeRendererMaterial;
  constructor() {
    super(new StrokeRendererGeometry(), new StrokeRendererMaterial());
  }
}

class StrokeRenderer {
  private static readonly auxColor = new Color();
  private static readonly clearColor = new Color(0);

  private readonly gl: WebGLRenderer;
  private readonly camera: OrthographicCamera;
  private readonly mesh: StrokeRendererMesh;
  private readonly target: WebGLRenderTarget;

  constructor(gl: WebGLRenderer, size: [number, number]) {
    this.gl = gl;
    this.camera = new OrthographicCamera(0, size[0], 0, size[1], -1, 1);
    this.mesh = new StrokeRendererMesh();
    this.target = new WebGLRenderTarget(size[0], size[1], {
      minFilter: NearestFilter,
      magFilter: NearestFilter,
    });
  }

  dispose() {
    const { mesh, target } = this;
    mesh.geometry.dispose();
    mesh.material.dispose();
    target.dispose();
  }

  getTexture() {
    return this.target.texture;
  }

  resize(size: [number, number]) {
    const { camera, target } = this;
    camera.right = size[0];
    camera.bottom = size[1];
    camera.updateProjectionMatrix();
    target.setSize(size[0], size[1]);
  }

  render(
    aspect: number,
    color: ColorRepresentation,
    count: number,
    hardness: BrushHardness,
    pressure: { start: number; end: number },
    px: { c0: number; c1: number; c2: number; c3: number },
    py: { c0: number; c1: number; c2: number; c3: number },
    rotation: number,
    size: number,
    allowCrossingSymmetryAxis: boolean,
    symmetryOrigin: Vector2,
    symmetryDirection: Vector2
  ) {
    const { gl, camera, mesh, target } = this;
    const {
      geometry,
      material: { uniforms },
    } = mesh;
    geometry.instanceCount = count;
    uniforms.aspect.value = aspect;
    uniforms.count.value = count;
    uniforms.color.value.set(color);
    uniforms.hardness.value = hardness;
    uniforms.pressure.value = pressure;
    uniforms.px.value = px;
    uniforms.py.value = py;
    uniforms.rotation.value = rotation;
    uniforms.allowCrossingSymmetryAxis.value = allowCrossingSymmetryAxis;
    uniforms.symmetryOrigin.value = symmetryOrigin;
    uniforms.symmetryDirection.value = symmetryDirection;
    uniforms.size.value = size;

    const currentAutoClear = gl.autoClear;
    const currentRenderTarget = gl.getRenderTarget();
    gl.autoClear = false;
    gl.setRenderTarget(target);
    gl.render(mesh, camera);
    gl.setRenderTarget(currentRenderTarget);
    gl.autoClear = currentAutoClear;
  }

  clear() {
    const { gl, target } = this;
    const currentClearAlpha = gl.getClearAlpha();
    const currentClearColor = gl.getClearColor(StrokeRenderer.auxColor);
    const currentRenderTarget = gl.getRenderTarget();
    gl.setRenderTarget(target);
    gl.setClearColor(StrokeRenderer.clearColor);
    gl.setClearAlpha(0);
    gl.clear();
    gl.setRenderTarget(currentRenderTarget);
    gl.setClearColor(currentClearColor);
    gl.setClearAlpha(currentClearAlpha);
  }
}

export const useStrokeRenderer = (size: [number, number]) => {
  const gl = useThree((s) => s.gl);
  const strokeRenderer = useMemo(() => new StrokeRenderer(gl, size), []);
  useEffect(() => strokeRenderer.resize(size), [size[0], size[1]]);
  useEffect(() => () => strokeRenderer.dispose(), []);

  return strokeRenderer;
};
