From 267c92fd20a31c9e4728bde9316cc109c9ce7f4c Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Tue, 3 Feb 2026 19:38:15 +0000 Subject: [PATCH 1/4] feat: add parallel downloading for multi-file datasets - Add max_parallel_downloads configuration to Cache class (default: 4) - Extract download logic into _download_and_read_parquet helper function - Implement parallel downloads using ThreadPoolExecutor - Maintain file ordering by extracting index from filenames - Keep serial processing for single files or when max_workers=1 - Update progress bar to track completion of parallel downloads Co-authored-by: jhamon --- pinecone_datasets/cfg.py | 3 + pinecone_datasets/dataset_fsreader.py | 109 +++++++++++++++++++++----- 2 files changed, 93 insertions(+), 19 deletions(-) diff --git a/pinecone_datasets/cfg.py b/pinecone_datasets/cfg.py index 8fc8a39..fcd5d0c 100644 --- a/pinecone_datasets/cfg.py +++ b/pinecone_datasets/cfg.py @@ -16,6 +16,9 @@ class Cache: "1", "yes", ) + max_parallel_downloads: int = int( + os.getenv("PINECONE_DATASETS_MAX_PARALLEL_DOWNLOADS", "4") + ) class Schema: diff --git a/pinecone_datasets/dataset_fsreader.py b/pinecone_datasets/dataset_fsreader.py index 2bcf37e..1927474 100644 --- a/pinecone_datasets/dataset_fsreader.py +++ b/pinecone_datasets/dataset_fsreader.py @@ -2,12 +2,13 @@ import logging import os import warnings -from typing import Literal, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Literal, Optional, Tuple import pandas as pd import pyarrow.parquet as pq -from .cfg import Schema +from .cfg import Cache, Schema from .dataset_metadata import DatasetMetadata from .fs import CloudOrLocalFS, get_cached_path, is_cloud_path from .retry import create_cloud_storage_retry_decorator @@ -71,6 +72,49 @@ def _does_datatype_exist( ) -> bool: return fs.exists(os.path.join(dataset_path, data_type)) + @staticmethod + def _download_and_read_parquet( + path: str, + fs: CloudOrLocalFS, + use_cache: bool, + protocol: Optional[str], + ) -> Tuple[int, pd.DataFrame]: + """ + Download (if needed) and read a single parquet file. + + Args: + path: Path to the parquet file + fs: Filesystem object + use_cache: Whether to use caching for this file + protocol: Protocol prefix (gs:// or s3://) if applicable + + Returns: + Tuple of (file_index, dataframe) where file_index is from the path + """ + if use_cache and protocol: + # Reconstruct full URL if path doesn't have protocol + if not path.startswith(protocol): + full_path = f"{protocol}{path}" + else: + full_path = path + # Download to cache and read from local path + local_path = get_cached_path(full_path, fs) + piece = pq.read_pandas(local_path) + else: + # Read directly from filesystem + piece = pq.read_pandas(path, filesystem=fs) + + df_piece = piece.to_pandas() + # Extract index from path for proper ordering (handles paths like "documents/0000.parquet") + try: + filename = os.path.basename(path) + file_index = int(os.path.splitext(filename)[0]) + except (ValueError, AttributeError): + # If we can't extract an index, use hash of path for consistent ordering + file_index = hash(path) + + return (file_index, df_piece) + @staticmethod def _safe_read_from_path( fs: CloudOrLocalFS, @@ -94,23 +138,50 @@ def _safe_read_from_path( elif dataset_path.startswith("https://s3.amazonaws.com/"): protocol = "s3://" - # First, collect all the dataframes - dfs = [] - for path in tqdm(read_path, desc=f"Loading {data_type}"): - if use_cache_for_dataset and protocol: - # Reconstruct full URL if path doesn't have protocol - if not path.startswith(protocol): - full_path = f"{protocol}{path}" - else: - full_path = path - # Download to cache and read from local path - local_path = get_cached_path(full_path, fs) - piece = pq.read_pandas(local_path) - else: - # Read directly from filesystem - piece = pq.read_pandas(path, filesystem=fs) - df_piece = piece.to_pandas() - dfs.append(df_piece) + # Collect all dataframes using parallel downloads + num_files = len(read_path) + max_workers = min(Cache.max_parallel_downloads, num_files) if num_files > 1 else 1 + + dfs_with_index = [] + + if max_workers == 1: + # Serial processing for single file or when max_workers=1 + for path in tqdm(read_path, desc=f"Loading {data_type}"): + file_index, df_piece = DatasetFSReader._download_and_read_parquet( + path, fs, use_cache_for_dataset, protocol + ) + dfs_with_index.append((file_index, df_piece)) + else: + # Parallel processing for multiple files + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all download tasks + future_to_path = { + executor.submit( + DatasetFSReader._download_and_read_parquet, + path, + fs, + use_cache_for_dataset, + protocol, + ): path + for path in read_path + } + + # Collect results as they complete with progress bar + with tqdm(total=num_files, desc=f"Loading {data_type}") as pbar: + for future in as_completed(future_to_path): + path = future_to_path[future] + try: + file_index, df_piece = future.result() + dfs_with_index.append((file_index, df_piece)) + except Exception as e: + logger.error(f"Failed to load {path}: {e}") + raise + finally: + pbar.update(1) + + # Sort by file index to maintain consistent ordering + dfs_with_index.sort(key=lambda x: x[0]) + dfs = [df for _, df in dfs_with_index] if not dfs: raise ValueError(f"No parquet files found in {read_path_str}") From c208ac18cb3b3c59061c32121f4a027de10236bd Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Tue, 3 Feb 2026 19:41:53 +0000 Subject: [PATCH 2/4] test: add comprehensive tests for parallel downloads - Add unit tests for parallel download configuration - Test file index extraction from filenames - Verify correct ordering of files after parallel downloads - Test both serial and parallel processing paths - Verify max_workers limits are respected - All 8 new tests pass Co-authored-by: jhamon --- tests/unit/test_parallel_downloads.py | 260 ++++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 tests/unit/test_parallel_downloads.py diff --git a/tests/unit/test_parallel_downloads.py b/tests/unit/test_parallel_downloads.py new file mode 100644 index 0000000..56183ef --- /dev/null +++ b/tests/unit/test_parallel_downloads.py @@ -0,0 +1,260 @@ +""" +Unit tests for parallel download functionality. +""" +import os +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch, call + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from pinecone_datasets import cfg +from pinecone_datasets.dataset_fsreader import DatasetFSReader + + +class TestParallelDownloads: + """Test parallel download functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.original_max_workers = cfg.Cache.max_parallel_downloads + + def teardown_method(self): + """Restore original configuration.""" + cfg.Cache.max_parallel_downloads = self.original_max_workers + + def test_parallel_download_configuration(self): + """Test that max_parallel_downloads can be configured.""" + # Test default value + assert cfg.Cache.max_parallel_downloads == 4 + + # Test setting via attribute + cfg.Cache.max_parallel_downloads = 8 + assert cfg.Cache.max_parallel_downloads == 8 + + def test_parallel_download_environment_variable(self): + """Test that max_parallel_downloads can be set via environment variable.""" + with patch.dict(os.environ, {"PINECONE_DATASETS_MAX_PARALLEL_DOWNLOADS": "10"}): + # Need to reload the module for env var to take effect + import importlib + from pinecone_datasets import cfg as cfg_module + + importlib.reload(cfg_module) + assert cfg_module.Cache.max_parallel_downloads == 10 + + def test_download_and_read_parquet_extracts_file_index(self): + """Test that _download_and_read_parquet extracts file index from filename.""" + # Create a temporary parquet file + with tempfile.TemporaryDirectory() as tmpdir: + # Create test data + df = pd.DataFrame( + { + "id": ["1", "2"], + "values": [[0.1, 0.2], [0.3, 0.4]], + "sparse_values": [None, None], + "metadata": [None, None], + "blob": [None, None], + } + ) + + # Write to parquet with numbered filename + path = os.path.join(tmpdir, "0042.parquet") + table = pa.Table.from_pandas(df) + pq.write_table(table, path) + + # Mock filesystem + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Test extraction + file_index, result_df = DatasetFSReader._download_and_read_parquet( + path=path, fs=fs, use_cache=False, protocol=None + ) + + assert file_index == 42 + assert len(result_df) == 2 + + def test_download_and_read_parquet_handles_non_numeric_filenames(self): + """Test that _download_and_read_parquet handles non-numeric filenames.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test data + df = pd.DataFrame( + { + "id": ["1", "2"], + "values": [[0.1, 0.2], [0.3, 0.4]], + "sparse_values": [None, None], + "metadata": [None, None], + "blob": [None, None], + } + ) + + # Write to parquet with non-numeric filename + path = os.path.join(tmpdir, "part-abc.parquet") + table = pa.Table.from_pandas(df) + pq.write_table(table, path) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Test extraction - should use hash fallback + file_index, result_df = DatasetFSReader._download_and_read_parquet( + path=path, fs=fs, use_cache=False, protocol=None + ) + + # File index should be consistent for same path + assert isinstance(file_index, int) + assert file_index == hash(path) + + def test_safe_read_from_path_sorts_by_file_index(self): + """Test that files are read in correct order.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create dataset directory structure + docs_dir = Path(tmpdir) / "test-dataset" / "documents" + docs_dir.mkdir(parents=True) + + # Create multiple parquet files with different indices + # Write them in reverse order to test sorting + for i in [2, 0, 1]: + df = pd.DataFrame( + { + "id": [f"id_{i}"], + "values": [[float(i)]], + "sparse_values": [None], + "metadata": [None], + "blob": [None], + } + ) + path = docs_dir / f"{i:04d}.parquet" + table = pa.Table.from_pandas(df) + pq.write_table(table, str(path)) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Read with parallel processing disabled + cfg.Cache.max_parallel_downloads = 1 + result_df = DatasetFSReader._safe_read_from_path( + fs=fs, dataset_path=str(Path(tmpdir) / "test-dataset"), data_type="documents" + ) + + # Verify order + assert len(result_df) == 3 + assert result_df.iloc[0]["id"] == "id_0" + assert result_df.iloc[1]["id"] == "id_1" + assert result_df.iloc[2]["id"] == "id_2" + + def test_safe_read_from_path_with_parallel_workers(self): + """Test that parallel workers produce correct results.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create dataset directory structure + docs_dir = Path(tmpdir) / "test-dataset" / "documents" + docs_dir.mkdir(parents=True) + + # Create multiple parquet files + expected_ids = [] + for i in range(5): + df = pd.DataFrame( + { + "id": [f"id_{i}"], + "values": [[float(i)]], + "sparse_values": [None], + "metadata": [None], + "blob": [None], + } + ) + expected_ids.append(f"id_{i}") + path = docs_dir / f"{i:04d}.parquet" + table = pa.Table.from_pandas(df) + pq.write_table(table, str(path)) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Read with parallel processing enabled + cfg.Cache.max_parallel_downloads = 3 + result_df = DatasetFSReader._safe_read_from_path( + fs=fs, dataset_path=str(Path(tmpdir) / "test-dataset"), data_type="documents" + ) + + # Verify all data is present and in correct order + assert len(result_df) == 5 + result_ids = result_df["id"].tolist() + assert result_ids == expected_ids + + def test_safe_read_from_path_single_file_uses_serial(self): + """Test that single file datasets use serial processing.""" + with tempfile.TemporaryDirectory() as tmpdir: + docs_dir = Path(tmpdir) / "test-dataset" / "documents" + docs_dir.mkdir(parents=True) + + # Create single parquet file + df = pd.DataFrame( + { + "id": ["id_0"], + "values": [[0.1]], + "sparse_values": [None], + "metadata": [None], + "blob": [None], + } + ) + path = docs_dir / "0000.parquet" + table = pa.Table.from_pandas(df) + pq.write_table(table, str(path)) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Even with high max_workers, should use serial for single file + cfg.Cache.max_parallel_downloads = 10 + result_df = DatasetFSReader._safe_read_from_path( + fs=fs, dataset_path=str(Path(tmpdir) / "test-dataset"), data_type="documents" + ) + + assert len(result_df) == 1 + assert result_df.iloc[0]["id"] == "id_0" + + def test_safe_read_from_path_respects_max_workers_limit(self): + """Test that max_workers limit is respected.""" + with tempfile.TemporaryDirectory() as tmpdir: + docs_dir = Path(tmpdir) / "test-dataset" / "documents" + docs_dir.mkdir(parents=True) + + # Create 10 parquet files + for i in range(10): + df = pd.DataFrame( + { + "id": [f"id_{i}"], + "values": [[float(i)]], + "sparse_values": [None], + "metadata": [None], + "blob": [None], + } + ) + path = docs_dir / f"{i:04d}.parquet" + table = pa.Table.from_pandas(df) + pq.write_table(table, str(path)) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Set low max_workers + cfg.Cache.max_parallel_downloads = 2 + + result_df = DatasetFSReader._safe_read_from_path( + fs=fs, dataset_path=str(Path(tmpdir) / "test-dataset"), data_type="documents" + ) + + # Verify all files were processed + assert len(result_df) == 10 + # Verify order is maintained + for i in range(10): + assert result_df.iloc[i]["id"] == f"id_{i}" From fa021c1a2d795f6d653aeb4ed2ec08edda11b5ed Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Tue, 3 Feb 2026 19:45:55 +0000 Subject: [PATCH 3/4] style: apply ruff formatting to parallel download files Co-authored-by: jhamon --- pinecone_datasets/dataset_fsreader.py | 20 +++++++++++--------- tests/unit/test_parallel_downloads.py | 17 +++++++++++++---- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pinecone_datasets/dataset_fsreader.py b/pinecone_datasets/dataset_fsreader.py index 1927474..fd8a8a6 100644 --- a/pinecone_datasets/dataset_fsreader.py +++ b/pinecone_datasets/dataset_fsreader.py @@ -81,13 +81,13 @@ def _download_and_read_parquet( ) -> Tuple[int, pd.DataFrame]: """ Download (if needed) and read a single parquet file. - + Args: path: Path to the parquet file fs: Filesystem object use_cache: Whether to use caching for this file protocol: Protocol prefix (gs:// or s3://) if applicable - + Returns: Tuple of (file_index, dataframe) where file_index is from the path """ @@ -103,7 +103,7 @@ def _download_and_read_parquet( else: # Read directly from filesystem piece = pq.read_pandas(path, filesystem=fs) - + df_piece = piece.to_pandas() # Extract index from path for proper ordering (handles paths like "documents/0000.parquet") try: @@ -112,7 +112,7 @@ def _download_and_read_parquet( except (ValueError, AttributeError): # If we can't extract an index, use hash of path for consistent ordering file_index = hash(path) - + return (file_index, df_piece) @staticmethod @@ -140,10 +140,12 @@ def _safe_read_from_path( # Collect all dataframes using parallel downloads num_files = len(read_path) - max_workers = min(Cache.max_parallel_downloads, num_files) if num_files > 1 else 1 - + max_workers = ( + min(Cache.max_parallel_downloads, num_files) if num_files > 1 else 1 + ) + dfs_with_index = [] - + if max_workers == 1: # Serial processing for single file or when max_workers=1 for path in tqdm(read_path, desc=f"Loading {data_type}"): @@ -165,7 +167,7 @@ def _safe_read_from_path( ): path for path in read_path } - + # Collect results as they complete with progress bar with tqdm(total=num_files, desc=f"Loading {data_type}") as pbar: for future in as_completed(future_to_path): @@ -178,7 +180,7 @@ def _safe_read_from_path( raise finally: pbar.update(1) - + # Sort by file index to maintain consistent ordering dfs_with_index.sort(key=lambda x: x[0]) dfs = [df for _, df in dfs_with_index] diff --git a/tests/unit/test_parallel_downloads.py b/tests/unit/test_parallel_downloads.py index 56183ef..6c53f15 100644 --- a/tests/unit/test_parallel_downloads.py +++ b/tests/unit/test_parallel_downloads.py @@ -1,6 +1,7 @@ """ Unit tests for parallel download functionality. """ + import os import tempfile from pathlib import Path @@ -140,7 +141,9 @@ def test_safe_read_from_path_sorts_by_file_index(self): # Read with parallel processing disabled cfg.Cache.max_parallel_downloads = 1 result_df = DatasetFSReader._safe_read_from_path( - fs=fs, dataset_path=str(Path(tmpdir) / "test-dataset"), data_type="documents" + fs=fs, + dataset_path=str(Path(tmpdir) / "test-dataset"), + data_type="documents", ) # Verify order @@ -180,7 +183,9 @@ def test_safe_read_from_path_with_parallel_workers(self): # Read with parallel processing enabled cfg.Cache.max_parallel_downloads = 3 result_df = DatasetFSReader._safe_read_from_path( - fs=fs, dataset_path=str(Path(tmpdir) / "test-dataset"), data_type="documents" + fs=fs, + dataset_path=str(Path(tmpdir) / "test-dataset"), + data_type="documents", ) # Verify all data is present and in correct order @@ -215,7 +220,9 @@ def test_safe_read_from_path_single_file_uses_serial(self): # Even with high max_workers, should use serial for single file cfg.Cache.max_parallel_downloads = 10 result_df = DatasetFSReader._safe_read_from_path( - fs=fs, dataset_path=str(Path(tmpdir) / "test-dataset"), data_type="documents" + fs=fs, + dataset_path=str(Path(tmpdir) / "test-dataset"), + data_type="documents", ) assert len(result_df) == 1 @@ -250,7 +257,9 @@ def test_safe_read_from_path_respects_max_workers_limit(self): cfg.Cache.max_parallel_downloads = 2 result_df = DatasetFSReader._safe_read_from_path( - fs=fs, dataset_path=str(Path(tmpdir) / "test-dataset"), data_type="documents" + fs=fs, + dataset_path=str(Path(tmpdir) / "test-dataset"), + data_type="documents", ) # Verify all files were processed From 2f65aa5cfb0be56ef588daa72b69d30c5fe1cb13 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Tue, 3 Feb 2026 19:47:59 +0000 Subject: [PATCH 4/4] fix: resolve linting issues - Remove unused imports (Mock, call, pytest) - Fix import ordering - Replace typing.Tuple with builtin tuple for Python 3.10+ compatibility - Update .gitignore to exclude __pycache__ directories Co-authored-by: jhamon --- .gitignore | 2 ++ pinecone_datasets/dataset_fsreader.py | 4 ++-- tests/unit/test_parallel_downloads.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 4e78d44..8ceaf9f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,7 @@ dist/ .mypy_cache/ scratchpad.ipynb .pycache/ +__pycache__/ +*.pyc .pytest_cache/ .coverage diff --git a/pinecone_datasets/dataset_fsreader.py b/pinecone_datasets/dataset_fsreader.py index fd8a8a6..6352747 100644 --- a/pinecone_datasets/dataset_fsreader.py +++ b/pinecone_datasets/dataset_fsreader.py @@ -3,7 +3,7 @@ import os import warnings from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Literal, Optional, Tuple +from typing import Literal, Optional import pandas as pd import pyarrow.parquet as pq @@ -78,7 +78,7 @@ def _download_and_read_parquet( fs: CloudOrLocalFS, use_cache: bool, protocol: Optional[str], - ) -> Tuple[int, pd.DataFrame]: + ) -> tuple[int, pd.DataFrame]: """ Download (if needed) and read a single parquet file. diff --git a/tests/unit/test_parallel_downloads.py b/tests/unit/test_parallel_downloads.py index 6c53f15..d891f74 100644 --- a/tests/unit/test_parallel_downloads.py +++ b/tests/unit/test_parallel_downloads.py @@ -5,12 +5,11 @@ import os import tempfile from pathlib import Path -from unittest.mock import Mock, patch, call +from unittest.mock import patch import pandas as pd import pyarrow as pa import pyarrow.parquet as pq -import pytest from pinecone_datasets import cfg from pinecone_datasets.dataset_fsreader import DatasetFSReader @@ -41,6 +40,7 @@ def test_parallel_download_environment_variable(self): with patch.dict(os.environ, {"PINECONE_DATASETS_MAX_PARALLEL_DOWNLOADS": "10"}): # Need to reload the module for env var to take effect import importlib + from pinecone_datasets import cfg as cfg_module importlib.reload(cfg_module)