Conversation
…ty of the model training command line
Feature/more model options
…bedding-kit into feature/remove-hopkraft
Co-authored-by: lbeckman314 <15309567+lbeckman314@users.noreply.github.com>
Use package metadata for version and add test coverage
…dding-kit into feature/version+tests
feat: Add version flag
… feature/remove-hopkraft
Feature/more model options
Issues related to HLA2Vec development
…ncies from PR review Co-authored-by: kellrott <113868+kellrott@users.noreply.github.com>
Fix NetVAE serialization, weight-access crash, tautological test, and doc inconsistencies
Claude suggestions for docs
Encoding Classes
There was a problem hiding this comment.
Pull request overview
This PR aggregates a large set of refactors and new features targeting the next release (v0.3), including a new factory/serialization system for models, a datasets→resources re-org, expanded CLI/docs, and substantial new/updated tests.
Changes:
- Introduces
embkit.factory(registry +to_dict/from_dict+save/load) and migrates VAE/FFNN components to it. - Refactors “datasets” downloaders into
embkit.resourcesand updates CLI/docs/tests accordingly; adds--versionoutput. - Adds new utilities (files: CSV/H5/loaders, encoding, pathway SIF mask) plus extensive new test coverage and documentation.
Reviewed changes
Copilot reviewed 114 out of 123 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_main.py | Adds CLI --version tests |
| tests/resources/test_resource_base.py | New tests for Resource/Downloader |
| tests/resources/test_resource_and_gtex.py | New tests for GTEx resource |
| tests/resources/test_resource.py | Updates tests to Resource rename |
| tests/resources/test_c_bio_portal.py | Updates imports/patch paths |
| tests/resources/init.py | Adds tests package init |
| tests/preprocessing/test_align.py | Removes Hopcroft-Karp tests |
| tests/pathway/test_pathway_mask.py | Tests build_sif_mask |
| tests/pathway/init.py | Adds tests package init |
| tests/optimize/test_fit.py | Tests tuple-input training |
| tests/modules/test_pariwise_comparison_module.py | Updates PairwiseComparison import |
| tests/modules/test_modules.py | Updates module/factory imports |
| tests/modules/init.py | Adds tests package init |
| tests/models/vae_models/test_net_vae.py | Adds NetVAE constraint tests |
| tests/models/vae_models/test_encoder.py | Updates to Layer/LayerList |
| tests/models/vae_models/test_decoder.py | Updates to Layer/LayerList |
| tests/files/test_h5.py | Adds H5 reader/writer tests |
| tests/files/test_csv.py | Adds CsvReader tests |
| tests/files/init.py | Adds tests package init |
| tests/factory/test_vae.py | Tests factory save/load for VAE |
| tests/factory/test_mapping_sequential.py | Tests mapping Sequential round-trip |
| tests/factory/test_mapping.py | Tests mapping utilities |
| tests/factory/test_layers_helpers.py | Tests Layer/LayerList helpers |
| tests/factory/test_layer_helpers.py | Tests mask helper functions |
| tests/factory/test_core_build_extended.py | Extends factory core.build tests |
| tests/factory/test_core_build.py | Tests factory core.build |
| tests/factory/test_construct.py | Tests factory constructors/build |
| tests/encoding/test_encoding.py | Tests one-hot encoders |
| tests/data/sample_pathway.sif | Adds sample SIF data |
| src/embkit/version.py | Adds version/git metadata helper |
| src/embkit/utilities/pca.py | Updates import to new files pkg |
| src/embkit/utilities/kmeans.py | Updates import to new files pkg |
| src/embkit/resources/util.py | Adds downloader util helper |
| src/embkit/resources/sif.py | Updates SingleFileDownloader import |
| src/embkit/resources/resource.py | Renames Dataset→Resource, downloader tweaks |
| src/embkit/resources/hugo.py | Updates SingleFileDownloader import |
| src/embkit/resources/gtex.py | Renames base class to Resource |
| src/embkit/resources/c_bio_portal.py | Renames base class + move to shutil |
| src/embkit/resources/init.py | Exposes resources API |
| src/embkit/preprocessing/normalize.py | Adds zero-mask helper + typing tweaks |
| src/embkit/preprocessing/dataset.py | Adds iterable batching helper |
| src/embkit/preprocessing/init.py | Exposes new preprocessing utilities |
| src/embkit/pathway.py | Adds build_sif_mask |
| src/embkit/modules/tsp.py | Renames local variable (pairs→vpairs) |
| src/embkit/modules/pairwise_comparison_layer.py | Adds PairwiseComparison layer |
| src/embkit/modules/masked_linear.py | Adds dtype passthrough |
| src/embkit/modules/init.py | New modules package exports |
| src/embkit/models/vae/vae.py | Migrates VAE to factory serialization |
| src/embkit/models/vae/rna_vae.py | Migrates LayerInfo→Layer |
| src/embkit/models/vae/pytorch.py | Removes legacy script/module |
| src/embkit/models/vae/net_vae.py | Delegates training to fit_net_vae; adds (de)serialization |
| src/embkit/models/vae/encoder.py | Migrates to factory LayerList + adds (de)serialization |
| src/embkit/models/vae/decoder.py | Migrates to factory LayerList + adds (de)serialization |
| src/embkit/models/vae/base_vae.py | Removes legacy save/load; adds factory-oriented base |
| src/embkit/models/protein.py | Removes legacy protein module (moved) |
| src/embkit/models/ffnn.py | Reworks FFNN construction via factory |
| src/embkit/losses/vae_loss.py | Adds return typing annotations |
| src/embkit/layers/layer_info.py | Removes old layers API |
| src/embkit/layers/constraint_info.py | Removes old constraint info |
| src/embkit/layers/init.py | Removes old layers package init |
| src/embkit/files/read_csv.py | Adds CsvReader; trims example code |
| src/embkit/files/loaders.py | Adds GCT/GTEx/HUGO loaders |
| src/embkit/files/json.py | Adds JSON encoder helper |
| src/embkit/files/h5.py | Adds H5 reader/writer implementations |
| src/embkit/files/init.py | Exposes files API |
| src/embkit/file_readers/init.py | Removes deprecated file_readers export |
| src/embkit/factory/registry.py | Adds global class registry + decorators |
| src/embkit/factory/mapping.py | Adds serializable nn.Module mappings |
| src/embkit/factory/layers.py | Adds Layer/LayerList/ConstraintInfo (new) |
| src/embkit/factory/core.py | Adds build/save/load implementation |
| src/embkit/factory/init.py | Factory public exports |
| src/embkit/estimator/vae_estimator.py | Adds typing to fit/score |
| src/embkit/encoding/protein.py | Moves ProteinEncoder implementation |
| src/embkit/encoding/init.py | Adds OneHot + ProteinOneHot encoders |
| src/embkit/datasets/init.py | Replaces dataset exports with dataset utilities |
| src/embkit/constraints/network_constraint.py | Adds (de)serialization |
| src/embkit/commands/resources.py | Renames CLI group datasets→resources |
| src/embkit/commands/protein.py | Adds --pool none + JSON output |
| src/embkit/commands/cbio.py | Updates imports to resources |
| src/embkit/commands/align.py | Removes hopkraft path |
| src/embkit/commands/init.py | Wires in resources command group |
| src/embkit/align.py | Removes Hopcroft-Karp implementation |
| src/embkit/main.py | Adds --version; fixes help traversal |
| src/embkit/init.py | Adds dtype support + dataset helper |
| src/embkit.egg-info/top_level.txt | Removes generated metadata file |
| src/embkit.egg-info/requires.txt | Removes generated metadata file |
| src/embkit.egg-info/entry_points.txt | Removes generated metadata file |
| src/embkit.egg-info/dependency_links.txt | Removes generated metadata file |
| src/embkit.egg-info/SOURCES.txt | Removes generated metadata file |
| src/embkit.egg-info/PKG-INFO | Removes generated metadata file |
| requirements.txt | Updates deps (fair-esm, faiss-cpu) |
| pyproject.toml | Adds deps (faiss-cpu, h5py) |
| mkdocs.yml | Expands docs nav + reorganizes API |
| docs/training.md | Adds training guide |
| docs/index.md | Rewrites landing page |
| docs/examples/protein.md | Adds protein example |
| docs/examples/netvae.md | Adds NetVAE example |
| docs/examples/hla2vec.md | Removes legacy example |
| docs/examples/gtex.md | Expands GTEx example |
| docs/examples/factory.md | Adds factory decorator examples |
| docs/concepts.md | Adds concepts overview |
| docs/cli.md | Adds CLI reference |
| docs/api/preprocessing.md | Adds preprocessing API page |
| docs/api/pathway.md | Adds pathway API page |
| docs/api/optimize.md | Adds optimize API page |
| docs/api/models/rna_vae.md | Adds RNAVAE API page |
| docs/api/models/net_vae.md | Adds NetVAE API page |
| docs/api/layers/tsp.md | Updates doc path to modules |
| docs/api/layers/pairwise_comparison_layer.md | Updates doc path to modules |
| docs/api/layers/masked_linear.md | Updates doc path to modules |
| docs/api/layers/layers.md | Updates doc path to modules |
| docs/api/layers/layer_info.md | Updates docs to factory Layer/LayerList |
| docs/api/factory/index.md | Adds factory API page |
| docs/api/commands/cbio.md | Updates cbio docs |
| docs/api/bmeg.md | Adds BMEG API page |
| docs/api.md | Removes legacy mkdocstrings index |
| docs/about.md | Rewrites about page |
| .gitignore | Ignores *.egg-info |
| .github/workflows/pr_coverage_check.yml | Lowers coverage thresholds |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| Classes: | ||
| MaskedLinear: Linear layer whose weight is elementwise-multiplied by a mask at forward time. | ||
| LayerInfo: A data structure to hold information about each layer, including the number of units, | ||
| activation function, and whether to use batch normalization. | ||
| PairwiseComparison: A layer that performs pairwise comparisons between inputs. |
There was a problem hiding this comment.
The module docstring still references LayerInfo, but LayerInfo no longer exists/gets exported from embkit.modules (layer configs now live under embkit.factory.layers.Layer). Please update the docstring to avoid misleading API docs.
| def get_activation(name: Optional[str]) -> Optional[nn.Module]: | ||
| """ | ||
| Get a PyTorch activation function module from a string name. | ||
| Args: | ||
| name (Optional[str]): Name of the activation function (e.g., "relu", " | ||
| "tanh", "sigmoid", etc.). If None or empty, returns None. | ||
|
|
||
|
|
||
| Returns: | ||
| Optional[nn.Module]: Corresponding PyTorch activation function module or None if not found. | ||
| """ |
There was a problem hiding this comment.
get_activation() is annotated/documented as returning an nn.Module, but it actually returns a registered class (callers do get_activation(name)() to instantiate it). This mismatch can confuse type checking and readers. Consider changing the return type to Optional[type[nn.Module]] (or Optional[Type[nn.Module]]) and updating the docstring accordingly.
| @@ -87,26 +83,31 @@ | |||
| ) | |||
There was a problem hiding this comment.
latent_dim is now a required int, but the constructor still has checks like if self.latent_dim is None:. These branches are now unreachable and can hide real errors when refactoring. Consider removing the None checks and simplifying the error messages accordingly (or revert latent_dim to Optional[int] if None is still a valid state).
| def to(self, device=None, dtype=None): | ||
| self.encoder.to(device=device, dtype=dtype) | ||
| self.decoder.to(device=device, dtype=dtype) | ||
|
|
There was a problem hiding this comment.
BaseVAE.to() overrides nn.Module.to() but (1) doesn’t return self, (2) doesn’t call super().to(...), and (3) will raise if encoder/decoder are None. This can break call sites that expect model = model.to(device) and can leave any parameters defined on subclasses unmoved. Consider delegating to super().to(device=device, dtype=dtype) and returning self, while guarding self.encoder/self.decoder when they’re unset.
| if act is not None: | ||
| layers.append(act()) | ||
| if self.batch_norm: | ||
| layers.append(nn.BatchNorm1d(out_features)) | ||
| return layers |
There was a problem hiding this comment.
When batch_norm is enabled, Layer.gen_layer() adds nn.BatchNorm1d(out_features) without passing through the provided device / dtype. This can lead to mixed-device / mixed-dtype modules at construction time (especially when callers pass device/dtype to build). Consider constructing the BatchNorm with the same device and dtype arguments for consistency.
| eps = torch.randn_like(std) | ||
| z = mu + eps * std | ||
| return mu, logvar, z | ||
| return mu, logvar, h | ||
|
|
There was a problem hiding this comment.
In the non-sampling path, the encoder returns (mu, logvar, h) where h is the pre-head hidden tensor, not the latent sample. This makes the decoder operate on a different tensor than the one used to compute the KL term (mu/logvar), which breaks the VAE objective and also makes BaseVAE.encode() return the wrong latent. Consider returning (mu, logvar, mu) when sampling=False (or (mu, logvar, z) if you want reparameterized sampling), so the third return value is the actual latent fed to the decoder.
There was a problem hiding this comment.
This is incorrect. You replicated the mu variable and returned it twice. The third parameter is expected to be the value (h).
There was a problem hiding this comment.
Acknowledged — the revert you made in 6cc4ae5 is correct. Returning h (the pre-head hidden representation) as the third value is the right behavior; mu is only appropriate when you explicitly want a deterministic mean-as-latent encoding, which isn't the intended design here.
| def get_dataset_nonzero_mask(d: Dataset, threshold: float) -> torch.Tensor: | ||
| """ | ||
| Scan a torch Dataset, identify the elements that are non-zero above the threshold | ||
| percentage. This returns a binary mask that selects those columns. | ||
| """ | ||
|
|
||
| s = None | ||
| count = 0 | ||
| for row in d: | ||
| if s is None: | ||
| s = [] | ||
| for f in row: | ||
| s.append( torch.zeros( len(f) ) ) | ||
| for i, f in enumerate(row): | ||
| s[i] = s[i] + (f == 0.0).int() | ||
| count += 1 | ||
| out = [] | ||
| for v in s: | ||
| mask = (v.int() / count) < threshold | ||
| out.append(mask) | ||
| return out |
There was a problem hiding this comment.
get_dataset_nonzero_mask() is annotated as returning torch.Tensor, but it actually returns a list[Tensor] (one mask per tensor in each dataset row). Also, the accumulator tensors are created on the default device/dtype (torch.zeros(len(f))), which will break if the dataset yields tensors on a non-CPU device. Consider updating the return type (and docstring) to list[torch.Tensor] and initializing accumulators with torch.zeros_like(f, dtype=torch.int32) (or torch.zeros(f.shape, device=f.device) as appropriate).
|
@copilot open a new pull request to apply changes based on the comments in this thread |
…ice, encoder return value Co-authored-by: kellrott <113868+kellrott@users.noreply.github.com>
Co-authored-by: kellrott <113868+kellrott@users.noreply.github.com>
…onzero-mask Add unit tests for `get_dataset_nonzero_mask`
Fix type annotations, dead code, VAE objective, and device consistency bugs from PR review
Co-authored-by: kellrott <113868+kellrott@users.noreply.github.com>
…factory-save Switch examples to use `factory.save` instead of `model.save`
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 117 out of 126 changed files in this pull request and generated 9 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class DummyObj(nn.Module): | ||
| def __init__(self, data): | ||
| self._data = data | ||
|
|
There was a problem hiding this comment.
DummyObj subclasses nn.Module but its __init__ doesn't call super().__init__(). This typically raises at runtime because nn.Module isn't initialized before setting attributes. Add super().__init__() in __init__ to make the test reliable.
| _, _, batch_tokens = self.batch_converter(block) | ||
| if fix_len is not None: | ||
| #if they have defined that the tokenization will be a fixed length | ||
| batch_tokens = self.pad(batch_tokens, fix_len+1) # length plus start token | ||
| if self.device is not None: | ||
| batch_tokens = batch_tokens.to(device=self.device, non_blocking=True) | ||
| if fix_len is None: | ||
| batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1) | ||
| else: | ||
| batch_lens = [fix_len] * len(block) | ||
| with torch.no_grad(): | ||
| results = self.model(batch_tokens, | ||
| repr_layers=[self.out_layer], | ||
| return_contacts=True) | ||
| token_representations = results["representations"][self.out_layer] | ||
| for i, tokens_len in enumerate(batch_lens): | ||
| if fix_len: | ||
| vec = token_representations[i, 1 : fix_len+2] # include fix_len + start_token | ||
| else: | ||
| vec = token_representations[i, 1 : tokens_len - 1] # remove padding | ||
| if output == "mean-pool": | ||
| yield block[i][0], vec.mean(0).to(device="cpu") | ||
| elif output == "sum-pool": | ||
| yield block[i][0], vec.sum(0).to(device="cpu") | ||
| else: | ||
| yield block[i][0], vec.to(device="cpu") | ||
|
|
||
| def pad(self, tokens, fix_len): | ||
| padded_tokens = F.pad(tokens, (0, fix_len - len(tokens[0]) ), value=self.alphabet.padding_idx) | ||
| return padded_tokens No newline at end of file |
There was a problem hiding this comment.
When fix_len is smaller than the tokenized sequence length, F.pad(..., fix_len - len(tokens[0])) becomes negative and will error. Since CLI/docs describe "pad or truncate", implement truncation (slice to fix_len) before padding, and use if fix_len is not None instead of truthiness checks.
| if isinstance(columns, int): | ||
| self.columns = pd.RangeIndex(columns) | ||
| var_group = base_group.create_group("var") | ||
| var_group.create_dataset("_index", | ||
| data=columns, | ||
| dtype=int) | ||
| self.dataset = base_group.create_dataset("X", (len(index),columns), dtype='f') |
There was a problem hiding this comment.
The columns-as-int writer path writes var/_index as a single integer (data=columns) with dtype int, but H5Reader always assumes var/_index contains per-column string entries. This makes the int-path unreadable. Either write a proper range index dataset compatible with the reader, or remove/guard this mode.
There was a problem hiding this comment.
@copilot apply changes based on this feedback and add unit tests to check behavior
| self.assertEqual(result.returncode, 0) | ||
| self.assertIn("embedding-kit version", result.stdout) | ||
| self.assertRegex(result.stdout, r'embedding-kit version \d+\.\d+') | ||
|
|
There was a problem hiding this comment.
Same issue as above: the regex for the python -m embkit --version output only matches major.minor and will fail for major.minor.patch versions. Update the assertion to accept the full version string format used by package metadata.
| def test_default_save_path_created(self): | ||
| ds = DummyDataset(name="test", save_path=None, download=False) | ||
| # Should create a .embkit directory in home | ||
| default_dir = Path(Path.home(), ".embkit") | ||
| self.assertTrue(ds.save_path == default_dir) | ||
| self.assertTrue(ds.save_path.is_dir()) |
There was a problem hiding this comment.
This test asserts against the real home directory and requires ~/.embkit to be creatable. To avoid side effects/flakiness in CI, patch Path.home() to a temporary directory (or set HOME env var within the test) and assert against that.
There was a problem hiding this comment.
@copilot apply changes based on this feedback
There was a problem hiding this comment.
Applied in commit 3145fa5 on branch copilot/fix-home-patch-2. test_default_save_path_created and test_str_representation now patch embkit.resources.resource.Path.home inside a TemporaryDirectory context so they no longer touch the real ~/.embkit. All 5 tests in the file pass.
Note: the develop branch is protected and report_progress cannot push to it directly. The change is on branch copilot/fix-home-patch-2 and needs to be merged into develop via a sub-PR.
| for i, n in enumerate(self.classes): | ||
| self.mapping[n] = F.one_hot( torch.tensor(i), self.num_classes ).to(device) | ||
| self.class_idx[n] = i |
There was a problem hiding this comment.
Tensor.to(device) will raise if device is None. Since device defaults to None here, constructing OneHotEncoder(labels) can fail. Only call .to(...) when a non-None device is provided (or use device=self.device in torch.tensor(...) and skip the .to).
| s = None | ||
| count = 0 | ||
| for row in d: | ||
| if s is None: | ||
| s = [] | ||
| for f in row: | ||
| s.append(torch.zeros(f.shape, dtype=torch.int32, device=f.device)) | ||
| for i, f in enumerate(row): | ||
| s[i] = s[i] + (f == 0.0).int() | ||
| count += 1 | ||
| out = [] | ||
| for v in s: | ||
| mask = (v.int() / count) < threshold | ||
| out.append(mask) |
There was a problem hiding this comment.
If the dataset is empty, s stays None and count stays 0, so iterating for v in s will crash (and division by zero would also be an issue). Add an explicit empty-dataset guard (e.g., return [] or raise a clear ValueError) before building masks.
| self.data = self.hfile[self.group]["X"] | ||
| self.index = pd.Index(i.decode("utf-8") for i in self.hfile[self.group]["obs/_index"]) | ||
| self.columns = pd.Index(i.decode("utf-8") for i in self.hfile[self.group]["var/_index"]) |
There was a problem hiding this comment.
h5py.string_dtype() datasets are often returned as Python str (not bytes) by h5py, so calling .decode('utf-8') can raise AttributeError. Consider handling both bytes and str (decode only when the value is bytes) when building index/columns.
| self.assertIn("embedding-kit version", result.output) | ||
| # Should include version number from package metadata | ||
| self.assertRegex(result.output, r'embedding-kit version \d+\.\d+') | ||
|
|
There was a problem hiding this comment.
The version regex only matches major.minor (e.g. 0.3) but package versions are typically major.minor.patch (e.g. 0.1.2). This will make the test fail for normal semantic versions. Consider updating the regex to accept an optional patch segment (or use packaging.version to validate).
(mu, logvar, mu)→(mu, logvar, h)(already done by @kellrott in 6cc4ae5)tests/resources/test_resource_and_gtex.pyto patchPath.home()with a temp directory instead of touching the real home dir📱 Kick off Copilot coding agent tasks wherever you are with GitHub Mobile, available on iOS and Android.