-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_index.py
More file actions
146 lines (123 loc) · 5.17 KB
/
Copy pathbuild_index.py
File metadata and controls
146 lines (123 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Build the Chroma schema index for one (or all) registered databases.
Live tool — calls Mistral `mistral-embed` for vectors. Idempotent: re-runs
upsert chunks under the same `chunk_id` (db::table), so vectors get refreshed
in place and stale chunks for renamed tables are NOT auto-pruned (run with
``--reset`` to clear the collection first if you have schema deletions).
The default ``--sample-size`` is imported from ``PipelineConfig.primary_sample_size``
so the index is built with the same density runtime expects. Pass an explicit
value only if you want to rebuild for a non-default runtime configuration.
Usage:
uv run python scripts/build_index.py --db chinook
uv run python scripts/build_index.py --db all --persist chroma_data
uv run python scripts/build_index.py --db chinook --reset
"""
from __future__ import annotations
import argparse
import sys
from dataclasses import fields
from pathlib import Path
import chromadb
from nl_sql.agent.graph import PipelineConfig
from nl_sql.config import get_settings
from nl_sql.db.registry import get_default_registry
from nl_sql.llm.cache import CachingEmbeddingProvider
from nl_sql.llm.providers.base import EmbeddingProvider
from nl_sql.llm.providers.mistral import MistralProvider
from nl_sql.schema_index.chunker import to_chunks
from nl_sql.schema_index.indexer import SCHEMA_COLLECTION, SchemaIndex
from nl_sql.schema_index.introspector import introspect
def _runtime_sample_size_default() -> int:
"""Read `PipelineConfig.primary_sample_size` default without constructing
the dataclass (it requires live providers/registry we don't have here)."""
for field_ in fields(PipelineConfig):
if field_.name == "primary_sample_size":
default = field_.default
if isinstance(default, int):
return default
raise RuntimeError("PipelineConfig.primary_sample_size default missing")
DEFAULT_SAMPLE_SIZE: int = _runtime_sample_size_default()
"""Source of truth for the sample density baked into Chroma chunks.
Runtime expects this to equal `PipelineConfig.primary_sample_size`; the
mixture appendix breaks if the index is built with more samples than
runtime advertises."""
def build_for_db(idx: SchemaIndex, db_id: str, *, sample_size: int = DEFAULT_SAMPLE_SIZE) -> int:
registry = get_default_registry()
spec = registry.get(db_id)
print(f"[introspect] {db_id} ({spec.url})")
tables = introspect(spec.make_engine(), sample_size=sample_size)
print(f"[chunk] {len(tables)} tables → chunks")
chunks = to_chunks(tables, db_id=db_id)
print(f"[index] embedding + upserting {len(chunks)} chunks")
n = idx.index_schema(chunks)
print(f"[done] {db_id}: {n} chunks indexed")
return n
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--db",
required=True,
help="Database id (e.g. 'chinook', 'bird_california_schools') or 'all'.",
)
parser.add_argument(
"--persist",
default="chroma_data",
help="Chroma persist directory (default: chroma_data/)",
)
parser.add_argument(
"--sample-size",
type=int,
default=DEFAULT_SAMPLE_SIZE,
help=(
"Top-K sample values per column to bake into each chunk "
f"(default: {DEFAULT_SAMPLE_SIZE} = PipelineConfig.primary_sample_size). "
"Keep aligned with runtime or the sample-mixture appendix breaks."
),
)
parser.add_argument(
"--reset",
action="store_true",
help="Drop the schema_chunks collection before indexing.",
)
parser.add_argument(
"--no-cache",
action="store_true",
help="Disable diskcache wrapper around the embedding provider.",
)
return parser
def main(argv: list[str] | None = None) -> int:
args = build_parser().parse_args(argv)
settings = get_settings()
persist = Path(args.persist)
persist.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(path=str(persist))
if args.reset:
try:
client.delete_collection(SCHEMA_COLLECTION)
print(f"[reset] dropped {SCHEMA_COLLECTION}")
except Exception as exc:
print(f"[reset] no existing {SCHEMA_COLLECTION} to drop ({exc})")
raw_embedder = MistralProvider(
api_key=settings.mistral_api_key,
gen_model=settings.mistral_gen_model,
embed_model=settings.mistral_embed_model,
base_url=settings.mistral_base_url,
)
embedder: EmbeddingProvider = (
raw_embedder
if args.no_cache
else CachingEmbeddingProvider(
raw_embedder,
cache_dir=settings.llm_cache_dir,
size_limit_gb=settings.llm_cache_size_limit_gb,
)
)
idx = SchemaIndex(persist_dir=persist, embedder=embedder, client=client)
registry = get_default_registry()
targets = registry.ids() if args.db == "all" else [args.db]
total = 0
for db_id in targets:
total += build_for_db(idx, db_id, sample_size=args.sample_size)
print(f"[summary] indexed {total} chunks across {len(targets)} db(s)")
return 0
if __name__ == "__main__":
sys.exit(main())