|
|
@ -3,6 +3,7 @@ package com.project.milvus.domain.service.impl; |
|
|
import com.google.gson.JsonObject; |
|
|
import com.google.gson.JsonObject; |
|
|
import com.project.milvus.domain.dto.TitleVector; |
|
|
import com.project.milvus.domain.dto.TitleVector; |
|
|
import com.project.milvus.domain.service.MilvusDemoService; |
|
|
import com.project.milvus.domain.service.MilvusDemoService; |
|
|
|
|
|
import com.project.task.domain.enums.QuestionTypeEnum; |
|
|
import io.milvus.v2.client.MilvusClientV2; |
|
|
import io.milvus.v2.client.MilvusClientV2; |
|
|
import io.milvus.v2.common.DataType; |
|
|
import io.milvus.v2.common.DataType; |
|
|
import io.milvus.v2.common.IndexParam; |
|
|
import io.milvus.v2.common.IndexParam; |
|
|
@ -25,7 +26,7 @@ import java.util.List; |
|
|
@Slf4j |
|
|
@Slf4j |
|
|
public class MilvusDemoServiceImpl implements MilvusDemoService { |
|
|
public class MilvusDemoServiceImpl implements MilvusDemoService { |
|
|
//集合名称
|
|
|
//集合名称
|
|
|
private static final String COLLECTION_NAME = "titleCollection"; |
|
|
private static final String COLLECTION_NAME = "titleCollection_"; |
|
|
|
|
|
|
|
|
private final MilvusClientV2 client ; |
|
|
private final MilvusClientV2 client ; |
|
|
//创建向量数据库连接
|
|
|
//创建向量数据库连接
|
|
|
@ -78,17 +79,19 @@ public class MilvusDemoServiceImpl implements MilvusDemoService { |
|
|
* 往collection中插入一条数据 |
|
|
* 往collection中插入一条数据 |
|
|
*/ |
|
|
*/ |
|
|
public void insertRecord(TitleVector title) { |
|
|
public void insertRecord(TitleVector title) { |
|
|
|
|
|
QuestionTypeEnum questionTypeEnum = QuestionTypeEnum.valueOf(title.getType()); |
|
|
|
|
|
|
|
|
JsonObject vector = new JsonObject(); |
|
|
JsonObject vector = new JsonObject(); |
|
|
vector.addProperty("id", title.getId()); |
|
|
vector.addProperty("id", title.getId()); |
|
|
vector.add("point_ids", title.getPointIdsJson()); |
|
|
vector.addProperty("point_ids", title.getPointIdsHash()); |
|
|
vector.add("title_vector", title.getTitleVectorJson()); |
|
|
vector.add("title_vector", title.getTitleVectorJson()); |
|
|
|
|
|
|
|
|
InsertReq insertReq = InsertReq.builder() |
|
|
InsertReq insertReq = InsertReq.builder() |
|
|
.collectionName(COLLECTION_NAME) |
|
|
.collectionName(COLLECTION_NAME+questionTypeEnum.getType()) |
|
|
.data(Collections.singletonList(vector)) |
|
|
.data(Collections.singletonList(vector)) |
|
|
.build(); |
|
|
.build(); |
|
|
InsertResp resp = client.insert(insertReq); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client.insert(insertReq); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/** |
|
|
/** |
|
|
@ -107,25 +110,16 @@ public class MilvusDemoServiceImpl implements MilvusDemoService { |
|
|
* 查询关联知识点相似度最大的题目 |
|
|
* 查询关联知识点相似度最大的题目 |
|
|
*/ |
|
|
*/ |
|
|
@Override |
|
|
@Override |
|
|
public List<List<SearchResp.SearchResult>> query(List<Long> pointIds,List<Float> titleVector){ |
|
|
public List<List<SearchResp.SearchResult>> query(TitleVector title){ |
|
|
StringBuffer expr = new StringBuffer(); |
|
|
QuestionTypeEnum questionTypeEnum = QuestionTypeEnum.valueOf(title.getType()); |
|
|
for (int i = 0; i < pointIds.size(); i++) { |
|
|
|
|
|
if (pointIds.size() == 1){ |
|
|
String expr = String.format("point_ids == '%s'", title.getPointIdsHash()); |
|
|
expr.append("array_contains(point_ids, " + pointIds.get(i) + ") " + " && array_length(point_ids) == 1"); |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
if(i == pointIds.size() - 1){ |
|
|
|
|
|
expr.append("array_contains(point_ids, " + pointIds.get(i) + ")"); |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
expr.append("array_contains(point_ids, " + pointIds.get(i) + ") and "); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
SearchResp searchReq = client.search(SearchReq.builder() |
|
|
SearchResp searchReq = client.search(SearchReq.builder() |
|
|
.collectionName(COLLECTION_NAME) |
|
|
.collectionName(COLLECTION_NAME+questionTypeEnum.getType()) |
|
|
.data(Collections.singletonList(new FloatVec(titleVector))) |
|
|
.data(Collections.singletonList(new FloatVec(title.getTitleVectorList()))) |
|
|
.topK(1) |
|
|
.topK(1) |
|
|
.filter(expr.toString()) |
|
|
.filter(expr) |
|
|
.outputFields(Collections.singletonList("*")) |
|
|
.outputFields(Collections.singletonList("*")) |
|
|
.metricType(IndexParam.MetricType.COSINE) // 余弦相似度
|
|
|
.metricType(IndexParam.MetricType.COSINE) // 余弦相似度
|
|
|
.build()); |
|
|
.build()); |
|
|
|