from __future__ import annotations
import logging
import os
from abc import ABC, abstractmethod
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts import (
    BasePromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_experimental.rl_chain.metrics import (
    MetricsTrackerAverage,
    MetricsTrackerRollingWindow,
)
from langchain_experimental.rl_chain.model_repository import ModelRepository
from langchain_experimental.rl_chain.vw_logger import VwLogger
if TYPE_CHECKING:
    import vowpal_wabbit_next as vw
logger = logging.getLogger(__name__)
class _BasedOn:
    def __init__(self, value: Any):
        self.value = value
    def __str__(self) -> str:
        return str(self.value)
    __repr__ = __str__
[docs]def BasedOn(anything: Any) -> _BasedOn:
    return _BasedOn(anything) 
class _ToSelectFrom:
    def __init__(self, value: Any):
        self.value = value
    def __str__(self) -> str:
        return str(self.value)
    __repr__ = __str__
[docs]def ToSelectFrom(anything: Any) -> _ToSelectFrom:
    if not isinstance(anything, list):
        raise ValueError("ToSelectFrom must be a list to select from")
    return _ToSelectFrom(anything) 
class _Embed:
    def __init__(self, value: Any, keep: bool = False):
        self.value = value
        self.keep = keep
    def __str__(self) -> str:
        return str(self.value)
    __repr__ = __str__
[docs]def Embed(anything: Any, keep: bool = False) -> Any:
    if isinstance(anything, _ToSelectFrom):
        return ToSelectFrom(Embed(anything.value, keep=keep))
    elif isinstance(anything, _BasedOn):
        return BasedOn(Embed(anything.value, keep=keep))
    if isinstance(anything, list):
        return [Embed(v, keep=keep) for v in anything]
    elif isinstance(anything, dict):
        return {k: Embed(v, keep=keep) for k, v in anything.items()}
    elif isinstance(anything, _Embed):
        return anything
    return _Embed(anything, keep=keep) 
[docs]def EmbedAndKeep(anything: Any) -> Any:
    return Embed(anything, keep=True) 
# helper functions
[docs]def stringify_embedding(embedding: List) -> str:
    return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)]) 
[docs]def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
    return [parser.parse_line(line) for line in input_str.split("\n")] 
[docs]def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
    to_select_from = {
        k: inputs[k].value
        for k in inputs.keys()
        if isinstance(inputs[k], _ToSelectFrom)
    }
    if not to_select_from:
        raise ValueError(
            "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."  # noqa: E501
        )
    based_on = {
        k: inputs[k].value if isinstance(inputs[k].value, list) else [inputs[k].value]
        for k in inputs.keys()
        if isinstance(inputs[k], _BasedOn)
    }
    return based_on, to_select_from 
# end helper functions
[docs]class Selected(ABC):
    pass 
TSelected = TypeVar("TSelected", bound=Selected)
[docs]class Event(Generic[TSelected], ABC):
    inputs: Dict[str, Any]
    selected: Optional[TSelected]
[docs]    def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None):
        self.inputs = inputs
        self.selected = selected  
TEvent = TypeVar("TEvent", bound=Event)
[docs]class Policy(Generic[TEvent], ABC):
[docs]    def __init__(self, **kwargs: Any):
        pass 
[docs]    @abstractmethod
    def predict(self, event: TEvent) -> Any:
        ... 
[docs]    @abstractmethod
    def learn(self, event: TEvent) -> None:
        ... 
[docs]    @abstractmethod
    def log(self, event: TEvent) -> None:
        ... 
[docs]    def save(self) -> None:
        pass  
[docs]class VwPolicy(Policy):
[docs]    def __init__(
        self,
        model_repo: ModelRepository,
        vw_cmd: List[str],
        feature_embedder: Embedder,
        vw_logger: VwLogger,
        *args: Any,
        **kwargs: Any,
    ):
        super().__init__(*args, **kwargs)
        self.model_repo = model_repo
        self.workspace = self.model_repo.load(vw_cmd)
        self.feature_embedder = feature_embedder
        self.vw_logger = vw_logger 
[docs]    def predict(self, event: TEvent) -> Any:
        import vowpal_wabbit_next as vw
        text_parser = vw.TextFormatParser(self.workspace)
        return self.workspace.predict_one(
            parse_lines(text_parser, self.feature_embedder.format(event))
        ) 
