From 4298368b63e20f668ccec20f6ba0598a6e82ea30 Mon Sep 17 00:00:00 2001 From: default Date: Sat, 24 Jan 2026 07:52:48 +0000 Subject: [PATCH] Populate DB Chromadb --- .gitignore | 3 +- backend/requirements.txt | 6 ++- backend/scripts/populate_db.py | 28 +++++++---- backend/src/chroma/__init__.py | 0 backend/src/chroma/vector_store.py | 69 ++++++++++++++++++++++++++++ backend/src/mongo/metadata.py | 62 +++++++++++++++++++++++++ backend/src/mongo/vector_store.py | 2 +- backend/src/rag/embeddings.py | 52 +++++++++++---------- backend/src/rag/ingest.py | 74 +++++++++++++++++++++++++++++- backend/src/rag/store.py | 31 ++++++++----- 10 files changed, 279 insertions(+), 48 deletions(-) create mode 100644 backend/src/chroma/__init__.py create mode 100644 backend/src/chroma/vector_store.py create mode 100644 backend/src/mongo/metadata.py diff --git a/.gitignore b/.gitignore index eb41763..a43d6ac 100644 --- a/.gitignore +++ b/.gitignore @@ -65,4 +65,5 @@ venv.bak/ # Tauri -bun.lock \ No newline at end of file +bun.lock +dataset/ \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt index 6f0c0e3..2ece870 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,12 +1,14 @@ flask -google-genai gunicorn -pymongo ultralytics opencv-python-headless transformers torch pandas pypdf +openpyxl python-dotenv flask-cors +ollama +chromadb-client +pymongo diff --git a/backend/scripts/populate_db.py b/backend/scripts/populate_db.py index 72d40f2..3865dcd 100644 --- a/backend/scripts/populate_db.py +++ b/backend/scripts/populate_db.py @@ -1,8 +1,8 @@ import os import sys +import argparse from pathlib import Path -# Add backend directory to path so we can import src sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from dotenv import load_dotenv @@ -10,21 +10,23 @@ load_dotenv() from src.rag.ingest import process_file from src.rag.store import ingest_documents -from src.mongo.vector_store import is_file_processed, log_processed_file +from src.mongo.metadata import is_file_processed, log_processed_file -def populate_from_dataset(dataset_dir): +def populate_from_dataset(dataset_dir, category=None): dataset_path = Path(dataset_dir) if not dataset_path.exists(): print(f"Dataset directory not found: {dataset_dir}") return print(f"Scanning {dataset_dir}...") + if category: + print(f"Category: {category}") total_chunks = 0 files_processed = 0 for file_path in dataset_path.glob('*'): - if file_path.is_file() and file_path.suffix.lower() in ['.csv', '.pdf']: + if file_path.is_file() and file_path.suffix.lower() in ['.csv', '.pdf', '.txt', '.xlsx']: if is_file_processed(file_path.name): print(f"Skipping {file_path.name} (already processed)") continue @@ -33,10 +35,10 @@ def populate_from_dataset(dataset_dir): try: chunks = process_file(str(file_path)) if chunks: - count = ingest_documents(chunks) + count = ingest_documents(chunks, source_file=file_path.name, category=category) print(f" Ingested {count} chunks.") if count > 0: - log_processed_file(file_path.name) + log_processed_file(file_path.name, category=category, chunk_count=count) total_chunks += count files_processed += 1 else: @@ -47,6 +49,14 @@ def populate_from_dataset(dataset_dir): print(f"\nFinished! Processed {files_processed} files. Total chunks ingested: {total_chunks}") if __name__ == "__main__": - # Assuming run from backend/ - dataset_dir = os.path.join(os.path.dirname(__file__), '../dataset') - populate_from_dataset(dataset_dir) + parser = argparse.ArgumentParser(description="Populate vector database from dataset files") + parser.add_argument("--category", "-c", type=str, help="Category to assign to ingested documents") + parser.add_argument("--dir", "-d", type=str, default=None, help="Dataset directory path") + args = parser.parse_args() + + if args.dir: + dataset_dir = args.dir + else: + dataset_dir = os.path.join(os.path.dirname(__file__), '../dataset') + + populate_from_dataset(dataset_dir, category=args.category) diff --git a/backend/src/chroma/__init__.py b/backend/src/chroma/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/src/chroma/vector_store.py b/backend/src/chroma/vector_store.py new file mode 100644 index 0000000..bcb259a --- /dev/null +++ b/backend/src/chroma/vector_store.py @@ -0,0 +1,69 @@ +import chromadb + +CHROMA_HOST = "http://chroma.sirblob.co" +COLLECTION_NAME = "rag_documents" + +_client = None + +def get_chroma_client(): + global _client + if _client is None: + _client = chromadb.HttpClient(host=CHROMA_HOST) + return _client + +def get_collection(collection_name=COLLECTION_NAME): + client = get_chroma_client() + return client.get_or_create_collection(name=collection_name) + +def insert_documents(texts, embeddings, collection_name=COLLECTION_NAME, metadata_list=None): + collection = get_collection(collection_name) + + ids = [f"doc_{i}_{hash(text)}" for i, text in enumerate(texts)] + + if metadata_list: + collection.add( + ids=ids, + embeddings=embeddings, + documents=texts, + metadatas=metadata_list + ) + else: + collection.add( + ids=ids, + embeddings=embeddings, + documents=texts + ) + + return len(texts) + +def search_documents(query_embedding, collection_name=COLLECTION_NAME, num_results=5, filter_metadata=None): + collection = get_collection(collection_name) + + query_params = { + "query_embeddings": [query_embedding], + "n_results": num_results + } + + if filter_metadata: + query_params["where"] = filter_metadata + + results = collection.query(**query_params) + + output = [] + if results and results["documents"]: + for i, doc in enumerate(results["documents"][0]): + score = results["distances"][0][i] if "distances" in results else None + output.append({ + "text": doc, + "score": score + }) + + return output + +def delete_documents_by_source(source_file, collection_name=COLLECTION_NAME): + collection = get_collection(collection_name) + results = collection.get(where={"source": source_file}) + if results["ids"]: + collection.delete(ids=results["ids"]) + return len(results["ids"]) + return 0 diff --git a/backend/src/mongo/metadata.py b/backend/src/mongo/metadata.py new file mode 100644 index 0000000..0a04e52 --- /dev/null +++ b/backend/src/mongo/metadata.py @@ -0,0 +1,62 @@ +from .connection import get_mongo_client +from datetime import datetime + +DB_NAME = "hoya_metadata" + +def get_datasets_collection(): + client = get_mongo_client() + db = client.get_database(DB_NAME) + return db["datasets"] + +def get_categories_collection(): + client = get_mongo_client() + db = client.get_database(DB_NAME) + return db["categories"] + +def is_file_processed(filename): + collection = get_datasets_collection() + return collection.find_one({"filename": filename}) is not None + +def log_processed_file(filename, category=None, chunk_count=0): + collection = get_datasets_collection() + doc = { + "filename": filename, + "category": category, + "chunk_count": chunk_count, + "processed_at": datetime.utcnow(), + "status": "processed" + } + collection.insert_one(doc) + +def get_all_datasets(): + collection = get_datasets_collection() + return list(collection.find({}, {"_id": 0})) + +def get_datasets_by_category(category): + collection = get_datasets_collection() + return list(collection.find({"category": category}, {"_id": 0})) + +def delete_dataset_record(filename): + collection = get_datasets_collection() + result = collection.delete_one({"filename": filename}) + return result.deleted_count > 0 + +def create_category(name, description=""): + collection = get_categories_collection() + if collection.find_one({"name": name}): + return False + collection.insert_one({ + "name": name, + "description": description, + "created_at": datetime.utcnow() + }) + return True + +def get_all_categories(): + collection = get_categories_collection() + return list(collection.find({}, {"_id": 0})) + +def delete_category(name): + collection = get_categories_collection() + result = collection.delete_one({"name": name}) + return result.deleted_count > 0 diff --git a/backend/src/mongo/vector_store.py b/backend/src/mongo/vector_store.py index c1162f6..b495a86 100644 --- a/backend/src/mongo/vector_store.py +++ b/backend/src/mongo/vector_store.py @@ -46,4 +46,4 @@ def log_processed_file(filename, log_collection="ingested_files", db_name="vecto client = get_mongo_client() db = client.get_database(db_name) collection = db[log_collection] - collection.insert_one({"filename": filename, "processed_at": 1}) # keeping it simple + collection.insert_one({"filename": filename, "processed_at": 1}) diff --git a/backend/src/rag/embeddings.py b/backend/src/rag/embeddings.py index 66b350c..63824c4 100644 --- a/backend/src/rag/embeddings.py +++ b/backend/src/rag/embeddings.py @@ -1,26 +1,32 @@ -from google import genai +import ollama import os -def get_embedding(text, model="gemini-embedding-001"): - api_key = os.environ.get("GOOGLE_API_KEY") - if not api_key: - raise ValueError("GOOGLE_API_KEY environment variable not set") - - client = genai.Client(api_key=api_key) - result = client.models.embed_content( - model=model, - contents=text - ) - return result.embeddings[0].values +client = ollama.Client(host="https://ollama.sirblob.co") +DEFAULT_MODEL = "nomic-embed-text:latest" -def get_embeddings_batch(texts, model="gemini-embedding-001"): - api_key = os.environ.get("GOOGLE_API_KEY") - if not api_key: - raise ValueError("GOOGLE_API_KEY environment variable not set") - - client = genai.Client(api_key=api_key) - result = client.models.embed_content( - model=model, - contents=texts - ) - return [emb.values for emb in result.embeddings] +def get_embedding(text, model=DEFAULT_MODEL): + try: + response = client.embeddings(model=model, prompt=text) + return response["embedding"] + except Exception as e: + print(f"Error getting embedding from Ollama: {e}") + raise e + +def get_embeddings_batch(texts, model=DEFAULT_MODEL, batch_size=50): + all_embeddings = [] + + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + try: + response = client.embed(model=model, input=batch) + + if "embeddings" in response: + all_embeddings.extend(response["embeddings"]) + else: + raise ValueError("Unexpected response format from client.embed") + + except Exception as e: + print(f"Error embedding batch {i}-{i+batch_size}: {e}") + raise e + + return all_embeddings diff --git a/backend/src/rag/ingest.py b/backend/src/rag/ingest.py index 16de74a..4be32b6 100644 --- a/backend/src/rag/ingest.py +++ b/backend/src/rag/ingest.py @@ -3,6 +3,40 @@ from pypdf import PdfReader import io import os +def chunk_text(text, target_length=2000, overlap=100): + if not text: + return [] + + chunks = [] + + paragraphs = text.split('\n\n') + current_chunk = "" + + for para in paragraphs: + if len(current_chunk) + len(para) > target_length: + if current_chunk: + chunks.append(current_chunk.strip()) + + if len(para) > target_length: + start = 0 + while start < len(para): + end = start + target_length + chunks.append(para[start:end].strip()) + start += (target_length - overlap) + current_chunk = "" + else: + current_chunk = para + else: + if current_chunk: + current_chunk += "\n\n" + para + else: + current_chunk = para + + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + def load_csv(file_path): df = pd.read_csv(file_path) return df.apply(lambda x: ' | '.join(x.astype(str)), axis=1).tolist() @@ -13,14 +47,52 @@ def load_pdf(file_path): for page in reader.pages: text = page.extract_text() if text: - text_chunks.append(text) + if len(text) > 4000: + text_chunks.extend(chunk_text(text)) + else: + text_chunks.append(text) return text_chunks +def load_txt(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + return chunk_text(content) + +def load_xlsx(file_path): + all_rows = [] + try: + sheets = pd.read_excel(file_path, sheet_name=None) + except Exception as e: + raise ValueError(f"Pandas read_excel failed: {e}") + + for sheet_name, df in sheets.items(): + if df.empty: + continue + + df = df.fillna("") + + for row in df.values: + row_items = [str(x) for x in row if str(x).strip() != ""] + if row_items: + row_str = f"Sheet: {str(sheet_name)} | " + " | ".join(row_items) + + if len(row_str) > 8000: + all_rows.extend(chunk_text(row_str)) + else: + all_rows.append(row_str) + + return all_rows + def process_file(file_path): ext = os.path.splitext(file_path)[1].lower() if ext == '.csv': return load_csv(file_path) elif ext == '.pdf': return load_pdf(file_path) + elif ext == '.txt': + return load_txt(file_path) + elif ext == '.xlsx': + return load_xlsx(file_path) else: raise ValueError(f"Unsupported file type: {ext}") diff --git a/backend/src/rag/store.py b/backend/src/rag/store.py index 1ba8628..a71d0ac 100644 --- a/backend/src/rag/store.py +++ b/backend/src/rag/store.py @@ -1,18 +1,27 @@ from .embeddings import get_embeddings_batch, get_embedding -from ..mongo.vector_store import insert_rag_documents, search_rag_documents +from ..chroma.vector_store import insert_documents, search_documents -def ingest_documents(text_chunks, collection_name="rag_documents"): +def ingest_documents(text_chunks, collection_name="rag_documents", source_file=None, category=None): embeddings = get_embeddings_batch(text_chunks) - documents = [] - for text, embedding in zip(text_chunks, embeddings): - documents.append({ - "text": text, - "embedding": embedding - }) + metadata_list = None + if source_file or category: + metadata_list = [] + for _ in text_chunks: + meta = {} + if source_file: + meta["source"] = source_file + if category: + meta["category"] = category + metadata_list.append(meta) - return insert_rag_documents(documents, collection_name=collection_name) + return insert_documents(text_chunks, embeddings, collection_name=collection_name, metadata_list=metadata_list) -def vector_search(query_text, collection_name="rag_documents", num_results=5): +def vector_search(query_text, collection_name="rag_documents", num_results=5, category=None): query_embedding = get_embedding(query_text) - return search_rag_documents(query_embedding, collection_name=collection_name, num_results=num_results) + + filter_metadata = None + if category: + filter_metadata = {"category": category} + + return search_documents(query_embedding, collection_name=collection_name, num_results=num_results, filter_metadata=filter_metadata)