import LoadingState from '@components/common/LoadingState';
import { Box } from '@mantine/core';
import {
    flexRender,
    getCoreRowModel,
    getExpandedRowModel,
    useReactTable,
    type ColumnDef,
    type Row,
} from '@tanstack/react-table';
import { useVirtualizer } from '@tanstack/react-virtual';
import { useRef } from 'react';
import { fuzzyFilter } from './utils';

interface VirtualTableProps<TData> {
    // data is an array of objects that will be rendered in the table
    data: TData[];
    // columns is an array of objects that define the columns of the table
    columns: ColumnDef<TData, any>[];
    // onRowClick is a function that will be called when a row is clicked
    onRowClick?: (row: Row<TData>) => void;
    // customClass is a string that will be added to the table container
    customClass?: string;
    // rowSize is the estimated height of a row
    rowSize?: number;
    // columnSize is the estimated width of a column
    columnSize?: number;
    // isLoading is a boolean that indicates if the table is loading
    isLoading?: boolean;
    // paginationText is the text to be displayed in the pagination section
    paginationText?: string;
}
// overscan is the number of rows to render outside the viewport
const OVERSCAN = 5;
// rowHeight is the estimated height of a row
const ROW_HEIGHT = 50;

const VirtualTable = <TData,>({
    data,
    columns,
    onRowClick,
    customClass,
    rowSize,

    isLoading,
    paginationText,
}: VirtualTableProps<TData>) => {
    // parentRef is a ref to the parent div of the table
    const parentRef = useRef<HTMLDivElement | null>(null);
    // table is the react table instance
    // data is the data to be rendered in the table
    // columns is the columns to be rendered in the table
    // getCoreRowModel is the core row model
    // getExpandedRowModel is the expanded row model
    // filterFns is the filter functions
    const table = useReactTable({
        data,
        columns,
        getCoreRowModel: getCoreRowModel(),
        getExpandedRowModel: getExpandedRowModel(),
        filterFns: {
            fuzzy: fuzzyFilter,
        },
        debugTable: true,
        debugHeaders: true,
        debugColumns: false,
    });
    const { rows } = table.getRowModel();
    // rowVirtualizer is the virtualizer for the rows which will be used to render the rows
    // count is the number of rows to render
    // estimateSize is a function that estimates the size of a row
    // getScrollElement is a function that returns the scroll element
    // getItemKey is a function that returns the key of a row

    const rowVirtualizer = useVirtualizer({
        count: data.length,
        estimateSize: () => rowSize ?? ROW_HEIGHT, // Estimate row height
        getScrollElement: () => parentRef.current,
        getItemKey: (index) => (data[index] as any).id ?? index,
        overscan: OVERSCAN,
    });

    // if the table is loading, render a loading state
    if (isLoading) {
        return <LoadingState title="" />;
    }
    // render the table
    return (
        // add a border and rounded corners to the table container
        <Box className="border border-gray-100 rounded-2xl p-0.5">
            <Box className="overflow-hidden border border-gray-200 rounded-xl">
                <div ref={parentRef} className={`overflow-auto ${customClass}`}>
                    <table className="w-full">
                        <thead className={'bg-white sticky top-0 z-50'}>
                            {table.getHeaderGroups().map((headerGroup) => (
                                <tr
                                    key={headerGroup.id}
                                    className="w-full bg-shade-4"
                                >
                                    {headerGroup.headers.map((header) => {
                                        return (
                                            <th
                                                key={header.id}
                                                colSpan={header.colSpan}
                                                className={`px-3.5 py-2.5 text-xs font-normal text-left text-gray-500 uppercase`}
                                            >
                                                {header.isPlaceholder ? null : (
                                                    <>
                                                        <Box className="w-max">
                                                            {flexRender(
                                                                header.column
                                                                    .columnDef
                                                                    .header,
                                                                header.getContext(),
                                                            )}
                                                        </Box>
                                                    </>
                                                )}
                                            </th>
                                        );
                                    })}
                                </tr>
                            ))}
                        </thead>
                        <tbody className="divide-y">
                            {rowVirtualizer
                                .getVirtualItems()
                                .map((virtualRow, index) => {
                                    const row = rows[virtualRow.index];
                                    return (
                                        <tr
                                            key={row.id}
                                            style={{
                                                height: `${virtualRow.size}px`,
                                                transform: `translateY(${
                                                    virtualRow.start -
                                                    index * virtualRow.size
                                                }px)`,
                                            }}
                                            onClick={() =>
                                                onRowClick && onRowClick(row)
                                            }
                                            className={`${
                                                Boolean(onRowClick)
                                                    ? 'cursor-pointer hover:bg-gray-50'
                                                    : ''
                                            }`}
                                        >
                                            {row
                                                .getVisibleCells()
                                                .map((cell) => {
                                                    return (
                                                        <td
                                                            key={cell.id}
                                                            className="px-3.5 py-2.5"
                                                        >
                                                            {flexRender(
                                                                cell.column
                                                                    .columnDef
                                                                    .cell,
                                                                cell.getContext(),
                                                            )}
                                                        </td>
                                                    );
                                                })}
                                        </tr>
                                    );
                                })}
                        </tbody>
                    </table>
                </div>
                {/* render the pagination text if it is provided and there are rows */}
                {paginationText && table.getRowModel().rows?.length !== 0 && (
                    <Box className="flex items-center px-3 py-2 border-t border-shade-4">
                        <span className="text-sm text-gray-600">
                            {paginationText}
                        </span>
                    </Box>
                )}
            </Box>
        </Box>
    );
};

export default VirtualTable;
