from __future__ import annotations
import json
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
from pydantic import BaseModel
from imagine.langchain.mixins import BaseLangChainMixin
from imagine.types.chat_completions import ChatMessage as ImagineChatMessage
try:
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
)
# parse_tool_call,
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
except ImportError:
raise ImportError(
"LangChain dependencies are missing. Please install with 'pip install imagine-sdk[langchain]' to add them."
)
from imagine.types.chat_completions import (
ChatCompletionResponse,
ChatCompletionStreamResponse,
)
def _convert_imagine_message_to_lc_message(
_message: ImagineChatMessage,
) -> BaseMessage:
role = _message.role
content = _message.content
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
additional_kwargs: dict = {}
tool_calls = []
invalid_tool_calls = []
if response_tool_calls := _message.tool_calls:
# additional_kwargs["tool_calls"] = response_tool_calls
for response_tool_call in response_tool_calls:
try:
# tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
tool_calls.append(
ToolCall(
name=response_tool_call.function.name,
args=json.loads(response_tool_call.function.arguments),
id=response_tool_call.id,
type="tool_call",
)
)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(response_tool_call, str(e))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=content)
elif role == "tool":
additional_kwargs = {}
additional_kwargs["name"] = _message.name
return ToolMessage(
content=content,
tool_call_id=_message.tool_call_id,
additional_kwargs=additional_kwargs,
)
else:
return ChatMessage(content=content, role=role)
def _convert_chunk_to_message_chunk(
chunk: ChatCompletionStreamResponse, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
_delta = chunk.choices[0].delta
role = _delta.role
content = _delta.content or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
if token_usage := chunk.usage:
usage_metadata = {
"input_tokens": token_usage.prompt_tokens,
"output_tokens": token_usage.completion_tokens,
"total_tokens": token_usage.total_tokens,
}
else:
usage_metadata = None
return AIMessageChunk(
content=content,
usage_metadata=usage_metadata, # type: ignore[arg-type]
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore
else:
return default_class(content=content) # type: ignore[call-arg]
def _lc_tool_call_to_imagine_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _convert_lc_message_to_dict_message(
message: BaseMessage,
) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict: dict[str, Any] = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_imagine_tool_call(tc) for tc in message.tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
[docs]
class ImagineChat(BaseChatModel, BaseLangChainMixin):
"""A chat model that use Imagine Inference API"""
# max_concurrent_requests: int = 64 # later
model: str = "Llama-3.1-8B"
temperature: float = 0.0
max_tokens: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
streaming: bool = False
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
repetition_penalty: Optional[float] = None
stop: Optional[List[str]] = None
max_seconds: Optional[int] = None
ignore_eos: Optional[bool] = None
skip_special_tokens: Optional[bool] = None
stop_token_ids: Optional[List[List[int]]] = None
@property
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling the API."""
body = dict(
model=self.model,
stream=self.streaming,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
repetition_penalty=self.repetition_penalty,
stop=self.stop,
max_seconds=self.max_seconds,
ignore_eos=self.ignore_eos,
skip_special_tokens=self.skip_special_tokens,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
)
body = {k: v for k, v in body.items() if v is not None}
return body
@property
def _client_params(self) -> dict[str, Any]:
"""Get the parameters used for the client."""
return self._default_params
@property
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "imagine-chat"
@property
def lc_secrets(self) -> dict[str, str]:
return {"IMAGINE_API_KEY": "IMAGINE_API_KEY"}
[docs]
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
[docs]
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "imagine"]
def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
if token_usage := output.get("token_usage", None):
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined
def _create_chat_result(self, response: ChatCompletionResponse) -> ChatResult:
generations = []
for res in response.choices:
finish_reason = str(res.finish_reason)
message = _convert_imagine_message_to_lc_message(res.message)
if response.usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": response.usage.prompt_tokens or 0,
"output_tokens": response.usage.completion_tokens or 0,
"total_tokens": response.usage.total_tokens or 0,
}
gen = ChatGeneration(
message=message,
generation_info={"finish_reason": finish_reason},
)
generations.append(gen)
llm_output = {"token_usage": response.usage.model_dump(), "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: list[BaseMessage]
) -> tuple[list[dict], dict[str, Any]]:
params = self._client_params
message_dicts = [_convert_lc_message_to_dict_message(m) for m in messages]
return message_dicts, params
def _stream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages)
params = {**params, **kwargs}
params.pop("stream", "")
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
for chunk in self.client.chat_stream(messages=message_dicts, **params):
if len(chunk.choices) == 0:
continue
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
if run_manager:
run_manager.on_llm_new_token(
token=cast(str, new_chunk.content), chunk=gen_chunk
)
yield gen_chunk
async def _astream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages)
params = {**params, **kwargs}
params.pop("stream", "")
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
async for chunk in self.async_client.chat_stream(
messages=message_dicts, **params
):
if len(chunk.choices) == 0:
continue
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
if run_manager:
await run_manager.on_llm_new_token(
token=cast(str, new_chunk.content), chunk=gen_chunk
)
yield gen_chunk
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages)
params = {**params, **kwargs}
params.pop("stream", "")
if self.streaming:
stream_iter = self.client.chat_stream(
messages=message_dicts,
**params, # type: ignore
)
return generate_from_stream(stream_iter) # type: ignore
response = self.client.chat(messages=message_dicts, **params)
return self._create_chat_result(response)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
stream: bool | None = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages)
params = {**params, **kwargs}
params.pop("stream", "")
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **params
)
return await agenerate_from_stream(stream_iter)
response = await self.async_client.chat(messages=message_dicts, **params)
return self._create_chat_result(response)
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "any", "none"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
if tool_choice == "any" or tool_choice == "none":
raise ValueError("Imagine only supports tool_choice as 'auto'")
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)
# TODO: we don't support "json_mode" with structured output
# implement it in tool calling
def with_structured_output(
self,
schema: Optional[Union[dict, Type[BaseModel]]] = None,
*,
method: Literal["function_calling"] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema],
first_tool_only=True,
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser