import { useCallback, useMemo } from 'react';
import type { MRT_Header, MRT_RowData, MRT_TableInstance } from '../types';

interface UseGroupedColumnsProps<TData extends MRT_RowData> {
  table: MRT_TableInstance<TData>;
}

export function useGroupedColumns<TData extends MRT_RowData>(props: UseGroupedColumnsProps<TData>) {
  const { table } = props;
  const flatHeaders = table.getFlatHeaders();

  /**
   * Map of columnId to headerId - used to determine which header to render for a given column.
   * This is useful when columns are grouped. Since mrt adds some placeholder headers to columns without group headers,
   * So we need to determine which header to render for a given column.
   */
  const headerIdByColumnId = useMemo(() => {
    return flatHeaders.reduce<Record<string, string>>((accumulator, header) => {
      if (!accumulator[header.column.id]) {
        accumulator[header.column.id] = header.id;
      }
      return accumulator;
    }, {});
  }, [flatHeaders]);

  /**
   * Returns the header with correct rowSpan for grouped columns.
   * Returns null for headers that are not supposed to be rendered for a given column.
   */
  const getAdaptedHeader = useCallback(
    (header: MRT_Header<TData>) => {
      const headerIdForColumn = headerIdByColumnId[header.column.id];

      if (headerIdForColumn !== header.id) {
        return null;
      }

      if (!header.isPlaceholder) {
        return header;
      }

      return { ...header, rowSpan: header.getLeafHeaders().length };
    },
    [headerIdByColumnId],
  );

  return { getAdaptedHeader };
}
