7 changed files with 262 additions and 1 deletions
@ -0,0 +1,23 @@ |
|||
package com.project.milvus.config; |
|||
|
|||
import io.milvus.v2.client.ConnectConfig; |
|||
import io.milvus.v2.client.MilvusClientV2; |
|||
import org.springframework.beans.factory.annotation.Value; |
|||
import org.springframework.context.annotation.Bean; |
|||
import org.springframework.context.annotation.Configuration; |
|||
|
|||
@Configuration |
|||
public class MilvusConfig { |
|||
|
|||
@Value("${milvus.host}") |
|||
private String host; |
|||
@Value("${milvus.port}") |
|||
private String port; |
|||
|
|||
@Bean |
|||
public MilvusClientV2 milvusClient(){ |
|||
String uri = "http://"+host+":"+port; |
|||
ConnectConfig connectConfig = ConnectConfig.builder().uri(uri).build(); |
|||
return new MilvusClientV2(connectConfig); |
|||
} |
|||
} |
|||
@ -0,0 +1,16 @@ |
|||
package com.project.milvus.controller; |
|||
|
|||
import com.project.milvus.service.MilvusDemoService; |
|||
import org.springframework.beans.factory.annotation.Autowired; |
|||
import org.springframework.web.bind.annotation.*; |
|||
|
|||
|
|||
@RestController |
|||
@RequestMapping("/milvus") |
|||
public class MilvusController { |
|||
@Autowired |
|||
private MilvusDemoService milvusDemoService; |
|||
|
|||
|
|||
|
|||
} |
|||
@ -0,0 +1,34 @@ |
|||
package com.project.milvus.domain; |
|||
|
|||
import com.google.gson.Gson; |
|||
import com.google.gson.JsonElement; |
|||
import lombok.Data; |
|||
|
|||
import java.util.List; |
|||
|
|||
@Data |
|||
public class TitleVector { |
|||
public Long id; |
|||
//知识点ID
|
|||
public List<Long> pointIds; |
|||
//题目向量
|
|||
public List<Float> titleVector; |
|||
|
|||
public JsonElement getPointIdsJson() { |
|||
Gson gson = new Gson(); |
|||
return gson.toJsonTree(pointIds); |
|||
} |
|||
|
|||
public List<Long> getPointIdsList() { |
|||
return pointIds; |
|||
} |
|||
|
|||
public JsonElement getTitleVectorJson() { |
|||
Gson gson = new Gson(); |
|||
return gson.toJsonTree(titleVector); |
|||
} |
|||
|
|||
public List<Float> getTitleVectorList() { |
|||
return titleVector; |
|||
} |
|||
} |
|||
@ -0,0 +1,31 @@ |
|||
package com.project.milvus.service; |
|||
|
|||
import com.project.milvus.domain.TitleVector; |
|||
import io.milvus.v2.service.vector.response.GetResp; |
|||
import io.milvus.v2.service.vector.response.SearchResp; |
|||
|
|||
import java.util.List; |
|||
|
|||
public interface MilvusDemoService { |
|||
|
|||
/** |
|||
* 创建一个Collection |
|||
*/ |
|||
void createCollection(); |
|||
|
|||
/** |
|||
* 插入向量 |
|||
*/ |
|||
void insertRecord(TitleVector vector); |
|||
|
|||
/** |
|||
* 查询向量 |
|||
*/ |
|||
GetResp getRecord(String id); |
|||
|
|||
|
|||
/** |
|||
* 查询关联知识点相似度最大的题目 |
|||
*/ |
|||
List<List<SearchResp.SearchResult>> query(List<Long> point_id,List<Float> titleVector); |
|||
} |
|||
@ -0,0 +1,133 @@ |
|||
package com.project.milvus.service.impl; |
|||
|
|||
import com.google.gson.JsonObject; |
|||
import com.project.milvus.domain.TitleVector; |
|||
import com.project.milvus.service.MilvusDemoService; |
|||
import io.milvus.v2.client.MilvusClientV2; |
|||
import io.milvus.v2.common.DataType; |
|||
import io.milvus.v2.common.IndexParam; |
|||
import io.milvus.v2.service.collection.request.AddFieldReq; |
|||
import io.milvus.v2.service.collection.request.CreateCollectionReq; |
|||
import io.milvus.v2.service.vector.request.GetReq; |
|||
import io.milvus.v2.service.vector.request.InsertReq; |
|||
import io.milvus.v2.service.vector.request.SearchReq; |
|||
import io.milvus.v2.service.vector.request.data.FloatVec; |
|||
import io.milvus.v2.service.vector.response.GetResp; |
|||
import io.milvus.v2.service.vector.response.InsertResp; |
|||
import io.milvus.v2.service.vector.response.SearchResp; |
|||
import lombok.SneakyThrows; |
|||
import lombok.extern.slf4j.Slf4j; |
|||
import org.springframework.stereotype.Service; |
|||
|
|||
import java.util.Collections; |
|||
import java.util.List; |
|||
|
|||
@Service |
|||
@Slf4j |
|||
public class MilvusDemoServiceImpl implements MilvusDemoService { |
|||
//集合名称
|
|||
private static final String COLLECTION_NAME = "titleCollection"; |
|||
|
|||
private final MilvusClientV2 client ; |
|||
//创建向量数据库连接
|
|||
public MilvusDemoServiceImpl(MilvusClientV2 client ) { |
|||
this.client = client ; |
|||
} |
|||
|
|||
/** |
|||
* 创建一个Collection |
|||
*/ |
|||
@Override |
|||
public void createCollection() { |
|||
CreateCollectionReq.CollectionSchema schema = client.createSchema(); |
|||
|
|||
schema.addField(AddFieldReq.builder() |
|||
.fieldName("id") |
|||
.dataType(DataType.Int64) |
|||
.isPrimaryKey(true) |
|||
.autoID(false) |
|||
.build()); |
|||
|
|||
schema.addField(AddFieldReq.builder() |
|||
.fieldName("point_ids") |
|||
.dataType(DataType.Array) |
|||
.elementType(DataType.Int64) |
|||
.maxCapacity(100) |
|||
.build()); |
|||
|
|||
schema.addField(AddFieldReq.builder() |
|||
.fieldName("title_vector") |
|||
.dataType(DataType.FloatVector) |
|||
.dimension(1024) |
|||
.build()); |
|||
|
|||
IndexParam indexParam = IndexParam.builder() |
|||
.fieldName("title_vector") |
|||
.metricType(IndexParam.MetricType.COSINE) |
|||
.build(); |
|||
|
|||
CreateCollectionReq createCollectionReq = CreateCollectionReq.builder() |
|||
.collectionName(COLLECTION_NAME) |
|||
.collectionSchema(schema) |
|||
.indexParams(Collections.singletonList(indexParam)) |
|||
.build(); |
|||
|
|||
client.createCollection(createCollectionReq); |
|||
} |
|||
|
|||
/** |
|||
* 往collection中插入一条数据 |
|||
*/ |
|||
public void insertRecord(TitleVector title) { |
|||
JsonObject vector = new JsonObject(); |
|||
vector.addProperty("id", title.getId()); |
|||
vector.add("point_ids", title.getPointIdsJson()); |
|||
vector.add("title_vector", title.getTitleVectorJson()); |
|||
|
|||
InsertReq insertReq = InsertReq.builder() |
|||
.collectionName(COLLECTION_NAME) |
|||
.data(Collections.singletonList(vector)) |
|||
.build(); |
|||
InsertResp resp = client.insert(insertReq); |
|||
|
|||
} |
|||
|
|||
/** |
|||
* 通过ID获取记录 |
|||
*/ |
|||
public GetResp getRecord(String id) { |
|||
GetReq getReq = GetReq.builder() |
|||
.collectionName(COLLECTION_NAME) |
|||
.ids(Collections.singletonList(id)) |
|||
.build(); |
|||
GetResp resp = client.get(getReq); |
|||
return resp; |
|||
} |
|||
|
|||
/** |
|||
* 查询关联知识点相似度最大的题目 |
|||
*/ |
|||
@Override |
|||
public List<List<SearchResp.SearchResult>> query(List<Long> pointIds,List<Float> titleVector){ |
|||
StringBuffer expr = new StringBuffer(); |
|||
for (int i = 0; i < pointIds.size(); i++) { |
|||
if(i == pointIds.size() - 1){ |
|||
expr.append("array_contains(point_ids, " + pointIds.get(i) + ")"); |
|||
continue; |
|||
} |
|||
expr.append("array_contains(point_ids, " + pointIds.get(i) + ") and "); |
|||
} |
|||
|
|||
SearchResp searchReq = client.search(SearchReq.builder() |
|||
.collectionName(COLLECTION_NAME) |
|||
.data(Collections.singletonList(new FloatVec(titleVector))) |
|||
.topK(1) |
|||
.filter(expr.toString()) |
|||
.outputFields(Collections.singletonList("*")) |
|||
.metricType(IndexParam.MetricType.COSINE) // 余弦相似度
|
|||
.build()); |
|||
|
|||
return searchReq.getSearchResults(); |
|||
} |
|||
|
|||
} |
|||
Loading…
Reference in new issue