forked from zilliztech/VectorDBBench
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtraining_script_ivfflat_L2.py
More file actions
88 lines (65 loc) · 2.48 KB
/
training_script_ivfflat_L2.py
File metadata and controls
88 lines (65 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import faiss
import numpy as np
from datasets import load_dataset
import itertools
import json
import pickle
d = 768 # Cohere embedding dimension
nlist = 1024 # Number of clusters
train_size = 298000 # Number of vectors for training
index_id = "cohere_wiki_ivfflat" # Unique ID for MySQL
print("=== IVFFlat Index Configuration ===")
print(f"Dimension (d): {d}")
print(f"Number of clusters (nlist): {nlist}")
print(f"Training size: {train_size}")
print(f"Index ID: {index_id}")
try:
with open("accumulated_cohere_embeddings.pkl", 'rb') as f:
embedding_data = pickle.load(f)
# Convert to numpy array (same format as original Cohere loading)
train_vectors = np.stack(embedding_data['embeddings'], axis=0).astype('float32')
print(f"Loaded {len(train_vectors)} training vectors")
print(f"Training vectors shape: {train_vectors.shape}")
print(f"First vector sample: {train_vectors[0][:10]}...") # Show first 10 dimensions
except Exception as e:
print(f"Error loading training vectors from file: {e}")
# ✅ Do NOT normalize for L2
print("Using raw vectors for L2 distance")
# ✅ Create quantizer and index for L2
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
print("Training Started")
# Train on subset (100K normalized vectors)
index.train(train_vectors)
# Extract index data for MySQL
print("Extracting index data...")
centroids = index.quantizer.reconstruct_n(0, nlist)
print(f"Coarse centroids shape: {centroids.shape}")
# Sanity check
assert centroids.shape == (nlist, d), f"Centroids shape mismatch: {centroids.shape}"
print("✓ Shape verification passed!")
# Compute norms for debug
centroid_norms = np.linalg.norm(centroids, axis=1)
print(f"Centroid norms - min: {centroid_norms.min():.6f}, max: {centroid_norms.max():.6f}, mean: {centroid_norms.mean():.6f}")
metadata_sql = f"""
INSERT INTO VECTORDB_DATA VALUES (
'{index_id}', 'metadata', 0,
JSON_OBJECT('version', 1, 'nlist', {nlist})
);
"""
quantizer_sqls = [
f"INSERT INTO VECTORDB_DATA VALUES ("
f"'{index_id}', 'quantizer', {i}, '{json.dumps(centroids[i].tolist())}'"
f");"
for i in range(nlist)
]
# Combine all SQL
full_sql = (
metadata_sql + "\n" +
"\n".join(quantizer_sqls)
)
# Save to file
with open("cohere_wiki_ivfflat_l2.sql", "w") as f:
f.write(full_sql)
print(f"SQL for auxiliary table saved to cohere_wiki_ivfflat_l2.sql")
print(f"Total centroids: {nlist}")