mirror of
https://github.com/SirBlobby/Hoya26.git
synced 2026-02-04 03:34:34 -05:00
Populate DB Chromadb
This commit is contained in:
69
backend/src/chroma/vector_store.py
Normal file
69
backend/src/chroma/vector_store.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import chromadb
|
||||
|
||||
CHROMA_HOST = "http://chroma.sirblob.co"
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
|
||||
_client = None
|
||||
|
||||
def get_chroma_client():
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = chromadb.HttpClient(host=CHROMA_HOST)
|
||||
return _client
|
||||
|
||||
def get_collection(collection_name=COLLECTION_NAME):
|
||||
client = get_chroma_client()
|
||||
return client.get_or_create_collection(name=collection_name)
|
||||
|
||||
def insert_documents(texts, embeddings, collection_name=COLLECTION_NAME, metadata_list=None):
|
||||
collection = get_collection(collection_name)
|
||||
|
||||
ids = [f"doc_{i}_{hash(text)}" for i, text in enumerate(texts)]
|
||||
|
||||
if metadata_list:
|
||||
collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=texts,
|
||||
metadatas=metadata_list
|
||||
)
|
||||
else:
|
||||
collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=texts
|
||||
)
|
||||
|
||||
return len(texts)
|
||||
|
||||
def search_documents(query_embedding, collection_name=COLLECTION_NAME, num_results=5, filter_metadata=None):
|
||||
collection = get_collection(collection_name)
|
||||
|
||||
query_params = {
|
||||
"query_embeddings": [query_embedding],
|
||||
"n_results": num_results
|
||||
}
|
||||
|
||||
if filter_metadata:
|
||||
query_params["where"] = filter_metadata
|
||||
|
||||
results = collection.query(**query_params)
|
||||
|
||||
output = []
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
score = results["distances"][0][i] if "distances" in results else None
|
||||
output.append({
|
||||
"text": doc,
|
||||
"score": score
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
def delete_documents_by_source(source_file, collection_name=COLLECTION_NAME):
|
||||
collection = get_collection(collection_name)
|
||||
results = collection.get(where={"source": source_file})
|
||||
if results["ids"]:
|
||||
collection.delete(ids=results["ids"])
|
||||
return len(results["ids"])
|
||||
return 0
|
||||
Reference in New Issue
Block a user