[docs]    def learn(self, event: TEvent) -> None:
        import vowpal_wabbit_next as vw
        vw_ex = self.feature_embedder.format(event)
        text_parser = vw.TextFormatParser(self.workspace)
        multi_ex = parse_lines(text_parser, vw_ex)
        self.workspace.learn_one(multi_ex) 
[docs]    def log(self, event: TEvent) -> None:
        if self.vw_logger.logging_enabled():
            vw_ex = self.feature_embedder.format(event)
            self.vw_logger.log(vw_ex) 
[docs]    def save(self) -> None:
        self.model_repo.save(self.workspace)  
[docs]class Embedder(Generic[TEvent], ABC):
[docs]    def __init__(self, *args: Any, **kwargs: Any):
        pass 
 
[docs]class SelectionScorer(Generic[TEvent], ABC, BaseModel):
    """Abstract method to grade the chosen selection or the response of the llm"""
[docs]    @abstractmethod
    def score_response(
        self, inputs: Dict[str, Any], llm_response: str, event: TEvent
    ) -> float:
        ...  
[docs]class AutoSelectionScorer(SelectionScorer[Event], BaseModel):
    llm_chain: LLMChain
    prompt: Union[BasePromptTemplate, None] = None
    scoring_criteria_template_str: Optional[str] = None
[docs]    @staticmethod
    def get_default_system_prompt() -> SystemMessagePromptTemplate:
        return SystemMessagePromptTemplate.from_template(
            "PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n \
                You are a strict judge that is called on to rank a response based on \
                    given criteria. You must respond with your ranking by providing a \
                        single float within the range [0, 1], 0 being very bad \
                            response and 1 being very good response."
        ) 
[docs]    @staticmethod
    def get_default_prompt() -> ChatPromptTemplate:
        human_template = 'Given this based_on "{rl_chain_selected_based_on}" \
            as the most important attribute, rank how good or bad this text is: \
                "{rl_chain_selected}".'
        human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
        default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
        chat_prompt = ChatPromptTemplate.from_messages(
            [default_system_prompt, human_message_prompt]
        )
        return chat_prompt 
    @root_validator(pre=True)
    def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        llm = values.get("llm")
        prompt = values.get("prompt")
        scoring_criteria_template_str = values.get("scoring_criteria_template_str")
        if prompt is None and scoring_criteria_template_str is None:
            prompt = AutoSelectionScorer.get_default_prompt()
        elif prompt is None and scoring_criteria_template_str is not None:
            human_message_prompt = HumanMessagePromptTemplate.from_template(
                scoring_criteria_template_str
            )
            default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
            prompt = ChatPromptTemplate.from_messages(
                [default_system_prompt, human_message_prompt]
            )
        values["prompt"] = prompt
        values["llm_chain"] = LLMChain(llm=llm, prompt=prompt)
        return values
[docs]    def score_response(
        self, inputs: Dict[str, Any], llm_response: str, event: Event
    ) -> float:
        ranking = self.llm_chain.predict(llm_response=llm_response, **inputs)
        ranking = ranking.strip()
        try:
            resp = float(ranking)
            return resp
        except Exception as e:
            raise RuntimeError(
                f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}"  # noqa: E501
            )  
