import { LegalEntityShareholder } from '@gain/rpc/app-model'
import { formatPercentage } from '@gain/utils/number'
import generateUtilityClasses from '@mui/material/generateUtilityClasses'
import { useTheme } from '@mui/material/styles'
import { CSSProperties } from '@mui/material/styles/createTypography'
import { active, select } from 'd3'
import { OrgChart as D3OrgChart } from 'd3-org-chart'
import { isEqual } from 'lodash'
import {
  createElement,
  MouseEvent,
  RefObject,
  useCallback,
  useLayoutEffect,
  useMemo,
  useRef,
  useState,
} from 'react'
import { renderToStaticMarkup } from 'react-dom/server'

import { getTextWidth } from '../../../../../features/chart'
import { useTrackActivityCallback } from '../../../../../features/planhat/planhat-hooks'
import ExpandCollapseButton from './ExpandCollapseButton'
import NodeContent from './NodeContent'
import {
  generateNodeId,
  isLegalEntityNode,
  isShareholdersNode,
  LegalEntityNodeData,
  Node,
  TempLegalEntityStructureNode,
  TooltipData,
} from './org-chart-model'

// Define Constants used to render and pre-calculate node dimensions for optimal
// chart layout

export const NODE_MAX_WIDTH = 250
export const NODE_HEIGHT = 52
export const NODE_ICON_SIZE = 24

// Gap between label and icon
export const NODE_ICON_LABEL_GAP = 8

// Gap between items in a node when there is more than 1
export const NODE_ITEM_GAP = 16
export const DIVIDER_WIDTH = 1
export const NODE_HORIZONTAL_PADDING = 16

const CHART_ANIMATION_DURATION_MS = 250
export const CHART_CONTAINER_PADDING_TOP = 8
export const CHART_CONTAINER_PADDING_BOTTOM = 12
export const CHART_CONTAINER_PADDING_X = 12
export const CHART_CHILDREN_MARGIN = 60
export const CHART_PADDING = 60
export const CHART_HEIGHT =
  CHART_PADDING * 2 +
  NODE_HEIGHT +
  CHART_CHILDREN_MARGIN +
  NODE_HEIGHT +
  CHART_CHILDREN_MARGIN +
  NODE_HEIGHT

/**
 * CSS classes are used to style specific d3-org-chart elements from the
 * containing styled component.
 */
export const orgChartClasses = generateUtilityClasses('D3OrgChart', [
  'node',
  'itemContainer',
  'label',
  'truncateText',
  'textSecondary',
  'divider',
  'iconContainer',
  'icon',
  'expandButton',
])

/**
 * Formats the share percentage of a shareholders as it should be displayed in
 * the label. This is used to render the HTML and also to calculate the eventual
 * node width.
 */
export function formatShare(shareholder: LegalEntityShareholder) {
  if (!shareholder.percentageShareExact) {
    return ''
  }

  return `(${formatPercentage(shareholder.percentageShareExact)})`
}

/**
 * Renders the HTML contents of a node
 */
function renderNodeContent(d) {
  return renderToStaticMarkup(createElement(NodeContent, { d }))
}

/**
 * Calculates the width required to render contents of a shareholders node
 */
function calculateShareholdersNodeWidth(node: Node<LegalEntityShareholder[]>, font: CSSProperties) {
  // First calculate the total width of each individual shareholder item
  const itemWidths = node.data.reduce((acc, current) => {
    let label = `${current.name}`

    // Add share percentage when available
    if (current.percentageShareExact) {
      label += ` ${formatShare(current)}`
    }

    const labelWidth = getTextWidth(label, font.fontFamily, font.fontSize, font.fontWeight)
    return acc + labelWidth + NODE_ICON_LABEL_GAP + NODE_ICON_SIZE
  }, 0)

  // Calculate spacing between the individual items
  const item = DIVIDER_WIDTH + NODE_ITEM_GAP + NODE_ITEM_GAP
  const gaps = (node.data.length - 1) * item

  // Combine all calculated values and add the horizontal node padding
  return itemWidths + NODE_HORIZONTAL_PADDING * 2 + gaps
}

/**
 * Calculates the width required to render the contents of a legal entity node
 */
function calculateLegalEntityNodeWidth(node: Node<LegalEntityNodeData>, font: CSSProperties) {
  const labelWidth = getTextWidth(node.data.name, font.fontFamily, font.fontSize, font.fontWeight)
  const calculatedWidth = labelWidth + NODE_ICON_SIZE + NODE_ITEM_GAP + NODE_HORIZONTAL_PADDING * 2
  return Math.min(calculatedWidth, NODE_MAX_WIDTH)
}

/**
 * Calculates the node width that should be reserved to render the actual
 * contents of a node in the org chart.
 */
