import cv from "@techstark/opencv-js";
import * as ort from "onnxruntime-web";

export type Analysis = {
  imageData: {
    width: number;
    height: number;
  };
  instanceData: {
    backgrounds: { x: number; y: number; width: number; height: number }[];
    clusters: {
      x: number;
      y: number;
      width: number;
      height: number;
      estimatedEggCount: number;
    }[];
    individualEggs: { x: number; y: number; width: number; height: number }[];
  };
  count: number;
};

async function algorithm(imgElement: HTMLImageElement | HTMLCanvasElement) {
  const CROP_SIZE = [5, 40];
  const EXPAND_BORDER = 5;
  const EGG_SIZE = 46;
  const max = Math.max,
    min = Math.min,
    ceil = Math.ceil,
    floor = Math.floor,
    round = Math.round;
  // let url = window.location.origin + "/models/onnx_model.onnx";
  const sess = await ort.InferenceSession.create(
    "https://assets.apf.cloud/models/onnx_model.onnx",
    { executionProviders: ["webgl", "wasm"] }
  );

  function predict_boxes(img: cv.Mat) {
    cv.cvtColor(img, img, cv.COLOR_RGBA2GRAY);
    cv.threshold(img, img, 50, 255, cv.THRESH_BINARY);
    let contours = new cv.MatVector();
    cv.findContours(
      img,
      contours,
      new cv.Mat(),
      cv.RETR_LIST,
      cv.CHAIN_APPROX_NONE
    );
    img.delete();

    const points: number[][][] = []; //unpack contours into 2d array
    for (let i = 0; i < (contours.size() as unknown as number); i++) {
      const ci = contours.get(i);
      points.push([]);
      for (let j = 0; j < ci.data32S.length; j += 2) {
        const p: number[] = [];
        p.push(ci.data32S[j], ci.data32S[j + 1]);
        points[i].push(p);
      }
    }
    contours.delete();

    const boxes = []; //calculate bounding boxes
    for (let area of points) {
      if (area.length === 1) {
        //exclude single points
        continue;
      }
      const x: number[] = [],
        y: number[] = [];
      for (let coordinate of area) {
        x.push(coordinate[0]);
        y.push(coordinate[1]);
      }
      let minx = Math.min(...x),
        maxx = Math.max(...x);
      let miny = Math.min(...y),
        maxy = Math.max(...y);
      if (maxx - minx < CROP_SIZE[0] && maxy - miny < CROP_SIZE[0]) {
        continue; // exclude boxes too small
      }
      if (maxx - minx > CROP_SIZE[1] || maxy - miny > CROP_SIZE[1]) {
        continue; // exclude boxes too big
      }
      boxes.push([minx, miny, maxx - minx, maxy - miny]);
    }
    return boxes;
  }

  function preprocess(box: number[], img: cv.Mat) {
    const crop_params: [number, number, number, number] = [
      max(box[0] - EXPAND_BORDER, 0),
      max(box[1] - EXPAND_BORDER, 0),
      min(box[2] + EXPAND_BORDER * 2, img.cols - box[0]),
      min(box[3] + EXPAND_BORDER * 2, img.rows - box[1]),
    ];
    const rect = new cv.Rect(...crop_params);
    const crop = img.roi(rect).clone();
    const pad = new cv.Mat();
    cv.copyMakeBorder(
      crop,
      pad,
      ceil((50 - crop.rows) / 2),
      floor((50 - crop.rows) / 2),
      ceil((50 - crop.cols) / 2),
      floor((50 - crop.cols) / 2),
      cv.BORDER_CONSTANT,
      new cv.Scalar(0, 0, 0, 255)
    );
    crop.delete();
    return pad;
  }

  function format_for_onnx(img: cv.Mat) {
    const normalize = function (mean: number, std: number) {
      return function (px: number) {
        return (px / 255 - mean) / std;
      };
    };
    let red_channel = [],
      green_channel = [],
      blue_channel = [];
    for (let row = 0; row < img.rows; row++) {
      for (let col = 0; col < img.cols; col++) {
        red_channel.push(
          img.data[row * img.cols * img.channels() + col * img.channels()]
        );
        green_channel.push(
          img.data[row * img.cols * img.channels() + col * img.channels() + 1]
        );
        blue_channel.push(
          img.data[row * img.cols * img.channels() + col * img.channels() + 2]
        );
      }
    }
    red_channel = red_channel.map(normalize(0.485, 0.229));
    green_channel = green_channel.map(normalize(0.456, 0.224));
    blue_channel = blue_channel.map(normalize(0.406, 0.225));
    return red_channel.concat(green_channel, blue_channel);
  }

  async function model(arr: number[]) {
    const input = new ort.Tensor(new Float32Array(arr), [1, 3, 50, 50]);
    const outputMap = await sess.run({ "input.1": input });
    const predictions = outputMap["30"].data as Float32Array; // If you know the specific type, replace Float32Array with it
    let max = 0;
    let maxIndex = 0;
    for (let i = 0; i < predictions.length; i++) {
      if (predictions[i] > max) {
        max = predictions[i];
        maxIndex = i;
      }
    }
    return maxIndex;
  }

  const analysis: Analysis = {
    imageData: {
      width: imgElement.width,
      height: imgElement.height,
    },
    instanceData: {
      backgrounds: [],
      clusters: [],
      individualEggs: [],
    },
    count: 0,
  };
  const img = cv.imread(imgElement);
  const original = img.clone();
  const boxes = predict_boxes(img);
  for (let box of boxes) {
    const preprocessedImage: cv.Mat = preprocess(box, original);
    const flatTensor = format_for_onnx(preprocessedImage);
    preprocessedImage.delete();
    switch (await model(flatTensor)) {
      case 0:
        analysis.instanceData.backgrounds.push({
          x: box[0],
          y: box[1],
          width: box[2],
          height: box[3],
        });
        break;
      case 1:
        const crop = original
          .roi(new cv.Rect(box[0], box[1], box[2], box[3]))
          .clone();
        cv.cvtColor(crop, crop, cv.COLOR_RGBA2GRAY);
        cv.threshold(crop, crop, 50, 255, cv.THRESH_BINARY);
        let pxVolume = 0;
        for (let i of crop.data) {
          if (i === 0) pxVolume++;
        }
        analysis.instanceData["clusters"].push({
          x: box[0],
          y: box[1],
          width: box[2],
          height: box[3],
          estimatedEggCount: round(pxVolume / EGG_SIZE),
        });
        analysis.count += round(pxVolume / EGG_SIZE);
        crop.delete();
        break;
      case 2:
        analysis.instanceData["individualEggs"].push({
          x: box[0],
          y: box[1],
          width: box[2],
          height: box[3],
        });
        analysis.count++;
    }
  }
  original.delete();
  console.log("algorithm concluded");
  return analysis;
}

export default algorithm;