[docs]class RLChain(Chain, Generic[TEvent]):
    """
    The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
    Attributes:
        - llm_chain (Chain): Represents the underlying Language Model chain.
        - prompt (BasePromptTemplate): The template for the base prompt.
        - selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
        - policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
        - auto_embed (bool): Determines if embedding should be automatic. Default is False.
        - metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): Tracker for metrics, can be set to None.
    Initialization Attributes:
        - feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
        - model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
        - reset_model (bool): If set to True, the model starts training from scratch. Default is False.
        - vw_cmd (List[str], optional): Command line arguments for the VW model.
        - policy (Type[VwPolicy]): Policy used by the chain.
        - vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
        - metrics_step (int): Step for the metrics tracker. Default is -1. If set without metrics_window_size, average metrics will be tracked, otherwise rolling window metrics will be tracked.
        - metrics_window_size (int): Window size for the metrics tracker. Default is -1. If set, rolling window metrics will be tracked.
    Notes:
        The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
    """  # noqa: E501
    class _NoOpPolicy(Policy):
        """Placeholder policy that does nothing"""
        def predict(self, event: TEvent) -> Any:
            return None
        def learn(self, event: TEvent) -> None:
            pass
        def log(self, event: TEvent) -> None:
            pass
    llm_chain: Chain
    output_key: str = "result"  #: :meta private:
    prompt: BasePromptTemplate
    selection_scorer: Union[SelectionScorer, None]
    active_policy: Policy = _NoOpPolicy()
    auto_embed: bool = False
    selection_scorer_activated: bool = True
    selected_input_key = "rl_chain_selected"
    selected_based_on_input_key = "rl_chain_selected_based_on"
    metrics: Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]] = None
    def __init__(
        self,
        feature_embedder: Embedder,
        model_save_dir: str = "./",
        reset_model: bool = False,
        vw_cmd: Optional[List[str]] = None,
        policy: Type[Policy] = VwPolicy,
        vw_logs: Optional[Union[str, os.PathLike]] = None,
        metrics_step: int = -1,
        metrics_window_size: int = -1,
        *args: Any,
        **kwargs: Any,
    ):
        super().__init__(*args, **kwargs)
        if self.selection_scorer is None:
            logger.warning(
                "No selection scorer provided, which means that no \
                    reinforcement learning will be done in the RL chain \
                        unless update_with_delayed_score is called."
            )
        if isinstance(self.active_policy, RLChain._NoOpPolicy):
            self.active_policy = policy(
                model_repo=ModelRepository(
                    model_save_dir, with_history=True, reset=reset_model
                ),
                vw_cmd=vw_cmd or [],
                feature_embedder=feature_embedder,
                vw_logger=VwLogger(vw_logs),
            )
        if metrics_window_size > 0:
            self.metrics = MetricsTrackerRollingWindow(
                step=metrics_step, window_size=metrics_window_size
            )
        else:
            self.metrics = MetricsTrackerAverage(step=metrics_step)
    class Config:
        """Configuration for this pydantic object."""
        extra = Extra.forbid
        arbitrary_types_allowed = True
    @property
    def input_keys(self) -> List[str]:
        """Expect input key.
        :meta private:
        """
        return []
    @property
    def output_keys(self) -> List[str]:
        """Expect output key.
        :meta private:
        """
        return [self.output_key]
[docs]    def update_with_delayed_score(
        self, score: float, chain_response: Dict[str, Any], force_score: bool = False
    ) -> None:
        """
        Updates the learned policy with the score provided.
        Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
        """  # noqa: E501
        if self._can_use_selection_scorer() and not force_score:
            raise RuntimeError(
                "The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."  # noqa: E501
            )
        if self.metrics:
            self.metrics.on_feedback(score)
        event: TEvent = chain_response["selection_metadata"]
        self._call_after_scoring_before_learning(event=event, score=score)
        self.active_policy.learn(event=event)
        self.active_policy.log(event=event) 
[docs]    def deactivate_selection_scorer(self) -> None:
        """
        Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses.
        """  # noqa: E501
        self.selection_scorer_activated = False 
[docs]    def activate_selection_scorer(self) -> None:
        """
        Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses.
        """  # noqa: E501
        self.selection_scorer_activated = True 
