// Class to maintain energy components for a compound
// Also tracks the status of the requests for the energies

import { findHbonds } from 'BMapsSrc/interactions';
import { DefaultBindingSiteRadius } from '../utils';
import { getAtomsNearAtoms } from '../util/atom_distance_utils';

export class EnergyInfo {
    static get clashMessage() {
        return 'This compound is clashing with the protein. You may need to dock it to get a favorable energy score.';
    }

    constructor(compound) {
        this.compound = compound;
        this.reset();
    }

    reset() {
        this.energies = {}; // {energyType: {type, status, energy, error} }

        // Updated by summarize(), for better or worse
        this.status = EnergyInfo.States.success;
        this.errors = [];
        this.totalEnergy = 0; // Sum of all energy components (not currently visible)
        this.energyScore = 0; // Interaction energy score
        this.internalEnergy = 0; // Used for energy score if no protein

        // Specific state for forcefield and minimization requests
        this.forcefieldParamStatus = EnergyInfo.States.success; // tracked separately from energies
        this.ffError = null;
        this.minimizationStatus = EnergyInfo.States.success;
        this.minimizationError = null;
    }

    // This should be tracked in the status, instead of doing this calculation
    // Exclude "extra types" from the calculation
    notRequested() {
        return this.status === EnergyInfo.States.success
            && this.getAllEnergies().filter(
                ({ type }) => Object.values(EnergyInfo.Types).includes(type)
            ).length === 0;
    }

    energyAvailable(type) {
        const info = this.getEnergyEntryByType(type);
        return info && info.status === EnergyInfo.States.success;
    }

    energiesAvailable(types) {
        for (const type of types) {
            if (!this.energyAvailable(type)) {
                return false;
            }
        }
        return true;
    }

    energyScoreAvailable() {
        return Object.keys(EnergyInfo.EnergyScoreFactors).every((enT) => this.energyAvailable(enT));
    }

    // Functions to track the status of forcefield and energy requests (server operations)
    forcefieldRequested() {
        // First remove energy components that will be impacted by new forcefield params.
        // Deleting them here may not actually be necessary. This operation is only called by
        // getEnergies, not by energyMinimize, which guarantees FF params on the server.
        // This leads to a difference in what the user sees if the compound already has energies:
        // getEnergies clears everything and then fills it out, whereas energyMinimize will
        // show the old values until the new values come in.
        delete this.energies[EnergyInfo.Types.vdW];
        delete this.energies[EnergyInfo.Types.hbonds];
        delete this.energies[EnergyInfo.Types.electrostatics];
        delete this.energies[EnergyInfo.Types.stress];
        delete this.energies[EnergyInfo.Types.ddGs];
        this.forcefieldParamStatus = EnergyInfo.States.working;
        this.summarize();
    }

    updateForcefieldRequest(success, error) {
        this.forcefieldParamStatus = success ? EnergyInfo.States.success : EnergyInfo.States.error;
        if (!success) {
            this.ffError = error;
        }
        this.summarize();
    }

    minimizationRequested() {
        this.minimizationStatus = EnergyInfo.States.working;
        this.summarize();
    }

    updateMinimizationRequest(success, error) {
        this.minimizationStatus = success ? EnergyInfo.States.success : EnergyInfo.States.error;
        if (!success) {
            this.minimizationError = error;
        }
        this.summarize();
    }

    setClashing() {
        this.updateMinimizationRequest(false, EnergyInfo.clashMessage);
    }

    isClashing() {
        return this.minimizationError === EnergyInfo.clashMessage;
    }

    energyRequested(energyType) { this.updateEnergy(energyType, EnergyInfo.States.working); }
    energyError(energyType, error) {
        this.updateEnergy(energyType, EnergyInfo.States.error, null, error);
    }

    setEnergy(energyType, value) {
        this.updateEnergy(energyType, EnergyInfo.States.success, value);
    }

    updateEnergy(energyType, status, val, error='') {
        console.log(`Updating energy ${this.compound.resSpec} : ${energyType} => ${status} ${val != null ? val.toFixed(2) : error}`);
        // Entry objects should look like: {type, status, energy, error} }
        let entry = this.getEnergyEntryByType(energyType);

        // Create an entry if necessary
        if (!entry) {
            entry = { type: energyType };
            this.energies[energyType] = entry;
        }

        // Update appropriately
        entry.status = status;
        if (status === EnergyInfo.States.success) {
            entry.energy = val;
        } else if (status === EnergyInfo.States.error) {
            entry.error = error;
        }

        this.summarize(); // also publishes to the display
    }

