scaleByNeighbors.js

import * as wasm from "./wasm.js";
import * as utils from "./utils.js";
import { buildNeighborSearchIndex, BuildNeighborSearchIndexResults } from "./findNearestNeighbors.js"; 

/**
 * Scale embeddings based on the variation between neighboring cells.
 * This aims to equalize the noise across embeddings for the same population of cells across different data modalities,
 * allowing them to be combined into a single embedding for coordinated downstream analyses.
 * Check out the [**mumosa**](https://github.com/libscran/mumosa) documentation for more details.
 *
 * @param {Array} embeddings - Array of Float64WasmArrays containing column-major matrices where rows are dimensions and columns are cells.
 * All entries of this array should contain data for the same number and ordering of cells.
 * @param {number} numberOfCells - Number of cells in all embeddings.
 * @param {object} [options={}] - Optional parameters.
 * @param {number} [options.neighbors=20] - Number of neighbors to use for quantifying variation.
 * Larger values provide a more stable calculation but assume larger subpopulations.
 * @param {?Array} [options.indices=null] - Array of {@linkplain BuildNeighborSearchIndexResults} objects, 
 * where each entry is constructed from the corresponding entry of `embeddings` (see {@linkcode buildNeighborSearchIndex}).
 * This can be used to avoid redundant calculation of indices if they are already available.
 * @param {boolean} [options.asTypedArray=true] - Whether to return a Float64Array.
 * If `false`, a Float64WasmArray is returned instead.
 * @param {?Float64WasmArray} [options.buffer=null] - Array in which to store the combined embedding.
 * This should have length equal to the product of `numberOfCells` and the sum of dimensions of all embeddings.
 * @param {boolean} [options.approximate=true] - Should we construct an approximate search index if `indices` is not supplied?
 * @param {?(Array|TypedArray|Float64WasmArray)} [options.weights=null] - Array of length equal to the number of embeddings, containing a non-enegative relative weight for each embedding.
 * This is used to scale each embedding if non-equal noise is desired in the combined embedding.
 * If `null`, all embeddings receive the same weight.
 * @param {?number} [options.numberOfThreads=null] - Number of threads to use.
 * If `null`, defaults to {@linkcode maximumThreads}.
 *
 * @return {Float64Array|Float64WasmArray} Array containing the combined embeddings in column-major format, i.e., dimensions in rows and cells in columns.
 * If `buffer` is supplied, the function returns `buffer` if `asTypedArray = false`, or a view on `buffer` if `asTypedArray = true`.
 */
export function scaleByNeighbors(embeddings, numberOfCells, options = {}) {
    let { neighbors = 20, indices = null, asTypedArray = true, buffer = null, approximate = true, weights = null, numberOfThreads = null, ...others } = options;
    utils.checkOtherOptions(others);

    let embed_ptrs;
    let index_ptrs;
    let holding_weights;
    let local_buffer = null;
    let nthreads = utils.chooseNumberOfThreads(numberOfThreads);

    try {
        let nembed = embeddings.length;
        embed_ptrs = utils.createBigUint64WasmArray(nembed);
        let embed_arr = embed_ptrs.array();
        for (var i = 0; i < nembed; i++) {
            embed_arr[i] = BigInt(embeddings[i].offset);
        }

        let ndims = [];
        let total_ndim = 0;
        for (var i = 0; i < nembed; i++) {
            let n = embeddings[i].length;
            let ND = Math.floor(n / numberOfCells);
            if (numberOfCells * ND !== n) {
                throw new Error("length of arrays in 'embeddings' should be a multiple of 'numberOfCells'");
            }
            ndims.push(ND);
            total_ndim += ND;
        }

        if (indices === null) {
            indices = [];
            for (var i = 0; i < nembed; i++) {
                indices.push(buildNeighborSearchIndex(embeddings[i], { numberOfDims: ndims[i], numberOfCells: numberOfCells, approximate: approximate }));
            }
        } else {
            if (nembed !== indices.length) {
                throw new Error("'indices' and 'embeddings' should have the same length");
            }
            for (var i = 0; i < nembed; i++) {
                let index = indices[i];
                if (numberOfCells != index.numberOfCells()) {
                    throw new Error("each element of 'indices' should have the same number of cells as 'numberOfCells'");
                }
                if (ndims[i] != index.numberOfDims()) {
                    throw new Error("each element of 'indices' should have the same number of dimensions as its embedding in 'embeddings'");
                }
            }
        }

        let weight_offset = 0;
        let use_weights = false;
        if (weights !== null) {
            use_weights = true;
            holding_weights = utils.wasmifyArray(weights, "Float64WasmArray");
            if (holding_weights.length != nembed) {
                throw new Error("length of 'weights' should be equal to the number of embeddings");
            }
            weight_offset = holding_weights.offset;
        }

        index_ptrs = utils.createBigUint64WasmArray(nembed);
        let index_arr = index_ptrs.array();
        for (var i = 0; i < nembed; i++) {
            let index = indices[i];
            index_arr[i] = BigInt(indices[i].index.$$.ptr);
        }

        let total_len = total_ndim * numberOfCells;
        if (buffer === null) {
            local_buffer = utils.createFloat64WasmArray(total_len);
            buffer = local_buffer;
        } else if (total_len !== buffer.length) {
            throw new Error("length of 'buffer' should be equal to the product of 'numberOfCells' and the total number of dimensions");
        }

        wasm.call(module => module.scale_by_neighbors(
            numberOfCells, 
            nembed, 
            embed_ptrs.offset, 
            index_ptrs.offset, 
            buffer.offset, 
            neighbors, 
            use_weights, 
            weight_offset,
            nthreads
        ));

    } catch (e) {
        utils.free(local_buffer);
        throw e;

    } finally {
        utils.free(embed_ptrs);
        utils.free(index_ptrs);
        utils.free(holding_weights);
    }

    return utils.toTypedArray(buffer, local_buffer == null, asTypedArray);
}