import { flatMap, groupBy, map } from "lodash";
import * as dagre from "dagre";
import {
  Dispatch,
  SetStateAction,
  useCallback,
  useEffect,
  useState,
} from "react";
import {
  Node,
  Edge,
  Position,
  useReactFlow,
  applyNodeChanges,
  NodeChange,
  applyEdgeChanges,
  EdgeChange,
} from "reactflow";

import { isFunc } from "./fn";
import { maybeMap } from "./maybe";

const rand = () => Math.round(Math.random() * 400);

interface FlowLayoutOptions {
  orientation?: "horizontal" | "vertical";
}

type LayoutStrategy<T = any> = (
  nodes: Node<T>[],
  edges: Edge<T>[],
  opts?: FlowLayoutOptions
) => { nodes: Node<T>[]; edges: Edge<T>[] };

export const randomLayout = (): LayoutStrategy => (nodes, edges, opts) => ({
  nodes: map(nodes, (n) => ({ ...n, position: { x: rand(), y: rand() } })),
  edges: map(edges, (e) => e),
});

export const treeLayout =
  (
    sizes: { width: number; height: number; hSpace?: number; vSpace?: number },
    opts: FlowLayoutOptions
  ): LayoutStrategy =>
  (nodes, edges) => {
    const parentGroup = groupBy(nodes, (n) => n.parentNode);
    const parents = Object.keys(parentGroup);
    const allNodes = flatMap(parents, (p) => {
      const inner = parentGroup[p];
      const dagreGraph = new dagre.graphlib.Graph({
        compound: false,
        directed: true,
        multigraph: true,
      });
      dagreGraph.setDefaultEdgeLabel(() => ({}));

      const isHorizontal = opts?.orientation !== "vertical";

      dagreGraph.setGraph({
        rankdir: isHorizontal ? "LR" : "TB",
        ranksep: 10 * (sizes.vSpace || 1),
        nodesep: 10 * (sizes.hSpace || 1),
      });

      inner.forEach((node) => {
        dagreGraph.setNode(node.id, {
          width: sizes.width,
          height: sizes.height,
        });
      });

      edges.forEach((edge) => {
        dagreGraph.setEdge(edge.source, edge.target);
      });

      dagre.layout(dagreGraph);

      return map(inner, (node) => {
        const position = dagreGraph.node(node.id);

        return {
          ...node,
          targetPosition: isHorizontal ? Position.Left : Position.Top,
          sourcePosition: isHorizontal ? Position.Right : Position.Bottom,

          // We are shifting the dagre node position (anchor=center center) to the top left
          // so it matches the React Flow node anchor point (top left).
          position: {
            x: Math.floor(position.x - sizes.width / 2),
            y: Math.floor(position.y - sizes.height / 2),
          },
        };
      });
    });

    return {
      nodes: allNodes,
      edges,
    };
  };

export type ReactFlowState<T> = { nodes: Node<T>[]; edges: Edge<T>[] };

export function useDefaultState<T>() {
  const { fitView } = useReactFlow();
  const [data, _setData] = useState<ReactFlowState<T>>({
    nodes: [],
    edges: [],
  });

  const setData = useCallback<Dispatch<SetStateAction<ReactFlowState<T>>>>(
    (newV) => {
      _setData((data) => {
        const selected = maybeMap(data.nodes, (n) =>
          n.selected ? n.id : undefined
        );
        const newData = isFunc(newV) ? newV(data) : newV;

        const { nodes, edges } = {
          nodes: newData.nodes || data.nodes,
          edges: newData.edges || data.edges,
        };

        setTimeout(
          () => fitView({ duration: 100, padding: 0.1, maxZoom: 1 }),
          200
        );

        return {
          nodes: map(nodes, (n) => ({
            ...n,
            selected: n.selected || selected.includes(n.id),
          })),
          edges: edges,
        };
      });
    },
    [_setData, fitView]
  );

  const onNodesChange = useCallback(
    (changes: NodeChange[]) =>
      _setData((d) => ({ ...d, nodes: applyNodeChanges(changes, d.nodes) })),
    [_setData]
  );

  const onEdgesChange = useCallback(
    (changes: EdgeChange[]) =>
      _setData((d) => ({ ...d, edges: applyEdgeChanges(changes, d.edges) })),
    [_setData]
  );

  return {
    nodes: data.nodes,
    edges: data.edges,
    setData,
    onNodesChange,
    onEdgesChange,
  };
}

export function useLayoutStrategy<T>(
  strategy: LayoutStrategy,
  _data?: ReactFlowState<T>
) {
  const { fitView } = useReactFlow();
  const [data, _setData] = useState<ReactFlowState<T>>(
    () =>
      _data || {
        nodes: [],
        edges: [],
      }
  );

  const setData = useCallback<Dispatch<SetStateAction<ReactFlowState<T>>>>(
    (newV) => {
      _setData((data) => {
        const selected = maybeMap(data.nodes, (n) =>
          n.selected ? n.id : undefined
        );
        const newData = isFunc(newV) ? newV(data) : newV;

        const { nodes, edges } = strategy(
          newData.nodes || data.nodes,
          newData.edges || data.edges
        );

        setTimeout(
          () => fitView({ duration: 100, padding: 0.1, maxZoom: 1 }),
          200
        );

        return {
          nodes: map(nodes, (n) => ({
            ...n,
            selected: n.selected || selected.includes(n.id),
          })),
          edges: edges,
        };
      });
    },
    [_setData, fitView, strategy]
  );

  const onNodesChange = useCallback(
    (changes: NodeChange[]) =>
      _setData((d) => ({ ...d, nodes: applyNodeChanges(changes, d.nodes) })),
    [_setData]
  );

  const onEdgesChange = useCallback(
    (changes: EdgeChange[]) =>
      _setData((d) => ({ ...d, edges: applyEdgeChanges(changes, d.edges) })),
    [_setData]
  );

  useEffect(() => {
    if (_data) {
      setData(_data);
    }
  }, [_data]);

  return {
    nodes: data.nodes,
    edges: data.edges,
    setData,
    onNodesChange,
    onEdgesChange,
  };
}
