// FROM https://raw.githubusercontent.com/pmndrs/drei/master/src/web/Html.tsx
// Modified to fix a bug with the occlusion and orthographic camera

import * as React from 'react';
import * as ReactDOM from 'react-dom/client';
import {
  Vector3,
  Group,
  Object3D,
  Matrix4,
  Camera,
  PerspectiveCamera,
  OrthographicCamera,
  DoubleSide,
  Mesh,
  Frustum,
  Box3,
} from 'three';
import { Assign } from 'utility-types';
import {
  ReactThreeFiber,
  useFrame,
  useThree,
  context,
} from '@react-three/fiber';
import { ThemeContext } from 'styled-components';
import {
  UNSAFE_LocationContext,
  UNSAFE_NavigationContext,
} from 'react-router-dom';
import { useContextBridge } from '@react-three/drei';
import { Context as UrqlContext } from 'urql';
import { DEFAULT_Z_INDEX_RANGE } from '../helpers';
import { isDraggingContext, isViewerContext } from '../../lib/utils';
import { useRecoilBridgeAcrossReactRoots_UNSTABLE } from 'recoil';
import { SelectionApiContext } from '../studio/selection/useSelectionApi';

const v1 = new Vector3();
const v2 = new Vector3();
const v3 = new Vector3();
const b3 = new Box3();
const frustum = new Frustum();

function defaultCalculatePosition(
  el: Object3D,
  camera: Camera,
  size: { width: number; height: number }
) {
  const objectPos = v1.setFromMatrixPosition(el.matrixWorld);
  objectPos.project(camera);
  const widthHalf = size.width / 2;
  const heightHalf = size.height / 2;
  return [
    objectPos.x * widthHalf + widthHalf,
    -(objectPos.y * heightHalf) + heightHalf,
  ];
}

export type CalculatePosition = typeof defaultCalculatePosition;

const isInviewport = (
  el: Object3D,
  camera: Camera,
  size: { width: number; height: number }
) => {
  const objectPos = new Vector3().setFromMatrixPosition(el.matrixWorld);

  const pointsToCheck = [
    new Vector3(
      objectPos.x - size.width / 2,
      objectPos.y - size.height / 2,
      objectPos.z
    ),
    new Vector3(
      objectPos.x + size.width / 2,
      objectPos.y - size.height / 2,
      objectPos.z
    ),
    new Vector3(
      objectPos.x - size.width / 2,
      objectPos.y + size.height / 2,
      objectPos.z
    ),
    new Vector3(
      objectPos.x + size.width / 2,
      objectPos.y + size.height / 2,
      objectPos.z
    ),
  ];

  frustum.setFromProjectionMatrix(
    new Matrix4().multiplyMatrices(
      camera.projectionMatrix,
      camera.matrixWorldInverse
    )
  );
  b3.setFromPoints(pointsToCheck);

  return frustum.intersectsBox(b3);
};

function objectZIndex(
  el: Object3D,
  camera: Camera,
  zIndexRange: Array<number>
) {
  const objectPos = v1.setFromMatrixPosition(el.matrixWorld);
  const cameraPos = v2.setFromMatrixPosition(camera.matrixWorld);
  if (camera instanceof PerspectiveCamera) {
    const A = (zIndexRange[1] - zIndexRange[0]) / (camera.far - camera.near);
    const B = zIndexRange[1] - A * camera.far;
    const dist = objectPos.distanceTo(cameraPos);
    return Math.round(A * dist + B);
  } else if (camera instanceof OrthographicCamera) {
    // VIZCOM TWEAK: the original code compute the zIndex by measuring the distance between the camera and the center of the element
    // this doesn't work in our case because it means the elements jump depending on the position of the camera
    // instead we compute zIndex by looking at the distance in the z axis only
    const A = (zIndexRange[1] - zIndexRange[0]) / (camera.far - camera.near);
    const B = zIndexRange[1] - A * camera.far;
    return Math.round(A * Math.abs(objectPos.z - cameraPos.z) + B);
  }
  return undefined;
}

const epsilon = (value: number) => (Math.abs(value) < 1e-10 ? 0 : value);

function getCSSMatrix(matrix: Matrix4, multipliers: number[], prepend = '') {
  let matrix3d = 'matrix3d(';
  for (let i = 0; i !== 16; i++) {
    matrix3d +=
      epsilon(multipliers[i] * matrix.elements[i]) + (i !== 15 ? ',' : ')');
  }
  return prepend + matrix3d;
}

const getCameraCSSMatrix = ((multipliers: number[]) => {
  return (matrix: Matrix4) => getCSSMatrix(matrix, multipliers);
})([1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1]);

