import { Image } from 'cornerstone-core';
import { type InferenceSession, type TensorConstructor } from 'onnxruntime-web';
import { IPoint } from '../DicomViewerHelper/interface';
import { EmbeddingAPI } from './EmbeddingAPI';
import { PredictionState, StateFactory } from './entities';

export abstract class PredictionAPI extends EmbeddingAPI {
    Tensor: TensorConstructor;
    session: InferenceSession;

    public state: PredictionState = StateFactory();
    public historyState: Array<PredictionState> = [];
    public nextState: Array<PredictionState> = [];

    get predictionResult() {
        return this.state.result?.masks?.data as Float32Array;
    }

    get maskInput() {
        return this.state.result?.low_res_masks?.data as Float32Array;
    }

    get clicks() {
        return this.state.clicks;
    }

    get rect() {
        return this.state.rect;
    }
    set rect(value) {
        this.state.rect = value;
    }

    modelUrl = '/onnx/sam.onnx';

    protected image: Image;
    get height() {
        return this.image.height;
    }
    get width() {
        return this.image.width;
    }

    get modelScale() {
        return 1024 / Math.max(this.height, this.width);
    }

    async segment(embedding: string) {
        await this.initModel();

        const input = this.getModelInput(embedding);
        if (!input) return;

        try {
            this.state.result = (await this.session.run(input)) as any;

            return this.predictionResult;
        } catch (error) {
            console.error(error);
        }
    }

    async initModel() {
        if (this.session) return;
        const { Tensor, InferenceSession, env } = await import('onnxruntime-web');
        this.Tensor = Tensor;
        env.wasm.wasmPaths = this.wasmPaths;
        this.session = await InferenceSession.create(this.modelUrl);
    }

    getModelInput(embedding: string) {
        let pointCoords;
        let pointLabels;
        let pointCoordsTensor;
        let pointLabelsTensor;
        let useLastPred = true;
        const clicks = this.clicks;

        // Check there are input click prompts
        if (clicks) {
            let n = clicks.length;
            const points = this.transformPoints(clicks, this.height, this.width);
            // If there is no box input, a single padding point with
            // label -1 and coordinates (0.0, 0.0) should be concatenated
            // so initialize the array to support (n + 1) points.
            pointCoords = new Float32Array(2 * (n + 1));
            pointLabels = new Float32Array(n + 1);

            const clicksFromBox = this.rect ? 2 : 0;

            if (clicksFromBox) {
                const boxClicks = this.transformPoints(this.rect, this.height, this.width);
                // For box model need to include the box clicks in the point
                // coordinates and also don't need to include the extra
                // negative point

                pointCoords = new Float32Array(2 * (n + 2));
                pointLabels = new Float32Array(n + 2);

                boxClicks.forEach((point, index) => {
                    pointCoords[index] = point;
                });
                pointLabels[0] = 2.0; // UPPER_LEFT
                pointLabels[1] = 3.0; // BOTTOM_RIGHT

                useLastPred = false;
            }

            // Add clicks and scale to what SAM expects
            for (let i = 0; i < n; i++) {
                pointCoords[2 * (i + clicksFromBox)] = points[2 * i];
                pointCoords[2 * (i + clicksFromBox) + 1] = points[2 * i + 1];
                pointLabels[i + clicksFromBox] = clicks[i].clickType;
            }

            // Add in the extra point/label when only clicks and no box
            // The extra point is at (0, 0) with label -1

            if (!clicksFromBox) {
                pointCoords[2 * n] = 0.0;
                pointCoords[2 * n + 1] = 0.0;
                pointLabels[n] = -1.0;
                // update n for creating the tensor
                n = n + 1;
            }

            // Create the tensor
            pointCoordsTensor = new this.Tensor('float32', pointCoords, [1, n + clicksFromBox, 2]);
            pointLabelsTensor = new this.Tensor('float32', pointLabels, [1, n + clicksFromBox]);
        }

        if (pointCoordsTensor === undefined || pointLabelsTensor === undefined) return;

        const commonInputs = this.getCommonInput(embedding);

        if (useLastPred && this.maskInput) {
            commonInputs.mask_input = new this.Tensor('float32', this.maskInput, [1, 1, 256, 256]);
            commonInputs.has_mask_input = new this.Tensor('float32', new Float32Array([1]));
        }

        return {
            ...commonInputs,
            point_coords: pointCoordsTensor,
            point_labels: pointLabelsTensor,
        };
    }

    getCommonInput(embedding: string) {
        const uint8arr = Uint8Array.from(atob(embedding), c => c.charCodeAt(0));

        return {
            image_embeddings: new this.Tensor('float32', new Float32Array(uint8arr.buffer), [1, 256, 64, 64]),
            // original image size
            orig_im_size: new this.Tensor('float32', new Float32Array([this.height, this.width])),
            // empty mask
            mask_input: new this.Tensor('float32', new Float32Array(256 * 256), [1, 1, 256, 256]),
            has_mask_input: new this.Tensor('float32', new Float32Array([0])),
        };
    }
    getRectInput(embedding: string, points: Array<IPoint>) {
        return {
            ...this.getCommonInput(embedding),
            point_labels: new this.Tensor('float32', new Float32Array([2, 3]), [1, 2]),
            point_coords: new this.Tensor(
                'float32',
                new Float32Array(this.transformPoints(points, this.height, this.width)),
                [1, 2, 2]
            ),
        };
    }

    transformPoints(points: Array<IPoint>, oldH: number, oldW: number) {
        const [newH, newW] = this.getPreProcessShape(oldH, oldW);

        return points.flatMap(p => [p.x * (newW / oldW), p.y * (newH / oldH)]);
    }

    getPreProcessShape(oldH: number, oldW: number) {
        const scale = (1024 * 1.0) / Math.max(oldH, oldW);
        let newh = oldH * scale;
        let neww = oldW * scale;
        neww = parseInt((neww + 0.5) as any);
        newh = parseInt((newh + 0.5) as any);
        return [newh, neww];
    }

    wasmPaths = {
        'ort-wasm.wasm': '/onnx/ort-wasm.wasm',
        'ort-wasm-simd.wasm': '/onnx/ort-wasm-simd.wasm',
        'ort-wasm-threaded.wasm': '/onnx/ort-wasm-threaded.wasm',
        'ort-wasm-simd-threaded.wasm': '/onnx/ort-wasm-simd-threaded.wasm',
    } as const;
}
