92 lines
2.7 KiB
Python
92 lines
2.7 KiB
Python
|
import time
|
||
|
import json
|
||
|
from pymilvus import Collection, utility, connections, CollectionSchema, FieldSchema, DataType
|
||
|
import os
|
||
|
import sys
|
||
|
sys.path.append('/root/home/BA_QA_HSMA/backendd')
|
||
|
|
||
|
|
||
|
class MilviusHandler:
|
||
|
def __init__(self) -> None:
|
||
|
# TODO: get alias from param
|
||
|
connections.connect(
|
||
|
alias="default",
|
||
|
host='localhost',
|
||
|
port='19530'
|
||
|
)
|
||
|
|
||
|
# TODO: get collection_name from param
|
||
|
if not utility.has_collection("ib"):
|
||
|
id = FieldSchema(
|
||
|
name="id",
|
||
|
dtype=DataType.INT64,
|
||
|
is_primary=True,
|
||
|
auto_id=True
|
||
|
)
|
||
|
|
||
|
es_id = FieldSchema(
|
||
|
name="es_id",
|
||
|
dtype=DataType.VARCHAR,
|
||
|
max_length=200)
|
||
|
|
||
|
content = FieldSchema(
|
||
|
name="module_content",
|
||
|
dtype=DataType.FLOAT_VECTOR,
|
||
|
dim=4096
|
||
|
)
|
||
|
schema = CollectionSchema(
|
||
|
fields=[content, id, es_id],
|
||
|
description="module search"
|
||
|
)
|
||
|
# TODO: get collection_name from param
|
||
|
self.collection = Collection(
|
||
|
name="ib",
|
||
|
schema=schema,
|
||
|
using='default'
|
||
|
)
|
||
|
else:
|
||
|
# Get an existing collection.
|
||
|
self.collection = Collection("ib")
|
||
|
self.collection.load()
|
||
|
|
||
|
def write_data(self, data):
|
||
|
self.collection.insert(data)
|
||
|
# utility.do_bulk_insert(
|
||
|
# collection_name="ib",
|
||
|
# files=[data])
|
||
|
index_params = {
|
||
|
"metric_type": "L2",
|
||
|
"index_type": "IVF_FLAT",
|
||
|
"params": {"nlist": 1024}
|
||
|
}
|
||
|
self.collection.create_index(
|
||
|
field_name="module_content",
|
||
|
index_params=index_params
|
||
|
)
|
||
|
return
|
||
|
|
||
|
def search(self, query_emb):
|
||
|
self.collection.load()
|
||
|
search_params = {
|
||
|
"metric_type": "L2",
|
||
|
"params": {"nprobe": 100},
|
||
|
}
|
||
|
result = self.collection.search(query_emb, anns_field="module_content", param=search_params, limit=3, output_fields=["es_id"])
|
||
|
return result
|
||
|
def query(self):
|
||
|
self.collection.load()
|
||
|
res = self.collection.query(
|
||
|
expr = "id > 0",
|
||
|
offset = 0,
|
||
|
limit = 10,
|
||
|
output_fields = ["id", "es_id", "module_content"],
|
||
|
consistency_level="Strong"
|
||
|
)
|
||
|
return res
|
||
|
if __name__ == "__main__":
|
||
|
with open('../embedded_docs.json', 'r') as f:
|
||
|
data = json.load(f)
|
||
|
emb= data[0]["module_content"]
|
||
|
es = MilviusHandler()
|
||
|
print(es.search([emb]))
|