[docs]    def save_progress(self) -> None:
        """
        This function should be called to save the state of the learned policy model.
        """  # noqa: E501
        self.active_policy.save() 
    def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
        super()._validate_inputs(inputs)
        if (
            self.selected_input_key in inputs.keys()
            or self.selected_based_on_input_key in inputs.keys()
        ):
            raise ValueError(
                f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."  # noqa: E501
            )
    def _can_use_selection_scorer(self) -> bool:
        """
        Returns whether the chain can use the selection scorer to score responses or not.
        """  # noqa: E501
        return self.selection_scorer is not None and self.selection_scorer_activated
    @abstractmethod
    def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
        ...
    @abstractmethod
    def _call_after_predict_before_llm(
        self, inputs: Dict[str, Any], event: TEvent, prediction: Any
    ) -> Tuple[Dict[str, Any], TEvent]:
        ...
    @abstractmethod
    def _call_after_llm_before_scoring(
        self, llm_response: str, event: TEvent
    ) -> Tuple[Dict[str, Any], TEvent]:
        ...
    @abstractmethod
    def _call_after_scoring_before_learning(
        self, event: TEvent, score: Optional[float]
    ) -> TEvent:
        ...
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        event: TEvent = self._call_before_predict(inputs=inputs)
        prediction = self.active_policy.predict(event=event)
        if self.metrics:
            self.metrics.on_decision()
        next_chain_inputs, event = self._call_after_predict_before_llm(
            inputs=inputs, event=event, prediction=prediction
        )
        t = self.llm_chain.run(**next_chain_inputs, callbacks=_run_manager.get_child())
        _run_manager.on_text(t, color="green", verbose=self.verbose)
        t = t.strip()
        if self.verbose:
            _run_manager.on_text("\nCode: ", verbose=self.verbose)
        output = t
        _run_manager.on_text("\nAnswer: ", verbose=self.verbose)
        _run_manager.on_text(output, color="yellow", verbose=self.verbose)
        next_chain_inputs, event = self._call_after_llm_before_scoring(
            llm_response=output, event=event
        )
        score = None
        try:
            if self._can_use_selection_scorer():
                score = self.selection_scorer.score_response(  # type: ignore
                    inputs=next_chain_inputs, llm_response=output, event=event
                )
        except Exception as e:
            logger.info(
                f"The selection scorer was not able to score, \
                and the chain was not able to adjust to this response, error: {e}"
            )
        if self.metrics and score is not None:
            self.metrics.on_feedback(score)
        event = self._call_after_scoring_before_learning(score=score, event=event)
        self.active_policy.learn(event=event)
        self.active_policy.log(event=event)
        return {self.output_key: {"response": output, "selection_metadata": event}}
    @property
    def _chain_type(self) -> str:
        return "llm_personalizer_chain" 
[docs]def is_stringtype_instance(item: Any) -> bool:
    """Helper function to check if an item is a string."""
    return isinstance(item, str) or (
        isinstance(item, _Embed) and isinstance(item.value, str)
    ) 
[docs]def embed_string_type(
    item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, Union[str, List[str]]]:
    """Helper function to embed a string or an _Embed object."""
    keep_str = ""
    if isinstance(item, _Embed):
        encoded = stringify_embedding(model.encode(item.value))
        if item.keep:
            keep_str = item.value.replace(" ", "_") + " "
    elif isinstance(item, str):
        encoded = item.replace(" ", "_")
    else:
        raise ValueError(f"Unsupported type {type(item)} for embedding")
    if namespace is None:
        raise ValueError(
            "The default namespace must be provided when embedding a string or _Embed object."  # noqa: E501
        )
    return {namespace: keep_str + encoded} 
[docs]def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
    """Helper function to embed a dictionary item."""
    inner_dict: Dict = {}
    for ns, embed_item in item.items():
        if isinstance(embed_item, list):
            inner_dict[ns] = []
            for embed_list_item in embed_item:
                embedded = embed_string_type(embed_list_item, model, ns)
                inner_dict[ns].append(embedded[ns])
        else:
            inner_dict.update(embed_string_type(embed_item, model, ns))
    return inner_dict 
[docs]def embed_list_type(
    item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]:
    ret_list: List = []
    for embed_item in item:
        if isinstance(embed_item, dict):
            ret_list.append(embed_dict_type(embed_item, model))
        elif isinstance(embed_item, list):
            item_embedding = embed_list_type(embed_item, model, namespace)
            # Get the first key from the first dictionary
            first_key = next(iter(item_embedding[0]))
            # Group the values under that key
            grouping = {first_key: [item[first_key] for item in item_embedding]}
            ret_list.append(grouping)
        else:
            ret_list.append(embed_string_type(embed_item, model, namespace))
    return ret_list 
[docs]def embed(
    to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
    model: Any,
    namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]:
    """
    Embeds the actions or context using the SentenceTransformer model (or a model that has an `encode` function)
    Attributes:
        to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
        namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided.
        model: (Any, required) The model to use for embedding
    Returns:
        List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
    """  # noqa: E501
    if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
        to_embed, str
    ):
        return [embed_string_type(to_embed, model, namespace)]
    elif isinstance(to_embed, dict):
        return [embed_dict_type(to_embed, model)]
    elif isinstance(to_embed, list):
        return embed_list_type(to_embed, model, namespace)
    else:
        raise ValueError("Invalid input format for embedding")