diff --git a/pom.xml b/pom.xml index 12d1395..e78337b 100644 --- a/pom.xml +++ b/pom.xml @@ -140,6 +140,11 @@ org.springframework.boot spring-boot-starter-data-redis + + org.redisson + redisson-spring-boot-starter + 3.23.5 + diff --git a/src/main/java/com/project/milvus/application/MilvusApplicationService.java b/src/main/java/com/project/milvus/application/MilvusApplicationService.java new file mode 100644 index 0000000..f792964 --- /dev/null +++ b/src/main/java/com/project/milvus/application/MilvusApplicationService.java @@ -0,0 +1,10 @@ +package com.project.milvus.application; + +import com.project.milvus.domain.dto.TitleVector; + +public interface MilvusApplicationService { + /** + * 题目入向量数据库 + */ + void insertTitle(TitleVector title); +} diff --git a/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java b/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java new file mode 100644 index 0000000..f716511 --- /dev/null +++ b/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java @@ -0,0 +1,98 @@ +package com.project.milvus.application.impl; + +import cn.hutool.core.collection.CollectionUtil; +import com.project.base.config.CustomIdGenerator; +import com.project.base.domain.exception.MissingParameterException; +import com.project.milvus.application.MilvusApplicationService; +import com.project.milvus.domain.dto.TitleVector; +import com.project.milvus.domain.service.CheckMilvusDomainService; +import com.project.milvus.domain.service.MilvusDemoService; +import io.milvus.v2.service.vector.response.SearchResp; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.codec.digest.DigestUtils; +import org.redisson.api.RLock; +import org.redisson.api.RedissonClient; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.stereotype.Service; + +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +@Service +@Slf4j +public class MilvusApplicationServiceImpl implements MilvusApplicationService { + @Autowired + private MilvusDemoService milvusDemoService; + @Autowired + private CheckMilvusDomainService checkMilvusDomainService; + @Autowired + private CustomIdGenerator customIdGenerator; + @Autowired + private RedissonClient redissonClient; + + + //相似度阈值 + private static final Float SIMILARITY_THRESHOLD = 0.8f; + private static final String LOCK_KEY = "lock:title:pointIds:"; + + @Override + public void insertTitle(TitleVector title) { + //参数校验 + checkMilvusDomainService.check(title); + + String lockKey = buildLockKey(title.getPointIdsList()); + RLock lock = redissonClient.getLock(lockKey); + + boolean locked = false; + int retry = 3; + try { + while (retry > 0 && !locked) { + // 等待 2s,拿到锁后10s自动释放 + locked = lock.tryLock(2, 10, TimeUnit.SECONDS); + if (!locked) { + //休眠等待重试 + Thread.sleep(100 + new Random().nextInt(100)); + } + } + + if (!locked) { + throw new RuntimeException("当前知识点正在处理,请稍后再试"); + } + + //比较相似度 + List> query = milvusDemoService.query(title.getPointIdsList(), title.getTitleVectorList()); + if (CollectionUtil.isNotEmpty(query) && CollectionUtil.isNotEmpty(query.get(0))) { + SearchResp.SearchResult searchResult = query.get(0).get(0); + Float score = searchResult.getScore(); + if(score.compareTo(SIMILARITY_THRESHOLD) == 1){ + throw new MissingParameterException("题目相似度"+ score +",超过阈值"); + } + } + + //新增数据 + title.setId(customIdGenerator.nextId(title)); + milvusDemoService.insertRecord(title); + + }catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("获取锁被中断", e); + } finally { + if (locked && lock.isHeldByCurrentThread()) { + lock.unlock(); + } + } + } + + /** + * 构建锁的key + */ + private String buildLockKey(List poinIds){ + Collections.sort(poinIds); + String collect = poinIds.stream().map(String::valueOf).collect(Collectors.joining(",")); + return LOCK_KEY + DigestUtils.md5Hex(collect); + } +} diff --git a/src/main/java/com/project/milvus/controller/MilvusController.java b/src/main/java/com/project/milvus/controller/MilvusController.java index 49aa46e..e522700 100644 --- a/src/main/java/com/project/milvus/controller/MilvusController.java +++ b/src/main/java/com/project/milvus/controller/MilvusController.java @@ -1,6 +1,9 @@ package com.project.milvus.controller; -import com.project.milvus.service.MilvusDemoService; +import com.project.base.domain.result.Result; +import com.project.milvus.application.MilvusApplicationService; +import com.project.milvus.domain.dto.TitleVector; +import com.project.milvus.domain.service.MilvusDemoService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; @@ -9,8 +12,15 @@ import org.springframework.web.bind.annotation.*; @RequestMapping("/milvus") public class MilvusController { @Autowired - private MilvusDemoService milvusDemoService; - + private MilvusApplicationService milvusApplicationService; + /** + * 题目入向量数据库 + */ + @PostMapping("/insert") + public Result insert(@RequestBody TitleVector title) { + milvusApplicationService.insertTitle(title); + return Result.success(null); + } } diff --git a/src/main/java/com/project/milvus/domain/TitleVector.java b/src/main/java/com/project/milvus/domain/dto/TitleVector.java similarity index 94% rename from src/main/java/com/project/milvus/domain/TitleVector.java rename to src/main/java/com/project/milvus/domain/dto/TitleVector.java index c30d7f1..669eb8b 100644 --- a/src/main/java/com/project/milvus/domain/TitleVector.java +++ b/src/main/java/com/project/milvus/domain/dto/TitleVector.java @@ -1,4 +1,4 @@ -package com.project.milvus.domain; +package com.project.milvus.domain.dto; import com.google.gson.Gson; import com.google.gson.JsonElement; diff --git a/src/main/java/com/project/milvus/domain/service/CheckMilvusDomainService.java b/src/main/java/com/project/milvus/domain/service/CheckMilvusDomainService.java new file mode 100644 index 0000000..e1500fe --- /dev/null +++ b/src/main/java/com/project/milvus/domain/service/CheckMilvusDomainService.java @@ -0,0 +1,7 @@ +package com.project.milvus.domain.service; + +import com.project.milvus.domain.dto.TitleVector; + +public interface CheckMilvusDomainService { + void check(TitleVector title); +} diff --git a/src/main/java/com/project/milvus/service/MilvusDemoService.java b/src/main/java/com/project/milvus/domain/service/MilvusDemoService.java similarity index 85% rename from src/main/java/com/project/milvus/service/MilvusDemoService.java rename to src/main/java/com/project/milvus/domain/service/MilvusDemoService.java index ef068d1..82304cd 100644 --- a/src/main/java/com/project/milvus/service/MilvusDemoService.java +++ b/src/main/java/com/project/milvus/domain/service/MilvusDemoService.java @@ -1,6 +1,6 @@ -package com.project.milvus.service; +package com.project.milvus.domain.service; -import com.project.milvus.domain.TitleVector; +import com.project.milvus.domain.dto.TitleVector; import io.milvus.v2.service.vector.response.GetResp; import io.milvus.v2.service.vector.response.SearchResp; diff --git a/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java b/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java new file mode 100644 index 0000000..5aa3aff --- /dev/null +++ b/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java @@ -0,0 +1,33 @@ +package com.project.milvus.domain.service.impl; + +import cn.hutool.core.collection.CollectionUtil; +import com.project.base.domain.exception.MissingParameterException; +import com.project.milvus.domain.dto.TitleVector; +import com.project.milvus.domain.service.CheckMilvusDomainService; +import org.springframework.stereotype.Service; + +import java.util.List; + +/** + * 校验扩展 + */ +@Service +public class CheckMilvusDomainServiceImpl implements CheckMilvusDomainService { + @Override + public void check(TitleVector title) { + if (title == null) { + throw new MissingParameterException("请求参数缺失或格式错误"); + } + + List pointIdsList = title.getPointIdsList(); + List titleVectorList = title.getTitleVectorList(); + + if (CollectionUtil.isEmpty(pointIdsList) || CollectionUtil.isEmpty(titleVectorList)) { + throw new MissingParameterException("请求参数缺失或格式错误"); + } + + if(titleVectorList.size() != 1024){ + throw new MissingParameterException("向量格式错误"); + } + } +} diff --git a/src/main/java/com/project/milvus/service/impl/MilvusDemoServiceImpl.java b/src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java similarity index 92% rename from src/main/java/com/project/milvus/service/impl/MilvusDemoServiceImpl.java rename to src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java index 9c6585c..babdcc3 100644 --- a/src/main/java/com/project/milvus/service/impl/MilvusDemoServiceImpl.java +++ b/src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java @@ -1,8 +1,8 @@ -package com.project.milvus.service.impl; +package com.project.milvus.domain.service.impl; import com.google.gson.JsonObject; -import com.project.milvus.domain.TitleVector; -import com.project.milvus.service.MilvusDemoService; +import com.project.milvus.domain.dto.TitleVector; +import com.project.milvus.domain.service.MilvusDemoService; import io.milvus.v2.client.MilvusClientV2; import io.milvus.v2.common.DataType; import io.milvus.v2.common.IndexParam; @@ -15,7 +15,6 @@ import io.milvus.v2.service.vector.request.data.FloatVec; import io.milvus.v2.service.vector.response.GetResp; import io.milvus.v2.service.vector.response.InsertResp; import io.milvus.v2.service.vector.response.SearchResp; -import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -111,6 +110,10 @@ public class MilvusDemoServiceImpl implements MilvusDemoService { public List> query(List pointIds,List titleVector){ StringBuffer expr = new StringBuffer(); for (int i = 0; i < pointIds.size(); i++) { + if (pointIds.size() == 1){ + expr.append("array_contains(point_ids, " + pointIds.get(i) + ") " + " && array_length(point_ids) == 1"); + break; + } if(i == pointIds.size() - 1){ expr.append("array_contains(point_ids, " + pointIds.get(i) + ")"); continue;