Roboflow 是一个用于注释图像以用于对象检测 AI 的平台。
我将这个平台用于 C2SMR c2smr.fr,我的海上救援计算机视觉协会。
在本文中,我将向您展示如何使用这个平台并使用 python 训练您的模型。
您可以在我的 github 上找到更多示例代码:https://github.com/C2SMR/Detector
要创建数据集,请访问 https://app.roboflow.com/ 并开始注释您的图像,如下图所示。
在此示例中,我绕过所有游泳者来预测他们在未来图像中的位置。
为了获得良好的结果,请裁剪所有游泳者并将边界框放置在对象后面以正确包围它。
您已经可以使用公共 roboflow 数据集进行此检查 https://universe.roboflow.com/
在训练阶段,您可以直接使用 roboflow,但到了第三次您就需要付费,这就是为什么我向您展示如何使用笔记本电脑进行操作。
第一步是导入数据集。为此,您可以导入 Roboflow 库。
pip install roboflow
要创建模型,您需要使用YOLO算法,您可以使用ultralytics库导入该算法。
pip install ultralytics
在我的脚本中,我使用以下命令:
py train.py api-key project-workspace project-name project-version nb-epoch size_model
您必须获得:
最初,脚本会下载 yolov8-obb.pt,这是带有锻炼前数据的默认 yolo 权重,以方便训练。
import sys import os import random from roboflow import Roboflow from ultralytics import YOLO import yaml import time class Main: rf: Roboflow project: object dataset: object model: object results: object model_size: str def __init__(self): self.model_size = sys.argv[6] self.import_dataset() self.train() def import_dataset(self): self.rf = Roboflow(api_key=sys.argv[1]) self.project = self.rf.workspace(sys.argv[2]).project(sys.argv[3]) self.dataset = self.project.version(sys.argv[4]).download("yolov8-obb") with open(f'{self.dataset.location}/data.yaml', 'r') as file: data = yaml.safe_load(file) data['path'] = self.dataset.location with open(f'{self.dataset.location}/data.yaml', 'w') as file: yaml.dump(data, file, sort_keys=False) def train(self): list_of_models = ["n", "s", "m", "l", "x"] if self.model_size != "ALL" and self.model_size in list_of_models: self.model = YOLO(f"yolov8{self.model_size}-obb.pt") self.results = self.model.train(data=f"{self.dataset.location}/" f"yolov8-obb.yaml", epochs=int(sys.argv[5]), imgsz=640) elif self.model_size == "ALL": for model_size in list_of_models: self.model = YOLO(f"yolov8{model_size}.pt") self.results = self.model.train(data=f"{self.dataset.location}" f"/yolov8-obb.yaml", epochs=int(sys.argv[5]), imgsz=640) else: print("Invalid model size") if __name__ == '__main__': Main()
训练模型后,得到文件best.py和last.py,分别对应权重。
通过ultralytics库,您还可以导入YOLO并加载您的体重,然后加载您的测试视频。
在此示例中,我使用跟踪功能来获取每个游泳者的 ID。
import cv2 from ultralytics import YOLO import sys def main(): cap = cv2.VideoCapture(sys.argv[1]) model = YOLO(sys.argv[2]) while True: ret, frame = cap.read() results = model.track(frame, persist=True) res_plotted = results[0].plot() cv2.imshow("frame", res_plotted) if cv2.waitKey(1) == 27: break cap.release() cv2.destroyAllWindows() if __name__ == "__main__": main()
为了分析预测,可以获取模型json如下。
results = model.track(frame, persist=True) results_json = json.loads(results[0].tojson())
免责声明: 提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发到邮箱:[email protected] 我们会第一时间内为您处理。
Copyright© 2022 湘ICP备2022001581号-3