import Konva from "konva";
import { useEffect, useState } from "react";

export interface Position {
  x: number;
  y: number;
}

function getDistance(p1: Position, p2: Position) {
  return Math.sqrt(Math.pow(p2.x - p1.x, 2) + Math.pow(p2.y - p1.y, 2));
}

function getCenter(p1: Position, p2: Position) {
  return {
    x: (p1.x + p2.x) / 2,
    y: (p1.y + p2.y) / 2,
  };
}

export const useKonvaStageScale = (stage: Konva.Stage | null, stageBaseWidth: number) => {
  const [lastCenter, setLastCenter] = useState<Position | null>(null);
  const [lastDist, setLastDist] = useState<number | null>(null);

  useEffect(() => {
    Konva.hitOnDragEnabled = true;
  }, []);

  const onTouchMove = (e: Konva.KonvaEventObject<TouchEvent>) => {
    e.evt.preventDefault();

    const touch1 = e.evt.touches[0];
    const touch2 = e.evt.touches[1];

    if (touch1 && touch2) {
      if (!stage) {
        return;
      }

      if (stage.isDragging()) {
        stage.stopDrag();
      }

      const p1 = {
        x: touch1.clientX,
        y: touch1.clientY,
      };
      const p2 = {
        x: touch2.clientX,
        y: touch2.clientY,
      };

      if (!lastCenter) {
        setLastCenter(getCenter(p1, p2));
        return;
      }
      const newCenter = getCenter(p1, p2);

      const dist = getDistance(p1, p2);

      if (!lastDist) {
        setLastDist(dist);
      }

      // local coordinates of center point
      const pointTo = {
        x: (newCenter.x - stage.x()) / stage.scaleX(),
        y: (newCenter.y - stage.y()) / stage.scaleX(),
      };

      let scale = stage.scaleX() * (dist / (lastDist ?? dist));

      if (scale > 1.5) {
        scale = 1.5;
      }

      const baseScale = stage.width() / stageBaseWidth;

      if (scale < baseScale) {
        scale = baseScale;
      }

      stage.scale({ x: scale, y: scale });

      const dx = newCenter.x - lastCenter.x;
      const dy = newCenter.y - lastCenter.y;

      const currentStageWdith = stageBaseWidth * scale;

      let newPosX = newCenter.x - pointTo.x * scale + dx;
      let newPosY = newCenter.y - pointTo.y * scale + dy;

      if (currentStageWdith + newPosX < stage.width()) {
        newPosX = stage.width() - currentStageWdith;
      }

      if (currentStageWdith + newPosY < stage.height()) {
        newPosY = stage.height() - currentStageWdith;
      }

      const newPos = {
        x: newPosX > 0 ? 0 : newPosX,
        y: newPosY > 0 ? 0 : newPosY,
      };

      stage.position(newPos);

      setLastDist(dist);
      setLastCenter(newCenter);
    }
  };

  const onTouchEnd = () => {
    setLastDist(null);
    setLastCenter(null);
  };

  return {
    onTouchMove,
    onTouchEnd,
  };
};
