BA-Chatbot/backend/database/milvius.py

92 lines
2.7 KiB
Python
Raw Normal View History

2023-11-15 14:28:48 +01:00
import time
import json
from pymilvus import Collection, utility, connections, CollectionSchema, FieldSchema, DataType
import os
import sys
sys.path.append('/root/home/BA_QA_HSMA/backendd')
class MilviusHandler:
def __init__(self) -> None:
# TODO: get alias from param
connections.connect(
alias="default",
host='localhost',
port='19530'
)
# TODO: get collection_name from param
if not utility.has_collection("ib"):
id = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True
)
es_id = FieldSchema(
name="es_id",
dtype=DataType.VARCHAR,
max_length=200)
content = FieldSchema(
name="module_content",
dtype=DataType.FLOAT_VECTOR,
dim=4096
)
schema = CollectionSchema(
fields=[content, id, es_id],
description="module search"
)
# TODO: get collection_name from param
self.collection = Collection(
name="ib",
schema=schema,
using='default'
)
else:
# Get an existing collection.
self.collection = Collection("ib")
self.collection.load()
def write_data(self, data):
self.collection.insert(data)
# utility.do_bulk_insert(
# collection_name="ib",
# files=[data])
index_params = {
"metric_type": "L2",
"index_type": "IVF_FLAT",
"params": {"nlist": 1024}
}
self.collection.create_index(
field_name="module_content",
index_params=index_params
)
return
def search(self, query_emb):
self.collection.load()
search_params = {
"metric_type": "L2",
"params": {"nprobe": 100},
}
result = self.collection.search(query_emb, anns_field="module_content", param=search_params, limit=3, output_fields=["es_id"])
return result
def query(self):
self.collection.load()
res = self.collection.query(
expr = "id > 0",
offset = 0,
limit = 10,
output_fields = ["id", "es_id", "module_content"],
consistency_level="Strong"
)
return res
if __name__ == "__main__":
with open('../embedded_docs.json', 'r') as f:
data = json.load(f)
emb= data[0]["module_content"]
es = MilviusHandler()
print(es.search([emb]))