import {
    Grid,
    Paper,
    Table,
    TableBody,
    TableCell,
    TableContainer,
    TableHead,
    TableRow,
} from "@mui/material";
import { LANG, MVR_LABELS } from "../../../../constants";
import { useMemo, useState } from "react";

import { ClassificationMatrix } from "./ClassificationMatrix";
import { DataDisplay } from "../../components/DataDisplay/DataDisplay";
import { Histogram } from "../../components/DataDisplay/Histogram";
import { ImageExplorer } from "../../components/DataDisplay/ImageExplorer";
import PropTypes from "prop-types";

export const ClassificationTemplate = ({
    versionData,
    dataset,
    classes,
    modelIdCard,
    versionTimeline,
}) => {
    const [selectedCategoryIndexes, setSelectedCategoryIndexes] = useState([
        null,
        null,
    ]);

    const getClassName = (id) => {
        return classes.find(({ classId }) => classId === id)?.name;
    };

    const data = useMemo(() => {
        return versionData
            ? versionData.ModelVersionFiles.map((image) => {
                  const reference = dataset.find(
                      ({ name }) => name === image.imageName
                  );
                  return {
                      ...image,
                      prediction: image.prediction,
                      groundTruth: reference.annotation,
                  };
              })
            : [];
    }, [dataset, versionData]);

    const selectedImages = useMemo(() => {
        const [x, y] = selectedCategoryIndexes;
        const groundTruth = classes[x]?.classId;
        const prediction = classes[y]?.classId;
        return data
            .filter(
                (image) =>
                    image.groundTruth === groundTruth &&
                    image.prediction === prediction
            )
            .map((image) => {
                const file = dataset.find(
                    ({ name }) => name === image.imageName
                );
                return { ...image, path: file.path };
            });
    }, [classes, data, dataset, selectedCategoryIndexes]);

    const classificationMatrixData = useMemo(
        () =>
            classes.map(({ classId: classIdX }) => {
                const classData = data.filter(
                    (image) => image.prediction === classIdX
                );
                return classes.map(
                    ({ classId: classIdY }) =>
                        classData.filter(
                            (image) => image.groundTruth === classIdY
                        ).length
                );
            }),
        [data, classes]
    );

    // https://react.dev/learn/you-might-not-need-an-effect#adjusting-some-state-when-a-prop-changes
    const [prevData, setPrevData] = useState(data);
    if (data !== prevData) {
        setPrevData(data);
        setSelectedCategoryIndexes([null, null]);
    }

    return (
        <>
            <Grid item container xs={7} sx={{ p: 0.5 }}>
                <Grid item xs={12}>
                    {modelIdCard}
                </Grid>
                <Grid item xs={4}>
                    {versionTimeline}
                </Grid>
                <Grid item xs={8}>
                    <ClassificationMatrix
                        matrixData={classificationMatrixData}
                        classes={classes}
                        handleBoxClick={setSelectedCategoryIndexes}
                        selectedCategoryIndexes={selectedCategoryIndexes}
                    />
                </Grid>
            </Grid>

            <Grid item xs={5} sx={{ p: 0.5 }}>
                <DataDisplay
                    components={[
                        {
                            name: "Explorateur d'images",
                            jsx: (
                                <ImageExplorer
                                    key={selectedImages}
                                    images={selectedImages}
                                />
                            ),
                        },
                        {
                            name: MVR_LABELS.CONFIDENCE_SCORE[LANG],
                            jsx: (
                                <Histogram
                                    data={data.map(
                                        ({ predictionScore }) => predictionScore
                                    )}
                                />
                            ),
                        },
                        {
                            name: "Table de classification",
                            jsx: (
                                <TableContainer component={Paper}>
                                    <Table size="small">
                                        <TableHead>
                                            <TableRow>
                                                <TableCell>Name</TableCell>
                                                <TableCell align="right">
                                                    Prediction
                                                </TableCell>
                                                <TableCell align="right">
                                                    {
                                                        MVR_LABELS
                                                            .CONFIDENCE_SCORE[
                                                            LANG
                                                        ]
                                                    }
                                                </TableCell>
                                            </TableRow>
                                        </TableHead>
                                        <TableBody>
                                            {data.map((row) => (
                                                <TableRow
                                                    key={row.imageName}
                                                    sx={{
                                                        "&:last-child td, &:last-child th":
                                                            { border: 0 },
                                                    }}
                                                >
                                                    <TableCell
                                                        align="left"
                                                        sx={{
                                                            whiteSpace:
                                                                "nowrap",
                                                            overflow: "hidden",
                                                            textOverflow:
                                                                "ellipsis",
                                                            maxWidth: 200,
                                                        }}
                                                    >
                                                        {row.imageName}
                                                    </TableCell>
                                                    <TableCell align="right">
                                                        {getClassName(
                                                            row.groundTruth
                                                        )}
                                                    </TableCell>
                                                    <TableCell align="right">
                                                        {getClassName(
                                                            row.prediction
                                                        )}
                                                    </TableCell>
                                                    <TableCell align="right">
                                                        {row.predictionScore.toFixed(
                                                            2
                                                        )}
                                                    </TableCell>
                                                </TableRow>
                                            ))}
                                        </TableBody>
                                    </Table>
                                </TableContainer>
                            ),
                        },
                    ]}
                />
            </Grid>
        </>
    );
};

ClassificationTemplate.propTypes = {
    classes: PropTypes.array.isRequired,
    modelIdCard: PropTypes.node.isRequired,
    versionTimeline: PropTypes.node.isRequired,
};
