import LoadingState from '@components/common/LoadingState';
import { type AnyType } from '@lightdash/common';
import { Box } from '@mantine/core';
import { CaretDown, CaretUp, CaretUpDown } from '@phosphor-icons/react';
import {
    flexRender,
    getCoreRowModel,
    getExpandedRowModel,
    getSortedRowModel,
    useReactTable,
    type ColumnDef,
    type Row,
    type SortingState,
} from '@tanstack/react-table';
import { useVirtualizer } from '@tanstack/react-virtual';
import { useEffect, useRef, useState } from 'react';
import NoDataFound from './NoDataFound';
import TableLoader from './TableLoader';
import { fuzzyFilter } from './utils';

interface VirtualTableProps<TData> {
    // data is an array of objects that will be rendered in the table
    data: TData[];
    // columns is an array of objects that define the columns of the table
    columns: ColumnDef<TData, AnyType>[];
    // onRowClick is a function that will be called when a row is clicked
    onRowClick?: (row: Row<TData>) => void;
    // customClass is a string that will be added to the table container
    customClass?: string;
    // rowSize is the estimated height of a row
    rowSize?: number;
    // columnSize is the estimated width of a column
    columnSize?: number;
    // isLoading is a boolean that indicates if the table is loading
    isLoading?: boolean;
    // paginationText is the text to be displayed in the pagination section
    paginationText?: string;
    // isNeedToRefresh is a boolean that indicates if the table needs to be refreshed
    isNeedToRefresh?: boolean;
    // isDataFetching is a boolean that indicates if the data is being fetched
    isDataFetching?: boolean;
    // defaultSorting is the default sorting of the table
    defaultSorting?: SortingState;
    tableName?: string;
}
// overscan is the number of rows to render outside the viewport
const OVERSCAN = 5;
// rowHeight is the estimated height of a row
const ROW_HEIGHT = 50;

