import {
  forwardRef,
  useCallback,
  KeyboardEvent,
  ComponentPropsWithRef,
  useReducer,
  useLayoutEffect,
  ReactNode,
  Reducer,
  Ref,
  useEffect,
  useState,
  ReactElement,
} from "react";
import styled from "styled-components";
import update from "immutability-helper";

import { useUniqueId } from "../../../hooks";
import { propagateRef } from "../../../lib";
import {
  ArrowDownIcon,
  ArrowRightIcon,
  CheckboxCheckedIcon,
  CheckboxUncheckedIcon,
} from "../../../icons";
import computeScrollIntoView from "compute-scroll-into-view";
import { Skeleton } from "../../Skeleton";

const TreeContainer = styled.ul<{ cursorMode: "mouse" | "keyboard" }>`
  list-style: none;
  width: 100%;

  &:focus {
    outline: none;
  }

  [id=${(props) => props["aria-activedescendant"]}] {
    background: ${(props) => props.theme.SelectOptionHoverBackground};

    ${(props) =>
      props.cursorMode === "keyboard" && `box-shadow: ${props.theme.FocusRing}`}
  }
`;

const TreeItem = styled.li<{ "aria-level": number }>`
  display: flex;
  .nesting-spacer {
    width: calc(${(props) => props["aria-level"] - 1} * var(--spacing-sm));
  }
`;

const TreeItemLabel = styled.div`
  font-size: inherit;
  cursor: default;
  user-select: none;

  flex-grow: 10;

  display: flex;
  align-items: center;

  padding: ${(props) => props.theme.SecondaryPadding};

  .arrow {
    margin-right: var(--spacing-xs);
    color: ${(props) => props.theme.ControlIconColor};
  }
`;

const TreeItemSelectBox = styled.div<{ checked: boolean }>`
  padding: ${(props) => props.theme.SecondaryMediumPadding};

  border-radius: ${(props) => props.theme.CornerInsetRadius};
  color: ${(props) =>
    props.checked
      ? props.theme.ControlOnBackground
      : props.theme.ControlIconColor};
`;

export type TreeNode<T> = {
  label: string;
  value: T;
  children?: TreeNode<T>[] | (() => Promise<TreeNode<T>[]>);
  childCount?: number;
};

type ViewNode<T> = {
  htmlId: string;
  level: number;
  childNodeIds: string[];
  siblingCount: number;
  siblingIndex: number;
} & (
  | {
      type: "loading";
      treeNode: undefined;
    }
  | {
      type: "error";
      treeNode: undefined;
      error: ReactNode;
    }
  | {
      type: "open" | "closed" | "end";
      treeNode: TreeNode<T>;
    }
);

type TreeState<T> = {
  cursor: number;
  cursorMode: "mouse" | "keyboard";
  exposedNodes: ViewNode<T>[];
};

type TreeAction<T> =
  | [type: "MOVE_CURSOR_TO", position: number]
  | [type: "MOVE_CURSOR_BY", offset: number]
  | [type: "MOVE_CURSOR_TO_END"]
  | [type: "SET_CURSOR_MODE", mode: "mouse" | "keyboard"]
  | [type: "TOGGLE_NODE_AT_CURSOR", dispatch: (a: TreeAction<T>) => void]
  | [
      type: "HYDRATE_VIEW_NODES",
      data: { parentNodeId: string; treeNodes: TreeNode<T>[] }
    ]
  | [
      type: "ERROR_LOADING_NODES",
      data: { parentNodeId: string; error: ReactNode }
    ]
  | [type: "ARROW_RIGHT", dispatch: (a: TreeAction<T>) => void]
  | [type: "ARROW_LEFT"];

function moveCursorTo<T>(state: TreeState<T>, position: number): TreeState<T> {
  if (position >= state.exposedNodes.length || position < 0) {
    return state;
  } else {
    return update(state, {
      cursor: { $set: position },
    });
  }
}

function setCursorMode<T>(
  state: TreeState<T>,
  mode: "mouse" | "keyboard"
): TreeState<T> {
  return update(state, {
    cursorMode: { $set: mode },
  });
}

function openNodeAtCursor<T>(
  state: TreeState<T>,
  dispatch: (t: TreeAction<T>) => void
): TreeState<T> {
  /**
   * to open a node in our tree, we first insert dummy "loading" nodes into our `exposedNodes`
   * corresponding to the number of children given by the current node's childCount.
   *
   * then, we call the children() callback, which will give us the data to actually fill in our dummy nodes.
   */
  let { cursor, exposedNodes } = state;
  /**
   * first, check that this call makes sense at all.
   *  - the cursor must be on a real exposed node
   *  - the node under the cursor must be a "closed" node
   *  - the node under the cursor must have a defined childCount
   */
  if (cursor >= exposedNodes.length) return state;
  let selectedNode = exposedNodes[cursor];

  if (selectedNode.type !== "closed") return state;
  if (!selectedNode.treeNode?.childCount) return state;

  let childrenNodes: ViewNode<T>[] = [];

  /**
   * there are two possible types at this point for the value of a TreeNode's `children` prop
   * a) TreeNode<T>[] -- the case where the children were already known when the parent was created
   * b) () => Promise<TreeNode<T>[]> -- the case where we need to dynamically load the children
   *
   * the first and last steps are similar in both cases. We need to create the correct number of
   * ViewNodes for the children, and insert them in the correct place in the exposedNodes array.
   *
   * in the case where we dynamically load the children, we create "loading" nodes at first,
   * and only later will we actually go back and fill in the data for them, once the callback's
   * promise finishes
   */
  if (Array.isArray(selectedNode.treeNode.children)) {
    /**
     * if `children` is an array (rather than a callback) we can just fill in
     * the child nodes immediately.
     */
    childrenNodes = selectedNode.treeNode.children.map(
      (treeNode, index, siblings) =>
        ({
          type:
            treeNode.children && treeNode.childCount !== 0 ? "closed" : "end",
          treeNode,
          htmlId: `${selectedNode.htmlId}_${index + 1}`,
          level: selectedNode.level + 1,
          siblingIndex: index + 1,
          siblingCount: siblings.length,
          childNodeIds: [],
        } as ViewNode<T>)
    );
  } else {
    /**
     * if `children` is a callback, we first need to create enough dummy
     * nodes to account for the children we're about to load
     */
    for (let i = 1; i <= selectedNode.treeNode.childCount; i++) {
      childrenNodes.push({
        type: "loading",
        treeNode: undefined,
        htmlId: `${selectedNode.htmlId}_${i}`,
        level: selectedNode.level + 1,
        siblingIndex: i,
        siblingCount: selectedNode.treeNode.childCount,
        childNodeIds: [],
      });
    }

    /**
     * use the provided callback to load the children
     */
    selectedNode.treeNode.children?.().then(
      (childTreeNodes) => {
        dispatch([
          "HYDRATE_VIEW_NODES",
          {
            parentNodeId: selectedNode.htmlId,
            treeNodes: childTreeNodes,
          },
        ]);
      },
      (error) => {
        dispatch([
          "ERROR_LOADING_NODES",
          {
            parentNodeId: selectedNode.htmlId,
            error,
          },
        ]);
      }
    );
  }

  return update(state, {
    exposedNodes: {
      [cursor]: {
        type: { $set: "open" },
        childNodeIds: { $set: childrenNodes.map((node) => node.htmlId) },
      },
      $splice: [[cursor + 1, 0, ...childrenNodes]],
    },
  });
}

