// cmds/ScoringCmds.js
/**
 * @fileoverview
 *
 * @typedef {import('BMapsModel').Compound} Compound
 * @typedef {import('BMapsModel').MapCase} MapCase
 *
 * @typedef ScoringCmds
 * @type {object}
 * @property {scoreCompounds} ScoreCompounds
 */
import { App } from 'BMapsSrc/BMapsApp';
import { UserActions, UserCmd } from 'BMapsCmds/UserCmd';
import { DockingReference } from 'BMapsSrc/data_tools';
import { pointsCentroid, pointDistanceSquared } from 'BMapsUtil/atom_distance_utils';
import { ensureArray } from 'BMapsUtil/js_utils';
import { MoleculeLoadOptions, MolDataSource } from 'BMapsSrc/utils';
import { gatherEnergyErrors } from 'BMapsCmds/cmds_common';

/** @type {ScoringCmds} */
export const ScoringCmds = {
    ScoreCompounds: new UserCmd('ScoreCompounds', scoreCompounds),
};

/* IMPLEMENTATION */

/**
 *
 * @param {{ compounds: Compound[], proteins: MapCase[], scoringParams: object }}
 * @return {Promise<{
 * results: Map<MapCase, { topPose: Compound, topScore: number, errors: string[] }
 * errors: {name: string, message: string}[]
 * }>>}
 */
async function scoreCompounds({ compounds, proteins, scoringParams={} }) {
    /** @type {Map<MapCase, {topPose: Compound, topScore: number}} */
    const results = new Map();
    const errors = [];

    // Ensure all compounds exist on all proteins
    await loopOnProteins(proteins, async (mapCase) => {
        const caseData = App.Workspace.lookupCaseData(mapCase);
        for (const compound of compounds) {
            await ensureCompoundOnProtein(compound, caseData);
        }
    });

    // Score the compounds on the proteins
    await loopOnProteins(proteins, async (mapCase) => {
        for (const compound of compounds) {
            const score = await scoreCompoundForOne({ compound, mapCase, scoringParams });
            // If an error causes a compound to not get scored add it to the errors array.
            if (!score.topPose || !score.topScore) {
                errors.push(
                    `Unable to score ${compound.displayName} on ${mapCase.getShortName()}`,
                    ...score.errors
                );
                continue;
            }
            // If the compound still scores even though
            // it had an error just add the error to results.
            let perCompoundMap = results.get(compound);
            if (!perCompoundMap) {
                perCompoundMap = new Map();
                results.set(compound, perCompoundMap);
            }

            perCompoundMap.set(mapCase, score);
        }
    });

    // Cleanup
    const toRemove = [];
    const toGetEnergies = [];
    for (const perCompoundMap of results.values()) {
        for (const score of perCompoundMap.values()) {
            for (const pose of score.scoringPoses) {
                const keepDockingSites = false; // Can enable for debugging, to see docking refs
                const isDockingSite = score.dockingSites.some((site) => site.pose === pose);
                const keepDockingSite = keepDockingSites && isDockingSite;
                if (pose !== score.topPose && !keepDockingSite) {
                    console.log(`Removing pose ${pose.resSpec} with score ${scorePose(pose).toFixed(2)}, in favor of ${score.topPose.resSpec} with score ${score.topScore.toFixed(2)}`);
                    toRemove.push(pose);
                } else {
                    toGetEnergies.push(pose);
                }
            }
        }
    }

    if (scoringParams.cleanup) {
        for (const pose of toRemove) {
            await UserActions.RemoveCompound(pose);
        }
    }
    // Solvation is not initially calculated, so fetch it for top poses.
    for (const pose of toGetEnergies) await UserActions.GetEnergies(pose);

    return { results, errors };
}

/**
 * This function is to unify calls for both old (>1 connection) and new (1 connection) selectivity.
 * If we have multiple connections, we can take advantage of Promise.all.
 * For a single connection, the server will eventually crash if we try to do too many things in a
 * Promise.all.
 *
 * This function gathers proteins into groups by connection.
 * It does Promise.all across each conn. group, but awaits in a loop for proteins within the group.
 * @param {MapCase[]} proteins
 * @param {function} doFn
 */
