Merge pull request #1002 from khoj-ai/features/improve-multiple-output-mode-generation

Improve handling of multiple output modes

- Use the generated descriptions / inferred queries to supply context to the model about what it's created and give a richer response
- Stop sending the generated image in user message. This seemed to be confusing the model more than helping.
- Collect generated assets in a structured objects to provide model context. This seems to help it follow instructions and separate responsibility better
- Also, rename the open ai converse method to converse_openai to follow patterns with other providers
This commit is contained in:
sabaimran 2024-12-10 17:06:19 -08:00 committed by GitHub
commit 01d000e570
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 111 additions and 87 deletions

View file

@ -157,10 +157,9 @@ def converse_anthropic(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
program_execution_context: Optional[List[str]] = None,
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
):
"""
@ -221,9 +220,8 @@ def converse_anthropic(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)

View file

@ -167,9 +167,8 @@ def converse_gemini(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
tracer={},
):
@ -232,9 +231,8 @@ def converse_gemini(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)

View file

@ -3,7 +3,7 @@ import logging
import os
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Union
import pyjson5
from langchain.schema import ChatMessage
@ -166,6 +166,7 @@ def converse_offline(
query_files: str = None,
generated_files: List[FileAttachment] = None,
additional_context: List[str] = None,
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
@ -234,6 +235,7 @@ def converse_offline(
model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=additional_context,
)

View file

@ -137,7 +137,7 @@ def send_message_to_model(
)
def converse(
def converse_openai(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
@ -157,9 +157,8 @@ def converse(
query_images: Optional[list[str]] = None,
vision_available: bool = False,
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: Optional[str] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
tracer: dict = {},
):
@ -223,9 +222,8 @@ def converse(
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
query_files=query_files,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_files=generated_files,
generated_images=generated_images,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")

View file

@ -178,40 +178,41 @@ Improved Prompt:
""".strip()
)
generated_image_attachment = PromptTemplate.from_template(
f"""
Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences.
generated_assets_context = PromptTemplate.from_template(
"""
Assets that you created have already been created to respond to the query. Below, there are references to the descriptions used to create the assets.
You can provide a summary of your reasoning from the information below or use it to respond to the original query.
Generated Assets:
{generated_assets}
Limit your response to 3 sentences max. Be succinct, clear, and informative.
""".strip()
)
generated_diagram_attachment = PromptTemplate.from_template(
f"""
I've successfully created a diagram based on the user's query. The diagram will automatically be shared with the user. I can follow-up with a general response or summary. Limit to 1-2 sentences.
""".strip()
)
## Diagram Generation
## --
improve_diagram_description_prompt = PromptTemplate.from_template(
"""
you are an architect working with a novice digital artist using a diagramming software.
You are an architect working with a novice digital artist using a diagramming software.
{personality_context}
you need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
- text
- rectangle
- ellipse
- line
- arrow
You need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
- Text
- Rectangle
- Ellipse
- Line
- Arrow
use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description.
Use these primitives to describe what sort of diagram the drawer should create. The artist must recreate the diagram every time, so include all relevant prior information in your description.
- include the full, exact description. the artist does not have much experience, so be precise.
- describe the layout.
- you can only use straight lines.
- use simple, concise language.
- keep it simple and easy to understand. the artist is easily distracted.
- Include the full, exact description. the artist does not have much experience, so be precise.
- Describe the layout.
- You can only use straight lines.
- Use simple, concise language.
- Keep it simple and easy to understand. the artist is easily distracted.
Today's Date: {current_date}
User's Location: {location}
@ -337,6 +338,17 @@ Diagram Description: {query}
""".strip()
)
failed_diagram_generation = PromptTemplate.from_template(
"""
You attempted to programmatically generate a diagram but failed due to a system issue. You are normally able to generate diagrams, but you encountered a system issue this time.
You can create an ASCII image of the diagram in response instead.
This is the diagram you attempted to make:
{attempted_diagram}
""".strip()
)
## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(

View file

@ -40,6 +40,7 @@ from khoj.utils.helpers import (
merge_dicts,
)
from khoj.utils.rawconfig import FileAttachment
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -381,9 +382,8 @@ def generate_chatml_messages_with_context(
model_type="",
context_message="",
query_files: str = None,
generated_images: Optional[list[str]] = None,
generated_files: List[FileAttachment] = None,
generated_excalidraw_diagram: str = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = [],
):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
@ -403,11 +403,15 @@ def generate_chatml_messages_with_context(
message_context = ""
message_attached_files = ""
generated_assets = {}
chat_message = chat.get("message")
role = "user" if chat["by"] == "you" else "assistant"
# Legacy code to handle excalidraw diagrams prior to Dec 2024
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0]
if not is_none_or_empty(chat.get("context")):
references = "\n\n".join(
{
@ -434,15 +438,23 @@ def generate_chatml_messages_with_context(
reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message)
if chat.get("images") and role == "assistant":
# Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message.
file_attachment_message = construct_structured_message(
message=prompts.generated_image_attachment.format(),
images=chat.get("images"),
model_type=model_type,
vision_enabled=vision_enabled,
if not is_none_or_empty(chat.get("images")) and role == "assistant":
generated_assets["image"] = {
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
}
if not is_none_or_empty(chat.get("excalidrawDiagram")) and role == "assistant":
generated_assets["diagram"] = {
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
}
if not is_none_or_empty(generated_assets):
chatml_messages.append(
ChatMessage(
content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_assets))}\n",
role="user",
)
)
chatml_messages.append(ChatMessage(content=file_attachment_message, role="user"))
message_content = construct_structured_message(
chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled
@ -465,33 +477,22 @@ def generate_chatml_messages_with_context(
role="user",
)
)
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))
if generated_images:
messages.append(
ChatMessage(
content=construct_structured_message(
prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled
),
role="user",
)
)
if generated_files:
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
if generated_excalidraw_diagram:
messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant"))
if not is_none_or_empty(generated_asset_results):
context_message += (
f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n"
)
if program_execution_context:
messages.append(
ChatMessage(
content=prompts.additional_program_context.format(context="\n".join(program_execution_context)),
role="assistant",
)
)
program_context_text = "\n".join(program_execution_context)
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user"))
if len(chatml_messages) > 0:
messages += chatml_messages

View file

@ -23,6 +23,7 @@ from khoj.database.adapters import (
aget_user_name,
)
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
from khoj.processor.image.generate import text_to_image
@ -765,6 +766,7 @@ async def chat(
researched_results = ""
online_results: Dict = dict()
code_results: Dict = dict()
generated_asset_results: Dict = dict()
## Extract Document References
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
@ -1128,6 +1130,10 @@ async def chat(
else:
generated_images.append(generated_image)
generated_asset_results["images"] = {
"query": improved_image_prompt,
}
async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
@ -1166,6 +1172,10 @@ async def chat(
generated_excalidraw_diagram = diagram_description
generated_asset_results["diagrams"] = {
"query": better_diagram_description_prompt,
}
async for result in send_event(
ChatEvent.GENERATED_ASSETS,
{
@ -1176,7 +1186,9 @@ async def chat(
else:
error_message = "Failed to generate diagram. Please try again later."
program_execution_context.append(
f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt."
prompts.failed_diagram_generation.format(
attempted_diagram=better_diagram_description_prompt
)
)
async for result in send_event(ChatEvent.STATUS, error_message):
@ -1209,6 +1221,7 @@ async def chat(
generated_files,
generated_excalidraw_diagram,
program_execution_context,
generated_asset_results,
tracer,
)

View file

@ -88,7 +88,10 @@ from khoj.processor.conversation.offline.chat_model import (
converse_offline,
send_message_to_model_offline,
)
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.openai.gpt import (
converse_openai,
send_message_to_model,
)
from khoj.processor.conversation.utils import (
ChatEvent,
ThreadedGenerator,
@ -751,7 +754,7 @@ async def generate_excalidraw_diagram(
)
except Exception as e:
logger.error(f"Error generating Excalidraw diagram for {user.email}: {e}", exc_info=True)
yield None, None
yield better_diagram_description_prompt, None
return
scratchpad = excalidraw_diagram_description.get("scratchpad")
@ -1189,6 +1192,7 @@ def generate_chat_response(
raw_generated_files: List[FileAttachment] = [],
generated_excalidraw_diagram: str = None,
program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
@ -1251,6 +1255,7 @@ def generate_chat_response(
agent=agent,
query_files=query_files,
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
tracer=tracer,
)
@ -1258,7 +1263,7 @@ def generate_chat_response(
openai_chat_config = conversation_config.ai_model_api
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
chat_response = converse(
chat_response = converse_openai(
compiled_references,
query_to_run,
query_images=query_images,
@ -1278,8 +1283,7 @@ def generate_chat_response(
vision_available=vision_available,
query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
tracer=tracer,
)
@ -1305,8 +1309,7 @@ def generate_chat_response(
vision_available=vision_available,
query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
tracer=tracer,
)
@ -1331,8 +1334,7 @@ def generate_chat_response(
vision_available=vision_available,
query_files=query_files,
generated_files=raw_generated_files,
generated_images=generated_images,
generated_excalidraw_diagram=generated_excalidraw_diagram,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
tracer=tracer,
)

View file

@ -4,7 +4,7 @@ import freezegun
import pytest
from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.openai.gpt import converse_openai, extract_questions
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import (
aget_data_sources_and_output_format,
@ -158,7 +158,7 @@ def test_generate_search_query_using_question_and_answer_from_chat_history():
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content():
# Act
response_gen = converse(
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
api_key=api_key,
@ -183,7 +183,7 @@ def test_answer_from_chat_history_and_no_content():
]
# Act
response_gen = converse(
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="What is my name?",
conversation_log=populate_chat_history(message_list),
@ -214,7 +214,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
]
# Act
response_gen = converse(
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
@ -239,7 +239,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
]
# Act
response_gen = converse(
response_gen = converse_openai(
references=[
{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"}
], # Assume context retrieved from notes for the user_query
@ -265,7 +265,7 @@ def test_refuse_answering_unanswerable_question():
]
# Act
response_gen = converse(
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
@ -318,7 +318,7 @@ Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse(
response_gen = converse_openai(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?",
api_key=api_key,
@ -362,7 +362,7 @@ Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse(
response_gen = converse_openai(
references=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?",
api_key=api_key,
@ -386,7 +386,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
]
# Act
response_gen = converse(
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
conversation_log=populate_chat_history(message_list),
@ -426,7 +426,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
]
# Act
response_gen = converse(
response_gen = converse_openai(
references=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?",
api_key=api_key,
@ -459,13 +459,13 @@ def test_agent_prompt_should_be_used(openai_agent):
expected_responses = ["9.50", "9.5"]
# Act
response_gen = converse(
response_gen = converse_openai(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
api_key=api_key,
)
no_agent_response = "".join([response_chunk for response_chunk in response_gen])
response_gen = converse(
response_gen = converse_openai(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
api_key=api_key,