import { sortGeneSetsByJaccardSimilarity } from './sortGeneSetsByJaccardSimilarity';
import { getIntersectionSize } from './setUtils';
import { GeneSetEntry } from './types';
import { GeneSetSortMethod } from 'generated/types';

type Comparator = (
  [nameA, genesA]: GeneSetEntry,
  [nameB, genesB]: GeneSetEntry
) => number;

export function sortGeneSets(
  geneSets: GeneSetEntry[],
  panelGenes: Set<string>,
  mode: GeneSetSortMethod
) {
  if (mode === GeneSetSortMethod.jaccardSimilarity) {
    return sortGeneSetsByJaccardSimilarity(geneSets, panelGenes);
  }

  let comparator: Comparator;
  switch (mode) {
    case GeneSetSortMethod.proportionOfRepresentationAscending:
      comparator = reversedComparator(
        createRepresentationComparator(panelGenes)
      );
      break;
    case GeneSetSortMethod.numMatchingAscending:
      comparator = reversedComparator(
        createSelectedGenesComparator(panelGenes)
      );
      break;
    case GeneSetSortMethod.numMatchingDescending:
      comparator = createSelectedGenesComparator(panelGenes);
      break;
    case GeneSetSortMethod.proportionOfRepresentationDescending:
    default:
      comparator = createRepresentationComparator(panelGenes);
  }
  return geneSets.slice().sort(comparator);
}

const reversedComparator =
  (comparator: Comparator): Comparator =>
  (a, b) =>
    -comparator(a, b);

function createRepresentationComparator(panelGenes: Set<string>) {
  const intersections: Record<string, number> = {};
  const comparator = (
    [nameA, genesA]: GeneSetEntry,
    [nameB, genesB]: GeneSetEntry
  ) => {
    if (intersections[nameA] === undefined) {
      intersections[nameA] = getIntersectionSize(genesA, panelGenes);
    }
    if (intersections[nameB] === undefined) {
      intersections[nameB] = getIntersectionSize(genesB, panelGenes);
    }

    const fractionA = intersections[nameA] / genesA.size;
    const fractionB = intersections[nameB] / genesB.size;
    if (fractionA !== fractionB) {
      // Sort by fraction DESC
      return fractionB - fractionA;
    } else {
      // Sort by alphabet ASC
      return nameA > nameB ? 1 : -1;
    }
  };
  return comparator;
}

function createSelectedGenesComparator(panelGenes: Set<string>) {
  const intersections: Record<string, number> = {};
  const comparator = (
    [nameA, genesA]: GeneSetEntry,
    [nameB, genesB]: GeneSetEntry
  ) => {
    if (intersections[nameA] === undefined) {
      intersections[nameA] = getIntersectionSize(genesA, panelGenes);
    }
    if (intersections[nameB] === undefined) {
      intersections[nameB] = getIntersectionSize(genesB, panelGenes);
    }

    if (intersections[nameA] !== intersections[nameB]) {
      // Sort by intersection size DESC
      return intersections[nameB] - intersections[nameA];
    } else if (genesA.size !== genesB.size) {
      // Sort by fraction DESC
      return genesA.size - genesB.size;
    } else {
      // Sort by alphabet ASC
      return nameA > nameB ? 1 : -1;
    }
  };
  return comparator;
}