function hydrateViewNodes<T>(
  state: TreeState<T>,
  parentNodeId: string,
  treeNodes: TreeNode<T>[]
): TreeState<T> {
  let parentPosition = state.exposedNodes.findIndex(
    (viewNode) => viewNode.htmlId === parentNodeId
  );
  /**
   * update parentNode with the current object representing that node in state,
   * tossing the object we got passed as it might be out of date.
   */
  let parentNode = state.exposedNodes[parentPosition];

  /**
   * check that parentNode is still in a state where loading the children makes sense
   */
  if (!parentNode) return state;
  if (parentNode.type !== "open") return state;

  let childViewNodes: ViewNode<T>[] = treeNodes.map((node, index) => ({
    type: node.children && node.childCount !== 0 ? "closed" : "end",
    treeNode: node,
    htmlId: `${parentNode.htmlId}_${index + 1}`,
    level: parentNode.level + 1,
    siblingIndex: index + 1,
    siblingCount: treeNodes.length,
    childNodeIds: [],
  }));

  return update(state, {
    exposedNodes: {
      /**
       * update the parent node with the correct childNodeIds, in case we added or removed any
       */
      [parentPosition]: {
        childNodeIds: { $set: childViewNodes.map((node) => node.htmlId) },
      },
      /**
       * delete the 'loading' nodes from before, and splice in our newly loaded nodes
       */
      $splice: [
        [parentPosition + 1, parentNode.childNodeIds.length, ...childViewNodes],
      ],
    },
  });
}

function errorLoadingNodes<T>(
  state: TreeState<T>,
  parentNodeId: string,
  error: ReactNode
): TreeState<T> {
  let parentPosition = state.exposedNodes.findIndex(
    (viewNode) => viewNode.htmlId === parentNodeId
  );
  /**
   * update parentNode with the current object representing that node in state,
   * tossing the object we got passed as it might be out of date.
   */
  let parentNode = state.exposedNodes[parentPosition];

  /**
   * check that parentNode is still in a state where loading the children makes sense
   */
  if (!parentNode) return state;
  if (parentNode.type !== "open") return state;

  let childViewNodes: ViewNode<T>[] = [
    {
      type: "error",
      error,
      treeNode: undefined,
      htmlId: `${parentNode.htmlId}_1`,
      level: parentNode.level + 1,
      siblingIndex: 1,
      siblingCount: 1,
      childNodeIds: [],
    },
  ];

  return update(state, {
    exposedNodes: {
      /**
       * update the parent node with the correct childNodeIds, in case we added or removed any
       */
      [parentPosition]: {
        childNodeIds: { $set: childViewNodes.map((node) => node.htmlId) },
      },
      /**
       * delete the 'loading' nodes from before, and splice in our newly loaded nodes
       */
      $splice: [
        [parentPosition + 1, parentNode.childNodeIds.length, ...childViewNodes],
      ],
    },
  });
}

function closeNodeAtCursor<T>(state: TreeState<T>): TreeState<T> {
  /**
   * closing a node in the tree is easier than opening one; we just need to find
   * it in the exposedNodes and remove all its descendants from the list.
   */
  let { exposedNodes, cursor } = state;
  if (cursor >= exposedNodes.length) return state;
  let selectedNode = exposedNodes[cursor];

  if (selectedNode.type !== "open") return state;

  let numberOfExposedDescendants = 0;

  let loopCursor = cursor + 1;
  while (
    loopCursor < exposedNodes.length &&
    exposedNodes[loopCursor].level > selectedNode.level
  ) {
    loopCursor++, numberOfExposedDescendants++;
  }

  return update(state, {
    exposedNodes: {
      [cursor]: { type: { $set: "closed" }, childNodeIds: { $set: [] } },
      $splice: [[cursor + 1, numberOfExposedDescendants]],
    },
  });
}

function arrowRight<T>(
  state: TreeState<T>,
  dispatch: (a: TreeAction<T>) => void
): TreeState<T> {
  let currentNode = state.exposedNodes[state.cursor];
  if (currentNode.type === "closed") {
    // open a closed node
    return openNodeAtCursor(state, dispatch);
  } else if (currentNode.type === "open") {
    // move to the first child of an open node
    return moveCursorTo(state, state.cursor + 1);
  } else {
    return state;
  }
}

