Show better intermediate steps when responding to chat via web socket

- Show internet search, webpage read, image query, image generation steps
- Standardize, improve rendering of the intermediate steps on the web app

Benefits:
1. Improved transparency, allow users to see what Khoj is doing behind
   the scenes and modify their query patterns to improve response quality
2. Reduced websocket connection keep alive timeouts for long running steps
This commit is contained in:
Debanjum Singh Solanky 2024-04-11 17:57:34 +05:30
parent fae7900f19
commit 997741119a
3 changed files with 54 additions and 22 deletions

View file

@ -3,7 +3,7 @@ import json
import logging
import os
from collections import defaultdict
from typing import Dict, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union
import aiohttp
import requests
@ -42,7 +42,9 @@ OLOSTEP_QUERY_PARAMS = {
MAX_WEBPAGES_TO_READ = 1
async def search_online(query: str, conversation_history: dict, location: LocationData):
async def search_online(
query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None
):
if not online_search_enabled():
logger.warn("SERPER_DEV_API_KEY is not set")
return {}
@ -52,7 +54,9 @@ async def search_online(query: str, conversation_history: dict, location: Locati
response_dict = {}
for subquery in subqueries:
logger.info(f"Searching with Google for '{subquery}'")
if send_status_func:
await send_status_func(f"**🌐 Searching the Internet for**: {subquery}")
logger.info(f"🌐 Searching the Internet for '{subquery}'")
response_dict[subquery] = search_with_google(subquery)
# Gather distinct web pages from organic search results of each subquery without an instant answer
@ -64,7 +68,10 @@ async def search_online(query: str, conversation_history: dict, location: Locati
}
# Read, extract relevant info from the retrieved web pages
logger.info(f"Reading web pages at: {webpage_links.keys()}")
if webpage_links:
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
if send_status_func:
await send_status_func(f"**📖 Reading web pages**: {'\n- ' + '\n- '.join(list(webpage_links))}")
tasks = [read_webpage_and_extract_content(subquery, link) for link, subquery in webpage_links.items()]
results = await asyncio.gather(*tasks)
@ -95,12 +102,18 @@ def search_with_google(subquery: str):
return extracted_search_result
async def read_webpages(query: str, conversation_history: dict, location: LocationData):
async def read_webpages(
query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
if send_status_func:
await send_status_func(f"**🧐 Inferring web pages to read**")
urls = await infer_webpage_urls(query, conversation_history, location)
logger.info(f"Reading web pages at: {urls}")
if send_status_func:
await send_status_func(f"**📖 Reading web pages**: {'\n- ' + '\n- '.join(list(urls))}")
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
results = await asyncio.gather(*tasks)

View file

@ -343,7 +343,7 @@ async def websocket_endpoint(
conversation_commands = [get_conversation_command(query=q, any_references=True)]
await send_status_update(f"**Processing query**: {q}")
await send_status_update(f"**👀 Understanding Query**: {q}")
if conversation_commands == [ConversationCommand.Help]:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
@ -358,7 +358,11 @@ async def websocket_endpoint(
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}")
mode = await aget_relevant_output_modes(q, meta_log)
await send_status_update(f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}")
if mode not in conversation_commands:
conversation_commands.append(mode)
@ -366,17 +370,15 @@ async def websocket_endpoint(
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
await send_status_update(
f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}"
)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location
)
if compiled_references:
headings = set([c.split("\n")[0] for c in compiled_references])
await send_status_update(f"**Searching references**: {headings}")
headings = "\n- " + "\n- ".join(
set([" ".join(c.split("Path: ")[1:]).split("\n ")[0] for c in compiled_references])
)
await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
online_results: Dict = dict()
@ -395,10 +397,7 @@ async def websocket_endpoint(
conversation_commands.append(ConversationCommand.Webpage)
else:
try:
await send_status_update("**Operation**: Searching the web for relevant information...")
online_results = await search_online(defiltered_query, meta_log, location)
online_searches = ", ".join([f"{query}" for query in online_results.keys()])
await send_status_update(f"**Online searches**: {online_searches}")
online_results = await search_online(defiltered_query, meta_log, location, send_status_update)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
await send_complete_llm_response(
@ -408,13 +407,12 @@ async def websocket_endpoint(
if ConversationCommand.Webpage in conversation_commands:
try:
await send_status_update("**Operation**: Directly searching web pages...")
online_results = await read_webpages(defiltered_query, meta_log, location)
online_results = await read_webpages(defiltered_query, meta_log, location, send_status_update)
webpages = []
for query in online_results:
for webpage in online_results[query]["webpages"]:
webpages.append(webpage["link"])
await send_status_update(f"**Web pages read**: {webpages}")
await send_status_update(f"**📚 Read web pages**: {webpages}")
except ValueError as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
@ -427,10 +425,15 @@ async def websocket_endpoint(
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
await send_status_update("**Operation**: Augmenting your query and generating a superb image...")
intent_type = "text-to-image"
image, status_code, improved_image_prompt, image_url = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
q,
user,
meta_log,
location_data=location,
references=compiled_references,
online_results=online_results,
send_status_func=send_status_update,
)
if image is None or status_code != 200:
content_obj = {
@ -462,6 +465,7 @@ async def websocket_endpoint(
await send_complete_llm_response(json.dumps(content_obj))
continue
await send_status_update(f"**💭 Generating a well-informed response**")
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,

View file

@ -4,7 +4,17 @@ import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import (
Annotated,
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
import openai
from fastapi import Depends, Header, HTTPException, Request, UploadFile
@ -497,6 +507,7 @@ async def text_to_image(
location_data: LocationData,
references: List[str],
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
status_code = 200
image = None
@ -522,6 +533,8 @@ async def text_to_image(
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
try:
with timer("Improve the original user query", logger):
if send_status_func:
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
@ -530,6 +543,8 @@ async def text_to_image(
online_results=online_results,
)
with timer("Generate image with OpenAI", logger):
if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
)