const getObjectCSSMatrix = ((scaleMultipliers: (n: number) => number[]) => {
  return (matrix: Matrix4, factor: number) =>
    getCSSMatrix(matrix, scaleMultipliers(factor), 'translate(-50%,-50%)');
})((f: number) => [
  1 / f,
  1 / f,
  1 / f,
  1,
  -1 / f,
  -1 / f,
  -1 / f,
  -1,
  1 / f,
  1 / f,
  1 / f,
  1,
  1,
  1,
  1,
  1,
]);

type PointerEventsProperties =
  | 'auto'
  | 'none'
  | 'visiblePainted'
  | 'visibleFill'
  | 'visibleStroke'
  | 'visible'
  | 'painted'
  | 'fill'
  | 'stroke'
  | 'all'
  | 'inherit';

export interface CustomHtmlProps
  extends Omit<
    Assign<
      React.HTMLAttributes<HTMLDivElement>,
      ReactThreeFiber.Object3DNode<Group, typeof Group>
    >,
    'ref'
  > {
  prepend?: boolean;
  center?: boolean;
  fullscreen?: boolean;
  eps?: number;
  portal?: React.MutableRefObject<HTMLElement>;
  sprite?: boolean;
  transform?: boolean;
  zIndexRange?: Array<number>;
  calculatePosition?: CalculatePosition;
  as?: string;
  wrapperClass?: string;
  pointerEvents?: PointerEventsProperties;
  contentSize?: { width: number; height: number };

  // Occlusion based off work by Jerome Etienne and James Baicoianu
  // https://www.youtube.com/watch?v=ScZcUEDGjJI
  // as well as Joe Pea in CodePen: https://codepen.io/trusktr/pen/RjzKJx
  occlude?: boolean | 'blending';
  onOcclude?: (visible: boolean) => null;
  scaleOcclusionGeometry?: boolean; // Scale occlusion geometry to match the size of the HTML element, geometry should have a size of 1x1
  material?: React.ReactNode; // Material for occlusion plane
  geometry?: React.ReactNode; // Geometry for occlusion plane
  castShadow?: boolean; // Cast shadow for occlusion plane
  receiveShadow?: boolean; // Receive shadow for occlusion plane
}