function arrowLeft<T>(state: TreeState<T>): TreeState<T> {
  let currentNode = state.exposedNodes[state.cursor];
  if (currentNode.type === "open") {
    // close an open node
    return closeNodeAtCursor(state);
  } else if (currentNode.level === 1) {
    // do nothing on root nodes
    return state;
  } else {
    // otherwise move to the parent of the current node
    let parentSearchCursor = state.cursor;
    while (state.exposedNodes[parentSearchCursor].level >= currentNode.level) {
      parentSearchCursor--;
    }

    return moveCursorTo(state, parentSearchCursor);
  }
}

export function reducer<T>(
  state: Readonly<TreeState<T>>,
  action: Readonly<TreeAction<T>>
): TreeState<T> {
  let currentNode = state.exposedNodes[state.cursor];
  switch (action[0]) {
    case "MOVE_CURSOR_TO":
      return moveCursorTo(state, action[1]);
    case "MOVE_CURSOR_BY":
      return moveCursorTo(state, state.cursor + action[1]);
    case "MOVE_CURSOR_TO_END":
      return moveCursorTo(state, state.exposedNodes.length - 1);
    case "SET_CURSOR_MODE":
      return setCursorMode(state, action[1]);
    case "TOGGLE_NODE_AT_CURSOR":
      if (currentNode.type === "closed") {
        return openNodeAtCursor(state, action[1]);
      } else if (currentNode.type === "open") {
        return closeNodeAtCursor(state);
      } else {
        return state;
      }
    case "HYDRATE_VIEW_NODES":
      return hydrateViewNodes(
        state,
        action[1].parentNodeId,
        action[1].treeNodes
      );
    case "ERROR_LOADING_NODES":
      return errorLoadingNodes(state, action[1].parentNodeId, action[1].error);
    case "ARROW_LEFT":
      return arrowLeft(state);
    case "ARROW_RIGHT":
      return arrowRight(state, action[1]);
  }
}

type SelectableProps<T> = {
  /**
   * isSelected should return `undefined` when the value is not selectable at all
   */
  isSelected: (value: T) => boolean | undefined;
  onSelect: (value: T, selected: boolean) => void;
};

type NonSelectableProps = {
  isSelected?: never;
  onSelect?: never;
};

type TreeViewProps<T> = Omit<ComponentPropsWithRef<"ul">, "onSelect"> & {
  roots: TreeNode<T>[];
  singleSelect?: boolean;
  onActiveDescendantChange?: (activeDescendantId: string | undefined) => void;
} & (SelectableProps<T> | NonSelectableProps);

