mirror of
https://github.com/SirBlobby/Hoya26.git
synced 2026-02-03 19:24:34 -05:00
Populate DB Chromadb
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -66,3 +66,4 @@ venv.bak/
|
||||
# Tauri
|
||||
|
||||
bun.lock
|
||||
dataset/
|
||||
@@ -1,12 +1,14 @@
|
||||
flask
|
||||
google-genai
|
||||
gunicorn
|
||||
pymongo
|
||||
ultralytics
|
||||
opencv-python-headless
|
||||
transformers
|
||||
torch
|
||||
pandas
|
||||
pypdf
|
||||
openpyxl
|
||||
python-dotenv
|
||||
flask-cors
|
||||
ollama
|
||||
chromadb-client
|
||||
pymongo
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# Add backend directory to path so we can import src
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -10,21 +10,23 @@ load_dotenv()
|
||||
|
||||
from src.rag.ingest import process_file
|
||||
from src.rag.store import ingest_documents
|
||||
from src.mongo.vector_store import is_file_processed, log_processed_file
|
||||
from src.mongo.metadata import is_file_processed, log_processed_file
|
||||
|
||||
def populate_from_dataset(dataset_dir):
|
||||
def populate_from_dataset(dataset_dir, category=None):
|
||||
dataset_path = Path(dataset_dir)
|
||||
if not dataset_path.exists():
|
||||
print(f"Dataset directory not found: {dataset_dir}")
|
||||
return
|
||||
|
||||
print(f"Scanning {dataset_dir}...")
|
||||
if category:
|
||||
print(f"Category: {category}")
|
||||
|
||||
total_chunks = 0
|
||||
files_processed = 0
|
||||
|
||||
for file_path in dataset_path.glob('*'):
|
||||
if file_path.is_file() and file_path.suffix.lower() in ['.csv', '.pdf']:
|
||||
if file_path.is_file() and file_path.suffix.lower() in ['.csv', '.pdf', '.txt', '.xlsx']:
|
||||
if is_file_processed(file_path.name):
|
||||
print(f"Skipping {file_path.name} (already processed)")
|
||||
continue
|
||||
@@ -33,10 +35,10 @@ def populate_from_dataset(dataset_dir):
|
||||
try:
|
||||
chunks = process_file(str(file_path))
|
||||
if chunks:
|
||||
count = ingest_documents(chunks)
|
||||
count = ingest_documents(chunks, source_file=file_path.name, category=category)
|
||||
print(f" Ingested {count} chunks.")
|
||||
if count > 0:
|
||||
log_processed_file(file_path.name)
|
||||
log_processed_file(file_path.name, category=category, chunk_count=count)
|
||||
total_chunks += count
|
||||
files_processed += 1
|
||||
else:
|
||||
@@ -47,6 +49,14 @@ def populate_from_dataset(dataset_dir):
|
||||
print(f"\nFinished! Processed {files_processed} files. Total chunks ingested: {total_chunks}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Assuming run from backend/
|
||||
parser = argparse.ArgumentParser(description="Populate vector database from dataset files")
|
||||
parser.add_argument("--category", "-c", type=str, help="Category to assign to ingested documents")
|
||||
parser.add_argument("--dir", "-d", type=str, default=None, help="Dataset directory path")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dir:
|
||||
dataset_dir = args.dir
|
||||
else:
|
||||
dataset_dir = os.path.join(os.path.dirname(__file__), '../dataset')
|
||||
populate_from_dataset(dataset_dir)
|
||||
|
||||
populate_from_dataset(dataset_dir, category=args.category)
|
||||
|
||||
0
backend/src/chroma/__init__.py
Normal file
0
backend/src/chroma/__init__.py
Normal file
69
backend/src/chroma/vector_store.py
Normal file
69
backend/src/chroma/vector_store.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import chromadb
|
||||
|
||||
CHROMA_HOST = "http://chroma.sirblob.co"
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
|
||||
_client = None
|
||||
|
||||
def get_chroma_client():
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = chromadb.HttpClient(host=CHROMA_HOST)
|
||||
return _client
|
||||
|
||||
def get_collection(collection_name=COLLECTION_NAME):
|
||||
client = get_chroma_client()
|
||||
return client.get_or_create_collection(name=collection_name)
|
||||
|
||||
def insert_documents(texts, embeddings, collection_name=COLLECTION_NAME, metadata_list=None):
|
||||
collection = get_collection(collection_name)
|
||||
|
||||
ids = [f"doc_{i}_{hash(text)}" for i, text in enumerate(texts)]
|
||||
|
||||
if metadata_list:
|
||||
collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=texts,
|
||||
metadatas=metadata_list
|
||||
)
|
||||
else:
|
||||
collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=texts
|
||||
)
|
||||
|
||||
return len(texts)
|
||||
|
||||
def search_documents(query_embedding, collection_name=COLLECTION_NAME, num_results=5, filter_metadata=None):
|
||||
collection = get_collection(collection_name)
|
||||
|
||||
query_params = {
|
||||
"query_embeddings": [query_embedding],
|
||||
"n_results": num_results
|
||||
}
|
||||
|
||||
if filter_metadata:
|
||||
query_params["where"] = filter_metadata
|
||||
|
||||
results = collection.query(**query_params)
|
||||
|
||||
output = []
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
score = results["distances"][0][i] if "distances" in results else None
|
||||
output.append({
|
||||
"text": doc,
|
||||
"score": score
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
def delete_documents_by_source(source_file, collection_name=COLLECTION_NAME):
|
||||
collection = get_collection(collection_name)
|
||||
results = collection.get(where={"source": source_file})
|
||||
if results["ids"]:
|
||||
collection.delete(ids=results["ids"])
|
||||
return len(results["ids"])
|
||||
return 0
|
||||
62
backend/src/mongo/metadata.py
Normal file
62
backend/src/mongo/metadata.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from .connection import get_mongo_client
|
||||
from datetime import datetime
|
||||
|
||||
DB_NAME = "hoya_metadata"
|
||||
|
||||
def get_datasets_collection():
|
||||
client = get_mongo_client()
|
||||
db = client.get_database(DB_NAME)
|
||||
return db["datasets"]
|
||||
|
||||
def get_categories_collection():
|
||||
client = get_mongo_client()
|
||||
db = client.get_database(DB_NAME)
|
||||
return db["categories"]
|
||||
|
||||
def is_file_processed(filename):
|
||||
collection = get_datasets_collection()
|
||||
return collection.find_one({"filename": filename}) is not None
|
||||
|
||||
def log_processed_file(filename, category=None, chunk_count=0):
|
||||
collection = get_datasets_collection()
|
||||
doc = {
|
||||
"filename": filename,
|
||||
"category": category,
|
||||
"chunk_count": chunk_count,
|
||||
"processed_at": datetime.utcnow(),
|
||||
"status": "processed"
|
||||
}
|
||||
collection.insert_one(doc)
|
||||
|
||||
def get_all_datasets():
|
||||
collection = get_datasets_collection()
|
||||
return list(collection.find({}, {"_id": 0}))
|
||||
|
||||
def get_datasets_by_category(category):
|
||||
collection = get_datasets_collection()
|
||||
return list(collection.find({"category": category}, {"_id": 0}))
|
||||
|
||||
def delete_dataset_record(filename):
|
||||
collection = get_datasets_collection()
|
||||
result = collection.delete_one({"filename": filename})
|
||||
return result.deleted_count > 0
|
||||
|
||||
def create_category(name, description=""):
|
||||
collection = get_categories_collection()
|
||||
if collection.find_one({"name": name}):
|
||||
return False
|
||||
collection.insert_one({
|
||||
"name": name,
|
||||
"description": description,
|
||||
"created_at": datetime.utcnow()
|
||||
})
|
||||
return True
|
||||
|
||||
def get_all_categories():
|
||||
collection = get_categories_collection()
|
||||
return list(collection.find({}, {"_id": 0}))
|
||||
|
||||
def delete_category(name):
|
||||
collection = get_categories_collection()
|
||||
result = collection.delete_one({"name": name})
|
||||
return result.deleted_count > 0
|
||||
@@ -46,4 +46,4 @@ def log_processed_file(filename, log_collection="ingested_files", db_name="vecto
|
||||
client = get_mongo_client()
|
||||
db = client.get_database(db_name)
|
||||
collection = db[log_collection]
|
||||
collection.insert_one({"filename": filename, "processed_at": 1}) # keeping it simple
|
||||
collection.insert_one({"filename": filename, "processed_at": 1})
|
||||
|
||||
@@ -1,26 +1,32 @@
|
||||
from google import genai
|
||||
import ollama
|
||||
import os
|
||||
|
||||
def get_embedding(text, model="gemini-embedding-001"):
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GOOGLE_API_KEY environment variable not set")
|
||||
client = ollama.Client(host="https://ollama.sirblob.co")
|
||||
DEFAULT_MODEL = "nomic-embed-text:latest"
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
result = client.models.embed_content(
|
||||
model=model,
|
||||
contents=text
|
||||
)
|
||||
return result.embeddings[0].values
|
||||
def get_embedding(text, model=DEFAULT_MODEL):
|
||||
try:
|
||||
response = client.embeddings(model=model, prompt=text)
|
||||
return response["embedding"]
|
||||
except Exception as e:
|
||||
print(f"Error getting embedding from Ollama: {e}")
|
||||
raise e
|
||||
|
||||
def get_embeddings_batch(texts, model="gemini-embedding-001"):
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GOOGLE_API_KEY environment variable not set")
|
||||
def get_embeddings_batch(texts, model=DEFAULT_MODEL, batch_size=50):
|
||||
all_embeddings = []
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
result = client.models.embed_content(
|
||||
model=model,
|
||||
contents=texts
|
||||
)
|
||||
return [emb.values for emb in result.embeddings]
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
try:
|
||||
response = client.embed(model=model, input=batch)
|
||||
|
||||
if "embeddings" in response:
|
||||
all_embeddings.extend(response["embeddings"])
|
||||
else:
|
||||
raise ValueError("Unexpected response format from client.embed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error embedding batch {i}-{i+batch_size}: {e}")
|
||||
raise e
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@@ -3,6 +3,40 @@ from pypdf import PdfReader
|
||||
import io
|
||||
import os
|
||||
|
||||
def chunk_text(text, target_length=2000, overlap=100):
|
||||
if not text:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
|
||||
paragraphs = text.split('\n\n')
|
||||
current_chunk = ""
|
||||
|
||||
for para in paragraphs:
|
||||
if len(current_chunk) + len(para) > target_length:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
if len(para) > target_length:
|
||||
start = 0
|
||||
while start < len(para):
|
||||
end = start + target_length
|
||||
chunks.append(para[start:end].strip())
|
||||
start += (target_length - overlap)
|
||||
current_chunk = ""
|
||||
else:
|
||||
current_chunk = para
|
||||
else:
|
||||
if current_chunk:
|
||||
current_chunk += "\n\n" + para
|
||||
else:
|
||||
current_chunk = para
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
def load_csv(file_path):
|
||||
df = pd.read_csv(file_path)
|
||||
return df.apply(lambda x: ' | '.join(x.astype(str)), axis=1).tolist()
|
||||
@@ -13,14 +47,52 @@ def load_pdf(file_path):
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
if len(text) > 4000:
|
||||
text_chunks.extend(chunk_text(text))
|
||||
else:
|
||||
text_chunks.append(text)
|
||||
return text_chunks
|
||||
|
||||
def load_txt(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
return chunk_text(content)
|
||||
|
||||
def load_xlsx(file_path):
|
||||
all_rows = []
|
||||
try:
|
||||
sheets = pd.read_excel(file_path, sheet_name=None)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Pandas read_excel failed: {e}")
|
||||
|
||||
for sheet_name, df in sheets.items():
|
||||
if df.empty:
|
||||
continue
|
||||
|
||||
df = df.fillna("")
|
||||
|
||||
for row in df.values:
|
||||
row_items = [str(x) for x in row if str(x).strip() != ""]
|
||||
if row_items:
|
||||
row_str = f"Sheet: {str(sheet_name)} | " + " | ".join(row_items)
|
||||
|
||||
if len(row_str) > 8000:
|
||||
all_rows.extend(chunk_text(row_str))
|
||||
else:
|
||||
all_rows.append(row_str)
|
||||
|
||||
return all_rows
|
||||
|
||||
def process_file(file_path):
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext == '.csv':
|
||||
return load_csv(file_path)
|
||||
elif ext == '.pdf':
|
||||
return load_pdf(file_path)
|
||||
elif ext == '.txt':
|
||||
return load_txt(file_path)
|
||||
elif ext == '.xlsx':
|
||||
return load_xlsx(file_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {ext}")
|
||||
|
||||
@@ -1,18 +1,27 @@
|
||||
from .embeddings import get_embeddings_batch, get_embedding
|
||||
from ..mongo.vector_store import insert_rag_documents, search_rag_documents
|
||||
from ..chroma.vector_store import insert_documents, search_documents
|
||||
|
||||
def ingest_documents(text_chunks, collection_name="rag_documents"):
|
||||
def ingest_documents(text_chunks, collection_name="rag_documents", source_file=None, category=None):
|
||||
embeddings = get_embeddings_batch(text_chunks)
|
||||
|
||||
documents = []
|
||||
for text, embedding in zip(text_chunks, embeddings):
|
||||
documents.append({
|
||||
"text": text,
|
||||
"embedding": embedding
|
||||
})
|
||||
metadata_list = None
|
||||
if source_file or category:
|
||||
metadata_list = []
|
||||
for _ in text_chunks:
|
||||
meta = {}
|
||||
if source_file:
|
||||
meta["source"] = source_file
|
||||
if category:
|
||||
meta["category"] = category
|
||||
metadata_list.append(meta)
|
||||
|
||||
return insert_rag_documents(documents, collection_name=collection_name)
|
||||
return insert_documents(text_chunks, embeddings, collection_name=collection_name, metadata_list=metadata_list)
|
||||
|
||||
def vector_search(query_text, collection_name="rag_documents", num_results=5):
|
||||
def vector_search(query_text, collection_name="rag_documents", num_results=5, category=None):
|
||||
query_embedding = get_embedding(query_text)
|
||||
return search_rag_documents(query_embedding, collection_name=collection_name, num_results=num_results)
|
||||
|
||||
filter_metadata = None
|
||||
if category:
|
||||
filter_metadata = {"category": category}
|
||||
|
||||
return search_documents(query_embedding, collection_name=collection_name, num_results=num_results, filter_metadata=filter_metadata)
|
||||
|
||||
Reference in New Issue
Block a user