import { agnes } from 'ml-hclust';
import { getIntersection, getIntersectionSize, getUnionSize } from './setUtils';
import { GeneSetEntry } from './types';

export function sortGeneSetsByJaccardSimilarity(
  geneSets: GeneSetEntry[],
  panelGenes: Set<string>
) {
  const start = performance.now();

  const geneSetsWithIntersection: GeneSetEntry[] = [];
  const geneSetsWithoutIntersection: GeneSetEntry[] = [];
  for (let i = 0; i < geneSets.length; i++) {
    const geneSet = geneSets[i];
    const size = getIntersectionSize(geneSet[1], panelGenes);
    if (size) {
      geneSetsWithIntersection.push(geneSet);
    } else {
      geneSetsWithoutIntersection.push(geneSet);
    }
  }

  const geneSetsJaccardDistance = (
    [nameA, genesA]: GeneSetEntry,
    [nameB, genesB]: GeneSetEntry
  ) => {
    if (nameA === nameB) return 0;
    const intersection1 = getIntersection(genesA, panelGenes);
    if (!intersection1.length) return 1;
    const intersectionSize = getIntersectionSize(
      new Set(intersection1),
      genesB
    );
    if (!intersectionSize) return 1;
    const unionSize = getUnionSize(genesA, genesB);
    const dist = 1 - intersectionSize / unionSize;
    return dist;
  };

  const cluster = agnes(geneSetsWithIntersection, {
    distanceFunction: geneSetsJaccardDistance
  });
  const indices = cluster?.indices() || [];
  const sorted = indices.map((i) => geneSetsWithIntersection[i]);

  const result = geneSetsWithoutIntersection.concat(sorted);

  const time = performance.now() - start;
  console.log('sortGeneSetsByJaccardSimilarity', Math.round(time), 'ms');

  return result;
}
