Source code for langchain.memory.motorhead_memory

from typing import Any, Dict, List, Optional

import requests
from langchain_core.messages import get_buffer_string

from langchain.memory.chat_memory import BaseChatMemory

MANAGED_URL = "https://api.getmetal.io/v1/motorhead"
# LOCAL_URL = "http://localhost:8080"


[docs]class MotorheadMemory(BaseChatMemory): """Chat message memory backed by Motorhead service.""" url: str = MANAGED_URL timeout: int = 3000 memory_key: str = "history" session_id: str context: Optional[str] = None # Managed Params api_key: Optional[str] = None client_id: Optional[str] = None def __get_headers(self) -> Dict[str, str]: is_managed = self.url == MANAGED_URL headers = { "Content-Type": "application/json", } if is_managed and not (self.api_key and self.client_id): raise ValueError( """ You must provide an API key or a client ID to use the managed version of Motorhead. Visit https://getmetal.io for more information. """ ) if is_managed and self.api_key and self.client_id: headers["x-metal-api-key"] = self.api_key headers["x-metal-client-id"] = self.client_id return headers
[docs] async def init(self) -> None: res = requests.get( f"{self.url}/sessions/{self.session_id}/memory", timeout=self.timeout, headers=self.__get_headers(), ) res_data = res.json() res_data = res_data.get("data", res_data) # Handle Managed Version messages = res_data.get("messages", []) context = res_data.get("context", "NONE") for message in reversed(messages): if message["role"] == "AI": self.chat_memory.add_ai_message(message["content"]) else: self.chat_memory.add_user_message(message["content"]) if context and context != "NONE": self.context = context
[docs] def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]: if self.return_messages: return {self.memory_key: self.chat_memory.messages} else: return {self.memory_key: get_buffer_string(self.chat_memory.messages)}
@property def memory_variables(self) -> List[str]: return [self.memory_key]
[docs] def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: input_str, output_str = self._get_input_output(inputs, outputs) requests.post( f"{self.url}/sessions/{self.session_id}/memory", timeout=self.timeout, json={ "messages": [ {"role": "Human", "content": f"{input_str}"}, {"role": "AI", "content": f"{output_str}"}, ] }, headers=self.__get_headers(), ) super().save_context(inputs, outputs)
[docs] def delete_session(self) -> None: """Delete a session""" requests.delete(f"{self.url}/sessions/{self.session_id}/memory")