Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ dist/
.mypy_cache/
scratchpad.ipynb
.pycache/
__pycache__/
*.pyc
.pytest_cache/
.coverage
3 changes: 3 additions & 0 deletions pinecone_datasets/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Cache:
"1",
"yes",
)
max_parallel_downloads: int = int(
os.getenv("PINECONE_DATASETS_MAX_PARALLEL_DOWNLOADS", "4")
)


class Schema:
Expand Down
109 changes: 91 additions & 18 deletions pinecone_datasets/dataset_fsreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import logging
import os
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional

import pandas as pd
import pyarrow.parquet as pq

from .cfg import Schema
from .cfg import Cache, Schema
from .dataset_metadata import DatasetMetadata
from .fs import CloudOrLocalFS, get_cached_path, is_cloud_path
from .retry import create_cloud_storage_retry_decorator
Expand Down Expand Up @@ -71,6 +72,49 @@ def _does_datatype_exist(
) -> bool:
return fs.exists(os.path.join(dataset_path, data_type))

@staticmethod
def _download_and_read_parquet(
path: str,
fs: CloudOrLocalFS,
use_cache: bool,
protocol: Optional[str],
) -> tuple[int, pd.DataFrame]:
"""
Download (if needed) and read a single parquet file.

Args:
path: Path to the parquet file
fs: Filesystem object
use_cache: Whether to use caching for this file
protocol: Protocol prefix (gs:// or s3://) if applicable

Returns:
Tuple of (file_index, dataframe) where file_index is from the path
"""
if use_cache and protocol:
# Reconstruct full URL if path doesn't have protocol
if not path.startswith(protocol):
full_path = f"{protocol}{path}"
else:
full_path = path
# Download to cache and read from local path
local_path = get_cached_path(full_path, fs)
piece = pq.read_pandas(local_path)
else:
# Read directly from filesystem
piece = pq.read_pandas(path, filesystem=fs)

df_piece = piece.to_pandas()
# Extract index from path for proper ordering (handles paths like "documents/0000.parquet")
try:
filename = os.path.basename(path)
file_index = int(os.path.splitext(filename)[0])
except (ValueError, AttributeError):
# If we can't extract an index, use hash of path for consistent ordering
file_index = hash(path)

return (file_index, df_piece)

@staticmethod
def _safe_read_from_path(
fs: CloudOrLocalFS,
Expand All @@ -94,23 +138,52 @@ def _safe_read_from_path(
elif dataset_path.startswith("https://s3.amazonaws.com/"):
protocol = "s3://"

# First, collect all the dataframes
dfs = []
for path in tqdm(read_path, desc=f"Loading {data_type}"):
if use_cache_for_dataset and protocol:
# Reconstruct full URL if path doesn't have protocol
if not path.startswith(protocol):
full_path = f"{protocol}{path}"
else:
full_path = path
# Download to cache and read from local path
local_path = get_cached_path(full_path, fs)
piece = pq.read_pandas(local_path)
else:
# Read directly from filesystem
piece = pq.read_pandas(path, filesystem=fs)
df_piece = piece.to_pandas()
dfs.append(df_piece)
# Collect all dataframes using parallel downloads
num_files = len(read_path)
max_workers = (
min(Cache.max_parallel_downloads, num_files) if num_files > 1 else 1
)

dfs_with_index = []

if max_workers == 1:
# Serial processing for single file or when max_workers=1
for path in tqdm(read_path, desc=f"Loading {data_type}"):
file_index, df_piece = DatasetFSReader._download_and_read_parquet(
path, fs, use_cache_for_dataset, protocol
)
dfs_with_index.append((file_index, df_piece))
else:
# Parallel processing for multiple files
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all download tasks
future_to_path = {
executor.submit(
DatasetFSReader._download_and_read_parquet,
path,
fs,
use_cache_for_dataset,
protocol,
): path
for path in read_path
}

# Collect results as they complete with progress bar
with tqdm(total=num_files, desc=f"Loading {data_type}") as pbar:
for future in as_completed(future_to_path):
path = future_to_path[future]
try:
file_index, df_piece = future.result()
dfs_with_index.append((file_index, df_piece))
except Exception as e:
logger.error(f"Failed to load {path}: {e}")
raise
finally:
pbar.update(1)

# Sort by file index to maintain consistent ordering
dfs_with_index.sort(key=lambda x: x[0])
dfs = [df for _, df in dfs_with_index]

if not dfs:
raise ValueError(f"No parquet files found in {read_path_str}")
Expand Down
Loading