diff --git a/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java b/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java index f716511..107f578 100644 --- a/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java +++ b/src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java @@ -1,6 +1,7 @@ package com.project.milvus.application.impl; import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.core.util.RandomUtil; import com.project.base.config.CustomIdGenerator; import com.project.base.domain.exception.MissingParameterException; import com.project.milvus.application.MilvusApplicationService; @@ -13,12 +14,10 @@ import org.apache.commons.codec.digest.DigestUtils; import org.redisson.api.RLock; import org.redisson.api.RedissonClient; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; -import java.util.Collections; -import java.util.List; -import java.util.Random; +import java.util.*; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -64,7 +63,7 @@ public class MilvusApplicationServiceImpl implements MilvusApplicationService { } //比较相似度 - List> query = milvusDemoService.query(title.getPointIdsList(), title.getTitleVectorList()); + List> query = milvusDemoService.query(title); if (CollectionUtil.isNotEmpty(query) && CollectionUtil.isNotEmpty(query.get(0))) { SearchResp.SearchResult searchResult = query.get(0).get(0); Float score = searchResult.getScore(); @@ -91,7 +90,6 @@ public class MilvusApplicationServiceImpl implements MilvusApplicationService { * 构建锁的key */ private String buildLockKey(List poinIds){ - Collections.sort(poinIds); String collect = poinIds.stream().map(String::valueOf).collect(Collectors.joining(",")); return LOCK_KEY + DigestUtils.md5Hex(collect); } diff --git a/src/main/java/com/project/milvus/domain/dto/TitleVector.java b/src/main/java/com/project/milvus/domain/dto/TitleVector.java index 669eb8b..edd1c04 100644 --- a/src/main/java/com/project/milvus/domain/dto/TitleVector.java +++ b/src/main/java/com/project/milvus/domain/dto/TitleVector.java @@ -3,8 +3,12 @@ package com.project.milvus.domain.dto; import com.google.gson.Gson; import com.google.gson.JsonElement; import lombok.Data; +import org.apache.commons.codec.digest.DigestUtils; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; @Data public class TitleVector { @@ -13,14 +17,23 @@ public class TitleVector { public List pointIds; //题目向量 public List titleVector; + //题目类型 + public String type; - public JsonElement getPointIdsJson() { - Gson gson = new Gson(); - return gson.toJsonTree(pointIds); + public String getPointIdsHash() { + List sorted = new ArrayList<>(pointIds); + Collections.sort(sorted); + String canonical = sorted.stream() + .map(String::valueOf) + .collect(Collectors.joining(",")); + + return DigestUtils.md5Hex(canonical); } public List getPointIdsList() { - return pointIds; + List sorted = new ArrayList<>(pointIds); + Collections.sort(sorted); + return sorted; } public JsonElement getTitleVectorJson() { diff --git a/src/main/java/com/project/milvus/domain/service/MilvusDemoService.java b/src/main/java/com/project/milvus/domain/service/MilvusDemoService.java index 82304cd..ff0b43d 100644 --- a/src/main/java/com/project/milvus/domain/service/MilvusDemoService.java +++ b/src/main/java/com/project/milvus/domain/service/MilvusDemoService.java @@ -27,5 +27,5 @@ public interface MilvusDemoService { /** * 查询关联知识点相似度最大的题目 */ - List> query(List point_id,List titleVector); + List> query(TitleVector title); } diff --git a/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java b/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java index 5aa3aff..3c06459 100644 --- a/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java +++ b/src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java @@ -1,9 +1,11 @@ package com.project.milvus.domain.service.impl; import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.core.util.EnumUtil; import com.project.base.domain.exception.MissingParameterException; import com.project.milvus.domain.dto.TitleVector; import com.project.milvus.domain.service.CheckMilvusDomainService; +import com.project.task.domain.enums.QuestionTypeEnum; import org.springframework.stereotype.Service; import java.util.List; @@ -29,5 +31,9 @@ public class CheckMilvusDomainServiceImpl implements CheckMilvusDomainService { if(titleVectorList.size() != 1024){ throw new MissingParameterException("向量格式错误"); } + + if(!EnumUtil.contains(QuestionTypeEnum.class,title.getType())){ + throw new MissingParameterException("题目类型不存在"); + } } } diff --git a/src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java b/src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java index babdcc3..66f97f3 100644 --- a/src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java +++ b/src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java @@ -3,6 +3,7 @@ package com.project.milvus.domain.service.impl; import com.google.gson.JsonObject; import com.project.milvus.domain.dto.TitleVector; import com.project.milvus.domain.service.MilvusDemoService; +import com.project.task.domain.enums.QuestionTypeEnum; import io.milvus.v2.client.MilvusClientV2; import io.milvus.v2.common.DataType; import io.milvus.v2.common.IndexParam; @@ -25,7 +26,7 @@ import java.util.List; @Slf4j public class MilvusDemoServiceImpl implements MilvusDemoService { //集合名称 - private static final String COLLECTION_NAME = "titleCollection"; + private static final String COLLECTION_NAME = "titleCollection_"; private final MilvusClientV2 client ; //创建向量数据库连接 @@ -78,17 +79,19 @@ public class MilvusDemoServiceImpl implements MilvusDemoService { * 往collection中插入一条数据 */ public void insertRecord(TitleVector title) { + QuestionTypeEnum questionTypeEnum = QuestionTypeEnum.valueOf(title.getType()); + JsonObject vector = new JsonObject(); vector.addProperty("id", title.getId()); - vector.add("point_ids", title.getPointIdsJson()); + vector.addProperty("point_ids", title.getPointIdsHash()); vector.add("title_vector", title.getTitleVectorJson()); InsertReq insertReq = InsertReq.builder() - .collectionName(COLLECTION_NAME) + .collectionName(COLLECTION_NAME+questionTypeEnum.getType()) .data(Collections.singletonList(vector)) .build(); - InsertResp resp = client.insert(insertReq); + client.insert(insertReq); } /** @@ -107,25 +110,16 @@ public class MilvusDemoServiceImpl implements MilvusDemoService { * 查询关联知识点相似度最大的题目 */ @Override - public List> query(List pointIds,List titleVector){ - StringBuffer expr = new StringBuffer(); - for (int i = 0; i < pointIds.size(); i++) { - if (pointIds.size() == 1){ - expr.append("array_contains(point_ids, " + pointIds.get(i) + ") " + " && array_length(point_ids) == 1"); - break; - } - if(i == pointIds.size() - 1){ - expr.append("array_contains(point_ids, " + pointIds.get(i) + ")"); - continue; - } - expr.append("array_contains(point_ids, " + pointIds.get(i) + ") and "); - } + public List> query(TitleVector title){ + QuestionTypeEnum questionTypeEnum = QuestionTypeEnum.valueOf(title.getType()); + + String expr = String.format("point_ids == '%s'", title.getPointIdsHash()); SearchResp searchReq = client.search(SearchReq.builder() - .collectionName(COLLECTION_NAME) - .data(Collections.singletonList(new FloatVec(titleVector))) + .collectionName(COLLECTION_NAME+questionTypeEnum.getType()) + .data(Collections.singletonList(new FloatVec(title.getTitleVectorList()))) .topK(1) - .filter(expr.toString()) + .filter(expr) .outputFields(Collections.singletonList("*")) .metricType(IndexParam.MetricType.COSINE) // 余弦相似度 .build()); diff --git a/src/main/java/com/project/task/domain/enums/QuestionTypeEnum.java b/src/main/java/com/project/task/domain/enums/QuestionTypeEnum.java index de56c77..d179289 100644 --- a/src/main/java/com/project/task/domain/enums/QuestionTypeEnum.java +++ b/src/main/java/com/project/task/domain/enums/QuestionTypeEnum.java @@ -8,10 +8,11 @@ import lombok.RequiredArgsConstructor; @Getter @RequiredArgsConstructor public enum QuestionTypeEnum implements HasValueEnum { - SINGLE_CHOICE(1, "单选题"), // 对应黄色背景 - MULTIPLE_CHOICE(2, "多选题"), // 对应黄色背景 - TRUE_FALSE(3, "判断题"); // 对应绿色背景 + SINGLE_CHOICE(1, "单选题","single"), // 对应黄色背景 + MULTIPLE_CHOICE(2, "多选题","multiple"), // 对应黄色背景 + TRUE_FALSE(3, "判断题","judgment"); // 对应绿色背景 private final Integer value; private final String description; + private final String type; }