steps/combine_embeddings.js

import * as scran from "scran.js";
import * as utils from "./utils/general.js";
import * as rna_pca_module from "./rna_pca.js";
import * as adt_pca_module from "./adt_pca.js";
import * as crispr_pca_module from "./crispr_pca.js";

export const step_name = "combine_embeddings";

function find_nonzero_upstream_states(pca_states, weights) {
    let tmp = utils.findValidUpstreamStates(pca_states);
    let to_use = [];
    for (const k of tmp) {
        if (weights[k] > 0) {
            to_use.push(k);
        }
    }
    return to_use;
}

/**
 * This step combines multiple embeddings from different modalities into a single matrix for downstream analysis.
 * It wraps the [`scaleByNeighbors`](https://kanaverse.github.io/scran.js/global.html#scaleByNeighbors) function
 * from [**scran.js**](https://kanaverse.github.io/scran.js).
 *
 * Methods not documented here are not part of the stable API and should not be used by applications.
 * @hideconstructor
 */
export class CombineEmbeddingsState {
    #pca_states;
    #parameters;
    #cache;

    constructor(pca_states, parameters = null, cache = null) {
        if (!(pca_states.RNA instanceof rna_pca_module.RnaPcaState)) {
            throw new Error("'pca_states.RNA' should be an RnaPcaState object");
        }
        if (!(pca_states.ADT instanceof adt_pca_module.AdtPcaState)) {
            throw new Error("'pca_states.ADT' should be an AdtPcaState object");
        }
        if (!(pca_states.CRISPR instanceof crispr_pca_module.CrisprPcaState)) {
            throw new Error("'pca_states.CRISPR' should be an CrisprPcaState object");
        }
        this.#pca_states = pca_states;

        this.#parameters = (parameters === null ? {} : parameters);
        this.#cache = (cache === null ? {} : cache);
        this.changed = false;
    }

    free() {
        utils.freeCache(this.#cache.combined_buffer);
    }

    /***************************
     ******** Getters **********
     ***************************/

    /**
     * @return {Float64WasmArray} Buffer containing the combined embeddings as a column-major dense matrix,
     * where the rows are the dimensions and the columns are the cells.
     * This is available after running {@linkcode CombineEmbeddingsState#compute compute}.
     */
    fetchCombined() {
        return this.#cache.combined_buffer;
    }

    /**
     * @return {number} Number of cells in {@linkcode CombineEmbeddingsState#fetchCombined fetchCombined},
     * available after running {@linkcode CombineEmbeddingsState#compute compute}.
     */
    fetchNumberOfCells() {
        return this.#cache.num_cells;
    }

    /**
     * @return {number} Number of dimensions in {@linkcode CombineEmbeddingsState#fetchCombined fetchCombined},
     * available after running {@linkcode CombineEmbeddingsState#compute compute}.
     */
    fetchNumberOfDimensions() {
        return this.#cache.total_dims;
    }

    /**
     * @return {object} Object containing the parameters.
     */
    fetchParameters() {
        // Avoid any pass-by-reference activity.
        return { ...this.#parameters };
    }

    /***************************
     ******** Compute **********
     ***************************/

    static defaults() {
        return { 
            rna_weight: 1,
            adt_weight: 1,
            crispr_weight: 0,
            approximate: true
        };
    }

    static createPcsView(cache, upstream) {
        utils.freeCache(cache.combined_buffer);
        cache.combined_buffer = upstream.principalComponents({ copy: "view" }).view();
        cache.num_cells = upstream.numberOfCells();
        cache.total_dims = upstream.numberOfPCs();
    }

    /**
     * This method should not be called directly by users, but is instead invoked by {@linkcode runAnalysis}.
     *
     * @param {object} parameters - Parameter object, equivalent to the `adt_normalization` property of the `parameters` of {@linkcode runAnalysis}.
     * @param {number} parameters.rna_weight - Relative weight of the RNA embeddings.
     * @param {number} parameters.adt_weight - Relative weight of the ADT embeddings.
     * @param {number} parameters.crispr_weight - Relative weight of the CRISPR embeddings.
     * @param {boolean} parameters.approximate - Whether an approximate nearest neighbor search should be used by `scaleByNeighbors`.
     *
     * @return The object is updated with new results.
     */
    compute(parameters) {
        let { rna_weight, adt_weight, crispr_weight, approximate } = parameters;
        this.changed = false;

        for (const v of Object.values(this.#pca_states)) {
            if (v.changed) {
                this.changed = true;
                break;
            }
        }

        if (approximate !== this.#parameters.approximate) {
            this.#parameters.approximate = approximate;
            this.changed = true;
        }

        if (rna_weight !== this.#parameters.rna_weight || adt_weight !== this.#parameters.adt_weight || crispr_weight !== this.#parameters.crispr_weight) {
            this.#parameters.rna_weight = rna_weight;
            this.#parameters.adt_weight = adt_weight;
            this.#parameters.crispr_weight = crispr_weight;
            this.changed = true;
        }

        if (this.changed) { 
            const weights = { RNA: rna_weight, ADT: adt_weight, CRISPR: crispr_weight };
            let to_use = find_nonzero_upstream_states(this.#pca_states, weights);

            if (to_use.length > 1) {
                let weight_arr = to_use.map(x => weights[x]);
                let collected = [];
                let total = 0;
                let ncells = null;

                for (const k of to_use) {
                    let curpcs = this.#pca_states[k].fetchPCs();
                    collected.push(curpcs.principalComponents({ copy: "view" }));
                    if (ncells == null) {
                        ncells = curpcs.numberOfCells();
                    } else if (ncells !== curpcs.numberOfCells()) {
                        throw new Error("number of cells should be consistent across all embeddings");
                    }
                    total += curpcs.numberOfPCs();
                }

                let buffer = utils.allocateCachedArray(ncells * total, "Float64Array", this.#cache, "combined_buffer");
                scran.scaleByNeighbors(collected, ncells, { buffer: buffer, weights: weight_arr, approximate: approximate });
                this.#cache.num_cells = ncells;
                this.#cache.total_dims = total;

            } else {
                // If there's only one embedding, we shouldn't respond to changes
                // in parameters, because they won't have any effect.
                let pcs = this.#pca_states[to_use[0]].fetchPCs();
                this.constructor.createPcsView(this.#cache, pcs);
            }
        }

        // Updating all parameters anyway. This requires us to take ownership
        // of 'weights' to avoid pass-by-reference shenanigans.
        return;
    }
}