Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/update-notebooks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Update Colab notebooks

on:
push:
branches: [main]
paths:
- "examples/[0-9][0-9]_*.py"
- "docs/make_notebooks.py"

jobs:
update-notebooks:
runs-on: ubuntu-latest
permissions:
contents: write

steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install jupytext
run: pip install jupytext

- name: Regenerate notebooks
run: python docs/make_notebooks.py

- name: Commit updated notebooks
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add docs/auto_examples/*.ipynb
git diff --staged --quiet || git commit -m "auto: regenerate Colab notebooks from .py examples [skip ci]"
git push
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: local
hooks:
- id: update-colab-notebooks
name: Regenerate Colab notebooks
language: python
additional_dependencies: [jupytext]
entry: python docs/make_notebooks.py --stage --examples
files: ^examples/\d{2}_.*\.py$
pass_filenames: true
23 changes: 23 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
version: 2

build:
os: ubuntu-22.04
tools:
python: "3.12"
# RTD does not execute sphinx-gallery examples (GPU-dependent, slow).
# Set BRAINDEC_BUILD_GALLERY=1 locally to build with executed outputs.
jobs:
post_install:
- python -c "from braindec._version import __version__; print(__version__)"

sphinx:
configuration: docs/conf.py
fail_on_warning: false

python:
install:
- method: pip
path: .
extra_requirements:
- doc
- plotting
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ The trained baseline models use in the paper can be downloaded from the OSF repo

Alternatively, you can use the pre-trained model provided in the `./results/pubmed` directory in https://osf.io/dsj56/.

### Download published assets

The package now includes an OSF downloader for the assets documented in this README. It can download individual files, predefined bundles, or whole published folders while recreating the OSF directory layout under a destination root.

```bash
# List built-in assets and bundles
python -m braindec.fetcher --list

# Download the example prediction bundle into the current repository
python -m braindec.fetcher --bundle example_prediction --destination_root .

# Download the published pretrained results and baseline folders
python -m braindec.fetcher --bundle paper_results --destination_root .

# Download a specific published folder from the OSF project
python -m braindec.fetcher --folder data/cognitive_atlas --destination_root .
```

### Predictions

To perform predictions using the trained model, you can use the [predict.py](./braindec/predict.py) script.
Expand Down
15 changes: 12 additions & 3 deletions braindec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
"""Braindec: Brain image decoder."""

from . import dataset, embedding, loss, model, plot, train, utils # predict
from importlib import import_module

__all__ = [
"model",
"dataset",
"loss",
"embedding",
"fetcher",
"loss",
"model",
"plot",
"train",
"utils",
]


def __getattr__(name):
if name in __all__:
module = import_module(f".{name}", __name__)
globals()[name] = module
return module
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
73 changes: 55 additions & 18 deletions braindec/cogatlas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import os.path as op
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import pandas as pd
import requests
from nimare import extract

COGATLAS_URLS = {
"task": "https://www.cognitiveatlas.org/api/v-alpha/task",
Expand Down Expand Up @@ -56,6 +57,43 @@ def _get_concepts_to_tasks(relationships_df, concept_to_task=None):
return concepts_to_tasks_df


def _fetch_full_task_concepts(task_ids, cache_fn, concept_to_task=None, max_workers=16):
if cache_fn is not None and op.exists(cache_fn):
concepts_to_tasks_df = pd.read_csv(cache_fn)
else:
base_url = "https://www.cognitiveatlas.org/api/v-alpha/task"

def _fetch_one(task_id):
response = requests.get(base_url, params={"id": task_id}, timeout=30)
response.raise_for_status()
task_json = response.json()
rows = []
for concept in task_json.get("concepts", []):
concept_id = concept.get("concept_id")
if concept_id:
rows.append({"id": concept_id, "measuredBy": task_id})
return rows

rows = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for task_rows in executor.map(_fetch_one, task_ids):
rows.extend(task_rows)
concepts_to_tasks_df = pd.DataFrame(rows).drop_duplicates()
if cache_fn is not None:
concepts_to_tasks_df.to_csv(cache_fn, index=False)

if concept_to_task is not None:
extra_df = pd.DataFrame(
{
"id": list(concept_to_task.keys()),
"measuredBy": list(concept_to_task.values()),
}
)
concepts_to_tasks_df = pd.concat([concepts_to_tasks_df, extra_df], ignore_index=True)
concepts_to_tasks_df = concepts_to_tasks_df.drop_duplicates()
return concepts_to_tasks_df


class CognitiveAtlas:
def __init__(
self,
Expand Down Expand Up @@ -131,11 +169,12 @@ def __init__(
if reduced_tasks is not None:
concepts_to_tasks_df = self._get_concepts_to_tasks_red(reduced_tasks)
else:
cogatlas = extract.download_cognitive_atlas(data_dir=data_dir, overwrite=False)
relationships_df = pd.read_csv(cogatlas["relationships"])

concepts_to_tasks_df = _get_concepts_to_tasks(
relationships_df,
cache_fn = None
if data_dir is not None:
cache_fn = op.join(data_dir, "cognitive_atlas", "full_task_concepts.csv")
concepts_to_tasks_df = _fetch_full_task_concepts(
self.task_df["id"].tolist(),
cache_fn=cache_fn,
concept_to_task=concept_to_task,
)

Expand All @@ -148,14 +187,14 @@ def __init__(
continue

sel_tasks = sel_df["measuredBy"].values
indices = np.where(np.in1d(self.task_df["id"].values, sel_tasks))[0]
indices = np.where(np.isin(self.task_df["id"].values, sel_tasks))[0]

self.concept_to_task_idxs.append(indices)

self.process_to_concept_idxs = []
for process in self.process_names:
sel_df = self.concept_df.loc[self.concept_df["cognitive_process"] == process]
indices = np.where(np.in1d(self.concept_df["id"].values, sel_df["id"].values))[0]
indices = np.where(np.isin(self.concept_df["id"].values, sel_df["id"].values))[0]

self.process_to_concept_idxs.append(indices)

Expand All @@ -167,7 +206,7 @@ def __init__(
continue

sel_concepts = sel_df["id"].values
indices = np.where(np.in1d(self.concept_df["id"].values, sel_concepts))[0]
indices = np.where(np.isin(self.concept_df["id"].values, sel_concepts))[0]

self.task_to_concept_idxs.append(indices)

Expand All @@ -176,11 +215,9 @@ def __init__(
sel_concepts = concepts_to_tasks_df.loc[concepts_to_tasks_df["measuredBy"] == task][
"id"
].values
if task == "trm_550b54a8b30f4":
print(sel_concepts)

sel_df = self.concept_df.loc[self.concept_df["id"].isin(sel_concepts)]
indices = np.where(np.in1d(self.process_ids, sel_df["id_concept_class"].values))[0]
indices = np.where(np.isin(self.process_ids, sel_df["id_concept_class"].values))[0]

self.task_to_process_idxs.append(indices)

Expand All @@ -198,24 +235,24 @@ def get_task_id_from_name(self, names):

def get_task_idx_from_names(self, names):
if isinstance(names, str):
return np.where(np.in1d(self.task_names, names))[0][0]
return np.where(np.isin(self.task_names, names))[0][0]

return [np.where(np.in1d(self.task_names, task_name))[0][0] for task_name in names]
return [np.where(np.isin(self.task_names, task_name))[0][0] for task_name in names]

def get_concept_idx_from_names(self, names):
if isinstance(names, str):
return np.where(np.in1d(self.concept_names, names))[0][0]
return np.where(np.isin(self.concept_names, names))[0][0]

return [
np.where(np.in1d(self.concept_names, concept_name))[0][0] for concept_name in names
np.where(np.isin(self.concept_names, concept_name))[0][0] for concept_name in names
]

def get_process_idx_from_names(self, names):
if isinstance(names, str):
return np.where(np.in1d(self.process_names, names))[0][0]
return np.where(np.isin(self.process_names, names))[0][0]

return [
np.where(np.in1d(self.process_names, process_name))[0][0] for process_name in names
np.where(np.isin(self.process_names, process_name))[0][0] for process_name in names
]

def get_task_names_from_idx(self, task_idx):
Expand Down
Loading