153 lines
5.5 KiB
Python
153 lines
5.5 KiB
Python
import glob
|
|
import os
|
|
import pytesseract
|
|
import re
|
|
import torch
|
|
import ultralytics
|
|
from ultralytics import YOLO
|
|
from PIL import Image
|
|
import logging
|
|
import time
|
|
logger = logging.getLogger('lib')
|
|
#ultralytics.checks()
|
|
from xml.etree import ElementTree as ET
|
|
from .consts import BASE_PATH, FRAME_RATE
|
|
from pytesseract import TesseractError
|
|
|
|
daily_log_ts = time.strftime("%Y-%m-%d")
|
|
#logging.basicConfig(
|
|
# filename= f"{BASE_PATH}/logs/main-yolo-tesseract-{daily_log_ts}.log",
|
|
# level=logging.DEBUG,
|
|
# format='%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
|
|
# )
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.device('cuda')
|
|
else:
|
|
device = torch.device('cpu')
|
|
|
|
def convert_bbox_to_yolo(
|
|
size: tuple[float, float], box: tuple[float, float, float, float]
|
|
) -> tuple[float, float, float, float]:
|
|
"""Convert bounding box from absolute coordinates to relative coordinates.
|
|
|
|
:param size: Tuple of (width, height) of the image.
|
|
:param box: Tuple of (xmin, ymin, xmax, ymax) for the bounding box.
|
|
:return: Tuple of (x_center, y_center, width, height) in relative
|
|
coordinates.
|
|
"""
|
|
scale_width = 1.0 / size[0]
|
|
scale_height = 1.0 / size[1]
|
|
|
|
center_x = (box[0] + box[2]) / 2.0
|
|
center_y = (box[1] + box[3]) / 2.0
|
|
box_width = box[2] - box[0]
|
|
box_height = box[3] - box[1]
|
|
|
|
rel_center_x = center_x * scale_width
|
|
rel_center_y = center_y * scale_height
|
|
rel_width = box_width * scale_width
|
|
rel_height = box_height * scale_height
|
|
|
|
return (rel_center_x, rel_center_y, rel_width, rel_height)
|
|
def xml_to_txt(input_xml: str, output_txt: str, class_mapping: dict[str, int]):
|
|
"""Parse an XML file and write to a .txt file in YOLO format.
|
|
|
|
:param input_xml: Path to the input XML file.
|
|
:param output_txt: Path to the output .txt file.
|
|
:param class_mapping: Dictionary mapping class names to class.
|
|
"""
|
|
tree = ET.parse(input_xml)
|
|
root = tree.getroot()
|
|
width = int(root.find(".//size/width").text)
|
|
height = int(root.find(".//size/height").text)
|
|
|
|
with open(output_txt, "w", encoding="utf-8") as txt_file:
|
|
for obj in root.iter("object"):
|
|
cell_name = obj.find("name").text
|
|
cell_id = class_mapping.get(cell_name, -1)
|
|
|
|
if cell_id == -1:
|
|
continue
|
|
|
|
xmlbox = obj.find("bndbox")
|
|
box = (
|
|
float(xmlbox.find("xmin").text),
|
|
float(xmlbox.find("ymin").text),
|
|
float(xmlbox.find("xmax").text),
|
|
float(xmlbox.find("ymax").text),
|
|
)
|
|
bbox = convert_bbox_to_yolo((width, height), box)
|
|
txt_file.write(f"{cell_id} {' '.join([str(a) for a in bbox])}\n")
|
|
|
|
|
|
def download_kaggle_data():
|
|
import kagglehub
|
|
if not os.path.exists(f"{BASE_PATH}/data/indian-number-plates-dataset"):
|
|
# Download latest version
|
|
path = kagglehub.dataset_download("dataclusterlabs/indian-number-plates-dataset")
|
|
print("Path to dataset files:", path)
|
|
|
|
def convert_to_yolo_format(data_path):
|
|
import glob
|
|
files = glob.glob(data_path + '/*.xml')
|
|
print(len(files))
|
|
for xml_fil in files:
|
|
if not os.path.isdir(xml_fil):
|
|
txt_fil = xml_fil.split('.')[:-1]
|
|
txt_fil = '.'.join(txt_fil) + '.txt'
|
|
xml_to_txt(xml_fil, txt_fil, class_mapping = {'number_plate': '0'})
|
|
files = glob.glob(data_path + '/*.txt')
|
|
print(len(files))
|
|
|
|
def train_model():
|
|
from ultralytics.data.dataset import YOLODataset
|
|
dataset = YOLODataset(img_path=f"{BASE_PATH}/data/train/images", data={"names": {0: "person"}}, task="detect")
|
|
dataset.get_labels()
|
|
|
|
model = YOLO(f"{BASE_PATH}/models/yolov8n.pt")
|
|
results = model.train(data=f'{BASE_PATH}/data/data.yaml', epochs=50, imgsz=1728)
|
|
|
|
model.export(format='onnx', dynamic=True,
|
|
#path = "../models/yolov8n_anpr.onnx",
|
|
simplify=True, device=device)
|
|
|
|
def infer(filename):
|
|
model2 = YOLO(f"{BASE_PATH}/models/train25/best.pt")
|
|
test_result = model2.predict(source=filename)
|
|
logger.debug(len(test_result))
|
|
#onnx_model = YOLO("../models/train25/best.onnx")
|
|
#model_2 = YOLO('/kaggle/input/weights/best(2).pt')
|
|
#testfiles = glob.glob('../data/TEST/*')
|
|
#import pdb; pdb.set_trace()
|
|
#test_result = model2.predict(source=testfiles[4])
|
|
vidname = os.path.basename(filename).split('.')[0]
|
|
if not os.path.exists(f"{BASE_PATH}/images/{vidname}"):
|
|
os.mkdir(f"{BASE_PATH}/images/{vidname}")
|
|
number_plates = dict()
|
|
number_plates['frames_found'] = list()
|
|
for i, res in enumerate(test_result):
|
|
res_img = res.plot()
|
|
plate_im = Image.fromarray(res_img)
|
|
logger.debug(f"Saving result number {i} for file {vidname}")
|
|
if i % FRAME_RATE == 0:
|
|
number_plates['frames_found'].append(i)
|
|
plate_im.save(f"{BASE_PATH}/images/{vidname}/frame-{i}.jpg")
|
|
#try:
|
|
# np_text = pytesseract.image_to_string(plate_im, lang='eng')
|
|
# plate = str("".join(re.split("[^a-zA-Z0-9]*", np_text)))
|
|
# number_plates[i] = plate.upper()
|
|
#except TesseractError as TE:
|
|
# logger.error("Failed to extract numberplate from the detected image")
|
|
# continue
|
|
#except Exception as e:
|
|
# logger.error("Failed to extract numberplate from the detected image")
|
|
# continue
|
|
return number_plates
|
|
|
|
#download_kaggle_data()
|
|
#data_path = '../data/train/labels'
|
|
#convert_to_yolo_format(data_path)
|
|
#data_path = '../data/TEST/labels'
|
|
#convert_to_yolo_format(data_path)
|
|
#train_model()
|