01_fit_msaa_flexible.ipynb¶
Flexible MS-AA fitting notebook.
[31]:
ANALYSIS_TYPE = "spatial" # "spatial" or "temporal"
FIT_SCOPE = "across" # "within" or "across"
K_VALUES = [5, 10, 25, 50, 75, 100, 300, 500, 700]
CONDITIONS_TO_USE = ["intact", "word", "rest"]
TRIM_REST_BY = 100
RNG_SEED = 42
PIEMAN_MAT_PATH = "data/pieman/raw/pieman_data.mat"
POSTERIOR_MAT_PATH = "data/pieman/raw/pieman_posterior_K700.mat"
HELPERS_DIR = "helpers"
HYPERTOOLS_NORMALIZE_PATH = "hypertools/normalize.py"
OUTPUT_DIR = "msaa_flexible_outputs_npz"
MSAA_OPTS = dict(maxiter=200, conv_crit=1e-6, fix_var_iter=5, heteroscedastic=False, use_gpu=True, rngSEED=RNG_SEED)
SKIP_EXISTING = True
STOP_ON_ERROR = False
VERBOSE_PROGRESS = True
[32]:
%matplotlib inline
import os, sys, importlib.util
import numpy as np, pandas as pd
from scipy.io import loadmat
HT_NORMALIZE_AVAILABLE = False
try:
spec = importlib.util.spec_from_file_location("normalize", HYPERTOOLS_NORMALIZE_PATH)
ht = importlib.util.module_from_spec(spec)
spec.loader.exec_module(ht)
HT_NORMALIZE_AVAILABLE = True
except Exception:
ht = None
sys.path.append(HELPERS_DIR)
from helpers import MultiSubject_AA
from MultiSubject_AA import Subject, multi_subject_aa
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("Ready.")
Ready.
[33]:
def to_float_array(x): return np.array(x, dtype=float)
def array_cutter(data_list, cut_length):
out = [to_float_array(x).copy() for x in data_list]
return out if cut_length <= 0 else [x[:-cut_length] for x in out]
def coerce_TxV(x):
x = to_float_array(x)
if x.ndim != 2:
raise ValueError(f"Expected 2D array, got {x.shape}")
T, V = x.shape
if V < T:
x = x.T
x = x - x.mean(axis=0, keepdims=True)
return x.astype(np.float32, copy=False)
def normalize_across_fallback(subject_list, eps=1e-8):
subj = [coerce_TxV(x) for x in subject_list]
stacked = np.vstack(subj)
mu = stacked.mean(axis=0, keepdims=True)
sd = stacked.std(axis=0, keepdims=True) + eps
return [((x - mu) / sd).astype(np.float32, copy=False) for x in subj]
def normalize_subject_list(subject_list):
subj = [coerce_TxV(x) for x in subject_list]
if HT_NORMALIZE_AVAILABLE:
try:
return ht.normalize(subj, normalize="across")
except Exception:
return normalize_across_fallback(subj)
return normalize_across_fallback(subj)
def load_condition_data():
pieman_data = loadmat(PIEMAN_MAT_PATH)
data, conds = [], []
for c in CONDITIONS_TO_USE:
next_data = list(map(lambda i: pieman_data[c][:, i][0], np.arange(pieman_data[c].shape[1])))
data.extend(next_data)
conds.extend([c] * len(next_data))
conds_array = np.array(conds)
condition_data = {}
for c in CONDITIONS_TO_USE:
cur = [data[i] for i in np.where(conds_array == c)[0]]
if c == "rest" and TRIM_REST_BY > 0:
cur = array_cutter(cur, TRIM_REST_BY)
condition_data[c] = normalize_subject_list(cur)
return condition_data
condition_data = load_condition_data()
for c in CONDITIONS_TO_USE:
print(c, len(condition_data[c]), np.asarray(condition_data[c][0]).shape)
posterior = loadmat(POSTERIOR_MAT_PATH)
centers = to_float_array(posterior['posterior']['centers'][0][0][0][0][0])
widths = to_float_array(list(posterior['posterior']['widths'][0][0][0][0][0][:, 0].T)).ravel()
intact 36 (300, 700)
word 36 (300, 700)
rest 36 (300, 700)
[34]:
def orient_subject_matrix(X_raw, analysis_type):
Xtv = coerce_TxV(X_raw)
return Xtv if analysis_type == "spatial" else Xtv.T
def build_subject_objects(subject_list, analysis_type):
return [Subject(X=orient_subject_matrix(x, analysis_type), sX=orient_subject_matrix(x, analysis_type)) for x in subject_list]
def expected_shape_string(example_X_raw, analysis_type, K):
Xtv = coerce_TxV(example_X_raw)
T, V = Xtv.shape
return f"sXC ~ ({T}, {K}), S ~ ({K}, {V})" if analysis_type=="spatial" else f"sXC ~ ({V}, {K}), S ~ ({K}, {T})"
def fit_msaa(subject_list, K, analysis_type, opts=None, verbose=True):
opts = dict(MSAA_OPTS if opts is None else opts)
subj_objs = build_subject_objects(subject_list, analysis_type)
results_subj, C, cost_fun, varexpl, secs = multi_subject_aa(subj_objs, noc=K, opts=opts)
if verbose:
print(f"[{analysis_type.upper()} AA] K={K} | VarExpl={varexpl*100:.2f}%")
print("subj0 sXC shape:", np.asarray(results_subj[0]["sXC"]).shape)
print("subj0 S shape:", np.asarray(results_subj[0]["S"]).shape)
print("Expected:", expected_shape_string(subject_list[0], analysis_type, K))
return {"results_subj": results_subj, "C": C, "cost_fun": cost_fun, "varexpl": varexpl, "secs": secs, "K": K}
def save_fit_npz(save_path, fit_res, condition_labels_str):
np.savez_compressed(
save_path,
K=np.array(fit_res["K"]),
results_subj=np.array(fit_res["results_subj"], dtype=object),
C=np.array(fit_res["C"], dtype=object),
cost_fun=np.array(fit_res["cost_fun"], dtype=object),
varexpl=np.array(fit_res["varexpl"]),
secs=np.array(fit_res["secs"]),
condition_labels_str=np.array(condition_labels_str, dtype=object),
condition_codes=np.array(pd.Categorical(condition_labels_str).codes),
condition_names=np.array(sorted(np.unique(condition_labels_str)), dtype=object),
analysis_type=np.array(ANALYSIS_TYPE, dtype=object),
fit_scope=np.array(FIT_SCOPE, dtype=object),
centers=np.array(centers),
widths=np.array(widths),
)
def get_fit_jobs(condition_data, fit_scope):
jobs = []
if fit_scope == "within":
for cond_name, subject_list in condition_data.items():
jobs.append({"job_name": cond_name, "subject_list": subject_list, "condition_labels_str": [cond_name]*len(subject_list)})
else:
all_subjects, all_labels = [], []
for cond_name, subject_list in condition_data.items():
all_subjects.extend(subject_list)
all_labels.extend([cond_name]*len(subject_list))
jobs.append({"job_name": "acrossCond", "subject_list": all_subjects, "condition_labels_str": all_labels})
return jobs
def job_output_path(output_dir, analysis_type, fit_scope, job_name, K):
return os.path.join(output_dir, f"{analysis_type}AA_{fit_scope}_{job_name}_K{K}.npz")
jobs = get_fit_jobs(condition_data, FIT_SCOPE)
jobs
[34]:
[{'job_name': 'acrossCond',
'subject_list': [array([[ 0.7390178 , 0.0482161 , 0.41048738, ..., 2.08967 ,
1.7098577 , -0.51149654],
[ 0.99505675, 0.37382004, -1.9477133 , ..., 1.7763319 ,
3.687591 , -0.91213423],
[ 2.1640973 , 0.33549806, -1.9503074 , ..., 1.2327224 ,
2.0651553 , -0.98489267],
...,
[-0.04969338, 0.3840668 , -0.60217726, ..., 0.8902922 ,
0.5947482 , -1.0532358 ],
[ 0.5482102 , 1.1633033 , -1.3218285 , ..., 1.6571469 ,
-0.40749523, -0.61102223],
[-0.7863177 , 0.03553534, -0.87664074, ..., 1.3443426 ,
-1.1747104 , 0.23313299]], dtype=float32),
array([[-1.3355055 , 0.05815953, -0.6389527 , ..., -2.5327556 ,
0.05553686, 0.09448261],
[-0.11980657, -0.90872073, -0.32386687, ..., -2.6671445 ,
-1.0529251 , 0.4178725 ],
[ 0.7181586 , 0.32694626, 1.5326638 , ..., -1.3597811 ,
-0.42158124, 0.5623699 ],
...,
[ 0.85279834, 0.5834086 , 0.9105071 , ..., 1.2327574 ,
0.74229705, -0.2022038 ],
[ 0.36754313, -0.70938075, -1.0477269 , ..., 0.14077267,
2.324751 , 0.40425405],
[-1.0413787 , -0.5402781 , -1.0965084 , ..., -0.10083916,
1.1456709 , -1.1621627 ]], dtype=float32),
array([[ 0.5603107 , -0.37583745, 1.470591 , ..., 1.6827015 ,
0.2970299 , -2.0640578 ],
[ 0.16190158, 0.23785615, 1.0908443 , ..., 1.5596046 ,
-0.29055473, -1.8549138 ],
[-0.40885293, 0.70138836, 1.7487617 , ..., 1.980824 ,
0.36851057, -0.8602258 ],
...,
[ 0.25354972, -0.9955393 , 0.37233368, ..., 0.28858435,
0.2104122 , 0.49673665],
[ 1.2505999 , -0.11117586, 1.4105264 , ..., 0.5651045 ,
0.13084073, -0.42167124],
[ 1.6660527 , 0.14879838, 1.6305336 , ..., 0.90677327,
0.06170204, 0.7434361 ]], dtype=float32),
array([[-0.5397172 , 0.97607416, 3.806748 , ..., 0.4906643 ,
1.3734721 , 0.36171496],
[-0.08842935, 0.2374603 , 2.7605295 , ..., 0.02885295,
0.3715529 , -0.7415578 ],
[-0.10905854, -1.1123528 , 1.950878 , ..., 0.9852272 ,
-0.28940868, -0.30383426],
...,
[-0.64817375, -0.7614477 , 0.7637328 , ..., 2.5112545 ,
0.776832 , -0.5318073 ],
[-1.1497946 , 0.46783993, 0.40554917, ..., 0.01003621,
-0.10539138, -0.7653684 ],
[-1.6952189 , -1.9822366 , -0.3814318 , ..., -0.5055152 ,
-2.2491138 , -0.4006145 ]], dtype=float32),
array([[ 0.22137973, -0.09299652, 0.0154783 , ..., -0.39757234,
1.3093148 , -0.66603005],
[-1.0152187 , -0.6243001 , -0.06215607, ..., -0.94304854,
0.99060816, -0.27801946],
[-1.4478375 , -0.5275948 , 0.06967478, ..., -1.7573456 ,
0.47942814, 0.73067003],
...,
[-0.2294402 , -1.9840367 , 1.4519899 , ..., 1.8173841 ,
-1.4056413 , 0.82458574],
[-1.6279058 , -2.9360993 , 0.6014083 , ..., 1.3094418 ,
1.0429313 , 1.7742375 ],
[-1.0756115 , -3.2914004 , 0.6567025 , ..., 0.39933988,
0.02091355, 1.9508593 ]], dtype=float32),
array([[ 0.8249232 , -1.5262455 , -0.82812935, ..., -2.2166233 ,
-0.21268173, 1.9235411 ],
[-0.49814698, -1.143478 , -1.4882698 , ..., -2.2795486 ,
0.25961745, 1.9563462 ],
[ 0.20576364, -1.5829862 , -2.079018 , ..., -1.3966264 ,
0.7346948 , 2.1427166 ],
...,
[ 0.39731663, -1.3232337 , 0.8538717 , ..., 0.7204007 ,
0.08321119, 2.1979375 ],
[ 0.3806104 , -0.69028234, 2.5640569 , ..., 0.78318757,
0.3116425 , 1.5481331 ],
[-0.91837835, -1.5158116 , 3.4710395 , ..., -0.18075956,
0.92637897, 1.6446083 ]], dtype=float32),
array([[ 0.1959962 , -1.9253213 , 0.14689504, ..., -0.3254581 ,
1.0550927 , -1.5802543 ],
[ 0.2688305 , -1.2975283 , 0.49883625, ..., -0.96004444,
0.43959823, -1.2910807 ],
[-0.09012878, -1.311608 , 0.77797973, ..., -0.14912403,
0.08101639, -0.67725587],
...,
[ 0.16575298, 0.08517738, 0.5879714 , ..., -0.83741975,
-1.8476994 , -0.60410327],
[-0.25406158, 1.4980084 , 1.3538698 , ..., -0.9621968 ,
-0.17981473, 0.13409987],
[ 0.5734744 , 1.1523459 , 0.17240505, ..., 0.78960365,
0.14684187, 2.0645206 ]], dtype=float32),
array([[-3.2974824e-01, 1.6479158e+00, 1.9745809e+00, ...,
6.5395904e-01, 5.5492800e-01, 4.7292659e-01],
[ 2.7251428e-01, 1.6737143e+00, 1.2628458e-02, ...,
4.5604318e-01, 1.0091408e+00, 7.2221804e-01],
[ 1.8763050e-01, 1.0999972e+00, -8.0890357e-01, ...,
1.0188515e+00, 7.7372134e-01, 4.9914300e-01],
...,
[ 5.6072360e-01, 1.7415364e+00, 3.8309160e-01, ...,
1.4905858e+00, -2.4391592e+00, 1.2517349e-01],
[ 2.7038619e-01, 1.2808168e-01, 6.5044861e-04, ...,
2.3452804e+00, -2.8897557e+00, -9.2954093e-01],
[-1.0215391e+00, 4.3672171e-01, 4.7452784e-01, ...,
2.3772326e+00, -1.5396640e+00, -1.1567736e+00]], dtype=float32),
array([[ 0.9426217 , 1.3241006 , -2.1458738 , ..., 0.3371971 ,
0.94377387, 0.3443159 ],
[-0.08889996, 0.67893165, -1.2334135 , ..., -0.6296816 ,
0.47738472, 0.09305423],
[ 1.0394448 , 0.84588045, 0.31372777, ..., -1.737548 ,
0.7082813 , -0.607765 ],
...,
[ 0.8590699 , -0.33694145, 0.19900821, ..., 0.15512897,
-0.5262231 , 0.28620788],
[ 0.60629606, -0.24034722, -0.54429215, ..., -1.4057236 ,
-0.09988109, 1.0039208 ],
[ 0.20653018, 0.7045066 , 0.26341432, ..., -1.1955622 ,
0.5801692 , 1.1342934 ]], dtype=float32),
array([[ 0.72356176, -0.97991085, 1.265221 , ..., 1.3038725 ,
0.6489574 , -0.6779169 ],
[ 0.3087346 , 0.12071916, 1.1758003 , ..., 0.7869704 ,
-1.1932161 , 0.2702206 ],
[-0.57743835, -0.01206993, 1.7492968 , ..., 1.3501917 ,
0.0260087 , -0.11468963],
...,
[-1.7563982 , -2.0637457 , 0.618641 , ..., -0.15320055,
0.2522176 , -0.04599535],
[-1.3366417 , -2.4983191 , -0.5589511 , ..., -0.9209508 ,
-0.49525255, 0.36325958],
[-0.51714677, -3.1522806 , -1.0774366 , ..., -0.6250707 ,
0.29010874, 0.03947557]], dtype=float32),
array([[ 0.71355605, 3.4872725 , 1.4505602 , ..., 3.3666556 ,
2.535822 , -1.7939836 ],
[-0.17694801, 3.4111888 , 0.6270476 , ..., 2.5612218 ,
-0.5540496 , -0.7768245 ],
[-0.32167655, 3.3512077 , 1.4682922 , ..., 1.3945674 ,
1.1294898 , 0.52584803],
...,
[-0.70531845, -0.05721777, 0.6549329 , ..., -1.5539868 ,
1.0284967 , 1.6068839 ],
[ 0.42767334, -1.7228291 , 0.52416795, ..., -0.4071238 ,
1.8883835 , 0.6083716 ],
[-0.62944514, -0.809434 , 0.73043805, ..., -1.4101099 ,
0.88381916, 0.64465123]], dtype=float32),
array([[ 2.1537447 , -1.1047823 , -0.5573104 , ..., 1.3472735 ,
-0.2832397 , 0.46356604],
[ 0.5911809 , -0.35858065, 0.0110495 , ..., 1.4062219 ,
-0.07830517, 1.2098178 ],
[ 0.7004962 , -0.90092385, 0.53157556, ..., 1.2821952 ,
0.43320224, 1.437641 ],
...,
[ 0.59216875, -1.9632404 , 1.8247517 , ..., 1.3201549 ,
-1.3811388 , 0.0746446 ],
[ 0.01084798, -1.8002512 , 0.3083977 , ..., 1.5526009 ,
-0.9808506 , -0.20339105],
[-1.3442614 , -0.60630804, -0.06537273, ..., 1.2062674 ,
-0.18448943, -1.3778393 ]], dtype=float32),
array([[ 2.0712748 , 1.0019736 , -0.94277173, ..., -3.289525 ,
0.9712073 , -2.6402266 ],
[ 2.6123712 , -0.08934733, 0.7959927 , ..., -3.1996372 ,
0.83905095, -1.8937953 ],
[-0.9433681 , 1.673027 , 2.7181222 , ..., 1.4850706 ,
0.24683155, -0.16215296],
...,
[-0.42187497, -2.3406262 , 0.530007 , ..., 0.83460957,
0.90565324, 1.8890206 ],
[ 0.05284988, -2.6008809 , 0.22312884, ..., 1.0375566 ,
0.41586807, 1.6393757 ],
[-0.23170109, -1.7112055 , 0.29938325, ..., 0.3731526 ,
-0.23144236, -0.3667292 ]], dtype=float32),
array([[ 2.0628638 , 0.45558238, 2.1577866 , ..., 0.70967865,
2.3757966 , 1.1656309 ],
[ 0.53308856, -0.2855954 , -0.1370056 , ..., -0.1522649 ,
0.25150642, 1.0437446 ],
[-1.0345492 , 0.40426236, 0.69511575, ..., 0.8367079 ,
0.54877967, 0.65683675],
...,
[ 2.361055 , -1.8558182 , 0.9670947 , ..., 1.9898274 ,
1.2658651 , 0.64906836],
[-0.03410075, -0.69706315, 1.3726455 , ..., 1.7733822 ,
1.3924761 , -0.4861673 ],
[ 0.45241898, -1.7411535 , 0.46511716, ..., 0.7261859 ,
0.4408361 , -0.6654733 ]], dtype=float32),
array([[ 1.2745525 , 1.2688726 , -0.11326164, ..., 2.1365993 ,
-2.161293 , -0.4463611 ],
[ 1.1945733 , 0.8129434 , -0.53035367, ..., -0.51819927,
1.0728649 , -0.6765254 ],
[-0.8032575 , -1.2273701 , -0.92175025, ..., -0.27992404,
0.44999114, -2.492097 ],
...,
[ 2.0687218 , -2.0830536 , -0.4258065 , ..., -1.244959 ,
-1.8508611 , -3.2494786 ],
[ 2.653184 , -0.59137905, -0.5919486 , ..., 1.2449652 ,
1.4759021 , -3.1056817 ],
[ 0.94585264, 0.4019399 , -0.8981591 , ..., 1.3720764 ,
0.2955397 , -2.2230244 ]], dtype=float32),
array([[-1.1604416 , 0.72118837, 1.8554634 , ..., -0.4014971 ,
0.5364345 , 1.0777668 ],
[ 0.46688578, -0.04295034, 1.2847925 , ..., 0.37356022,
0.12647304, -0.55635566],
[-0.14617062, 0.27739003, -0.23848474, ..., -0.5110284 ,
-0.28221056, -1.110964 ],
...,
[-0.376232 , 0.55635273, 0.8014504 , ..., -0.39052078,
0.01846801, -0.07154845],
[-0.5633393 , -0.5834416 , 0.8803994 , ..., -0.51620716,
-0.9545088 , -0.9274522 ],
[ 0.07191011, -2.2373362 , -1.0339043 , ..., 0.11995698,
-0.02501206, -0.13816458]], dtype=float32),
array([[-0.45546973, -1.5854179 , 1.1205748 , ..., -0.00451433,
0.89653814, -0.22824314],
[-0.74300367, -0.29004988, -0.17170677, ..., -1.0093083 ,
1.248333 , 0.39251718],
[ 0.16385242, -0.40184245, 0.529642 , ..., -0.91984576,
0.5978529 , 0.92794114],
...,
[ 0.49684742, -0.2998419 , 0.99365 , ..., 0.6539206 ,
-0.22766219, -0.22543108],
[-0.8243233 , 0.2595913 , 0.26160765, ..., 0.22837593,
0.02933186, 0.1430671 ],
[-0.1265595 , -1.4897392 , 0.7107429 , ..., 0.30059147,
0.43519753, -0.09355153]], dtype=float32),
array([[-0.8701806 , -1.9689481 , -0.18220916, ..., -1.114435 ,
2.2025938 , 1.2925311 ],
[-0.07955506, -0.29297146, 1.7164247 , ..., 0.03798555,
0.83406067, 1.8329525 ],
[ 0.6001884 , -0.72531056, -0.18037978, ..., 0.89549404,
0.5421409 , 0.13959609],
...,
[-0.23066938, -0.79147714, 1.2364975 , ..., -1.4318113 ,
0.53054893, 0.047457 ],
[-1.0451856 , -1.0494578 , 1.5948019 , ..., -2.1926627 ,
0.49812767, -0.41941637],
[-0.46825638, -2.221168 , -1.2224398 , ..., -3.4631288 ,
-0.61178046, 0.407117 ]], dtype=float32),
array([[ 0.02134061, 0.16707014, 1.227667 , ..., 1.489595 ,
1.6014613 , 0.6455997 ],
[ 1.0775602 , 0.17000704, 0.43884146, ..., 0.11267819,
0.18792479, 0.76982224],
[-0.29798263, -0.1458549 , -0.57961553, ..., 0.0028938 ,
-0.11344934, -0.08889275],
...,
[-1.6674912 , -0.50551045, 0.04305394, ..., 0.07877064,
-1.3944423 , -0.6438531 ],
[-1.6112357 , 0.66108984, -0.71493477, ..., -0.73373955,
-0.9122343 , -1.554692 ],
[-2.7291045 , 0.75746936, -1.1448628 , ..., -1.5878414 ,
-2.1284115 , -1.4379019 ]], dtype=float32),
array([[ 0.72066367, 1.0274323 , 0.02064193, ..., -0.17512669,
1.2962015 , 1.7026414 ],
[-1.7614418 , 0.67388785, 0.36266515, ..., -1.5268493 ,
1.0022483 , -0.1860586 ],
[-2.1208827 , -0.17846894, 0.6277487 , ..., -0.63363373,
-0.61562824, -0.7564957 ],
...,
[ 0.6208703 , -0.30723113, 2.4497004 , ..., 2.7276566 ,
0.55806994, 0.43766817],
[ 0.19728453, -0.31360033, 1.736716 , ..., 1.3196268 ,
-0.11539779, 0.68545204],
[-0.15023628, -0.8817894 , 1.351156 , ..., 1.289527 ,
-0.2604276 , -0.66407233]], dtype=float32),
array([[ 0.6865799 , 1.6024748 , -0.3227221 , ..., -1.0620053 ,
0.4168055 , 0.44255626],
[-0.8469883 , -0.48524648, -0.37598482, ..., -0.72684425,
-0.3876209 , -1.2949187 ],
[-0.93079394, -2.2526596 , 0.08877525, ..., -1.2038736 ,
-0.45271656, -0.7126692 ],
...,
[-0.80575395, -1.8731098 , -1.2769076 , ..., -0.02514255,
-0.3213051 , 0.17715144],
[-1.0792805 , -1.3322552 , -2.1054099 , ..., -0.80541456,
-0.59748083, 0.4782259 ],
[-0.96457225, -0.46701217, -0.04705581, ..., -0.00722997,
0.2971585 , -0.83018607]], dtype=float32),
array([[ 0.15748134, -0.20657188, 1.7093493 , ..., 0.9278269 ,
2.7099638 , 1.40989 ],
[-0.21111639, 0.05630115, 1.5294071 , ..., 0.13152665,
2.5575144 , 0.3527339 ],
[ 0.62383306, 1.4885142 , 1.9194083 , ..., 0.51463544,
1.6023746 , 0.96493125],
...,
[ 1.2043899 , 0.9679857 , 1.6724696 , ..., -0.75637835,
0.96489495, -0.42729077],
[-0.42131764, -0.23450965, 0.66219765, ..., -0.33546034,
-0.38640654, 0.5989126 ],
[ 1.1839164 , 0.8117586 , 1.3316717 , ..., -0.61864287,
-0.09852954, 0.9833038 ]], dtype=float32),
array([[ 1.4106961 , -0.17426789, -1.1368394 , ..., 0.7335291 ,
0.30901495, 1.2884549 ],
[-0.28718582, -0.8063906 , -1.8932943 , ..., 0.26155588,
-0.0514004 , 3.1407907 ],
[-0.6823331 , -0.82661724, -1.457333 , ..., 0.7537143 ,
0.4393797 , 2.674663 ],
...,
[ 0.5477625 , -0.95375305, -2.9289098 , ..., -2.284483 ,
0.06476276, 0.78828 ],
[-0.72383904, -0.5398273 , -1.5337002 , ..., -2.0155895 ,
0.8565057 , 1.2738643 ],
[-0.98318315, 0.15659814, -2.5499592 , ..., -1.4710639 ,
1.4698828 , 2.8153093 ]], dtype=float32),
array([[-2.084588 , -0.70443183, 1.2409015 , ..., 2.1878514 ,
-1.0409099 , 1.1491213 ],
[-1.379073 , -1.7755674 , 1.9950787 , ..., 0.3979115 ,
-0.25634012, -0.00575971],
[ 0.11460473, -1.1517016 , 0.3892071 , ..., -0.10231799,
0.80934364, 0.91973686],
...,
[-1.0602771 , -0.52931315, 1.1217343 , ..., -0.04182733,
0.1043918 , 2.0442915 ],
[-0.711789 , -0.24784696, 1.8995337 , ..., -0.03784478,
0.20417707, 1.2919602 ],
[ 0.04840634, 0.58928955, 2.3900535 , ..., 1.0358518 ,
1.080189 , 2.0557768 ]], dtype=float32),
array([[ 1.5283881 , 1.2171044 , 0.77944297, ..., -0.37551448,
1.0074605 , 0.82438916],
[-0.65975595, 1.4599266 , 1.2471536 , ..., 0.16539964,
-0.48312375, 1.9118901 ],
[-0.91329795, 0.5753065 , 1.5873545 , ..., 0.28416494,
-0.2578854 , 0.46584842],
...,
[-0.52455705, 1.4790775 , 1.7034119 , ..., 0.50963676,
1.5105929 , -0.7225215 ],
[-0.4002906 , 1.1179876 , 1.4207579 , ..., -0.33958337,
1.2961698 , -0.392176 ],
[ 0.28252557, -0.24968177, 1.221817 , ..., -0.9852486 ,
0.4392636 , 0.574967 ]], dtype=float32),
array([[ 1.2635275 , -0.69141835, 0.3048125 , ..., 2.6837792 ,
-1.6813303 , -1.0613207 ],
[ 0.3628993 , -0.38234842, -0.12229529, ..., 1.1154888 ,
-1.5913422 , -0.6725618 ],
[-0.45076603, -1.2245637 , 0.8943058 , ..., -0.32669616,
-1.6570921 , -0.95904535],
...,
[ 1.7279671 , -1.960848 , 0.763937 , ..., 0.90521836,
-1.7330693 , -0.77608335],
[ 1.1203343 , -2.9563468 , 2.9087396 , ..., 0.94913876,
-0.1676931 , -1.5447367 ],
[-0.32077894, -1.0519105 , 2.1991606 , ..., 0.26159668,
-0.6128104 , -1.2385558 ]], dtype=float32),
array([[ 0.30595696, 1.4781595 , -0.16872254, ..., 0.8017177 ,
-0.2026927 , -0.85710067],
[ 1.3227234 , 1.69429 , 0.7899473 , ..., 0.5821254 ,
1.2868924 , 1.5631151 ],
[ 0.58477134, 2.1053483 , 1.586721 , ..., 0.30735308,
1.8717344 , 1.9541849 ],
...,
[ 1.7011565 , 0.51743484, -1.2842584 , ..., 0.18133488,
1.0585245 , -0.8798341 ],
[ 2.1314104 , 0.33944193, -0.96357554, ..., -0.27829576,
0.16422078, -1.7908566 ],
[ 0.6670965 , -0.71942896, -0.9757837 , ..., 0.15224087,
0.05984012, 0.47093052]], dtype=float32),
array([[-1.2497512 , 1.2985702 , -0.47023633, ..., -1.242889 ,
1.1457105 , 1.1366804 ],
[-0.544845 , 0.9919714 , 0.36871746, ..., -1.5401931 ,
-0.34032163, -0.00270908],
[ 0.60931635, 1.9426792 , -0.5261689 , ..., -1.284464 ,
-1.9692975 , 0.5798675 ],
...,
[-0.6916285 , -0.09454055, 0.26357684, ..., -1.2672998 ,
-1.2567781 , 0.6632811 ],
[ 0.76379216, -0.38201442, -0.5204061 , ..., -2.3813434 ,
-2.140348 , 0.3420519 ],
[-1.3222544 , -0.5449134 , 0.8349026 , ..., -1.7710676 ,
-2.6189282 , 0.02737117]], dtype=float32),
array([[-0.11762964, 1.407116 , -0.10160831, ..., 1.4472836 ,
-0.17873684, 0.20925352],
[-0.9503117 , -0.2964832 , 0.5923771 , ..., 0.4398702 ,
0.2688875 , 0.42575473],
[ 0.26793143, -0.16488858, 0.78853846, ..., 0.67135733,
1.1154739 , 0.5297355 ],
...,
[ 1.079226 , 0.2079864 , -0.44802806, ..., -0.39794773,
-0.5000269 , 0.51322913],
[ 0.638239 , 0.6434459 , -1.1394682 , ..., -0.52813023,
0.04171708, 0.54289454],
[ 0.85810417, -0.351167 , -1.6470387 , ..., -0.7366571 ,
-0.38618383, 0.88937706]], dtype=float32),
array([[-1.1370481 , 0.59263456, -1.1249925 , ..., -1.6234448 ,
1.5192574 , -0.03758588],
[-3.0528827 , -0.9149365 , -1.2512625 , ..., -1.4542922 ,
-0.20167708, 0.45574507],
[-2.154367 , -0.8504415 , -1.0926851 , ..., -2.0353997 ,
-0.65952647, -0.24425322],
...,
[-1.0658656 , -1.7634275 , 0.92660457, ..., 3.6695302 ,
0.50997543, 2.0399182 ],
[ 0.8718185 , -0.9024726 , 0.7534066 , ..., 2.4955041 ,
-1.9293121 , 0.56720793],
[-0.9449979 , 0.4301018 , 0.30508024, ..., 1.3451865 ,
-0.84218025, 0.39063266]], dtype=float32),
array([[ 0.8327907 , -0.50663966, 2.665478 , ..., 0.38428715,
2.2855017 , 0.4188913 ],
[ 0.44025013, -1.1382394 , 1.7526109 , ..., 0.46707562,
-0.30671 , -0.5456553 ],
[-1.2209616 , -1.3847637 , 0.904859 , ..., 0.94741327,
-1.1682811 , -0.411897 ],
...,
[ 0.85857946, -0.1994129 , 1.306179 , ..., -0.18761268,
-0.71414775, -1.6635534 ],
[-0.46887898, -0.44386247, 1.8113064 , ..., -0.8135728 ,
-0.95630276, -1.9432522 ],
[-0.38629237, -1.0298389 , 1.9648892 , ..., -0.5609785 ,
0.11482971, -1.2501981 ]], dtype=float32),
array([[ 0.48308533, 0.10822366, 0.2739298 , ..., -1.0738425 ,
4.005451 , -1.1437125 ],
[ 0.8131857 , -0.10646603, -1.0549349 , ..., -3.0052366 ,
3.4262686 , -0.22227426],
[ 1.6954342 , 0.9134475 , 1.5686233 , ..., 0.04991836,
1.3326453 , 1.1480608 ],
...,
[ 0.37830743, 0.17610425, 1.3096962 , ..., -1.1822685 ,
0.9549437 , -1.6331943 ],
[ 0.98817486, 1.1048496 , -0.42009544, ..., -2.462854 ,
0.88211817, -1.0247637 ],
[-0.10035483, 0.7613344 , -0.8174921 , ..., -1.3692095 ,
0.28605032, -0.21560732]], dtype=float32),
array([[ 0.31144357, 0.22346462, 0.06591919, ..., 2.749033 ,
-0.45165962, -0.02155037],
[ 0.14590614, -0.8100267 , 0.87472594, ..., 0.5991242 ,
-0.29243013, 0.11579123],
[ 0.3521812 , 0.7122578 , -0.35082072, ..., -0.86644375,
-0.10187618, -0.25721762],
...,
[ 1.0519408 , -1.3419595 , 0.81393826, ..., 1.0525138 ,
1.1557198 , 0.80662394],
[ 2.046592 , -0.2640682 , 1.8095661 , ..., 0.6639248 ,
0.81787133, 1.2064518 ],
[ 2.1874 , -1.6698933 , 1.4442877 , ..., 1.0089625 ,
0.00492027, 1.0519465 ]], dtype=float32),
array([[-0.7408734 , -0.67146206, 0.9981282 , ..., 0.2466547 ,
0.65118223, -0.7073163 ],
[-0.5799743 , -2.6254988 , 1.5612695 , ..., -0.09032917,
0.9911044 , -0.6800554 ],
[-0.36153692, -2.8939931 , 2.0491238 , ..., -0.9245387 ,
0.32114652, 0.10692298],
...,
[-0.2571319 , 0.3946458 , 0.9344187 , ..., -0.19051087,
-0.8633586 , -1.421279 ],
[-0.3153888 , -1.0466278 , 1.5144703 , ..., 0.21993634,
-1.3214744 , -0.37742913],
[-1.208023 , -0.34730363, 1.5509114 , ..., 0.20233685,
-0.5919301 , -1.676856 ]], dtype=float32),
array([[ 0.72231716, 2.0411544 , 0.64721787, ..., 0.8867602 ,
0.5853474 , 1.1322366 ],
[ 0.5405715 , -0.07949123, 2.5334303 , ..., 0.1900696 ,
1.0122919 , 0.9129506 ],
[-0.76834786, -0.5383611 , 1.9698265 , ..., -0.14156751,
1.4371328 , 0.23550907],
...,
[-1.0462339 , -0.381153 , 0.07865799, ..., -0.02572094,
1.4287553 , 1.4662995 ],
[-0.46847972, -1.6375849 , -0.02890162, ..., -0.4548475 ,
-0.22753158, 1.068198 ],
[-0.7973161 , -0.83293337, 0.5233656 , ..., -1.2205366 ,
-0.50174606, 0.82148594]], dtype=float32),
array([[ 0.7800478 , 3.132656 , 0.23101413, ..., 0.97873306,
0.34633186, 0.92795813],
[-0.59249705, 0.90927625, -0.23435067, ..., -0.33739167,
0.26799542, 0.7210118 ],
[-0.6861393 , 0.8599368 , 0.96881336, ..., -0.05941591,
0.3809339 , 0.33317384],
...,
[ 1.0076424 , -0.73971003, 3.8259573 , ..., -1.5729928 ,
1.0229925 , -0.79756206],
[ 0.6981268 , -1.4118763 , 2.4709187 , ..., -0.19410862,
0.32266784, -0.50798607],
[-0.9421238 , -0.5876456 , 3.4149535 , ..., -1.5113941 ,
0.4883281 , -0.26994693]], dtype=float32),
array([[-1.2025571 , 0.34354088, -0.3703537 , ..., 0.04466444,
0.77926636, 0.83967 ],
[-0.8054848 , -0.5786968 , 0.5821522 , ..., -1.5634562 ,
0.50599295, 0.7625565 ],
[-1.3978726 , -1.3693833 , 0.9689234 , ..., -1.6550643 ,
0.58485764, 0.24693903],
...,
[ 0.11268672, -0.65100044, 1.3458916 , ..., -0.832993 ,
-0.93964607, -0.9120335 ],
[ 0.06555071, 0.00493107, 1.4086773 , ..., 0.6832064 ,
-0.3085196 , -1.0140674 ],
[ 0.37850454, 0.5224717 , 0.7134874 , ..., 1.5460573 ,
1.552326 , -0.92729723]], dtype=float32),
array([[-0.3334624 , -0.44613174, 1.3955767 , ..., 1.424159 ,
1.268598 , -1.6505747 ],
[-0.21450005, 0.15614152, 2.2619536 , ..., 1.5478424 ,
1.8480592 , -2.128445 ],
[-0.03812077, 0.41087192, 1.5510945 , ..., 1.1219571 ,
0.9959495 , -1.0443426 ],
...,
[-0.09439173, -0.04308858, -0.25696915, ..., -1.8404281 ,
-0.32975075, -1.8474727 ],
[ 0.31423777, -0.93375635, -0.3335744 , ..., -0.85323423,
-0.12852225, -1.9732889 ],
[ 0.75520056, -0.40575948, 0.5878918 , ..., -0.44272885,
0.02869681, -1.5171266 ]], dtype=float32),
array([[-0.6835824 , -2.2846184 , 0.05166047, ..., -0.46644962,
1.2017698 , -0.8310663 ],
[-1.0168027 , -1.5395677 , 0.18712924, ..., -1.0953001 ,
2.1544745 , 0.093353 ],
[-0.5135986 , -1.4868544 , -0.08674329, ..., -0.6897917 ,
0.86329895, -1.0741868 ],
...,
[-2.5788543 , -1.5068465 , -0.43237442, ..., -0.18201084,
0.6092662 , -2.6498327 ],
[-2.0022871 , -2.6501906 , -0.7398281 , ..., -1.2659264 ,
-0.4985433 , -1.1235405 ],
[-1.4040592 , -2.4961543 , -1.7184815 , ..., -2.266353 ,
-1.7874087 , -1.2829715 ]], dtype=float32),
array([[ 1.867759 , 0.895639 , 3.2640195 , ..., 2.6375704 ,
3.5634499 , 1.3522242 ],
[-0.02084195, -0.5427502 , 1.1755154 , ..., 2.0233092 ,
2.5631552 , 0.12414528],
[ 0.43276536, -1.1231577 , 1.5457008 , ..., 1.1800566 ,
1.732419 , 1.3155035 ],
...,
[ 0.14938615, 0.10253841, 0.04278079, ..., -0.2993945 ,
-0.6355541 , -0.95653796],
[ 0.58024246, 0.11941084, -0.97583884, ..., -1.241724 ,
-0.6122684 , -0.41972613],
[ 0.8883338 , -1.7578603 , -1.557356 , ..., -1.5131032 ,
-0.19171007, -0.19647896]], dtype=float32),
array([[-0.77811486, -0.9200835 , -0.34352145, ..., -0.6952921 ,
-0.7521597 , 1.9314063 ],
[ 0.5550232 , 1.3381473 , -0.8927673 , ..., -1.7807237 ,
-0.62807924, 1.3457881 ],
[ 0.25304422, 1.1528295 , -0.92933375, ..., -1.3401525 ,
0.4693502 , 0.9559242 ],
...,
[-0.3556631 , -1.9135338 , -1.7992202 , ..., -1.4574018 ,
-1.9729819 , -1.4741563 ],
[ 1.9487926 , -1.8567853 , 0.15767601, ..., -0.34574974,
-1.9013805 , -0.6786529 ],
[ 0.00346977, -0.6174029 , -0.9642277 , ..., -0.3983926 ,
-2.1090798 , -1.0053546 ]], dtype=float32),
array([[-1.0270039 , -1.1038812 , 1.573922 , ..., -0.3502419 ,
1.9552547 , 1.2912518 ],
[ 0.8484791 , -1.7185808 , 1.0225459 , ..., -0.2692755 ,
1.5484561 , 0.5177783 ],
[ 1.3761461 , -0.99521327, 1.2287141 , ..., -0.4733405 ,
1.1246719 , -0.19452287],
...,
[ 1.0273731 , 0.75528914, 0.30736643, ..., 0.20737663,
-0.19872794, -2.4543984 ],
[-0.61155725, 0.6530137 , 0.3197632 , ..., -0.29689303,
0.18331446, -0.81114006],
[ 0.20005654, 0.4084283 , -0.12148466, ..., 0.55234385,
-0.5802144 , 1.9682392 ]], dtype=float32),
array([[-1.448251 , -0.01125452, 0.3385937 , ..., 1.4295045 ,
1.7983508 , -1.8930151 ],
[-0.36644465, -0.5887955 , 0.79340625, ..., 0.7806706 ,
0.45290276, -1.4158783 ],
[ 0.34395742, -1.3531532 , 0.3515645 , ..., 0.20838998,
0.4872913 , 0.46105993],
...,
[ 0.4506459 , -1.2507323 , -0.40148908, ..., 0.7829111 ,
0.06765137, -1.5925446 ],
[ 1.0501455 , -1.0905613 , -0.2356204 , ..., 1.3671579 ,
1.1363105 , -1.2247769 ],
[ 0.47133976, -1.049973 , 0.2726253 , ..., 0.92323196,
0.2102331 , -0.12023517]], dtype=float32),
array([[ 1.0103415 , -0.41344845, 0.9106565 , ..., -0.00840607,
1.2159364 , 2.1356564 ],
[-0.21910436, 0.5501856 , 0.18625495, ..., -0.40180248,
0.975526 , 0.44252244],
[ 0.22051595, 0.43748018, -0.35379934, ..., -1.147709 ,
-0.8094875 , -0.47170576],
...,
[ 0.7426917 , 1.0657325 , -0.71631616, ..., -1.1272583 ,
0.9841104 , -0.4929949 ],
[ 1.1352152 , 0.28163752, -1.742591 , ..., -0.13041277,
0.62934875, 0.21630676],
[ 1.4282644 , -0.17693397, -0.5893638 , ..., -0.80752575,
0.24868004, 0.8326239 ]], dtype=float32),
array([[ 1.4243134 , 1.7835143 , 1.0730317 , ..., 3.663663 ,
2.441308 , 1.3440924 ],
[ 1.7629377 , 1.5440209 , -0.04104272, ..., 2.939125 ,
2.4851542 , 1.7577063 ],
[ 1.5338899 , 1.9862826 , 2.982648 , ..., 2.5974607 ,
2.5012465 , 2.331936 ],
...,
[ 0.32035434, -0.23059362, 0.9933931 , ..., 1.6714025 ,
0.5951284 , -0.12671266],
[-0.26783785, -0.6014577 , -0.8756194 , ..., 0.90547603,
-0.06169757, 0.64561355],
[-0.5274673 , 1.0905155 , 0.18820861, ..., 0.6602218 ,
-0.2695024 , -0.61573285]], dtype=float32),
array([[ 0.07116701, -0.58176565, 1.1343398 , ..., -0.0444817 ,
0.8334979 , 0.80214185],
[ 0.17822617, 0.48124403, 0.27836788, ..., 0.47932208,
0.09408113, 0.19702542],
[-0.11212593, -0.02073994, -0.22789441, ..., -0.69997936,
-0.05510452, -0.64190745],
...,
[ 4.561573 , -1.2352515 , 2.6031296 , ..., 1.0717926 ,
-1.795789 , -1.8988005 ],
[ 1.9491427 , -1.5637897 , 1.1409273 , ..., 0.54163784,
-2.763235 , -2.7814791 ],
[-0.12952316, -2.568283 , -0.5843922 , ..., 1.810246 ,
-1.6573855 , -2.9600196 ]], dtype=float32),
array([[ 0.44363758, 2.0200198 , 0.61477435, ..., 2.2702641 ,
0.4738676 , 0.36180753],
[-1.2031467 , 0.9540959 , 0.6864423 , ..., 1.3365746 ,
-1.5202281 , 1.220517 ],
[ 1.3215561 , -1.6087883 , 0.92437583, ..., -0.5479754 ,
-1.8262938 , 0.257921 ],
...,
[ 0.8724717 , 1.3722891 , 2.995218 , ..., 0.48805082,
0.34916353, -1.6166513 ],
[ 0.62951595, -0.5694731 , 2.141211 , ..., -0.8165051 ,
0.40975472, -0.86069655],
[ 0.7812983 , 0.00586072, 0.7503733 , ..., -1.0734117 ,
-0.8022711 , -0.19326715]], dtype=float32),
array([[ 0.45909262, -0.9288828 , 0.9834809 , ..., -1.8920888 ,
0.89440894, -2.2063057 ],
[-0.66827196, -0.56711906, 0.36493677, ..., -1.8900766 ,
0.02922901, 1.8881873 ],
[-0.6038952 , -0.9929944 , 0.12108629, ..., -1.293284 ,
0.11995904, 2.9038382 ],
...,
[ 0.88740504, -1.8897728 , 0.31835392, ..., 0.326594 ,
-0.1814878 , -0.21144027],
[ 0.37182158, -0.9910129 , 1.8728042 , ..., 0.43937334,
-0.33394113, 0.5549615 ],
[ 0.7015578 , -1.0165309 , 2.625502 , ..., 0.61384165,
0.15924846, 1.1077385 ]], dtype=float32),
array([[-1.445636 , -0.47685033, -0.0359103 , ..., 0.42783335,
-0.18719782, -1.4203061 ],
[-0.1197788 , -1.4826258 , -0.94080526, ..., -0.09894625,
1.2305866 , -0.04428663],
[-0.2927531 , -1.6145829 , 0.1444815 , ..., -0.20200826,
-0.2056315 , -0.9773835 ],
...,
[-0.45386225, -1.6171087 , 0.65803313, ..., 0.27761266,
-1.4213476 , -1.0970969 ],
[ 0.50741845, -2.1491764 , 0.22442092, ..., -0.1389223 ,
-0.7737223 , -0.34079698],
[-0.96836364, -2.639437 , -1.6047003 , ..., 0.26992002,
0.09421141, -0.20358157]], dtype=float32),
array([[ 0.23422147, -0.3389358 , -0.58439875, ..., 1.1142879 ,
1.9510994 , -0.04823534],
[ 0.60265887, -1.4671867 , -0.5701921 , ..., 1.5484467 ,
2.4551904 , 0.10565981],
[-0.23950669, -1.6458807 , -0.59168005, ..., 0.21907525,
2.470652 , -1.6226238 ],
...,
[ 0.4606507 , 1.2052448 , -0.22534084, ..., -0.19750029,
0.51878285, 0.63734955],
[-0.22257969, -0.7032121 , 0.02125443, ..., 0.09797084,
0.31502938, -0.41384602],
[ 1.8017273 , -1.4173303 , 0.68160385, ..., 0.62184507,
0.7271969 , 0.35646453]], dtype=float32),
array([[-1.1859248 , -0.18632086, 1.9547614 , ..., 1.9819579 ,
3.2560213 , 1.9674898 ],
[-0.7281422 , 0.95233846, 2.2494414 , ..., 1.8653177 ,
2.4019938 , 0.12762061],
[-0.2444206 , 0.7124917 , -0.43430492, ..., 1.8107275 ,
1.2571485 , -0.17886066],
...,
[ 0.2733662 , 0.5716106 , 0.7734418 , ..., 0.52637357,
1.0713935 , 0.69215894],
[ 0.47145545, 1.449949 , 0.09331153, ..., 0.45816913,
1.195703 , -0.35216793],
[ 0.04235786, 1.189377 , -0.04827493, ..., 0.28875902,
0.5250594 , -0.5773043 ]], dtype=float32),
array([[-0.16587585, -1.4423214 , 1.4637448 , ..., 0.9226769 ,
1.5176648 , 0.7341115 ],
[-1.5930936 , -1.4661456 , 0.6561149 , ..., -0.00714181,
1.1513389 , 0.2706878 ],
[-0.21893737, -1.3087385 , 0.02765398, ..., 0.08472531,
-0.06792852, -0.7201214 ],
...,
[ 0.6350919 , -2.039617 , -1.3034688 , ..., -0.9554744 ,
-1.599948 , 0.08479304],
[ 0.07492974, -0.5986934 , -1.1050867 , ..., -0.11557823,
-0.28731412, -0.31918246],
[ 0.48471937, -0.06257073, -1.2974975 , ..., -1.0962993 ,
-0.4040572 , 0.21458103]], dtype=float32),
array([[-0.7708095 , -0.32798892, -0.85267 , ..., -1.1566638 ,
-0.99381095, -1.2310058 ],
[-0.86885613, -0.92307526, -1.29594 , ..., -2.0624325 ,
-1.907441 , -2.9969432 ],
[-1.2015699 , -1.9176369 , 0.12706709, ..., 0.51506096,
0.00390473, -2.2483337 ],
...,
[-0.14983118, -2.574254 , -0.1847633 , ..., -0.6045526 ,
-0.77673376, -0.15773112],
[-0.10941257, -2.8653822 , -1.3854533 , ..., -0.80016965,
-0.08330945, -1.2884191 ],
[-0.08558089, -2.2401795 , -1.3387018 , ..., 0.4021217 ,
0.42515507, -1.7090813 ]], dtype=float32),
array([[-1.5491508e+00, -1.3155830e+00, -1.7421651e-01, ...,
-9.3138868e-01, -2.2747356e-01, -2.0232930e+00],
[-7.4754161e-01, -1.2739531e+00, -2.1347232e-01, ...,
2.8841010e-01, -4.3937716e-01, -1.1703664e+00],
[-1.9709258e+00, -6.2096602e-01, 5.0808877e-01, ...,
1.1404451e+00, -3.5375515e-01, -1.1222769e+00],
...,
[ 1.1877130e+00, 1.1287568e+00, 2.9174702e+00, ...,
1.0410763e+00, -1.2480073e-01, -3.0632731e-01],
[ 4.2065775e-01, 4.5900089e-01, 7.2773045e-01, ...,
8.9997667e-01, 2.5268352e-01, -1.8355316e-01],
[-3.1194031e-01, -6.4204484e-01, -6.6993916e-01, ...,
1.2414803e-03, -3.6406729e-01, -1.5995297e+00]], dtype=float32),
array([[ 1.0591333 , -1.0933094 , 0.30739665, ..., 0.71420574,
0.04606887, 1.1111258 ],
[ 0.6190886 , -1.1281767 , 0.0722167 , ..., 1.5855699 ,
0.37566823, 1.5974752 ],
[ 0.47711286, -0.01163641, 0.70723087, ..., 1.0262892 ,
0.55962074, 0.6303274 ],
...,
[ 0.17514986, 1.5385402 , -0.37254462, ..., 0.9020769 ,
0.5915993 , 0.33597535],
[-0.08376978, 0.9680634 , -0.11065453, ..., 0.53146577,
0.37930664, 0.11124872],
[ 0.11484523, 1.0426683 , 0.34150457, ..., 0.5409239 ,
0.7846303 , 0.5361923 ]], dtype=float32),
array([[ 0.32999477, 0.17257422, -0.02576547, ..., 0.68836564,
1.3876315 , 0.381285 ],
[-0.44304574, 0.9351039 , -0.14362532, ..., 0.45813897,
0.6335701 , 1.1349803 ],
[ 0.29613248, -0.67812103, 0.6472177 , ..., 0.70507175,
1.093124 , -0.42788172],
...,
[-0.6115895 , -0.55199385, -0.17171958, ..., -1.7289279 ,
-1.2144066 , -0.26242262],
[-0.5806329 , -1.2928493 , -0.32571778, ..., -1.3317893 ,
-0.5357137 , -0.1335874 ],
[-1.3655157 , -0.7823968 , 0.3385153 , ..., -0.82651436,
-0.42906252, -0.47424763]], dtype=float32),
array([[-3.2878976 , -1.0093194 , 0.22419819, ..., -2.8724759 ,
-0.02020989, -1.97309 ],
[-4.287771 , -2.4031138 , 1.3309505 , ..., -4.794849 ,
-0.64743924, -2.8073277 ],
[-2.8471644 , -2.7788126 , 1.8292081 , ..., -3.0492985 ,
-1.2713096 , -1.8178537 ],
...,
[-0.26133814, -0.6291796 , 1.4717265 , ..., 1.4553648 ,
-0.12694861, -0.97592974],
[ 0.8423908 , -0.5905986 , 0.09132817, ..., 0.01582366,
-2.2775512 , -1.9327923 ],
[-0.04958614, -0.22488853, 0.23246817, ..., -0.169459 ,
-0.6085778 , -1.543017 ]], dtype=float32),
array([[-0.06682424, -0.7450508 , -1.6548545 , ..., -0.45005533,
1.4599041 , 1.4350084 ],
[-0.49487555, 0.5681783 , -1.6785839 , ..., 0.03632116,
1.0407704 , 1.7117581 ],
[-0.27497712, 0.8913527 , -0.09413183, ..., 0.15748712,
0.712124 , 0.8629844 ],
...,
[-2.0785584 , -1.2359517 , 1.0292501 , ..., 1.7519091 ,
-0.1612157 , 1.1500338 ],
[-1.5251975 , -0.32918286, 0.5681826 , ..., 0.87714976,
-1.0111082 , -0.24764286],
[-0.36873695, -0.84597665, 0.89754236, ..., 1.5136204 ,
-0.19637768, -1.0313793 ]], dtype=float32),
array([[ 0.43625307, -1.0653203 , 0.32761878, ..., 0.06171988,
0.5738824 , 1.2062526 ],
[ 0.80085397, -0.7327369 , 1.1362983 , ..., 0.01132748,
0.79442656, 0.12309124],
[-0.88307744, -0.35869998, 0.43704078, ..., 0.6499472 ,
1.1098062 , 1.7723241 ],
...,
[ 1.6927369 , 0.6164818 , -1.0603409 , ..., -2.3887687 ,
-1.9102426 , -1.25551 ],
[ 1.2789361 , 1.5496958 , -1.2726898 , ..., -1.9889568 ,
-0.92565805, -1.1940223 ],
[-1.0755777 , 2.308003 , 0.32872215, ..., -2.3273242 ,
-2.5216153 , -1.3858548 ]], dtype=float32),
array([[ 0.39887545, -1.0930575 , 1.1025922 , ..., -0.945704 ,
0.4691077 , -0.89525473],
[-0.9064785 , -1.1409211 , -0.17285275, ..., -0.62646073,
0.693705 , -0.714968 ],
[-0.7510879 , -0.07583265, 0.7343275 , ..., -1.2140878 ,
-0.1468427 , -0.9280092 ],
...,
[ 0.2387795 , -0.4184364 , 0.40356135, ..., -0.7301674 ,
-1.1834645 , -1.6871489 ],
[-1.3272573 , -1.6782094 , -0.24517457, ..., -1.3664061 ,
0.40816742, -1.7362688 ],
[-0.8881753 , -1.1718726 , -1.5788368 , ..., -1.413439 ,
0.2355162 , -1.1996605 ]], dtype=float32),
array([[-0.78184754, -0.66809845, 1.0991846 , ..., -0.5456283 ,
1.3076684 , -2.0496268 ],
[-0.75256187, -0.49808884, 0.02566541, ..., -0.49812725,
1.5600789 , -1.3404057 ],
[-1.3734721 , 0.8596513 , 0.5717674 , ..., -0.92654634,
-0.64497304, -1.1436597 ],
...,
[-1.1318356 , 0.50034285, 0.7640534 , ..., 0.02538684,
-1.0470244 , 0.73771167],
[-1.0042878 , -1.1880598 , -0.6098851 , ..., 0.01396907,
-1.1241004 , 0.47160292],
[-0.7562368 , -0.82467604, -0.904065 , ..., -0.29229823,
-0.7516627 , 1.3053488 ]], dtype=float32),
array([[ 1.5532098e+00, -1.2609082e-01, 4.4478283e+00, ...,
2.7715037e+00, 2.8488500e+00, 1.9007223e+00],
[ 1.4401282e+00, -2.3541582e-01, 2.9139445e+00, ...,
1.6085638e+00, 2.0108342e+00, 1.1925684e+00],
[ 1.3794278e+00, 2.3805085e-01, 2.0041530e+00, ...,
1.1916746e+00, 1.6723529e+00, -4.5561618e-01],
...,
[-1.4086071e+00, -3.7036579e-02, -1.5715460e-01, ...,
-2.6526126e-01, -7.8287417e-01, -1.3870548e+00],
[ 7.5651580e-01, 9.5143974e-01, 3.1172398e-01, ...,
-3.1133401e-01, -1.5040870e+00, 3.8755059e-01],
[-8.5875058e-01, 1.3504355e-01, 1.6244681e+00, ...,
-1.6376062e-03, -1.0844401e+00, -9.1610610e-01]], dtype=float32),
array([[ 0.79215837, -1.1004183 , -1.1203923 , ..., -0.65233743,
0.8096089 , -1.6573611 ],
[ 0.2347323 , -1.839019 , -0.40912446, ..., -0.841489 ,
0.5148645 , -1.6523585 ],
[-0.60679835, -0.4185905 , 0.00372098, ..., -0.7088542 ,
1.151269 , -1.9017191 ],
...,
[ 1.4979622 , 0.65091604, 0.48702532, ..., 0.7613225 ,
0.9285922 , -0.44892657],
[ 1.2428638 , -0.26166725, 1.935121 , ..., 1.0377291 ,
0.5290513 , -0.2383192 ],
[ 0.9323369 , -0.40171504, 1.933029 , ..., 0.7396882 ,
0.978378 , 1.2135247 ]], dtype=float32),
array([[-0.9959006 , -1.7803189 , -0.36661854, ..., -0.16546819,
0.57648313, -1.438934 ],
[-0.31464407, -0.96728307, -0.56891316, ..., 0.82861334,
0.33146006, -1.3109674 ],
[ 0.42343858, -1.0391936 , -0.2449314 , ..., 1.0854808 ,
0.69901985, -0.48661834],
...,
[-0.32852378, 0.01202542, -1.5627251 , ..., -0.00839812,
-0.00737538, -0.6431352 ],
[-0.4230666 , 0.05414339, -1.4589509 , ..., 0.5303711 ,
0.74116755, 0.00349341],
[ 0.16117789, -0.20650576, -1.1544298 , ..., 0.20297928,
0.9782806 , -0.15210915]], dtype=float32),
array([[-0.5704387 , -1.5939155 , 0.7297422 , ..., -0.959838 ,
-0.15605944, 0.7989917 ],
[-0.14702363, -0.89761394, 0.32259792, ..., -1.6322302 ,
0.6886961 , 0.1893368 ],
[ 1.4781657 , -1.1084502 , 1.3548198 , ..., -1.1684914 ,
-0.19130848, 0.34969342],
...,
[ 1.1371639 , -0.8571445 , 2.95739 , ..., -0.83923125,
0.15402725, 1.702003 ],
[ 1.3653564 , -0.21177517, 2.2497864 , ..., -0.56261504,
1.6077832 , 1.5598679 ],
[ 1.0206444 , -1.5604558 , -0.34618282, ..., -0.55532706,
1.1021457 , 2.061943 ]], dtype=float32),
array([[-0.04189857, -0.45419127, 0.5038306 , ..., -2.1358204 ,
-0.22671466, -1.1676046 ],
[ 0.19857867, 0.05930965, 1.4053423 , ..., -2.256914 ,
-0.51986444, -1.5787822 ],
[-0.18852164, 0.19776241, 1.919334 , ..., -2.3786063 ,
-0.36360353, 0.07009408],
...,
[ 0.5441573 , -1.9301578 , 0.22470291, ..., 1.6169136 ,
-2.542941 , -2.3356535 ],
[ 0.9535341 , -0.6420601 , 0.64187 , ..., -0.12589715,
0.5131392 , -0.58408475],
[-0.35644665, -1.3551066 , -0.34261107, ..., 0.42267907,
0.93034244, 0.08801711]], dtype=float32),
array([[-0.54738235, 0.6647873 , -0.5131834 , ..., 0.4809246 ,
0.06636852, 0.9589999 ],
[-2.0064945 , 0.5921879 , -0.578285 , ..., -0.2334296 ,
-0.33682585, 0.5360114 ],
[-0.68690205, -0.97649264, 0.69919753, ..., -0.82209563,
-0.21980576, 0.7681972 ],
...,
[ 0.48113447, 0.68538195, -2.4631793 , ..., -0.86488116,
0.396683 , -0.9689204 ],
[ 0.6526906 , 0.25136545, -1.6561006 , ..., -0.95207834,
-0.04475881, -0.55650294],
[ 0.5835094 , -0.7925178 , -0.8076192 , ..., 0.35120338,
-0.36139324, -1.1275165 ]], dtype=float32),
array([[-0.29325423, -1.738501 , 2.013726 , ..., -0.15782206,
1.0007862 , -0.08426487],
[-0.31451166, -1.1086086 , -0.6015891 , ..., -1.6538444 ,
1.2664276 , -0.41151482],
[-0.14640315, -1.528364 , -0.27823493, ..., -2.96766 ,
-0.6710702 , -0.26081905],
...,
[ 1.5552136 , -3.0834687 , -1.0444156 , ..., 2.0384505 ,
0.9113593 , 0.01388774],
[ 0.51015544, -0.7104525 , 0.02504521, ..., 1.321757 ,
1.516073 , -0.81254166],
[-0.6420603 , -2.4378457 , -1.8099562 , ..., 1.328097 ,
0.00346263, -0.58896065]], dtype=float32),
array([[-1.0339715 , 0.36404255, -0.14562291, ..., 0.14448814,
-0.23626433, -0.3327841 ],
[-0.4210608 , 0.07954335, -0.2199393 , ..., -0.63992584,
-0.5713273 , -0.2617176 ],
[ 0.06390741, 0.868607 , 0.06235068, ..., 0.3670007 ,
-0.59217685, -0.06872906],
...,
[ 0.75372064, 2.2635887 , 1.0658425 , ..., -0.80361617,
-0.28154993, 1.2821828 ],
[ 0.9725676 , 2.814391 , -0.5106799 , ..., -0.5021383 ,
-0.01826464, 1.3839499 ],
[ 0.05977002, 1.8937707 , -0.02506769, ..., -0.8350715 ,
0.8026098 , 1.854034 ]], dtype=float32),
array([[ 0.14218824, -3.0595288 , 0.13038401, ..., 1.0031396 ,
2.372657 , 0.4960793 ],
[ 0.5235753 , -1.9900875 , 0.7336647 , ..., 0.09692469,
2.0772285 , -0.2218642 ],
[-0.46599957, -0.2268267 , 1.1412253 , ..., 0.28842068,
2.6090643 , 1.4712384 ],
...,
[ 0.58618623, 1.4435111 , -1.6332911 , ..., -1.0940357 ,
-3.6311698 , -0.54702586],
[ 0.03595666, 0.6104246 , -0.39879915, ..., -0.7400897 ,
-1.7987903 , -2.4070709 ],
[ 0.30727518, -2.3425503 , -0.39909616, ..., 0.3447141 ,
-1.6388785 , -0.8594849 ]], dtype=float32),
array([[-1.4274064 , -1.1852942 , -1.502372 , ..., 0.5582581 ,
0.22196048, 0.21363759],
[-0.30535066, 0.40650293, -0.84408784, ..., -1.3867826 ,
-0.23146945, -1.255213 ],
[-1.2363107 , -0.5076576 , -0.5571766 , ..., -1.86623 ,
0.31351528, -0.9488801 ],
...,
[-0.5844473 , -0.52946806, -1.800047 , ..., -1.2958642 ,
-0.7872089 , -0.19694391],
[ 0.46663302, 0.49937984, -2.5587661 , ..., -0.1305863 ,
-2.2433736 , -1.2265661 ],
[ 0.85294825, 1.3413651 , -2.6066713 , ..., -0.7810436 ,
-1.4702088 , -1.0256575 ]], dtype=float32),
array([[-0.40019286, 0.4615244 , -0.97423935, ..., -0.6545409 ,
3.1390095 , 2.8693912 ],
[-0.5539683 , 0.13449182, -0.24831504, ..., 0.723724 ,
1.3750082 , -0.05414042],
[ 0.57483727, 0.43146375, 0.977092 , ..., 1.7639434 ,
1.4801533 , -0.92843616],
...,
[ 0.46529475, 0.20824452, -1.0354674 , ..., 0.7264086 ,
0.2341511 , -0.61342466],
[ 0.5597321 , 0.8675183 , -0.49698803, ..., 1.6020817 ,
0.6812541 , -0.377476 ],
[ 0.32346538, 1.0856463 , 1.7805008 , ..., 0.73943007,
1.2646255 , 1.2305567 ]], dtype=float32),
array([[ 0.7620444 , 0.06814841, 0.28451237, ..., 1.285873 ,
-0.8461693 , 0.3389991 ],
[-0.952944 , 0.15015273, -0.8760661 , ..., 1.6883844 ,
-1.5667839 , -0.22842008],
[-1.3123825 , -1.690485 , 0.07148658, ..., 1.670351 ,
-1.3378259 , -0.595432 ],
...,
[ 0.16657706, -0.06233083, 0.11481489, ..., 0.47250667,
0.17968287, 1.1170496 ],
[-1.3420559 , -0.67395073, 0.45560434, ..., -0.750643 ,
0.2517233 , 1.5413102 ],
[-0.98822725, -1.3972563 , 1.517198 , ..., 0.09105795,
0.84782225, 1.0729043 ]], dtype=float32),
array([[-1.8005044 , 1.2287303 , 0.2520265 , ..., -0.5112129 ,
0.6030869 , 0.7739403 ],
[-1.4606725 , 0.8246631 , 0.19874103, ..., -0.6129102 ,
0.77025235, 0.9394402 ],
[-1.981871 , -0.20805584, 0.27382705, ..., -0.22015984,
-0.3488801 , 1.6250648 ],
...,
[-1.165392 , 0.46012905, -1.1047233 , ..., 0.5639135 ,
-1.5824699 , 0.91740584],
[ 0.67085785, 1.3503156 , 0.7259721 , ..., 0.6526452 ,
-1.6861752 , 0.17239597],
[ 1.4908121 , 2.414607 , -0.7551853 , ..., 0.6892942 ,
-1.7651825 , -0.10332297]], dtype=float32),
array([[-1.7857778 , 0.36721593, -0.02702453, ..., 2.6162112 ,
2.0041494 , -0.02836817],
[-0.41320434, 0.33003983, 0.9490504 , ..., 1.0821718 ,
1.3724556 , 0.19075944],
[ 0.78839755, -0.31269932, -0.3781773 , ..., 1.3988172 ,
0.5166411 , 0.8652227 ],
...,
[-0.13094012, 1.7268811 , 1.591182 , ..., -1.0683663 ,
0.8155291 , 1.1416723 ],
[ 0.7840855 , 0.92385775, 1.2571265 , ..., -0.74688566,
0.22843114, 2.205926 ],
[ 0.6456897 , -0.43328917, 1.335713 , ..., 0.553666 ,
-0.12462845, 1.3498095 ]], dtype=float32),
array([[ 1.4373641 , -0.199083 , -0.21279572, ..., -0.04790843,
-0.11333801, 0.51373476],
[-0.1540929 , -1.3630308 , 0.8088984 , ..., -0.5726548 ,
0.21436054, 0.7036518 ],
[-0.7409385 , -1.0343765 , -0.21932252, ..., -0.3237797 ,
-1.0558759 , -0.6978053 ],
...,
[ 0.97941065, 1.7633791 , 0.62113553, ..., -0.7899649 ,
-0.65325487, 2.0141904 ],
[-0.34951216, -0.95798266, -0.77140963, ..., -0.29763356,
-0.03839348, 0.4844508 ],
[ 0.9056462 , 0.38018695, -1.003405 , ..., -0.07926434,
0.7519348 , -0.07360731]], dtype=float32),
array([[-1.3620564e+00, 1.7073003e+00, -5.9239751e-01, ...,
2.2473450e+00, 6.9259906e-01, 5.3261352e-01],
[-4.4644859e-01, 3.0766159e-01, 5.9008032e-01, ...,
1.3289802e+00, -1.2532604e+00, -6.4198583e-01],
[-1.4084323e-01, -4.0418848e-01, -2.2852639e-04, ...,
6.6605794e-01, -1.2222325e+00, -9.9840099e-01],
...,
[-1.4528941e+00, -4.4937816e-01, 1.6831405e+00, ...,
-1.5571921e+00, -6.2371039e-01, -1.0030826e+00],
[ 1.4180024e-01, 7.9057229e-01, 5.4245996e-01, ...,
7.1623510e-01, -2.8748973e-03, 4.9939498e-01],
[ 7.5857919e-01, 1.1295128e+00, -1.0879661e+00, ...,
5.9378505e-01, 1.7247069e+00, 1.7510939e+00]], dtype=float32),
array([[-0.6009104 , 0.88176304, -0.370863 , ..., 3.2098067 ,
4.311326 , 0.15457271],
[ 0.31479412, 0.5228121 , -0.695277 , ..., 1.3041004 ,
2.3229368 , 0.0157322 ],
[ 0.50901383, -0.00468721, -1.0052313 , ..., -0.49730083,
0.8754164 , 0.7254645 ],
...,
[ 0.4223865 , -0.21710104, -0.580631 , ..., 0.01912983,
-0.63522375, -0.06591247],
[ 1.4888117 , 0.14193808, -0.19013193, ..., 0.5433171 ,
0.26996073, 0.5143779 ],
[ 1.3614957 , 2.0627956 , -0.02327744, ..., 0.7091169 ,
1.0703529 , 1.8377218 ]], dtype=float32),
array([[-0.36216295, -0.7457695 , -0.60783947, ..., -0.28201413,
1.1857277 , -0.74537456],
[-1.8490089 , 0.4505212 , -0.3948908 , ..., 0.07288259,
-0.78744596, -0.61772066],
[-1.7758862 , -1.1223505 , 2.0971038 , ..., 0.01498946,
-1.157456 , 0.8963272 ],
...,
[ 0.85573024, -0.24300739, 0.10429019, ..., -0.48299256,
0.79125917, -0.06332838],
[-0.33790645, -0.65815943, 0.0399457 , ..., -0.40312666,
0.24290636, -1.0424632 ],
[-0.41348627, -1.0610628 , -1.777799 , ..., -0.12458837,
-0.8268474 , -0.7806824 ]], dtype=float32),
array([[-0.71202296, -0.7708182 , 0.81068975, ..., -2.3533902 ,
4.185735 , -1.0887464 ],
[-0.2802166 , -0.4874427 , 2.0640337 , ..., -2.0572314 ,
2.315774 , -1.7450057 ],
[-0.77575344, -0.10643723, 0.98800886, ..., -1.6587806 ,
-0.92819375, -1.2632191 ],
...,
[ 0.84605366, 0.38117072, -0.04029485, ..., 0.16823398,
0.88381886, 0.8750882 ],
[ 1.3953687 , 0.14596726, 0.5367909 , ..., 0.17492154,
1.5820736 , -0.3292208 ],
[ 1.945629 , 1.9713994 , 1.2426908 , ..., 0.81192243,
1.3970611 , 0.23734656]], dtype=float32),
array([[-1.9232132 , 0.498199 , 0.07558472, ..., 0.46851227,
-0.96997905, 0.33870122],
[-0.8948249 , 0.26086694, 0.02259908, ..., 0.7053921 ,
-0.7493197 , 1.3485334 ],
[-0.1262711 , 0.6305783 , 0.17741719, ..., -0.13012694,
0.05185516, 0.4224629 ],
...,
[ 0.7475501 , 0.10586275, -0.9389211 , ..., 0.3674196 ,
-1.7311223 , -0.20573659],
[ 1.8549091 , 0.7366515 , -1.1474788 , ..., 0.6442668 ,
-1.1243709 , -0.34540278],
[ 1.662965 , -0.29895243, -1.3454841 , ..., 0.40709537,
-0.16936733, 0.36849567]], dtype=float32),
array([[-1.5126151e-01, 3.4122273e-01, 4.4267920e-01, ...,
2.2358716e+00, 1.8731557e-01, 2.5265351e-01],
[-4.3891105e-01, 1.7657304e-01, -6.7755717e-01, ...,
4.4318351e-01, -8.2860969e-04, 1.3240874e+00],
[-2.2245589e-01, 1.2124676e+00, -2.0245543e+00, ...,
2.3109566e-01, -3.5134187e-01, 7.3542453e-02],
...,
[-5.0388537e-02, 1.3121417e+00, 9.5730826e-02, ...,
-3.3369026e-01, -5.9811842e-01, -1.4677629e+00],
[ 1.1337873e+00, 9.5711464e-01, 1.7050190e-01, ...,
-7.8357197e-02, -4.5233524e-01, -1.4869859e+00],
[ 2.9254609e-01, 1.7527424e+00, 5.7795465e-01, ...,
6.2730414e-01, 3.4370193e-01, -5.3772265e-01]], dtype=float32),
array([[-0.96034026, 0.55499035, -0.50742984, ..., -0.37423614,
0.9574735 , -0.7179552 ],
[ 0.02273548, 0.5177982 , -0.4796871 , ..., -0.21972264,
0.73399043, 0.36465177],
[-0.1724504 , -0.02398042, -1.1862627 , ..., -0.15734087,
-1.09693 , 1.5157362 ],
...,
[-0.46036008, -2.0399332 , 0.9604358 , ..., -0.62199324,
-2.7096112 , 0.9170393 ],
[ 0.22765781, -1.9718229 , -0.42595106, ..., 0.37672922,
-1.8096159 , 0.8868221 ],
[-0.68851894, -1.3221353 , 1.9621247 , ..., -0.804514 ,
-0.34452176, 0.7434635 ]], dtype=float32),
array([[ 1.5710616 , 0.07801455, 0.41611627, ..., 1.4817048 ,
-0.09099045, 0.12375545],
[ 0.88723445, 1.1664042 , -0.07903641, ..., -0.16773216,
-0.7335596 , 0.26899582],
[ 0.259718 , 0.41769257, -0.28391147, ..., -1.4788716 ,
0.11251095, -0.16705881],
...,
[-0.41228968, -0.46688542, 0.21496277, ..., 0.22336198,
-0.00701263, -1.0781568 ],
[ 0.337662 , -1.8904454 , -0.9687175 , ..., -1.8237107 ,
-0.15105557, 0.7757551 ],
[-1.0823382 , -0.6630194 , 0.1898421 , ..., -0.4315493 ,
-0.32186317, 0.4785384 ]], dtype=float32),
array([[-0.6285118 , -0.11564266, 0.9082326 , ..., 0.39805496,
1.1160431 , -0.67432064],
[-1.5920782 , -0.2177919 , -1.4333123 , ..., -0.9587067 ,
0.9249889 , 0.14278325],
[-0.10731111, 0.15846638, -0.71284664, ..., -1.783296 ,
-0.02258285, 1.3773257 ],
...,
[-0.8691797 , 0.02303608, 0.3024959 , ..., -0.50078595,
0.89177465, -0.3606679 ],
[-1.832116 , 0.00351492, 1.350111 , ..., -0.38694444,
0.61222684, 0.20829777],
[-1.2281251 , -0.00370791, 0.13286756, ..., 0.05842127,
-0.94591165, 1.2091058 ]], dtype=float32),
array([[-0.9655409 , -0.05227269, -1.0196103 , ..., 0.9947163 ,
-0.51921123, 0.82248855],
[-1.5359533 , 0.04440498, -1.7895106 , ..., -0.11922174,
-2.0318766 , 0.744589 ],
[ 1.300935 , 0.58358705, -0.59209013, ..., 1.0257448 ,
-1.8015596 , -1.1979644 ],
...,
[ 1.4917141 , -0.851112 , -0.9688963 , ..., -1.5727042 ,
1.200741 , -2.7299678 ],
[ 0.41484866, -1.6983846 , -0.40206236, ..., -1.0156233 ,
0.6768684 , -2.4099743 ],
[-0.82474107, -0.7620385 , -1.0098392 , ..., -0.10178828,
1.7094315 , -0.5448372 ]], dtype=float32),
array([[ 0.7043974 , -1.3908719 , -0.13791052, ..., 1.6042407 ,
0.3450043 , 0.24852975],
[-1.0914711 , 0.9251265 , -1.2240915 , ..., 0.5604905 ,
0.05854373, 0.56733835],
[ 0.31803527, 0.34829837, 0.07104467, ..., 0.04919111,
0.5522023 , -0.26076892],
...,
[-0.07224645, 0.86409605, -0.23442908, ..., 0.13059261,
0.37195703, -1.4635687 ],
[ 1.4573953 , 0.76064694, -0.9871069 , ..., -0.31339416,
0.6815114 , 0.07735016],
[ 0.60681427, 1.7401016 , 0.36596134, ..., -0.6619155 ,
1.1647462 , -0.83906615]], dtype=float32),
array([[ 1.5894322 , 2.7202218 , 0.1218577 , ..., -0.23556225,
-0.50729257, 0.73113227],
[ 1.3954283 , 1.6451889 , -0.06269962, ..., -0.05320759,
-0.6850412 , 0.6476829 ],
[ 0.2768865 , 0.09243655, -1.0786585 , ..., -1.1308012 ,
-1.5683885 , 0.31116965],
...,
[ 0.6302334 , -0.3064984 , -1.3492873 , ..., -0.24108389,
-0.17565824, 2.1384063 ],
[ 1.1778773 , 0.10948468, -1.4119192 , ..., -1.467874 ,
-0.58983374, 1.1457886 ],
[ 0.6818447 , 0.21290563, -1.5722479 , ..., -1.1884314 ,
-1.1158953 , 1.0943506 ]], dtype=float32),
array([[ 0.7150891 , -0.7743385 , 0.54834557, ..., -1.7832971 ,
1.505931 , -0.9581622 ],
[ 0.56074697, -1.6751571 , -0.44713137, ..., -1.4667529 ,
0.06393412, 0.06538063],
[-0.01548599, 0.04304687, 0.15539359, ..., -0.17411892,
-0.896984 , 0.9283492 ],
...,
[ 0.12855518, 1.6819376 , -1.7276587 , ..., -0.3545797 ,
0.41282293, 0.46087 ],
[-0.15429746, 1.1885437 , -1.8860674 , ..., -0.46937183,
-1.1990105 , -0.53341305],
[ 0.867391 , 0.4530301 , -0.1854979 , ..., 0.5104297 ,
-0.7754563 , -0.2394437 ]], dtype=float32),
array([[-1.1938875 , -0.9428353 , 0.5730548 , ..., 1.4579664 ,
2.415001 , 1.4038366 ],
[-1.629066 , -0.9582286 , -0.12304766, ..., 0.4678896 ,
2.2875278 , 0.05082373],
[-2.5807798 , -1.6165544 , -0.15410781, ..., 0.07524679,
0.47260478, -1.1483669 ],
...,
[ 0.38921326, -1.4176033 , 2.3688397 , ..., 0.1909315 ,
-2.8385193 , 0.8172395 ],
[-0.5406923 , -2.2546632 , 2.2033484 , ..., 0.21527018,
-1.3990532 , 1.3382453 ],
[-0.12835763, -0.38702407, 1.800445 , ..., 0.06740593,
0.7596922 , -0.26278487]], dtype=float32),
array([[-0.9637298 , -1.0319778 , -2.2049983 , ..., 0.24380389,
0.09778386, -0.28984565],
[-0.23240574, -0.16470763, -0.504038 , ..., 0.15055881,
-0.26316205, -0.3848152 ],
[-0.05418091, -0.02968691, -0.9996463 , ..., -0.23420362,
0.02341123, 0.8881736 ],
...,
[ 1.5455061 , 0.76468706, -0.34930238, ..., -0.9101983 ,
0.08369046, 0.03626812],
[ 2.3967474 , -1.6390218 , -1.7088296 , ..., -1.5922892 ,
-0.79856074, -0.6634208 ],
[ 1.4501913 , -1.8390362 , -1.9383495 , ..., -0.6785707 ,
0.13619459, -0.2905436 ]], dtype=float32),
array([[-2.37169981e-01, -4.61918592e-01, 7.12140083e-01, ...,
-3.20887208e-01, 2.00235677e+00, -7.63016999e-01],
[-8.89117301e-01, -2.46725726e+00, -1.15555441e+00, ...,
-1.09757507e+00, 1.33609688e+00, -9.55275655e-01],
[ 2.07814202e-03, 7.20027229e-03, 5.98351955e-01, ...,
-1.15386456e-01, 1.23799169e+00, 9.69858170e-01],
...,
[-1.19975746e+00, -9.44652736e-01, 5.82578301e-01, ...,
7.86782861e-01, 1.18735683e+00, -3.64439413e-02],
[-6.47413194e-01, 4.58770059e-02, -1.29868770e+00, ...,
-5.70921540e-01, 4.99320358e-01, -1.25548100e+00],
[-7.49845803e-01, 8.00440073e-01, -1.08877981e+00, ...,
-4.98052388e-01, 2.63794005e-01, -7.69753754e-01]], dtype=float32),
array([[ 0.86806744, -1.0173975 , 0.6467372 , ..., 0.595801 ,
0.36251476, 0.211552 ],
[ 0.16553123, -1.3126917 , -0.5816441 , ..., 0.3068082 ,
-0.49591255, -0.88967377],
[-0.631537 , -1.0132114 , -0.6392088 , ..., 0.27082324,
-0.81286544, -0.5230837 ],
...,
[ 2.0843801 , -0.84821725, -1.7011327 , ..., 0.76037514,
-1.081201 , 0.09923442],
[ 2.32483 , -0.61033106, -1.3304464 , ..., 1.3402687 ,
-0.36242872, -0.71662194],
[ 2.190384 , -1.1361952 , -0.40771583, ..., 0.65384215,
0.9187543 , 1.2808316 ]], dtype=float32),
array([[-2.6726441e+00, 2.1104591e+00, -3.3242777e-02, ...,
1.1738220e+00, 1.7900871e+00, -1.7605627e-01],
[-1.6688141e+00, 6.3067764e-01, 8.9320809e-02, ...,
8.9328760e-01, -8.0327618e-01, -1.0757042e+00],
[ 7.3963577e-01, -9.3286502e-01, 6.3641392e-02, ...,
-4.6201849e-01, -1.9122918e+00, -2.3211915e+00],
...,
[ 6.8330318e-01, -1.8736920e+00, 4.5698938e-01, ...,
3.3091739e-01, -3.0935517e-01, 1.7119221e-01],
[ 2.0442654e-01, -1.7638675e+00, 4.3907452e-01, ...,
3.5957956e-01, 6.8090051e-01, 1.1949276e+00],
[ 6.7352581e-01, -7.4204814e-01, 1.9562789e-03, ...,
1.5108454e+00, 7.6459564e-02, 2.6628609e+00]], dtype=float32),
array([[ 1.8355731e+00, 6.3721293e-01, 1.4387861e+00, ...,
1.5183390e+00, 1.3768994e+00, 1.1398090e+00],
[-7.6285309e-01, -1.7151172e-01, 7.7904779e-01, ...,
2.0082670e-03, 5.4300445e-01, -1.7661130e-02],
[-7.1544427e-01, -3.3821759e-01, 6.5137371e-02, ...,
-1.2946208e-01, -1.2803964e-01, -7.0695722e-01],
...,
[ 2.6245685e+00, 1.5668703e+00, -6.7021966e-01, ...,
-2.5492144e-01, 1.6117696e-01, 1.0322076e+00],
[ 2.3771856e+00, 4.4969693e-01, -7.0273966e-01, ...,
-5.8892173e-01, -7.7758688e-01, -9.3576096e-02],
[ 1.3128207e+00, 3.1150520e+00, 1.3393317e+00, ...,
-1.0016419e+00, -9.5276421e-01, -1.1641923e+00]], dtype=float32),
array([[ 0.83290744, -0.8993664 , -0.402725 , ..., 0.10582674,
-0.01235053, -0.09738191],
[ 2.1338842 , 0.01603432, -1.2063591 , ..., -0.3820464 ,
0.76676905, 0.27815175],
[ 2.120046 , -0.16025397, -0.08871422, ..., 0.22135866,
0.65567553, 0.42302525],
...,
[-0.13026482, 0.4517881 , -1.4836109 , ..., 1.1219379 ,
-0.9371764 , -1.1314248 ],
[-1.1179477 , 0.16191007, -1.1466225 , ..., 1.2478981 ,
-0.9128549 , -0.7566994 ],
[-0.7213329 , -0.8711958 , -2.0390506 , ..., 0.07419373,
-0.57677644, 0.50420225]], dtype=float32),
array([[ 0.4103352 , -0.52905226, 1.0153542 , ..., 0.4753324 ,
1.0308077 , -2.0255363 ],
[ 0.7359962 , -0.26603514, 1.2708586 , ..., -0.57504964,
-0.07347727, -0.19985658],
[-0.3538252 , 0.06416875, -0.13405213, ..., -0.18260989,
-0.23987597, -1.255189 ],
...,
[-0.74848014, 0.22666904, -2.6267166 , ..., -0.98463064,
1.7221543 , 0.15214834],
[ 1.2221069 , 0.42036545, -2.3450122 , ..., -0.54053473,
0.7596963 , 0.02417171],
[ 0.8020033 , 0.24751437, -0.73956126, ..., 0.21496078,
0.45159924, -1.020427 ]], dtype=float32),
array([[ 0.3566594 , 0.449696 , 0.8803943 , ..., -0.27954066,
-0.7188793 , 0.03940161],
[-0.8832984 , -1.2276958 , 0.7840812 , ..., 0.22020257,
-0.6587119 , 0.07025829],
[-0.451683 , -0.1536548 , 0.693159 , ..., -0.35344216,
-1.0606401 , -0.5222097 ],
...,
[-0.95997137, 0.12089899, -1.1831 , ..., 0.0210605 ,
0.13389786, 0.1197189 ],
[-0.3723317 , -0.858281 , -0.7813147 , ..., -0.3399848 ,
0.34963062, 0.41852584],
[-0.26307055, 0.21286727, -0.14672194, ..., -0.27064312,
0.23617145, -0.04905495]], dtype=float32),
array([[-1.5423086 , -0.41674405, 0.32278934, ..., -0.48561963,
1.3235972 , 1.0947137 ],
[-0.43097323, -0.3729657 , 0.36833474, ..., -0.24885002,
0.02456255, -0.5059626 ],
[-0.28356156, -0.73586303, 0.81975526, ..., -0.33171654,
-0.3055159 , 0.02063482],
...,
[-0.1102392 , -0.23311119, -2.3949356 , ..., 0.2611473 ,
0.6101144 , 0.00523532],
[ 1.3885255 , 0.7925769 , -1.4495907 , ..., 0.30880964,
-0.771638 , 1.2977543 ],
[ 0.0587432 , 0.7445164 , -0.20883049, ..., -0.18030784,
0.02500429, 0.962118 ]], dtype=float32),
array([[ 0.16396885, 2.1365042 , -1.5938615 , ..., 0.53133374,
0.14221856, 0.7198487 ],
[ 0.7685013 , 0.8800407 , 0.07131885, ..., 0.8842533 ,
0.84760785, 0.64061326],
[-1.2296537 , -0.08120761, 0.4653212 , ..., 1.8077071 ,
1.1730205 , 0.07177108],
...,
[ 0.23427312, -0.8024212 , 0.06597315, ..., -0.77624357,
0.95499444, -0.5483867 ],
[ 0.7133692 , -1.44774 , -0.11900216, ..., -1.850611 ,
0.5746179 , -0.48667806],
[ 0.32128957, -0.8625351 , -0.56147313, ..., -2.078413 ,
-0.3623512 , 0.49528953]], dtype=float32),
array([[ 0.79925615, -0.11046945, 0.26140705, ..., -0.30440444,
1.5764267 , -0.9696423 ],
[-2.305426 , -0.7360753 , -0.69595647, ..., -0.19015577,
0.33076623, 0.05932562],
[ 1.6313475 , 0.47425893, -0.8992906 , ..., 0.1963522 ,
2.2216752 , 1.0445418 ],
...,
[ 0.39278013, -0.39754686, 0.6933964 , ..., 1.5115184 ,
0.10667583, -1.1953909 ],
[-0.9841818 , 0.40281126, 1.6860114 , ..., 3.7408965 ,
0.42164692, -2.186154 ],
[ 1.2675418 , -0.02409442, 0.02182645, ..., 1.2144654 ,
1.3072206 , -0.4438752 ]], dtype=float32),
array([[-1.2192026 , -0.93381715, -0.09254263, ..., 0.14003803,
1.2372396 , 0.6415857 ],
[-1.3408351 , -0.13622981, 0.52433634, ..., -0.71559393,
-0.03018649, -0.0616351 ],
[-2.103941 , -0.9411868 , -0.7861624 , ..., -1.0703001 ,
-0.9839708 , -1.0840874 ],
...,
[-0.84618574, -0.4419959 , -0.8643464 , ..., 0.50056076,
0.63066876, 0.9178843 ],
[-0.63578546, -0.2195098 , -1.3800963 , ..., 1.2053883 ,
1.3478142 , 0.34936693],
[-1.0170684 , 1.1674016 , -0.6548495 , ..., -0.1434365 ,
0.81251854, -1.3575886 ]], dtype=float32),
array([[-0.46007478, 1.2514827 , 0.34222215, ..., 1.1633182 ,
-0.4136564 , -0.15204278],
[-0.41819477, 0.49974155, 0.718653 , ..., 1.402601 ,
-0.19331357, -0.05004539],
[ 1.395996 , -0.4265406 , -0.87285775, ..., 0.27291784,
-0.672799 , -0.05960998],
...,
[-0.20204487, 0.38923043, -0.94756687, ..., -1.3003789 ,
-0.57752943, -0.8651056 ],
[-0.02228898, -0.35452652, -0.13872375, ..., -1.5488387 ,
-0.61668235, -0.4156778 ],
[ 1.0096984 , -1.4712228 , -0.3869995 , ..., -1.5390744 ,
0.28503868, -0.59548926]], dtype=float32),
array([[-1.2980112 , -0.10164279, -0.2849671 , ..., 0.6488607 ,
-0.3992511 , 1.7831595 ],
[-1.0330914 , 0.170592 , -0.44438714, ..., 0.4013004 ,
0.01996766, 0.06398434],
[-0.16662055, 0.35228607, -0.31586906, ..., 1.7044468 ,
0.30152568, -0.77163875],
...,
[-1.352562 , -0.6611807 , 0.6812929 , ..., 0.04395775,
-0.3350484 , 0.69912744],
[-0.38834575, -0.01287032, 0.9136661 , ..., 1.0791212 ,
-0.11421538, 0.48214644],
[-0.5513112 , -1.0736217 , -0.26715022, ..., 0.4242161 ,
0.35782793, 0.8137447 ]], dtype=float32),
array([[-0.738883 , 1.8015068 , -1.7405919 , ..., 1.5863024 ,
0.7974057 , 0.10409734],
[-1.1283479 , 1.7300047 , -0.5917312 , ..., 0.446108 ,
1.8985872 , -0.12919854],
[-0.9350154 , 1.1138222 , -0.48919085, ..., -1.5305504 ,
1.2551781 , 0.90535617],
...,
[-1.3595905 , -1.105739 , 1.0650616 , ..., -0.0595351 ,
-1.3688546 , 0.3558271 ],
[ 0.16096482, -0.22663285, 1.7169338 , ..., -0.69847584,
-1.5337858 , -0.10595929],
[-1.6097369 , -0.5571229 , 1.9413296 , ..., 0.18960929,
-0.25289312, -0.20365423]], dtype=float32),
array([[-1.0076029 , 1.5661017 , 0.4450116 , ..., -0.9414317 ,
1.6603913 , 2.0273383 ],
[-0.47319892, -0.22058228, -0.9042234 , ..., -0.52822495,
1.1003162 , 1.344999 ],
[-1.1902295 , 0.25175503, -0.769976 , ..., -0.4125756 ,
0.5289827 , 1.1931387 ],
...,
[ 1.4232621 , -0.22053331, 0.11088138, ..., 1.2593362 ,
1.068194 , 0.29381454],
[ 1.2264346 , 0.6313647 , 0.07780237, ..., 0.9501781 ,
0.667892 , 0.29651532],
[ 0.66898525, 0.9213892 , -0.23006822, ..., 0.87095577,
0.26903483, 0.27359015]], dtype=float32),
array([[ 0.4313007 , 1.7683698 , 0.8859389 , ..., 0.28485548,
1.6380237 , 0.19567503],
[ 0.934723 , 0.9773703 , 0.7196389 , ..., -0.36407807,
0.42946252, 2.3825448 ],
[-1.3218893 , -0.9236471 , -1.4578354 , ..., 0.19237204,
-1.3374197 , 1.7401958 ],
...,
[ 0.27820146, 0.09531964, 0.23583852, ..., -0.8042503 ,
-0.6348798 , 0.1228614 ],
[ 0.17446674, -0.5994683 , -0.96817595, ..., -0.5530261 ,
-1.1614046 , -0.34990266],
[-0.5196843 , 1.0473673 , -1.3887397 , ..., 0.1331169 ,
-0.7248431 , 0.45765704]], dtype=float32),
array([[-3.8100176 , -0.24919337, 0.6972416 , ..., -0.06895131,
1.8092674 , -0.03361126],
[-2.6298678 , -0.92854005, 1.0099338 , ..., -0.41835254,
0.6103745 , -1.128349 ],
[-0.29317513, 1.0677167 , -0.46259597, ..., -1.1267719 ,
0.03547287, -0.41047814],
...,
[-0.42851517, -0.9347528 , -2.7779844 , ..., 0.84280026,
-0.37186107, -0.08587877],
[-0.9309512 , -1.6179509 , -2.1732867 , ..., 1.1171993 ,
-0.7711189 , -0.1582423 ],
[ 1.079849 , -0.56552327, -1.7008709 , ..., 1.8282412 ,
-0.12107114, -0.12713514]], dtype=float32)],
'condition_labels_str': ['intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'intact',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'word',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest',
'rest']}]
[35]:
saved_paths = []
failed_jobs = []
for K in K_VALUES:
print("\n" + "#" * 90)
print(f"Beginning all jobs for K={K}")
print("#" * 90)
for job in jobs:
job_name = job["job_name"]
subject_list = job["subject_list"]
condition_labels_str = job["condition_labels_str"]
save_path = job_output_path(OUTPUT_DIR, ANALYSIS_TYPE, FIT_SCOPE, job_name, K)
if SKIP_EXISTING and os.path.exists(save_path):
print(f"Skipping existing: {save_path}")
saved_paths.append(save_path)
continue
try:
fit_res = fit_msaa(subject_list, K, ANALYSIS_TYPE, opts=MSAA_OPTS, verbose=VERBOSE_PROGRESS)
save_fit_npz(save_path, fit_res, condition_labels_str)
saved_paths.append(save_path)
print("Saved:", save_path)
except Exception as e:
failed_jobs.append({"job_name": job_name, "K": K, "error_type": type(e).__name__, "error_message": str(e)})
print(f"FAILED: job={job_name}, K={K}, error={type(e).__name__}: {e}")
if STOP_ON_ERROR:
raise
saved_paths, failed_jobs
##########################################################################################
Beginning all jobs for K=5
##########################################################################################
[init] Running furthest_sum_gpu (K=5, N=32400)...
[furthest_sum_gpu] 1/5 archetypes initialised...
[furthest_sum_gpu] 2/5 archetypes initialised...
[furthest_sum_gpu] 3/5 archetypes initialised...
[furthest_sum_gpu] 4/5 archetypes initialised...
[furthest_sum_gpu] 5/5 archetypes initialised...
[furthest_sum_gpu] done — 5 archetypes selected.
iter 5 | NLL 2.9469e+07 | dNLL/NLL 2.1336e-03
iter 10 | NLL 2.9388e+07 | dNLL/NLL 1.4441e-04
iter 15 | NLL 2.9379e+07 | dNLL/NLL 3.1185e-05
iter 20 | NLL 2.9376e+07 | dNLL/NLL 1.9040e-05
iter 25 | NLL 2.9372e+07 | dNLL/NLL 3.4658e-05
iter 30 | NLL 2.9366e+07 | dNLL/NLL 4.7626e-05
iter 35 | NLL 2.9357e+07 | dNLL/NLL 8.5224e-05
iter 40 | NLL 2.9327e+07 | dNLL/NLL 2.8534e-04
iter 45 | NLL 2.9298e+07 | dNLL/NLL 1.3304e-04
iter 50 | NLL 2.9286e+07 | dNLL/NLL 5.5081e-05
iter 55 | NLL 2.9281e+07 | dNLL/NLL 2.7892e-05
iter 60 | NLL 2.9278e+07 | dNLL/NLL 1.4000e-05
iter 65 | NLL 2.9277e+07 | dNLL/NLL 5.3783e-06
iter 70 | NLL 2.9277e+07 | dNLL/NLL 5.9368e-06
iter 75 | NLL 2.9276e+07 | dNLL/NLL 3.0602e-06
[SPATIAL AA] K=5 | VarExpl=25.62%
subj0 sXC shape: (300, 5)
subj0 S shape: (5, 700)
Expected: sXC ~ (300, 5), S ~ (5, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K5.npz
##########################################################################################
Beginning all jobs for K=10
##########################################################################################
[init] Running furthest_sum_gpu (K=10, N=32400)...
[furthest_sum_gpu] 1/10 archetypes initialised...
[furthest_sum_gpu] 2/10 archetypes initialised...
[furthest_sum_gpu] 3/10 archetypes initialised...
[furthest_sum_gpu] 4/10 archetypes initialised...
[furthest_sum_gpu] 5/10 archetypes initialised...
[furthest_sum_gpu] 6/10 archetypes initialised...
[furthest_sum_gpu] 7/10 archetypes initialised...
[furthest_sum_gpu] 8/10 archetypes initialised...
[furthest_sum_gpu] 9/10 archetypes initialised...
[furthest_sum_gpu] 10/10 archetypes initialised...
[furthest_sum_gpu] done — 10 archetypes selected.
iter 5 | NLL 2.8969e+07 | dNLL/NLL 4.3932e-03
iter 10 | NLL 2.8675e+07 | dNLL/NLL 6.4268e-04
iter 15 | NLL 2.8646e+07 | dNLL/NLL 1.1012e-04
iter 20 | NLL 2.8638e+07 | dNLL/NLL 2.1642e-05
iter 25 | NLL 2.8637e+07 | dNLL/NLL 2.0959e-06
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2897: RuntimeWarning: invalid value encountered in divide
c /= stddev[:, None]
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2898: RuntimeWarning: invalid value encountered in divide
c /= stddev[None, :]
[SPATIAL AA] K=10 | VarExpl=31.25%
subj0 sXC shape: (300, 10)
subj0 S shape: (10, 700)
Expected: sXC ~ (300, 10), S ~ (10, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K10.npz
##########################################################################################
Beginning all jobs for K=25
##########################################################################################
[init] Running furthest_sum_gpu (K=25, N=32400)...
[furthest_sum_gpu] 2/25 archetypes initialised...
[furthest_sum_gpu] 4/25 archetypes initialised...
[furthest_sum_gpu] 6/25 archetypes initialised...
[furthest_sum_gpu] 8/25 archetypes initialised...
[furthest_sum_gpu] 10/25 archetypes initialised...
[furthest_sum_gpu] 12/25 archetypes initialised...
[furthest_sum_gpu] 14/25 archetypes initialised...
[furthest_sum_gpu] 16/25 archetypes initialised...
[furthest_sum_gpu] 18/25 archetypes initialised...
[furthest_sum_gpu] 20/25 archetypes initialised...
[furthest_sum_gpu] 22/25 archetypes initialised...
[furthest_sum_gpu] 24/25 archetypes initialised...
[furthest_sum_gpu] done — 25 archetypes selected.
iter 5 | NLL 2.8153e+07 | dNLL/NLL 1.6879e-03
iter 10 | NLL 2.7947e+07 | dNLL/NLL 1.5402e-03
iter 15 | NLL 2.7889e+07 | dNLL/NLL 1.2981e-04
iter 20 | NLL 2.7859e+07 | dNLL/NLL 2.7902e-04
iter 25 | NLL 2.7847e+07 | dNLL/NLL 5.3729e-05
iter 30 | NLL 2.7841e+07 | dNLL/NLL 4.3764e-05
iter 35 | NLL 2.7836e+07 | dNLL/NLL 2.3926e-05
iter 40 | NLL 2.7835e+07 | dNLL/NLL 8.3912e-06
iter 45 | NLL 2.7833e+07 | dNLL/NLL 8.7382e-06
iter 50 | NLL 2.7832e+07 | dNLL/NLL 1.0781e-05
iter 55 | NLL 2.7830e+07 | dNLL/NLL 1.8747e-05
iter 60 | NLL 2.7826e+07 | dNLL/NLL 3.1956e-05
iter 65 | NLL 2.7824e+07 | dNLL/NLL 5.9930e-06
iter 70 | NLL 2.7823e+07 | dNLL/NLL 2.8629e-06
iter 75 | NLL 2.7823e+07 | dNLL/NLL 1.5817e-06
iter 80 | NLL 2.7823e+07 | dNLL/NLL 1.1584e-06
iter 85 | NLL 2.7822e+07 | dNLL/NLL 9.9395e-07
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2897: RuntimeWarning: invalid value encountered in divide
c /= stddev[:, None]
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2898: RuntimeWarning: invalid value encountered in divide
c /= stddev[None, :]
[SPATIAL AA] K=25 | VarExpl=38.44%
subj0 sXC shape: (300, 25)
subj0 S shape: (25, 700)
Expected: sXC ~ (300, 25), S ~ (25, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K25.npz
##########################################################################################
Beginning all jobs for K=50
##########################################################################################
[init] Running furthest_sum_gpu (K=50, N=32400)...
[furthest_sum_gpu] 5/50 archetypes initialised...
[furthest_sum_gpu] 10/50 archetypes initialised...
[furthest_sum_gpu] 15/50 archetypes initialised...
[furthest_sum_gpu] 20/50 archetypes initialised...
[furthest_sum_gpu] 25/50 archetypes initialised...
[furthest_sum_gpu] 30/50 archetypes initialised...
[furthest_sum_gpu] 35/50 archetypes initialised...
[furthest_sum_gpu] 40/50 archetypes initialised...
[furthest_sum_gpu] 45/50 archetypes initialised...
[furthest_sum_gpu] 50/50 archetypes initialised...
[furthest_sum_gpu] done — 50 archetypes selected.
iter 5 | NLL 2.7183e+07 | dNLL/NLL 1.9710e-03
iter 10 | NLL 2.7041e+07 | dNLL/NLL 2.2821e-04
iter 15 | NLL 2.7024e+07 | dNLL/NLL 8.5199e-05
iter 20 | NLL 2.7018e+07 | dNLL/NLL 3.4123e-05
iter 25 | NLL 2.7003e+07 | dNLL/NLL 2.3001e-04
iter 30 | NLL 2.6982e+07 | dNLL/NLL 1.7824e-05
iter 35 | NLL 2.6981e+07 | dNLL/NLL 3.5132e-06
iter 40 | NLL 2.6981e+07 | dNLL/NLL 1.8575e-06
iter 45 | NLL 2.6981e+07 | dNLL/NLL 1.1028e-06
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2897: RuntimeWarning: invalid value encountered in divide
c /= stddev[:, None]
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2898: RuntimeWarning: invalid value encountered in divide
c /= stddev[None, :]
[SPATIAL AA] K=50 | VarExpl=45.86%
subj0 sXC shape: (300, 50)
subj0 S shape: (50, 700)
Expected: sXC ~ (300, 50), S ~ (50, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K50.npz
##########################################################################################
Beginning all jobs for K=75
##########################################################################################
[init] Running furthest_sum_gpu (K=75, N=32400)...
[furthest_sum_gpu] 7/75 archetypes initialised...
[furthest_sum_gpu] 14/75 archetypes initialised...
[furthest_sum_gpu] 21/75 archetypes initialised...
[furthest_sum_gpu] 28/75 archetypes initialised...
[furthest_sum_gpu] 35/75 archetypes initialised...
[furthest_sum_gpu] 42/75 archetypes initialised...
[furthest_sum_gpu] 49/75 archetypes initialised...
[furthest_sum_gpu] 56/75 archetypes initialised...
[furthest_sum_gpu] 63/75 archetypes initialised...
[furthest_sum_gpu] 70/75 archetypes initialised...
[furthest_sum_gpu] done — 75 archetypes selected.
iter 5 | NLL 2.6581e+07 | dNLL/NLL 1.1162e-03
iter 10 | NLL 2.6489e+07 | dNLL/NLL 5.5280e-04
iter 15 | NLL 2.6396e+07 | dNLL/NLL 1.8310e-04
iter 20 | NLL 2.6386e+07 | dNLL/NLL 5.9075e-05
iter 25 | NLL 2.6382e+07 | dNLL/NLL 2.3319e-05
iter 30 | NLL 2.6380e+07 | dNLL/NLL 4.8399e-06
iter 35 | NLL 2.6380e+07 | dNLL/NLL 2.2926e-06
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2897: RuntimeWarning: invalid value encountered in divide
c /= stddev[:, None]
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2898: RuntimeWarning: invalid value encountered in divide
c /= stddev[None, :]
[SPATIAL AA] K=75 | VarExpl=51.16%
subj0 sXC shape: (300, 75)
subj0 S shape: (75, 700)
Expected: sXC ~ (300, 75), S ~ (75, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K75.npz
##########################################################################################
Beginning all jobs for K=100
##########################################################################################
[init] Running furthest_sum_gpu (K=100, N=32400)...
[furthest_sum_gpu] 10/100 archetypes initialised...
[furthest_sum_gpu] 20/100 archetypes initialised...
[furthest_sum_gpu] 30/100 archetypes initialised...
[furthest_sum_gpu] 40/100 archetypes initialised...
[furthest_sum_gpu] 50/100 archetypes initialised...
[furthest_sum_gpu] 60/100 archetypes initialised...
[furthest_sum_gpu] 70/100 archetypes initialised...
[furthest_sum_gpu] 80/100 archetypes initialised...
[furthest_sum_gpu] 90/100 archetypes initialised...
[furthest_sum_gpu] 100/100 archetypes initialised...
[furthest_sum_gpu] done — 100 archetypes selected.
iter 5 | NLL 2.6073e+07 | dNLL/NLL 5.3281e-04
iter 10 | NLL 2.5991e+07 | dNLL/NLL 6.4110e-04
iter 15 | NLL 2.5969e+07 | dNLL/NLL 7.4884e-05
iter 20 | NLL 2.5958e+07 | dNLL/NLL 9.9424e-05
iter 25 | NLL 2.5940e+07 | dNLL/NLL 1.3199e-04
iter 30 | NLL 2.5929e+07 | dNLL/NLL 1.2131e-05
iter 35 | NLL 2.5929e+07 | dNLL/NLL 3.0853e-06
iter 40 | NLL 2.5928e+07 | dNLL/NLL 2.4819e-06
iter 45 | NLL 2.5928e+07 | dNLL/NLL 2.6737e-06
iter 50 | NLL 2.5927e+07 | dNLL/NLL 4.9904e-06
iter 55 | NLL 2.5926e+07 | dNLL/NLL 1.2397e-05
iter 60 | NLL 2.5924e+07 | dNLL/NLL 2.0714e-05
iter 65 | NLL 2.5921e+07 | dNLL/NLL 2.6400e-05
iter 70 | NLL 2.5917e+07 | dNLL/NLL 3.0896e-05
iter 75 | NLL 2.5915e+07 | dNLL/NLL 7.9878e-06
iter 80 | NLL 2.5914e+07 | dNLL/NLL 8.3688e-06
iter 85 | NLL 2.5912e+07 | dNLL/NLL 1.2625e-05
iter 90 | NLL 2.5911e+07 | dNLL/NLL 1.1413e-05
iter 95 | NLL 2.5910e+07 | dNLL/NLL 5.4563e-06
iter 100 | NLL 2.5909e+07 | dNLL/NLL 2.5105e-06
iter 105 | NLL 2.5909e+07 | dNLL/NLL 2.0023e-06
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2897: RuntimeWarning: invalid value encountered in divide
c /= stddev[:, None]
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2898: RuntimeWarning: invalid value encountered in divide
c /= stddev[None, :]
[SPATIAL AA] K=100 | VarExpl=55.31%
subj0 sXC shape: (300, 100)
subj0 S shape: (100, 700)
Expected: sXC ~ (300, 100), S ~ (100, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K100.npz
##########################################################################################
Beginning all jobs for K=300
##########################################################################################
[init] Running furthest_sum_gpu (K=300, N=32400)...
[furthest_sum_gpu] 30/300 archetypes initialised...
[furthest_sum_gpu] 60/300 archetypes initialised...
[furthest_sum_gpu] 90/300 archetypes initialised...
[furthest_sum_gpu] 120/300 archetypes initialised...
[furthest_sum_gpu] 150/300 archetypes initialised...
[furthest_sum_gpu] 180/300 archetypes initialised...
[furthest_sum_gpu] 210/300 archetypes initialised...
[furthest_sum_gpu] 240/300 archetypes initialised...
[furthest_sum_gpu] 270/300 archetypes initialised...
[furthest_sum_gpu] 300/300 archetypes initialised...
[furthest_sum_gpu] done — 300 archetypes selected.
iter 5 | NLL 2.3512e+07 | dNLL/NLL 1.2102e-05
iter 10 | NLL 2.3512e+07 | dNLL/NLL 9.8619e-07
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2897: RuntimeWarning: invalid value encountered in divide
c /= stddev[:, None]
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2898: RuntimeWarning: invalid value encountered in divide
c /= stddev[None, :]
[SPATIAL AA] K=300 | VarExpl=76.45%
subj0 sXC shape: (300, 300)
subj0 S shape: (300, 700)
Expected: sXC ~ (300, 300), S ~ (300, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K300.npz
##########################################################################################
Beginning all jobs for K=500
##########################################################################################
[init] Running furthest_sum_gpu (K=500, N=32400)...
[furthest_sum_gpu] 50/500 archetypes initialised...
[furthest_sum_gpu] 100/500 archetypes initialised...
[furthest_sum_gpu] 150/500 archetypes initialised...
[furthest_sum_gpu] 200/500 archetypes initialised...
[furthest_sum_gpu] 250/500 archetypes initialised...
[furthest_sum_gpu] 300/500 archetypes initialised...
[furthest_sum_gpu] 350/500 archetypes initialised...
[furthest_sum_gpu] 400/500 archetypes initialised...
[furthest_sum_gpu] 450/500 archetypes initialised...
[furthest_sum_gpu] 500/500 archetypes initialised...
[furthest_sum_gpu] done — 500 archetypes selected.
iter 5 | NLL 2.1964e+07 | dNLL/NLL 2.2120e-06
iter 10 | NLL 2.1964e+07 | dNLL/NLL 1.0518e-06
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2897: RuntimeWarning: invalid value encountered in divide
c /= stddev[:, None]
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2898: RuntimeWarning: invalid value encountered in divide
c /= stddev[None, :]
[SPATIAL AA] K=500 | VarExpl=90.10%
subj0 sXC shape: (300, 500)
subj0 S shape: (500, 700)
Expected: sXC ~ (300, 500), S ~ (500, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K500.npz
##########################################################################################
Beginning all jobs for K=700
##########################################################################################
[init] Running furthest_sum_gpu (K=700, N=32400)...
[furthest_sum_gpu] 70/700 archetypes initialised...
[furthest_sum_gpu] 140/700 archetypes initialised...
[furthest_sum_gpu] 210/700 archetypes initialised...
[furthest_sum_gpu] 280/700 archetypes initialised...
[furthest_sum_gpu] 350/700 archetypes initialised...
[furthest_sum_gpu] 420/700 archetypes initialised...
[furthest_sum_gpu] 490/700 archetypes initialised...
[furthest_sum_gpu] 560/700 archetypes initialised...
[furthest_sum_gpu] 630/700 archetypes initialised...
[furthest_sum_gpu] 700/700 archetypes initialised...
[furthest_sum_gpu] done — 700 archetypes selected.
iter 5 | NLL 2.0842e+07 | dNLL/NLL 3.3284e-08
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2897: RuntimeWarning: invalid value encountered in divide
c /= stddev[:, None]
/home/alex/miniforge3/envs/jupyter/lib/python3.11/site-packages/numpy/lib/function_base.py:2898: RuntimeWarning: invalid value encountered in divide
c /= stddev[None, :]
[SPATIAL AA] K=700 | VarExpl=100.00%
subj0 sXC shape: (300, 700)
subj0 S shape: (700, 700)
Expected: sXC ~ (300, 700), S ~ (700, 700)
Saved: msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K700.npz
[35]:
(['msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K5.npz',
'msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K10.npz',
'msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K25.npz',
'msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K50.npz',
'msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K75.npz',
'msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K100.npz',
'msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K300.npz',
'msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K500.npz',
'msaa_flexible_outputs_npz/spatialAA_across_acrossCond_K700.npz'],
[])
[ ]: