diff --git a/semhash/records.py b/semhash/records.py index 0835606..5067cd6 100644 --- a/semhash/records.py +++ b/semhash/records.py @@ -119,7 +119,9 @@ def prepare_records( dict_records_typed: list[dict[str, Any]] = list(records) dict_records = [] for record in dict_records_typed: - coerced: dict[str, Any] = {} + # Start with a copy of the full record to preserve non-embedding fields + coerced: dict[str, Any] = dict(record) + # Then coerce only the embedding columns for column in columns: val = record.get(column) if val is None: diff --git a/semhash/semhash.py b/semhash/semhash.py index e866bc4..5f7e2a3 100644 --- a/semhash/semhash.py +++ b/semhash/semhash.py @@ -312,12 +312,14 @@ def _validate_if_strings(self, records: Sequence[dict[str, Any] | str]) -> list[ dict_records: Sequence[dict[str, Any]] = records # type: ignore[assignment] result: list[dict[str, Any]] = [] - for r in dict_records: - out = {} - for c in self.columns: - if (val := r.get(c)) is None: - raise ValueError(f"Column '{c}' has None value in record {r}") - out[c] = coerce_value(val) + for record in dict_records: + # Start with a copy of the full record to preserve non-embedding fields + out = dict(record) + # Then coerce only the embedding columns + for col in self.columns: + if (val := record.get(col)) is None: + raise ValueError(f"Column '{col}' has None value in record {record}") + out[col] = coerce_value(val) result.append(out) return result diff --git a/semhash/version.py b/semhash/version.py index 9dbcf97..f05245b 100644 --- a/semhash/version.py +++ b/semhash/version.py @@ -1,2 +1,2 @@ -__version_triple__ = (0, 4, 0) # pragma: no cover +__version_triple__ = (0, 4, 1) # pragma: no cover __version__ = ".".join(map(str, __version_triple__)) # pragma: no cover diff --git a/tests/test_semhash.py b/tests/test_semhash.py index 55119fd..6baf030 100644 --- a/tests/test_semhash.py +++ b/tests/test_semhash.py @@ -298,6 +298,41 @@ def test_from_records_edge_cases(model: Encoder) -> None: SemHash.from_records([{"text": "apple"}, {"text": None}], columns=["text"], model=model) +def test_preserve_non_embedding_fields(model: Encoder) -> None: + """Test that fields not specified in columns are preserved in results.""" + records = [ + {"id": 0, "text": "triforce", "metadata": "game1"}, + {"id": 1, "text": "master sword", "metadata": "game2"}, + {"id": 2, "text": "hylian shield", "metadata": "game3"}, + ] + semhash = SemHash.from_records(records, columns=["text"], model=model) + + # Test self_deduplicate preserves non-embedding fields + result = semhash.self_deduplicate(threshold=0.9) + assert len(result.selected) == 3, "All records should be unique" + + # All results should have id and metadata fields preserved + for record in result.selected: + assert "id" in record, "id field should be preserved" + assert "text" in record, "text field should be preserved" + assert "metadata" in record, "metadata field should be preserved" + + # Check specific values are correct + ids = {r["id"] for r in result.selected} + assert ids == {0, 1, 2}, "All id values should be preserved" + + metadatas = {r["metadata"] for r in result.selected} + assert metadatas == {"game1", "game2", "game3"}, "All metadata values should be preserved" + + # Test that cross-dataset deduplication also preserves fields + new_records = [{"id": 10, "text": "triforce", "metadata": "duplicate"}] + dup_result = semhash.deduplicate(new_records, threshold=0.9) + + assert len(dup_result.filtered) == 1, "Should detect duplicate" + assert "id" in dup_result.filtered[0].record, "id should be preserved in filtered records" + assert dup_result.filtered[0].record["id"] == 10, "Correct id value" + + def test_deduplicate_edge_cases(model: Encoder) -> None: """Test deduplicate() edge cases: coercion, None rejection, empty records, type mismatches.""" semhash = SemHash.from_records(["1", "2", "3"], model=model)