@ -15,9 +15,11 @@ class BaseChatAdpter:
def match ( self , model_path : str ) :
return True
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
""" Return the generate stream handler func """
pass
from pilot . model . inference import generate_stream
return generate_stream
def get_conv_template ( self , model_path : str ) - > Conversation :
return None
@ -105,10 +107,21 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
class VicunaChatAdapter ( BaseChatAdpter ) :
""" Model chat Adapter for vicuna """
def _is_llama2_based ( self , model_path : str ) :
# see https://huggingface.co/lmsys/vicuna-13b-v1.5
return " v1.5 " in model_path . lower ( )
def match ( self , model_path : str ) :
return " vicuna " in model_path
return " vicuna " in model_path . lower ( )
def get_generate_stream_func ( self ) :
def get_conv_template ( self , model_path : str ) - > Conversation :
if self . _is_llama2_based ( model_path ) :
return get_conv_template ( " vicuna_v1.1 " )
return None
def get_generate_stream_func ( self , model_path : str ) :
if self . _is_llama2_based ( model_path ) :
return super ( ) . get_generate_stream_func ( model_path )
return generate_stream
@ -118,7 +131,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
def match ( self , model_path : str ) :
return " chatglm " in model_path
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
from pilot . model . llm_out . chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream
@ -130,7 +143,7 @@ class CodeT5ChatAdapter(BaseChatAdpter):
def match ( self , model_path : str ) :
return " codet5 " in model_path
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
# TODO
pass
@ -141,7 +154,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
def match ( self , model_path : str ) :
return " codegen " in model_path
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
# TODO
pass
@ -152,7 +165,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
def match ( self , model_path : str ) :
return " guanaco " in model_path
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
from pilot . model . llm_out . guanaco_llm import guanaco_generate_stream
return guanaco_generate_stream
@ -164,7 +177,7 @@ class FalconChatAdapter(BaseChatAdpter):
def match ( self , model_path : str ) :
return " falcon " in model_path
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
from pilot . model . llm_out . falcon_llm import falcon_generate_output
return falcon_generate_output
@ -174,7 +187,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
def match ( self , model_path : str ) :
return " proxyllm " in model_path
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
from pilot . model . llm_out . proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream
@ -184,7 +197,7 @@ class GorillaChatAdapter(BaseChatAdpter):
def match ( self , model_path : str ) :
return " gorilla " in model_path
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
from pilot . model . llm_out . gorilla_llm import generate_stream
return generate_stream
@ -194,7 +207,7 @@ class GPT4AllChatAdapter(BaseChatAdpter):
def match ( self , model_path : str ) :
return " gpt4all " in model_path
def get_generate_stream_func ( self ) :
def get_generate_stream_func ( self , model_path : str ) :
from pilot . model . llm_out . gpt4all_llm import gpt4all_generate_stream
return gpt4all_generate_stream
@ -207,11 +220,6 @@ class Llama2ChatAdapter(BaseChatAdpter):
def get_conv_template ( self , model_path : str ) - > Conversation :
return get_conv_template ( " llama-2 " )
def get_generate_stream_func ( self ) :
from pilot . model . inference import generate_stream
return generate_stream
class BaichuanChatAdapter ( BaseChatAdpter ) :
def match ( self , model_path : str ) :
@ -222,10 +230,13 @@ class BaichuanChatAdapter(BaseChatAdpter):
return get_conv_template ( " baichuan-chat " )
return get_conv_template ( " zero_shot " )
def get_generate_stream_func ( self ) :
from pilot . model . inference import generate_stream
return generate_stream
class WizardLMChatAdapter ( BaseChatAdpter ) :
def match ( self , model_path : str ) :
return " wizardlm " in model_path . lower ( )
def get_conv_template ( self , model_path : str ) - > Conversation :
return get_conv_template ( " vicuna_v1.1 " )
register_llm_model_chat_adapter ( VicunaChatAdapter )
@ -236,6 +247,7 @@ register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter ( GPT4AllChatAdapter )
register_llm_model_chat_adapter ( Llama2ChatAdapter )
register_llm_model_chat_adapter ( BaichuanChatAdapter )
register_llm_model_chat_adapter ( WizardLMChatAdapter )
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter ( ProxyllmChatAdapter )