Source code for langchain.agents.openai_functions_agent.agent_token_buffer_memory

"""Memory used to save agent output AND intermediate steps."""
from typing import Any, Dict, List

from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string

from langchain.agents.format_scratchpad.openai_functions import (
    format_to_openai_function_messages,
)
from langchain.memory.chat_memory import BaseChatMemory


[docs]class AgentTokenBufferMemory(BaseChatMemory): """Memory used to save agent output AND intermediate steps.""" human_prefix: str = "Human" ai_prefix: str = "AI" llm: BaseLanguageModel memory_key: str = "history" max_token_limit: int = 12000 """The max number of tokens to keep in the buffer. Once the buffer exceeds this many tokens, the oldest messages will be pruned.""" return_messages: bool = True output_key: str = "output" intermediate_steps_key: str = "intermediate_steps" @property def buffer(self) -> List[BaseMessage]: """String buffer of memory.""" return self.chat_memory.messages @property def memory_variables(self) -> List[str]: """Will always return list of memory variables. :meta private: """ return [self.memory_key]
[docs] def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Return history buffer.""" if self.return_messages: final_buffer: Any = self.buffer else: final_buffer = get_buffer_string( self.buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix, ) return {self.memory_key: final_buffer}
[docs] def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: """Save context from this conversation to buffer. Pruned.""" input_str, output_str = self._get_input_output(inputs, outputs) self.chat_memory.add_user_message(input_str) steps = format_to_openai_function_messages(outputs[self.intermediate_steps_key]) for msg in steps: self.chat_memory.add_message(msg) self.chat_memory.add_ai_message(output_str) # Prune buffer if it exceeds max token limit buffer = self.chat_memory.messages curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) if curr_buffer_length > self.max_token_limit: while curr_buffer_length > self.max_token_limit: buffer.pop(0) curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)