Java【代码 16】Milvus向量库工具类和SeetaSDK获取人像向量和属性的工具类分享

Milvus向量库和SeetaSDK工具类分享

1.Milvus向量库工具类

Milvus的Maven依赖:

        <dependency>
            <groupId>io.milvus</groupId>
            <artifactId>milvus-sdk-java</artifactId>
            <version>2.1.0</version>
            <exclusions>
                <exclusion>
                    <artifactId>log4j-slf4j-impl</artifactId>
                    <groupId>org.apache.logging.log4j</groupId>
                </exclusion>
            </exclusions>
        </dependency>

向量库的配置类:

@Data
@Component
@ConfigurationProperties(MilvusConfiguration.PREFIX)
public class MilvusConfiguration {
    
    
    public static final String PREFIX = "milvus-config";

    public String host;
    public int port;
    public String collectionName;

}

工具类主类:

@Slf4j
@Component
public class MilvusUtil {
    
    

    @Resource
    private MilvusConfiguration milvusConfiguration;

    private MilvusServiceClient milvusServiceClient;

    @PostConstruct
    private void connectToServer() {
    
    
        milvusServiceClient = new MilvusServiceClient(
                ConnectParam.newBuilder()
                        .withHost(milvusConfiguration.host)
                        .withPort(milvusConfiguration.port)
                        .build());
        // 加载数据
        LoadCollectionParam faceSearchNewLoad = LoadCollectionParam.newBuilder()
                .withCollectionName(milvusConfiguration.collectionName).build();
        R<RpcStatus> rpcStatusR = milvusServiceClient.loadCollection(faceSearchNewLoad);
        log.info("Milvus LoadCollection [{}]", rpcStatusR.getStatus() == 0 ? "Successful!" : "Failed!");
    }
}

主类里的数据入库方法:

    public int insertDataToMilvus(String id, String path, float[] feature) {
    
    
        List<InsertParam.Field> fields = new ArrayList<>();
        List<Float> featureList = new ArrayList<>(feature.length);
        for (float v : feature) {
    
    
            featureList.add(v);
        }
        fields.add(new InsertParam.Field("id", Collections.singletonList(id)));
        fields.add(new InsertParam.Field("image_path", Collections.singletonList(path)));
        fields.add(new InsertParam.Field("image_feature", Collections.singletonList(featureList)));
        InsertParam insertParam = InsertParam.newBuilder()
                .withCollectionName(milvusConfiguration.collectionName)
                //.withPartitionName("novel")
                .withFields(fields)
                .build();
        R<MutationResult> insert = milvusServiceClient.insert(insertParam);
        return insert.getStatus();
    }

主类类的数据查询方法:

  • 这里的topK没有进行参数化。
    public List<MilvusRes> searchImageByFeatureVector(float[] feature) {
    
    
        List<Float> featureList = new ArrayList<>(feature.length);
        for (float v : feature) {
    
    
            featureList.add(v);
        }
        List<String> queryOutputFields = Arrays.asList("image_path");

        SearchParam faceSearch = SearchParam.newBuilder()
                .withCollectionName(milvusConfiguration.collectionName)
                .withMetricType(MetricType.IP)
                .withVectorFieldName("image_feature")
                .withVectors(Collections.singletonList(featureList))
                .withOutFields(queryOutputFields)
                .withRoundDecimal(3)
                .withTopK(10).build();
        // 执行搜索
        long l = System.currentTimeMillis();
        R<SearchResults> respSearch = milvusServiceClient.search(faceSearch);
        log.info("MilvusServiceClient.search cost [{}]", System.currentTimeMillis() - l);
        // 解析结果数据
        SearchResultData results = respSearch.getData().getResults();
        int scoresCount = results.getScoresCount();
        SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(results);
        List<MilvusRes> milvusResList = new ArrayList<>();
        for (int i = 0; i < scoresCount; i++) {
    
    
            float score = wrapperSearch.getIDScore(0).get(i).getScore();
            Object imagePath = wrapperSearch.getFieldData("image_path", 0).get(i);
            MilvusRes milvusRes = MilvusRes.builder().score(score).imagePath(imagePath.toString()).build();
            milvusResList.add(milvusRes);
        }
        return milvusResList;
    }

2.SeetaSDK工具类

SeetaSDK的Maven依赖:

        <dependency>
            <groupId>com.seeta</groupId>
            <artifactId>sdk</artifactId>
            <version>1.2.1</version>
            <scope>system</scope>
            <systemPath>${project.basedir}/lib/seeta-sdk-platform-1.2.1.jar</systemPath>
        </dependency>
       		<!--注意-->
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <includeSystemScope>true</includeSystemScope>
                </configuration>
            </plugin>       

jar是从官网下的源码进行的打包:

在这里插入图片描述
工具类主类:

@Slf4j
@Component
public class FaceUtil {
    
    

    static {
    
    
        // 加载本地方法
        LoadNativeCore.LOAD_NATIVE(SeetaDevice.SEETA_DEVICE_AUTO);
    }

    @Resource
    private SeetaModelConfiguration seetaModelConfiguration;

    private FaceDetectorProxy faceDetectorProxy;
    private FaceLandmarkerProxy faceLandmarkerProxy;
    private FaceRecognizerProxy faceRecognizerProxy;
    private AgePredictorProxy agePredictorProxy;
    private GenderPredictorProxy genderPredictorProxy;
    private MaskDetectorProxy maskDetectorProxy;
    private EyeStateDetectorProxy eyeStateDetectorProxy;

}

主类里的初始方法:

    @PostConstruct
    private void inti() {
    
    
        String basePath = seetaModelConfiguration.basePath;
        try {
    
    
            // 人脸识别检测器对象池配置
            SeetaConfSetting detectorPoolSetting = new SeetaConfSetting(
                    new SeetaModelSetting(0, new String[]{
    
    basePath + seetaModelConfiguration.faceDetectorModelFileName},
                            SeetaDevice.SEETA_DEVICE_AUTO));
            faceDetectorProxy = new FaceDetectorProxy(detectorPoolSetting);
            // 关键点定位器【默认使用5点可通过配置切换为68点】
            SeetaConfSetting faceLandmarkerPoolSetting = new SeetaConfSetting(
                    new SeetaModelSetting(1, new String[]{
    
    basePath + seetaModelConfiguration.faceLandmarkerModelFileName},
                            SeetaDevice.SEETA_DEVICE_AUTO));
            faceLandmarkerProxy = new FaceLandmarkerProxy(faceLandmarkerPoolSetting);
            // 人脸向量特征提取和对比器
            SeetaConfSetting faceRecognizerPoolSetting = new SeetaConfSetting(
                    new SeetaModelSetting(2, new String[]{
    
    basePath + seetaModelConfiguration.faceRecognizerModelFileName},
                            SeetaDevice.SEETA_DEVICE_AUTO));
            faceRecognizerProxy = new FaceRecognizerProxy(faceRecognizerPoolSetting);
            // 年龄评估器
            SeetaConfSetting agePredictorPoolSetting = new SeetaConfSetting(
                    new SeetaModelSetting(3, new String[]{
    
    basePath + seetaModelConfiguration.agePredictorModelFileName},
                            SeetaDevice.SEETA_DEVICE_AUTO));
            agePredictorProxy = new AgePredictorProxy(agePredictorPoolSetting);
            // 性别识别器
            SeetaConfSetting genderPredictorPoolSetting = new SeetaConfSetting(
                    new SeetaModelSetting(4, new String[]{
    
    basePath + seetaModelConfiguration.genderPredictorModelFileName},
                            SeetaDevice.SEETA_DEVICE_AUTO));
            genderPredictorProxy = new GenderPredictorProxy(genderPredictorPoolSetting);
            // 口罩检测器
            SeetaConfSetting maskDetectorPoolSetting = new SeetaConfSetting(
                    new SeetaModelSetting(5, new String[]{
    
    basePath + seetaModelConfiguration.maskDetectorModelFileName},
                            SeetaDevice.SEETA_DEVICE_AUTO));
            maskDetectorProxy = new MaskDetectorProxy(maskDetectorPoolSetting);
            // 眼睛状态检测
            SeetaConfSetting eyeStaterPoolSetting = new SeetaConfSetting(
                    new SeetaModelSetting(5, new String[]{
    
    basePath + seetaModelConfiguration.eyeStateModelFileName},
                            SeetaDevice.SEETA_DEVICE_AUTO));
            eyeStateDetectorProxy = new EyeStateDetectorProxy(eyeStaterPoolSetting);
        } catch (Exception e) {
    
    
            e.printStackTrace();
        }
    }

主类里的根据图片路径获取脸部特征向量方法:

    /**
     * 根据图片路径获取脸部特征向量
     *
     * @param imagePath 图片路径
     * @return 脸部特征向量
     */
    public float[] getFaceFeaturesByPath(String imagePath) {
    
    
        try {
    
    
            // 照片人脸识别
            SeetaImageData image = SeetafaceUtil.toSeetaImageData(imagePath);
            SeetaRect[] detects = faceDetectorProxy.detect(image);
            // 人脸关键点定位【主驾或副驾仅有一个人脸,多个人脸仅取第一个】
            if (detects.length > 0) {
    
    
                SeetaPointF[] pointFace = faceLandmarkerProxy.mark(image, detects[0]);
                // 人脸向量特征提取features
                return faceRecognizerProxy.extract(image, pointFace);
            }
        } catch (Exception e) {
    
    
            e.printStackTrace();
        }
        return null;
    }

主类里的根据人像图片的路径获取其属性【年龄、性别、是否戴口罩、眼睛状态】方法:

    /**
     * 根据人像图片的路径获取其属性【年龄、性别、是否戴口罩、眼睛状态】
     *
     * @param imagePath 图片路径
     * @return 图片属性 MAP 对象
     */
    public Map<String, Object> getAttributeByPath(String imagePath) {
    
    
        long l = System.currentTimeMillis();
        Map<String, Object> attributeMap = new HashMap<>(4);
        try {
    
    
            // 监测人脸
            SeetaImageData image = SeetafaceUtil.toSeetaImageData(imagePath);
            SeetaRect[] detects = faceDetectorProxy.detect(image);
            if (detects.length > 0) {
    
    
                SeetaPointF[] pointFace = faceLandmarkerProxy.mark(image, detects[0]);
                // 获取年龄
                int age = agePredictorProxy.predictAgeWithCrop(image, pointFace);
                attributeMap.put("age", age);
                // 性别
                GenderPredictor.GENDER gender = genderPredictorProxy.predictGenderWithCrop(image, pointFace).getGender();
                attributeMap.put("gender", gender);
                // 口罩
                boolean mask = maskDetectorProxy.detect(image, detects[0]).getMask();
                attributeMap.put("mask", mask);
                // 眼睛
                EyeStateDetector.EYE_STATE[] eyeStates = eyeStateDetectorProxy.detect(image, pointFace);
                attributeMap.put("eye", Arrays.toString(eyeStates));
                log.info("getAttributeByPath [{}] cost [{}]", imagePath, System.currentTimeMillis() - l);
            }
        } catch (Exception e) {
    
    
            e.printStackTrace();
            return attributeMap;
        }
        return attributeMap;
    }

猜你喜欢

转载自blog.csdn.net/weixin_39168541/article/details/131601949