"""
Spatial Multi-Subject Archetypal Analysis (MS-AA).
Direct Python port of ``MultiSubjectAA.m`` from the upstream Multisubject
Archetypal Analysis Toolbox.
Original MATLAB authors:
Jesper L. Hinrich, Sophia E. Bardenfleth, Morten Mørup
Copyright (C) 2016 Technical University of Denmark.
Distributed under the terms of the Multisubject Archetypal Analysis
Toolbox license (see LICENSE).
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any
import numpy as np
try:
import cupy as cp
except ImportError:
cp = None
from archepy._utils import mgetopt, to_float, to_numpy
from archepy.core._s_update import supdate_indi_step
from archepy.init.furthest_sum import furthest_sum
[docs]
@dataclass
class Subject:
"""
Per-subject data and per-iteration sufficient statistics.
Parameters
----------
X : ndarray, shape (T, V)
Observed data: time points × voxels.
sX : ndarray, shape (T, sV)
"Tilde" data used to construct archetypes (often equal to X).
The remaining fields are populated and updated by ``multi_subject_aa``;
you do not need to set them yourself.
"""
X: np.ndarray
sX: np.ndarray
T: Any = field(default=None)
SST: Any = field(default=None)
sigmaSq: Any = field(default=None)
S: Any = field(default=None)
muS: Any = field(default=None)
sXC: Any = field(default=None)
XCtX: Any = field(default=None)
CtXtXC: Any = field(default=None)
SSt: Any = field(default=None)
XSt: Any = field(default=None)
SST_sigmaSq: float = field(default=0.0)
NLL: float = field(default=0.0)
[docs]
def multi_subject_aa(
subj: list[Subject],
noc: int,
opts: dict[str, Any] | None = None,
) -> tuple[list[dict[str, Any]], np.ndarray, np.ndarray, float, float]:
"""
Fit Multi-Subject Archetypal Analysis (spatial variant).
Parameters
----------
subj : list of Subject
One :class:`Subject` per subject. Each provides ``X`` (T × V) and
``sX`` (T × sV); they may be the same array.
noc : int
Number of archetypes (K).
opts : dict, optional
Configuration. Keys (with defaults):
- ``conv_crit`` (1e-6) — relative-NLL convergence threshold
- ``maxiter`` (100) — maximum outer iterations
- ``fix_var_iter`` (5) — keep ``sigmaSq`` fixed for this many iters
- ``use_gpu`` (False) — run on GPU via CuPy
- ``heteroscedastic`` (True) — per-voxel noise variance
- ``numCstep`` (10) — inner C-update steps per outer iter
- ``numSstep`` (20) — inner S-update steps per outer iter
- ``sort_crit`` ("corr") — how to order archetypes at the end
- ``init`` ("FurthestSum") — initialization scheme
- ``initSstep`` (250) — extra S-update steps during init
- ``rngSEED`` (None) — RNG seed
- ``noise_threshold`` (None) — lower bound for ``sigmaSq``
Returns
-------
results : list of dict
Per-subject outputs: ``S``, ``sXC``, ``sigmaSq``, ``NLL``, ``SSE``,
``SST``, ``SST_sigmaSq``.
C : ndarray, shape (sV, K)
Shared archetype generator (columns sum to 1).
cost_fun : ndarray, shape (n_iter,)
Negative log-likelihood at each outer iteration.
varexpl : float
Variance explained, ``(SST - sum(SSE)) / SST``.
elapsed : float
Wall-clock time in seconds.
"""
if opts is None:
opts = {}
conv_crit = mgetopt(opts, "conv_crit", 1e-6)
maxiter = int(mgetopt(opts, "maxiter", 100))
fix_var_iter = int(mgetopt(opts, "fix_var_iter", 5))
runGPU = bool(mgetopt(opts, "use_gpu", False))
voxelVariance = bool(mgetopt(opts, "heteroscedastic", True))
numCstep = int(mgetopt(opts, "numCstep", 10))
numSstep = int(mgetopt(opts, "numSstep", 20))
sort_crit = mgetopt(opts, "sort_crit", "corr")
init_type = mgetopt(opts, "init", "FurthestSum")
initial_S_steps = int(mgetopt(opts, "initSstep", 250))
rngSEED = mgetopt(opts, "rngSEED", None)
if runGPU and cp is None:
raise ImportError(
"opts['use_gpu']=True but CuPy is not installed. Install with: pip install archepy[gpu]"
)
rng = np.random.default_rng(rngSEED)
V = subj[0].X.shape[1]
sV = subj[0].sX.shape[1]
B = len(subj)
T_list = [s.sX.shape[0] for s in subj]
xp = cp if runGPU else np
# ------------------------------------------------------------------ #
# Initialise C #
# ------------------------------------------------------------------ #
if init_type.lower() == "furthestsum":
total_T = sum(T_list)
Xcombined = np.zeros((total_T, sV), dtype=float)
offset = 0
for s in subj:
T = s.sX.shape[0]
Xcombined[offset : offset + T, :] = np.asarray(s.sX, dtype=float)
offset += T
seed = int(rng.integers(low=0, high=sV))
if runGPU and cp is not None:
from archepy.init._gpu import furthest_sum_gpu
print(f"[init] Running furthest_sum_gpu (K={noc}, N={total_T})...")
idx = furthest_sum_gpu(
Xcombined,
noc=noc,
i=seed,
exclude=None,
treat_as_kernel=False,
one_based=False,
)
else:
print(f"[init] Running furthest_sum CPU (K={noc}, N={total_T})...")
idx = furthest_sum(
Xcombined,
noc=noc,
i=seed,
exclude=None,
treat_as_kernel=False,
)
idx = np.asarray(idx, dtype=int)
C = xp.zeros((sV, noc), dtype=float)
C[idx, xp.arange(noc)] = 1.0
else:
print(f"[init] Random initialisation (K={noc})...")
C_rand = xp.asarray(rng.random((sV, noc)), dtype=float)
C_rand /= C_rand.sum(axis=0, keepdims=True) + xp.finfo(float).eps
C = C_rand
muC = xp.array(1.0)
# ------------------------------------------------------------------ #
# Move subject arrays to GPU and precompute SST #
# ------------------------------------------------------------------ #
for s in subj:
if runGPU:
s.X = cp.asarray(s.X, dtype=float)
s.sX = cp.asarray(s.sX, dtype=float)
s.T = s.sX.shape[0]
s.SST = (s.X * s.X).sum()
SST = float(sum(to_float(s.SST, runGPU) for s in subj))
# ------------------------------------------------------------------ #
# Initialise per-subject quantities #
# ------------------------------------------------------------------ #
for s in subj:
if voxelVariance:
s.sigmaSq = xp.ones((V, 1), dtype=float) * (SST / (sum(T_list) * V))
else:
s.sigmaSq = xp.ones((V, 1), dtype=float)
U = xp.asarray(rng.random((noc, V)), dtype=float)
s.S = -xp.log(U + xp.finfo(float).tiny)
s.S /= s.S.sum(axis=0, keepdims=True) + xp.finfo(float).eps
s.muS = xp.ones((1,), dtype=float)
s.sXC = s.sX @ C
s.XCtX = s.sXC.T @ s.X
s.CtXtXC = s.sXC.T @ s.sXC
s.SSt = s.S @ s.S.T
s.XSt = s.X @ s.S.T
S_np, muS_np, SSt_np = supdate_indi_step(
to_numpy(s.S, runGPU),
to_numpy(s.XCtX, runGPU),
to_numpy(s.CtXtXC, runGPU),
np.ones(s.S.shape[1], dtype=float),
int(s.T),
int(initial_S_steps),
to_numpy(s.sigmaSq.squeeze(), runGPU),
)
s.S = xp.asarray(S_np)
s.muS = xp.ones((s.S.shape[1],), dtype=float)
s.SSt = xp.asarray(SSt_np)
# ------------------------------------------------------------------ #
# Initial NLL #
# ------------------------------------------------------------------ #
NLL = 0.0
SST_sigmaSq = 0.0
for s in subj:
s.XSt = s.X @ (s.S / s.sigmaSq.T).T
s.SSt = s.S @ (s.S / s.sigmaSq.T).T
s.SST_sigmaSq = to_float((s.X * (s.X / s.sigmaSq.T)).sum(), runGPU)
s.NLL = float(
0.5 * s.SST_sigmaSq
- to_float((s.sXC * s.XSt).sum(), runGPU)
+ 0.5 * to_float((s.CtXtXC * s.SSt).sum(), runGPU)
+ (s.T / 2.0) * (V * np.log(2.0 * np.pi) + to_float(xp.log(s.sigmaSq).sum(), runGPU))
)
NLL += s.NLL
SST_sigmaSq += s.SST_sigmaSq
t_start = time.perf_counter()
cost_fun = np.zeros((maxiter,), dtype=float)
noise_threshold_opt = mgetopt(opts, "noise_threshold", None)
var_threshold = (
float(noise_threshold_opt)
if noise_threshold_opt is not None
else (SST / (sum(T_list) * V)) * 1e-2
)
# ------------------------------------------------------------------ #
# Main optimisation loop #
# ------------------------------------------------------------------ #
iter_ = 0
dNLL = np.inf
while ((abs(dNLL) >= conv_crit * abs(NLL)) or (fix_var_iter >= iter_)) and (iter_ < maxiter):
iter_ += 1
NLL_old = NLL
cost_fun[iter_ - 1] = NLL
# ---- C update ----
C, muC, NLL = _Cupdate_multi_subjects(subj, C, muC, NLL, numCstep, runGPU)
for s in subj:
s.sXC = s.sX @ C
s.XCtX = s.sXC.T @ s.X
s.CtXtXC = s.sXC.T @ s.sXC
s.NLL = float(
0.5 * s.SST_sigmaSq
- to_float((s.XCtX * (s.S / s.sigmaSq.T)).sum(), runGPU)
+ 0.5 * to_float((s.CtXtXC * s.SSt).sum(), runGPU)
+ (s.T / 2.0)
* (V * np.log(2.0 * np.pi) + to_float(xp.log(s.sigmaSq).sum(), runGPU))
)
# ---- S update ----
NLL = 0.0
SST_sigmaSq = 0.0
for s in subj:
S_np, muS_np, SSt_np = supdate_indi_step(
to_numpy(s.S, runGPU),
to_numpy(s.XCtX, runGPU),
to_numpy(s.CtXtXC, runGPU),
to_numpy(s.muS, runGPU),
int(s.T),
int(numSstep),
to_numpy(s.sigmaSq.squeeze(), runGPU),
)
s.S = xp.asarray(S_np)
s.muS = xp.asarray(muS_np)
s.SSt = xp.asarray(SSt_np)
if voxelVariance and (iter_ > fix_var_iter):
resid = s.X - (s.sXC @ s.S)
s.sigmaSq = (resid * resid).sum(axis=0, keepdims=True).T / float(s.T)
s.sigmaSq = xp.maximum(s.sigmaSq, var_threshold)
s.XSt = s.X @ (s.S / s.sigmaSq.T).T
s.SSt = s.S @ (s.S / s.sigmaSq.T).T
s.SST_sigmaSq = to_float((s.X * (s.X / s.sigmaSq.T)).sum(), runGPU)
else:
s.XSt = s.X @ (s.S / s.sigmaSq.T).T
s.NLL = float(
0.5 * s.SST_sigmaSq
- to_float((s.sXC * s.XSt).sum(), runGPU)
+ 0.5 * to_float((s.CtXtXC * s.SSt).sum(), runGPU)
+ (s.T / 2.0)
* (V * np.log(2.0 * np.pi) + to_float(xp.log(s.sigmaSq).sum(), runGPU))
)
NLL += s.NLL
SST_sigmaSq += s.SST_sigmaSq
dNLL = NLL_old - NLL
if iter_ % 5 == 0:
print(f" iter {iter_:4d} | NLL {NLL:.4e} | dNLL/NLL {dNLL / abs(NLL):.4e}")
# ------------------------------------------------------------------ #
# Wrap-up #
# ------------------------------------------------------------------ #
elapsed = time.perf_counter() - t_start
SSE = []
for s in subj:
sXC_S = s.sXC @ s.S
sse = float((cp if runGPU else np).linalg.norm(s.X - sXC_S, ord="fro") ** 2)
SSE.append(sse)
varexpl = (SST - sum(SSE)) / SST
ind = np.arange(noc)
if sort_crit.lower() == "corr" and (sum(T_list) == max(T_list) * B):
arch = np.zeros((T_list[0], B))
mean_corr = np.zeros((noc,))
for j in range(noc):
for bi, s in enumerate(subj):
arch[:, bi] = to_numpy(s.sXC[:, j], runGPU)
Ccorr = np.corrcoef(arch, rowvar=False)
iu = np.triu_indices(B, k=1)
mean_corr[j] = float(Ccorr[iu].mean())
ind = np.argsort(mean_corr)[::-1]
if not np.array_equal(ind, np.arange(noc)):
C = C[:, ind]
for s in subj:
s.S = s.S[ind, :]
s.sXC = s.sXC[:, ind]
C_np = to_numpy(C, runGPU)
results_subj: list[dict[str, Any]] = []
for bi, s in enumerate(subj):
results_subj.append(
{
"S": to_numpy(s.S, runGPU),
"sXC": to_numpy(s.sXC, runGPU),
"sigmaSq": to_numpy(s.sigmaSq, runGPU).reshape(-1, 1),
"NLL": s.NLL,
"SSE": float(SSE[bi]),
"SST": to_float(s.SST, runGPU),
"SST_sigmaSq": s.SST_sigmaSq,
}
)
return results_subj, C_np, cost_fun[:iter_], float(varexpl), float(elapsed)
# ------------------------------------------------------------------ #
# C update #
# ------------------------------------------------------------------ #
def _Cupdate_multi_subjects(subj, C, muC, NLL, niter, runGPU):
xp = cp if runGPU else np
sV, noc = C.shape
V = subj[0].X.shape[1]
total_T = sum(int(s.T) for s in subj)
for s in subj:
if not isinstance(s.SST_sigmaSq, float):
s.SST_sigmaSq = to_float(s.SST_sigmaSq, runGPU)
if not isinstance(s.NLL, float):
s.NLL = to_float(s.NLL, runGPU)
sXtsX_list = [s.sX.T @ s.sX for s in subj]
log_var_terms = [
float(s.T) / 2.0 * (V * np.log(2.0 * np.pi) + to_float(xp.log(s.sigmaSq).sum(), runGPU))
for s in subj
]
for _ in range(niter):
NLL_old = NLL
if float(muC) < 1e-12:
muC = xp.array(1e-4)
XtXSt_list = [s.sX.T @ s.XSt for s in subj]
g = xp.zeros((sV, noc), dtype=float)
for sXtsX, XtXSt, s in zip(sXtsX_list, XtXSt_list, subj):
g += sXtsX @ (C @ s.SSt) - XtXSt
g /= total_T * sV
g = g - (g * C).sum(axis=0, keepdims=True)
Cold = C.copy()
stop = False
ls_steps = 0
while not stop:
C = Cold - muC * g
xp.maximum(C, 0.0, out=C)
C /= C.sum(axis=0, keepdims=True) + xp.finfo(float).eps
NLL_gpu = xp.array(0.0)
for log_var, s in zip(log_var_terms, subj):
sXC_new = s.sX @ C
CtXtXC = sXC_new.T @ sXC_new
NLL_gpu = (
NLL_gpu
+ 0.5 * s.SST_sigmaSq
- (sXC_new * s.XSt).sum()
+ 0.5 * (CtXtXC * s.SSt).sum()
+ log_var
)
NLL_tmp = to_float(NLL_gpu, runGPU)
ls_steps += 1
if NLL_tmp <= NLL_old * (1.0 + 1e-9):
muC = muC * 1.2
NLL = NLL_tmp
stop = True
elif ls_steps >= 50:
C = Cold
NLL = NLL_old
stop = True
muC = xp.array(1e-4)
else:
muC = muC / 2.0
for s in subj:
s.sXC = s.sX @ C
s.CtXtXC = s.sXC.T @ s.sXC
return C, muC, NLL