mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Show better intermediate steps when responding to chat via web socket
- Show internet search, webpage read, image query, image generation steps - Standardize, improve rendering of the intermediate steps on the web app Benefits: 1. Improved transparency, allow users to see what Khoj is doing behind the scenes and modify their query patterns to improve response quality 2. Reduced websocket connection keep alive timeouts for long running steps
This commit is contained in:
parent
fae7900f19
commit
997741119a
3 changed files with 54 additions and 22 deletions
|
@ -3,7 +3,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
@ -42,7 +42,9 @@ OLOSTEP_QUERY_PARAMS = {
|
|||
MAX_WEBPAGES_TO_READ = 1
|
||||
|
||||
|
||||
async def search_online(query: str, conversation_history: dict, location: LocationData):
|
||||
async def search_online(
|
||||
query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None
|
||||
):
|
||||
if not online_search_enabled():
|
||||
logger.warn("SERPER_DEV_API_KEY is not set")
|
||||
return {}
|
||||
|
@ -52,7 +54,9 @@ async def search_online(query: str, conversation_history: dict, location: Locati
|
|||
response_dict = {}
|
||||
|
||||
for subquery in subqueries:
|
||||
logger.info(f"Searching with Google for '{subquery}'")
|
||||
if send_status_func:
|
||||
await send_status_func(f"**🌐 Searching the Internet for**: {subquery}")
|
||||
logger.info(f"🌐 Searching the Internet for '{subquery}'")
|
||||
response_dict[subquery] = search_with_google(subquery)
|
||||
|
||||
# Gather distinct web pages from organic search results of each subquery without an instant answer
|
||||
|
@ -64,7 +68,10 @@ async def search_online(query: str, conversation_history: dict, location: Locati
|
|||
}
|
||||
|
||||
# Read, extract relevant info from the retrieved web pages
|
||||
logger.info(f"Reading web pages at: {webpage_links.keys()}")
|
||||
if webpage_links:
|
||||
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
|
||||
if send_status_func:
|
||||
await send_status_func(f"**📖 Reading web pages**: {'\n- ' + '\n- '.join(list(webpage_links))}")
|
||||
tasks = [read_webpage_and_extract_content(subquery, link) for link, subquery in webpage_links.items()]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
|
@ -95,12 +102,18 @@ def search_with_google(subquery: str):
|
|||
return extracted_search_result
|
||||
|
||||
|
||||
async def read_webpages(query: str, conversation_history: dict, location: LocationData):
|
||||
async def read_webpages(
|
||||
query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
if send_status_func:
|
||||
await send_status_func(f"**🧐 Inferring web pages to read**")
|
||||
urls = await infer_webpage_urls(query, conversation_history, location)
|
||||
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
await send_status_func(f"**📖 Reading web pages**: {'\n- ' + '\n- '.join(list(urls))}")
|
||||
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
|
|
|
@ -343,7 +343,7 @@ async def websocket_endpoint(
|
|||
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
await send_status_update(f"**Processing query**: {q}")
|
||||
await send_status_update(f"**👀 Understanding Query**: {q}")
|
||||
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
|
@ -358,7 +358,11 @@ async def websocket_endpoint(
|
|||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
await send_status_update(f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}")
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log)
|
||||
await send_status_update(f"**🧑🏾💻 Decided Response Mode:** {mode.value}")
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
|
@ -366,17 +370,15 @@ async def websocket_endpoint(
|
|||
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
await send_status_update(
|
||||
f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}"
|
||||
)
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location
|
||||
)
|
||||
|
||||
if compiled_references:
|
||||
headings = set([c.split("\n")[0] for c in compiled_references])
|
||||
await send_status_update(f"**Searching references**: {headings}")
|
||||
headings = "\n- " + "\n- ".join(
|
||||
set([" ".join(c.split("Path: ")[1:]).split("\n ")[0] for c in compiled_references])
|
||||
)
|
||||
await send_status_update(f"**📜 Found Relevant Notes**: {headings}")
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
||||
|
@ -395,10 +397,7 @@ async def websocket_endpoint(
|
|||
conversation_commands.append(ConversationCommand.Webpage)
|
||||
else:
|
||||
try:
|
||||
await send_status_update("**Operation**: Searching the web for relevant information...")
|
||||
online_results = await search_online(defiltered_query, meta_log, location)
|
||||
online_searches = ", ".join([f"{query}" for query in online_results.keys()])
|
||||
await send_status_update(f"**Online searches**: {online_searches}")
|
||||
online_results = await search_online(defiltered_query, meta_log, location, send_status_update)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
|
||||
await send_complete_llm_response(
|
||||
|
@ -408,13 +407,12 @@ async def websocket_endpoint(
|
|||
|
||||
if ConversationCommand.Webpage in conversation_commands:
|
||||
try:
|
||||
await send_status_update("**Operation**: Directly searching web pages...")
|
||||
online_results = await read_webpages(defiltered_query, meta_log, location)
|
||||
online_results = await read_webpages(defiltered_query, meta_log, location, send_status_update)
|
||||
webpages = []
|
||||
for query in online_results:
|
||||
for webpage in online_results[query]["webpages"]:
|
||||
webpages.append(webpage["link"])
|
||||
await send_status_update(f"**Web pages read**: {webpages}")
|
||||
await send_status_update(f"**📚 Read web pages**: {webpages}")
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
|
||||
|
@ -427,10 +425,15 @@ async def websocket_endpoint(
|
|||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
await send_status_update("**Operation**: Augmenting your query and generating a superb image...")
|
||||
intent_type = "text-to-image"
|
||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
||||
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||
q,
|
||||
user,
|
||||
meta_log,
|
||||
location_data=location,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
send_status_func=send_status_update,
|
||||
)
|
||||
if image is None or status_code != 200:
|
||||
content_obj = {
|
||||
|
@ -462,6 +465,7 @@ async def websocket_endpoint(
|
|||
await send_complete_llm_response(json.dumps(content_obj))
|
||||
continue
|
||||
|
||||
await send_status_update(f"**💭 Generating a well-informed response**")
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
|
|
|
@ -4,7 +4,17 @@ import logging
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import openai
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
|
@ -497,6 +507,7 @@ async def text_to_image(
|
|||
location_data: LocationData,
|
||||
references: List[str],
|
||||
online_results: Dict[str, Any],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
|
||||
status_code = 200
|
||||
image = None
|
||||
|
@ -522,6 +533,8 @@ async def text_to_image(
|
|||
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
|
||||
try:
|
||||
with timer("Improve the original user query", logger):
|
||||
if send_status_func:
|
||||
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
|
||||
improved_image_prompt = await generate_better_image_prompt(
|
||||
message,
|
||||
chat_history,
|
||||
|
@ -530,6 +543,8 @@ async def text_to_image(
|
|||
online_results=online_results,
|
||||
)
|
||||
with timer("Generate image with OpenAI", logger):
|
||||
if send_status_func:
|
||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
||||
response = state.openai_client.images.generate(
|
||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue