Browse Source

向量数据库相关优化

master
luogw 1 month ago
parent
commit
1d693ba49c
  1. 10
      src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java
  2. 21
      src/main/java/com/project/milvus/domain/dto/TitleVector.java
  3. 2
      src/main/java/com/project/milvus/domain/service/MilvusDemoService.java
  4. 6
      src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java
  5. 34
      src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java
  6. 7
      src/main/java/com/project/task/domain/enums/QuestionTypeEnum.java

10
src/main/java/com/project/milvus/application/impl/MilvusApplicationServiceImpl.java

@ -1,6 +1,7 @@
package com.project.milvus.application.impl;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.RandomUtil;
import com.project.base.config.CustomIdGenerator;
import com.project.base.domain.exception.MissingParameterException;
import com.project.milvus.application.MilvusApplicationService;
@ -13,12 +14,10 @@ import org.apache.commons.codec.digest.DigestUtils;
import org.redisson.api.RLock;
import org.redisson.api.RedissonClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@ -64,7 +63,7 @@ public class MilvusApplicationServiceImpl implements MilvusApplicationService {
}
//比较相似度
List<List<SearchResp.SearchResult>> query = milvusDemoService.query(title.getPointIdsList(), title.getTitleVectorList());
List<List<SearchResp.SearchResult>> query = milvusDemoService.query(title);
if (CollectionUtil.isNotEmpty(query) && CollectionUtil.isNotEmpty(query.get(0))) {
SearchResp.SearchResult searchResult = query.get(0).get(0);
Float score = searchResult.getScore();
@ -91,7 +90,6 @@ public class MilvusApplicationServiceImpl implements MilvusApplicationService {
* 构建锁的key
*/
private String buildLockKey(List<Long> poinIds){
Collections.sort(poinIds);
String collect = poinIds.stream().map(String::valueOf).collect(Collectors.joining(","));
return LOCK_KEY + DigestUtils.md5Hex(collect);
}

21
src/main/java/com/project/milvus/domain/dto/TitleVector.java

@ -3,8 +3,12 @@ package com.project.milvus.domain.dto;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import lombok.Data;
import org.apache.commons.codec.digest.DigestUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@Data
public class TitleVector {
@ -13,14 +17,23 @@ public class TitleVector {
public List<Long> pointIds;
//题目向量
public List<Float> titleVector;
//题目类型
public String type;
public JsonElement getPointIdsJson() {
Gson gson = new Gson();
return gson.toJsonTree(pointIds);
public String getPointIdsHash() {
List<Long> sorted = new ArrayList<>(pointIds);
Collections.sort(sorted);
String canonical = sorted.stream()
.map(String::valueOf)
.collect(Collectors.joining(","));
return DigestUtils.md5Hex(canonical);
}
public List<Long> getPointIdsList() {
return pointIds;
List<Long> sorted = new ArrayList<>(pointIds);
Collections.sort(sorted);
return sorted;
}
public JsonElement getTitleVectorJson() {

2
src/main/java/com/project/milvus/domain/service/MilvusDemoService.java

@ -27,5 +27,5 @@ public interface MilvusDemoService {
/**
* 查询关联知识点相似度最大的题目
*/
List<List<SearchResp.SearchResult>> query(List<Long> point_id,List<Float> titleVector);
List<List<SearchResp.SearchResult>> query(TitleVector title);
}

6
src/main/java/com/project/milvus/domain/service/impl/CheckMilvusDomainServiceImpl.java

@ -1,9 +1,11 @@
package com.project.milvus.domain.service.impl;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.EnumUtil;
import com.project.base.domain.exception.MissingParameterException;
import com.project.milvus.domain.dto.TitleVector;
import com.project.milvus.domain.service.CheckMilvusDomainService;
import com.project.task.domain.enums.QuestionTypeEnum;
import org.springframework.stereotype.Service;
import java.util.List;
@ -29,5 +31,9 @@ public class CheckMilvusDomainServiceImpl implements CheckMilvusDomainService {
if(titleVectorList.size() != 1024){
throw new MissingParameterException("向量格式错误");
}
if(!EnumUtil.contains(QuestionTypeEnum.class,title.getType())){
throw new MissingParameterException("题目类型不存在");
}
}
}

34
src/main/java/com/project/milvus/domain/service/impl/MilvusDemoServiceImpl.java

@ -3,6 +3,7 @@ package com.project.milvus.domain.service.impl;
import com.google.gson.JsonObject;
import com.project.milvus.domain.dto.TitleVector;
import com.project.milvus.domain.service.MilvusDemoService;
import com.project.task.domain.enums.QuestionTypeEnum;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
@ -25,7 +26,7 @@ import java.util.List;
@Slf4j
public class MilvusDemoServiceImpl implements MilvusDemoService {
//集合名称
private static final String COLLECTION_NAME = "titleCollection";
private static final String COLLECTION_NAME = "titleCollection_";
private final MilvusClientV2 client ;
//创建向量数据库连接
@ -78,17 +79,19 @@ public class MilvusDemoServiceImpl implements MilvusDemoService {
* 往collection中插入一条数据
*/
public void insertRecord(TitleVector title) {
QuestionTypeEnum questionTypeEnum = QuestionTypeEnum.valueOf(title.getType());
JsonObject vector = new JsonObject();
vector.addProperty("id", title.getId());
vector.add("point_ids", title.getPointIdsJson());
vector.addProperty("point_ids", title.getPointIdsHash());
vector.add("title_vector", title.getTitleVectorJson());
InsertReq insertReq = InsertReq.builder()
.collectionName(COLLECTION_NAME)
.collectionName(COLLECTION_NAME+questionTypeEnum.getType())
.data(Collections.singletonList(vector))
.build();
InsertResp resp = client.insert(insertReq);
client.insert(insertReq);
}
/**
@ -107,25 +110,16 @@ public class MilvusDemoServiceImpl implements MilvusDemoService {
* 查询关联知识点相似度最大的题目
*/
@Override
public List<List<SearchResp.SearchResult>> query(List<Long> pointIds,List<Float> titleVector){
StringBuffer expr = new StringBuffer();
for (int i = 0; i < pointIds.size(); i++) {
if (pointIds.size() == 1){
expr.append("array_contains(point_ids, " + pointIds.get(i) + ") " + " && array_length(point_ids) == 1");
break;
}
if(i == pointIds.size() - 1){
expr.append("array_contains(point_ids, " + pointIds.get(i) + ")");
continue;
}
expr.append("array_contains(point_ids, " + pointIds.get(i) + ") and ");
}
public List<List<SearchResp.SearchResult>> query(TitleVector title){
QuestionTypeEnum questionTypeEnum = QuestionTypeEnum.valueOf(title.getType());
String expr = String.format("point_ids == '%s'", title.getPointIdsHash());
SearchResp searchReq = client.search(SearchReq.builder()
.collectionName(COLLECTION_NAME)
.data(Collections.singletonList(new FloatVec(titleVector)))
.collectionName(COLLECTION_NAME+questionTypeEnum.getType())
.data(Collections.singletonList(new FloatVec(title.getTitleVectorList())))
.topK(1)
.filter(expr.toString())
.filter(expr)
.outputFields(Collections.singletonList("*"))
.metricType(IndexParam.MetricType.COSINE) // 余弦相似度
.build());

7
src/main/java/com/project/task/domain/enums/QuestionTypeEnum.java

@ -8,10 +8,11 @@ import lombok.RequiredArgsConstructor;
@Getter
@RequiredArgsConstructor
public enum QuestionTypeEnum implements HasValueEnum<Integer> {
SINGLE_CHOICE(1, "单选题"), // 对应黄色背景
MULTIPLE_CHOICE(2, "多选题"), // 对应黄色背景
TRUE_FALSE(3, "判断题"); // 对应绿色背景
SINGLE_CHOICE(1, "单选题","single"), // 对应黄色背景
MULTIPLE_CHOICE(2, "多选题","multiple"), // 对应黄色背景
TRUE_FALSE(3, "判断题","judgment"); // 对应绿色背景
private final Integer value;
private final String description;
private final String type;
}

Loading…
Cancel
Save