export const TreeView = forwardRef(function <T>(
  props: TreeViewProps<T>,
  forwardedRef: Ref<HTMLUListElement>
) {
  let {
    roots,
    isSelected,
    singleSelect,
    onSelect,
    onActiveDescendantChange,
    ...ulProps
  } = props;

  /**
   * Keep track of the list element so we can scroll elements into view
   * if it is currently focused
   */
  let [listElement, setListElement] = useState<HTMLUListElement | null>(null);

  /**
   * create a unique id for this TreeView component, to create uniquely
   * distinguishable IDs for it and its children (which are necessary for certain
   * ARIA requirements)
   */
  let treeId = useUniqueId();

  let [state, dispatch] = useReducer<Reducer<TreeState<T>, TreeAction<T>>>(
    reducer,
    {
      /**
       * the cursor represents the index of the currently "focused" node.
       * rather than using DOM focus, we exclusively use this cursor to track
       * which node in the tree is "active" and will be the target of any
       * user interactions.
       *
       * to the DOM, this cursor is represented in the `aria-activedescendant`
       * property of the outermost element of the TreeView, which will take the ID
       * of the node under this cursor.
       */
      cursor: 0,
      cursorMode: "mouse",

      /**
       * "focus"/activedescendant and keyboard navigation will both be easier to manage on
       * a flat list than they would be in a recursive DOM structure. This makes certain
       * other aspects of the component less straightforward, but the keyboard navigation
       * will be simple enough to make up for it.
       *
       * So, we will do our bookkeeping on an internal list of "ViewNodes", which are just
       * structs containing all the information necessary to render a single node in our tree.
       */
      exposedNodes: roots.map((root, index) => ({
        type: root.children && root.childCount !== 0 ? "closed" : "end",
        treeNode: root,
        htmlId: `TreeView_${treeId}-node_${index + 1}`,
        level: 1,
        siblingIndex: index + 1,
        siblingCount: roots.length,
        childNodeIds: [],
      })),
    }
  );

  let currentNode: ViewNode<T> | undefined = state.exposedNodes[state.cursor];

  /**
   * Keyboard navigation in a TreeView is specified by ARIA as follows.
   *  Right arrow:
   *   - When focus is on a closed node, opens the node; focus does not move.
   *   - When focus is on a open node, moves focus to the first child node.
   *   - When focus is on an end node, does nothing.
   *  Left arrow:
   *   - When focus is on an open node, closes the node.
   *   - When focus is on a child node that is also either an end node or
   *     a closed node, moves focus to its parent node.
   *   - When focus is on a root node that is also either an end node or
   *     a closed node, does nothing.
   *  Down Arrow: Moves focus to the next node that is focusable without opening
   *              or closing a node.
   *  Up Arrow: Moves focus to the previous node that is focusable without
   *            opening or closing a node.
   *  Home: Moves focus to the first node in the tree without opening or closing
   *        a node.
   *  End: Moves focus to the last node in the tree that is focusable without
   *       opening a node.
   *  Enter: activates a node, i.e., performs its default action.
   *   - For parent nodes, one possible default action is to open or close the
   *     node.
   *   - In single-select trees where selection does not follow focus,
   *     the default action is typically to select the focused node.
   */
  let handleKeyDown = useCallback(
    (evt: KeyboardEvent<HTMLUListElement>) => {
      dispatch(["SET_CURSOR_MODE", "keyboard"]);

      switch (evt.key) {
        case "ArrowRight":
          dispatch(["ARROW_RIGHT", dispatch]);
          break;
        case "ArrowLeft":
          dispatch(["ARROW_LEFT"]);
          break;
        case "ArrowDown":
          dispatch(["MOVE_CURSOR_BY", 1]);
          break;
        case "ArrowUp":
          dispatch(["MOVE_CURSOR_BY", -1]);
          break;
        case "Home":
          dispatch(["MOVE_CURSOR_TO", 0]);
          break;
        case "End":
          dispatch(["MOVE_CURSOR_TO_END"]);
          break;
        case " ": {
          let selected =
            currentNode?.treeNode &&
            isSelected &&
            isSelected(currentNode.treeNode.value);

          if (currentNode?.treeNode && onSelect && selected != null) {
            onSelect(currentNode.treeNode.value, !selected);
          }

          break;
        }
        default:
          /**
           * we're going to preventDefault on the event, since all the keys
           * we're handling also have default behaviors around scrolling. Any
           * unhandled keyDown events should proceed as normal though!
           */
          return;
      }

      evt.preventDefault();
    },
    [onSelect, isSelected, currentNode]
  );

  useLayoutEffect(() => {
    if (
      document.activeElement &&
      document.activeElement === listElement &&
      state.cursorMode === "keyboard" &&
      currentNode?.htmlId
    ) {
      const element = document.querySelector<HTMLElement>(
        `#${currentNode.htmlId}`
      );
      if (element) {
        const actions = computeScrollIntoView(element, {
          scrollMode: "if-needed",
          block: "nearest",
          inline: "nearest",
        });
        actions.forEach(({ el, top }) => {
          el.scrollTop = top;
        });
      }
    }
  }, [currentNode?.htmlId, state.cursorMode, listElement]);

  useEffect(() => {
    onActiveDescendantChange && onActiveDescendantChange(currentNode?.htmlId);
  }, [onActiveDescendantChange, currentNode.htmlId]);

  return (
    <TreeContainer
      {...ulProps}
      role="tree"
      aria-activedescendant={currentNode && currentNode.htmlId}
      cursorMode={state.cursorMode}
      onKeyDown={handleKeyDown}
      aria-multiselectable={!!onSelect}
      ref={(el) => {
        if (forwardedRef) {
          propagateRef(forwardedRef, el);
        }
        setListElement(el);
      }}
    >
      {state.exposedNodes.map((viewNode, index) => {
        let ariaExpandedStates = {
          closed: false,
          open: true,
          end: undefined,
          loading: undefined,
          error: undefined,
        };

        let treeNodeIsSelected =
          viewNode.treeNode && isSelected?.(viewNode.treeNode.value);

        const entireTreeItemSelectable =
          singleSelect && viewNode.treeNode && viewNode.type === "end";

        return (
          <TreeItem
            role="treeitem"
            key={viewNode.htmlId}
            id={viewNode.htmlId}
            aria-level={viewNode.level}
            aria-setsize={viewNode.siblingCount}
            aria-posinset={viewNode.siblingIndex}
            aria-expanded={ariaExpandedStates[viewNode.type]}
            aria-selected={viewNode.treeNode && treeNodeIsSelected}
            onClick={() =>
              entireTreeItemSelectable &&
              onSelect?.(viewNode.treeNode.value, !treeNodeIsSelected)
            }
            onMouseMoveCapture={() => {
              dispatch(["SET_CURSOR_MODE", "mouse"]);
              dispatch(["MOVE_CURSOR_TO", index]);
            }}
            onClickCapture={() => {
              dispatch(["SET_CURSOR_MODE", "mouse"]);
              dispatch(["MOVE_CURSOR_TO", index]);
            }}
            aria-busy={viewNode.type === "loading"}
            aria-live={viewNode.type === "loading" ? "polite" : undefined}
            aria-describedby={
              viewNode.type === "loading"
                ? `${viewNode.htmlId}_loading`
                : undefined
            }
          >
            <div role="group" aria-owns={viewNode.childNodeIds.join(" ")} />
            <TreeItemLabel
              onClick={() => {
                dispatch(["TOGGLE_NODE_AT_CURSOR", dispatch]);
              }}
            >
              <div className="nesting-spacer" />
              {viewNode.type === "open" ? (
                <>
                  <ArrowDownIcon role="button" className="arrow" size="small" />
                  {viewNode.treeNode.label}
                </>
              ) : viewNode.type === "closed" ? (
                <>
                  <ArrowRightIcon
                    role="button"
                    className="arrow"
                    size="small"
                  />
                  {viewNode.treeNode.label}
                </>
              ) : viewNode.type === "end" ? (
                viewNode.treeNode.label
              ) : viewNode.type === "error" ? (
                viewNode.error
              ) : (
                /* viewNode.type === "loading" */
                <Skeleton id={`${viewNode.htmlId}_loading`} label="loading" />
              )}
            </TreeItemLabel>

            {onSelect &&
              !singleSelect &&
              viewNode.type !== "loading" &&
              viewNode.type !== "error" &&
              treeNodeIsSelected != null && (
                <TreeItemSelectBox
                  className="treeitem-select-box"
                  checked={treeNodeIsSelected}
                  onClick={() => {
                    onSelect?.(viewNode.treeNode.value, !treeNodeIsSelected);
                  }}
                >
                  {treeNodeIsSelected ? (
                    <CheckboxCheckedIcon aria-hidden={true} size="regular" />
                  ) : (
                    <CheckboxUncheckedIcon aria-hidden={true} size="regular" />
                  )}
                </TreeItemSelectBox>
              )}
          </TreeItem>
        );
      })}
    </TreeContainer>
  );
}) as (<T>(props: TreeViewProps<T>) => ReactElement) & { displayName?: string };

TreeView.displayName = "ARIA.TreeView";
