Include additional user context in the image generation flow (#660)

* Make major improvements to the image generation flow

- Include user context from online references and personal notes for generating images
- Dynamically select the modality that the LLM should respond with
- Retun the inferred context in the query response for the dekstop, web chat views to read

* Add unit tests for retrieving response modes via LLM

* Move output mode unit tests to the actor suite, rather than director

* Only show the references button if there is at least one available

* Rename aget_relevant_modes to aget_relevant_output_modes

* Use a shared method for generating reference sections, simplify some of the prompting logic

* Make out of space errors in the desktop client more obvious
This commit is contained in:
sabaimran 2024-03-06 13:48:41 +05:30 committed by GitHub
parent 3cbc5b0d52
commit e323a6d69b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 336 additions and 102 deletions

View file

@ -197,13 +197,18 @@
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
}
renderMessage(imageMarkdown, by, dt);
renderMessage(message, by, dt);
return;
}
@ -261,6 +266,16 @@
references.appendChild(referenceSection);
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt, references);
return;
}
renderMessage(message, by, dt, references);
}
@ -324,6 +339,46 @@
return element
}
function createReferenceSection(references) {
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
if (Array.isArray(references)) {
numReferences = references.length;
references.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, references);
}
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let referencesDiv = document.createElement('div');
referencesDiv.classList.add("references");
referencesDiv.appendChild(referenceExpandButton);
referencesDiv.appendChild(referenceSection);
return referencesDiv;
}
async function chat() {
// Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim();
@ -382,6 +437,7 @@
// Call Khoj chat API
let response = await fetch(chatApi, { headers });
let rawResponse = "";
let references = null;
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
@ -396,6 +452,10 @@
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
}
}
if (responseAsJson.context) {
const rawReferenceAsJson = responseAsJson.context;
references = createReferenceSection(rawReferenceAsJson);
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
@ -407,6 +467,10 @@
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
@ -441,45 +505,7 @@
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
references = createReferenceSection(rawReferenceAsJson);
readStream();
} else {
// Display response from Khoj

View file

@ -209,17 +209,17 @@ To get started, just start typing below. You can also type / to see a list of co
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
}
renderMessage(imageMarkdown, by, dt);
return;
}
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
}
@ -273,6 +273,16 @@ To get started, just start typing below. You can also type / to see a list of co
references.appendChild(referenceSection);
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt, references);
return;
}
renderMessage(message, by, dt, references);
}
@ -336,6 +346,46 @@ To get started, just start typing below. You can also type / to see a list of co
return element
}
function createReferenceSection(references) {
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
if (Array.isArray(references)) {
numReferences = references.length;
references.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, references);
}
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let referencesDiv = document.createElement('div');
referencesDiv.classList.add("references");
referencesDiv.appendChild(referenceExpandButton);
referencesDiv.appendChild(referenceSection);
return referencesDiv;
}
async function chat() {
// Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim();
@ -390,6 +440,7 @@ To get started, just start typing below. You can also type / to see a list of co
// Call specified Khoj API
let response = await fetch(url);
let rawResponse = "";
let references = null;
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
@ -404,6 +455,10 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
if (responseAsJson.context && responseAsJson.context.length > 0) {
const rawReferenceAsJson = responseAsJson.context;
references = createReferenceSection(rawReferenceAsJson);
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
@ -415,6 +470,10 @@ To get started, just start typing below. You can also type / to see a list of co
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
@ -449,45 +508,7 @@ To get started, just start typing below. You can also type / to see a list of co
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
references = createReferenceSection(rawReferenceAsJson);
readStream();
} else {
// Display response from Khoj

View file

@ -120,16 +120,23 @@ User's Notes:
image_generation_improve_prompt = PromptTemplate.from_template(
"""
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information from the query. Use the conversation log to inform your response.
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
Today's Date: {current_date}
User's Location: {location}
User's Notes:
{references}
Online References:
{online_results}
Conversation Log:
{chat_history}
Query: {query}
Remember, now you are generating a prompt to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. Use the additional context from the user's notes, online references and conversation log to improve the image generation.
Improved Query:"""
)
@ -294,6 +301,40 @@ Collate the relevant information from the website to answer the target query.
""".strip()
)
pick_relevant_output_mode = PromptTemplate.from_template(
"""
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query. You have access to a limited set of modes for your response. You can only use one of these modes.
{modes}
Here are some example responses:
Example:
Chat History:
User: I just visited Jerusalem for the first time. Pull up my notes from the trip.
AI: You mention visiting Masjid Al-Aqsa and the Western Wall. You also mention trying the local cuisine and visiting the Dead Sea.
Q: Draw a picture of my trip to Jerusalem.
Khoj: image
Example:
Chat History:
User: I'm having trouble deciding which laptop to get. I want something with at least 16 GB of RAM and a 1 TB SSD.
AI: I can help with that. I see online that there is a new model of the Dell XPS 15 that meets your requirements.
Q: What are the specs of the new Dell XPS 15?
Khoj: default
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a string.
Chat History:
{chat_history}
Q: {query}
Khoj:
""".strip()
)
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, a smart and helpful personal assistant. You have access to a variety of data sources to help you answer the user's question. You can use the data sources listed below to collect more relevant information. You can use any combination of these data sources to answer the user's question. Tell me which data sources you would like to use to answer the user's question.

View file

@ -22,6 +22,7 @@ from khoj.routers.helpers import (
ConversationCommandRateLimiter,
agenerate_chat_response,
aget_relevant_information_sources,
aget_relevant_output_modes,
get_conversation_command,
is_ready_to_chat,
text_to_image,
@ -250,6 +251,9 @@ async def chat(
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
mode = await aget_relevant_output_modes(q, meta_log)
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
@ -287,7 +291,8 @@ async def chat(
media_type="text/event-stream",
status_code=200,
)
elif conversation_commands == [ConversationCommand.Image]:
if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
@ -295,7 +300,9 @@ async def chat(
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(q, meta_log, location_data=location)
image, status_code, improved_image_prompt = await text_to_image(
q, meta_log, location_data=location, references=compiled_references, online_results=online_results
)
if image is None:
content_obj = {"image": image, "intentType": "text-to-image", "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
@ -308,8 +315,10 @@ async def chat(
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.

View file

@ -38,6 +38,7 @@ from khoj.utils.helpers import (
ConversationCommand,
is_none_or_empty,
log_telemetry,
mode_descriptions_for_llm,
tool_descriptions_for_llm,
)
from khoj.utils.rawconfig import LocationData
@ -117,6 +118,9 @@ def construct_chat_history(conversation_history: dict, n: int = 4) -> str:
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n"
elif chat["by"] == "khoj" and chat["intent"].get("type") == "text-to-image":
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"Khoj: [generated image redacted for space]\n"
return chat_history
@ -185,6 +189,42 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
return [ConversationCommand.Default]
async def aget_relevant_output_modes(query: str, conversation_history: dict):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
"""
mode_options = dict()
for mode, description in mode_descriptions_for_llm.items():
mode_options[mode.value] = description
chat_history = construct_chat_history(conversation_history)
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
query=query,
modes=str(mode_options),
chat_history=chat_history,
)
response = await send_message_to_model_wrapper(relevant_mode_prompt)
try:
response = response.strip()
if is_none_or_empty(response):
return ConversationCommand.Default
if response in mode_options.keys():
# Check whether the tool exists as a valid ConversationCommand
return ConversationCommand(response)
return ConversationCommand.Default
except Exception as e:
logger.error(f"Invalid response for determining relevant mode: {response}")
return ConversationCommand.Default
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
"""
Generate subqueries from the given query
@ -234,7 +274,13 @@ async def extract_relevant_info(q: str, corpus: dict) -> List[str]:
return response.strip()
async def generate_better_image_prompt(q: str, conversation_history: str, location_data: LocationData) -> str:
async def generate_better_image_prompt(
q: str,
conversation_history: str,
location_data: LocationData,
note_references: List[str],
online_results: Optional[dict] = None,
) -> str:
"""
Generate a better image prompt from the given query
"""
@ -242,11 +288,26 @@ async def generate_better_image_prompt(q: str, conversation_history: str, locati
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
location_prompt = prompts.user_location.format(location=location)
user_references = "\n\n".join([f"# {item}" for item in note_references])
simplified_online_results = {}
if online_results:
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
elif online_results[result].get("extracted_content"):
simplified_online_results[result] = online_results[result]["extracted_content"]
image_prompt = prompts.image_generation_improve_prompt.format(
query=q,
chat_history=conversation_history,
location=location,
location=location_prompt,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
)
response = await send_message_to_model_wrapper(image_prompt)
@ -377,7 +438,11 @@ def generate_chat_response(
async def text_to_image(
message: str, conversation_log: dict, location_data: LocationData
message: str,
conversation_log: dict,
location_data: LocationData,
references: List[str],
online_results: Dict[str, Any],
) -> Tuple[Optional[str], int, Optional[str]]:
status_code = 200
image = None
@ -396,7 +461,13 @@ async def text_to_image(
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
improved_image_prompt = await generate_better_image_prompt(message, chat_history, location_data=location_data)
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
)
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"

View file

@ -289,6 +289,11 @@ tool_descriptions_for_llm = {
ConversationCommand.Online: "Use this when you would like to look up information on the internet",
}
mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if you think the user is requesting an image or visual response to their query.",
ConversationCommand.Default: "Use this if the other response modes don't seem to fit the query.",
}
def generate_random_name():
# List of adjectives and nouns to choose from

View file

@ -24,6 +24,7 @@ from khoj.processor.conversation.offline.chat_model import (
)
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
@ -497,6 +498,34 @@ def test_filter_questions():
assert filtered_questions[0] == "Who is on the basketball team?"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_default_response_mode(client_offline_chat):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "default"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(client_offline_chat):
# Arrange
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "image"
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):

View file

@ -7,6 +7,7 @@ from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@ -434,6 +435,34 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_default_response_mode(chat_client):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "default"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(chat_client):
# Arrange
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "image"
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):

View file

@ -8,7 +8,10 @@ from freezegun import freeze_time
from khoj.database.models import KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
from khoj.routers.helpers import (
aget_relevant_information_sources,
aget_relevant_output_modes,
)
from tests.helpers import ConversationFactory
# Initialize variables for tests