diff --git a/.gitignore b/.gitignore index 4e78d44..8ceaf9f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,7 @@ dist/ .mypy_cache/ scratchpad.ipynb .pycache/ +__pycache__/ +*.pyc .pytest_cache/ .coverage diff --git a/pinecone_datasets/cfg.py b/pinecone_datasets/cfg.py index 8fc8a39..fcd5d0c 100644 --- a/pinecone_datasets/cfg.py +++ b/pinecone_datasets/cfg.py @@ -16,6 +16,9 @@ class Cache: "1", "yes", ) + max_parallel_downloads: int = int( + os.getenv("PINECONE_DATASETS_MAX_PARALLEL_DOWNLOADS", "4") + ) class Schema: diff --git a/pinecone_datasets/dataset_fsreader.py b/pinecone_datasets/dataset_fsreader.py index 2bcf37e..6352747 100644 --- a/pinecone_datasets/dataset_fsreader.py +++ b/pinecone_datasets/dataset_fsreader.py @@ -2,12 +2,13 @@ import logging import os import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Literal, Optional import pandas as pd import pyarrow.parquet as pq -from .cfg import Schema +from .cfg import Cache, Schema from .dataset_metadata import DatasetMetadata from .fs import CloudOrLocalFS, get_cached_path, is_cloud_path from .retry import create_cloud_storage_retry_decorator @@ -71,6 +72,49 @@ def _does_datatype_exist( ) -> bool: return fs.exists(os.path.join(dataset_path, data_type)) + @staticmethod + def _download_and_read_parquet( + path: str, + fs: CloudOrLocalFS, + use_cache: bool, + protocol: Optional[str], + ) -> tuple[int, pd.DataFrame]: + """ + Download (if needed) and read a single parquet file. + + Args: + path: Path to the parquet file + fs: Filesystem object + use_cache: Whether to use caching for this file + protocol: Protocol prefix (gs:// or s3://) if applicable + + Returns: + Tuple of (file_index, dataframe) where file_index is from the path + """ + if use_cache and protocol: + # Reconstruct full URL if path doesn't have protocol + if not path.startswith(protocol): + full_path = f"{protocol}{path}" + else: + full_path = path + # Download to cache and read from local path + local_path = get_cached_path(full_path, fs) + piece = pq.read_pandas(local_path) + else: + # Read directly from filesystem + piece = pq.read_pandas(path, filesystem=fs) + + df_piece = piece.to_pandas() + # Extract index from path for proper ordering (handles paths like "documents/0000.parquet") + try: + filename = os.path.basename(path) + file_index = int(os.path.splitext(filename)[0]) + except (ValueError, AttributeError): + # If we can't extract an index, use hash of path for consistent ordering + file_index = hash(path) + + return (file_index, df_piece) + @staticmethod def _safe_read_from_path( fs: CloudOrLocalFS, @@ -94,23 +138,52 @@ def _safe_read_from_path( elif dataset_path.startswith("https://s3.amazonaws.com/"): protocol = "s3://" - # First, collect all the dataframes - dfs = [] - for path in tqdm(read_path, desc=f"Loading {data_type}"): - if use_cache_for_dataset and protocol: - # Reconstruct full URL if path doesn't have protocol - if not path.startswith(protocol): - full_path = f"{protocol}{path}" - else: - full_path = path - # Download to cache and read from local path - local_path = get_cached_path(full_path, fs) - piece = pq.read_pandas(local_path) - else: - # Read directly from filesystem - piece = pq.read_pandas(path, filesystem=fs) - df_piece = piece.to_pandas() - dfs.append(df_piece) + # Collect all dataframes using parallel downloads + num_files = len(read_path) + max_workers = ( + min(Cache.max_parallel_downloads, num_files) if num_files > 1 else 1 + ) + + dfs_with_index = [] + + if max_workers == 1: + # Serial processing for single file or when max_workers=1 + for path in tqdm(read_path, desc=f"Loading {data_type}"): + file_index, df_piece = DatasetFSReader._download_and_read_parquet( + path, fs, use_cache_for_dataset, protocol + ) + dfs_with_index.append((file_index, df_piece)) + else: + # Parallel processing for multiple files + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all download tasks + future_to_path = { + executor.submit( + DatasetFSReader._download_and_read_parquet, + path, + fs, + use_cache_for_dataset, + protocol, + ): path + for path in read_path + } + + # Collect results as they complete with progress bar + with tqdm(total=num_files, desc=f"Loading {data_type}") as pbar: + for future in as_completed(future_to_path): + path = future_to_path[future] + try: + file_index, df_piece = future.result() + dfs_with_index.append((file_index, df_piece)) + except Exception as e: + logger.error(f"Failed to load {path}: {e}") + raise + finally: + pbar.update(1) + + # Sort by file index to maintain consistent ordering + dfs_with_index.sort(key=lambda x: x[0]) + dfs = [df for _, df in dfs_with_index] if not dfs: raise ValueError(f"No parquet files found in {read_path_str}") diff --git a/tests/unit/test_parallel_downloads.py b/tests/unit/test_parallel_downloads.py new file mode 100644 index 0000000..d891f74 --- /dev/null +++ b/tests/unit/test_parallel_downloads.py @@ -0,0 +1,269 @@ +""" +Unit tests for parallel download functionality. +""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq + +from pinecone_datasets import cfg +from pinecone_datasets.dataset_fsreader import DatasetFSReader + + +class TestParallelDownloads: + """Test parallel download functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.original_max_workers = cfg.Cache.max_parallel_downloads + + def teardown_method(self): + """Restore original configuration.""" + cfg.Cache.max_parallel_downloads = self.original_max_workers + + def test_parallel_download_configuration(self): + """Test that max_parallel_downloads can be configured.""" + # Test default value + assert cfg.Cache.max_parallel_downloads == 4 + + # Test setting via attribute + cfg.Cache.max_parallel_downloads = 8 + assert cfg.Cache.max_parallel_downloads == 8 + + def test_parallel_download_environment_variable(self): + """Test that max_parallel_downloads can be set via environment variable.""" + with patch.dict(os.environ, {"PINECONE_DATASETS_MAX_PARALLEL_DOWNLOADS": "10"}): + # Need to reload the module for env var to take effect + import importlib + + from pinecone_datasets import cfg as cfg_module + + importlib.reload(cfg_module) + assert cfg_module.Cache.max_parallel_downloads == 10 + + def test_download_and_read_parquet_extracts_file_index(self): + """Test that _download_and_read_parquet extracts file index from filename.""" + # Create a temporary parquet file + with tempfile.TemporaryDirectory() as tmpdir: + # Create test data + df = pd.DataFrame( + { + "id": ["1", "2"], + "values": [[0.1, 0.2], [0.3, 0.4]], + "sparse_values": [None, None], + "metadata": [None, None], + "blob": [None, None], + } + ) + + # Write to parquet with numbered filename + path = os.path.join(tmpdir, "0042.parquet") + table = pa.Table.from_pandas(df) + pq.write_table(table, path) + + # Mock filesystem + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Test extraction + file_index, result_df = DatasetFSReader._download_and_read_parquet( + path=path, fs=fs, use_cache=False, protocol=None + ) + + assert file_index == 42 + assert len(result_df) == 2 + + def test_download_and_read_parquet_handles_non_numeric_filenames(self): + """Test that _download_and_read_parquet handles non-numeric filenames.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test data + df = pd.DataFrame( + { + "id": ["1", "2"], + "values": [[0.1, 0.2], [0.3, 0.4]], + "sparse_values": [None, None], + "metadata": [None, None], + "blob": [None, None], + } + ) + + # Write to parquet with non-numeric filename + path = os.path.join(tmpdir, "part-abc.parquet") + table = pa.Table.from_pandas(df) + pq.write_table(table, path) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Test extraction - should use hash fallback + file_index, result_df = DatasetFSReader._download_and_read_parquet( + path=path, fs=fs, use_cache=False, protocol=None + ) + + # File index should be consistent for same path + assert isinstance(file_index, int) + assert file_index == hash(path) + + def test_safe_read_from_path_sorts_by_file_index(self): + """Test that files are read in correct order.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create dataset directory structure + docs_dir = Path(tmpdir) / "test-dataset" / "documents" + docs_dir.mkdir(parents=True) + + # Create multiple parquet files with different indices + # Write them in reverse order to test sorting + for i in [2, 0, 1]: + df = pd.DataFrame( + { + "id": [f"id_{i}"], + "values": [[float(i)]], + "sparse_values": [None], + "metadata": [None], + "blob": [None], + } + ) + path = docs_dir / f"{i:04d}.parquet" + table = pa.Table.from_pandas(df) + pq.write_table(table, str(path)) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Read with parallel processing disabled + cfg.Cache.max_parallel_downloads = 1 + result_df = DatasetFSReader._safe_read_from_path( + fs=fs, + dataset_path=str(Path(tmpdir) / "test-dataset"), + data_type="documents", + ) + + # Verify order + assert len(result_df) == 3 + assert result_df.iloc[0]["id"] == "id_0" + assert result_df.iloc[1]["id"] == "id_1" + assert result_df.iloc[2]["id"] == "id_2" + + def test_safe_read_from_path_with_parallel_workers(self): + """Test that parallel workers produce correct results.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create dataset directory structure + docs_dir = Path(tmpdir) / "test-dataset" / "documents" + docs_dir.mkdir(parents=True) + + # Create multiple parquet files + expected_ids = [] + for i in range(5): + df = pd.DataFrame( + { + "id": [f"id_{i}"], + "values": [[float(i)]], + "sparse_values": [None], + "metadata": [None], + "blob": [None], + } + ) + expected_ids.append(f"id_{i}") + path = docs_dir / f"{i:04d}.parquet" + table = pa.Table.from_pandas(df) + pq.write_table(table, str(path)) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Read with parallel processing enabled + cfg.Cache.max_parallel_downloads = 3 + result_df = DatasetFSReader._safe_read_from_path( + fs=fs, + dataset_path=str(Path(tmpdir) / "test-dataset"), + data_type="documents", + ) + + # Verify all data is present and in correct order + assert len(result_df) == 5 + result_ids = result_df["id"].tolist() + assert result_ids == expected_ids + + def test_safe_read_from_path_single_file_uses_serial(self): + """Test that single file datasets use serial processing.""" + with tempfile.TemporaryDirectory() as tmpdir: + docs_dir = Path(tmpdir) / "test-dataset" / "documents" + docs_dir.mkdir(parents=True) + + # Create single parquet file + df = pd.DataFrame( + { + "id": ["id_0"], + "values": [[0.1]], + "sparse_values": [None], + "metadata": [None], + "blob": [None], + } + ) + path = docs_dir / "0000.parquet" + table = pa.Table.from_pandas(df) + pq.write_table(table, str(path)) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Even with high max_workers, should use serial for single file + cfg.Cache.max_parallel_downloads = 10 + result_df = DatasetFSReader._safe_read_from_path( + fs=fs, + dataset_path=str(Path(tmpdir) / "test-dataset"), + data_type="documents", + ) + + assert len(result_df) == 1 + assert result_df.iloc[0]["id"] == "id_0" + + def test_safe_read_from_path_respects_max_workers_limit(self): + """Test that max_workers limit is respected.""" + with tempfile.TemporaryDirectory() as tmpdir: + docs_dir = Path(tmpdir) / "test-dataset" / "documents" + docs_dir.mkdir(parents=True) + + # Create 10 parquet files + for i in range(10): + df = pd.DataFrame( + { + "id": [f"id_{i}"], + "values": [[float(i)]], + "sparse_values": [None], + "metadata": [None], + "blob": [None], + } + ) + path = docs_dir / f"{i:04d}.parquet" + table = pa.Table.from_pandas(df) + pq.write_table(table, str(path)) + + from fsspec.implementations.local import LocalFileSystem + + fs = LocalFileSystem() + + # Set low max_workers + cfg.Cache.max_parallel_downloads = 2 + + result_df = DatasetFSReader._safe_read_from_path( + fs=fs, + dataset_path=str(Path(tmpdir) / "test-dataset"), + data_type="documents", + ) + + # Verify all files were processed + assert len(result_df) == 10 + # Verify order is maintained + for i in range(10): + assert result_df.iloc[i]["id"] == f"id_{i}"