Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,35 @@ Use the following to install the required libraries in a virtual environment of

`pip install -r requirements.txt`

### Running test.ipynb

Before running `test.ipynb`, ensure you have the required pretrained weights and grid artifacts:

**Required Files:**

1. **Pretrained Model Weights:**
- `vdl_weights.pt` - Pretrained model weights (must be in the root directory)
- **How to obtain:** Train the model using `python train.py --exp-name vdl`
- Note: Training requires the dataset and may take significant time. The weights will be saved as `vdl_weights.pt` in the root directory upon completion.

2. **Grid Matrices** (should be in `grid_matrices/` directory):
- `scatter_to_log_128.pt`
- `forward_from_log_128.pt`
- `scatter_from_log_128.pt`
- `sparse_grid_fracs_euclid_backward.pt`
- **How to obtain:** These files are included in the repository in the `grid_matrices/` directory. If they are missing, you can regenerate them by running the creation notebooks:
- `create_log_grid.ipynb` (generates `scatter_to_log_128.pt`)
- `create_log_forward_grid.ipynb` (generates `forward_from_log_128.pt`)
- `create_log_crop_grid.ipynb` (generates `scatter_from_log_128.pt`)
- `create_backward_grid.ipynb` (generates `sparse_grid_fracs_euclid_backward.pt`)

**Quick Start:**
1. Install dependencies: `pip install -r requirements.txt`
2. Ensure all required files are present (the notebook includes a setup cell that checks this automatically)
3. Run `test.ipynb` cells in order

The `test.ipynb` notebook includes a setup cell that automatically checks for all required files and provides clear error messages with instructions if any files are missing.

Do find the implementation in Section 4 if you want to skip to it.

Observed galaxy images of the [Galaxy10 DECaLS Dataset](https://astronn.readthedocs.io/en/latest/galaxy10.html) are used in the analyses of this work.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "202d13a2",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torch.nn.functional as F"
"import torch.nn.functional as F\n",
"import os"
]
},
{
Expand Down Expand Up @@ -42,9 +43,94 @@
"target_shape = 128"
]
},
{
"cell_type": "markdown",
"id": "24e60086",
"metadata": {},
"source": [
"## Setup: Check for Required Files\n",
"\n",
"Before running the test, ensure that all required pretrained weights and grid artifacts are available. The setup cell below will check for all required files and provide clear instructions if any are missing.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19333944",
"metadata": {},
"outputs": [],
"source": [
"# Setup: Check for required files\n",
"GRID_MATRICES_DIR = 'grid_matrices'\n",
"\n",
"# Required files\n",
"REQUIRED_WEIGHTS_FILE = 'vdl_weights.pt'\n",
"REQUIRED_GRID_FILES = [\n",
" 'scatter_to_log_128.pt',\n",
" 'forward_from_log_128.pt',\n",
" 'scatter_from_log_128.pt',\n",
" 'sparse_grid_fracs_euclid_backward.pt'\n",
"]\n",
"\n",
"# Check for grid matrices\n",
"missing_grid_files = []\n",
"for filename in REQUIRED_GRID_FILES:\n",
" filepath = os.path.join(GRID_MATRICES_DIR, filename)\n",
" if not os.path.exists(filepath):\n",
" missing_grid_files.append(filename)\n",
" print(f\"Missing: {filepath}\")\n",
"\n",
"if missing_grid_files:\n",
" print(f\"\\nMissing {len(missing_grid_files)} grid matrix file(s).\")\n",
" print(\"\\nThe grid matrices should be in the 'grid_matrices/' directory.\")\n",
" print(\"These files are included in the repository and should be present.\")\n",
" print(\"\\nIf they are missing, you can regenerate them by running:\")\n",
" print(\" - create_log_grid.ipynb (generates scatter_to_log_128.pt)\")\n",
" print(\" - create_log_forward_grid.ipynb (generates forward_from_log_128.pt)\")\n",
" print(\" - create_log_crop_grid.ipynb (generates scatter_from_log_128.pt)\")\n",
" print(\" - create_backward_grid.ipynb (generates sparse_grid_fracs_euclid_backward.pt)\")\n",
" print(\"\\nSee README.md for more information.\")\n",
" raise FileNotFoundError(f\"Missing required grid matrix files: {', '.join(missing_grid_files)}\")\n",
"else:\n",
" print(\"All grid matrix files found in grid_matrices/ directory\")\n",
"\n",
"# Check for pretrained weights\n",
"if not os.path.exists(REQUIRED_WEIGHTS_FILE):\n",
" print(f\"\\nMissing: {REQUIRED_WEIGHTS_FILE}\")\n",
" print(\"\\nThe pretrained model weights are missing.\")\n",
" print(\"\\nTo obtain the weights:\")\n",
" print(\" 1. Train the model using: python train.py --exp-name vdl\")\n",
" print(\" 2. Or download from repository releases (if available)\")\n",
" print(\"\\nNote: Training the model requires the dataset and may take significant time.\")\n",
" print(\"See README.md for detailed instructions.\")\n",
" raise FileNotFoundError(f\"Missing required pretrained weights: {REQUIRED_WEIGHTS_FILE}\")\n",
"else:\n",
" print(f\"Pretrained weights found: {REQUIRED_WEIGHTS_FILE}\")\n",
"\n",
"print(\"\\nAll required files are available. Proceeding with test setup...\")\n"
]
},
{
"cell_type": "markdown",
"id": "795a0f37",
"metadata": {},
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d4f6a5f",
"metadata": {},
"outputs": [],
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "b68a0a92",
"metadata": {},
"outputs": [],
Expand All @@ -59,10 +145,11 @@
"model.load_state_dict(torch.load('vdl_weights.pt'))\n",
"lensing_module = DifferentiableLensing(device=device, alpha=None, target_resolution=target_resolution, target_shape=target_shape).to(device)\n",
"\n",
"cross_grid_to_log = torch.load('scatter_to_log_128.pt').to(device)\n",
"cross_grid_forward_from_log = torch.load('forward_from_log_128.pt').to(device)\n",
"cross_grid_from_log = torch.load('scatter_from_log_128.pt').to(device)\n",
"cross_grid_backward = torch.load('sparse_grid_fracs_euclid_backward.pt').to(device)\n",
"# Load grid matrices from grid_matrices/ directory\n",
"cross_grid_to_log = torch.load(os.path.join(GRID_MATRICES_DIR, 'scatter_to_log_128.pt')).to(device)\n",
"cross_grid_forward_from_log = torch.load(os.path.join(GRID_MATRICES_DIR, 'forward_from_log_128.pt')).to(device)\n",
"cross_grid_from_log = torch.load(os.path.join(GRID_MATRICES_DIR, 'scatter_from_log_128.pt')).to(device)\n",
"cross_grid_backward = torch.load(os.path.join(GRID_MATRICES_DIR, 'sparse_grid_fracs_euclid_backward.pt')).to(device)\n",
"\n",
"psf, _, _ = lensing_module.gaussian_kernel(fwhm_arcsec=0.16, pixscale_arcsec=target_resolution)\n",
"psf = torch.tensor(psf, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0).to(device)"
Expand Down