diff --git a/pom.xml b/pom.xml
index 12d1395..e78337b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -140,6 +140,11 @@
org.springframework.boot
spring-boot-starter-data-redis
+
+ org.redisson
+ redisson-spring-boot-starter
+ 3.23.5
+
diff --git a/src/main/java/com/project/milvus/application/MilvusApplicationService.java b/src/main/java/com/project/milvus/application/MilvusApplicationService.java
new file mode 100644
index 0000000..f792964
--- /dev/null
+++ b/src/main/java/com/project/milvus/application/MilvusApplicationService.java
@@ -0,0 +1,10 @@
+package com.project.milvus.application;
+
+import com.project.milvus.domain.dto.TitleVector;
+
+public interface MilvusApplicationService {
+ /**
+ * 题目入向量数据库
+ */
+ void insertTitle(TitleVector title);
+}
diff --git a/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java b/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java
new file mode 100644
index 0000000..f716511
--- /dev/null
+++ b/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java
@@ -0,0 +1,98 @@
+package com.project.milvus.application.impl;
+
+import cn.hutool.core.collection.CollectionUtil;
+import com.project.base.config.CustomIdGenerator;
+import com.project.base.domain.exception.MissingParameterException;
+import com.project.milvus.application.MilvusApplicationService;
+import com.project.milvus.domain.dto.TitleVector;
+import com.project.milvus.domain.service.CheckMilvusDomainService;
+import com.project.milvus.domain.service.MilvusDemoService;
+import io.milvus.v2.service.vector.response.SearchResp;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.codec.digest.DigestUtils;
+import org.redisson.api.RLock;
+import org.redisson.api.RedissonClient;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.data.redis.core.StringRedisTemplate;
+import org.springframework.stereotype.Service;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+@Service
+@Slf4j
+public class MilvusApplicationServiceImpl implements MilvusApplicationService {
+ @Autowired
+ private MilvusDemoService milvusDemoService;
+ @Autowired
+ private CheckMilvusDomainService checkMilvusDomainService;
+ @Autowired
+ private CustomIdGenerator customIdGenerator;
+ @Autowired
+ private RedissonClient redissonClient;
+
+
+ //相似度阈值
+ private static final Float SIMILARITY_THRESHOLD = 0.8f;
+ private static final String LOCK_KEY = "lock:title:pointIds:";
+
+ @Override
+ public void insertTitle(TitleVector title) {
+ //参数校验
+ checkMilvusDomainService.check(title);
+
+ String lockKey = buildLockKey(title.getPointIdsList());
+ RLock lock = redissonClient.getLock(lockKey);
+
+ boolean locked = false;
+ int retry = 3;
+ try {
+ while (retry > 0 && !locked) {
+ // 等待 2s,拿到锁后10s自动释放
+ locked = lock.tryLock(2, 10, TimeUnit.SECONDS);
+ if (!locked) {
+ //休眠等待重试
+ Thread.sleep(100 + new Random().nextInt(100));
+ }
+ }
+
+ if (!locked) {
+ throw new RuntimeException("当前知识点正在处理,请稍后再试");
+ }
+
+ //比较相似度
+ List> query = milvusDemoService.query(title.getPointIdsList(), title.getTitleVectorList());
+ if (CollectionUtil.isNotEmpty(query) && CollectionUtil.isNotEmpty(query.get(0))) {
+ SearchResp.SearchResult searchResult = query.get(0).get(0);
+ Float score = searchResult.getScore();
+ if(score.compareTo(SIMILARITY_THRESHOLD) == 1){
+ throw new MissingParameterException("题目相似度"+ score +",超过阈值");
+ }
+ }
+
+ //新增数据
+ title.setId(customIdGenerator.nextId(title));
+ milvusDemoService.insertRecord(title);
+
+ }catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException("获取锁被中断", e);
+ } finally {
+ if (locked && lock.isHeldByCurrentThread()) {
+ lock.unlock();
+ }
+ }
+ }
+
+ /**
+ * 构建锁的key
+ */
+ private String buildLockKey(List poinIds){
+ Collections.sort(poinIds);
+ String collect = poinIds.stream().map(String::valueOf).collect(Collectors.joining(","));
+ return LOCK_KEY + DigestUtils.md5Hex(collect);
+ }
+}
diff --git a/src/main/java/com/project/milvus/controller/MilvusController.java b/src/main/java/com/project/milvus/controller/MilvusController.java
index 49aa46e..e522700 100644
--- a/src/main/java/com/project/milvus/controller/MilvusController.java
+++ b/src/main/java/com/project/milvus/controller/MilvusController.java
@@ -1,6 +1,9 @@
package com.project.milvus.controller;
-import com.project.milvus.service.MilvusDemoService;
+import com.project.base.domain.result.Result;
+import com.project.milvus.application.MilvusApplicationService;
+import com.project.milvus.domain.dto.TitleVector;
+import com.project.milvus.domain.service.MilvusDemoService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@@ -9,8 +12,15 @@ import org.springframework.web.bind.annotation.*;
@RequestMapping("/milvus")
public class MilvusController {
@Autowired
- private MilvusDemoService milvusDemoService;
-
+ private MilvusApplicationService milvusApplicationService;
+ /**
+ * 题目入向量数据库
+ */
+ @PostMapping("/insert")
+ public Result insert(@RequestBody TitleVector title) {
+ milvusApplicationService.insertTitle(title);
+ return Result.success(null);
+ }
}
diff --git a/src/main/java/com/project/milvus/domain/TitleVector.java b/src/main/java/com/project/milvus/domain/dto/TitleVector.java
similarity index 94%
rename from src/main/java/com/project/milvus/domain/TitleVector.java
rename to src/main/java/com/project/milvus/domain/dto/TitleVector.java
index c30d7f1..669eb8b 100644
--- a/src/main/java/com/project/milvus/domain/TitleVector.java
+++ b/src/main/java/com/project/milvus/domain/dto/TitleVector.java
@@ -1,4 +1,4 @@
-package com.project.milvus.domain;
+package com.project.milvus.domain.dto;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
diff --git a/src/main/java/com/project/milvus/domain/service/CheckMilvusDomainService.java b/src/main/java/com/project/milvus/domain/service/CheckMilvusDomainService.java
new file mode 100644
index 0000000..e1500fe
--- /dev/null
+++ b/src/main/java/com/project/milvus/domain/service/CheckMilvusDomainService.java
@@ -0,0 +1,7 @@
+package com.project.milvus.domain.service;
+
+import com.project.milvus.domain.dto.TitleVector;
+
+public interface CheckMilvusDomainService {
+ void check(TitleVector title);
+}
diff --git a/src/main/java/com/project/milvus/service/MilvusDemoService.java b/src/main/java/com/project/milvus/domain/service/MilvusDemoService.java
similarity index 85%
rename from src/main/java/com/project/milvus/service/MilvusDemoService.java
rename to src/main/java/com/project/milvus/domain/service/MilvusDemoService.java
index ef068d1..82304cd 100644
--- a/src/main/java/com/project/milvus/service/MilvusDemoService.java
+++ b/src/main/java/com/project/milvus/domain/service/MilvusDemoService.java
@@ -1,6 +1,6 @@
-package com.project.milvus.service;
+package com.project.milvus.domain.service;
-import com.project.milvus.domain.TitleVector;
+import com.project.milvus.domain.dto.TitleVector;
import io.milvus.v2.service.vector.response.GetResp;
import io.milvus.v2.service.vector.response.SearchResp;
diff --git a/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java b/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java
new file mode 100644
index 0000000..5aa3aff
--- /dev/null
+++ b/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java
@@ -0,0 +1,33 @@
+package com.project.milvus.domain.service.impl;
+
+import cn.hutool.core.collection.CollectionUtil;
+import com.project.base.domain.exception.MissingParameterException;
+import com.project.milvus.domain.dto.TitleVector;
+import com.project.milvus.domain.service.CheckMilvusDomainService;
+import org.springframework.stereotype.Service;
+
+import java.util.List;
+
+/**
+ * 校验扩展
+ */
+@Service
+public class CheckMilvusDomainServiceImpl implements CheckMilvusDomainService {
+ @Override
+ public void check(TitleVector title) {
+ if (title == null) {
+ throw new MissingParameterException("请求参数缺失或格式错误");
+ }
+
+ List pointIdsList = title.getPointIdsList();
+ List titleVectorList = title.getTitleVectorList();
+
+ if (CollectionUtil.isEmpty(pointIdsList) || CollectionUtil.isEmpty(titleVectorList)) {
+ throw new MissingParameterException("请求参数缺失或格式错误");
+ }
+
+ if(titleVectorList.size() != 1024){
+ throw new MissingParameterException("向量格式错误");
+ }
+ }
+}
diff --git a/src/main/java/com/project/milvus/service/impl/MilvusDemoServiceImpl.java b/src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java
similarity index 92%
rename from src/main/java/com/project/milvus/service/impl/MilvusDemoServiceImpl.java
rename to src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java
index 9c6585c..babdcc3 100644
--- a/src/main/java/com/project/milvus/service/impl/MilvusDemoServiceImpl.java
+++ b/src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java
@@ -1,8 +1,8 @@
-package com.project.milvus.service.impl;
+package com.project.milvus.domain.service.impl;
import com.google.gson.JsonObject;
-import com.project.milvus.domain.TitleVector;
-import com.project.milvus.service.MilvusDemoService;
+import com.project.milvus.domain.dto.TitleVector;
+import com.project.milvus.domain.service.MilvusDemoService;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
@@ -15,7 +15,6 @@ import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.response.GetResp;
import io.milvus.v2.service.vector.response.InsertResp;
import io.milvus.v2.service.vector.response.SearchResp;
-import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@@ -111,6 +110,10 @@ public class MilvusDemoServiceImpl implements MilvusDemoService {
public List> query(List pointIds,List titleVector){
StringBuffer expr = new StringBuffer();
for (int i = 0; i < pointIds.size(); i++) {
+ if (pointIds.size() == 1){
+ expr.append("array_contains(point_ids, " + pointIds.get(i) + ") " + " && array_length(point_ids) == 1");
+ break;
+ }
if(i == pointIds.size() - 1){
expr.append("array_contains(point_ids, " + pointIds.get(i) + ")");
continue;