LLM Core(s)
For the implementation of LLM core(s), we design the BaseLLM
class as below that defines several functionalities.
class BaseLLM(ABC):
def __init__(
self,
llm_name: str,
max_gpu_memory: dict = None,
eval_device: str = None,
max_new_tokens: int = 256,
log_mode: str = "console",
use_context_manager: bool = False,
):
self.max_gpu_memory = max_gpu_memory
self.eval_device = eval_device
self.max_new_tokens = max_new_tokens
self.log_mode = log_mode
self.model_name = llm_name
self.use_context_manager = use_context_manager
if use_context_manager:
self.context_manager = SimpleContextManager()
self.load_llm_and_tokenizer()
self.logger = self.setup_logger()
self.logger.log("AIOS has been successfully initialized.\n", level="info")
def convert_map(self, original_map: dict) -> dict:
"""helper utility to convert the keys of a map to int"""
if original_map:
new_map = {}
for k, v in original_map.items():
new_map[int(k)] = v
return new_map
return None
def check_model_type(self, model_name):
# TODO add more model types
return "causal_lm"
def setup_logger(self):
logger = LLMKernelLogger(self.model_name, self.log_mode)
return logger
@abstractmethod
def load_llm_and_tokenizer(self) -> None: # load model from config
# raise NotImplementedError
"""Load model and tokenizers for each type of LLMs"""
return
# only use for open-sourced LLM
def tool_calling_input_format(self, messages: list, tools: list) -> list:
"""Integrate tool information into the messages for open-sourced LLMs
Args:
messages (list): messages with different roles
tools (list): tool information
"""
prefix_prompt = (
"In and only in current step, you need to call tools. Available tools are: "
)
tool_prompt = json.dumps(tools)
suffix_prompt = "".join(
[
"Must call functions that are available. To call a function, respond "
"immediately and only with a list of JSON object of the following format:"
'{[{"name":"function_name_value","parameters":{"parameter_name1":"parameter_value1",'
'"parameter_name2":"parameter_value2"}}]}'
]
)
# translate tool call message for models don't support tool call
for message in messages:
if "tool_calls" in message:
message["content"] = json.dumps(message.pop("tool_calls"))
elif message["role"] == "tool":
message["role"] = "user"
tool_call_id = message.pop("tool_call_id")
content = message.pop("content")
message["content"] = (
f"The result of the execution of function(id :{tool_call_id}) is: {content}. "
)
messages[-1]["content"] += prefix_prompt + tool_prompt + suffix_prompt
return messages
def parse_json_format(self, message: str) -> str:
json_array_pattern = r"\[\s*\{.*?\}\s*\]"
json_object_pattern = r"\{\s*.*?\s*\}"
match_array = re.search(json_array_pattern, message)
# print(f"match_array: {match_array}")
if match_array:
json_array_substring = match_array.group(0)
try:
json_array_data = json.loads(json_array_substring)
return json.dumps(json_array_data)
except json.JSONDecodeError:
pass
match_object = re.search(json_object_pattern, message)
if match_object:
json_object_substring = match_object.group(0)
try:
json_object_data = json.loads(json_object_substring)
return json.dumps(json_object_data)
except json.JSONDecodeError:
pass
return "[]"
def parse_tool_calls(self, message):
# add tool call id and type for models don't support tool call
tool_calls = json.loads(self.parse_json_format(message))
for tool_call in tool_calls:
tool_call["id"] = generator_tool_call_id()
tool_call["type"] = "function"
return tool_calls
@abstractmethod
def address_syscall(self, llm_syscall, temperature=0.0):
# return self.process(llm_syscall)
raise NotImplementedError
Different instances are implemented by inheriting this BaseLLM
class and override its abstract methods.
Last updated