Browse Source

初始化向量数据库

master
luogw 1 month ago
parent
commit
df26a8dd98
  1. 20
      pom.xml
  2. 23
      src/main/java/com/project/milvus/config/MilvusConfig.java
  3. 16
      src/main/java/com/project/milvus/controller/MilvusController.java
  4. 34
      src/main/java/com/project/milvus/domain/TitleVector.java
  5. 31
      src/main/java/com/project/milvus/service/MilvusDemoService.java
  6. 133
      src/main/java/com/project/milvus/service/impl/MilvusDemoServiceImpl.java
  7. 4
      src/main/resources/application.yml

20
pom.xml

@ -146,5 +146,25 @@
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId> <artifactId>commons-pool2</artifactId>
</dependency> </dependency>
<!-- milvus向量数据库 -->
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>2.5.4</version>
<exclusions>
<exclusion>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>3.24.0</version>
<scope>compile</scope>
</dependency>
</dependencies> </dependencies>
</project> </project>

23
src/main/java/com/project/milvus/config/MilvusConfig.java

@ -0,0 +1,23 @@
package com.project.milvus.config;
import io.milvus.v2.client.ConnectConfig;
import io.milvus.v2.client.MilvusClientV2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class MilvusConfig {
@Value("${milvus.host}")
private String host;
@Value("${milvus.port}")
private String port;
@Bean
public MilvusClientV2 milvusClient(){
String uri = "http://"+host+":"+port;
ConnectConfig connectConfig = ConnectConfig.builder().uri(uri).build();
return new MilvusClientV2(connectConfig);
}
}

16
src/main/java/com/project/milvus/controller/MilvusController.java

@ -0,0 +1,16 @@
package com.project.milvus.controller;
import com.project.milvus.service.MilvusDemoService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/milvus")
public class MilvusController {
@Autowired
private MilvusDemoService milvusDemoService;
}

34
src/main/java/com/project/milvus/domain/TitleVector.java

@ -0,0 +1,34 @@
package com.project.milvus.domain;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import lombok.Data;
import java.util.List;
@Data
public class TitleVector {
public Long id;
//知识点ID
public List<Long> pointIds;
//题目向量
public List<Float> titleVector;
public JsonElement getPointIdsJson() {
Gson gson = new Gson();
return gson.toJsonTree(pointIds);
}
public List<Long> getPointIdsList() {
return pointIds;
}
public JsonElement getTitleVectorJson() {
Gson gson = new Gson();
return gson.toJsonTree(titleVector);
}
public List<Float> getTitleVectorList() {
return titleVector;
}
}

31
src/main/java/com/project/milvus/service/MilvusDemoService.java

@ -0,0 +1,31 @@
package com.project.milvus.service;
import com.project.milvus.domain.TitleVector;
import io.milvus.v2.service.vector.response.GetResp;
import io.milvus.v2.service.vector.response.SearchResp;
import java.util.List;
public interface MilvusDemoService {
/**
* 创建一个Collection
*/
void createCollection();
/**
* 插入向量
*/
void insertRecord(TitleVector vector);
/**
* 查询向量
*/
GetResp getRecord(String id);
/**
* 查询关联知识点相似度最大的题目
*/
List<List<SearchResp.SearchResult>> query(List<Long> point_id,List<Float> titleVector);
}

133
src/main/java/com/project/milvus/service/impl/MilvusDemoServiceImpl.java

@ -0,0 +1,133 @@
package com.project.milvus.service.impl;
import com.google.gson.JsonObject;
import com.project.milvus.domain.TitleVector;
import com.project.milvus.service.MilvusDemoService;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.GetReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.SearchReq;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.response.GetResp;
import io.milvus.v2.service.vector.response.InsertResp;
import io.milvus.v2.service.vector.response.SearchResp;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.Collections;
import java.util.List;
@Service
@Slf4j
public class MilvusDemoServiceImpl implements MilvusDemoService {
//集合名称
private static final String COLLECTION_NAME = "titleCollection";
private final MilvusClientV2 client ;
//创建向量数据库连接
public MilvusDemoServiceImpl(MilvusClientV2 client ) {
this.client = client ;
}
/**
* 创建一个Collection
*/
@Override
public void createCollection() {
CreateCollectionReq.CollectionSchema schema = client.createSchema();
schema.addField(AddFieldReq.builder()
.fieldName("id")
.dataType(DataType.Int64)
.isPrimaryKey(true)
.autoID(false)
.build());
schema.addField(AddFieldReq.builder()
.fieldName("point_ids")
.dataType(DataType.Array)
.elementType(DataType.Int64)
.maxCapacity(100)
.build());
schema.addField(AddFieldReq.builder()
.fieldName("title_vector")
.dataType(DataType.FloatVector)
.dimension(1024)
.build());
IndexParam indexParam = IndexParam.builder()
.fieldName("title_vector")
.metricType(IndexParam.MetricType.COSINE)
.build();
CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
.collectionName(COLLECTION_NAME)
.collectionSchema(schema)
.indexParams(Collections.singletonList(indexParam))
.build();
client.createCollection(createCollectionReq);
}
/**
* 往collection中插入一条数据
*/
public void insertRecord(TitleVector title) {
JsonObject vector = new JsonObject();
vector.addProperty("id", title.getId());
vector.add("point_ids", title.getPointIdsJson());
vector.add("title_vector", title.getTitleVectorJson());
InsertReq insertReq = InsertReq.builder()
.collectionName(COLLECTION_NAME)
.data(Collections.singletonList(vector))
.build();
InsertResp resp = client.insert(insertReq);
}
/**
* 通过ID获取记录
*/
public GetResp getRecord(String id) {
GetReq getReq = GetReq.builder()
.collectionName(COLLECTION_NAME)
.ids(Collections.singletonList(id))
.build();
GetResp resp = client.get(getReq);
return resp;
}
/**
* 查询关联知识点相似度最大的题目
*/
@Override
public List<List<SearchResp.SearchResult>> query(List<Long> pointIds,List<Float> titleVector){
StringBuffer expr = new StringBuffer();
for (int i = 0; i < pointIds.size(); i++) {
if(i == pointIds.size() - 1){
expr.append("array_contains(point_ids, " + pointIds.get(i) + ")");
continue;
}
expr.append("array_contains(point_ids, " + pointIds.get(i) + ") and ");
}
SearchResp searchReq = client.search(SearchReq.builder()
.collectionName(COLLECTION_NAME)
.data(Collections.singletonList(new FloatVec(titleVector)))
.topK(1)
.filter(expr.toString())
.outputFields(Collections.singletonList("*"))
.metricType(IndexParam.MetricType.COSINE) // 余弦相似度
.build());
return searchReq.getSearchResults();
}
}

4
src/main/resources/application.yml

@ -37,3 +37,7 @@ mybatis-plus:
map-underscore-to-camel-case: true map-underscore-to-camel-case: true
mapper-locations: classpath*:mapper/**/*.xml mapper-locations: classpath*:mapper/**/*.xml
type-aliases-package: com.proposal.**.domain.entity type-aliases-package: com.proposal.**.domain.entity
milvus:
host: 172.16.204.50
port: 19530
Loading…
Cancel
Save