import { isDefined } from '@gain/utils/common'
import { styled } from '@mui/material/styles'
import { extent, line, scaleLinear, ScaleTime, scaleTime } from 'd3'
import { Delaunay } from 'd3-delaunay'
import { addMonths } from 'date-fns/addMonths'
import { addYears } from 'date-fns/addYears'
import { differenceInCalendarMonths } from 'date-fns/differenceInCalendarMonths'
import { differenceInCalendarYears } from 'date-fns/differenceInCalendarYears'
import { Fragment, useCallback, useMemo, useState } from 'react'

import { useMaxTextWidth } from '../../chart-utils/use-max-text-width'
import Grid from './grid.component'
import MultiLineChartLine from './multi-line-chart-line'
import MultiLineChartPoint from './multi-line-chart-point'

const GRID_PADDING_X = 28

const StyledAxisLabel = styled('text')(({ theme }) => ({
  ...theme.typography.overline,
  fill: theme.palette.text.secondary,
}))

export interface ScaleConfig<D, V> {
  label?: string
  options?: number[]
  getValue: (d: D) => V | null
  getLabel: (value: V) => string
  getTooltip?: (d: D) => string
  ticks?: (scale: ScaleTime<number, number, unknown>) => V[]
}

interface Point<D> {
  key: number
  x: Date
  y: number
  estimate?: boolean
  data: D
}

interface DataItem<D> {
  id: number
  label?: string
  color: string
  data: D[]
}

interface Line<D> {
  id: number
  label: string
  color: string
  data: Point<D>[]
}

export interface MultiLineChartProps<D> {
  width: number
  height: number
  data: DataItem<D>[]
  getId: (d: D) => number
  getIsEstimated?: (d: D) => boolean
  xScaleConfig: ScaleConfig<D, Date> & { dateVariant: 'months' | 'years' }
  yScaleConfig: ScaleConfig<D, number>
}

function isPositionAnEstimate<D>(position: number, points: Point<D>[]) {
  return position >= 0 && position < points.length && Boolean(points[position].estimate)
}

const MARGIN = { top: 20, right: 0, bottom: 30, left: 8 }

const StyledVoronoiPath = styled('path')({
  fill: 'none',
  pointerEvents: 'all',
})

export default function MultiLineChart<D>({
  width,
  height,
  data,
  getId,
  getIsEstimated = () => false,
  xScaleConfig,
  yScaleConfig,
}: MultiLineChartProps<D>) {
  const [activeLine, setActivePointIndex] = useState<number | null>(null)

  const lines = useMemo(() => {
    return data.map((item) => {
      const points = item.data.reduce((acc, d) => {
        const x = xScaleConfig.getValue(d)
        const y = yScaleConfig.getValue(d)

        if (!isDefined(x) || !isDefined(y)) {
          return acc
        }

        return acc.concat({
          key: getId(d),
          x,
          y,
          data: d,
          estimate: getIsEstimated(d),
        })
      }, new Array<Point<D>>())

      return {
        id: item.id,
        label: item.label,
        color: item.color,
        data: points,
      } as Line<D>
    })
  }, [data, getId, getIsEstimated, xScaleConfig, yScaleConfig])

  const yValues = useMemo(
    () => lines.flatMap((yLines) => yLines.data.map((point) => point.y)),
    [lines]
  )

  const xValues = useMemo(
    () => lines.flatMap((xLines) => xLines.data.map((point) => point.x)),
    [lines]
  )

  const boundsHeight = height - MARGIN.top - MARGIN.bottom

  const yScale = useMemo(() => {
    const yExtent = extent(yValues)
    const min = 0
    const max = yExtent[1] || 1
    return scaleLinear().domain([max, min]).range([0, boundsHeight]).nice(5)
  }, [boundsHeight, yValues])

  const yTicks = useMemo(
    () => yScale.ticks(5).map((value) => ({ value, y: yScale(value) })),
    [yScale]
  )

  const maxYLabelWidth = useMaxTextWidth(yTicks.map(({ value }) => yScaleConfig.getLabel(value)))
  const MARGIN_LEFT = MARGIN.left + maxYLabelWidth
  const boundsWidth = width - MARGIN.right - MARGIN_LEFT

  const xScale = useMemo(() => {
    const xExtent = extent(xValues)

    // Always make sure a minimum of 5 date points
    let min = xExtent[0] || 0
    const max = xExtent[1] || 1

    if (xScaleConfig.dateVariant === 'years' && differenceInCalendarYears(max, min) < 4) {
      min = addYears(max, -4)
    } else if (xScaleConfig.dateVariant === 'months' && differenceInCalendarMonths(max, min) < 4) {
      min = addMonths(max, -4)
    }

    return scaleTime()
      .domain([min, max])
      .range([0, boundsWidth - GRID_PADDING_X * 2])
  }, [boundsWidth, xScaleConfig.dateVariant, xValues])

  const generateLine = useMemo(
    () =>
      line<Point<D>>(
        (d) => xScale(d.x),
        (d) => yScale(d.y)
      ).defined((d) => !d.estimate),
    [xScale, yScale]
  )

  const generateEstimatesLine = useMemo(
    () =>
      line<Point<D>>(
        (d) => xScale(d.x),
        (d) => yScale(d.y)
      ).defined(
        // If the current one is not and the next one or the previous one was then still keep the dotted line
        (d, index, points) =>
          Boolean(d.estimate) ||
          isPositionAnEstimate(index - 1, points) ||
          isPositionAnEstimate(index + 1, points)
      ),
    [xScale, yScale]
  )

  const delaunay = useMemo(() => {
    return Delaunay.from<Point<D>>(
      lines.flatMap((dataLine) => dataLine.data),
      (point) => xScale(point.x) + MARGIN_LEFT + GRID_PADDING_X,
      (point) => yScale(point.y) + MARGIN.top
    )
  }, [MARGIN_LEFT, lines, xScale, yScale])

  const v = delaunay.voronoi([0, 0, width, height])
  const polygons = Array.from(v.cellPolygons())

  const handleMouseEnter = useCallback(
    (index: number) => () => {
      if (lines.length === 1) {
        setActivePointIndex(index)
      }
    },
    [lines]
  )

  const handleMouseLeave = useCallback(() => {
    setActivePointIndex(null)
  }, [])

  return (
    <svg
      height={height}
      width={width}>
      <g transform={`translate(${[MARGIN_LEFT, MARGIN.top].join(',')})`}>
        <Grid
          height={boundsHeight}
          ticks={yTicks}
          x1={0}
          x2={boundsWidth}
        />
      </g>
      <g
        height={boundsHeight}
        transform={`translate(${[MARGIN_LEFT + GRID_PADDING_X, MARGIN.top].join(',')})`}
        width={boundsWidth}>
        {lines.map((item, index) => (
          <Fragment key={index}>
            <MultiLineChartLine
              d={generateLine(item.data) as string}
              stroke={item.color}
            />
            <MultiLineChartLine
              d={generateEstimatesLine(item.data) as string}
              stroke={item.color}
              strokeDasharray={4}
            />
            {item.data.map((point, pointIndex) => (
              <MultiLineChartPoint
                key={pointIndex}
                color={item.color}
                estimate={point.estimate}
                tooltip={xScaleConfig.getTooltip?.(point.data)}
                tooltipVisible={pointIndex === activeLine}
                x={xScale(point.x)}
                y={yScale(point.y)}
              />
            ))}
          </Fragment>
        ))}
        {(xScaleConfig.ticks?.(xScale) ?? xScale.ticks()).map((value, index) => (
          <StyledAxisLabel
            key={index}
            dx={xScale(value)}
            dy={boundsHeight + 16}
            textAnchor={'middle'}>
            {xScaleConfig.getLabel(value)}
          </StyledAxisLabel>
        ))}
      </g>
      {/* Axes labels */}
      <g
        height={boundsHeight}
        transform={`translate(${[MARGIN_LEFT, MARGIN.top].join(',')})`}
        width={boundsWidth}>
        {yTicks.map(({ value }) => (
          <StyledAxisLabel
            key={value}
            alignmentBaseline={'middle'}
            dx={-8}
            dy={yScale(value)}
            textAnchor={'end'}>
            {yScaleConfig.getLabel(value)}
          </StyledAxisLabel>
        ))}
      </g>

      {polygons.map((polygon, i) => (
        <StyledVoronoiPath
          key={i}
          d={`M${polygon.join('L')}Z`}
          onMouseLeave={handleMouseLeave}
          onMouseOver={handleMouseEnter(polygon.index)}
        />
      ))}
    </svg>
  )
}
