Populate DB Chromadb

This commit is contained in:
2026-01-24 07:52:48 +00:00
parent d145f7e94c
commit 4298368b63
10 changed files with 279 additions and 48 deletions

1
.gitignore vendored
View File

@@ -66,3 +66,4 @@ venv.bak/
# Tauri # Tauri
bun.lock bun.lock
dataset/

View File

@@ -1,12 +1,14 @@
flask flask
google-genai
gunicorn gunicorn
pymongo
ultralytics ultralytics
opencv-python-headless opencv-python-headless
transformers transformers
torch torch
pandas pandas
pypdf pypdf
openpyxl
python-dotenv python-dotenv
flask-cors flask-cors
ollama
chromadb-client
pymongo

View File

@@ -1,8 +1,8 @@
import os import os
import sys import sys
import argparse
from pathlib import Path from pathlib import Path
# Add backend directory to path so we can import src
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -10,21 +10,23 @@ load_dotenv()
from src.rag.ingest import process_file from src.rag.ingest import process_file
from src.rag.store import ingest_documents from src.rag.store import ingest_documents
from src.mongo.vector_store import is_file_processed, log_processed_file from src.mongo.metadata import is_file_processed, log_processed_file
def populate_from_dataset(dataset_dir): def populate_from_dataset(dataset_dir, category=None):
dataset_path = Path(dataset_dir) dataset_path = Path(dataset_dir)
if not dataset_path.exists(): if not dataset_path.exists():
print(f"Dataset directory not found: {dataset_dir}") print(f"Dataset directory not found: {dataset_dir}")
return return
print(f"Scanning {dataset_dir}...") print(f"Scanning {dataset_dir}...")
if category:
print(f"Category: {category}")
total_chunks = 0 total_chunks = 0
files_processed = 0 files_processed = 0
for file_path in dataset_path.glob('*'): for file_path in dataset_path.glob('*'):
if file_path.is_file() and file_path.suffix.lower() in ['.csv', '.pdf']: if file_path.is_file() and file_path.suffix.lower() in ['.csv', '.pdf', '.txt', '.xlsx']:
if is_file_processed(file_path.name): if is_file_processed(file_path.name):
print(f"Skipping {file_path.name} (already processed)") print(f"Skipping {file_path.name} (already processed)")
continue continue
@@ -33,10 +35,10 @@ def populate_from_dataset(dataset_dir):
try: try:
chunks = process_file(str(file_path)) chunks = process_file(str(file_path))
if chunks: if chunks:
count = ingest_documents(chunks) count = ingest_documents(chunks, source_file=file_path.name, category=category)
print(f" Ingested {count} chunks.") print(f" Ingested {count} chunks.")
if count > 0: if count > 0:
log_processed_file(file_path.name) log_processed_file(file_path.name, category=category, chunk_count=count)
total_chunks += count total_chunks += count
files_processed += 1 files_processed += 1
else: else:
@@ -47,6 +49,14 @@ def populate_from_dataset(dataset_dir):
print(f"\nFinished! Processed {files_processed} files. Total chunks ingested: {total_chunks}") print(f"\nFinished! Processed {files_processed} files. Total chunks ingested: {total_chunks}")
if __name__ == "__main__": if __name__ == "__main__":
# Assuming run from backend/ parser = argparse.ArgumentParser(description="Populate vector database from dataset files")
dataset_dir = os.path.join(os.path.dirname(__file__), '../dataset') parser.add_argument("--category", "-c", type=str, help="Category to assign to ingested documents")
populate_from_dataset(dataset_dir) parser.add_argument("--dir", "-d", type=str, default=None, help="Dataset directory path")
args = parser.parse_args()
if args.dir:
dataset_dir = args.dir
else:
dataset_dir = os.path.join(os.path.dirname(__file__), '../dataset')
populate_from_dataset(dataset_dir, category=args.category)

View File

View File