function calculateNodeWidth(d, font: CSSProperties) {
  if (isShareholdersNode(d.data)) {
    return calculateShareholdersNodeWidth(d.data, font)
  }

  if (isLegalEntityNode(d.data)) {
    return calculateLegalEntityNodeWidth(d.data, font)
  }

  throw new Error('Unknown node type')
}

/**
 * Renders the HTML contents of the expand/collapse button that is rendered
 * together with nodes that have children.
 */
function renderExpandCollapseButtonContent(node, state, alwaysExpandedNodeIds: string[]) {
  return renderToStaticMarkup(createElement(ExpandCollapseButton, { node, alwaysExpandedNodeIds }))
}

/**
 * Initializes a d3-org-chart on in the given containerRef and performs an
 * initial render.
 */
export function useInitializeChart(
  containerRef: RefObject<HTMLDivElement>,
  data,
  alwaysExpandedNodeIds: string[]
) {
  const trackInteraction = useTrackChartInteraction()
  const theme = useTheme()
  const chartRef = useRef(new D3OrgChart())

  useLayoutEffect(() => {
    // Keep track of when the chart has completed the initial render
    let initialized = false

    if (!containerRef.current) {
      return // Should not happen
    }

    const chart = chartRef.current

    chart
      .container(containerRef.current)
      // Set data, it must be an array of objects, where hierarchy is defined via id and parentId props
      .data(data)

      // Set the svg height
      .svgHeight(CHART_HEIGHT)
      .svgWidth('100%')

      // Disable grouping child nodes vertically (compact layout)
      .compact(false)

      // Calculate a single node width based on it's type
      .nodeWidth((d) => calculateNodeWidth(d, theme.typography.body2))

      // Horizontal gap between sibling nodes (nodes that are placed next to
      // each other on the same level)
      .siblingsMargin(() => 8)

      // Vertical gap between parent and child nodes including the line
      .childrenMargin(() => CHART_CHILDREN_MARGIN)

      // Minimal horizontal margin between nodes on the same level, you'll
      // notice this when expanding multiple levels of children.
      .neighbourMargin(() => 0)

      // Disable automatically centering nodes on expand/collapse
      .setActiveNodeCentered(false)

      // Function responsible for rendering the actual HTML contents of a node
      .nodeContent(renderNodeContent)

      // Set how the expand/collapse button for a node is rendered
      .buttonContent(({ node, state }) =>
        renderExpandCollapseButtonContent(node, state, alwaysExpandedNodeIds)
      )

      // Fit the chart into view after expand or collapse
      .onExpandOrCollapse(() => {
        chart.fit()
        trackInteraction()
      })

      // Add event handlers to track chart interactions in Planhat
      .onZoomEnd(() => {
        // The chart zooms/pans (fit) after initial render which trigger this
        // event without manual interaction. To avoid this, only track an
        // interaction after the chart has initialized.
        if (initialized) {
          trackInteraction()
        }
      })

      .duration(CHART_ANIMATION_DURATION_MS)

      // Height of a single node in the chart
      .nodeHeight(() => NODE_HEIGHT)

      // Expand the children of the first/top level nodes
      .initialExpandLevel(1)

      // Reset to zoom level to 100%
      .initialZoom(1)

      // Perform initial render
      .render()

      // Ensure the root is centered within the view. We're using fit here
      // because it centers whatever is rendered into the view. Using
      // setCentered would leave a large gap of white space on the top. Also,
      // the chart must have been rendered at least once to make sure any
      // descendants have been initialized properly.
      .fit({
        animate: false, // Animate doesn't work properly so disable it
        scale: false, // Keep current zoom level
      })

    // There is no straightforward way to check if the chart has finished its
    // initial render. So we use a timeout with the animation duration and
    // additional time to be safe to determine when the chart has finished the
    // render + fit.
    const timeoutId = setTimeout(() => {
      initialized = true
    }, CHART_ANIMATION_DURATION_MS + 100)

    // Monkey patch onButtonClick to disable expand/collapse behaviour on
    // nodes that should always be expanded
    const originalOnButtonClick = chart.onButtonClick.bind(chart)
    chart.onButtonClick = (event, node) => {
      if (alwaysExpandedNodeIds.includes(node.id)) {
        return
      }

      originalOnButtonClick(event, node)
    }

    return () => {
      clearTimeout(timeoutId)

      // Cleanup the chart when the component unmounts
      chart.clear()
    }
  }, [alwaysExpandedNodeIds, containerRef, data, theme.typography.body2, trackInteraction])

  return chartRef
}

/**
 * Flattens a recursive legal entity structure into a d3-org-chart compatible
 * array by defining the parent/child relation using a parentId
 */