export const CustomHtml = React.forwardRef(
  (
    {
      children,
      eps = 0.001,
      style,
      className,
      prepend,
      center,
      fullscreen,
      portal,
      sprite = false,
      transform = false,
      occlude,
      onOcclude,
      castShadow,
      receiveShadow,
      material,
      geometry,
      contentSize,
      zIndexRange = DEFAULT_Z_INDEX_RANGE,
      calculatePosition = defaultCalculatePosition,
      as = 'div',
      wrapperClass,
      pointerEvents = 'auto',
      scaleOcclusionGeometry = true,
      ...props
    }: CustomHtmlProps,
    ref: React.Ref<HTMLDivElement>
  ) => {
    const { gl, camera, scene, size, events, viewport } = useThree();

    // VIZCOM MODIFIED: inject these context in the HTML element
    const RecoilBridge = useRecoilBridgeAcrossReactRoots_UNSTABLE();
    const ContextBridge = useContextBridge(
      ThemeContext,
      UNSAFE_LocationContext,
      UNSAFE_NavigationContext,
      isDraggingContext,
      isViewerContext,
      UrqlContext,
      SelectionApiContext,
      context // allow using useThree inside the html component
    );

    const [el] = React.useState(() => document.createElement(as));
    const root = React.useRef<ReactDOM.Root>();
    const group = React.useRef<Group>(null!);
    const oldZoom = React.useRef(0);
    const oldPosition = React.useRef([0, 0]);
    const oldCameraPosition = React.useRef([0, 0]);
    const oldSize = React.useRef<
      | {
          width: number;
          height: number;
        }
      | undefined
    >();
    const transformOuterRef = React.useRef<HTMLDivElement>(null!);
    const transformInnerRef = React.useRef<HTMLDivElement>(null!);
    const isRendered = React.useRef(false);

    // Append to the connected element, which makes HTML work with views
    const target = (portal?.current ||
      events.connected ||
      gl.domElement.parentNode) as HTMLElement;

    const occlusionMeshRef = React.useRef<Mesh>(null!);
    const isMeshSizeSet = React.useRef<boolean>(false);

    React.useLayoutEffect(() => {
      const el = gl.domElement as HTMLCanvasElement;

      el.style.zIndex = `${Math.floor(zIndexRange[0] / 2)}`;
      el.style.position = 'absolute';
      el.style.pointerEvents = 'none';
    }, []);

    React.useLayoutEffect(() => {
      if (group.current) {
        const currentRoot = (root.current = ReactDOM.createRoot(el));
        scene.updateMatrixWorld();
        if (transform) {
          el.style.cssText = `position:absolute;top:0;left:0;pointer-events:none;overflow:hidden;`;
        } else {
          const vec = calculatePosition(group.current, camera, size);
          el.style.cssText = `position:absolute;top:0;left:0;transform:translate3d(${vec[0]}px,${vec[1]}px,0);transform-origin:0 0;`;
        }
        if (target) {
          if (prepend) target.prepend(el);
          else target.appendChild(el);
        }
        return () => {
          if (target) target.removeChild(el);
          currentRoot.unmount();
        };
      }
    }, [target, transform]);

    React.useLayoutEffect(() => {
      if (wrapperClass) el.className = wrapperClass;
    }, [wrapperClass]);

    const styles: React.CSSProperties = React.useMemo(() => {
      if (transform) {
        return {
          position: 'absolute',
          top: 0,
          left: 0,
          width: size.width,
          height: size.height,
          transformStyle: 'preserve-3d',
          pointerEvents: 'none',
        };
      } else {
        return {
          position: 'absolute',
          transform: center ? 'translate3d(-50%,-50%,0)' : 'none',
          ...(fullscreen && {
            top: -size.height / 2,
            left: -size.width / 2,
            width: size.width,
            height: size.height,
          }),
          ...style,
        };
      }
    }, [style, center, fullscreen, size, transform]);

    const transformInnerStyles: React.CSSProperties = React.useMemo(
      () => ({ position: 'absolute', pointerEvents }),
      [pointerEvents]
    );

    React.useLayoutEffect(() => {
      isMeshSizeSet.current = false;

      if (transform) {
        root.current?.render(
          <RecoilBridge>
            <ContextBridge>
              <div ref={transformOuterRef} style={styles}>
                <div ref={transformInnerRef} style={transformInnerStyles}>
                  <div
                    ref={ref}
                    className={className}
                    style={style}
                    children={children}
                  />
                </div>
              </div>
            </ContextBridge>
          </RecoilBridge>
        );
      } else {
        root.current?.render(
          <RecoilBridge>
            <ContextBridge>
              <div
                ref={ref}
                style={styles}
                className={className}
                children={children}
              />
            </ContextBridge>
          </RecoilBridge>
        );
      }
    });

    useFrame((gl) => {
      if (group.current) {
        camera.updateMatrixWorld();
        group.current.updateWorldMatrix(true, false);
        const vec = calculatePosition(group.current, camera, size);

        if (
          (Math.abs(oldZoom.current - camera.zoom) > eps ||
            Math.abs(oldPosition.current[0] - vec[0]) > eps ||
            Math.abs(oldPosition.current[1] - vec[1]) > eps ||
            Math.abs(oldCameraPosition.current[0] - camera.position.x) > eps ||
            oldSize.current?.width !== contentSize?.width ||
            oldSize.current?.height !== contentSize?.height ||
            Math.abs(oldCameraPosition.current[1] - camera.position.y) > eps) &&
          (transform ? isRendered.current : true)
        ) {
          if (
            contentSize &&
            !isInviewport(group.current, camera, contentSize)
          ) {
            el.style.display = 'none';

            oldPosition.current = vec;
            oldZoom.current = camera.zoom;
            oldCameraPosition.current = [camera.position.x, camera.position.y];
            oldSize.current = contentSize;

            return;
          }
          el.style.display = 'block';

          const halfRange = Math.floor(zIndexRange[0] / 2);
          // when object should be occluded, it will be placed below the canvas zIndex
          // then the canvas will be transparent where the element should be displayed and the element will show below it
          // if the element is not occluded, it will be placed above the canvas zIndex
          const zRange = occlude
            ? [halfRange - 1, 0]
            : [zIndexRange[0], halfRange - 1];
          el.style.zIndex = `${objectZIndex(group.current, camera, zRange)}`;

          if (transform) {
            const [widthHalf, heightHalf] = [size.width / 2, size.height / 2];
            const fov = camera.projectionMatrix.elements[5] * heightHalf;
            const { isOrthographicCamera, top, left, bottom, right } =
              camera as OrthographicCamera;
            const cameraMatrix = getCameraCSSMatrix(camera.matrixWorldInverse);
            const cameraTransform = isOrthographicCamera
              ? `scale(${fov})translate(${epsilon(
                  -(right + left) / 2
                )}px,${epsilon((top + bottom) / 2)}px)`
              : `translateZ(${fov}px)`;
            let matrix = group.current.matrixWorld;
            if (sprite) {
              matrix = camera.matrixWorldInverse
                .clone()
                .transpose()
                .copyPosition(matrix)
                .scale(group.current.scale);
              matrix.elements[3] = matrix.elements[7] = matrix.elements[11] = 0;
              matrix.elements[15] = 1;
            }
            el.style.width = size.width + 'px';
            el.style.height = size.height + 'px';
            el.style.perspective = isOrthographicCamera ? '' : `${fov}px`;

            if (transformOuterRef.current && transformInnerRef.current) {
              transformOuterRef.current.style.transform = `${cameraTransform}${cameraMatrix}translate(${widthHalf}px,${heightHalf}px)`;
              transformInnerRef.current.style.transform = getObjectCSSMatrix(
                matrix,
                1
              );
            }
          } else {
            el.style.transform = `translate3d(${vec[0]}px,${vec[1]}px,0)`;
          }
          oldPosition.current = vec;
          oldZoom.current = camera.zoom;
          oldCameraPosition.current = [camera.position.x, camera.position.y];
          oldSize.current = contentSize;
        }

        if (
          !isRendered.current &&
          transformOuterRef.current &&
          transformInnerRef.current
        ) {
          isRendered.current = true;
        }
      }

      if (!scaleOcclusionGeometry) {
        occlusionMeshRef.current.scale.set(1, 1, 1);
      } else if (occlusionMeshRef.current && !isMeshSizeSet.current) {
        if (transform) {
          if (transformOuterRef.current) {
            const el = transformOuterRef.current.children[0];

            if (el?.clientWidth && el?.clientHeight) {
              // NOTE: Guillaume: This was manually modified to simplify this case and fix a problem with othographic camera
              occlusionMeshRef.current.scale.set(
                el.clientWidth,
                el.clientHeight,
                1
              );

              isMeshSizeSet.current = true;
            }
          }
        } else {
          const ele = el.children[0];

          if (ele?.clientWidth && ele?.clientHeight) {
            const ratio = 1 / viewport.factor;
            const w = ele.clientWidth * ratio;
            const h = ele.clientHeight * ratio;

            occlusionMeshRef.current.scale.set(w, h, 1);

            isMeshSizeSet.current = true;
          }

          occlusionMeshRef.current.lookAt(gl.camera.position);
        }
      }
    });

    const shaders = React.useMemo(
      () => ({
        vertexShader: !transform
          ? /* glsl */ `
          /*
            This shader is from the THREE's SpriteMaterial.
            We need to turn the backing plane into a Sprite
            (make it always face the camera) if "transfrom"
            is false.
          */
          #include <common>

          void main() {
            vec2 center = vec2(0., 1.);
            float rotation = 0.0;

            // This is somewhat arbitrary, but it seems to work well
            // Need to figure out how to derive this dynamically if it even matters
            float size = 0.03;

            vec4 mvPosition = modelViewMatrix * vec4( 0.0, 0.0, 0.0, 1.0 );
            vec2 scale;
            scale.x = length( vec3( modelMatrix[ 0 ].x, modelMatrix[ 0 ].y, modelMatrix[ 0 ].z ) );
            scale.y = length( vec3( modelMatrix[ 1 ].x, modelMatrix[ 1 ].y, modelMatrix[ 1 ].z ) );

            bool isPerspective = isPerspectiveMatrix( projectionMatrix );
            if ( isPerspective ) scale *= - mvPosition.z;

            vec2 alignedPosition = ( position.xy - ( center - vec2( 0.5 ) ) ) * scale * size;
            vec2 rotatedPosition;
            rotatedPosition.x = cos( rotation ) * alignedPosition.x - sin( rotation ) * alignedPosition.y;
            rotatedPosition.y = sin( rotation ) * alignedPosition.x + cos( rotation ) * alignedPosition.y;
            mvPosition.xy += rotatedPosition;

            gl_Position = projectionMatrix * mvPosition;
          }
      `
          : undefined,
        fragmentShader: /* glsl */ `
        void main() {
          gl_FragColor = vec4(0.0, 0.0, 0.0, 0.0);
        }
      `,
      }),
      [transform]
    );

    return (
      <group {...props} ref={group}>
        {occlude && (
          <mesh
            castShadow={castShadow}
            receiveShadow={receiveShadow}
            ref={occlusionMeshRef}
          >
            {geometry || <planeGeometry />}
            {material || (
              <shaderMaterial
                side={DoubleSide}
                vertexShader={shaders.vertexShader}
                fragmentShader={shaders.fragmentShader}
              />
            )}
          </mesh>
        )}
      </group>
    );
  }
);
