Skip to content

tsoj/fmri_attention_connectivity

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Computing directed connectivity by predicting fMRI signals using attention based artificial neural networks

In this work, I explore a method to extract directed connectivity information from fMRI resting state signals. As an alternative to, for example, correlation and Granger causality based approaches, this new method is based on an artificial neural network (ANN) model using a combination of dense fully connected layers and attention mechanisms. The ANN is trained to predict future fMRI signals, and the attention matrix during inference is used to describe connectivity between regions. The resulting attention connectivity matrices compare competitively to Granger causality and Pearson and partial correlation on fingerprintand on predicting individual behavior.

Paper

Setup

conda create --name fmri python=3.12
conda activate fmri
pip install torch --index-url https://download.pytorch.org/whl/rocm6.4 # use the right torch version for your machine
pip install numpy nibabel nilearn pydantic tqdm matplotlib mne mne-connectivity seaborn

Train

# Single group level model
python src/group_level_training.py config/group_train_config.json

# One model per individual
python src/subject_level_training.py config/subject_train_config.json

Extract connectivity matrices

# For baselines (Granger, Pearson, partial)
python src/evaluate_baselines.py \
       csv/subject_groups.csv \
       /path/to/HCP_Young_Adult_2025/data/ \
       atlas/Schaefer2018_100Parcels_17Networks_order.dlabel.nii \
       results/

# For group-level trained models
python src/evaluate_models.py \
       csv/subject_groups.csv \
       /path/to/HCP_Young_Adult_2025/data/ \
       atlas/Schaefer2018_100Parcels_17Networks_order.dlabel.nii \
       /path/to/group/level/model/ \
       results/ \
       --model-type group-level

# For subject-level trained models
python src/evaluate_models.py \
       csv/subject_groups.csv \
       /path/to/HCP_Young_Adult_2025/data/ atlas/Schaefer2018_100Parcels_17Networks_order.dlabel.nii \
       /path/to/subject/level/model/ \
       results/ \
       --model-type subject-level

Run behavioral benchmark

You need to download the csv file behavioral data of the HCP-YA S1200 release with all unrestricted columns. Here it is named csv/HCP_YA_subjects_2025_09_12_01_57_30.csv.

# connectivity_matrix_pattern can be mean_attention, pearson_correlation, partial_correlation, or granger_causality
python src/behavioral_connectivity.py \
csv/subject_groups.csv \
/path/to/connectivity/matrices/ \
connectivity_matrix_pattern.npy \
csv/HCP_YA_subjects_2025_09_12_01_57_30.csv \
csv/behaviour_targets.csv \
results/

Run fingerprinting benchmark

# connectivity_matrix_pattern can be mean_attention, pearson_correlation, partial_correlation, or granger_causality
# /path/to/test/set/matrices/ and /path/to/retest/set/matrices/ can be the same path
python src/fingerprint_connectivity.py \
csv/subject_ids_test-retest.csv \
connectivity_matrix_pattern.npy \
/path/to/test/set/matrices/ \
/path/to/retest/set/matrices/ \
results/ \
--method pearson --z-normalize --comparison-mode test_to_retest

Visualize

Scripts to visualize the results are:

  • src/compare_behavioral_results.py
  • src/compare_fingerprint_results.py
  • src/compare_prediction_r2_distributions.py
  • src/plot_training_curves.py
  • src/visualize_connectivity.py

About

Using attention based neural networks to compute functional connectivity from fMRI data

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors