forked from 1827133/BA-Chatbot
92 lines
2.7 KiB
Python
Executable File
92 lines
2.7 KiB
Python
Executable File
import time
|
|
import json
|
|
from pymilvus import Collection, utility, connections, CollectionSchema, FieldSchema, DataType
|
|
import os
|
|
import sys
|
|
sys.path.append('/root/home/BA_QA_HSMA/backendd')
|
|
|
|
|
|
class MilviusHandler:
|
|
def __init__(self) -> None:
|
|
# TODO: get alias from param
|
|
connections.connect(
|
|
alias="default",
|
|
host='localhost',
|
|
port='19530'
|
|
)
|
|
|
|
# TODO: get collection_name from param
|
|
if not utility.has_collection("ib"):
|
|
id = FieldSchema(
|
|
name="id",
|
|
dtype=DataType.INT64,
|
|
is_primary=True,
|
|
auto_id=True
|
|
)
|
|
|
|
es_id = FieldSchema(
|
|
name="es_id",
|
|
dtype=DataType.VARCHAR,
|
|
max_length=200)
|
|
|
|
content = FieldSchema(
|
|
name="module_content",
|
|
dtype=DataType.FLOAT_VECTOR,
|
|
dim=4096
|
|
)
|
|
schema = CollectionSchema(
|
|
fields=[content, id, es_id],
|
|
description="module search"
|
|
)
|
|
# TODO: get collection_name from param
|
|
self.collection = Collection(
|
|
name="ib",
|
|
schema=schema,
|
|
using='default'
|
|
)
|
|
else:
|
|
# Get an existing collection.
|
|
self.collection = Collection("ib")
|
|
self.collection.load()
|
|
|
|
def write_data(self, data):
|
|
self.collection.insert(data)
|
|
# utility.do_bulk_insert(
|
|
# collection_name="ib",
|
|
# files=[data])
|
|
index_params = {
|
|
"metric_type": "L2",
|
|
"index_type": "IVF_FLAT",
|
|
"params": {"nlist": 1024}
|
|
}
|
|
self.collection.create_index(
|
|
field_name="module_content",
|
|
index_params=index_params
|
|
)
|
|
return
|
|
|
|
def search(self, query_emb):
|
|
self.collection.load()
|
|
search_params = {
|
|
"metric_type": "L2",
|
|
"params": {"nprobe": 100},
|
|
}
|
|
result = self.collection.search(query_emb, anns_field="module_content", param=search_params, limit=3, output_fields=["es_id"])
|
|
return result
|
|
def query(self):
|
|
self.collection.load()
|
|
res = self.collection.query(
|
|
expr = "id > 0",
|
|
offset = 0,
|
|
limit = 10,
|
|
output_fields = ["id", "es_id", "module_content"],
|
|
consistency_level="Strong"
|
|
)
|
|
return res
|
|
if __name__ == "__main__":
|
|
with open('../embedded_docs.json', 'r') as f:
|
|
data = json.load(f)
|
|
emb= data[0]["module_content"]
|
|
es = MilviusHandler()
|
|
print(es.search([emb]))
|