async function loopOnProteins(proteins, doFn) {
    const connectionProteinsMap = new Map();
    for (const mapCase of proteins) {
        const { connector } = App.getDataParents(mapCase);
        let connProteinList = connectionProteinsMap.get(connector);
        if (!connProteinList) {
            connProteinList = [];
            connectionProteinsMap.set(connector, connProteinList);
        }
        connProteinList.push(mapCase);
    }

    await Promise.all(
        [...connectionProteinsMap.values()].map(async (proteinList) => {
            for (const protein of proteinList) {
                await doFn(protein);
            }
        })
    );
}

/**
 * @param {{
 * compound: Compound,
 * mapCase: MapCase,
 * scoringParams:{}
 * }} param0
 * @returns
 */
async function scoreCompoundForOne({ compound, mapCase, scoringParams }) {
    const scoringResults = await getScoringCandidates(compound, mapCase, scoringParams);

    const { compounds: scoringPoses, dockingSites=[], errors } = scoringResults;

    scoringPoses.sort(scoreSortFn);

    const topPose = scoringPoses[0];

    const topScore = scorePose(topPose);

    console.log(`Top pose for ${compound.resSpec} against ${mapCase.getNodeName()}: ${topPose?.resSpec}, score ${topScore.toFixed(2)}`);

    return {
        topPose,
        topScore,
        errors,
        scoringPoses,
        dockingSites,
    };
}

/**
 * For a compound and a mapCase, return poses that will be scored, to identify the top pose.
 */
async function getScoringCandidates(compound, mapCase, scoringParams) {
    const caseDataP = App.Workspace.lookupCaseData(mapCase);

    /** @type {Compound[]} */
    const compounds = [];
    const errors = [];

    const { diffdockPoseCount, autodockPoseCount, autodockSpeed } = getScoringParams(scoringParams);

    // If scoring for a different protein, need to copy the compound to the new protein
    const ensureResults = await ensureCompoundOnProtein(compound, caseDataP);
    const { compound: dockingCompound, errors: ensureErrs } = ensureResults;
    if (ensureErrs) errors.push(...ensureErrs);

    if (!dockingCompound) {
        // Nothing we can do if we can't get the compound onto the target protein
        return { compounds, errors };
    }
    const deleteDockingCompound = dockingCompound !== compound;

    // Start with DiffDock to get pose candidates and docking references for autodock
    const diffdockParams = { dockingProgram: 'diffdock', poseCount: diffdockPoseCount };
    const diffdockResults = await dockAndMinimizeForScoring(dockingCompound, diffdockParams);
    const { compounds: diffdockPoses, errors: diffdockErrors } = diffdockResults;

    compounds.push(...diffdockPoses);
    errors.push(...diffdockErrors);

    // Autodock
    const dockingSites = getDockingSites(caseDataP, diffdockPoses);

    await Promise.all(dockingSites.map(async (dockingSite) => {
        const { dockingReference } = dockingSite;
        const autodockParams = {
            dockingProgram: 'autodock',
            speedArg: autodockSpeed,
            poseCount: autodockPoseCount,
            boxParams: { refObj: dockingReference },
        };
        const autodockResults = await dockAndMinimizeForScoring(dockingCompound, autodockParams);
        const { compounds: autodockPoses, errors: autodockErrors } = autodockResults;
        compounds.push(...autodockPoses);
        errors.push(...autodockErrors);
    }));

    if (deleteDockingCompound) {
        await UserActions.RemoveCompound(dockingCompound);
    }

    return { compounds, dockingSites, errors };
}

function getScoringParams(scoringParams={}) {
    const ret = {
        diffdockPoseCount: 10,
        autodockPoseCount: 10,
        autodockSpeed: 'accurate',
    };

    if (scoringParams.quick) {
        ret.diffdockPoseCount = 2;
        ret.autodockPoseCount = 2;
        ret.autodockSpeed = 'faster';
    }
    if (scoringParams.diffdockPoseCount) ret.diffdockPoseCount = scoringParams.diffdockPoseCount;
    if (scoringParams.autodockPoseCount) ret.autodockPoseCount = scoringParams.autodockPoseCount;
    if (scoringParams.autodockSpeed) ret.autodockSpeed = scoringParams.autodockSpeed;
    return ret;
}

