7 changed files with 262 additions and 1 deletions
@ -0,0 +1,23 @@ |
|||||
|
package com.project.milvus.config; |
||||
|
|
||||
|
import io.milvus.v2.client.ConnectConfig; |
||||
|
import io.milvus.v2.client.MilvusClientV2; |
||||
|
import org.springframework.beans.factory.annotation.Value; |
||||
|
import org.springframework.context.annotation.Bean; |
||||
|
import org.springframework.context.annotation.Configuration; |
||||
|
|
||||
|
@Configuration |
||||
|
public class MilvusConfig { |
||||
|
|
||||
|
@Value("${milvus.host}") |
||||
|
private String host; |
||||
|
@Value("${milvus.port}") |
||||
|
private String port; |
||||
|
|
||||
|
@Bean |
||||
|
public MilvusClientV2 milvusClient(){ |
||||
|
String uri = "http://"+host+":"+port; |
||||
|
ConnectConfig connectConfig = ConnectConfig.builder().uri(uri).build(); |
||||
|
return new MilvusClientV2(connectConfig); |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,16 @@ |
|||||
|
package com.project.milvus.controller; |
||||
|
|
||||
|
import com.project.milvus.service.MilvusDemoService; |
||||
|
import org.springframework.beans.factory.annotation.Autowired; |
||||
|
import org.springframework.web.bind.annotation.*; |
||||
|
|
||||
|
|
||||
|
@RestController |
||||
|
@RequestMapping("/milvus") |
||||
|
public class MilvusController { |
||||
|
@Autowired |
||||
|
private MilvusDemoService milvusDemoService; |
||||
|
|
||||
|
|
||||
|
|
||||
|
} |
||||
@ -0,0 +1,34 @@ |
|||||
|
package com.project.milvus.domain; |
||||
|
|
||||
|
import com.google.gson.Gson; |
||||
|
import com.google.gson.JsonElement; |
||||
|
import lombok.Data; |
||||
|
|
||||
|
import java.util.List; |
||||
|
|
||||
|
@Data |
||||
|
public class TitleVector { |
||||
|
public Long id; |
||||
|
//知识点ID
|
||||
|
public List<Long> pointIds; |
||||
|
//题目向量
|
||||
|
public List<Float> titleVector; |
||||
|
|
||||
|
public JsonElement getPointIdsJson() { |
||||
|
Gson gson = new Gson(); |
||||
|
return gson.toJsonTree(pointIds); |
||||
|
} |
||||
|
|
||||
|
public List<Long> getPointIdsList() { |
||||
|
return pointIds; |
||||
|
} |
||||
|
|
||||
|
public JsonElement getTitleVectorJson() { |
||||
|
Gson gson = new Gson(); |
||||
|
return gson.toJsonTree(titleVector); |
||||
|
} |
||||
|
|
||||
|
public List<Float> getTitleVectorList() { |
||||
|
return titleVector; |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,31 @@ |
|||||
|
package com.project.milvus.service; |
||||
|
|
||||
|
import com.project.milvus.domain.TitleVector; |
||||
|
import io.milvus.v2.service.vector.response.GetResp; |
||||
|
import io.milvus.v2.service.vector.response.SearchResp; |
||||
|
|
||||
|
import java.util.List; |
||||
|
|
||||
|
public interface MilvusDemoService { |
||||
|
|
||||
|
/** |
||||
|
* 创建一个Collection |
||||
|
*/ |
||||
|
void createCollection(); |
||||
|
|
||||
|
/** |
||||
|
* 插入向量 |
||||
|
*/ |
||||
|
void insertRecord(TitleVector vector); |
||||
|
|
||||
|
/** |
||||
|
* 查询向量 |
||||
|
*/ |
||||
|
GetResp getRecord(String id); |
||||
|
|
||||
|
|
||||
|
/** |
||||
|
* 查询关联知识点相似度最大的题目 |
||||
|
*/ |
||||
|
List<List<SearchResp.SearchResult>> query(List<Long> point_id,List<Float> titleVector); |
||||
|
} |
||||
@ -0,0 +1,133 @@ |
|||||
|
package com.project.milvus.service.impl; |
||||
|
|
||||
|
import com.google.gson.JsonObject; |
||||
|
import com.project.milvus.domain.TitleVector; |
||||
|
import com.project.milvus.service.MilvusDemoService; |
||||
|
import io.milvus.v2.client.MilvusClientV2; |
||||
|
import io.milvus.v2.common.DataType; |
||||
|
import io.milvus.v2.common.IndexParam; |
||||
|
import io.milvus.v2.service.collection.request.AddFieldReq; |
||||
|
import io.milvus.v2.service.collection.request.CreateCollectionReq; |
||||
|
import io.milvus.v2.service.vector.request.GetReq; |
||||
|
import io.milvus.v2.service.vector.request.InsertReq; |
||||
|
import io.milvus.v2.service.vector.request.SearchReq; |
||||
|
import io.milvus.v2.service.vector.request.data.FloatVec; |
||||
|
import io.milvus.v2.service.vector.response.GetResp; |
||||
|
import io.milvus.v2.service.vector.response.InsertResp; |
||||
|
import io.milvus.v2.service.vector.response.SearchResp; |
||||
|
import lombok.SneakyThrows; |
||||
|
import lombok.extern.slf4j.Slf4j; |
||||
|
import org.springframework.stereotype.Service; |
||||
|
|
||||
|
import java.util.Collections; |
||||
|
import java.util.List; |
||||
|
|
||||
|
@Service |
||||
|
@Slf4j |
||||
|
public class MilvusDemoServiceImpl implements MilvusDemoService { |
||||
|
//集合名称
|
||||
|
private static final String COLLECTION_NAME = "titleCollection"; |
||||
|
|
||||
|
private final MilvusClientV2 client ; |
||||
|
//创建向量数据库连接
|
||||
|
public MilvusDemoServiceImpl(MilvusClientV2 client ) { |
||||
|
this.client = client ; |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* 创建一个Collection |
||||
|
*/ |
||||
|
@Override |
||||
|
public void createCollection() { |
||||
|
CreateCollectionReq.CollectionSchema schema = client.createSchema(); |
||||
|
|
||||
|
schema.addField(AddFieldReq.builder() |
||||
|
.fieldName("id") |
||||
|
.dataType(DataType.Int64) |
||||
|
.isPrimaryKey(true) |
||||
|
.autoID(false) |
||||
|
.build()); |
||||
|
|
||||
|
schema.addField(AddFieldReq.builder() |
||||
|
.fieldName("point_ids") |
||||
|
.dataType(DataType.Array) |
||||
|
.elementType(DataType.Int64) |
||||
|
.maxCapacity(100) |
||||
|
.build()); |
||||
|
|
||||
|
schema.addField(AddFieldReq.builder() |
||||
|
.fieldName("title_vector") |
||||
|
.dataType(DataType.FloatVector) |
||||
|
.dimension(1024) |
||||
|
.build()); |
||||
|
|
||||
|
IndexParam indexParam = IndexParam.builder() |
||||
|
.fieldName("title_vector") |
||||
|
.metricType(IndexParam.MetricType.COSINE) |
||||
|
.build(); |
||||
|
|
||||
|
CreateCollectionReq createCollectionReq = CreateCollectionReq.builder() |
||||
|
.collectionName(COLLECTION_NAME) |
||||
|
.collectionSchema(schema) |
||||
|
.indexParams(Collections.singletonList(indexParam)) |
||||
|
.build(); |
||||
|
|
||||
|
client.createCollection(createCollectionReq); |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* 往collection中插入一条数据 |
||||
|
*/ |
||||
|
public void insertRecord(TitleVector title) { |
||||
|
JsonObject vector = new JsonObject(); |
||||
|
vector.addProperty("id", title.getId()); |
||||
|
vector.add("point_ids", title.getPointIdsJson()); |
||||
|
vector.add("title_vector", title.getTitleVectorJson()); |
||||
|
|
||||
|
InsertReq insertReq = InsertReq.builder() |
||||
|
.collectionName(COLLECTION_NAME) |
||||
|
.data(Collections.singletonList(vector)) |
||||
|
.build(); |
||||
|
InsertResp resp = client.insert(insertReq); |
||||
|
|
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* 通过ID获取记录 |
||||
|
*/ |
||||
|
public GetResp getRecord(String id) { |
||||
|
GetReq getReq = GetReq.builder() |
||||
|
.collectionName(COLLECTION_NAME) |
||||
|
.ids(Collections.singletonList(id)) |
||||
|
.build(); |
||||
|
GetResp resp = client.get(getReq); |
||||
|
return resp; |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* 查询关联知识点相似度最大的题目 |
||||
|
*/ |
||||
|
@Override |
||||
|
public List<List<SearchResp.SearchResult>> query(List<Long> pointIds,List<Float> titleVector){ |
||||
|
StringBuffer expr = new StringBuffer(); |
||||
|
for (int i = 0; i < pointIds.size(); i++) { |
||||
|
if(i == pointIds.size() - 1){ |
||||
|
expr.append("array_contains(point_ids, " + pointIds.get(i) + ")"); |
||||
|
continue; |
||||
|
} |
||||
|
expr.append("array_contains(point_ids, " + pointIds.get(i) + ") and "); |
||||
|
} |
||||
|
|
||||
|
SearchResp searchReq = client.search(SearchReq.builder() |
||||
|
.collectionName(COLLECTION_NAME) |
||||
|
.data(Collections.singletonList(new FloatVec(titleVector))) |
||||
|
.topK(1) |
||||
|
.filter(expr.toString()) |
||||
|
.outputFields(Collections.singletonList("*")) |
||||
|
.metricType(IndexParam.MetricType.COSINE) // 余弦相似度
|
||||
|
.build()); |
||||
|
|
||||
|
return searchReq.getSearchResults(); |
||||
|
} |
||||
|
|
||||
|
} |
||||
Loading…
Reference in new issue