import { agnes } from 'ml-hclust';
import { TranscriptModel } from 'store/createPanel/models';
import { getIntersectionSize, getUnionSize } from './setUtils';
import { GeneSetEntry } from './types';
import { GeneSortMethod } from 'generated/types';

type Comparator = (a: TranscriptModel, b: TranscriptModel) => number;

export function sortGenes(
  genes: TranscriptModel[],
  geneSets: GeneSetEntry[],
  mode: GeneSortMethod
) {
  if (mode === GeneSortMethod.jaccardSimilarity) {
    return sortGenesByJaccardSimilarity(genes, geneSets);
  }

  let comparator: Comparator;
  switch (mode) {
    case GeneSortMethod.alphabetic:
      comparator = (a, b) => a.geneName.localeCompare(b.geneName);
      break;
    case GeneSortMethod.geneAbundanceAscending:
      comparator = (a, b) => a.abundance - b.abundance;
      break;
    case GeneSortMethod.geneAbundanceDescending:
    default:
      comparator = (a, b) => b.abundance - a.abundance;
      break;
  }
  return genes.slice().sort(comparator);
}

function sortGenesByJaccardSimilarity(
  panelGenes: TranscriptModel[],
  geneSets: GeneSetEntry[]
) {
  const start = performance.now();

  const setsWith = new Map<string, Set<string>>();

  const distanceFunction = (geneA: TranscriptModel, geneB: TranscriptModel) => {
    const geneNameA = (geneA?.geneName || '').toUpperCase();
    const geneNameB = (geneB?.geneName || '').toUpperCase();
    if (geneNameA === geneNameB) return 0;

    if (!setsWith.has(geneNameA))
      setsWith.set(geneNameA, getSetsWithGene(geneSets, geneNameA));
    if (!setsWith.has(geneNameB))
      setsWith.set(geneNameB, getSetsWithGene(geneSets, geneNameB));

    const setsWithGeneA = setsWith.get(geneNameA)!;
    const setsWithGeneB = setsWith.get(geneNameB)!;

    const unionSize = getUnionSize(setsWithGeneA, setsWithGeneB);
    if (!unionSize) return 1;

    const intersectionSize = getIntersectionSize(setsWithGeneA, setsWithGeneB);
    return 1 - intersectionSize / unionSize;
  };

  const cluster = agnes(panelGenes, {
    distanceFunction
  });
  const indices = cluster?.indices() || [];
  const sorted = indices.map((i) => panelGenes[i]);

  const time = performance.now() - start;
  console.log('sortGenesByJaccardSimilarity', Math.round(time), 'ms');
  return sorted;
}

function getSetsWithGene(geneSets: GeneSetEntry[], geneName: string) {
  const sets = geneName
    ? geneSets.filter(([, genes]) => genes.has(geneName)).map(([name]) => name)
    : [];
  return new Set(sets);
}
