Support OpenAI's new O1 Model Series (#912)

- Major
   - The new O1 series doesn't seem to support streaming, response_format enforcement, 
      stop words or temperature currently. 
   - Remove any markdown json codeblock in chat actors expecting json responses

- Minor
   - Override block display styling of links by Katex in chat messages
This commit is contained in:
Debanjum 2024-09-12 18:42:51 -07:00 committed by GitHub
commit 26ca3df605
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 63 additions and 21 deletions

View file

@ -18,6 +18,11 @@ div.chatMessageWrapper p:not(:last-child) {
margin-bottom: 16px; margin-bottom: 16px;
} }
/* Override some link styling by Katex to improve rendering */
div.chatMessageWrapper a span {
display: revert !important;
}
div.khojfullHistory { div.khojfullHistory {
border-width: 1px; border-width: 1px;
padding-left: 4px; padding-left: 4px;

View file

@ -14,6 +14,7 @@ from khoj.processor.conversation.openai.utils import (
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
construct_structured_message, construct_structured_message,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
remove_json_codeblock,
) )
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -85,6 +86,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
try: try:
response = response.strip() response = response.strip()
response = remove_json_codeblock(response)
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response: if not isinstance(response, list) or not response:

View file

@ -45,15 +45,28 @@ def completion_with_backoff(
openai_clients[client_key] = client openai_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True
# Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
if model.startswith("o1"):
stream = False
temperature = 1
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
chat = client.chat.completions.create( chat = client.chat.completions.create(
stream=True, stream=stream,
messages=formatted_messages, # type: ignore messages=formatted_messages, # type: ignore
model=model, # type: ignore model=model, # type: ignore
temperature=temperature, temperature=temperature,
timeout=20, timeout=20,
**(model_kwargs or dict()), **(model_kwargs or dict()),
) )
if not stream:
return chat.choices[0].message.content
aggregated_response = "" aggregated_response = ""
for chunk in chat: for chunk in chat:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
@ -112,9 +125,18 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
client: openai.OpenAI = openai_clients[client_key] client: openai.OpenAI = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True
# Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
if model_name.startswith("o1"):
stream = False
temperature = 1
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
chat = client.chat.completions.create( chat = client.chat.completions.create(
stream=True, stream=stream,
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
temperature=temperature, temperature=temperature,
@ -122,14 +144,17 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
**(model_kwargs or dict()), **(model_kwargs or dict()),
) )
for chunk in chat: if not stream:
if len(chunk.choices) == 0: g.send(chat.choices[0].message.content)
continue else:
delta_chunk = chunk.choices[0].delta for chunk in chat:
if isinstance(delta_chunk, str): if len(chunk.choices) == 0:
g.send(delta_chunk) continue
elif delta_chunk.content: delta_chunk = chunk.choices[0].delta
g.send(delta_chunk.content) if isinstance(delta_chunk, str):
g.send(delta_chunk)
elif delta_chunk.content:
g.send(delta_chunk.content)
except Exception as e: except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True) logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally: finally:

View file

@ -1,4 +1,3 @@
import json
import logging import logging
import math import math
import queue import queue
@ -24,6 +23,8 @@ model_to_prompt_size = {
"gpt-4-0125-preview": 20000, "gpt-4-0125-preview": 20000,
"gpt-4-turbo-preview": 20000, "gpt-4-turbo-preview": 20000,
"gpt-4o-mini": 20000, "gpt-4o-mini": 20000,
"o1-preview": 20000,
"o1-mini": 20000,
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500, "TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
"NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
@ -220,8 +221,9 @@ def truncate_messages(
try: try:
if loaded_model: if loaded_model:
encoder = loaded_model.tokenizer() encoder = loaded_model.tokenizer()
elif model_name.startswith("gpt-"): elif model_name.startswith("gpt-") or model_name.startswith("o1"):
encoder = tiktoken.encoding_for_model(model_name) # as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
elif tokenizer_name: elif tokenizer_name:
if tokenizer_name in state.pretrained_tokenizers: if tokenizer_name in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[tokenizer_name] encoder = state.pretrained_tokenizers[tokenizer_name]
@ -278,10 +280,19 @@ def truncate_messages(
) )
if system_message: if system_message:
system_message.role = "user" if "gemma-2" in model_name else "system" # Default system message role is system.
# Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series.
system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system"
return messages + [system_message] if system_message else messages return messages + [system_message] if system_message else messages
def reciprocal_conversation_to_chatml(message_pair): def reciprocal_conversation_to_chatml(message_pair):
"""Convert a single back and forth between user and assistant to chatml format""" """Convert a single back and forth between user and assistant to chatml format"""
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def remove_json_codeblock(response):
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3]
return response

View file

@ -88,6 +88,7 @@ from khoj.processor.conversation.openai.gpt import converse, send_message_to_mod
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ThreadedGenerator, ThreadedGenerator,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
remove_json_codeblock,
save_to_conversation_log, save_to_conversation_log,
) )
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
@ -298,9 +299,7 @@ async def aget_relevant_information_sources(
try: try:
response = response.strip() response = response.strip()
# Remove any markdown json codeblock formatting if present (useful for gemma-2) response = remove_json_codeblock(response)
if response.startswith("```json"):
response = response[7:-3]
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["source"] if q.strip()] response = [q.strip() for q in response["source"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0: if not isinstance(response, list) or not response or len(response) == 0:
@ -353,7 +352,9 @@ async def aget_relevant_output_modes(
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object") response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object")
try: try:
response = json.loads(response.strip()) response = response.strip()
response = remove_json_codeblock(response)
response = json.loads(response)
if is_none_or_empty(response): if is_none_or_empty(response):
return ConversationCommand.Text return ConversationCommand.Text
@ -433,9 +434,7 @@ async def generate_online_subqueries(
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
try: try:
response = response.strip() response = response.strip()
# Remove any markdown json codeblock formatting if present (useful for gemma-2) response = remove_json_codeblock(response)
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3]
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0: if not isinstance(response, list) or not response or len(response) == 0: