import { useResizeObserver } from '@gain/utils/dom'
import Fade from '@mui/material/Fade'
import { styled } from '@mui/material/styles'
import { useCallback, useRef, useState } from 'react'

import ButtonScroll from '../../button-scroll'
import { ChartGroup, ChartGroupValueType } from '../chart-groups'
import ChartLegendChip from './chart-legend-chip'

const SCROLL_MAX_OFFSET = 0.85

const StyledRoot = styled('div')({
  position: 'relative',
  minWidth: 0,
})

const StyledScrollContainer = styled('div')({
  overflow: 'auto',

  // hide scrollbars
  msOverflowStyle: 'none', // IE and Edge
  scrollbarWidth: 'none', // Firefox
  // Webkit based browsers
  '&::-webkit-scrollbar': {
    display: 'none',
  },
})

const StyledChipContainer = styled('div')(({ theme }) => ({
  minWidth: 'fit-content',
  display: 'flex',
  flexDirection: 'row',
  flexWrap: 'nowrap',
  gap: theme.spacing(1),
  alignItems: 'center',
  justifyContent: 'center',
}))

interface ChartLegendProps<Data, ValueType extends ChartGroupValueType> {
  groups: ChartGroup<Data, ValueType>[]
  value: ChartGroup<Data, ValueType>[]
  onChange: (groups: ChartGroup<Data, ValueType>[]) => void
  onGroupHover?: (group: ChartGroup<Data, ValueType> | null) => void
  onGroupMouseEnter?: (group: ChartGroup<Data, ValueType>) => void
  onGroupMouseLeave?: () => void
  height?: number
}

export default function ChartLegend<Data, ValueType extends ChartGroupValueType>({
  groups,
  value,
  onChange,
  onGroupHover,
  onGroupMouseEnter,
  onGroupMouseLeave,
  height = 48,
}: ChartLegendProps<Data, ValueType>) {
  const scrollContainerRef = useRef<HTMLDivElement>(null)
  const chipContainerRef = useRef<HTMLDivElement>(null)
  const [isScrollStart, setIsScrollStart] = useState(false)
  const [isScrollEnd, setIsScrollEnd] = useState(false)
  const [canScroll, setCanScroll] = useState(false)

  const handleClick = useCallback(
    (group: ChartGroup<Data, ValueType>) => (event: React.MouseEvent) => {
      event.preventDefault()
      event.stopPropagation()
      const index = value.findIndex((item) => item.value === group.value)

      // if all groups are enabled, deselect all groups except the one that was clicked
      if (value.length === groups.length) {
        onChange([group])
      } else {
        // toggle clicked group in current value
        const nextValue =
          index === -1 ? value.concat(group) : value.slice(0, index).concat(value.slice(index + 1))

        // if all groups are disabled as the result of a click, enable all groups
        if (nextValue.length === 0) {
          onChange(groups.slice())
        } else {
          onChange(nextValue)
        }
      }
    },
    [value, onChange, groups]
  )

  const handleMouseEnter = useCallback(
    (group: ChartGroup<Data, ValueType>) => () => {
      onGroupMouseEnter?.(group)
      if (value.some((item) => item.value === group.value)) {
        onGroupHover?.(group)
      }
    },
    [onGroupHover, onGroupMouseEnter, value]
  )

  const handleMouseLeave = useCallback(() => {
    onGroupMouseLeave?.()
    onGroupHover?.(null)
  }, [onGroupHover, onGroupMouseLeave])

  // Update canScroll when groups change
  const handleChipContainerResize = useCallback(() => {
    if (!scrollContainerRef.current || !chipContainerRef.current) {
      return
    }

    if (chipContainerRef.current.offsetWidth > scrollContainerRef.current.offsetWidth) {
      setCanScroll(scrollContainerRef.current.scrollWidth > scrollContainerRef.current.clientWidth)
      scrollContainerRef.current.scrollLeft = 0
      setIsScrollStart(true)
      setIsScrollEnd(false)
    } else {
      setCanScroll(false)
      setIsScrollStart(false)
      setIsScrollEnd(false)
    }
  }, [scrollContainerRef])

  useResizeObserver(chipContainerRef, handleChipContainerResize)

  const handleScrollTo = useCallback(
    (direction: 'left' | 'right') => (event: React.MouseEvent) => {
      event.preventDefault()
      event.stopPropagation()
      const scrollContainer = scrollContainerRef.current
      const chipContainer = chipContainerRef.current

      if (scrollContainer && chipContainer) {
        const width = chipContainer.getBoundingClientRect().width || 0
        const scrollWidth = scrollContainer.getBoundingClientRect().width || 0

        // Calculate the offset to scroll to
        let offset =
          direction === 'left'
            ? scrollContainer.scrollLeft - scrollWidth
            : scrollContainer.scrollLeft + scrollWidth

        // When we're not hitting the beginning or end of the scroll container reduce
        // the amount scrolled. This makes it easier to see the groups on the edges.
        if (offset > 0 && offset < width - scrollWidth) {
          offset =
            direction === 'left'
              ? scrollContainer.scrollLeft - scrollWidth * SCROLL_MAX_OFFSET
              : scrollContainer.scrollLeft + scrollWidth * SCROLL_MAX_OFFSET
        }

        scrollContainer.scrollTo({
          left: offset,
          behavior: 'smooth',
        })
      }
    },
    [scrollContainerRef]
  )

  const handleScroll = useCallback(() => {
    const scrollContainer = scrollContainerRef.current

    if (!scrollContainer) {
      return
    }

    // Use getClientBoundingRect width to get decimals and Math.ceil to prevent not detecting scroll end because of rounding errors
    const scrollLeft = Math.ceil(scrollContainer.getBoundingClientRect().width)
    const maxScrollLeft = scrollContainer.scrollWidth - scrollContainer.scrollLeft

    setIsScrollStart(scrollContainer.scrollLeft === 0)
    setIsScrollEnd(scrollLeft >= maxScrollLeft)
  }, [])

  return (
    <StyledRoot>
      {canScroll && (
        <>
          <Fade in={!isScrollStart}>
            <ButtonScroll
              direction={'left'}
              onClick={handleScrollTo('left')}
              sx={{ top: 8, left: -16, position: 'absolute' }}
            />
          </Fade>
          <Fade in={!isScrollEnd}>
            <ButtonScroll
              direction={'right'}
              onClick={handleScrollTo('right')}
              sx={{ top: 8, right: -16, position: 'absolute' }}
            />
          </Fade>
        </>
      )}

      <StyledScrollContainer
        ref={scrollContainerRef}
        onScroll={handleScroll}>
        <StyledChipContainer
          ref={chipContainerRef}
          sx={{ height: height }}>
          {groups.map((group) => (
            <ChartLegendChip
              key={group.value}
              active={
                value.some((item) => item.value === group.value) && value.length !== groups.length
              }
              color={group.color}
              disabled={!value.find((item) => item.value === group.value)}
              label={group.label}
              onClick={handleClick(group)}
              onMouseEnter={handleMouseEnter(group)}
              onMouseLeave={handleMouseLeave}
              variant={'outlined'}
              clickable
            />
          ))}
        </StyledChipContainer>
      </StyledScrollContainer>
    </StyledRoot>
  )
}
