import { Dictionary } from "common";
import type { BasicModelStatsData, ConfidenceDistributionData, ConfidenceDistributionGTlessData, ConfusionMatrixData, DatasetClassDistributionData, ExplorePredictionsData, IClassPerformanceItem, ILayoutState, ILiftChartData, IMostConfusedClassesData, IMostConfusedImagesGraphParams, IPerformanceByConfidenceData, IPerformanceByConfidenceGraphParams, IPerformanceByIoUThresholdData, IPerformanceByIoUThresholdGraphParams, IPrecisionRecallCurveData, IROCCurveData, MostConfusedImagesData, ObjectCountsData, PlotTopLossesData, PlotTopMissesData, PopulationDistributionData, PredictionDistributionData, SoftmaxHistogramData, THighlightedMetrics, TrainingVsValidationPerformanceData } from "../../metrics/entities";

export interface IComparisonMetrics {
    highlightedMetrics: IComparisonBaseData<THighlightedMetrics>;
    confusionMatrix: IComparisonBaseData<ConfusionMatrixData>;
    basicModelStats: IComparisonBaseData<BasicModelStatsData>;
    plotTopLosses: IComparisonBaseData<PlotTopLossesData>;
    performanceByConfidence: IComparisonBaseData<IPerformanceByConfidenceData, IPerformanceByConfidenceGraphParams>;
    trainingVsValidationPerformance: IComparisonBaseData<TrainingVsValidationPerformanceData>;
    datasetClassDistribution: IComparisonBaseData<DatasetClassDistributionData>;
    populationDistribution: IComparisonBaseData<PopulationDistributionData>;
    mostConfusedClasses: IComparisonBaseData<Array<IMostConfusedClassesData>>;
    mostConfusedImages: IComparisonBaseData<MostConfusedImagesData, IMostConfusedImagesGraphParams>;
    classPerformance: IComparisonBaseData<IClassPerformanceItem>;
    confidenceDistribution: IComparisonBaseData<ConfidenceDistributionData>;
    rocCurve: IComparisonBaseData<IROCCurveData>;
    predictionDistribution: IComparisonBaseData<PredictionDistributionData>;
    explorePredictions: IComparisonBaseData<ExplorePredictionsData>,
    precisionRecallCurve: IComparisonBaseData<IPrecisionRecallCurveData>,
    softmaxHistogram: IComparisonBaseData<SoftmaxHistogramData>,
    confidenceDistributionGTless: IComparisonBaseData<ConfidenceDistributionGTlessData>,
    plotTopMisses: IComparisonBaseData<PlotTopMissesData>,
    performanceByIoUThreshold: IComparisonBaseData<IPerformanceByIoUThresholdData, IPerformanceByIoUThresholdGraphParams>,
    objectCounts: IComparisonBaseData<ObjectCountsData>,
    mismatchedPredictions: IComparisonBaseData<any>,
    liftChart: IComparisonBaseData<ILiftChartData>,
}

export function ComparisonMetricsFactory(data?: Partial<IComparisonMetrics>): IComparisonMetrics {
    return {
        highlightedMetrics: ComparisonBaseDataFactory(data?.highlightedMetrics?.data),
        confusionMatrix: ComparisonBaseDataFactory(data?.confusionMatrix?.data),
        basicModelStats: ComparisonBaseDataFactory(data?.basicModelStats?.data, { Validation: [], Training: [] }),
        plotTopLosses: ComparisonBaseDataFactory(data?.plotTopLosses?.data, []),
        performanceByConfidence: ComparisonBaseDataFactory(data?.performanceByConfidence?.data),
        trainingVsValidationPerformance: ComparisonBaseDataFactory(data?.trainingVsValidationPerformance, {}),
        datasetClassDistribution: ComparisonBaseDataFactory(data?.datasetClassDistribution, { Validation: {}, Training: {} }),
        populationDistribution: ComparisonBaseDataFactory(data?.populationDistribution, { Validation: {}, Training: {} }),
        mostConfusedClasses: ComparisonBaseDataFactory(data?.mostConfusedClasses, []),
        mostConfusedImages: ComparisonBaseDataFactory(data?.mostConfusedImages, []),
        classPerformance: ComparisonBaseDataFactory(data?.classPerformance, {}),
        confidenceDistribution: ComparisonBaseDataFactory(data?.confidenceDistribution, {}),
        rocCurve: ComparisonBaseDataFactory(data?.rocCurve, []),
        predictionDistribution: ComparisonBaseDataFactory(data?.predictionDistribution, []),
        explorePredictions: ComparisonBaseDataFactory(data?.explorePredictions, []),
        precisionRecallCurve: ComparisonBaseDataFactory(data?.precisionRecallCurve, {}),
        softmaxHistogram: ComparisonBaseDataFactory(data?.softmaxHistogram, []),
        confidenceDistributionGTless: ComparisonBaseDataFactory(data?.confidenceDistributionGTless, {}),
        plotTopMisses: ComparisonBaseDataFactory(data?.plotTopMisses, []),
        performanceByIoUThreshold: ComparisonBaseDataFactory(data?.performanceByIoUThreshold, {}),
        objectCounts: ComparisonBaseDataFactory(data?.objectCounts, {}),
        mismatchedPredictions: ComparisonBaseDataFactory(data?.mismatchedPredictions, []),
        liftChart: ComparisonBaseDataFactory(data?.liftChart, []),
    };
}

export interface IComparisonBaseData<T = any, P = any> {
    layoutState: ILayoutState;
    graphParams: P,
    defaultParams?: P,
    data: Dictionary<T>;
}


export function ComparisonBaseDataFactory(data: any, defaultData?: any): IComparisonBaseData {
    return {
        layoutState: {
            isLoading: false,
            isEmpty: false
        },
        graphParams: {},
        defaultParams: {},
        data: data ?? defaultData
    };
}
// export type ComparisonChartKey = keyof IComparisonMetrics;

export enum ComparisonChartKey {
    HighlightedMetrics = "highlightedMetrics",
    ConfusionMatrix = "confusionMatrix",
    BasicModelStats = "basicModelStats",
    PlotTopLosses = "plotTopLosses",
    PerformanceByConfidence = "performanceByConfidence",
    TrainingVsValidationPerformance = "trainingVsValidationPerformance",
    DatasetClassDistribution = "datasetClassDistribution",
    PopulationDistribution = "populationDistribution",
    MostConfusedClasses = "mostConfusedClasses",
    MostConfusedImages = "mostConfusedImages",
    ClassPerformance = "classPerformance",
    ConfidenceDistribution = "confidenceDistribution",
    RocCurve = "rocCurve",
    PredictionDistribution = "predictionDistribution",
    ExplorePredictions = "explorePredictions",
    PrecisionRecallCurve = "precisionRecallCurve",
    SoftmaxHistogram = "softmaxHistogram",
    ConfidenceDistributionGTless = "confidenceDistributionGTless",
    PlotTopMisses = "plotTopMisses",
    PerformanceByIoUThreshold = "performanceByIoUThreshold",
    ObjectCounts = "objectCounts",
    MismatchedPredictions = "mismatchedPredictions",
    LiftChart = "liftChart",
}