const VirtualTable = <TData,>({
    data,
    columns,
    onRowClick,
    customClass,
    rowSize,
    isLoading,
    paginationText,
    isNeedToRefresh,
    isDataFetching,
    defaultSorting,
    tableName,
}: VirtualTableProps<TData>) => {
    // parentRef is a ref to the parent div of the table
    const parentRef = useRef<HTMLDivElement | null>(null);
    const [sorting, setSorting] = useState<SortingState>(defaultSorting ?? []);

    // table is the react table instance
    // data is the data to be rendered in the table
    // columns is the columns to be rendered in the table
    // getCoreRowModel is the core row model
    // getExpandedRowModel is the expanded row model
    // filterFns is the filter functions
    const table = useReactTable({
        data,
        columns,
        state: {
            sorting,
        },
        onSortingChange: setSorting,
        getCoreRowModel: getCoreRowModel(),
        getExpandedRowModel: getExpandedRowModel(),
        getSortedRowModel: getSortedRowModel(),
        filterFns: {
            fuzzy: fuzzyFilter,
        },
        debugTable: true,
        debugHeaders: true,
        debugColumns: false,
        manualSorting: false,
    });
    const { rows } = table.getRowModel();
    // rowVirtualizer is the virtualizer for the rows which will be used to render the rows
    // count is the number of rows to render
    // estimateSize is a function that estimates the size of a row
    // getScrollElement is a function that returns the scroll element
    // getItemKey is a function that returns the key of a row
    const rowVirtualizer = useVirtualizer({
        count: rows.length,
        estimateSize: () => rowSize ?? ROW_HEIGHT,
        getScrollElement: () => parentRef.current,
        getItemKey: (index) => rows[index].id,
        overscan: OVERSCAN,
    });
    useEffect(() => {
        rowVirtualizer?.scrollToIndex(0);
    }, [sorting, rowVirtualizer]);

    // if the table is loading, render a loading state
    if (isLoading) {
        return <LoadingState title="" />;
    }
    // render the table
    return (
        // add a border and rounded corners to the table container
        <Box className={`${customClass}`} ref={parentRef}>
            <table className="w-full">
                <thead className={'bg-white sticky top-0 z-50'}>
                    {table.getHeaderGroups().map((headerGroup) => (
                        <tr
                            key={headerGroup.id}
                            className="grid w-full bg-gray-50"
                            style={{
                                gridTemplateColumns: headerGroup.headers
                                    .map((header) => {
                                        if (header.column.getSize() >= 200) {
                                            return `minmax(${header.column.getSize()}px, 1fr)`;
                                        }
                                        if (header.column.getSize() <= 100) {
                                            return `${header.column.getSize()}px`;
                                        }
                                        return `minmax(0, 1fr)`;
                                    })
                                    .join(' '),
                            }}
                        >
                            {headerGroup.headers.map((header) => {
                                return (
                                    <th
                                        key={header.id}
                                        colSpan={header.colSpan}
                                        className={`px-3.5 py-2.5 text-xs font-normal text-left text-gray-500 uppercase ${
                                            header.column.getCanSort()
                                                ? 'cursor-pointer select-none hover:bg-gray-100'
                                                : ''
                                        }`}
                                        onClick={
                                            header.column.getCanSort()
                                                ? header.column.getToggleSortingHandler()
                                                : undefined
                                        }
                                    >
                                        {header.isPlaceholder ? null : (
                                            <>
                                                <Box className="flex items-center gap-1">
                                                    {flexRender(
                                                        header.column.columnDef
                                                            .header,
                                                        header.getContext(),
                                                    )}
                                                    {{
                                                        asc: (
                                                            <CaretUp
                                                                color="rgb(var(--color-blu-800))"
                                                                size={14}
                                                            />
                                                        ),
                                                        desc: (
                                                            <CaretDown
                                                                color="rgb(var(--color-blu-800))"
                                                                size={14}
                                                            />
                                                        ),
                                                        false: header.column.getCanSort() && (
                                                            <CaretUpDown
                                                                color="rgb(var(--color-gray-500))"
                                                                size={14}
                                                            />
                                                        ),
                                                    }[
                                                        header.column.getIsSorted() as string
                                                    ] ?? null}
                                                </Box>
                                            </>
                                        )}
                                    </th>
                                );
                            })}
                        </tr>
                    ))}
                </thead>
                <tbody
                    className="relative divide-y bg-white"
                    style={{ height: `${rowVirtualizer.getTotalSize()}px` }}
                >
                    {isDataFetching && isNeedToRefresh ? (
                        <TableLoader />
                    ) : (
                        <>
                            {rowVirtualizer
                                .getVirtualItems()
                                .map((virtualRow) => {
                                    const row = rows[virtualRow.index];
                                    return (
                                        <tr
                                            key={row.id}
                                            data-index={virtualRow.index}
                                            ref={(node) =>
                                                rowVirtualizer.measureElement(
                                                    node,
                                                )
                                            }
                                            className={`grid absolute w-full  ${
                                                onRowClick
                                                    ? 'cursor-pointer hover:bg-gray-50'
                                                    : ''
                                            }`}
                                            style={{
                                                transform: `translateY(${virtualRow.start}px)`,
                                                gridTemplateColumns: row
                                                    .getVisibleCells()
                                                    .map((cell) => {
                                                        if (
                                                            cell.column.getSize() >=
                                                            200
                                                        ) {
                                                            return `minmax(${cell.column.getSize()}px, 1fr)`;
                                                        }
                                                        if (
                                                            cell.column.getSize() <=
                                                            100
                                                        ) {
                                                            return `${cell.column.getSize()}px`;
                                                        }
                                                        return `minmax(0, 1fr)`;
                                                    })
                                                    .join(' '),
                                            }}
                                            onClick={() =>
                                                onRowClick && onRowClick(row)
                                            }
                                        >
                                            {row
                                                .getVisibleCells()
                                                .map((cell) => {
                                                    return (
                                                        <td
                                                            key={cell.id}
                                                            className="px-3.5 py-2.5"
                                                        >
                                                            {flexRender(
                                                                cell.column
                                                                    .columnDef
                                                                    .cell,
                                                                cell.getContext(),
                                                            )}
                                                        </td>
                                                    );
                                                })}
                                        </tr>
                                    );
                                })}
                        </>
                    )}
                </tbody>
                {/* render the pagination text if it is provided and there are rows */}
                {paginationText && table.getRowModel().rows?.length !== 0 && (
                    <tr className="sticky bottom-0 w-full bg-gray-50">
                        <td
                            colSpan={table.getAllColumns().length}
                            className="px-3 py-2"
                        >
                            <span className="text-sm text-gray-600">
                                {paginationText}
                            </span>
                        </td>
                    </tr>
                )}
            </table>
            {table.getRowModel().rows?.length === 0 && (
                <NoDataFound name={tableName} />
            )}
        </Box>
    );
};

export default VirtualTable;
