mirror of
https://github.com/SirBlobby/Hoya26.git
synced 2026-02-04 03:34:34 -05:00
62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
import chromadb
|
|
|
|
CHROMA_HOST = "http://chroma.sirblob.co"
|
|
COLLECTION_NAME = "rag_documents"
|
|
|
|
_client = None
|
|
|
|
def get_chroma_client():
|
|
global _client
|
|
if _client is None:
|
|
_client = chromadb.HttpClient(host=CHROMA_HOST)
|
|
return _client
|
|
|
|
def get_collection(collection_name=COLLECTION_NAME):
|
|
client = get_chroma_client()
|
|
return client.get_or_create_collection(name=collection_name)
|
|
|
|
def insert_documents(texts, embeddings, collection_name=COLLECTION_NAME, metadata_list=None):
|
|
collection = get_collection(collection_name)
|
|
ids = [f"doc_{i}_{hash(text)}" for i, text in enumerate(texts)]
|
|
if metadata_list:
|
|
collection.add(
|
|
ids=ids,
|
|
embeddings=embeddings,
|
|
documents=texts,
|
|
metadatas=metadata_list
|
|
)
|
|
else:
|
|
collection.add(
|
|
ids=ids,
|
|
embeddings=embeddings,
|
|
documents=texts
|
|
)
|
|
return len(texts)
|
|
|
|
def search_documents(query_embedding, collection_name=COLLECTION_NAME, num_results=5, filter_metadata=None):
|
|
collection = get_collection(collection_name)
|
|
query_params = {
|
|
"query_embeddings": [query_embedding],
|
|
"n_results": num_results
|
|
}
|
|
if filter_metadata:
|
|
query_params["where"] = filter_metadata
|
|
results = collection.query(**query_params)
|
|
output = []
|
|
if results and results["documents"]:
|
|
for i, doc in enumerate(results["documents"][0]):
|
|
score = results["distances"][0][i] if "distances" in results else None
|
|
output.append({
|
|
"text": doc,
|
|
"score": score
|
|
})
|
|
return output
|
|
|
|
def delete_documents_by_source(source_file, collection_name=COLLECTION_NAME):
|
|
collection = get_collection(collection_name)
|
|
results = collection.get(where={"source": source_file})
|
|
if results["ids"]:
|
|
collection.delete(ids=results["ids"])
|
|
return len(results["ids"])
|
|
return 0
|