import { isDefined } from '@gain/utils/common'
import { styled } from '@mui/material/styles'
import { hierarchy, HierarchyRectangularNode, treemap, treemapResquarify } from 'd3'
import { MouseEvent, useCallback, useMemo, useState } from 'react'

import { ChartGroupValueType } from '../../../../common/chart/chart-groups'
import { TreeLeaf, TreeNode, TreeRoot } from './treemap-chart-model'
import TreemapGroup, { TreemapGroupState } from './treemap-group'

function hasValue<NodeData, LeafData>(
  node: TreeRoot<NodeData, LeafData> | TreeNode<NodeData, LeafData> | TreeLeaf<LeafData>
): node is TreeNode<NodeData, LeafData> | TreeLeaf<LeafData> {
  return 'value' in node
}

const StyledSvg = styled('svg')({
  borderRadius: 4,
})

const StyledNoDateRect = styled('rect')(({ theme }) => ({
  fill: theme.palette.grey['50'],
}))

const StyledNoDataText = styled('text')(({ theme }) => ({
  ...theme.typography.subtitle2,
  color: theme.palette.text.primary,
  dominantBaseline: 'middle',
  textAnchor: 'middle',
}))

export interface TreemapChartProps<NodeData, LeafData> {
  width: number
  height: number
  data: TreeRoot<NodeData, LeafData>
  highlightGroupKey?: ChartGroupValueType
  valueFormatter: (value: number | null, item: LeafData) => string
  onLeafClick?: (d: LeafData, event: MouseEvent) => void
}

export default function TreemapChart<NodeData, LeafData>({
  width,
  height,
  data,
  highlightGroupKey,
  valueFormatter,
  onLeafClick,
}: TreemapChartProps<NodeData, LeafData>) {
  // Keep track of the highlighted top level group
  const [hoverGroupKey, setHoverGroupKey] = useState<string | null>(null)

  const hierarch = useMemo(() => {
    return hierarchy(data)
      .sum((node) => (hasValue(node) ? node.value : 0))
      .sort(function (nodeA, nodeB) {
        return nodeB.height - nodeA.height
      })
  }, [data])

  const root = useMemo(() => {
    const generator = treemap<TreeRoot<NodeData, LeafData>>()
      .size([width, height])
      .paddingInner((node) => {
        if (node.depth === 0) {
          return 5
        }
        return 1
      })
      .tile(treemapResquarify)
      .round(true)
    return generator(hierarch)
  }, [height, hierarch, width])

  const handleMouseEnterGroup = useCallback(
    (key: string) => () => {
      setHoverGroupKey(key)
    },
    []
  )

  const handleMouseLeaveGroup = useCallback(() => {
    setHoverGroupKey(null)
  }, [])

  const handleGetGroupState = useCallback(
    (groupKey: string) => {
      if (isDefined(hoverGroupKey)) {
        return TreemapGroupState.Active
      }

      const activeKey = isDefined(highlightGroupKey) ? highlightGroupKey : hoverGroupKey

      if (activeKey === null) {
        return TreemapGroupState.Default
      } else if (activeKey === groupKey) {
        return TreemapGroupState.Active
      } else {
        return TreemapGroupState.Inactive
      }
    },
    [highlightGroupKey, hoverGroupKey]
  )

  if (
    !root.children ||
    root.children.length === 0 ||
    root.children.every((child) => child?.data.children.length === 0)
  ) {
    return (
      <StyledSvg
        height={height}
        width={width}>
        <StyledNoDateRect
          height={height}
          width={width}
          x={0}
          y={0}
        />
        <StyledNoDataText
          x={'50%'}
          y={'50%'}>
          No data available
        </StyledNoDataText>
      </StyledSvg>
    )
  }

  return (
    <StyledSvg
      height={height}
      width={width}>
      {root.children.map((item) => {
        const child = item as unknown as HierarchyRectangularNode<TreeNode<NodeData, LeafData>>

        return (
          <TreemapGroup
            key={child.data.key}
            color={child.data.color}
            node={child}
            onLeafClick={onLeafClick}
            onMouseLeave={handleMouseLeaveGroup}
            onMouseOver={handleMouseEnterGroup(child.data.key)}
            state={handleGetGroupState(child.data.key)}
            valueFormatter={valueFormatter}
          />
        )
      })}
    </StyledSvg>
  )
}
