"""Adapter for <engine_name> vector database."""
import asyncio
from typing import List, Optional, Dict, Any
from concurrent.futures import ThreadPoolExecutor
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.shared.logging_utils import get_logger
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..utils import normalize_distances
from ..vector_db_interface import VectorDBInterface
from ..exceptions import CollectionNotFoundError
logger = get_logger()
class IndexSchema(DataPoint):
"""
Define a schema for indexing data points with a text field.
This class inherits from the DataPoint class and specifies the structure of a single
data point that includes a text attribute.
"""
text: str
metadata: dict = {"index_fields": ["text"]}
class <EngineName>Adapter(VectorDBInterface):
"""Adapter for <engine_name> vector database operations."""
name = "<EngineName>"
def __init__(
self,
url: str,
api_key: Optional[str] = None,
embedding_engine: EmbeddingEngine = None,
):
self.url = url
self.api_key = api_key
self.embedding_engine = embedding_engine
self.client = None
self.executor = ThreadPoolExecutor()
self._initialize_connection()
def _initialize_connection(self) -> None:
"""Establish connection to <engine_name>."""
try:
# Initialize your vector database client here
# Example: self.client = YourVectorDBClient(url=self.url, api_key=self.api_key)
logger.debug(f"Successfully connected to <engine_name> at {self.url}")
except Exception as e:
logger.error(f"Failed to initialize <engine_name> connection: {e}")
raise
async def has_collection(self, collection_name: str) -> bool:
"""Check if a specified collection exists."""
try:
# Implement collection existence check
# Example: return await self.client.collection_exists(collection_name)
return False
except Exception as e:
logger.error(f"Error checking collection existence: {e}")
return False
async def create_collection(
self,
collection_name: str,
payload_schema: Optional[object] = None,
):
"""Create a new collection with an optional payload schema."""
if await self.has_collection(collection_name):
return
try:
vector_size = self.embedding_engine.get_vector_size()
# Implement collection creation logic
# Example:
# await self.client.create_collection(
# name=collection_name,
# vector_size=vector_size,
# distance_metric="cosine"
# )
logger.debug(f"Created collection: {collection_name}")
except Exception as e:
logger.error(f"Failed to create collection {collection_name}: {e}")
raise
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
"""Insert new data points into the specified collection."""
if not await self.has_collection(collection_name):
await self.create_collection(collection_name)
# Generate embeddings for data points
embeddings = await self.embedding_engine.embed_text(
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
try:
# Implement data point insertion logic
# Example:
# formatted_points = [
# {
# "id": str(data_point.id),
# "vector": embeddings[i],
# "payload": data_point.model_dump()
# }
# for i, data_point in enumerate(data_points)
# ]
# await self.client.upsert(collection_name, formatted_points)
logger.debug(f"Inserted {len(data_points)} data points into {collection_name}")
except Exception as e:
logger.error(f"Failed to insert data points: {e}")
raise
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
"""Retrieve data points from a collection using their IDs."""
try:
# Implement data point retrieval logic
# Example:
# results = await self.client.retrieve(collection_name, data_point_ids)
# return [
# ScoredResult(
# id=parse_id(result["id"]),
# payload=result["payload"],
# score=0
# )
# for result in results
# ]
return []
except Exception as e:
logger.error(f"Failed to retrieve data points: {e}")
return []
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
) -> List[ScoredResult]:
"""Perform a search in the specified collection."""
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if not await self.has_collection(collection_name):
logger.warning(f"Collection '{collection_name}' not found; returning [].")
return []
if query_vector is None:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
try:
# Implement search logic
# Example:
# results = await self.client.search(
# collection_name=collection_name,
# query_vector=query_vector,
# limit=limit,
# include_vector=with_vector
# )
#
# return [
# ScoredResult(
# id=parse_id(result["id"]),
# payload=result["payload"],
# score=result["score"]
# )
# for result in results
# ]
return []
except Exception as e:
logger.error(f"Error searching collection '{collection_name}': {e}")
return []
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int,
with_vectors: bool = False,
):
"""Perform a batch search using multiple text queries."""
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]):
"""Delete specified data points from a collection."""
try:
# Implement deletion logic
# Example:
# result = await self.client.delete(collection_name, data_point_ids)
# return result
logger.debug(f"Deleted {len(data_point_ids)} data points from {collection_name}")
except Exception as e:
logger.error(f"Failed to delete data points: {e}")
raise
async def create_vector_index(self, index_name: str, index_property_name: str):
"""Create a vector index based on an index name and property name."""
return await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
):
"""Index a list of data points by creating an associated vector index collection."""
formatted_data_points = [
IndexSchema(
id=data_point.id,
text=DataPoint.get_embeddable_data(data_point),
)
for data_point in data_points
]
return await self.create_data_points(
f"{index_name}_{index_property_name}",
formatted_data_points,
)
async def prune(self):
"""Remove all data from the vector database."""
try:
# Implement pruning logic - delete all collections or data
# Example:
# collections = await self.client.list_collections()
# for collection in collections:
# await self.client.delete_collection(collection.name)
logger.debug("Pruned all data from <engine_name>")
except Exception as e:
logger.error(f"Failed to prune data: {e}")
raise