Google models
Source code is located at here.
import re
import time
import json
from aios.llm_core.cores.base import BaseLLM
from aios.utils import get_from_env
from cerebrum.llm.communication import Response
class GeminiLLM(BaseLLM):
def __init__(
self,
llm_name: str,
max_gpu_memory: dict = None,
eval_device: str = None,
max_new_tokens: int = 256,
log_mode: str = "console",
use_context_manager: bool = False,
):
super().__init__(
llm_name,
max_gpu_memory,
eval_device,
max_new_tokens,
log_mode,
use_context_manager,
)
def load_llm_and_tokenizer(self) -> None:
"""dynamic loading because the module is only needed for this case"""
assert re.search(r"gemini", self.model_name, re.IGNORECASE)
try:
import google.generativeai as genai
gemini_api_key = get_from_env("GEMINI_API_KEY")
genai.configure(api_key=gemini_api_key)
self.model = genai.GenerativeModel(self.model_name)
self.tokenizer = None
except ImportError:
raise ImportError(
"Could not import google.generativeai python package. "
"Please install it with `pip install google-generativeai`."
)
def convert_messages(self, messages):
if messages:
gemini_messages = []
for m in messages:
gemini_messages.append(
{
"role": "user" if m["role"] in ["user", "system"] else "model",
"parts": {"text": m["content"]},
}
)
else:
gemini_messages = None
return gemini_messages
def address_syscall(self, llm_syscall, temperature=0.0) -> None:
# ensures the model is the current one
"""wrapper around functions"""
llm_syscall.set_status("executing")
llm_syscall.set_start_time(time.time())
messages = llm_syscall.query.messages
tools = llm_syscall.query.tools
message_return_type = llm_syscall.query.message_return_type
if tools:
messages = self.tool_calling_input_format(messages, tools)
# convert role to fit the gemini role types
messages = self.convert_messages(
messages=messages,
)
self.logger.log(
f"{llm_syscall.agent_name} is switched to executing.\n", level="executing"
)
outputs = self.model.generate_content(json.dumps({"contents": messages}))
try:
result = outputs.candidates[0].content.parts[0].text
if tools:
tool_calls = self.parse_tool_calls(result)
if tool_calls:
response = Response(
response_message=None,
tool_calls=tool_calls,
finished=True
)
else:
response = Response(
response_message=result,
finished=True
)
else:
if message_return_type == "json":
result = self.parse_json_format(result)
response = Response(
response_message=result,
finished=True
)
except IndexError:
raise IndexError(
f"{self.model_name} can not generate a valid result, please try again"
)
return response
Last updated