    // Accessor functions
    getEnergyTypes() {
        return Object.keys(this.energies);
    }

    getEnergyEntryByType(energyType) {
        return this.energies[energyType];
    }

    getEnergyValueByType(energyType) {
        const entry = this.getEnergyEntryByType(energyType);
        return entry && entry.energy;
    }

    getEnergyTotal() {
        return this.totalEnergy;
    }

    getEnergyScore() {
        return this.energyScore;
    }

    getInternalEnergy() {
        return this.internalEnergy;
    }

    anyEnergiesWorking() {
        return (
            this.minimizationStatus === EnergyInfo.States.working
            || this.forcefieldParamStatus === EnergyInfo.States.working
            || Object.values(this.energies).some((en) => en.status === EnergyInfo.States.working)
        );
    }

    getEnergyReport() {
        const values = {};
        const types = [
            EnergyInfo.Types.vdW,
            EnergyInfo.Types.hbonds,
            EnergyInfo.Types.ddGs,
            EnergyInfo.Types.stress,
            EnergyInfo.Types.electrostatics,
        ];
        for (const type of types) {
            values[type] = this.getEnergyValueByType(type);
        }

        return {
            energyScore: this.getEnergyScore(),
            totalEnergy: this.getEnergyTotal(),
            internalEnergy: this.getInternalEnergy(),
            ...values,
        };
    }

    // Takes a list of energy types and sums those that are available
    sumEnergiesByTypes(types) {
        let sum = null; // initialized to null so caller can distinguish from 0

        if (this.notRequested() === false) {
            for (const type of types) {
                if (this.energyAvailable(type)) {
                    if (sum == null) sum = 0;
                    sum += this.getEnergyValueByType(type);
                }
            }
        }
        return sum;
    }

    getAllEnergies() {
        const results = [];
        // First add known energy types in order
        const sortedTypes = [
            EnergyInfo.Types.vdW,
            EnergyInfo.Types.ddGs,
            EnergyInfo.Types.hbonds,
            EnergyInfo.Types.electrostatics,
            EnergyInfo.Types.stress,
        ];
        for (const type of sortedTypes) {
            const energyEntry = this.getEnergyEntryByType(type);
            if (energyEntry) {
                results.push(energyEntry);
            }
        }

        // Now make sure we didn't miss any others that were added
        for (const [energyType, energyEntry] of Object.entries(this.energies)) {
            const existingEntry = results.find((r) => r.type === energyType);
            if (!existingEntry) {
                results.push(energyEntry);
            }
        }

        return results;
    }

    summarize() {
        this.totalEnergy = 0;
        this.energyScore = 0;
        this.internalEnergy = 0;
        this.errors = [];
        if (this.forcefieldParamStatus === EnergyInfo.States.working
            || this.minimizationStatus === EnergyInfo.States.working) {
            this.status = EnergyInfo.States.working;
        } else if (this.forcefieldParamStatus === EnergyInfo.States.error
            || this.minimizationStatus === EnergyInfo.States.error) {
            this.status = EnergyInfo.States.error;
        } else {
            this.status = EnergyInfo.States.success;
        }

        if (this.ffError) {
            this.errors.push(this.ffError);
        }
        if (this.minimizationError) this.errors.push(this.minimizationError);

        const energyErrors = [];

        for (const [type, entry] of Object.entries(this.energies)) {
            switch (entry.status) {
                case EnergyInfo.States.success: {
                    const energyTotalFn = this.getEnergyTotalFactorFn(type);
                    this.totalEnergy += energyTotalFn(entry.energy);
                    const energyScoreFn = this.getEnergyScoreFactorFn(type);
                    this.energyScore += energyScoreFn(entry.energy);
                    const internalEnergyFn = this.getInternalEnergyFactorFn(type);
                    this.internalEnergy += internalEnergyFn(entry.energy);
                    break;
                }
                case EnergyInfo.States.working:
                    this.status = entry.status;
                    break;
                case EnergyInfo.States.error:
                    if (Object.keys(EnergyInfo.ExtraTypes).includes(type)) {
                        // Don't include "extra energies" in the errors reported for the whole
                        // Extra energies are calculated from different services.
                        continue;
                    }

                    energyErrors.push(type);
                    // if there are both error and working entries, error defers to working
                    if (this.status === EnergyInfo.States.success) {
                        this.status = entry.status;
                    }
                    break;
                default:
                    console.error(`EnergyInfo: unknown energy status: ${entry.status}`);
            }
        }

        if (energyErrors.length > 0) {
            this.errors.push(`Failed energy requests: ${energyErrors.join()}`);
        }
    }

    captureState() {
        return this.energies;
    }

    /**
     * @description Return a function to calculate the contribution of this
     * energy type for the energy score.
     * @param EnergyInfo.Types type
     */
    getEnergyTotalFactorFn(type) {
        return EnergyInfo.EnergyTotalFactors[type] || (() => 0);
    }

    getEnergyScoreFactorFn(type) {
        return EnergyInfo.EnergyScoreFactors[type] || (() => 0);
    }

    getInternalEnergyFactorFn(type) {
        return EnergyInfo.InternalEnergyFactors[type] || (() => 0);
    }
}

EnergyInfo.States = {
    working: 'working',
    success: 'success',
    error: 'error',
};

EnergyInfo.Types = {
    vdW: 'vdW',
    hbonds: 'hbonds',
    ddGs: 'ddGs',
    stress: 'stress',
    electrostatics: 'electrostatics',
    coulomb: 'coulomb',
};

/**
 * Energy-like values that aren't requested from the server.
 */
EnergyInfo.ExtraTypes = {
    dockingScore: 'dockingScore',
    GiFE: 'GiFE',
};

// Functions for how different energy components play into the energy score
EnergyInfo.EnergyTotalFactors = {
    [EnergyInfo.Types.vdW]: (x) => x,
    [EnergyInfo.Types.hbonds]: (x) => x,
    [EnergyInfo.Types.ddGs]: (x) => x,
    [EnergyInfo.Types.stress]: (x) => x,
    [EnergyInfo.Types.electrostatics]: (x) => x,
};

EnergyInfo.EnergyScoreFactors = {
    [EnergyInfo.Types.vdW]: (x) => x,
    [EnergyInfo.Types.hbonds]: (x) => x,
};

EnergyInfo.InternalEnergyFactors = {
    [EnergyInfo.Types.stress]: (x) => x,
};

/**
 * @description Apply raw binding energy data and solvation data to a compound
 * @param {Compound} compound
 * @param {*} energyData
 * @param {*} solvData
 * @param {Array} soluteAtoms List of all solute atoms, used for finding hydrogen bonds
 * @param {function} lookupAtom function to to find an atom by uid, used for processing
 * solv atom-by-atom details
 */
export function applyEnergyStateToCompound(
    compound, energyData, solvData, caseData
) {
    const allProblems = [];
    if (energyData) {
        const { problems } = applyEnergiesToCompound(compound, energyData, caseData);
        if (problems) {
            allProblems.push(...problems);
        }
    }
    if (solvData) {
        const { problems } = applySolvationToCompound(compound, solvData, caseData);
        if (problems) {
            allProblems.push(...problems);
        }
    }
    return allProblems;
}

/**
 * @description Apply raw binding energy data to a compound
 * This was previously part of doReceiveEnergiesForLigand in project_data.js
 * @param {Compound} compound
 * @param {*} energyData - raw binding energy data: {interactionEnergies, internalEnergies}
 * @param {CaseData} caseData - Used to get solute atoms to find hydrogen bonds
 */
export function applyEnergiesToCompound(compound, energyData, caseData) {
    let { interactionEnergies, internalEnergies } = energyData;
    const problems = [];

    let torsionList;
    if (internalEnergies.totals) {
        internalEnergies = [internalEnergies.totals];
        torsionList = internalEnergies.bond_dihedral || [];
    } else {
        torsionList = internalEnergies[1][3];
    }
    if (interactionEnergies.totals) interactionEnergies = [interactionEnergies.totals];

    // Energies list contain: summary list, details lists: bond length, angle, dihedral,
    // improper energies, vdw, coulomb.  Summary has totals for bond length, angle,
    // dihedral, improper energies, vdw, coulomb.  Details has atom uuid's for the atoms
    // involved followed by energies.
    console.log(`internal energies ${internalEnergies[0].map((x) => x.toFixed(2)).join(',')},
                interaction ${interactionEnergies[0].map((x) => x.toFixed(2)).join(',')}`);

    // Save source data on the compound so we can save and restore.
    compound.originalEnergyData = energyData;

    const stressEnergySum = internalEnergies[0].reduce((x, y) => x+y); // reduce() sums the values
    const vdw = interactionEnergies[0][0] || 0;
    const coulomb = interactionEnergies[0][1] || 0;
    const soluteAtoms = caseData.getSoluteAtoms();
    const [hbondEnergy, newCoulomb] = processElectrostatic(compound, coulomb, soluteAtoms);
    compound.energyInfo.setEnergy(EnergyInfo.Types.stress, stressEnergySum);
    compound.energyInfo.setEnergy(EnergyInfo.Types.vdW, vdw);
    compound.energyInfo.setEnergy(EnergyInfo.Types.coulomb, coulomb);
    compound.energyInfo.setEnergy(EnergyInfo.Types.hbonds, hbondEnergy);
    compound.energyInfo.setEnergy(EnergyInfo.Types.electrostatics, newCoulomb);

    const worstTorsion = getWorstTorsion(torsionList);

    return {
        vdw,
        coulomb: newCoulomb,
        hbonds: hbondEnergy,
        stress: stressEnergySum,
        total: vdw+hbondEnergy+newCoulomb+stressEnergySum,
        worstTorsion,
        problems,
    };
}

/**
 * @description Apply raw solvation data to a compound
 * @param {*} compound
 * @param {*} solvData
 * @param {CaseData} caseData passed to applySolvationAtomDetails
 */
export function applySolvationToCompound(compound, solvData, caseData) {
    // Save source data so we can save and restore
    compound.originalSolvData = solvData;

    const [ddG_PL, ddG_LP, dG_P, dG_L] = solvData.totals;
    compound.ddG_PL = ddG_PL;
    compound.ddG_LP = ddG_LP;
    compound.dG_P = dG_P;
    compound.dG_L = dG_L;
    const ddGs = ddG_PL + ddG_LP;

    const [
        ligPSAunbound, ligPSAbound, ligTSAunbound, ligTSAbound,
        protPSAunbound, protPSAbound, protTSAunbound, protTSAbound,
    ] = solvData.PSA;
    compound.PSA = ligPSAunbound;
    compound.PSAbound = ligPSAbound;
    compound.TSA = ligTSAunbound;
    compound.TSAbound = ligTSAbound;
    compound.proteinPSA = protPSAunbound;
    compound.proteinPSAbound = protPSAbound;
    compound.proteinTSA = protTSAunbound;
    compound.proteinTSAbound = protTSAbound;

    compound.energyInfo.setEnergy(EnergyInfo.Types.ddGs, ddGs);

    const problems = applySolvationAtomDetails(compound, solvData.detail, caseData);

    return { ddGs, problems };
}

/* processElectrostatic()
 * We need to tweak the electrostatic value from the server to account for dielectric
 *
 * Algorithm:
    Take value from server
    Subtract hbond energies (sum of all hbond energies in the whole compound)
    Divide by dielectric divisor to approximate dielectric
*/
function processElectrostatic(compound, coulombIn, soluteAtoms) {
    const dielectricDivisor = 4;
    let coulomb = coulombIn;

    const debugInfo = [];
    const debug = (item) => debugInfo.push(item);

    const nearbySoluteAtoms = getAtomsNearAtoms(
        compound.getAtoms(), soluteAtoms, DefaultBindingSiteRadius, { roundup: true }
    );
    const hbondInfo = findHbonds(compound.getAtoms(), nearbySoluteAtoms);

    // Sum hbond energies
    let hbondEnergy = 0;
    for (const compoundAtomInfo of hbondInfo) {
        const [catom, partners] = compoundAtomInfo;
        for (const partnerInfo of partners) {
            const [patom, en] = partnerInfo;
            hbondEnergy += en;
        }
    }

    debug(`hbond energy: ${hbondEnergy.toFixed(1)}`);
    debug(`dielectric divisor: ${dielectricDivisor.toFixed(1)}`);
    debug(`Coulomb: initial ${coulomb.toFixed(1)}`);

    // Remove hbonds, then apply divisor (hack to approximate dielectric)
    coulomb -= hbondEnergy;
    debug(`-hbonds ${coulomb.toFixed(1)}`);
    coulomb /= dielectricDivisor;
    debug(`/dielectric ${coulomb.toFixed(1)}`);
    debug('(No longer adding hbonds)');

    console.log(`Electrostatic processing: ${debugInfo.join('; ')}`);
    return [hbondEnergy, coulomb];
}

function getWorstTorsion(torsionList) {
    // Each torsion entry is a 5-array: 4 atoms uuids and an energy
    const sortedTorsions = torsionList.sort(
        (t1, t2) => t1[4] < t2[4]
    ); // descending sort. [4] is energy per above.
    return sortedTorsions[0];
}

/**
 * @description Store individual ddGs components (dGs, dGs_bound) for use by highlighting
 * machinery. ddGs = dGs_bound - dGs
 * Values for compound atoms are stored on the atom itself.
 * Values for solute atoms are stored in a map on the atom residue object, totalling
 * up the atoms for each residue.
 * @param {*} compound
 * @param {*} atomDetail
 * @param {CaseData} caseData - used to look up solute atoms by atom uid
 */
function applySolvationAtomDetails(compound, atomDetail, caseData) {
    const { compound: compoundDetail, solute: soluteDetail } = atomDetail;

    const problems = [];

    // Store compound atom details directly on the atom
    for (const [atomOrAtomName, dGs, dGs_bound] of compoundDetail) {
        let atom = atomOrAtomName;
        if (typeof (atom) === 'string') {
            atom = compound.getAtomByName(atomOrAtomName);
        }

        if (!atom) {
            const msg = `applySolvationAtomDetails: can't find atom ${atomOrAtomName} for compound ${compound.resSpec}.`;
            problems.push(msg);
            console.warn(msg);
            console.log(`Existing atom names: ${compound.getAtoms().map((a) => a.atom).join(', ')}`);
            continue;
        }
        atom.dGs = dGs;
        atom.dGs_bound = dGs_bound;
    }

    const residueSet = new Set();

    // Sum up solute atom detail for this compound in a residue map
    for (const [atomOrUniqueId, dGs, dGs_bound] of soluteDetail) {
        let atom = atomOrUniqueId;
        if (typeof (atom) === 'number') {
            atom = caseData.findAtomByUid(atomOrUniqueId);
        }

        if (!atom) {
            const msg = `applySolvationAtomDetails: can't find atom ${atomOrUniqueId} for solute detail`;
            problems.push(msg);
            console.warn(msg);
            continue;
        }

        if (atom.residue) {
            // solvInfoMap maps from compoundSpec to
            // {dGs: <dGsTotal>, dGs_bound: <dGs_boundTotal>, ddGs}
            // The entry contains the sums of solv values for all atoms in the residue

            // First ensure that the map exists and that we have an entry for our compound
            if (atom.residue.solvInfoMap == null) {
                atom.residue.solvInfoMap = new Map();
            }

            let solvEntry = atom.residue.solvInfoMap.get(compound.resSpec);
            if (solvEntry == null) {
                solvEntry = { dGs: 0, dGs_bound: 0, ddGs: 0 };
                atom.residue.solvInfoMap.set(compound.resSpec, solvEntry);
            }

            // Now we can update the entry
            solvEntry.dGs += dGs;
            solvEntry.dGs_bound += dGs_bound;
            residueSet.add(atom.residue);
        } else {
            console.warn(`decodeSolvation: Weird, a solute atom without a residue: ${atom.atom} ${atom.uniqueID}`);
        }
    }

    // Calculate the ddGs sum for each residue
    for (const residue of residueSet) {
        const solvEntry = residue.solvInfoMap.get(compound.resSpec);
        solvEntry.ddGs = residue.atoms.length * (solvEntry.dGs_bound - solvEntry.dGs);
    }

    return problems;
}
