import { TreeProps } from 'antd';
import type { DataNode } from 'rc-tree/lib/interface';
import { useCallback, useMemo } from 'react';
import { dataNodeUtils } from '../Forms/FormsEditTree';
import { Tree } from './Tree';

export type CheckableTreeProps<T extends DataNode = DataNode> = Omit<
  TreeProps<T>,
  | 'checkStrictly'
  | 'checkable'
  | 'checkedKeys'
  | 'defaultCheckedKeys'
  | 'onCheck'
> & {
  onCheck: (keys: string[]) => any;
  checkedKeys: string[];
  mode: 'any';
};

function useData<T extends DataNode = DataNode>(props: CheckableTreeProps<T>) {
  const { treeData, checkedKeys } = props;

  return useMemo(() => {
    if (!treeData) {
      return { data: treeData, checked: checkedKeys };
    }

    const checked: string[] = [];

    function isParentChecked(path: T[]) {
      return path.some((node) => checkedKeys.includes(node.key.toString()));
    }

    function map(path: T[], node: T): T {
      const nextPath = [...path, node];
      const isChecked = checkedKeys.includes(node.key.toString());
      const children = node.children?.map((childNode) =>
        map(nextPath, childNode as T),
      );

      if (isParentChecked(path) || isChecked) {
        checked.push(node.key.toString());
      }

      return {
        ...node,
        checkable: true,
        children,
      };
    }

    const data = treeData.map((node) => map([], node));
    return { data, checked };
  }, [treeData, checkedKeys]);
}

function useHandleCheck<T extends DataNode = DataNode>(
  props: CheckableTreeProps<T>,
) {
  const { checkedKeys, onCheck, treeData } = props;

  const handleCheck: TreeProps<T>['onCheck'] = (_, info) => {
    function findImplicitlyCheckedNodesByParent(path: DataNode[]) {
      const firstCheckedParent = path.find((node) =>
        checkedKeys.some((c) => c === node.key),
      );

      return firstCheckedParent
        ? dataNodeUtils
            .flatten(firstCheckedParent)
            .filter((x) => x.key !== info.node.key)
            .map((x) => x.key.toString())
        : [];
    }

    function areAllChildrenChecked(
      node: DataNode,
      knownCheckedChildren: DataNode,
    ): boolean {
      if (!node.children || node.children.length === 0) {
        return false;
      }

      return node.children
        .filter((x) => x.key !== knownCheckedChildren.key)
        .every((child) => checkedKeys.some((c) => c === child.key));
    }

    function findImplicitlyCheckedParents(path: DataNode[]): DataNode[] {
      const arr = path.slice().reverse();
      let childIndex = 0;
      const parents = arr.slice(1);
      const result: DataNode[] = [];

      while (childIndex < parents.length) {
        const node = parents[childIndex];
        const childNode = arr[childIndex];

        if (!childNode) {
          return result;
        }

        if (areAllChildrenChecked(node, childNode)) {
          result.push(node);
          childIndex++;
        } else {
          return result;
        }
      }

      return result;
    }

    const path = dataNodeUtils.findPathInAny(treeData!, info.node.key);
    const parentNode = path ? findImplicitlyCheckedNodesByParent(path) : [];
    const implicitlyCheckedParents = path
      ? findImplicitlyCheckedParents(path).map((x) => x.key.toString())
      : [];
    const checked = parentNode ? [...checkedKeys, ...parentNode] : checkedKeys;
    const nodes = path
      ? [...dataNodeUtils.flatten(info.node), ...path]
      : dataNodeUtils.flatten(info.node);
    let newCheckedKeys = checked.filter((k) => !nodes.some((n) => n.key === k));

    newCheckedKeys = info.checked
      ? [
          ...newCheckedKeys,
          ...implicitlyCheckedParents,
          info.node.key.toString(),
        ]
      : newCheckedKeys;

    onCheck(newCheckedKeys);
  };

  return useCallback(handleCheck, [checkedKeys, onCheck, treeData]);
}

export function CheckableTree<T extends DataNode = DataNode>(
  props: CheckableTreeProps<T>,
) {
  const { onCheck: _, treeData, checkedKeys, ...otherProps } = props;
  const { data, checked } = useData(props);
  const onCheck = useHandleCheck(props);

  return (
    <Tree
      {...otherProps}
      treeData={data}
      checkable
      checkedKeys={checked}
      onCheck={onCheck}
    />
  );
}
