Source code for archepy.core.temporal

"""
Temporal Multi-Subject Archetypal Analysis (MS-AA).

Direct Python port of ``MultiSubjectAA_T.m`` from the upstream Multisubject
Archetypal Analysis Toolbox.

Original MATLAB authors:
    Jesper L. Hinrich, Sophia E. Bardenfleth, Morten Mørup
    Copyright (C) 2016 Technical University of Denmark.
    Distributed under the terms of the Multisubject Archetypal Analysis
    Toolbox license (see LICENSE).
"""

from __future__ import annotations

import time
from dataclasses import dataclass
from typing import Any

import numpy as np

try:
    import cupy as cp
except ImportError:
    cp = None

from archepy._utils import mgetopt, to_numpy
from archepy.core._s_update import supdate_indi_step
from archepy.init.furthest_sum import furthest_sum


[docs] @dataclass class SubjectT: """ Temporal MS-AA subject container. Parameters ---------- X : ndarray, shape (V, T) Voxels × time points. sX : ndarray, shape (V, sT) Voxels × "tilde" time points. """ X: np.ndarray sX: np.ndarray
[docs] def multi_subject_aa_T( subj: list[SubjectT], noc: int, opts: dict[str, Any] | None = None, ) -> tuple[list[dict[str, Any]], np.ndarray, np.ndarray, float, float]: """ Fit Multi-Subject Archetypal Analysis (temporal variant). See :func:`archepy.multi_subject_aa` for the spatial variant and a description of the ``opts`` dictionary; the same options apply here. """ if opts is None: opts = {} conv_crit = mgetopt(opts, "conv_crit", 1e-6) maxiter = int(mgetopt(opts, "maxiter", 100)) fix_var_iter = int(mgetopt(opts, "fix_var_iter", 5)) runGPU = bool(mgetopt(opts, "use_gpu", False)) voxelVariance = bool(mgetopt(opts, "heteroscedastic", True)) numCstep = int(mgetopt(opts, "numCstep", 10)) numSstep = int(mgetopt(opts, "numSstep", 20)) sort_crit = mgetopt(opts, "sort_crit", "corr") init_type = mgetopt(opts, "init", "FurthestSum") initial_S_steps = int(mgetopt(opts, "initSstep", 250)) rngSEED = mgetopt(opts, "rngSEED", None) if runGPU and cp is None: raise ImportError( "opts['use_gpu']=True requested but CuPy is not installed. " "Install with: pip install archepy[gpu]" ) rng = np.random.default_rng(rngSEED) subj[0].X.shape[0] T = subj[0].X.shape[1] sT = subj[0].sX.shape[1] B = len(subj) V_list = [s.sX.shape[0] for s in subj] xp = cp if runGPU else np # ---- Initialize C (sT x noc) ---- if init_type.lower() == "furthestsum": total_V = sum(V_list) Xcombined = np.zeros((total_V, sT), dtype=float) off = 0 for s in subj: vv = s.sX.shape[0] Xcombined[off : off + vv, :] = np.asarray(s.sX, dtype=float) off += vv seed = rng.integers(low=0, high=sT) idx = furthest_sum(Xcombined, noc=noc, i=int(seed), exclude=None, treat_as_kernel=False) C = xp.zeros((sT, noc), dtype=float) C[idx, xp.arange(noc)] = 1.0 muC = xp.array(1.0) else: C = xp.asarray(rng.random((sT, noc)), dtype=float) C /= C.sum(axis=0, keepdims=True) + xp.finfo(float).eps muC = xp.array(1.0) # ---- Move to GPU & SST ---- for s in subj: if runGPU: s.X = cp.asarray(s.X, dtype=float) s.sX = cp.asarray(s.sX, dtype=float) s.V = s.sX.shape[0] s.SST = (s.X * s.X).sum() SST = float(sum([float(s.SST.get()) if runGPU else float(s.SST) for s in subj])) # ---- Initialize per-subject ---- for s in subj: if voxelVariance: s.sigmaSq = xp.ones((s.V, 1), dtype=float) * (SST / (sum(V_list) * T)) else: s.sigmaSq = xp.ones((s.V, 1), dtype=float) U = xp.asarray(rng.random((noc, T)), dtype=float) s.S = -xp.log(U + xp.finfo(float).tiny) s.S /= s.S.sum(axis=0, keepdims=True) + xp.finfo(float).eps s.muS = xp.ones((1,), dtype=float) s.sXC = s.sX @ C s.XCtX = s.sXC.T @ s.X s.CtXtXC = s.sXC.T @ s.sXC s.SSt = s.S @ s.S.T s.XSt = s.X @ s.S.T S_np, muS_np, SSt_np = supdate_indi_step( to_numpy(s.S, runGPU), to_numpy(s.XCtX, runGPU), to_numpy(s.CtXtXC, runGPU), np.ones(s.S.shape[1], dtype=float), int(s.V), int(initial_S_steps), ) s.S = xp.asarray(S_np) s.muS = xp.ones((s.S.shape[1],), dtype=float) s.SSt = xp.asarray(SSt_np) # ---- Initial NLL ---- NLL = 0.0 SST_sigmaSq = 0.0 for s in subj: inv_sqrt_sig = 1.0 / xp.sqrt(s.sigmaSq) s.XSt = (s.X @ s.S.T) * inv_sqrt_sig s.sXC = (s.sX @ C) * inv_sqrt_sig s.CtXtXC = s.sXC.T @ s.sXC s.XCtX = ((s.sXC * inv_sqrt_sig).T) @ s.X s.SST_sigmaSq = (s.X * (s.X / s.sigmaSq)).sum() s.NLL = ( 0.5 * s.SST_sigmaSq - (s.sXC * s.XSt).sum() + 0.5 * (s.CtXtXC * s.SSt).sum() + (T / 2.0) * (s.V * np.log(2.0 * np.pi) + xp.log(s.sigmaSq).sum()) ) NLL += float(s.NLL.get()) if runGPU else float(s.NLL) SST_sigmaSq += float(s.SST_sigmaSq.get()) if runGPU else float(s.SST_sigmaSq) t_start = time.perf_counter() cost_fun = np.zeros((maxiter,), dtype=float) noise_threshold_opt = mgetopt(opts, "noise_threshold", None) if noise_threshold_opt is None: var_threshold = (SST / (sum(V_list) * T)) * 1e-3 else: var_threshold = float(noise_threshold_opt) # ---- Main loop ---- iter_ = 0 dNLL = np.inf while ((abs(dNLL) >= conv_crit * abs(NLL)) or (fix_var_iter >= iter_)) and (iter_ < maxiter): iter_ += 1 NLL_old = NLL cost_fun[iter_ - 1] = NLL C, muC, NLL = _Cupdate_multi_subjects_T(subj, C, muC, NLL, numCstep, runGPU) for s in subj: inv_sqrt_sig = 1.0 / xp.sqrt(s.sigmaSq) s.sXC = (s.sX @ C) * inv_sqrt_sig s.XCtX = ((s.sXC * inv_sqrt_sig).T) @ s.X s.CtXtXC = s.sXC.T @ s.sXC s.NLL = ( 0.5 * s.SST_sigmaSq - (s.XCtX * s.S).sum() + 0.5 * (s.CtXtXC * s.SSt).sum() + (T / 2.0) * (s.V * np.log(2.0 * np.pi) + xp.log(s.sigmaSq).sum()) ) NLL = 0.0 SST_sigmaSq = 0.0 for s in subj: S_np, muS_np, SSt_np = supdate_indi_step( to_numpy(s.S, runGPU), to_numpy(s.XCtX, runGPU), to_numpy(s.CtXtXC, runGPU), to_numpy(s.muS, runGPU), int(s.V), int(numSstep), ) s.S = xp.asarray(S_np) s.muS = xp.asarray(muS_np) s.SSt = xp.asarray(SSt_np) if voxelVariance and (iter_ > fix_var_iter): resid = s.X - (s.sX @ C) @ s.S s.sigmaSq = (resid * resid).sum(axis=1, keepdims=True) / float(T) s.sigmaSq = xp.maximum(s.sigmaSq, var_threshold) inv_sqrt_sig = 1.0 / xp.sqrt(s.sigmaSq) s.XSt = (s.X @ s.S.T) * inv_sqrt_sig s.sXC = (s.sX @ C) * inv_sqrt_sig s.XCtX = ((s.sXC * inv_sqrt_sig).T) @ s.X s.CtXtXC = s.sXC.T @ s.sXC s.SST_sigmaSq = (s.X * (s.X / s.sigmaSq)).sum() else: inv_sqrt_sig = 1.0 / xp.sqrt(s.sigmaSq) s.XSt = (s.X @ s.S.T) * inv_sqrt_sig s.NLL = ( 0.5 * s.SST_sigmaSq - (s.sXC * s.XSt).sum() + 0.5 * (s.CtXtXC * s.SSt).sum() + (T / 2.0) * (s.V * np.log(2.0 * np.pi) + xp.log(s.sigmaSq).sum()) ) NLL += float(s.NLL.get()) if runGPU else float(s.NLL) SST_sigmaSq += float(s.SST_sigmaSq.get()) if runGPU else float(s.SST_sigmaSq) dNLL = NLL_old - NLL elapsed = time.perf_counter() - t_start SSE = [] for s in subj: recon = (s.sX @ C) @ s.S if runGPU: sse = float(cp.linalg.norm(s.X - recon, ord="fro") ** 2) else: sse = float(np.linalg.norm(s.X - recon, ord="fro") ** 2) SSE.append(sse) varexpl = (SST - sum(SSE)) / SST ind = np.arange(noc) if sort_crit.lower() == "corr" and (sum(V_list) == max(V_list) * B): arch = np.zeros((V_list[0], B)) mean_corr = np.zeros((noc,)) for j in range(noc): for bi, s in enumerate(subj): arch[:, bi] = to_numpy(s.sXC[:, j], runGPU) Ccorr = np.corrcoef(arch, rowvar=False) iu = np.triu_indices(B, k=1) mean_corr[j] = float(Ccorr[iu].mean()) ind = np.argsort(mean_corr)[::-1] if not np.array_equal(ind, np.arange(noc)): C = C[:, ind] for s in subj: s.S = s.S[ind, :] s.sXC = s.sXC[:, ind] C_np = to_numpy(C, runGPU) results_subj: list[dict[str, Any]] = [] for bi, s in enumerate(subj): out = { "S": to_numpy(s.S, runGPU), "sXC": to_numpy(s.sXC, runGPU), "sigmaSq": to_numpy(s.sigmaSq, runGPU), "NLL": float(to_numpy(s.NLL, runGPU)), "SSE": float(SSE[bi]), "SST": float(to_numpy(s.SST, runGPU)), "SST_sigmaSq": float(to_numpy(s.SST_sigmaSq, runGPU)), } results_subj.append(out) return results_subj, C_np, cost_fun[:iter_], float(varexpl), float(elapsed)
# ------------------------- helpers ------------------------- def _Cupdate_multi_subjects_T( subj: list[Any], C, muC, NLL: float, niter: int, runGPU: bool, ): xp = cp if runGPU else np sT, noc = C.shape T = subj[0].X.shape[1] XtXSt_list = [] for s in subj: inv_sqrt_sig = 1.0 / xp.sqrt(s.sigmaSq) XtXSt_list.append(s.sX.T @ (s.XSt * inv_sqrt_sig)) total_V = sum([int(s.V) for s in subj]) for _ in range(niter): NLL_old = NLL g = xp.zeros((sT, noc), dtype=float) for s, XtXSt in zip(subj, XtXSt_list): inv_sqrt_sig = 1.0 / xp.sqrt(s.sigmaSq) g += s.sX.T @ ((s.sXC * inv_sqrt_sig) @ s.SSt) - XtXSt g /= total_V * sT col_dots = (g * C).sum(axis=0, keepdims=True) g = g - col_dots stop = False Cold = C.copy() while not stop: C = Cold - muC * g C = xp.maximum(C, 0.0) C /= C.sum(axis=0, keepdims=True) + xp.finfo(float).eps NLL_tmp = 0.0 for s in subj: inv_sqrt_sig = 1.0 / xp.sqrt(s.sigmaSq) s.sXC = (s.sX @ C) * inv_sqrt_sig s.CtXtXC = s.sXC.T @ s.sXC term = ( 0.5 * s.SST_sigmaSq - (s.sXC * s.XSt).sum() + 0.5 * (s.CtXtXC * s.SSt).sum() + (T / 2.0) * (s.V * np.log(2.0 * np.pi) + (cp if runGPU else np).log(s.sigmaSq).sum()) ) NLL_tmp += float(term.get()) if runGPU else float(term) if NLL_tmp <= NLL_old * (1.0 + 1e-9): muC = muC * 1.2 NLL = NLL_tmp stop = True else: muC = muC / 2.0 return C, muC, NLL