@@ -0,0 +1,69 @@
import chromadb
CHROMA_HOST = "http://chroma.sirblob.co"
COLLECTION_NAME = "rag_documents"
_client = None
def get_chroma_client():
global _client
if _client is None:
_client = chromadb.HttpClient(host=CHROMA_HOST)
return _client
def get_collection(collection_name=COLLECTION_NAME):
client = get_chroma_client()
return client.get_or_create_collection(name=collection_name)
def insert_documents(texts, embeddings, collection_name=COLLECTION_NAME, metadata_list=None):
collection = get_collection(collection_name)
ids = [f"doc_{i}_{hash(text)}" for i, text in enumerate(texts)]
if metadata_list:
collection.add(
ids=ids,
embeddings=embeddings,
documents=texts,
metadatas=metadata_list
)
else:
collection.add(
ids=ids,
embeddings=embeddings,
documents=texts
)
return len(texts)
def search_documents(query_embedding, collection_name=COLLECTION_NAME, num_results=5, filter_metadata=None):
collection = get_collection(collection_name)
query_params = {
"query_embeddings": [query_embedding],
"n_results": num_results
}
if filter_metadata:
query_params["where"] = filter_metadata
results = collection.query(**query_params)
output = []
if results and results["documents"]:
for i, doc in enumerate(results["documents"][0]):
score = results["distances"][0][i] if "distances" in results else None
output.append({
"text": doc,
"score": score
})
return output
def delete_documents_by_source(source_file, collection_name=COLLECTION_NAME):
collection = get_collection(collection_name)
results = collection.get(where={"source": source_file})
if results["ids"]:
collection.delete(ids=results["ids"])
return len(results["ids"])
return 0

View File

@@ -0,0 +1,62 @@
from .connection import get_mongo_client
from datetime import datetime
DB_NAME = "hoya_metadata"
def get_datasets_collection():
client = get_mongo_client()
db = client.get_database(DB_NAME)
return db["datasets"]
def get_categories_collection():
client = get_mongo_client()
db = client.get_database(DB_NAME)
return db["categories"]
def is_file_processed(filename):
collection = get_datasets_collection()
return collection.find_one({"filename": filename}) is not None
def log_processed_file(filename, category=None, chunk_count=0):
collection = get_datasets_collection()
doc = {
"filename": filename,
"category": category,
"chunk_count": chunk_count,
"processed_at": datetime.utcnow(),
"status": "processed"
}
collection.insert_one(doc)
def get_all_datasets():
collection = get_datasets_collection()
return list(collection.find({}, {"_id": 0}))
def get_datasets_by_category(category):
collection = get_datasets_collection()
return list(collection.find({"category": category}, {"_id": 0}))
def delete_dataset_record(filename):
collection = get_datasets_collection()
result = collection.delete_one({"filename": filename})
return result.deleted_count > 0
def create_category(name, description=""):
collection = get_categories_collection()
if collection.find_one({"name": name}):
return False
collection.insert_one({
"name": name,
"description": description,
"created_at": datetime.utcnow()
})
return True
def get_all_categories():
collection = get_categories_collection()
return list(collection.find({}, {"_id": 0}))
def delete_category(name):
collection = get_categories_collection()
result = collection.delete_one({"name": name})
return result.deleted_count > 0

View File

@@ -46,4 +46,4 @@ def log_processed_file(filename, log_collection="ingested_files", db_name="vecto
client = get_mongo_client() client = get_mongo_client()
db = client.get_database(db_name) db = client.get_database(db_name)
collection = db[log_collection] collection = db[log_collection]
collection.insert_one({"filename": filename, "processed_at": 1}) # keeping it simple collection.insert_one({"filename": filename, "processed_at": 1})

View File

@@ -1,26 +1,32 @@
from google import genai import ollama
import os import os
def get_embedding(text, model="gemini-embedding-001"): client = ollama.Client(host="https://ollama.sirblob.co")
api_key = os.environ.get("GOOGLE_API_KEY") DEFAULT_MODEL = "nomic-embed-text:latest"
if not api_key:
raise ValueError("GOOGLE_API_KEY environment variable not set")
client = genai.Client(api_key=api_key) def get_embedding(text, model=DEFAULT_MODEL):
result = client.models.embed_content( try:
model=model, response = client.embeddings(model=model, prompt=text)
contents=text return response["embedding"]
) except Exception as e:
return result.embeddings[0].values print(f"Error getting embedding from Ollama: {e}")
raise e
def get_embeddings_batch(texts, model="gemini-embedding-001"): def get_embeddings_batch(texts, model=DEFAULT_MODEL, batch_size=50):
api_key = os.environ.get("GOOGLE_API_KEY") all_embeddings = []
if not api_key:
raise ValueError("GOOGLE_API_KEY environment variable not set")
client = genai.Client(api_key=api_key) for i in range(0, len(texts), batch_size):
result = client.models.embed_content( batch = texts[i:i + batch_size]
model=model, try:
contents=texts response = client.embed(model=model, input=batch)
)
return [emb.values for emb in result.embeddings] if "embeddings" in response:
all_embeddings.extend(response["embeddings"])
else:
raise ValueError("Unexpected response format from client.embed")
except Exception as e:
print(f"Error embedding batch {i}-{i+batch_size}: {e}")
raise e
return all_embeddings

View File

@@ -3,6 +3,40 @@ from pypdf import PdfReader
import io import io
import os import os
def chunk_text(text, target_length=2000, overlap=100):
if not text:
return []
chunks = []
paragraphs = text.split('\n\n')
current_chunk = ""
for para in paragraphs:
if len(current_chunk) + len(para) > target_length:
if current_chunk:
chunks.append(current_chunk.strip())
if len(para) > target_length:
start = 0
while start < len(para):
end = start + target_length
chunks.append(para[start:end].strip())
start += (target_length - overlap)
current_chunk = ""
else:
current_chunk = para
else:
if current_chunk:
current_chunk += "\n\n" + para
else:
current_chunk = para
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def load_csv(file_path): def load_csv(file_path):
df = pd.read_csv(file_path) df = pd.read_csv(file_path)
return df.apply(lambda x: ' | '.join(x.astype(str)), axis=1).tolist() return df.apply(lambda x: ' | '.join(x.astype(str)), axis=1).tolist()
@@ -13,14 +47,52 @@ def load_pdf(file_path):
for page in reader.pages: for page in reader.pages:
text = page.extract_text() text = page.extract_text()
if text: if text:
text_chunks.append(text) if len(text) > 4000:
text_chunks.extend(chunk_text(text))
else:
text_chunks.append(text)
return text_chunks return text_chunks
def load_txt(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return chunk_text(content)
def load_xlsx(file_path):
all_rows = []
try:
sheets = pd.read_excel(file_path, sheet_name=None)
except Exception as e:
raise ValueError(f"Pandas read_excel failed: {e}")
for sheet_name, df in sheets.items():
if df.empty:
continue
df = df.fillna("")
for row in df.values:
row_items = [str(x) for x in row if str(x).strip() != ""]
if row_items:
row_str = f"Sheet: {str(sheet_name)} | " + " | ".join(row_items)
if len(row_str) > 8000:
all_rows.extend(chunk_text(row_str))
else:
all_rows.append(row_str)
return all_rows
def process_file(file_path): def process_file(file_path):
ext = os.path.splitext(file_path)[1].lower() ext = os.path.splitext(file_path)[1].lower()
if ext == '.csv': if ext == '.csv':
return load_csv(file_path) return load_csv(file_path)
elif ext == '.pdf': elif ext == '.pdf':
return load_pdf(file_path) return load_pdf(file_path)
elif ext == '.txt':
return load_txt(file_path)
elif ext == '.xlsx':
return load_xlsx(file_path)
else: else:
raise ValueError(f"Unsupported file type: {ext}") raise ValueError(f"Unsupported file type: {ext}")

View File

@@ -1,18 +1,27 @@
from .embeddings import get_embeddings_batch, get_embedding from .embeddings import get_embeddings_batch, get_embedding
from ..mongo.vector_store import insert_rag_documents, search_rag_documents from ..chroma.vector_store import insert_documents, search_documents
def ingest_documents(text_chunks, collection_name="rag_documents"): def ingest_documents(text_chunks, collection_name="rag_documents", source_file=None, category=None):
embeddings = get_embeddings_batch(text_chunks) embeddings = get_embeddings_batch(text_chunks)
documents = [] metadata_list = None
for text, embedding in zip(text_chunks, embeddings): if source_file or category:
documents.append({ metadata_list = []
"text": text, for _ in text_chunks:
"embedding": embedding meta = {}
}) if source_file:
meta["source"] = source_file
if category:
meta["category"] = category
metadata_list.append(meta)
return insert_rag_documents(documents, collection_name=collection_name) return insert_documents(text_chunks, embeddings, collection_name=collection_name, metadata_list=metadata_list)
def vector_search(query_text, collection_name="rag_documents", num_results=5): def vector_search(query_text, collection_name="rag_documents", num_results=5, category=None):
query_embedding = get_embedding(query_text) query_embedding = get_embedding(query_text)
return search_rag_documents(query_embedding, collection_name=collection_name, num_results=num_results)
filter_metadata = None
if category:
filter_metadata = {"category": category}
return search_documents(query_embedding, collection_name=collection_name, num_results=num_results, filter_metadata=filter_metadata)