509 lines
27 KiB
Python
509 lines
27 KiB
Python
|
# pylint: disable=ungrouped-imports
|
||
|
from typing import List, Dict, Union, Optional, Any, Literal, Callable
|
||
|
|
||
|
import logging
|
||
|
from pathlib import Path
|
||
|
from copy import deepcopy
|
||
|
from requests.exceptions import HTTPError
|
||
|
|
||
|
import numpy as np
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import pandas as pd
|
||
|
from huggingface_hub import hf_hub_download
|
||
|
|
||
|
from haystack.errors import HaystackError
|
||
|
from haystack.schema import Document, FilterType
|
||
|
from haystack.document_stores import BaseDocumentStore
|
||
|
from haystack.telemetry import send_event
|
||
|
from haystack.lazy_imports import LazyImport
|
||
|
from haystack.nodes.retriever import DenseRetriever
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_and_transformers_import:
|
||
|
import torch
|
||
|
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
|
||
|
from transformers import AutoConfig
|
||
|
|
||
|
import sys
|
||
|
sys.path.append("../..")
|
||
|
from api.embeddingsServiceCaller import EmbeddingServiceCaller
|
||
|
|
||
|
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
||
|
"llama": {}
|
||
|
}
|
||
|
|
||
|
class LlamaRetriever(DenseRetriever):
|
||
|
def __init__(
|
||
|
self,
|
||
|
model_format = "llama",
|
||
|
document_store: Optional[BaseDocumentStore] = None,
|
||
|
model_version: Optional[str] = None,
|
||
|
use_gpu: bool = True,
|
||
|
batch_size: int = 32,
|
||
|
max_seq_len: int = 512,
|
||
|
pooling_strategy: str = "reduce_mean",
|
||
|
emb_extraction_layer: int = -1,
|
||
|
top_k: int = 10,
|
||
|
progress_bar: bool = True,
|
||
|
devices: Optional[List[Union[str, "torch.device"]]] = None,
|
||
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||
|
scale_score: bool = True,
|
||
|
embed_meta_fields: Optional[List[str]] = None,
|
||
|
api_key: Optional[str] = None,
|
||
|
azure_api_version: str = "2022-12-01",
|
||
|
azure_base_url: Optional[str] = None,
|
||
|
azure_deployment_name: Optional[str] = None,
|
||
|
api_base: str = "https://api.openai.com/v1",
|
||
|
openai_organization: Optional[str] = None,
|
||
|
):
|
||
|
"""
|
||
|
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||
|
:param embedding_model: Local path or name of model in Hugging Face's model hub such
|
||
|
as ``'sentence-transformers/all-MiniLM-L6-v2'``. The embedding model could also
|
||
|
potentially be an OpenAI model ["ada", "babbage", "davinci", "curie"] or
|
||
|
a Cohere model ["small", "medium", "large"].
|
||
|
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||
|
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
||
|
:param batch_size: Number of documents to encode at once.
|
||
|
:param max_seq_len: Longest length of each document sequence. Maximum number of tokens for the document text. Longer ones will be cut down.
|
||
|
:param model_format: Name of framework that was used for saving the model or model type. If no model_format is
|
||
|
provided, it will be inferred automatically from the model configuration files.
|
||
|
Options:
|
||
|
|
||
|
- ``'farm'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
|
||
|
- ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
|
||
|
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
|
||
|
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
|
||
|
- ``'openai'``: (will use `_OpenAIEmbeddingEncoder` as embedding encoder)
|
||
|
- ``'cohere'``: (will use `_CohereEmbeddingEncoder` as embedding encoder)
|
||
|
:param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
|
||
|
Options:
|
||
|
|
||
|
- ``'cls_token'`` (sentence vector)
|
||
|
- ``'reduce_mean'`` (sentence vector)
|
||
|
- ``'reduce_max'`` (sentence vector)
|
||
|
- ``'per_token'`` (individual token vectors)
|
||
|
:param emb_extraction_layer: Number of layer from which the embeddings shall be extracted (for farm / transformers models only).
|
||
|
Default: -1 (very last layer).
|
||
|
:param top_k: How many documents to return per query.
|
||
|
:param progress_bar: If true displays progress bar during embedding.
|
||
|
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
|
||
|
A list containing torch device objects and/or strings is supported (For example
|
||
|
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
|
||
|
parameter is not used and a single cpu device is used for inference.
|
||
|
Note: As multi-GPU training is currently not implemented for EmbeddingRetriever,
|
||
|
training will only use the first device provided in this list.
|
||
|
:param use_auth_token: The API token used to download private models from Huggingface.
|
||
|
If this parameter is set to `True`, then the token generated when running
|
||
|
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||
|
Additional information can be found here
|
||
|
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||
|
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||
|
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
||
|
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||
|
:param embed_meta_fields: Concatenate the provided meta fields and text passage / table to a text pair that is
|
||
|
then used to create the embedding.
|
||
|
This approach is also used in the TableTextRetriever paper and is likely to improve
|
||
|
performance if your titles contain meaningful information for retrieval
|
||
|
(topic, entities etc.).
|
||
|
If no value is provided, a default empty list will be created.
|
||
|
:param api_key: The OpenAI API key or the Cohere API key. Required if one wants to use OpenAI/Cohere embeddings.
|
||
|
For more details see https://beta.openai.com/account/api-keys and https://dashboard.cohere.ai/api-keys
|
||
|
:param azure_api_version: The version of the Azure OpenAI API to use. The default is `2022-12-01` version.
|
||
|
:param azure_base_url: The base URL for the Azure OpenAI API. If not supplied, Azure OpenAI API will not be used.
|
||
|
This parameter is an OpenAI Azure endpoint, usually in the form `https://<your-endpoint>.openai.azure.com'
|
||
|
:param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API
|
||
|
will not be used.
|
||
|
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`.
|
||
|
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI
|
||
|
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
|
||
|
"""
|
||
|
torch_and_transformers_import.check()
|
||
|
|
||
|
if embed_meta_fields is None:
|
||
|
embed_meta_fields = []
|
||
|
super().__init__()
|
||
|
|
||
|
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=True)
|
||
|
|
||
|
if batch_size < len(self.devices):
|
||
|
logger.warning("Batch size is less than the number of devices.All gpus will not be utilized.")
|
||
|
|
||
|
self.document_store = document_store
|
||
|
self.model_version = model_version
|
||
|
self.use_gpu = use_gpu
|
||
|
self.batch_size = batch_size
|
||
|
self.max_seq_len = max_seq_len
|
||
|
self.pooling_strategy = pooling_strategy
|
||
|
self.emb_extraction_layer = emb_extraction_layer
|
||
|
self.top_k = top_k
|
||
|
self.progress_bar = progress_bar
|
||
|
self.use_auth_token = use_auth_token
|
||
|
self.scale_score = scale_score
|
||
|
self.api_key = api_key
|
||
|
self.api_base = api_base
|
||
|
self.api_version = azure_api_version
|
||
|
self.azure_base_url = azure_base_url
|
||
|
self.azure_deployment_name = azure_deployment_name
|
||
|
self.openai_organization = openai_organization
|
||
|
self.model_format= model_format
|
||
|
self.emb_caller= EmbeddingServiceCaller()
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
self.embed_meta_fields = embed_meta_fields
|
||
|
|
||
|
def retrieve(
|
||
|
self,
|
||
|
query: str,
|
||
|
filters: Optional[FilterType] = None,
|
||
|
top_k: Optional[int] = None,
|
||
|
index: Optional[str] = None,
|
||
|
headers: Optional[Dict[str, str]] = None,
|
||
|
scale_score: Optional[bool] = None,
|
||
|
document_store: Optional[BaseDocumentStore] = None,
|
||
|
) -> List[Document]:
|
||
|
"""
|
||
|
Scan through the documents in a DocumentStore and return a small number of documents
|
||
|
that are most relevant to the query.
|
||
|
|
||
|
:param query: The query
|
||
|
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain
|
||
|
conditions.
|
||
|
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical
|
||
|
operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`,
|
||
|
`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name.
|
||
|
Logical operator keys take a dictionary of metadata field names and/or logical operators as
|
||
|
value. Metadata field names take a dictionary of comparison operators as value. Comparison
|
||
|
operator keys take a single value or (in case of `"$in"`) a list of values as value.
|
||
|
If no logical operator is provided, `"$and"` is used as default operation. If no comparison
|
||
|
operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default
|
||
|
operation.
|
||
|
|
||
|
__Example__:
|
||
|
|
||
|
```python
|
||
|
filters = {
|
||
|
"$and": {
|
||
|
"type": {"$eq": "article"},
|
||
|
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||
|
"rating": {"$gte": 3},
|
||
|
"$or": {
|
||
|
"genre": {"$in": ["economy", "politics"]},
|
||
|
"publisher": {"$eq": "nytimes"}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
# or simpler using default operators
|
||
|
filters = {
|
||
|
"type": "article",
|
||
|
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||
|
"rating": {"$gte": 3},
|
||
|
"$or": {
|
||
|
"genre": ["economy", "politics"],
|
||
|
"publisher": "nytimes"
|
||
|
}
|
||
|
}
|
||
|
```
|
||
|
|
||
|
To use the same logical operator multiple times on the same level, logical operators take
|
||
|
optionally a list of dictionaries as value.
|
||
|
|
||
|
__Example__:
|
||
|
|
||
|
```python
|
||
|
filters = {
|
||
|
"$or": [
|
||
|
{
|
||
|
"$and": {
|
||
|
"Type": "News Paper",
|
||
|
"Date": {
|
||
|
"$lt": "2019-01-01"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"$and": {
|
||
|
"Type": "Blog Post",
|
||
|
"Date": {
|
||
|
"$gte": "2019-01-01"
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
]
|
||
|
}
|
||
|
```
|
||
|
:param top_k: How many documents to return per query.
|
||
|
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||
|
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic API_KEY'} for basic authentication)
|
||
|
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||
|
If true similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
||
|
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||
|
:param document_store: the docstore to use for retrieval. If `None`, the one given in the `__init__` is used instead.
|
||
|
"""
|
||
|
document_store = document_store or self.document_store
|
||
|
if document_store is None:
|
||
|
raise ValueError(
|
||
|
"This Retriever was not initialized with a Document Store. Provide one to the retrieve() method."
|
||
|
)
|
||
|
if top_k is None:
|
||
|
top_k = self.top_k
|
||
|
if index is None:
|
||
|
index = document_store.index
|
||
|
if scale_score is None:
|
||
|
scale_score = self.scale_score
|
||
|
query_emb = self.embed_queries(queries=[query])
|
||
|
documents = document_store.query_by_embedding(
|
||
|
query_emb=query_emb, filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score
|
||
|
)
|
||
|
return documents
|
||
|
|
||
|
def retrieve_batch(
|
||
|
self,
|
||
|
queries: List[str],
|
||
|
filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
|
||
|
top_k: Optional[int] = None,
|
||
|
index: Optional[str] = None,
|
||
|
headers: Optional[Dict[str, str]] = None,
|
||
|
batch_size: Optional[int] = None,
|
||
|
scale_score: Optional[bool] = None,
|
||
|
document_store: Optional[BaseDocumentStore] = None,
|
||
|
) -> List[List[Document]]:
|
||
|
"""
|
||
|
Scan through the documents in a DocumentStore and return a small number of documents
|
||
|
that are most relevant to the supplied queries.
|
||
|
|
||
|
Returns a list of lists of Documents (one per query).
|
||
|
|
||
|
:param queries: List of query strings.
|
||
|
:param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain
|
||
|
conditions. Can be a single filter that will be applied to each query or a list of filters
|
||
|
(one filter per query).
|
||
|
|
||
|
Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical
|
||
|
operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`,
|
||
|
`"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name.
|
||
|
Logical operator keys take a dictionary of metadata field names and/or logical operators as
|
||
|
value. Metadata field names take a dictionary of comparison operators as value. Comparison
|
||
|
operator keys take a single value or (in case of `"$in"`) a list of values as value.
|
||
|
If no logical operator is provided, `"$and"` is used as default operation. If no comparison
|
||
|
operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default
|
||
|
operation.
|
||
|
|
||
|
__Example__:
|
||
|
|
||
|
```python
|
||
|
filters = {
|
||
|
"$and": {
|
||
|
"type": {"$eq": "article"},
|
||
|
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||
|
"rating": {"$gte": 3},
|
||
|
"$or": {
|
||
|
"genre": {"$in": ["economy", "politics"]},
|
||
|
"publisher": {"$eq": "nytimes"}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
# or simpler using default operators
|
||
|
filters = {
|
||
|
"type": "article",
|
||
|
"date": {"$gte": "2015-01-01", "$lt": "2021-01-01"},
|
||
|
"rating": {"$gte": 3},
|
||
|
"$or": {
|
||
|
"genre": ["economy", "politics"],
|
||
|
"publisher": "nytimes"
|
||
|
}
|
||
|
}
|
||
|
```
|
||
|
|
||
|
To use the same logical operator multiple times on the same level, logical operators take
|
||
|
optionally a list of dictionaries as value.
|
||
|
|
||
|
__Example__:
|
||
|
|
||
|
```python
|
||
|
filters = {
|
||
|
"$or": [
|
||
|
{
|
||
|
"$and": {
|
||
|
"Type": "News Paper",
|
||
|
"Date": {
|
||
|
"$lt": "2019-01-01"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"$and": {
|
||
|
"Type": "Blog Post",
|
||
|
"Date": {
|
||
|
"$gte": "2019-01-01"
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
]
|
||
|
}
|
||
|
```
|
||
|
:param top_k: How many documents to return per query.
|
||
|
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||
|
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic API_KEY'} for basic authentication)
|
||
|
:param batch_size: Number of queries to embed at a time.
|
||
|
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
|
||
|
If true similarity scores (e.g. cosine or dot_product) which naturally have a different
|
||
|
value range will be scaled to a range of [0,1], where 1 means extremely relevant.
|
||
|
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||
|
:param document_store: the docstore to use for retrieval. If `None`, the one given in the `__init__` is used instead.
|
||
|
"""
|
||
|
document_store = document_store or self.document_store
|
||
|
if document_store is None:
|
||
|
raise ValueError(
|
||
|
"This Retriever was not initialized with a Document Store. Provide one to the retrieve_batch() method."
|
||
|
)
|
||
|
if top_k is None:
|
||
|
top_k = self.top_k
|
||
|
|
||
|
if batch_size is None:
|
||
|
batch_size = self.batch_size
|
||
|
|
||
|
if index is None:
|
||
|
index = document_store.index
|
||
|
if scale_score is None:
|
||
|
scale_score = self.scale_score
|
||
|
|
||
|
# embed_queries is already batched within by batch_size, so no need to batch the input here
|
||
|
query_embs: np.ndarray = self.embed_queries(queries=queries)
|
||
|
batched_query_embs: List[np.ndarray] = []
|
||
|
for i in range(0, len(query_embs), batch_size):
|
||
|
batched_query_embs.extend(query_embs[i : i + batch_size])
|
||
|
documents = document_store.query_by_embedding_batch(
|
||
|
query_embs=batched_query_embs,
|
||
|
top_k=top_k,
|
||
|
filters=filters,
|
||
|
index=index,
|
||
|
headers=headers,
|
||
|
scale_score=scale_score,
|
||
|
)
|
||
|
|
||
|
return documents
|
||
|
|
||
|
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||
|
if isinstance(queries, str):
|
||
|
queries = [queries]
|
||
|
assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||
|
return np.array(self.emb_caller.get_embeddings(queries[0], embedding_type="last_layer" ))
|
||
|
|
||
|
def embed_documents(self, documents: List[Document]) -> np.ndarray:
|
||
|
documents = self._preprocess_documents(documents)
|
||
|
embeddings=[]
|
||
|
for doc in documents:
|
||
|
embeddings.append(self.emb_caller.get_embeddings(doc.content, embedding_type="last_layer"))
|
||
|
return np.array(embeddings)
|
||
|
|
||
|
def _preprocess_documents(self, docs: List[Document]) -> List[Document]:
|
||
|
"""
|
||
|
Turns table documents into text documents by representing the table in csv format.
|
||
|
This allows us to use text embedding models for table retrieval.
|
||
|
It also concatenates specified meta data fields with the text representations.
|
||
|
|
||
|
:param docs: List of documents to linearize. If the document is not a table, it is returned as is.
|
||
|
:return: List of documents with meta data + linearized tables or original documents if they are not tables.
|
||
|
"""
|
||
|
linearized_docs = []
|
||
|
for doc in docs:
|
||
|
doc = deepcopy(doc)
|
||
|
if doc.content_type == "table":
|
||
|
if isinstance(doc.content, pd.DataFrame):
|
||
|
doc.content = doc.content.to_csv(index=False)
|
||
|
else:
|
||
|
raise HaystackError("Documents of type 'table' need to have a pd.DataFrame as content field")
|
||
|
# Gather all relevant metadata fields
|
||
|
meta_data_fields = []
|
||
|
for key in self.embed_meta_fields:
|
||
|
if key in doc.meta and doc.meta[key]:
|
||
|
if isinstance(doc.meta[key], list):
|
||
|
meta_data_fields.extend([item for item in doc.meta[key]])
|
||
|
else:
|
||
|
meta_data_fields.append(doc.meta[key])
|
||
|
# Convert to type string (e.g. for ints or floats)
|
||
|
meta_data_fields = [str(field) for field in meta_data_fields]
|
||
|
doc.content = "\n".join(meta_data_fields + [doc.content])
|
||
|
linearized_docs.append(doc)
|
||
|
return linearized_docs
|
||
|
|
||
|
@staticmethod
|
||
|
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
|
||
|
valid_openai_model_name = model_name_or_path in ["ada", "babbage", "davinci", "curie"] or any(
|
||
|
m in model_name_or_path for m in ["-ada-", "-babbage-", "-davinci-", "-curie-"]
|
||
|
)
|
||
|
if valid_openai_model_name:
|
||
|
return "openai"
|
||
|
if model_name_or_path in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"]:
|
||
|
return "cohere"
|
||
|
# Check if model name is a local directory with sentence transformers config file in it
|
||
|
if Path(model_name_or_path).exists():
|
||
|
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
|
||
|
return "sentence_transformers"
|
||
|
# Check if sentence transformers config file in model hub
|
||
|
else:
|
||
|
try:
|
||
|
hf_hub_download( # type: ignore [call-arg]
|
||
|
repo_id=model_name_or_path,
|
||
|
filename="config_sentence_transformers.json",
|
||
|
use_auth_token=use_auth_token,
|
||
|
)
|
||
|
return "sentence_transformers"
|
||
|
except HTTPError:
|
||
|
pass
|
||
|
|
||
|
# Check if retribert model
|
||
|
config = AutoConfig.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||
|
if config.model_type == "retribert":
|
||
|
return "retribert"
|
||
|
|
||
|
# Model is neither sentence-transformers nor retribert model -> use _DefaultEmbeddingEncoder
|
||
|
return "farm"
|
||
|
|
||
|
def train(
|
||
|
self,
|
||
|
training_data: List[Dict[str, Any]],
|
||
|
learning_rate: float = 2e-5,
|
||
|
n_epochs: int = 1,
|
||
|
num_warmup_steps: Optional[int] = None,
|
||
|
batch_size: int = 16,
|
||
|
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
|
||
|
num_workers: int = 0,
|
||
|
use_amp: bool = False,
|
||
|
**kwargs,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Trains/adapts the underlying embedding model. We only support the training of sentence-transformer embedding models.
|
||
|
|
||
|
Each training data example is a dictionary with the following keys:
|
||
|
|
||
|
* question: the question string
|
||
|
* pos_doc: the positive document string
|
||
|
* neg_doc: the negative document string
|
||
|
* score: the score margin
|
||
|
|
||
|
:param training_data: The training data in a dictionary format.
|
||
|
:param learning_rate: The learning rate.
|
||
|
:param n_epochs: The number of epochs that you want the train for.
|
||
|
:param num_warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is
|
||
|
increased from 0 up to the maximal learning rate. After these many training steps, the learning rate is
|
||
|
decreased linearly back to zero.
|
||
|
:param batch_size: The batch size to use for the training. The default values is 16.
|
||
|
:param train_loss: The loss to use for training.
|
||
|
If you're using a sentence-transformer embedding_model (which is the only model that training is supported for),
|
||
|
possible values are 'mnrl' (Multiple Negatives Ranking Loss) or 'margin_mse' (MarginMSE).
|
||
|
:param num_workers: The number of subprocesses to use for the Pytorch DataLoader.
|
||
|
:param use_amp: Use Automatic Mixed Precision (AMP).
|
||
|
:param kwargs: Additional training key word arguments to pass to the `SentenceTransformer.fit` function. Please
|
||
|
reference the Sentence-Transformers [documentation](https://www.sbert.net/docs/training/overview.html#sentence_transformers.SentenceTransformer.fit)
|
||
|
for a full list of keyword arguments.
|
||
|
"""
|
||
|
send_event(event_name="Training", event_properties={"class": self.__class__.__name__, "function_name": "train"})
|
||
|
|