mirror of
https://github.com/SirBlobby/Hoya26.git
synced 2026-02-03 19:24:34 -05:00
Populate DB Chromadb
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -65,4 +65,5 @@ venv.bak/
|
|||||||
|
|
||||||
# Tauri
|
# Tauri
|
||||||
|
|
||||||
bun.lock
|
bun.lock
|
||||||
|
dataset/
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
flask
|
flask
|
||||||
google-genai
|
|
||||||
gunicorn
|
gunicorn
|
||||||
pymongo
|
|
||||||
ultralytics
|
ultralytics
|
||||||
opencv-python-headless
|
opencv-python-headless
|
||||||
transformers
|
transformers
|
||||||
torch
|
torch
|
||||||
pandas
|
pandas
|
||||||
pypdf
|
pypdf
|
||||||
|
openpyxl
|
||||||
python-dotenv
|
python-dotenv
|
||||||
flask-cors
|
flask-cors
|
||||||
|
ollama
|
||||||
|
chromadb-client
|
||||||
|
pymongo
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Add backend directory to path so we can import src
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -10,21 +10,23 @@ load_dotenv()
|
|||||||
|
|
||||||
from src.rag.ingest import process_file
|
from src.rag.ingest import process_file
|
||||||
from src.rag.store import ingest_documents
|
from src.rag.store import ingest_documents
|
||||||
from src.mongo.vector_store import is_file_processed, log_processed_file
|
from src.mongo.metadata import is_file_processed, log_processed_file
|
||||||
|
|
||||||
def populate_from_dataset(dataset_dir):
|
def populate_from_dataset(dataset_dir, category=None):
|
||||||
dataset_path = Path(dataset_dir)
|
dataset_path = Path(dataset_dir)
|
||||||
if not dataset_path.exists():
|
if not dataset_path.exists():
|
||||||
print(f"Dataset directory not found: {dataset_dir}")
|
print(f"Dataset directory not found: {dataset_dir}")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Scanning {dataset_dir}...")
|
print(f"Scanning {dataset_dir}...")
|
||||||
|
if category:
|
||||||
|
print(f"Category: {category}")
|
||||||
|
|
||||||
total_chunks = 0
|
total_chunks = 0
|
||||||
files_processed = 0
|
files_processed = 0
|
||||||
|
|
||||||
for file_path in dataset_path.glob('*'):
|
for file_path in dataset_path.glob('*'):
|
||||||
if file_path.is_file() and file_path.suffix.lower() in ['.csv', '.pdf']:
|
if file_path.is_file() and file_path.suffix.lower() in ['.csv', '.pdf', '.txt', '.xlsx']:
|
||||||
if is_file_processed(file_path.name):
|
if is_file_processed(file_path.name):
|
||||||
print(f"Skipping {file_path.name} (already processed)")
|
print(f"Skipping {file_path.name} (already processed)")
|
||||||
continue
|
continue
|
||||||
@@ -33,10 +35,10 @@ def populate_from_dataset(dataset_dir):
|
|||||||
try:
|
try:
|
||||||
chunks = process_file(str(file_path))
|
chunks = process_file(str(file_path))
|
||||||
if chunks:
|
if chunks:
|
||||||
count = ingest_documents(chunks)
|
count = ingest_documents(chunks, source_file=file_path.name, category=category)
|
||||||
print(f" Ingested {count} chunks.")
|
print(f" Ingested {count} chunks.")
|
||||||
if count > 0:
|
if count > 0:
|
||||||
log_processed_file(file_path.name)
|
log_processed_file(file_path.name, category=category, chunk_count=count)
|
||||||
total_chunks += count
|
total_chunks += count
|
||||||
files_processed += 1
|
files_processed += 1
|
||||||
else:
|
else:
|
||||||
@@ -47,6 +49,14 @@ def populate_from_dataset(dataset_dir):
|
|||||||
print(f"\nFinished! Processed {files_processed} files. Total chunks ingested: {total_chunks}")
|
print(f"\nFinished! Processed {files_processed} files. Total chunks ingested: {total_chunks}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Assuming run from backend/
|
parser = argparse.ArgumentParser(description="Populate vector database from dataset files")
|
||||||
dataset_dir = os.path.join(os.path.dirname(__file__), '../dataset')
|
parser.add_argument("--category", "-c", type=str, help="Category to assign to ingested documents")
|
||||||
populate_from_dataset(dataset_dir)
|
parser.add_argument("--dir", "-d", type=str, default=None, help="Dataset directory path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.dir:
|
||||||
|
dataset_dir = args.dir
|
||||||
|
else:
|
||||||
|
dataset_dir = os.path.join(os.path.dirname(__file__), '../dataset')
|
||||||
|
|
||||||
|
populate_from_dataset(dataset_dir, category=args.category)
|
||||||
|
|||||||
0
backend/src/chroma/__init__.py
Normal file
0
backend/src/chroma/__init__.py
Normal file
69
backend/src/chroma/vector_store.py
Normal file
69
backend/src/chroma/vector_store.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import chromadb
|
||||||
|
|
||||||
|
CHROMA_HOST = "http://chroma.sirblob.co"
|
||||||
|
COLLECTION_NAME = "rag_documents"
|
||||||
|
|
||||||
|
_client = None
|
||||||
|
|
||||||
|
def get_chroma_client():
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
_client = chromadb.HttpClient(host=CHROMA_HOST)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
def get_collection(collection_name=COLLECTION_NAME):
|
||||||
|
client = get_chroma_client()
|
||||||
|
return client.get_or_create_collection(name=collection_name)
|
||||||
|
|
||||||
|
def insert_documents(texts, embeddings, collection_name=COLLECTION_NAME, metadata_list=None):
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
|
||||||
|
ids = [f"doc_{i}_{hash(text)}" for i, text in enumerate(texts)]
|
||||||
|
|
||||||
|
if metadata_list:
|
||||||
|
collection.add(
|
||||||
|
ids=ids,
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=texts,
|
||||||
|
metadatas=metadata_list
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
collection.add(
|
||||||
|
ids=ids,
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=texts
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(texts)
|
||||||
|
|
||||||
|
def search_documents(query_embedding, collection_name=COLLECTION_NAME, num_results=5, filter_metadata=None):
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
|
||||||
|
query_params = {
|
||||||
|
"query_embeddings": [query_embedding],
|
||||||
|
"n_results": num_results
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter_metadata:
|
||||||
|
query_params["where"] = filter_metadata
|
||||||
|
|
||||||
|
results = collection.query(**query_params)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
if results and results["documents"]:
|
||||||
|
for i, doc in enumerate(results["documents"][0]):
|
||||||
|
score = results["distances"][0][i] if "distances" in results else None
|
||||||
|
output.append({
|
||||||
|
"text": doc,
|
||||||
|
"score": score
|
||||||
|
})
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def delete_documents_by_source(source_file, collection_name=COLLECTION_NAME):
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
results = collection.get(where={"source": source_file})
|
||||||
|
if results["ids"]:
|
||||||
|
collection.delete(ids=results["ids"])
|
||||||
|
return len(results["ids"])
|
||||||
|
return 0
|
||||||
62
backend/src/mongo/metadata.py
Normal file
62
backend/src/mongo/metadata.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from .connection import get_mongo_client
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
DB_NAME = "hoya_metadata"
|
||||||
|
|
||||||
|
def get_datasets_collection():
|
||||||
|
client = get_mongo_client()
|
||||||
|
db = client.get_database(DB_NAME)
|
||||||
|
return db["datasets"]
|
||||||
|
|
||||||
|
def get_categories_collection():
|
||||||
|
client = get_mongo_client()
|
||||||
|
db = client.get_database(DB_NAME)
|
||||||
|
return db["categories"]
|
||||||
|
|
||||||
|
def is_file_processed(filename):
|
||||||
|
collection = get_datasets_collection()
|
||||||
|
return collection.find_one({"filename": filename}) is not None
|
||||||
|
|
||||||
|
def log_processed_file(filename, category=None, chunk_count=0):
|
||||||
|
collection = get_datasets_collection()
|
||||||
|
doc = {
|
||||||
|
"filename": filename,
|
||||||
|
"category": category,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
|
"processed_at": datetime.utcnow(),
|
||||||
|
"status": "processed"
|
||||||
|
}
|
||||||
|
collection.insert_one(doc)
|
||||||
|
|
||||||
|
def get_all_datasets():
|
||||||
|
collection = get_datasets_collection()
|
||||||
|
return list(collection.find({}, {"_id": 0}))
|
||||||
|
|
||||||
|
def get_datasets_by_category(category):
|
||||||
|
collection = get_datasets_collection()
|
||||||
|
return list(collection.find({"category": category}, {"_id": 0}))
|
||||||
|
|
||||||
|
def delete_dataset_record(filename):
|
||||||
|
collection = get_datasets_collection()
|
||||||
|
result = collection.delete_one({"filename": filename})
|
||||||
|
return result.deleted_count > 0
|
||||||
|
|
||||||
|
def create_category(name, description=""):
|
||||||
|
collection = get_categories_collection()
|
||||||
|
if collection.find_one({"name": name}):
|
||||||
|
return False
|
||||||
|
collection.insert_one({
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"created_at": datetime.utcnow()
|
||||||
|
})
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_all_categories():
|
||||||
|
collection = get_categories_collection()
|
||||||
|
return list(collection.find({}, {"_id": 0}))
|
||||||
|
|
||||||
|
def delete_category(name):
|
||||||
|
collection = get_categories_collection()
|
||||||
|
result = collection.delete_one({"name": name})
|
||||||
|
return result.deleted_count > 0
|
||||||
@@ -46,4 +46,4 @@ def log_processed_file(filename, log_collection="ingested_files", db_name="vecto
|
|||||||
client = get_mongo_client()
|
client = get_mongo_client()
|
||||||
db = client.get_database(db_name)
|
db = client.get_database(db_name)
|
||||||
collection = db[log_collection]
|
collection = db[log_collection]
|
||||||
collection.insert_one({"filename": filename, "processed_at": 1}) # keeping it simple
|
collection.insert_one({"filename": filename, "processed_at": 1})
|
||||||
|
|||||||
@@ -1,26 +1,32 @@
|
|||||||
from google import genai
|
import ollama
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def get_embedding(text, model="gemini-embedding-001"):
|
client = ollama.Client(host="https://ollama.sirblob.co")
|
||||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
DEFAULT_MODEL = "nomic-embed-text:latest"
|
||||||
if not api_key:
|
|
||||||
raise ValueError("GOOGLE_API_KEY environment variable not set")
|
|
||||||
|
|
||||||
client = genai.Client(api_key=api_key)
|
|
||||||
result = client.models.embed_content(
|
|
||||||
model=model,
|
|
||||||
contents=text
|
|
||||||
)
|
|
||||||
return result.embeddings[0].values
|
|
||||||
|
|
||||||
def get_embeddings_batch(texts, model="gemini-embedding-001"):
|
def get_embedding(text, model=DEFAULT_MODEL):
|
||||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
try:
|
||||||
if not api_key:
|
response = client.embeddings(model=model, prompt=text)
|
||||||
raise ValueError("GOOGLE_API_KEY environment variable not set")
|
return response["embedding"]
|
||||||
|
except Exception as e:
|
||||||
client = genai.Client(api_key=api_key)
|
print(f"Error getting embedding from Ollama: {e}")
|
||||||
result = client.models.embed_content(
|
raise e
|
||||||
model=model,
|
|
||||||
contents=texts
|
def get_embeddings_batch(texts, model=DEFAULT_MODEL, batch_size=50):
|
||||||
)
|
all_embeddings = []
|
||||||
return [emb.values for emb in result.embeddings]
|
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i:i + batch_size]
|
||||||
|
try:
|
||||||
|
response = client.embed(model=model, input=batch)
|
||||||
|
|
||||||
|
if "embeddings" in response:
|
||||||
|
all_embeddings.extend(response["embeddings"])
|
||||||
|
else:
|
||||||
|
raise ValueError("Unexpected response format from client.embed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error embedding batch {i}-{i+batch_size}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
|
|||||||
@@ -3,6 +3,40 @@ from pypdf import PdfReader
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
def chunk_text(text, target_length=2000, overlap=100):
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
paragraphs = text.split('\n\n')
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
for para in paragraphs:
|
||||||
|
if len(current_chunk) + len(para) > target_length:
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk.strip())
|
||||||
|
|
||||||
|
if len(para) > target_length:
|
||||||
|
start = 0
|
||||||
|
while start < len(para):
|
||||||
|
end = start + target_length
|
||||||
|
chunks.append(para[start:end].strip())
|
||||||
|
start += (target_length - overlap)
|
||||||
|
current_chunk = ""
|
||||||
|
else:
|
||||||
|
current_chunk = para
|
||||||
|
else:
|
||||||
|
if current_chunk:
|
||||||
|
current_chunk += "\n\n" + para
|
||||||
|
else:
|
||||||
|
current_chunk = para
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk.strip())
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
def load_csv(file_path):
|
def load_csv(file_path):
|
||||||
df = pd.read_csv(file_path)
|
df = pd.read_csv(file_path)
|
||||||
return df.apply(lambda x: ' | '.join(x.astype(str)), axis=1).tolist()
|
return df.apply(lambda x: ' | '.join(x.astype(str)), axis=1).tolist()
|
||||||
@@ -13,14 +47,52 @@ def load_pdf(file_path):
|
|||||||
for page in reader.pages:
|
for page in reader.pages:
|
||||||
text = page.extract_text()
|
text = page.extract_text()
|
||||||
if text:
|
if text:
|
||||||
text_chunks.append(text)
|
if len(text) > 4000:
|
||||||
|
text_chunks.extend(chunk_text(text))
|
||||||
|
else:
|
||||||
|
text_chunks.append(text)
|
||||||
return text_chunks
|
return text_chunks
|
||||||
|
|
||||||
|
def load_txt(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
return chunk_text(content)
|
||||||
|
|
||||||
|
def load_xlsx(file_path):
|
||||||
|
all_rows = []
|
||||||
|
try:
|
||||||
|
sheets = pd.read_excel(file_path, sheet_name=None)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Pandas read_excel failed: {e}")
|
||||||
|
|
||||||
|
for sheet_name, df in sheets.items():
|
||||||
|
if df.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
df = df.fillna("")
|
||||||
|
|
||||||
|
for row in df.values:
|
||||||
|
row_items = [str(x) for x in row if str(x).strip() != ""]
|
||||||
|
if row_items:
|
||||||
|
row_str = f"Sheet: {str(sheet_name)} | " + " | ".join(row_items)
|
||||||
|
|
||||||
|
if len(row_str) > 8000:
|
||||||
|
all_rows.extend(chunk_text(row_str))
|
||||||
|
else:
|
||||||
|
all_rows.append(row_str)
|
||||||
|
|
||||||
|
return all_rows
|
||||||
|
|
||||||
def process_file(file_path):
|
def process_file(file_path):
|
||||||
ext = os.path.splitext(file_path)[1].lower()
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
if ext == '.csv':
|
if ext == '.csv':
|
||||||
return load_csv(file_path)
|
return load_csv(file_path)
|
||||||
elif ext == '.pdf':
|
elif ext == '.pdf':
|
||||||
return load_pdf(file_path)
|
return load_pdf(file_path)
|
||||||
|
elif ext == '.txt':
|
||||||
|
return load_txt(file_path)
|
||||||
|
elif ext == '.xlsx':
|
||||||
|
return load_xlsx(file_path)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported file type: {ext}")
|
raise ValueError(f"Unsupported file type: {ext}")
|
||||||
|
|||||||
@@ -1,18 +1,27 @@
|
|||||||
from .embeddings import get_embeddings_batch, get_embedding
|
from .embeddings import get_embeddings_batch, get_embedding
|
||||||
from ..mongo.vector_store import insert_rag_documents, search_rag_documents
|
from ..chroma.vector_store import insert_documents, search_documents
|
||||||
|
|
||||||
def ingest_documents(text_chunks, collection_name="rag_documents"):
|
def ingest_documents(text_chunks, collection_name="rag_documents", source_file=None, category=None):
|
||||||
embeddings = get_embeddings_batch(text_chunks)
|
embeddings = get_embeddings_batch(text_chunks)
|
||||||
|
|
||||||
documents = []
|
metadata_list = None
|
||||||
for text, embedding in zip(text_chunks, embeddings):
|
if source_file or category:
|
||||||
documents.append({
|
metadata_list = []
|
||||||
"text": text,
|
for _ in text_chunks:
|
||||||
"embedding": embedding
|
meta = {}
|
||||||
})
|
if source_file:
|
||||||
|
meta["source"] = source_file
|
||||||
|
if category:
|
||||||
|
meta["category"] = category
|
||||||
|
metadata_list.append(meta)
|
||||||
|
|
||||||
return insert_rag_documents(documents, collection_name=collection_name)
|
return insert_documents(text_chunks, embeddings, collection_name=collection_name, metadata_list=metadata_list)
|
||||||
|
|
||||||
def vector_search(query_text, collection_name="rag_documents", num_results=5):
|
def vector_search(query_text, collection_name="rag_documents", num_results=5, category=None):
|
||||||
query_embedding = get_embedding(query_text)
|
query_embedding = get_embedding(query_text)
|
||||||
return search_rag_documents(query_embedding, collection_name=collection_name, num_results=num_results)
|
|
||||||
|
filter_metadata = None
|
||||||
|
if category:
|
||||||
|
filter_metadata = {"category": category}
|
||||||
|
|
||||||
|
return search_documents(query_embedding, collection_name=collection_name, num_results=num_results, filter_metadata=filter_metadata)
|
||||||
|
|||||||
Reference in New Issue
Block a user