/**
 * @param {Compound} compound
 * @param {CaseData} caseDataP
 * @returns {Promise<{compound: Compound} | {errors: string[]}>}
 */
async function ensureCompoundOnProtein(compound, caseDataP) {
    const compoundKey = compound.getUniqueStructId();

    const lookupCompounds = App.Workspace.getLoadedCompounds().filter((pose) => {
        const poseKey = pose.getUniqueStructId();
        return (poseKey === compoundKey && App.getDataParents(pose).mapCase === caseDataP.mapCase);
    });

    if (lookupCompounds.length > 0) {
        return { compound: lookupCompounds[0] };
    }

    const molSource = new MolDataSource({
        sourceType: MolDataSource.Types.ScoreCompound,
        molFormat: 'mol',
        molData: compound.getMol2000(),
    });
    // Note: this MolDataSource usage will not copy the properties or heritage from compound.
    // Ideally, we could assign sourceCmpds, which would cause it to get copied over after loading.
    // However, the mechanism in ConnectedDataCmds currently assumes coordinates were preserved,
    // and for scoring we need to put the new compound "outside the protein" to avoid clashing.
    const { compounds, errors } = await UserActions.LoadMolData(
        molSource,
        new MoleculeLoadOptions(),
        caseDataP,
    );
    return (errors?.length > 0) ? { errors } : { compound: compounds[0] };
}

function getDockingSites(caseDataProtein, diffdockPoses) {
    const sites = [];
    const siteSeparationDistanceSq = 8**2;
    for (const pose of diffdockPoses) {
        const centroid = pointsCentroid(pose.getAtoms());
        if (sites.every(
            ({ centroid: site }) => pointDistanceSquared(centroid, site) > siteSeparationDistanceSq
        )) {
            sites.push({
                pose,
                centroid,
                dockingReference: DockingReference.ForCompound(pose),
            });
        }
    }

    console.log(`ScoreCompound: docking sites: ${sites.map((s) => `${s.pose.resSpec}:${s.centroid.map((x) => x.toFixed(1))}`).join('; ')}`);
    return sites;
}

function scorePose(pose) {
    return pose?.energyInfo.getEnergyScore() || 100;
}

function scoreSortFn(a, b) {
    return scorePose(a) - scorePose(b);
}

async function dockAndMinimize(cmpds, dockingParams, { otherDockOptions, minimizationOptions }={}) {
    const compounds = ensureArray(cmpds);
    const {
        compounds: dockPoses, errors: dockErrs, groups,
    } = await UserActions.SubmitDocking(compounds, dockingParams, otherDockOptions);
    const minResults = await UserActions.EnergyMinimize(dockPoses, minimizationOptions);
    const { compounds: minPoses, errors: rawMinErrs, energies } = minResults;
    minPoses.sort(scoreSortFn);
    const cmpdErrMap = gatherEnergyErrors(
        dockPoses, minPoses, energies,
        rawMinErrs, minimizationOptions
    );
    const minErrs = [];
    for (const [cmpd, cmpdErrs] of cmpdErrMap.entries()) {
        if (cmpdErrs.length === 0) continue;
        for (const cmpdErr of cmpdErrs) {
            minErrs.push(`${cmpd.displayName} energy error: ${cmpdErr}`);
        }
    }
    console.log(`ScoreCompound: docked and minimized: ${minPoses.map((p) => `${p.resSpec}:${scorePose(p)}`).join('; ')}`);
    return { compounds: minPoses, groups, errors: [...dockErrs, ...minErrs] };
}

async function dockAndMinimizeForScoring(cmpds, dockingParams,
    { otherDockOptions: otherDockOptionsIn, minimizationOptions: minimizationOptionsIn }={}) {
    const otherDockOptions = { skipGrouping: true, ...otherDockOptionsIn };
    const minimizationOptions = {
        skipSolvation: true,
        ...minimizationOptionsIn,
    };
    return dockAndMinimize(cmpds, dockingParams, { otherDockOptions, minimizationOptions });
}
