mnnCorrect.js

import * as utils from "./utils.js";
import { RunPcaResults } from "./runPca.js";
import * as wasm from "./wasm.js";

/**
 * Perform mutual nearest neighbor (MNN) correction on a low-dimensional representation.
 * This is primarily used to remove batch effects.
 *
 * @param {(RunPcaResults|TypedArray|Array|Float64WasmArray)} x - A matrix of low-dimensional results where rows are dimensions and columns are cells.
 * If this is a {@linkplain RunPcaResults} object, the PCs are automatically extracted.
 * Otherwise, the matrix should be provided as an array in column-major form, with specification of `numberOfDims` and `numberOfCells`.
 * @param {(Int32WasmArray|Array|TypedArray)} block - Array containing the block assignment for each cell.
 * This should have length equal to the number of cells and contain all values from 0 to `n - 1` at least once, where `n` is the number of blocks.
 * This is used to segregate cells in order to perform normalization within each block.
 * @param {object} [options={}] - Further optional parameters.
 * @param {?Float64WasmArray} [options.buffer=null] - Buffer of length equal to the product of the number of cells and dimensions,
 * to be used to store the corrected coordinates for each cell.
 * If `null`, this is allocated and returned by the function.
 * @param {?number} [options.numberOfDims=null] - Number of dimensions in `x`.
 * This should be specified if an array-like object is provided, otherwise it is ignored.
 * @param {?number} [options.numberOfCells=null] - Number of cells in `x`.
 * This should be specified if an array-like object is provided, otherwise it is ignored.
 * @param {number} [options.k=15] - Number of neighbors to use in the MNN search. 
 * @param {number} [options.numberOfMADs=3] - Number of MADs to use to define the threshold on the distances to the neighbors,
 * see comments [here](https://ltla.github.io/CppMnnCorrect).
 * @param {number} [options.robustIterations=2] - Number of robustness iterations to use for computing the center of mass,
 * see comments [here](https://ltla.github.io/CppMnnCorrect).
 * @param {number} [options.robustTrim=0.25] - Proportion of furthest observations to remove during robustness iterations, 
 * see comments [here](https://ltla.github.io/CppMnnCorrect).
 * @param {string} [options.referencePolicy="max-rss"] - What policy to use to choose the first reference batch.
 * This can be the largest batch (`"max-size"`), the most variable batch (`"max-variance"`), the batch with the highest RSS (`"max-rss"`) or batch 0 in `block` (`"input"`).
 * @param {boolean} [options.approximate=true] - Whether to perform an approximate nearest neighbor search.
 * @param {?number} [options.numberOfThreads=null] - Number of threads to use.
 * If `null`, defaults to {@linkcode maximumThreads}.
 *
 * @return {Float64WasmArray} Array of length equal to `x`, containing the batch-corrected low-dimensional coordinates for all cells.
 * Values are organized using the column-major layout.
 * This is equal to `buffer` if provided.
 */
export function mnnCorrect(x, block, { 
    buffer = null, 
    numberOfDims = null,
    numberOfCells = null,
    k = 15,
    numberOfMADs = 3, 
    robustIterations = 2, 
    robustTrim = 0.25,
    referencePolicy = "max-rss",
    approximate = true,
    numberOfThreads = null
} = {}) {

    let local_buffer;
    let x_data;
    let block_data;
    let nthreads = utils.chooseNumberOfThreads(numberOfThreads);

    try {
        if (x instanceof RunPcaResults) {
            numberOfDims = x.numberOfPCs();
            numberOfCells = x.numberOfCells();
            x = x.principalComponents({ copy: "view" });
        } else {
            if (numberOfDims === null || numberOfCells === null || numberOfDims * numberOfCells !== x.length) {
                throw new Error("length of 'x' must be equal to the product of 'numberOfDims' and 'numberOfCells'");
            }
            x_data = utils.wasmifyArray(x, "Float64WasmArray");
            x = x_data;
        }

        if (buffer == null) {
            local_buffer = utils.createFloat64WasmArray(numberOfCells * numberOfDims);
            buffer = local_buffer;
        } else if (buffer.length !== x.length) {
            throw new Error("length of 'buffer' must be equal to the product of the number of dimensions and cells");
        }

        block_data = utils.wasmifyArray(block, "Int32WasmArray");
        if (block_data.length != numberOfCells) {
            throw new Error("'block' must be of length equal to the number of cells in 'x'");
        }

        wasm.call(module => module.mnn_correct(
            numberOfDims, 
            numberOfCells,
            x.offset,
            block_data.offset,
            buffer.offset,
            k,
            numberOfMADs,
            robustIterations,
            robustTrim,
            referencePolicy,
            approximate,
            nthreads
        ));

    } catch (e) {
        utils.free(local_buffer);
        throw e;
        
    } finally {
        utils.free(x_data);
    }

    return buffer; 
}