mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Include additional user context in the image generation flow (#660)
* Make major improvements to the image generation flow - Include user context from online references and personal notes for generating images - Dynamically select the modality that the LLM should respond with - Retun the inferred context in the query response for the dekstop, web chat views to read * Add unit tests for retrieving response modes via LLM * Move output mode unit tests to the actor suite, rather than director * Only show the references button if there is at least one available * Rename aget_relevant_modes to aget_relevant_output_modes * Use a shared method for generating reference sections, simplify some of the prompting logic * Make out of space errors in the desktop client more obvious
This commit is contained in:
parent
3cbc5b0d52
commit
e323a6d69b
9 changed files with 336 additions and 102 deletions
|
@ -197,13 +197,18 @@
|
|||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
||||
if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||
if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt);
|
||||
return;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt);
|
||||
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -261,6 +266,16 @@
|
|||
|
||||
references.appendChild(referenceSection);
|
||||
|
||||
if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt, references);
|
||||
return;
|
||||
}
|
||||
|
||||
renderMessage(message, by, dt, references);
|
||||
}
|
||||
|
||||
|
@ -324,6 +339,46 @@
|
|||
return element
|
||||
}
|
||||
|
||||
function createReferenceSection(references) {
|
||||
let referenceSection = document.createElement('div');
|
||||
referenceSection.classList.add("reference-section");
|
||||
referenceSection.classList.add("collapsed");
|
||||
|
||||
let numReferences = 0;
|
||||
|
||||
if (Array.isArray(references)) {
|
||||
numReferences = references.length;
|
||||
|
||||
references.forEach((reference, index) => {
|
||||
let polishedReference = generateReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
});
|
||||
} else {
|
||||
numReferences += processOnlineReferences(referenceSection, references);
|
||||
}
|
||||
|
||||
let referenceExpandButton = document.createElement('button');
|
||||
referenceExpandButton.classList.add("reference-expand-button");
|
||||
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||
|
||||
referenceExpandButton.addEventListener('click', function() {
|
||||
if (referenceSection.classList.contains("collapsed")) {
|
||||
referenceSection.classList.remove("collapsed");
|
||||
referenceSection.classList.add("expanded");
|
||||
} else {
|
||||
referenceSection.classList.add("collapsed");
|
||||
referenceSection.classList.remove("expanded");
|
||||
}
|
||||
});
|
||||
|
||||
let referencesDiv = document.createElement('div');
|
||||
referencesDiv.classList.add("references");
|
||||
referencesDiv.appendChild(referenceExpandButton);
|
||||
referencesDiv.appendChild(referenceSection);
|
||||
|
||||
return referencesDiv;
|
||||
}
|
||||
|
||||
async function chat() {
|
||||
// Extract required fields for search from form
|
||||
let query = document.getElementById("chat-input").value.trim();
|
||||
|
@ -382,6 +437,7 @@
|
|||
// Call Khoj chat API
|
||||
let response = await fetch(chatApi, { headers });
|
||||
let rawResponse = "";
|
||||
let references = null;
|
||||
const contentType = response.headers.get("content-type");
|
||||
|
||||
if (contentType === "application/json") {
|
||||
|
@ -396,6 +452,10 @@
|
|||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
|
||||
}
|
||||
}
|
||||
if (responseAsJson.context) {
|
||||
const rawReferenceAsJson = responseAsJson.context;
|
||||
references = createReferenceSection(rawReferenceAsJson);
|
||||
}
|
||||
if (responseAsJson.detail) {
|
||||
// If response has detail field, response is an error message.
|
||||
rawResponse += responseAsJson.detail;
|
||||
|
@ -407,6 +467,10 @@
|
|||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
if (references != null) {
|
||||
newResponseText.appendChild(references);
|
||||
}
|
||||
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
}
|
||||
|
@ -441,45 +505,7 @@
|
|||
|
||||
const rawReference = chunk.split("### compiled references:")[1];
|
||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||
references = document.createElement('div');
|
||||
references.classList.add("references");
|
||||
|
||||
let referenceExpandButton = document.createElement('button');
|
||||
referenceExpandButton.classList.add("reference-expand-button");
|
||||
|
||||
let referenceSection = document.createElement('div');
|
||||
referenceSection.classList.add("reference-section");
|
||||
referenceSection.classList.add("collapsed");
|
||||
|
||||
let numReferences = 0;
|
||||
|
||||
// If rawReferenceAsJson is a list, then count the length
|
||||
if (Array.isArray(rawReferenceAsJson)) {
|
||||
numReferences = rawReferenceAsJson.length;
|
||||
|
||||
rawReferenceAsJson.forEach((reference, index) => {
|
||||
let polishedReference = generateReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
});
|
||||
} else {
|
||||
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
||||
}
|
||||
|
||||
references.appendChild(referenceExpandButton);
|
||||
|
||||
referenceExpandButton.addEventListener('click', function() {
|
||||
if (referenceSection.classList.contains("collapsed")) {
|
||||
referenceSection.classList.remove("collapsed");
|
||||
referenceSection.classList.add("expanded");
|
||||
} else {
|
||||
referenceSection.classList.add("collapsed");
|
||||
referenceSection.classList.remove("expanded");
|
||||
}
|
||||
});
|
||||
|
||||
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||
referenceExpandButton.innerHTML = expandButtonText;
|
||||
references.appendChild(referenceSection);
|
||||
references = createReferenceSection(rawReferenceAsJson);
|
||||
readStream();
|
||||
} else {
|
||||
// Display response from Khoj
|
||||
|
|
|
@ -209,17 +209,17 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
||||
if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
|
||||
if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt);
|
||||
return;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt);
|
||||
return;
|
||||
}
|
||||
|
||||
if (context == null && onlineContext == null) {
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
}
|
||||
|
@ -273,6 +273,16 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
|
||||
references.appendChild(referenceSection);
|
||||
|
||||
if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
const inferredQuery = inferredQueries?.[0];
|
||||
if (inferredQuery) {
|
||||
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
renderMessage(imageMarkdown, by, dt, references);
|
||||
return;
|
||||
}
|
||||
|
||||
renderMessage(message, by, dt, references);
|
||||
}
|
||||
|
||||
|
@ -336,6 +346,46 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
return element
|
||||
}
|
||||
|
||||
function createReferenceSection(references) {
|
||||
let referenceSection = document.createElement('div');
|
||||
referenceSection.classList.add("reference-section");
|
||||
referenceSection.classList.add("collapsed");
|
||||
|
||||
let numReferences = 0;
|
||||
|
||||
if (Array.isArray(references)) {
|
||||
numReferences = references.length;
|
||||
|
||||
references.forEach((reference, index) => {
|
||||
let polishedReference = generateReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
});
|
||||
} else {
|
||||
numReferences += processOnlineReferences(referenceSection, references);
|
||||
}
|
||||
|
||||
let referenceExpandButton = document.createElement('button');
|
||||
referenceExpandButton.classList.add("reference-expand-button");
|
||||
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||
|
||||
referenceExpandButton.addEventListener('click', function() {
|
||||
if (referenceSection.classList.contains("collapsed")) {
|
||||
referenceSection.classList.remove("collapsed");
|
||||
referenceSection.classList.add("expanded");
|
||||
} else {
|
||||
referenceSection.classList.add("collapsed");
|
||||
referenceSection.classList.remove("expanded");
|
||||
}
|
||||
});
|
||||
|
||||
let referencesDiv = document.createElement('div');
|
||||
referencesDiv.classList.add("references");
|
||||
referencesDiv.appendChild(referenceExpandButton);
|
||||
referencesDiv.appendChild(referenceSection);
|
||||
|
||||
return referencesDiv;
|
||||
}
|
||||
|
||||
async function chat() {
|
||||
// Extract required fields for search from form
|
||||
let query = document.getElementById("chat-input").value.trim();
|
||||
|
@ -390,6 +440,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
// Call specified Khoj API
|
||||
let response = await fetch(url);
|
||||
let rawResponse = "";
|
||||
let references = null;
|
||||
const contentType = response.headers.get("content-type");
|
||||
|
||||
if (contentType === "application/json") {
|
||||
|
@ -404,6 +455,10 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
}
|
||||
if (responseAsJson.context && responseAsJson.context.length > 0) {
|
||||
const rawReferenceAsJson = responseAsJson.context;
|
||||
references = createReferenceSection(rawReferenceAsJson);
|
||||
}
|
||||
if (responseAsJson.detail) {
|
||||
// If response has detail field, response is an error message.
|
||||
rawResponse += responseAsJson.detail;
|
||||
|
@ -415,6 +470,10 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
if (references != null) {
|
||||
newResponseText.appendChild(references);
|
||||
}
|
||||
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
}
|
||||
|
@ -449,45 +508,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
|
||||
const rawReference = chunk.split("### compiled references:")[1];
|
||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||
references = document.createElement('div');
|
||||
references.classList.add("references");
|
||||
|
||||
let referenceExpandButton = document.createElement('button');
|
||||
referenceExpandButton.classList.add("reference-expand-button");
|
||||
|
||||
let referenceSection = document.createElement('div');
|
||||
referenceSection.classList.add("reference-section");
|
||||
referenceSection.classList.add("collapsed");
|
||||
|
||||
let numReferences = 0;
|
||||
|
||||
// If rawReferenceAsJson is a list, then count the length
|
||||
if (Array.isArray(rawReferenceAsJson)) {
|
||||
numReferences = rawReferenceAsJson.length;
|
||||
|
||||
rawReferenceAsJson.forEach((reference, index) => {
|
||||
let polishedReference = generateReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
});
|
||||
} else {
|
||||
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
||||
}
|
||||
|
||||
references.appendChild(referenceExpandButton);
|
||||
|
||||
referenceExpandButton.addEventListener('click', function() {
|
||||
if (referenceSection.classList.contains("collapsed")) {
|
||||
referenceSection.classList.remove("collapsed");
|
||||
referenceSection.classList.add("expanded");
|
||||
} else {
|
||||
referenceSection.classList.add("collapsed");
|
||||
referenceSection.classList.remove("expanded");
|
||||
}
|
||||
});
|
||||
|
||||
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||
referenceExpandButton.innerHTML = expandButtonText;
|
||||
references.appendChild(referenceSection);
|
||||
references = createReferenceSection(rawReferenceAsJson);
|
||||
readStream();
|
||||
} else {
|
||||
// Display response from Khoj
|
||||
|
|
|
@ -120,16 +120,23 @@ User's Notes:
|
|||
|
||||
image_generation_improve_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information from the query. Use the conversation log to inform your response.
|
||||
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
|
||||
|
||||
Today's Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
User's Notes:
|
||||
{references}
|
||||
|
||||
Online References:
|
||||
{online_results}
|
||||
|
||||
Conversation Log:
|
||||
{chat_history}
|
||||
|
||||
Query: {query}
|
||||
|
||||
Remember, now you are generating a prompt to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. Use the additional context from the user's notes, online references and conversation log to improve the image generation.
|
||||
Improved Query:"""
|
||||
)
|
||||
|
||||
|
@ -294,6 +301,40 @@ Collate the relevant information from the website to answer the target query.
|
|||
""".strip()
|
||||
)
|
||||
|
||||
pick_relevant_output_mode = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query. You have access to a limited set of modes for your response. You can only use one of these modes.
|
||||
|
||||
{modes}
|
||||
|
||||
Here are some example responses:
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
User: I just visited Jerusalem for the first time. Pull up my notes from the trip.
|
||||
AI: You mention visiting Masjid Al-Aqsa and the Western Wall. You also mention trying the local cuisine and visiting the Dead Sea.
|
||||
|
||||
Q: Draw a picture of my trip to Jerusalem.
|
||||
Khoj: image
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
User: I'm having trouble deciding which laptop to get. I want something with at least 16 GB of RAM and a 1 TB SSD.
|
||||
AI: I can help with that. I see online that there is a new model of the Dell XPS 15 that meets your requirements.
|
||||
|
||||
Q: What are the specs of the new Dell XPS 15?
|
||||
Khoj: default
|
||||
|
||||
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a string.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
|
||||
Q: {query}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart and helpful personal assistant. You have access to a variety of data sources to help you answer the user's question. You can use the data sources listed below to collect more relevant information. You can use any combination of these data sources to answer the user's question. Tell me which data sources you would like to use to answer the user's question.
|
||||
|
|
|
@ -22,6 +22,7 @@ from khoj.routers.helpers import (
|
|||
ConversationCommandRateLimiter,
|
||||
agenerate_chat_response,
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
get_conversation_command,
|
||||
is_ready_to_chat,
|
||||
text_to_image,
|
||||
|
@ -250,6 +251,9 @@ async def chat(
|
|||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||
mode = await aget_relevant_output_modes(q, meta_log)
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
|
@ -287,7 +291,8 @@ async def chat(
|
|||
media_type="text/event-stream",
|
||||
status_code=200,
|
||||
)
|
||||
elif conversation_commands == [ConversationCommand.Image]:
|
||||
|
||||
if ConversationCommand.Image in conversation_commands:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
|
@ -295,7 +300,9 @@ async def chat(
|
|||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
**common.__dict__,
|
||||
)
|
||||
image, status_code, improved_image_prompt = await text_to_image(q, meta_log, location_data=location)
|
||||
image, status_code, improved_image_prompt = await text_to_image(
|
||||
q, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||
)
|
||||
if image is None:
|
||||
content_obj = {"image": image, "intentType": "text-to-image", "detail": improved_image_prompt}
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
|
@ -308,8 +315,10 @@ async def chat(
|
|||
inferred_queries=[improved_image_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
)
|
||||
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
|
||||
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
|
||||
# Get the (streamed) chat response from the LLM of choice.
|
||||
|
|
|
@ -38,6 +38,7 @@ from khoj.utils.helpers import (
|
|||
ConversationCommand,
|
||||
is_none_or_empty,
|
||||
log_telemetry,
|
||||
mode_descriptions_for_llm,
|
||||
tool_descriptions_for_llm,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
@ -117,6 +118,9 @@ def construct_chat_history(conversation_history: dict, n: int = 4) -> str:
|
|||
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"Khoj: {chat['message']}\n"
|
||||
elif chat["by"] == "khoj" and chat["intent"].get("type") == "text-to-image":
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"Khoj: [generated image redacted for space]\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
|
@ -185,6 +189,42 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
|
|||
return [ConversationCommand.Default]
|
||||
|
||||
|
||||
async def aget_relevant_output_modes(query: str, conversation_history: dict):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
"""
|
||||
|
||||
mode_options = dict()
|
||||
|
||||
for mode, description in mode_descriptions_for_llm.items():
|
||||
mode_options[mode.value] = description
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
|
||||
query=query,
|
||||
modes=str(mode_options),
|
||||
chat_history=chat_history,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(relevant_mode_prompt)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
|
||||
if is_none_or_empty(response):
|
||||
return ConversationCommand.Default
|
||||
|
||||
if response in mode_options.keys():
|
||||
# Check whether the tool exists as a valid ConversationCommand
|
||||
return ConversationCommand(response)
|
||||
|
||||
return ConversationCommand.Default
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant mode: {response}")
|
||||
return ConversationCommand.Default
|
||||
|
||||
|
||||
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
|
@ -234,7 +274,13 @@ async def extract_relevant_info(q: str, corpus: dict) -> List[str]:
|
|||
return response.strip()
|
||||
|
||||
|
||||
async def generate_better_image_prompt(q: str, conversation_history: str, location_data: LocationData) -> str:
|
||||
async def generate_better_image_prompt(
|
||||
q: str,
|
||||
conversation_history: str,
|
||||
location_data: LocationData,
|
||||
note_references: List[str],
|
||||
online_results: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
"""
|
||||
|
@ -242,11 +288,26 @@ async def generate_better_image_prompt(q: str, conversation_history: str, locati
|
|||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
|
||||
user_references = "\n\n".join([f"# {item}" for item in note_references])
|
||||
|
||||
simplified_online_results = {}
|
||||
|
||||
if online_results:
|
||||
for result in online_results:
|
||||
if online_results[result].get("answerBox"):
|
||||
simplified_online_results[result] = online_results[result]["answerBox"]
|
||||
elif online_results[result].get("extracted_content"):
|
||||
simplified_online_results[result] = online_results[result]["extracted_content"]
|
||||
|
||||
image_prompt = prompts.image_generation_improve_prompt.format(
|
||||
query=q,
|
||||
chat_history=conversation_history,
|
||||
location=location,
|
||||
location=location_prompt,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(image_prompt)
|
||||
|
@ -377,7 +438,11 @@ def generate_chat_response(
|
|||
|
||||
|
||||
async def text_to_image(
|
||||
message: str, conversation_log: dict, location_data: LocationData
|
||||
message: str,
|
||||
conversation_log: dict,
|
||||
location_data: LocationData,
|
||||
references: List[str],
|
||||
online_results: Dict[str, Any],
|
||||
) -> Tuple[Optional[str], int, Optional[str]]:
|
||||
status_code = 200
|
||||
image = None
|
||||
|
@ -396,7 +461,13 @@ async def text_to_image(
|
|||
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: {chat['message']}\n"
|
||||
improved_image_prompt = await generate_better_image_prompt(message, chat_history, location_data=location_data)
|
||||
improved_image_prompt = await generate_better_image_prompt(
|
||||
message,
|
||||
chat_history,
|
||||
location_data=location_data,
|
||||
note_references=references,
|
||||
online_results=online_results,
|
||||
)
|
||||
try:
|
||||
response = state.openai_client.images.generate(
|
||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||
|
|
|
@ -289,6 +289,11 @@ tool_descriptions_for_llm = {
|
|||
ConversationCommand.Online: "Use this when you would like to look up information on the internet",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
ConversationCommand.Image: "Use this if you think the user is requesting an image or visual response to their query.",
|
||||
ConversationCommand.Default: "Use this if the other response modes don't seem to fit the query.",
|
||||
}
|
||||
|
||||
|
||||
def generate_random_name():
|
||||
# List of adjectives and nouns to choose from
|
||||
|
|
|
@ -24,6 +24,7 @@ from khoj.processor.conversation.offline.chat_model import (
|
|||
)
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_output_modes
|
||||
|
||||
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
|
||||
|
@ -497,6 +498,34 @@ def test_filter_questions():
|
|||
assert filtered_questions[0] == "Who is on the basketball team?"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_default_response_mode(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "default"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_image_response_mode(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "image"
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
|
|
@ -7,6 +7,7 @@ from freezegun import freeze_time
|
|||
|
||||
from khoj.processor.conversation.openai.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_output_modes
|
||||
|
||||
# Initialize variables for tests
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
@ -434,6 +435,34 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
|||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_default_response_mode(chat_client):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "default"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_image_response_mode(chat_client):
|
||||
# Arrange
|
||||
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "image"
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
|
|
@ -8,7 +8,10 @@ from freezegun import freeze_time
|
|||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
from khoj.routers.helpers import (
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
)
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
# Initialize variables for tests
|
||||
|
|
Loading…
Reference in a new issue