Source code for langchain_community.embeddings.vertexai

from typing import Dict, List

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import root_validator

from langchain_community.llms.vertexai import _VertexAICommon
from langchain_community.utilities.vertexai import raise_vertex_import_error


[docs]class VertexAIEmbeddings(_VertexAICommon, Embeddings): """Google Cloud VertexAI embedding models.""" model_name: str = "textembedding-gecko" @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates that the python package exists in environment.""" cls._try_init_vertexai(values) try: from vertexai.language_models import TextEmbeddingModel except ImportError: raise_vertex_import_error() values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) return values
[docs] def embed_documents( self, texts: List[str], batch_size: int = 5 ) -> List[List[float]]: """Embed a list of strings. Vertex AI currently sets a max batch size of 5 strings. Args: texts: List[str] The list of strings to embed. batch_size: [int] The batch size of embeddings to send to the model Returns: List of embeddings, one for each text. """ embeddings = [] for batch in range(0, len(texts), batch_size): text_batch = texts[batch : batch + batch_size] embeddings_batch = self.client.get_embeddings(text_batch) embeddings.extend([el.values for el in embeddings_batch]) return embeddings
[docs] def embed_query(self, text: str) -> List[float]: """Embed a text. Args: text: The text to embed. Returns: Embedding for the text. """ embeddings = self.client.get_embeddings([text]) return embeddings[0].values