import { Image } from 'cornerstone-core';
import cornerstoneTools from 'cornerstone-tools';
import { IMeasurementData, IPoint } from '../../DicomViewerHelper/interface';
import { EventData, samPredictionHelper } from '../../SAMPredictionHelper';
import {
    ClickTypeStrategy,
    PointMeasurementFactory,
    SAMPointPredictionToolConfiguration,
    strategyClickTypeMap,
} from './entities';

const BaseAnnotationTool = cornerstoneTools.importInternal('base/BaseAnnotationTool');
const probeCursor = cornerstoneTools.importInternal('tools/cursors').probeCursor;
const deepmerge = cornerstoneTools.importInternal('util/deepmerge');
const throttle = cornerstoneTools.importInternal('util/throttle');
const calculateSUV = cornerstoneTools.importInternal('util/calculateSUV');
const getRGBPixels = cornerstoneTools.importInternal('util/getRGBPixels');
const getNewContext = cornerstoneTools.importInternal('drawing/getNewContext');
const draw = cornerstoneTools.importInternal('drawing/draw');
const drawCircle = cornerstoneTools.importInternal('drawing/drawCircle');
const { getToolState, external, getModule } = cornerstoneTools;

export class SAMPointPredictionTool extends BaseAnnotationTool {
    constructor(props: SAMPointPredictionToolConfiguration = {}) {
        props = deepmerge(
            {
                name: 'SAMPointPrediction',
                supportedInteractionTypes: ['Mouse'],
                svgCursor: probeCursor,
                configuration: {
                    renderDashed: false,
                    handleRadius: 6,
                },
            },
            props
        );
        super(props);

        this.throttledUpdateCachedStats = throttle(this.updateCachedStats, 110);
    }

    getPrediction(eventData: EventData, strategy: ClickTypeStrategy) {
        const _point: IPoint = {
            ...eventData.currentPoints.image,
            clickType: strategyClickTypeMap[strategy],
        };

        samPredictionHelper.getPrediction({ image: eventData.image, element: eventData.element, point: _point });
    }

    getStrategy(eventData: EventData): ClickTypeStrategy {
        if (!samPredictionHelper.rect && !samPredictionHelper.clicks?.length) return 'add';
        return this.configuration.strategy === 'remove' || eventData.event.ctrlKey ? 'remove' : 'add';
    }

    pointNearTool() {
        return false;
    }

    mouseMoveCallback(e: MouseEvent) {
        e.stopImmediatePropagation?.();
    }

    updateCachedStats(image: Image, element: HTMLElement, data: any) {
        const x = Math.round(data.handles.end.x);
        const y = Math.round(data.handles.end.y);

        const stats: any = {};

        if (x >= 0 && y >= 0 && x < image.columns && y < image.rows) {
            stats.x = x;
            stats.y = y;

            if (image.color) {
                stats.storedPixels = getRGBPixels(element, x, y, 1, 1);
            } else {
                stats.storedPixels = external.cornerstone.getStoredPixels(element, x, y, 1, 1);
                stats.sp = stats.storedPixels[0];
                stats.mo = stats.sp * image.slope + image.intercept;
                stats.suv = calculateSUV(image, stats.sp);
            }
        }

        data.cachedStats = stats;
        data.invalidated = false;
    }

    createNewMeasurement(eventData: EventData): any {
        if (!eventData?.currentPoints?.image) return;

        let strategy = this.getStrategy(eventData);

        this.getPrediction(eventData, strategy);

        return PointMeasurementFactory(eventData.currentPoints.image, strategy);
    }

    renderToolData(evt: CustomEvent<any>) {
        const eventData = evt.detail;
        const { handleRadius, renderDashed } = this.configuration;
        const context: CanvasRenderingContext2D = getNewContext(eventData.canvasContext.canvas);
        const lineDash = getModule('globalConfiguration').configuration.lineDash;

        const toolData: Array<IMeasurementData<any>> = getToolState(evt.currentTarget, this.name)?.data;

        if (!toolData) return;

        // We have tool data for this element - iterate over each one and draw it

        toolData.forEach(data => {
            if (data.visible === false) return;

            draw(context, (context: CanvasRenderingContext2D) => {
                const color = data.color;

                // Draw the handles
                const handleOptions: any = { handleRadius, color };

                if (renderDashed) {
                    handleOptions.lineDash = lineDash;
                }

                // drawHandles(context, eventData, data.handles, handleOptions);

                drawCircle(context, eventData.element, data.handles.end, handleRadius, handleOptions);
            });
        });
    }
}