function flattenRecursiveStructure(
  childNodes: TempLegalEntityStructureNode[],
  parentId: string | null = null
) {
  return childNodes.reduce((acc, current) => {
    const { children, ...rest } = current

    // Generate a unique id for the node so we can reference it as parentId
    const id = generateNodeId()

    // Add the node
    acc.push({ id, parentId: parentId, data: rest })

    // Add the children when available and link them to the current node by id
    if (children) {
      acc.push(...flattenRecursiveStructure(children, id))
    }

    return acc
  }, new Array<Node<unknown>>())
}

/**
 * Maps the given legal entity structure to a flattened list that is used to
 * render the chart. Also sets some additional data points required for properly
 * interacting with the chart.
 */
export function useChartData(structure: TempLegalEntityStructureNode) {
  return useMemo(() => {
    const data = flattenRecursiveStructure([structure])
    const rootNodeId = data.length ? data[0].id : null
    const alwaysExpandedNodeIds = rootNodeId ? [rootNodeId] : []

    return {
      data,
      alwaysExpandedNodeIds,
    }
  }, [structure])
}

/**
 * Returns the depth of the legal entity structure
 */
export function getDepth(node: TempLegalEntityStructureNode): number {
  // Base case: if no children, the depth is 1 (the node itself)
  if (!node.children || node.children.length === 0) {
    return 1
  }

  // Recursively calculate the depth of each child
  const childDepths = node.children.map((child) => getDepth(child))

  // Return current node (1) plus the maximum depth among children
  return 1 + Math.max(...childDepths)
}

/**
 * Returns the dimensions and position relative to the ancestor for the given
 * element.
 */
function getRelativeElementRect(child: HTMLElement, ancestor: HTMLElement) {
  const { width, height } = child.getBoundingClientRect()

  // Get the bounding rect for the child and the ancestor
  const childRect = child.getBoundingClientRect()
  const ancestorRect = ancestor.getBoundingClientRect()

  return {
    // The distance from edge of the ancestor with a small correction for
    // container padding
    top: childRect.top - ancestorRect.top + CHART_CONTAINER_PADDING_TOP,
    left: childRect.left - ancestorRect.left + CHART_CONTAINER_PADDING_X,
    width,
    height,
  }
}

/**
 * True if the element is overflowing/truncated, false otherwise.
 */
function hasOverflowX(element: HTMLElement) {
  return element.scrollWidth > element.clientWidth
}

function isChartZoomLevelBelow(chart: D3OrgChart, threshold: number) {
  const { lastTransform } = chart.getChartState()
  const currentScale = lastTransform.k

  return currentScale < threshold
}

/**
 * Returns true if the chart is animating, false otherwise.
 */
function chartIsAnimating(chart: D3OrgChart) {
  return !select(chart.container()) // Select the containing node
    .selectAll('*') // Select all elements within the chart
    .filter(function () {
      // If there's an active transition, d3.active(this) is non-null
      return active(this) !== null
    })
    .empty() // True if none of the elements matches the previous filter
}

/**
 * Provides chart tooltip state and event handlers
 */
export function useChartTooltip(
  chartRef: RefObject<D3OrgChart>,
  containerRef: RefObject<HTMLDivElement>
) {
  // Keep track of the current tooltip state
  const [tooltip, setTooltip] = useState<TooltipData | null>(null)

  // Mouse handler for the chart container to determine if a tooltip should
  // be visible
  const handleMouseMove = useCallback(
    (event: MouseEvent) => {
      const target = event.target

      // Initial sanity checks
      if (
        !(target instanceof HTMLElement) ||
        !containerRef.current ||
        chartIsAnimating(chartRef.current)
      ) {
        setTooltip(null)
        return
      }

      // If it's a label and the text is truncated, show a tooltip with the
      // non-truncated label
      if (
        target.classList.contains(orgChartClasses.label) && // Make sure the target is a label
        // Show the tooltip when a label is overflowing or the zoom level is below a readable level
        (hasOverflowX(target) || isChartZoomLevelBelow(chartRef.current, 0.7)) &&
        target.textContent // Sanity check that the label has text
      ) {
        const rect = getRelativeElementRect(target, containerRef.current)

        const tooltipData = {
          top: rect.top,
          left: rect.left,
          width: rect.width,
          height: rect.height,
          content: target.textContent || '',
        }

        setTooltip((prev) => {
          if (isEqual(prev, tooltipData)) {
            return prev
          }

          return tooltipData
        })
      }
    },
    [chartRef, containerRef]
  )

  // Handler to close the tooltip
  const handleCloseTooltip = useCallback(() => {
    setTooltip(null)
  }, [])

  return [tooltip, handleMouseMove, handleCloseTooltip] as const
}

function useTrackChartInteraction() {
  // Track this event no more than once during a session. We're only interested
  // in knowing if the chart was used, not how many times per session.
  const isTracked = useRef(false)
  const trackActivity = useTrackActivityCallback()

  return useCallback(() => {
    if (isTracked.current) {
      return
    }

    trackActivity('Legal structure chart interaction')
    isTracked.current = true
  }, [trackActivity])
}
