yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo

2年前 (2022) 程序员胖胖胖虎阿
375 0 0

目录

  • 前言
  • 环境
  • 导出yolov5s模型
  • 编写Maven项目
    • 编写pom.xml文件
    • 引入opencv依赖
      • 下载opencv
      • 获取opencv的jar包和动态链接库dll文件
      • 将lib文件夹添加为Library
    • 将yolov5权重文件放到资源文件
    • 编写代码
  • 运行程序
  • 补充
  • MyUtils.mat2Image

前言

这篇博客主要是介绍如何通过djl在java中调用yolov5进行推理,顺便也学习了一下在java上的opencv api。
Deep Java Library是由亚马逊(Amazon)提供的一个深度学习工具包,能够让java开发者在java上调用目前主流的深度学习框架,像pytorch、tensorflow、mxnet、paddlepaddle(飞桨居然也有份😂),也包括onnx格式的模型。

环境

  • idea&pycharm
  • torch1.8.1+cu111
  • java1.8
  • Deep Java Library
  • yolov5 release v5.0
  • opencv 4.5.2

导出yolov5s模型

这次demo就直接使用yolov5s的预训练模型。yolov5项目本身就自带了非常完善的模型导出脚本,yolov5的5.0发行版也比之前的版本完善很多。
yolov5的模型导出脚本是models/export.py文件,
yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo
导出之前需要设置一下

  • 权重文件的位置
  • 输入图片的尺寸
  • 是否要输出bbox
  • 模型所在设备
    yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo
    上图红色的框按我的进行设置就行了,绿色的框根据自己的情况进行设置。
    设置好以后运行代码就可以在和权重文件相同的位置找到生成的torchscript模型权重。
    yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo

编写Maven项目

编写pom.xml文件

djl使用pytorch需要引入相关依赖

  • pytorch-model-zoo
  • pytorch-engine
  • pytorch-native-auto
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>xyz.hyhy</groupId>
    <artifactId>TestAI</artifactId>
    <version>1.0-SNAPSHOT</version>
    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <djl.version>0.11.0</djl.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>${djl.version}</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.8.1</version>
        </dependency>
    </dependencies>
</project>

引入opencv依赖

下载opencv

到官网下载opencv库
yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo

获取opencv的jar包和动态链接库dll文件

下载完会得到一个exe文件,实际只是个压缩包,解压后到build文件夹下将jar包和x64或x86文件夹下的dll文件一起复制到项目的lib文件夹下。dll文件根据自己系统是64位还是32位进行选择。
yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo

将lib文件夹添加为Library

yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo

将yolov5权重文件放到资源文件

将之前导出的yolov5s.torchscript.pt文件放到resources/yolov5s文件夹下。另外还要编写一个coco.names文件,用来说明分类任务的类名。
yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo
coco.names

person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
pottedplant
bed
diningtable
toilet
tvmonitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

编写代码

package xyz.hyhy;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.DetectedObjects.DetectedObject;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.YoloV5Translator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import org.opencv.core.*;
import org.opencv.highgui.HighGui;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import xyz.hyhy.utils.MyUtils;

import java.io.IOException;

import static org.opencv.videoio.Videoio.CAP_ANY;

public class Main {

    static {
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
    }

    public static void main(String[] args) {
        Translator<Image, DetectedObjects> translator = YoloV5Translator.builder().optSynsetArtifactName("coco.names").build();
        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .setTypes(Image.class, DetectedObjects.class)
                        .optDevice(Device.cpu())
                        .optModelUrls(Main.class.getResource("/yolov5s").getPath())
                        .optModelName("yolov5s.torchscript.pt")
                        .optTranslator(translator)
                        .optEngine("PyTorch")
                        .build();
//        Criteria<Image, DetectedObjects> criteria =
//                Criteria.builder()
//                        .setTypes(Image.class, DetectedObjects.class)
//                        .optDevice(Device.cpu())
//                        .optModelUrls(Main.class.getResource("/yolov5").getPath())
//                        .optModelName("yolov5s.onnx")
//                        .optTranslator(translator)
//                        .optEngine("OnnxRuntime")
//                        .build();
        try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
            VideoCapture cap = new VideoCapture(CAP_ANY);
            if (!cap.isOpened()) {//isOpened函数用来判断摄像头调用是否成功
                System.out.println("Camera Error");//如果摄像头调用失败,输出错误信息
            } else {
                Mat frame = new Mat();//创建一个输出帧
                boolean flag = cap.read(frame);//read方法读取摄像头的当前帧
                while (flag) {
                    detect(frame, model);
                    HighGui.imshow("yolov5", frame);
                    HighGui.waitKey(20);
                    flag = cap.read(frame);
                }
            }

        } catch (RuntimeException | ModelException | TranslateException | IOException e) {
            e.printStackTrace();
        }
    }

    static Rect rect = new Rect();
    static Scalar color = new Scalar(0, 255, 0);

    static void detect(Mat frame, ZooModel<Image, DetectedObjects> model) throws IOException, ModelNotFoundException, MalformedModelException, TranslateException {
        Image img = MyUtils.mat2Image(frame);
        long startTime = System.currentTimeMillis();
        try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
            DetectedObjects results = predictor.predict(img);
//            System.out.println(results);
            for (DetectedObject obj : results.<DetectedObject>items()) {
                BoundingBox bbox = obj.getBoundingBox();
                Rectangle rectangle = bbox.getBounds();
                String showText = String.format("%s: %.2f", obj.getClassName(), obj.getProbability());
                rect.x = (int) rectangle.getX();
                rect.y = (int) rectangle.getY();
                rect.width = (int) rectangle.getWidth();
                rect.height = (int) rectangle.getHeight();
                // 画框

                Imgproc.rectangle(frame, rect, color, 2);
                //画名字
                Imgproc.putText(frame, showText,
                        new Point(rect.x, rect.y),
                        Imgproc.FONT_HERSHEY_COMPLEX,
                        rectangle.getWidth() / 200,
                        color);
            }
        }
        System.out.println(String.format("%.2f", 1000.0 / (System.currentTimeMillis() - startTime)));
    }
}


运行程序

程序启动时,会卡住一段时间,不过不要慌,因为djl需要下载pytorch的动态链接库,下载的位置在%USERPROFILE%\.djl.ai\pytorch目录下。可以看一下加速球的流量消耗或者到对应文件夹下确认是否有在下载。
下载的实际上就是libtorch里面的那些动态链接库。djl会根据你的系统自动选择下载合适的版本(应该)。
yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo
效果:
yolov5+opencv+java:通过DJL在maven项目中使用yolov5的小demo

补充

之后测试了onnx的yolov5s模型,onnx的推理速度更快,速度大概是torchscript的3倍。

MyUtils.mat2Image

    public static Image mat2Image(Mat mat) {
        return ImageFactory.getInstance().fromImage(HighGui.toBufferedImage(mat));
    }

相关文章

暂无评论

暂无评论...