mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-22 12:18:09 +00:00
Merge pull request #1002 from khoj-ai/features/improve-multiple-output-mode-generation
Improve handling of multiple output modes - Use the generated descriptions / inferred queries to supply context to the model about what it's created and give a richer response - Stop sending the generated image in user message. This seemed to be confusing the model more than helping. - Collect generated assets in a structured objects to provide model context. This seems to help it follow instructions and separate responsibility better - Also, rename the open ai converse method to converse_openai to follow patterns with other providers
This commit is contained in:
commit
01d000e570
9 changed files with 111 additions and 87 deletions
src/khoj
processor/conversation
routers
tests
|
@ -157,10 +157,9 @@ def converse_anthropic(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
generated_images: Optional[list[str]] = None,
|
|
||||||
generated_files: List[FileAttachment] = None,
|
generated_files: List[FileAttachment] = None,
|
||||||
generated_excalidraw_diagram: Optional[str] = None,
|
|
||||||
program_execution_context: Optional[List[str]] = None,
|
program_execution_context: Optional[List[str]] = None,
|
||||||
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -221,9 +220,8 @@ def converse_anthropic(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
generated_images=generated_images,
|
generated_asset_results=generated_asset_results,
|
||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -167,9 +167,8 @@ def converse_gemini(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
generated_images: Optional[list[str]] = None,
|
|
||||||
generated_files: List[FileAttachment] = None,
|
generated_files: List[FileAttachment] = None,
|
||||||
generated_excalidraw_diagram: Optional[str] = None,
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
|
@ -232,9 +231,8 @@ def converse_gemini(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
generated_images=generated_images,
|
generated_asset_results=generated_asset_results,
|
||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Iterator, List, Optional, Union
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
@ -166,6 +166,7 @@ def converse_offline(
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
generated_files: List[FileAttachment] = None,
|
generated_files: List[FileAttachment] = None,
|
||||||
additional_context: List[str] = None,
|
additional_context: List[str] = None,
|
||||||
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -234,6 +235,7 @@ def converse_offline(
|
||||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
|
generated_asset_results=generated_asset_results,
|
||||||
program_execution_context=additional_context,
|
program_execution_context=additional_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ def send_message_to_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def converse(
|
def converse_openai(
|
||||||
references,
|
references,
|
||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
|
@ -157,9 +157,8 @@ def converse(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
generated_images: Optional[list[str]] = None,
|
|
||||||
generated_files: List[FileAttachment] = None,
|
generated_files: List[FileAttachment] = None,
|
||||||
generated_excalidraw_diagram: Optional[str] = None,
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
|
@ -223,9 +222,8 @@ def converse(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
generated_images=generated_images,
|
generated_asset_results=generated_asset_results,
|
||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
)
|
)
|
||||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||||
|
|
|
@ -178,40 +178,41 @@ Improved Prompt:
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_image_attachment = PromptTemplate.from_template(
|
generated_assets_context = PromptTemplate.from_template(
|
||||||
f"""
|
"""
|
||||||
Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences.
|
Assets that you created have already been created to respond to the query. Below, there are references to the descriptions used to create the assets.
|
||||||
|
You can provide a summary of your reasoning from the information below or use it to respond to the original query.
|
||||||
|
|
||||||
|
Generated Assets:
|
||||||
|
{generated_assets}
|
||||||
|
|
||||||
|
Limit your response to 3 sentences max. Be succinct, clear, and informative.
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_diagram_attachment = PromptTemplate.from_template(
|
|
||||||
f"""
|
|
||||||
I've successfully created a diagram based on the user's query. The diagram will automatically be shared with the user. I can follow-up with a general response or summary. Limit to 1-2 sentences.
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
## Diagram Generation
|
## Diagram Generation
|
||||||
## --
|
## --
|
||||||
|
|
||||||
improve_diagram_description_prompt = PromptTemplate.from_template(
|
improve_diagram_description_prompt = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
you are an architect working with a novice digital artist using a diagramming software.
|
You are an architect working with a novice digital artist using a diagramming software.
|
||||||
{personality_context}
|
{personality_context}
|
||||||
|
|
||||||
you need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
|
You need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
|
||||||
- text
|
- Text
|
||||||
- rectangle
|
- Rectangle
|
||||||
- ellipse
|
- Ellipse
|
||||||
- line
|
- Line
|
||||||
- arrow
|
- Arrow
|
||||||
|
|
||||||
use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description.
|
Use these primitives to describe what sort of diagram the drawer should create. The artist must recreate the diagram every time, so include all relevant prior information in your description.
|
||||||
|
|
||||||
- include the full, exact description. the artist does not have much experience, so be precise.
|
- Include the full, exact description. the artist does not have much experience, so be precise.
|
||||||
- describe the layout.
|
- Describe the layout.
|
||||||
- you can only use straight lines.
|
- You can only use straight lines.
|
||||||
- use simple, concise language.
|
- Use simple, concise language.
|
||||||
- keep it simple and easy to understand. the artist is easily distracted.
|
- Keep it simple and easy to understand. the artist is easily distracted.
|
||||||
|
|
||||||
Today's Date: {current_date}
|
Today's Date: {current_date}
|
||||||
User's Location: {location}
|
User's Location: {location}
|
||||||
|
@ -337,6 +338,17 @@ Diagram Description: {query}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
failed_diagram_generation = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You attempted to programmatically generate a diagram but failed due to a system issue. You are normally able to generate diagrams, but you encountered a system issue this time.
|
||||||
|
|
||||||
|
You can create an ASCII image of the diagram in response instead.
|
||||||
|
|
||||||
|
This is the diagram you attempted to make:
|
||||||
|
{attempted_diagram}
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
## Online Search Conversation
|
## Online Search Conversation
|
||||||
## --
|
## --
|
||||||
online_search_conversation = PromptTemplate.from_template(
|
online_search_conversation = PromptTemplate.from_template(
|
||||||
|
|
|
@ -40,6 +40,7 @@ from khoj.utils.helpers import (
|
||||||
merge_dicts,
|
merge_dicts,
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import FileAttachment
|
from khoj.utils.rawconfig import FileAttachment
|
||||||
|
from khoj.utils.yaml import yaml_dump
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -381,9 +382,8 @@ def generate_chatml_messages_with_context(
|
||||||
model_type="",
|
model_type="",
|
||||||
context_message="",
|
context_message="",
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
generated_images: Optional[list[str]] = None,
|
|
||||||
generated_files: List[FileAttachment] = None,
|
generated_files: List[FileAttachment] = None,
|
||||||
generated_excalidraw_diagram: str = None,
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
program_execution_context: List[str] = [],
|
program_execution_context: List[str] = [],
|
||||||
):
|
):
|
||||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||||
|
@ -403,11 +403,15 @@ def generate_chatml_messages_with_context(
|
||||||
message_context = ""
|
message_context = ""
|
||||||
message_attached_files = ""
|
message_attached_files = ""
|
||||||
|
|
||||||
|
generated_assets = {}
|
||||||
|
|
||||||
chat_message = chat.get("message")
|
chat_message = chat.get("message")
|
||||||
role = "user" if chat["by"] == "you" else "assistant"
|
role = "user" if chat["by"] == "you" else "assistant"
|
||||||
|
|
||||||
|
# Legacy code to handle excalidraw diagrams prior to Dec 2024
|
||||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||||
chat_message = chat["intent"].get("inferred-queries")[0]
|
chat_message = chat["intent"].get("inferred-queries")[0]
|
||||||
|
|
||||||
if not is_none_or_empty(chat.get("context")):
|
if not is_none_or_empty(chat.get("context")):
|
||||||
references = "\n\n".join(
|
references = "\n\n".join(
|
||||||
{
|
{
|
||||||
|
@ -434,15 +438,23 @@ def generate_chatml_messages_with_context(
|
||||||
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||||
chatml_messages.insert(0, reconstructed_context_message)
|
chatml_messages.insert(0, reconstructed_context_message)
|
||||||
|
|
||||||
if chat.get("images") and role == "assistant":
|
if not is_none_or_empty(chat.get("images")) and role == "assistant":
|
||||||
# Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message.
|
generated_assets["image"] = {
|
||||||
file_attachment_message = construct_structured_message(
|
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
|
||||||
message=prompts.generated_image_attachment.format(),
|
}
|
||||||
images=chat.get("images"),
|
|
||||||
model_type=model_type,
|
if not is_none_or_empty(chat.get("excalidrawDiagram")) and role == "assistant":
|
||||||
vision_enabled=vision_enabled,
|
generated_assets["diagram"] = {
|
||||||
|
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
if not is_none_or_empty(generated_assets):
|
||||||
|
chatml_messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
content=f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_assets))}\n",
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
chatml_messages.append(ChatMessage(content=file_attachment_message, role="user"))
|
|
||||||
|
|
||||||
message_content = construct_structured_message(
|
message_content = construct_structured_message(
|
||||||
chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled
|
chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled
|
||||||
|
@ -465,33 +477,22 @@ def generate_chatml_messages_with_context(
|
||||||
role="user",
|
role="user",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not is_none_or_empty(context_message):
|
|
||||||
messages.append(ChatMessage(content=context_message, role="user"))
|
|
||||||
|
|
||||||
if generated_images:
|
|
||||||
messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
content=construct_structured_message(
|
|
||||||
prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled
|
|
||||||
),
|
|
||||||
role="user",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if generated_files:
|
if generated_files:
|
||||||
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
|
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
|
||||||
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
|
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
|
||||||
|
|
||||||
if generated_excalidraw_diagram:
|
if not is_none_or_empty(generated_asset_results):
|
||||||
messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant"))
|
context_message += (
|
||||||
|
f"{prompts.generated_assets_context.format(generated_assets=yaml_dump(generated_asset_results))}\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
if program_execution_context:
|
if program_execution_context:
|
||||||
messages.append(
|
program_context_text = "\n".join(program_execution_context)
|
||||||
ChatMessage(
|
context_message += f"{prompts.additional_program_context.format(context=program_context_text)}\n"
|
||||||
content=prompts.additional_program_context.format(context="\n".join(program_execution_context)),
|
|
||||||
role="assistant",
|
if not is_none_or_empty(context_message):
|
||||||
)
|
messages.append(ChatMessage(content=context_message, role="user"))
|
||||||
)
|
|
||||||
|
|
||||||
if len(chatml_messages) > 0:
|
if len(chatml_messages) > 0:
|
||||||
messages += chatml_messages
|
messages += chatml_messages
|
||||||
|
|
|
@ -23,6 +23,7 @@ from khoj.database.adapters import (
|
||||||
aget_user_name,
|
aget_user_name,
|
||||||
)
|
)
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, KhojUser
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||||
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
|
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
|
||||||
from khoj.processor.image.generate import text_to_image
|
from khoj.processor.image.generate import text_to_image
|
||||||
|
@ -765,6 +766,7 @@ async def chat(
|
||||||
researched_results = ""
|
researched_results = ""
|
||||||
online_results: Dict = dict()
|
online_results: Dict = dict()
|
||||||
code_results: Dict = dict()
|
code_results: Dict = dict()
|
||||||
|
generated_asset_results: Dict = dict()
|
||||||
## Extract Document References
|
## Extract Document References
|
||||||
compiled_references: List[Any] = []
|
compiled_references: List[Any] = []
|
||||||
inferred_queries: List[Any] = []
|
inferred_queries: List[Any] = []
|
||||||
|
@ -1128,6 +1130,10 @@ async def chat(
|
||||||
else:
|
else:
|
||||||
generated_images.append(generated_image)
|
generated_images.append(generated_image)
|
||||||
|
|
||||||
|
generated_asset_results["images"] = {
|
||||||
|
"query": improved_image_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
ChatEvent.GENERATED_ASSETS,
|
ChatEvent.GENERATED_ASSETS,
|
||||||
{
|
{
|
||||||
|
@ -1166,6 +1172,10 @@ async def chat(
|
||||||
|
|
||||||
generated_excalidraw_diagram = diagram_description
|
generated_excalidraw_diagram = diagram_description
|
||||||
|
|
||||||
|
generated_asset_results["diagrams"] = {
|
||||||
|
"query": better_diagram_description_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
ChatEvent.GENERATED_ASSETS,
|
ChatEvent.GENERATED_ASSETS,
|
||||||
{
|
{
|
||||||
|
@ -1176,7 +1186,9 @@ async def chat(
|
||||||
else:
|
else:
|
||||||
error_message = "Failed to generate diagram. Please try again later."
|
error_message = "Failed to generate diagram. Please try again later."
|
||||||
program_execution_context.append(
|
program_execution_context.append(
|
||||||
f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt."
|
prompts.failed_diagram_generation.format(
|
||||||
|
attempted_diagram=better_diagram_description_prompt
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_event(ChatEvent.STATUS, error_message):
|
async for result in send_event(ChatEvent.STATUS, error_message):
|
||||||
|
@ -1209,6 +1221,7 @@ async def chat(
|
||||||
generated_files,
|
generated_files,
|
||||||
generated_excalidraw_diagram,
|
generated_excalidraw_diagram,
|
||||||
program_execution_context,
|
program_execution_context,
|
||||||
|
generated_asset_results,
|
||||||
tracer,
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,10 @@ from khoj.processor.conversation.offline.chat_model import (
|
||||||
converse_offline,
|
converse_offline,
|
||||||
send_message_to_model_offline,
|
send_message_to_model_offline,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
from khoj.processor.conversation.openai.gpt import (
|
||||||
|
converse_openai,
|
||||||
|
send_message_to_model,
|
||||||
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
|
@ -751,7 +754,7 @@ async def generate_excalidraw_diagram(
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating Excalidraw diagram for {user.email}: {e}", exc_info=True)
|
logger.error(f"Error generating Excalidraw diagram for {user.email}: {e}", exc_info=True)
|
||||||
yield None, None
|
yield better_diagram_description_prompt, None
|
||||||
return
|
return
|
||||||
|
|
||||||
scratchpad = excalidraw_diagram_description.get("scratchpad")
|
scratchpad = excalidraw_diagram_description.get("scratchpad")
|
||||||
|
@ -1189,6 +1192,7 @@ def generate_chat_response(
|
||||||
raw_generated_files: List[FileAttachment] = [],
|
raw_generated_files: List[FileAttachment] = [],
|
||||||
generated_excalidraw_diagram: str = None,
|
generated_excalidraw_diagram: str = None,
|
||||||
program_execution_context: List[str] = [],
|
program_execution_context: List[str] = [],
|
||||||
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
|
@ -1251,6 +1255,7 @@ def generate_chat_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=raw_generated_files,
|
generated_files=raw_generated_files,
|
||||||
|
generated_asset_results=generated_asset_results,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1258,7 +1263,7 @@ def generate_chat_response(
|
||||||
openai_chat_config = conversation_config.ai_model_api
|
openai_chat_config = conversation_config.ai_model_api
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
chat_response = converse(
|
chat_response = converse_openai(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
@ -1278,8 +1283,7 @@ def generate_chat_response(
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=raw_generated_files,
|
generated_files=raw_generated_files,
|
||||||
generated_images=generated_images,
|
generated_asset_results=generated_asset_results,
|
||||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
|
||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
@ -1305,8 +1309,7 @@ def generate_chat_response(
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=raw_generated_files,
|
generated_files=raw_generated_files,
|
||||||
generated_images=generated_images,
|
generated_asset_results=generated_asset_results,
|
||||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
|
||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
@ -1331,8 +1334,7 @@ def generate_chat_response(
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=raw_generated_files,
|
generated_files=raw_generated_files,
|
||||||
generated_images=generated_images,
|
generated_asset_results=generated_asset_results,
|
||||||
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
|
||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@ import freezegun
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
|
||||||
from khoj.processor.conversation.openai.gpt import converse, extract_questions
|
from khoj.processor.conversation.openai.gpt import converse_openai, extract_questions
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
aget_data_sources_and_output_format,
|
aget_data_sources_and_output_format,
|
||||||
|
@ -158,7 +158,7 @@ def test_generate_search_query_using_question_and_answer_from_chat_history():
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_chat_with_no_chat_history_or_retrieved_content():
|
def test_chat_with_no_chat_history_or_retrieved_content():
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=[], # Assume no context retrieved from notes for the user_query
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
user_query="Hello, my name is Testatron. Who are you?",
|
user_query="Hello, my name is Testatron. Who are you?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -183,7 +183,7 @@ def test_answer_from_chat_history_and_no_content():
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=[], # Assume no context retrieved from notes for the user_query
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
user_query="What is my name?",
|
user_query="What is my name?",
|
||||||
conversation_log=populate_chat_history(message_list),
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
@ -214,7 +214,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=[], # Assume no context retrieved from notes for the user_query
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
user_query="Where was I born?",
|
user_query="Where was I born?",
|
||||||
conversation_log=populate_chat_history(message_list),
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
@ -239,7 +239,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=[
|
references=[
|
||||||
{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"}
|
{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"}
|
||||||
], # Assume context retrieved from notes for the user_query
|
], # Assume context retrieved from notes for the user_query
|
||||||
|
@ -265,7 +265,7 @@ def test_refuse_answering_unanswerable_question():
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=[], # Assume no context retrieved from notes for the user_query
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
user_query="Where was I born?",
|
user_query="Where was I born?",
|
||||||
conversation_log=populate_chat_history(message_list),
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
@ -318,7 +318,7 @@ Expenses:Food:Dining 10.00 USD""",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=context, # Assume context retrieved from notes for the user_query
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
user_query="What did I have for Dinner today?",
|
user_query="What did I have for Dinner today?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -362,7 +362,7 @@ Expenses:Food:Dining 10.00 USD""",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=context, # Assume context retrieved from notes for the user_query
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
user_query="How much did I spend on dining this year?",
|
user_query="How much did I spend on dining this year?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -386,7 +386,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=[], # Assume no context retrieved from notes for the user_query
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
|
user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
|
||||||
conversation_log=populate_chat_history(message_list),
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
@ -426,7 +426,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=context, # Assume context retrieved from notes for the user_query
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
user_query="How many kids does my older sister have?",
|
user_query="How many kids does my older sister have?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -459,13 +459,13 @@ def test_agent_prompt_should_be_used(openai_agent):
|
||||||
expected_responses = ["9.50", "9.5"]
|
expected_responses = ["9.50", "9.5"]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=context, # Assume context retrieved from notes for the user_query
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
user_query="What did I buy?",
|
user_query="What did I buy?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
no_agent_response = "".join([response_chunk for response_chunk in response_gen])
|
no_agent_response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
response_gen = converse(
|
response_gen = converse_openai(
|
||||||
references=context, # Assume context retrieved from notes for the user_query
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
user_query="What did I buy?",
|
user_query="What did I buy?",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
Loading…
Reference in a new issue