mirror of
https://github.com/SirBlobby/Hoya26.git
synced 2026-02-04 03:34:34 -05:00
hello
This commit is contained in:
@@ -1,9 +1,47 @@
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
import os
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
|
||||
from src import create_app
|
||||
from src.rag.gemeni import GeminiClient
|
||||
from src.mongo import get_database
|
||||
|
||||
app = create_app()
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
try:
|
||||
brain = GeminiClient()
|
||||
db = get_database()
|
||||
print("--- Backend Components Initialized Successfully ---")
|
||||
except Exception as e:
|
||||
print(f"CRITICAL ERROR during initialization: {e}")
|
||||
|
||||
@app.route('/')
|
||||
def health_check():
|
||||
return {
|
||||
"status": "online",
|
||||
"message": "The Waiter is ready at the counter!"
|
||||
}
|
||||
|
||||
@app.route('/chat', methods=['POST'])
|
||||
def chat():
|
||||
data = request.json
|
||||
user_query = data.get("message")
|
||||
|
||||
if not user_query:
|
||||
return jsonify({"error": "You didn't say anything!"}), 400
|
||||
|
||||
try:
|
||||
context = ""
|
||||
ai_reply = brain.ask(user_query, context)
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"reply": ai_reply
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}), 500
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host='0.0.0.0', port=5000)
|
||||
app.run(debug=True, port=5000)
|
||||
@@ -1,7 +1,7 @@
|
||||
flask
|
||||
gunicorn
|
||||
ultralytics
|
||||
opencv-python-headless
|
||||
opencv-python
|
||||
transformers
|
||||
torch
|
||||
pandas
|
||||
|
||||
@@ -50,8 +50,8 @@ def populate_from_dataset(dataset_dir, category=None):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Populate vector database from dataset files")
|
||||
parser.add_argument("--category", "-c", type=str, help="Category to assign to ingested documents")
|
||||
parser.add_argument("--dir", "-d", type=str, default=None, help="Dataset directory path")
|
||||
parser.add_argument("--category", "-c", type=str)
|
||||
parser.add_argument("--dir", "-d", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dir:
|
||||
|
||||
@@ -17,9 +17,7 @@ def get_collection(collection_name=COLLECTION_NAME):
|
||||
|
||||
def insert_documents(texts, embeddings, collection_name=COLLECTION_NAME, metadata_list=None):
|
||||
collection = get_collection(collection_name)
|
||||
|
||||
ids = [f"doc_{i}_{hash(text)}" for i, text in enumerate(texts)]
|
||||
|
||||
if metadata_list:
|
||||
collection.add(
|
||||
ids=ids,
|
||||
@@ -33,22 +31,17 @@ def insert_documents(texts, embeddings, collection_name=COLLECTION_NAME, metadat
|
||||
embeddings=embeddings,
|
||||
documents=texts
|
||||
)
|
||||
|
||||
return len(texts)
|
||||
|
||||
def search_documents(query_embedding, collection_name=COLLECTION_NAME, num_results=5, filter_metadata=None):
|
||||
collection = get_collection(collection_name)
|
||||
|
||||
query_params = {
|
||||
"query_embeddings": [query_embedding],
|
||||
"n_results": num_results
|
||||
}
|
||||
|
||||
if filter_metadata:
|
||||
query_params["where"] = filter_metadata
|
||||
|
||||
results = collection.query(**query_params)
|
||||
|
||||
output = []
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
@@ -57,7 +50,6 @@ def search_documents(query_embedding, collection_name=COLLECTION_NAME, num_resul
|
||||
"text": doc,
|
||||
"score": score
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
def delete_documents_by_source(source_file, collection_name=COLLECTION_NAME):
|
||||
|
||||
@@ -0,0 +1,474 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
||||
CV_DIR = Path(__file__).parent
|
||||
DATA_DIR = CV_DIR / "data"
|
||||
MODELS_DIR = CV_DIR / "models"
|
||||
|
||||
SUPER_CATEGORIES = {
|
||||
"Food": 932,
|
||||
"Clothes": 604,
|
||||
"Necessities": 432,
|
||||
"Others": 371,
|
||||
"Electronic": 224,
|
||||
"Transportation": 213,
|
||||
"Leisure": 111,
|
||||
"Sports": 66,
|
||||
"Medical": 47
|
||||
}
|
||||
|
||||
COMMON_BRANDS = [
|
||||
"McDonalds", "Starbucks", "CocaCola", "Pepsi", "KFC", "BurgerKing",
|
||||
"Subway", "DunkinDonuts", "PizzaHut", "Dominos", "Nestle", "Heineken",
|
||||
"Nike", "Adidas", "Puma", "UnderArmour", "Levis", "HM", "Zara", "Gap",
|
||||
"Gucci", "LouisVuitton", "Chanel", "Versace", "Prada", "Armani",
|
||||
"Apple", "Samsung", "HP", "Dell", "Intel", "AMD", "Nvidia", "Microsoft",
|
||||
"Sony", "LG", "Huawei", "Xiaomi", "Lenovo", "Asus", "Acer",
|
||||
"BMW", "Mercedes", "Audi", "Toyota", "Honda", "Ford", "Chevrolet",
|
||||
"Volkswagen", "Tesla", "Porsche", "Ferrari", "Lamborghini", "Nissan",
|
||||
"Google", "Facebook", "Twitter", "Instagram", "YouTube", "Amazon",
|
||||
"Netflix", "Spotify", "Uber", "Airbnb", "PayPal", "Visa", "Mastercard"
|
||||
]
|
||||
|
||||
class LogoDet3KDataset:
|
||||
def __init__(self, dataset_path: Optional[str] = None):
|
||||
self.dataset_path = None
|
||||
self.categories = {}
|
||||
self.brand_templates = {}
|
||||
|
||||
if dataset_path and os.path.exists(dataset_path):
|
||||
self.dataset_path = Path(dataset_path)
|
||||
else:
|
||||
default_paths = [
|
||||
DATA_DIR / "LogoDet-3K",
|
||||
DATA_DIR / "logodet3k",
|
||||
Path.home() / "Downloads" / "LogoDet-3K",
|
||||
Path.home() / ".kaggle" / "datasets" / "lyly99" / "logodet3k",
|
||||
]
|
||||
for path in default_paths:
|
||||
if path.exists():
|
||||
self.dataset_path = path
|
||||
break
|
||||
|
||||
if self.dataset_path:
|
||||
self._load_categories()
|
||||
print(f"LogoDet-3K dataset loaded from: {self.dataset_path}")
|
||||
print(f"Found {len(self.categories)} brand categories")
|
||||
else:
|
||||
print("LogoDet-3K dataset not found locally.")
|
||||
print("\nTo download the dataset:")
|
||||
print("1. Install kaggle CLI: pip install kaggle")
|
||||
print("2. Download: kaggle datasets download -d lyly99/logodet3k")
|
||||
print("3. Extract to:", DATA_DIR / "LogoDet-3K")
|
||||
|
||||
def _load_categories(self):
|
||||
if not self.dataset_path:
|
||||
return
|
||||
|
||||
for super_cat in self.dataset_path.iterdir():
|
||||
if super_cat.is_dir() and not super_cat.name.startswith('.'):
|
||||
for brand_dir in super_cat.iterdir():
|
||||
if brand_dir.is_dir():
|
||||
brand_name = brand_dir.name
|
||||
self.categories[brand_name] = {
|
||||
"super_category": super_cat.name,
|
||||
"path": brand_dir,
|
||||
"images": list(brand_dir.glob("*.jpg")) + list(brand_dir.glob("*.png"))
|
||||
}
|
||||
|
||||
def get_brand_templates(self, brand_name: str, max_templates: int = 5) -> List[np.ndarray]:
|
||||
if brand_name not in self.categories:
|
||||
return []
|
||||
|
||||
templates = []
|
||||
images = self.categories[brand_name]["images"][:max_templates]
|
||||
|
||||
for img_path in images:
|
||||
img = cv2.imread(str(img_path))
|
||||
if img is not None:
|
||||
templates.append(img)
|
||||
|
||||
return templates
|
||||
|
||||
def get_all_brands(self) -> List[str]:
|
||||
return list(self.categories.keys())
|
||||
|
||||
def get_brands_by_category(self, super_category: str) -> List[str]:
|
||||
return [
|
||||
name for name, info in self.categories.items()
|
||||
if info["super_category"].lower() == super_category.lower()
|
||||
]
|
||||
|
||||
class LogoDetector:
|
||||
def __init__(self,
|
||||
model_path: Optional[str] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
use_gpu: bool = True):
|
||||
self.model_path = model_path
|
||||
self.use_gpu = use_gpu
|
||||
self.net = None
|
||||
self.dataset = LogoDet3KDataset(dataset_path)
|
||||
|
||||
self.conf_threshold = 0.3
|
||||
self.nms_threshold = 0.4
|
||||
|
||||
self.orb = cv2.ORB_create(nfeatures=1000)
|
||||
self.bf_matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
|
||||
|
||||
try:
|
||||
self.sift = cv2.SIFT_create()
|
||||
self.flann_matcher = cv2.FlannBasedMatcher(
|
||||
{"algorithm": 1, "trees": 5},
|
||||
{"checks": 50}
|
||||
)
|
||||
except:
|
||||
self.sift = None
|
||||
self.flann_matcher = None
|
||||
|
||||
self.brand_features = {}
|
||||
self._load_model()
|
||||
self._cache_brand_features()
|
||||
|
||||
def _load_model(self):
|
||||
if not self.model_path or not os.path.exists(self.model_path):
|
||||
return
|
||||
|
||||
try:
|
||||
print(f"Loading model: {self.model_path}")
|
||||
|
||||
if self.model_path.endswith('.onnx'):
|
||||
self.net = cv2.dnn.readNetFromONNX(self.model_path)
|
||||
else:
|
||||
self.net = cv2.dnn.readNet(self.model_path)
|
||||
|
||||
if self.use_gpu:
|
||||
try:
|
||||
self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_DEFAULT)
|
||||
self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_OPENCL)
|
||||
print("✅ Using OpenCL GPU acceleration")
|
||||
except:
|
||||
print("⚠️ GPU not available, using CPU")
|
||||
|
||||
print("Model loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load model: {e}")
|
||||
self.net = None
|
||||
|
||||
def _cache_brand_features(self):
|
||||
if not self.dataset.categories:
|
||||
return
|
||||
|
||||
print("Caching brand features (this may take a moment)...")
|
||||
|
||||
brands_to_cache = [b for b in COMMON_BRANDS if b in self.dataset.categories][:50]
|
||||
|
||||
for brand in brands_to_cache:
|
||||
templates = self.dataset.get_brand_templates(brand, max_templates=3)
|
||||
if templates:
|
||||
features = []
|
||||
for tmpl in templates:
|
||||
gray = cv2.cvtColor(tmpl, cv2.COLOR_BGR2GRAY)
|
||||
kp, des = self.orb.detectAndCompute(gray, None)
|
||||
if des is not None:
|
||||
features.append((kp, des))
|
||||
|
||||
if features:
|
||||
self.brand_features[brand] = features
|
||||
|
||||
print(f"Cached features for {len(self.brand_features)} brands")
|
||||
|
||||
def detect(self, frame: np.ndarray, conf_threshold: float = None) -> List[Dict]:
|
||||
if conf_threshold is None:
|
||||
conf_threshold = self.conf_threshold
|
||||
|
||||
detections = []
|
||||
|
||||
if self.net is not None:
|
||||
detections = self._detect_with_model(frame, conf_threshold)
|
||||
|
||||
if not detections and self.brand_features:
|
||||
detections = self._detect_with_features(frame, conf_threshold)
|
||||
|
||||
if not detections:
|
||||
detections = self._detect_logo_regions(frame)
|
||||
|
||||
return detections
|
||||
|
||||
def _detect_with_model(self, frame: np.ndarray, conf_threshold: float) -> List[Dict]:
|
||||
height, width = frame.shape[:2]
|
||||
|
||||
blob = cv2.dnn.blobFromImage(
|
||||
frame,
|
||||
scalefactor=1/255.0,
|
||||
size=(640, 640),
|
||||
swapRB=True,
|
||||
crop=False
|
||||
)
|
||||
|
||||
self.net.setInput(blob)
|
||||
|
||||
try:
|
||||
output_names = self.net.getUnconnectedOutLayersNames()
|
||||
outputs = self.net.forward(output_names)
|
||||
except:
|
||||
outputs = [self.net.forward()]
|
||||
|
||||
detections = []
|
||||
boxes = []
|
||||
confidences = []
|
||||
class_ids = []
|
||||
|
||||
for output in outputs:
|
||||
if len(output.shape) == 3:
|
||||
output = output[0]
|
||||
|
||||
for detection in output:
|
||||
if len(detection) < 5:
|
||||
continue
|
||||
|
||||
scores = detection[4:] if len(detection) > 5 else [detection[4]]
|
||||
class_id = np.argmax(scores) if len(scores) > 1 else 0
|
||||
confidence = float(scores[class_id]) if len(scores) > 1 else float(scores[0])
|
||||
|
||||
if confidence > conf_threshold:
|
||||
cx, cy, w, h = detection[:4]
|
||||
scale_x = width / 640
|
||||
scale_y = height / 640
|
||||
|
||||
x1 = int((cx - w/2) * scale_x)
|
||||
y1 = int((cy - h/2) * scale_y)
|
||||
x2 = int((cx + w/2) * scale_x)
|
||||
y2 = int((cy + h/2) * scale_y)
|
||||
|
||||
boxes.append([x1, y1, x2-x1, y2-y1])
|
||||
confidences.append(confidence)
|
||||
class_ids.append(class_id)
|
||||
|
||||
if boxes:
|
||||
indices = cv2.dnn.NMSBoxes(boxes, confidences, conf_threshold, self.nms_threshold)
|
||||
for i in indices:
|
||||
idx = i[0] if isinstance(i, (list, tuple, np.ndarray)) else i
|
||||
x, y, w, h = boxes[idx]
|
||||
detections.append({
|
||||
"bbox": (x, y, x + w, y + h),
|
||||
"label": f"Logo-{class_ids[idx]}",
|
||||
"confidence": confidences[idx],
|
||||
"class_id": class_ids[idx]
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def _detect_with_features(self, frame: np.ndarray, conf_threshold: float) -> List[Dict]:
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
kp_frame, des_frame = self.orb.detectAndCompute(gray, None)
|
||||
|
||||
if des_frame is None or len(kp_frame) < 10:
|
||||
return []
|
||||
|
||||
detections = []
|
||||
best_matches = []
|
||||
|
||||
for brand, feature_list in self.brand_features.items():
|
||||
for kp_tmpl, des_tmpl in feature_list:
|
||||
try:
|
||||
matches = self.bf_matcher.match(des_tmpl, des_frame)
|
||||
matches = sorted(matches, key=lambda x: x.distance)
|
||||
good_matches = [m for m in matches[:50] if m.distance < 60]
|
||||
|
||||
if len(good_matches) >= 8:
|
||||
pts = np.float32([kp_frame[m.trainIdx].pt for m in good_matches])
|
||||
if len(pts) > 0:
|
||||
x_min, y_min = pts.min(axis=0).astype(int)
|
||||
x_max, y_max = pts.max(axis=0).astype(int)
|
||||
avg_dist = np.mean([m.distance for m in good_matches])
|
||||
confidence = max(0.3, 1.0 - (avg_dist / 100))
|
||||
|
||||
if confidence >= conf_threshold:
|
||||
best_matches.append({
|
||||
"bbox": (x_min, y_min, x_max, y_max),
|
||||
"label": brand,
|
||||
"confidence": confidence,
|
||||
"match_count": len(good_matches)
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if best_matches:
|
||||
best_matches.sort(key=lambda x: x["confidence"], reverse=True)
|
||||
detections = best_matches[:5]
|
||||
|
||||
return detections
|
||||
|
||||
def _detect_logo_regions(self, frame: np.ndarray) -> List[Dict]:
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
||||
edges = cv2.Canny(blurred, 80, 200)
|
||||
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
edges = cv2.dilate(edges, kernel, iterations=1)
|
||||
edges = cv2.erode(edges, kernel, iterations=1)
|
||||
|
||||
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
detections = []
|
||||
height, width = frame.shape[:2]
|
||||
min_area = (width * height) * 0.01
|
||||
max_area = (width * height) * 0.15
|
||||
|
||||
for contour in contours:
|
||||
area = cv2.contourArea(contour)
|
||||
if area < min_area or area > max_area:
|
||||
continue
|
||||
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
aspect_ratio = w / h if h > 0 else 0
|
||||
|
||||
if aspect_ratio < 0.5 or aspect_ratio > 2.0:
|
||||
continue
|
||||
|
||||
hull = cv2.convexHull(contour)
|
||||
hull_area = cv2.contourArea(hull)
|
||||
solidity = area / hull_area if hull_area > 0 else 0
|
||||
|
||||
if solidity < 0.3:
|
||||
continue
|
||||
|
||||
roi = gray[y:y+h, x:x+w]
|
||||
if roi.size == 0:
|
||||
continue
|
||||
|
||||
corners = cv2.goodFeaturesToTrack(roi, 50, 0.01, 5)
|
||||
if corners is None or len(corners) < 15:
|
||||
continue
|
||||
|
||||
roi_edges = edges[y:y+h, x:x+w]
|
||||
edge_density = np.sum(roi_edges > 0) / (w * h) if (w * h) > 0 else 0
|
||||
|
||||
if edge_density < 0.05 or edge_density > 0.5:
|
||||
continue
|
||||
|
||||
corner_score = min(1.0, len(corners) / 40)
|
||||
solidity_score = solidity
|
||||
aspect_score = 1.0 - abs(1.0 - aspect_ratio) / 2
|
||||
|
||||
confidence = (corner_score * 0.4 + solidity_score * 0.3 + aspect_score * 0.3)
|
||||
|
||||
if confidence >= 0.6:
|
||||
detections.append({
|
||||
"bbox": (x, y, x + w, y + h),
|
||||
"label": "Potential Logo",
|
||||
"confidence": confidence,
|
||||
"class_id": -1
|
||||
})
|
||||
|
||||
detections.sort(key=lambda x: x["confidence"], reverse=True)
|
||||
return detections[:3]
|
||||
|
||||
def draw_detections(self, frame: np.ndarray, detections: List[Dict]) -> np.ndarray:
|
||||
result = frame.copy()
|
||||
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det["bbox"]
|
||||
label = det["label"]
|
||||
conf = det["confidence"]
|
||||
|
||||
if conf > 0.7:
|
||||
color = (0, 255, 0)
|
||||
elif conf > 0.5:
|
||||
color = (0, 255, 255)
|
||||
else:
|
||||
color = (0, 165, 255)
|
||||
|
||||
cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
|
||||
label_text = f"{label}: {conf:.2f}"
|
||||
(text_w, text_h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
cv2.rectangle(result, (x1, y1 - text_h - 6), (x1 + text_w + 4, y1), color, -1)
|
||||
cv2.putText(result, label_text, (x1 + 2, y1 - 4),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
||||
|
||||
return result
|
||||
|
||||
def start_scanner(model_path: Optional[str] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
use_gui: bool = True):
|
||||
print("=" * 60)
|
||||
print("LogoDet-3K Logo Scanner")
|
||||
print("3,000 logo categories | 9 super-categories | 200K+ objects")
|
||||
print("=" * 60)
|
||||
|
||||
detector = LogoDetector(
|
||||
model_path=model_path,
|
||||
dataset_path=dataset_path,
|
||||
use_gpu=True
|
||||
)
|
||||
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
print("\nError: Could not access camera.")
|
||||
return
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
||||
|
||||
writer = None
|
||||
output_path = CV_DIR / "output.mp4"
|
||||
|
||||
print(f"\n📷 Camera: {width}x{height} @ {fps:.1f}fps")
|
||||
print("Press 'q' to quit\n")
|
||||
|
||||
frame_count = 0
|
||||
try:
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
detections = detector.detect(frame)
|
||||
result_frame = detector.draw_detections(frame, detections)
|
||||
|
||||
info_text = f"Logos: {len(detections)} | Frame: {frame_count}"
|
||||
cv2.putText(result_frame, info_text, (10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
||||
|
||||
if use_gui:
|
||||
try:
|
||||
cv2.imshow('LogoDet-3K Scanner', result_frame)
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
if key == ord('q'):
|
||||
break
|
||||
elif key == ord('s'):
|
||||
cv2.imwrite(str(CV_DIR / f"screenshot_{frame_count}.jpg"), result_frame)
|
||||
except cv2.error:
|
||||
use_gui = False
|
||||
writer = cv2.VideoWriter(
|
||||
str(output_path),
|
||||
cv2.VideoWriter_fourcc(*'mp4v'),
|
||||
fps,
|
||||
(width, height)
|
||||
)
|
||||
|
||||
if not use_gui and writer:
|
||||
writer.write(result_frame)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
cap.release()
|
||||
if writer:
|
||||
writer.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", "-m", type=str)
|
||||
parser.add_argument("--dataset", "-d", type=str)
|
||||
parser.add_argument("--no-gui", action="store_true")
|
||||
args = parser.parse_args()
|
||||
start_scanner(model_path=args.model, dataset_path=args.dataset, use_gui=not args.no_gui)
|
||||
@@ -5,7 +5,6 @@ def generate_content(prompt, model_name="gemini-2.0-flash-exp"):
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
return "Error: GOOGLE_API_KEY not found."
|
||||
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
response = client.models.generate_content(
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
from dotenv import load_dotenv
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
env_path = os.path.join(script_dir, '..', 'rag', '.env')
|
||||
load_dotenv(env_path)
|
||||
|
||||
def get_database():
|
||||
uri = os.getenv("MONGO_URI")
|
||||
try:
|
||||
client = MongoClient(uri)
|
||||
db = client["my_rag_app"]
|
||||
print("SUCCESS: Connected to MongoDB Atlas!")
|
||||
return db
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not connect to MongoDB: {e}")
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
get_database()
|
||||
40
backend/src/rag/gemeni.py
Normal file
40
backend/src/rag/gemeni.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
from google import genai
|
||||
from dotenv import load_dotenv
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
load_dotenv(os.path.join(script_dir, '.env'))
|
||||
|
||||
class GeminiClient:
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No GOOGLE_API_KEY found in .env file!")
|
||||
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.model_name = "gemini-2.0-flash"
|
||||
|
||||
def ask(self, prompt, context=""):
|
||||
try:
|
||||
if context:
|
||||
full_message = f"Use this information to answer: {context}\n\nQuestion: {prompt}"
|
||||
else:
|
||||
full_message = prompt
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model_name,
|
||||
contents=full_message
|
||||
)
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
return f"Error talking to Gemini: {str(e)}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
brain = GeminiClient()
|
||||
print("--- Testing Class Connection ---")
|
||||
print(brain.ask("Hello! Give me a 1-sentence coding tip."))
|
||||
except Exception as e:
|
||||
print(f"Failed to start Gemini: {e}")
|
||||
@@ -8,7 +8,6 @@ def chunk_text(text, target_length=2000, overlap=100):
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
|
||||
paragraphs = text.split('\n\n')
|
||||
current_chunk = ""
|
||||
|
||||
@@ -56,7 +55,6 @@ def load_pdf(file_path):
|
||||
def load_txt(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
return chunk_text(content)
|
||||
|
||||
def load_xlsx(file_path):
|
||||
@@ -69,19 +67,15 @@ def load_xlsx(file_path):
|
||||
for sheet_name, df in sheets.items():
|
||||
if df.empty:
|
||||
continue
|
||||
|
||||
df = df.fillna("")
|
||||
|
||||
for row in df.values:
|
||||
row_items = [str(x) for x in row if str(x).strip() != ""]
|
||||
if row_items:
|
||||
row_str = f"Sheet: {str(sheet_name)} | " + " | ".join(row_items)
|
||||
|
||||
if len(row_str) > 8000:
|
||||
all_rows.extend(chunk_text(row_str))
|
||||
else:
|
||||
all_rows.append(row_str)
|
||||
|
||||
return all_rows
|
||||
|
||||
def process_file(file_path):
|
||||
|
||||
Reference in New Issue
Block a user