import { isNumber } from '@gain/utils/typescript'
import generateUtilityClasses from '@mui/material/generateUtilityClasses'
import { styled, useTheme } from '@mui/material/styles'
import { nanoid } from '@reduxjs/toolkit'
import {
  extent,
  forceCollide,
  forceManyBody,
  forceSimulation,
  forceX,
  forceY,
  ScaleLinear,
  scaleLinear,
  ScaleLogarithmic,
  scaleSqrt,
} from 'd3'
import React, {
  MouseEvent,
  ReactElement,
  ReactNode,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from 'react'

import { useMaxTextWidth } from '../../chart-utils/use-max-text-width'
import Grid from './grid'

const BUBBLE_MIN_SIZE = 7
const BUBBLE_MAX_SIZE = 28
const Y_LABEL_OFFSET = 16

interface StyledCircleProps {
  clickable?: boolean
}

const bubbleChartClasses = generateUtilityClasses('BubbleChart', ['bubble'])

const StyledCircleContainer = styled('g')(({ theme }) => ({
  ':not(:focus-within):not(:hover)': {
    [`& .${bubbleChartClasses.bubble}`]: {
      opacity: 1,
    },
  },

  [`& .${bubbleChartClasses.bubble}`]: {
    willChange: 'opacity',
    transition: theme.transitions.create('opacity', {
      easing: theme.transitions.easing.easeInOut,
      duration: theme.transitions.duration.shorter,
    }),

    opacity: 0.2,

    '&:hover': {
      opacity: 1,
    },
  },
}))

const StyledCircle = styled('circle', {
  shouldForwardProp: (prop) => prop !== 'clickable',
})<StyledCircleProps>(({ clickable }) => ({
  strokeWidth: 0,
  ...(clickable && {
    cursor: 'pointer',
  }),
}))

const StyledLabel = styled('text')(({ theme }) => ({
  fill: theme.palette.text.secondary,
  ...theme.typography.overline,
}))

export interface ScaleConfig<D> {
  label: string
  explainer?: string
  options?: number[]
  getValue: (d: D) => number | null
  getLabel: (value: number) => string
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
  scaleFn?: (params: { size: number }) => ScaleLinear<any, any> | ScaleLogarithmic<any, any>
  ticksFn?: () => number[]
  showGridLine?: (value: number) => boolean
}

function useBubbleChartItems<D>(
  data: D[],
  getBubbleValue: (d: D) => number | null,
  xScaleConfig: ScaleConfig<D>,
  yScaleConfig: ScaleConfig<D>
) {
  return useMemo(
    () =>
      data.reduce((acc, d) => {
        const x = xScaleConfig.getValue(d)
        const y = yScaleConfig.getValue(d)
        const r = getBubbleValue(d)

        if (!x || !y || !r) {
          return acc
        }

        return acc.concat({
          x,
          y,
          r,
          data: d,
        })
      }, new Array<Item<D>>()),
    [data, getBubbleValue, xScaleConfig, yScaleConfig]
  )
}

function useYValues<D>(chartItems: Item<D>[], yScaleConfig: ScaleConfig<D>) {
  return useMemo(
    () => (yScaleConfig?.options ? yScaleConfig.options : chartItems.map((item) => item.y)),
    [chartItems, yScaleConfig]
  )
}

export interface BubblePlotProps<D> {
  width: number
  height: number
  data: D[]
  getId: (d: D) => number
  getColor: (d: D) => string
  getBubbleValue: (d: D) => number | null
  xScaleConfig: ScaleConfig<D>
  yScaleConfig: ScaleConfig<D>
  renderBubble?: (item: D, node: ReactElement) => ReactNode
  onBubbleClick?: (item: D, event: MouseEvent) => void
  disableXAxisLines?: boolean
}

interface Item<D> {
  x: number
  y: number
  r: number
  data: D
}

export default function BubbleChart<D>({
  width,
  height,
  data,
  getId,
  getColor,
  getBubbleValue,
  xScaleConfig,
  yScaleConfig,
  renderBubble,
  onBubbleClick,
  disableXAxisLines,
}: BubblePlotProps<D>) {
  const { current: clipId } = useRef(nanoid())

  const axesRef = useRef(null)

  const [simulatedItems, setSimulatedItems] = useState<Item<D>[]>([])

  const chartItems = useBubbleChartItems(data, getBubbleValue, xScaleConfig, yScaleConfig)
  const yValues = useYValues(chartItems, yScaleConfig)

  const maxYLabelWidth = useMaxTextWidth(yValues.map(yScaleConfig.getLabel))

  const MARGIN = {
    top: 1,
    right: 0,
    bottom: 24,
    left: maxYLabelWidth + Y_LABEL_OFFSET,
  }

  // Layout. The div size is set by the given props.
  // The bounds (=area inside the axis) is calculated by substracting the margins
  const boundsWidth = width - MARGIN.right - MARGIN.left
  const boundsHeight = height - MARGIN.top - MARGIN.bottom

  const xValues = useMemo(
    () => (xScaleConfig?.options ? xScaleConfig.options : chartItems.map((item) => item.x)),
    [chartItems, xScaleConfig]
  )

  //Scales
  const yScale = useMemo(() => {
    const yExtent = extent(yValues)
    const min = yExtent[0] || 0
    const max = yExtent[1] || 1

    const yRange = max - min
    const step = yRange / yValues.length
    const padding = step * 0.5

    if (yScaleConfig.scaleFn) {
      return yScaleConfig.scaleFn({ size: boundsHeight })
    } else {
      return scaleLinear()
        .domain([min - padding, max + padding])
        .range([boundsHeight, 0])
    }
  }, [yValues, yScaleConfig, boundsHeight])

  const xScale = useMemo(() => {
    const xExtent = extent(xValues)
    const min = xExtent[0] || 0
    const max = xExtent[1] || 1

    const xRange = max - min
    const step = xRange / xValues.length
    const padding = step * 0.5

    if (xScaleConfig.scaleFn) {
      return xScaleConfig.scaleFn({ size: boundsWidth })
    } else {
      return scaleLinear()
        .domain([min - padding, max + padding])
        .range([0, boundsWidth])
    }
  }, [xValues, xScaleConfig, boundsWidth])

  const sizeScale = useMemo(() => {
    const [, max] = extent(chartItems.map((item) => item.r).filter(isNumber)) as [number, number]

    return scaleSqrt()
      .domain([0, max || 1])
      .range([BUBBLE_MIN_SIZE, BUBBLE_MAX_SIZE])
      .clamp(true)
  }, [chartItems])

  useEffect(() => {
    if (chartItems.length === 0) {
      return
    }

    const circleSimulation = forceSimulation(
      chartItems.map((item) => ({
        x: xScale(item.x),
        y: yScale(item.y),
        r: sizeScale(item.r),
        data: item.data,
      }))
    )
      .force('charge', forceManyBody().strength(0.5))
      .force(
        'x',
        forceX<Item<D>>().x((d) => d.x)
      )
      .force(
        'y',
        forceY<Item<D>>().y((d) => d.y)
      )
      .force(
        'collision',
        forceCollide<Item<D>>().radius((d) => d.r + 1.25)
      )

    for (
      let i = 0,
        n = Math.ceil(
          Math.log(circleSimulation.alphaMin()) / Math.log(1 - circleSimulation.alphaDecay())
        );
      i < n;
      ++i
    ) {
      circleSimulation.tick()
    }

    circleSimulation.on('end', () => {
      setSimulatedItems(circleSimulation.nodes())
    })

    return () => {
      circleSimulation.stop()
    }
  }, [chartItems, sizeScale, xScale, yScale])

  const yTicks = useMemo(() => {
    return yValues.slice(0, -1).map((value, index) => {
      const tickValue = (yValues[index + 1] - value) * 0.5 + value
      return {
        value: tickValue,
        y: yScale(tickValue),
      }
    })
  }, [yScale, yValues])

  const xTicks = useMemo(() => {
    if (disableXAxisLines) {
      return []
    }

    if (xScaleConfig.ticksFn) {
      return xScaleConfig
        .ticksFn()
        .filter((value) => {
          if (xScaleConfig.showGridLine) {
            return xScaleConfig.showGridLine(value)
          }

          return true
        })
        .map((value) => {
          return {
            value: value,
            x: xScale(value),
          }
        })
    }

    return xValues.slice(0, -1).map((value, index) => {
      const tickValue = (xValues[index + 1] - value) * 0.5 + value
      return {
        value: tickValue,
        x: xScale(tickValue),
      }
    })
  }, [disableXAxisLines, xScaleConfig, xValues, xScale])

  const xLabels = useMemo(() => {
    if (xScaleConfig.ticksFn) {
      return xScaleConfig.ticksFn()
    }

    return xValues
  }, [xScaleConfig, xValues])

  const handleRenderBubble = useCallback(
    (item: D, node: ReactElement) => {
      if (renderBubble) {
        return renderBubble(item, node)
      }

      return node
    },
    [renderBubble]
  )

  const handleBubbleClick = useCallback(
    (item: D) => (event: MouseEvent) => {
      event.preventDefault()
      event.stopPropagation()
      if (onBubbleClick) {
        onBubbleClick(item, event)
      }
    },
    [onBubbleClick]
  )

  const theme = useTheme()

  return (
    <div>
      <svg
        height={height}
        width={width}>
        <g transform={`translate(${[MARGIN.left, MARGIN.top].join(',')})`}>
          <rect
            fill={theme.palette.background.paper}
            height={boundsHeight - 1}
            rx={4}
            shapeRendering={'geometricPrecision'}
            stroke={theme.palette.divider}
            strokeWidth={1}
            width={boundsWidth - 1}
            x={0}
            y={0}
          />
        </g>
        <clipPath
          // clipPathUnits={'userSpaceOnUse'}
          id={clipId}>
          <rect
            height={boundsHeight}
            rx={4}
            width={boundsWidth}
            x={0}
            y={0}
          />
        </clipPath>
        <g
          clipPath={`url(#${clipId})`}
          height={boundsHeight}
          transform={`translate(${[MARGIN.left, MARGIN.top].join(',')})`}
          width={boundsWidth}>
          <Grid
            x1={0}
            x2={boundsWidth}
            xTicks={xTicks}
            y1={0}
            y2={boundsHeight}
            yTicks={yTicks}
          />
          <StyledCircleContainer>
            {simulatedItems.map((item) =>
              handleRenderBubble(
                item.data,
                <StyledCircle
                  key={getId(item.data)}
                  className={bubbleChartClasses.bubble}
                  clickable={!!onBubbleClick}
                  cx={0}
                  cy={0}
                  fill={getColor(item.data)}
                  onClick={handleBubbleClick(item.data)}
                  r={item.r}
                  style={{
                    transform: `translate(${item.x}px, ${item.y}px)`,
                  }}
                />
              )
            )}
          </StyledCircleContainer>
        </g>
        {/* Second is for the axes */}
        <g
          ref={axesRef}
          height={boundsHeight}
          transform={`translate(${[MARGIN.left, MARGIN.top].join(',')})`}
          width={boundsWidth}>
          {yValues.map((value) => (
            <StyledLabel
              key={value}
              alignmentBaseline={'middle'}
              dx={-16}
              dy={yScale(value)}
              textAnchor={'end'}>
              {yScaleConfig.getLabel(value)}
            </StyledLabel>
          ))}
          {xLabels.map((value) => (
            <StyledLabel
              key={value}
              dx={xScale(value)}
              dy={boundsHeight + 16}
              textAnchor={'middle'}>
              {xScaleConfig.getLabel(value)}
            </StyledLabel>
          ))}
        </g>
      </svg>
    </div>
  )
}
