diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index b43a23c4..312f2535 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -236,12 +236,11 @@ A:{ "search-type": "notes" }""" return json.loads(story.strip(empty_escape_sequences)) -def converse(references, user_query, conversation_log={}, api_key=None, temperature=0.2): +def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", api_key=None, temperature=0.2): """ Converse with user using OpenAI's ChatGPT """ # Initialize Variables - model = "gpt-3.5-turbo" compiled_references = "\n\n".join({f"# {item}" for item in references}) personality_primer = "You are Khoj, a friendly, smart and helpful personal assistant." diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 8ee6de13..6aa72b38 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -208,6 +208,7 @@ def chat(q: Optional[str] = None): # Initialize Variables api_key = state.processor_config.conversation.openai_api_key model = state.processor_config.conversation.model + chat_model = state.processor_config.conversation.chat_model user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Load Conversation History @@ -234,7 +235,7 @@ def chat(q: Optional[str] = None): try: with timer("Generating chat response took", logger): - gpt_response = converse(compiled_references, q, meta_log, api_key=api_key) + gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key) status = "ok" except Exception as e: gpt_response = str(e) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 90691592..76baa14d 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -69,6 +69,7 @@ class ConversationProcessorConfigModel: def __init__(self, processor_config: ConversationProcessorConfig): self.openai_api_key = processor_config.openai_api_key self.model = processor_config.model + self.chat_model = processor_config.chat_model self.conversation_logfile = Path(processor_config.conversation_logfile) self.chat_session = "" self.meta_log: dict = {} diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 6b87c220..51d64381 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -82,6 +82,7 @@ class ConversationProcessorConfig(ConfigBase): openai_api_key: str conversation_logfile: Path model: Optional[str] = "text-davinci-003" + chat_model: Optional[str] = "gpt-3.5-turbo" class ProcessorConfig(ConfigBase):