added the model

This commit is contained in:
samarthjain2023
2025-09-27 12:14:26 -04:00
parent 4c84ab1f67
commit 0df2b0019b
28 changed files with 1633 additions and 1 deletions

30
roadcast/inference.py Normal file
View File

@@ -0,0 +1,30 @@
import os
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from models import create_model
def load_model(path, device=None, in_channels=3, num_classes=10):
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(path, map_location=device)
model = create_model(device=device, in_channels=in_channels, num_classes=num_classes)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
class_to_idx = checkpoint.get('class_to_idx')
idx_to_class = {v: k for k, v in class_to_idx.items()} if class_to_idx else None
return model, idx_to_class
def predict_image(model, img_path, device=None):
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
preprocess = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
img = Image.open(img_path).convert('RGB')
x = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits, dim=1)
conf, idx = torch.max(probs, dim=1)
return int(idx.item()), float(conf.item())