From 7009793170fa598916b97f80540ab13a54d1fa43 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 3 Dec 2023 18:16:00 -0500 Subject: [PATCH 01/30] Migrate to OpenAI Python library >= 1.0 --- pyproject.toml | 2 +- .../processor/conversation/openai/utils.py | 20 +++++++++---------- .../processor/conversation/openai/whisper.py | 5 +++-- src/khoj/utils/models.py | 4 ++-- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 519d2d60..5a206cce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "fastapi >= 0.104.1", "python-multipart >= 0.0.5", "jinja2 == 3.1.2", - "openai >= 0.27.0, < 1.0.0", + "openai >= 1.0.0", "tiktoken >= 0.3.2", "tenacity >= 8.2.2", "pillow ~= 9.5.0", diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 69fab7e5..a5868cb8 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): @retry( retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) ), wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(3), @@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs): @retry( retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) ), wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3), diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py index 72834d92..351319b7 100644 --- a/src/khoj/processor/conversation/openai/whisper.py +++ b/src/khoj/processor/conversation/openai/whisper.py @@ -3,7 +3,7 @@ from io import BufferedReader # External Packages from asgiref.sync import sync_to_async -import openai +from openai import OpenAI async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: @@ -11,5 +11,6 @@ async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: Transcribe audio file using Whisper model via OpenAI's API """ # Send the audio data to the Whisper API - response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key) + client = OpenAI(api_key=api_key) + response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file) return response["text"] diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index b5bbe292..37d09418 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -32,7 +32,7 @@ class OpenAI(BaseEncoder): raise Exception( f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}" ) - openai.api_key = state.processor_config.conversation.openai_model.api_key + self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key) self.embedding_dimensions = None def encode(self, entries, device=None, **kwargs): @@ -43,7 +43,7 @@ class OpenAI(BaseEncoder): processed_entry = entries[index].replace("\n", " ") try: - response = openai.Embedding.create(input=processed_entry, model=self.model_name) + response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name) embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)] # Use current models embedding dimension, once available # Else default to embedding dimensions of the text-embedding-ada-002 model From 2b09caa237f302065fbc0d6ff14638cfc5638f2d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 3 Dec 2023 19:13:28 -0500 Subject: [PATCH 02/30] Make online results an optional argument to the gpt converse method --- src/khoj/processor/conversation/openai/gpt.py | 2 +- src/khoj/routers/helpers.py | 4 ++-- tests/test_openai_chat_director.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index fed110f7..dc708ab7 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -123,8 +123,8 @@ def send_message_to_model( def converse( references, - online_results, user_query, + online_results=[], conversation_log={}, model: str = "gpt-3.5-turbo", api_key: Optional[str] = None, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 22b0d037..618f89ef 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -251,9 +251,9 @@ def generate_chat_response( chat_model = conversation_config.chat_model chat_response = converse( compiled_references, - online_results, q, - meta_log, + online_results=online_results, + conversation_log=meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion, diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 7b98d4da..50ad60af 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client): response_message = response_message.split("### compiled references")[0] # Assert - expected_responses = ["http://www.paulgraham.com/greatwork.html"] + expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"] assert response.status_code == 200 assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got: " + response_message + "Expected links or serper not setup in response but got: " + response_message ) From 316b7d471a0344cbd47e779aa49f6ef431ed0f4f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 13:46:25 -0500 Subject: [PATCH 03/30] Handle offline chat model retrieval when no internet Offline chat shouldn't fail on retrieve_model when no internet, if model was previously downloaded and usable offline --- src/khoj/processor/conversation/offline/utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 0b876b26..3a1862f7 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -12,11 +12,12 @@ def download_model(model_name: str): logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - # Download the chat model - chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) - # Decide whether to load model to GPU or CPU + chat_model_config = None try: + # Download the chat model and its config + chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) + # Try load chat model to GPU if: # 1. Loading chat model to GPU isn't disabled via CLI and # 2. Machine has GPU @@ -26,6 +27,12 @@ def download_model(model_name: str): ) except ValueError: device = "cpu" + except Exception as e: + if chat_model_config is None: + device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory + logger.debug(f"Unable to download model config from gpt4all website: {e}") + else: + raise e # Now load the downloaded chat model onto appropriate device chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) From f0222f6d0828c7c4d6bdfc90b083bfa3596d8878 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 18:54:22 -0500 Subject: [PATCH 04/30] Make save_to_conversation_log helper function reusable - Move it out to conversation.utils from generate_chat_response function - Log new optional intent_type argument to capture type of response expected. This can be type responses by Khoj e.g speech, image. It can be used to render responses by Khoj appropriately on clients - Make user_message_time argument optional, set the time to now by default if not passed by calling function --- src/khoj/processor/conversation/utils.py | 29 ++++++++++++++++++++++ src/khoj/routers/helpers.py | 31 +++--------------------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e2164719..4606efd9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -4,6 +4,7 @@ from time import perf_counter import json from datetime import datetime import queue +from typing import Any, Dict, List import tiktoken # External packages @@ -11,6 +12,8 @@ from langchain.schema import ChatMessage from transformers import AutoTokenizer # Internal Packages +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import KhojUser from khoj.utils.helpers import merge_dicts @@ -89,6 +92,32 @@ def message_to_log( return conversation_log +def save_to_conversation_log( + q: str, + chat_response: str, + user: KhojUser, + meta_log: Dict, + user_message_time: str = None, + compiled_references: List[str] = [], + online_results: Dict[str, Any] = {}, + inferred_queries: List[str] = [], + intent_type: str = "remember", +): + user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") + updated_conversation = message_to_log( + user_message=q, + chat_response=chat_response, + user_message_metadata={"created": user_message_time}, + khoj_message_metadata={ + "context": compiled_references, + "intent": {"inferred-queries": inferred_queries, "type": intent_type}, + "onlineContext": online_results, + }, + conversation_log=meta_log.get("chat", []), + ) + ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) + + def generate_chatml_messages_with_context( user_message, system_message, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 618f89ef..3e8ed155 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -19,7 +19,7 @@ from khoj.database.models import KhojUser, Subscription from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.openai.gpt import converse, send_message_to_model -from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log +from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log # Internal Packages from khoj.utils import state @@ -186,30 +186,7 @@ def generate_chat_response( conversation_command: ConversationCommand = ConversationCommand.Default, user: KhojUser = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: - def _save_to_conversation_log( - q: str, - chat_response: str, - user_message_time: str, - compiled_references: List[str], - online_results: Dict[str, Any], - inferred_queries: List[str], - meta_log, - ): - updated_conversation = message_to_log( - user_message=q, - chat_response=chat_response, - user_message_metadata={"created": user_message_time}, - khoj_message_metadata={ - "context": compiled_references, - "intent": {"inferred-queries": inferred_queries}, - "onlineContext": online_results, - }, - conversation_log=meta_log.get("chat", []), - ) - ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) - # Initialize Variables - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") chat_response = None logger.debug(f"Conversation Type: {conversation_command.name}") @@ -217,13 +194,13 @@ def generate_chat_response( try: partial_completion = partial( - _save_to_conversation_log, + save_to_conversation_log, q, - user_message_time=user_message_time, + user=user, + meta_log=meta_log, compiled_references=compiled_references, online_results=online_results, inferred_queries=inferred_queries, - meta_log=meta_log, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user) From 1d9c1333f2881c849914bb970d4b184d3436fe0d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 17:54:30 -0500 Subject: [PATCH 05/30] Configure text to image models available on server - Currently supports OpenAI text to image model, by default dall-e-3 - Allow setting the text to image model via CLI during server setup --- src/khoj/database/adapters/__init__.py | 5 ++++ .../migrations/0022_texttoimagemodelconfig.py | 25 +++++++++++++++++++ src/khoj/database/models/__init__.py | 8 ++++++ src/khoj/utils/initialization.py | 10 ++++++++ 4 files changed, 48 insertions(+) create mode 100644 src/khoj/database/migrations/0022_texttoimagemodelconfig.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 12a127e9..f1f4031e 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -22,6 +22,7 @@ from khoj.database.models import ( GithubConfig, GithubRepoConfig, GoogleUser, + TextToImageModelConfig, KhojApiUser, KhojUser, NotionConfig, @@ -414,6 +415,10 @@ class ConversationAdapters: else: raise ValueError("Invalid conversation config - either configure offline chat or openai chat") + @staticmethod + async def aget_text_to_image_model_config(): + return await TextToImageModelConfig.objects.filter().afirst() + class EntryAdapters: word_filer = WordFilter() diff --git a/src/khoj/database/migrations/0022_texttoimagemodelconfig.py b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py new file mode 100644 index 00000000..7450dc40 --- /dev/null +++ b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.7 on 2023-12-04 22:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0021_speechtotextmodeloptions_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="TextToImageModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("model_name", models.CharField(default="dall-e-3", max_length=200)), + ("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 82348fbe..00700f2f 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel): cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") +class TextToImageModelConfig(BaseModel): + class ModelType(models.TextChoices): + OPENAI = "openai" + + model_name = models.CharField(max_length=200, default="dall-e-3") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + + class OpenAIProcessorConversationConfig(BaseModel): api_key = models.CharField(max_length=200) diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 313b18fc..0bb78dbe 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -7,6 +7,7 @@ from khoj.database.models import ( OpenAIProcessorConversationConfig, ChatModelOptions, SpeechToTextModelOptions, + TextToImageModelConfig, ) from khoj.utils.constants import default_offline_chat_model, default_online_chat_model @@ -103,6 +104,15 @@ def initialization(): model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI ) + default_text_to_image_model = "dall-e-3" + openai_text_to_image_model = input( + f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): " + ) + openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model + TextToImageModelConfig.objects.create( + model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI + ) + if use_offline_model == "y" or use_openai_model == "y": logger.info("🗣️ Chat model configuration complete") From 252b35b2f01e16cba56f95e73f33e9fb7a99ef48 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 17:58:04 -0500 Subject: [PATCH 06/30] Support /image slash command to generate images using the chat API --- src/khoj/routers/api.py | 9 ++++++++- src/khoj/routers/helpers.py | 35 ++++++++++++++++++++++++++++++----- src/khoj/utils/helpers.py | 2 ++ 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ae125980..ae31c260 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.prompts import help_message, no_entries_found +from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.tools.online_search import search_with_google from khoj.routers.helpers import ( ApiUserRateLimiter, CommonQueryParams, agenerate_chat_response, get_conversation_command, + text_to_image, is_ready_to_chat, update_telemetry_state, validate_conversation_config, @@ -665,7 +667,7 @@ async def chat( rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: - user = request.user.object + user: KhojUser = request.user.object await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) @@ -703,6 +705,11 @@ async def chat( media_type="text/event-stream", status_code=200, ) + elif conversation_command == ConversationCommand.Image: + image_url, status_code = await text_to_image(q) + await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image") + content_obj = {"imageUrl": image_url, "intentType": "text-to-image"} + return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. llm_response, chat_metadata = await agenerate_chat_response( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3e8ed155..4e43289f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -9,23 +9,23 @@ from functools import partial from time import time from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union +# External Packages from fastapi import Depends, Header, HTTPException, Request, UploadFile +import openai from starlette.authentication import has_required_scope -from asgiref.sync import sync_to_async - +# Internal Packages from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import KhojUser, Subscription +from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log - -# Internal Packages from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import ConversationCommand, log_telemetry + logger = logging.getLogger(__name__) executor = ThreadPoolExecutor(max_workers=1) @@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver return ConversationCommand.General elif query.startswith("/online"): return ConversationCommand.Online + elif query.startswith("/image"): + return ConversationCommand.Image # If no relevant notes found for the given query elif not any_references: return ConversationCommand.General @@ -248,6 +250,29 @@ def generate_chat_response( return chat_response, metadata +async def text_to_image(message: str) -> Tuple[Optional[str], int]: + status_code = 200 + image_url = None + + # Send the audio data to the Whisper API + text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() + openai_chat_config = await ConversationAdapters.get_openai_chat_config() + if not text_to_image_config: + # If the user has not configured a text to image model, return an unprocessable entity error + status_code = 422 + elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + client = openai.OpenAI(api_key=openai_chat_config.api_key) + text2image_model = text_to_image_config.model_name + try: + response = client.images.generate(prompt=message, model=text2image_model) + image_url = response.data[0].url + except openai.OpenAIError as e: + logger.error(f"Image Generation failed with {e.http_status}: {e.error}") + status_code = 500 + + return image_url, status_code + + class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int): self.requests = requests diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 42e3835d..21fe7e98 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -273,6 +273,7 @@ class ConversationCommand(str, Enum): Notes = "notes" Help = "help" Online = "online" + Image = "image" command_descriptions = { @@ -280,6 +281,7 @@ command_descriptions = { ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Online: "Look up information on the internet.", + ConversationCommand.Image: "Generate images by describing your imagination in words.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.", } From cc051ceb4be756969a41b0431321b194d5c58490 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 20:40:54 -0500 Subject: [PATCH 07/30] Show generated images in chat interface on Web client --- src/khoj/interface/web/chat.html | 43 +++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index a55d8e29..39cb6e77 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -183,12 +183,18 @@ To get started, just start typing below. You can also type / to see a list of co referenceSection.appendChild(polishedReference); } } - } + } return numOnlineReferences; - } + } + + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { + if (intentType === "text-to-image") { + let imageMarkdown = `![](${message})`; + renderMessage(imageMarkdown, by, dt); + return; + } - function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { if (context == null && onlineContext == null) { renderMessage(message, by, dt); return; @@ -253,6 +259,26 @@ To get started, just start typing below. You can also type / to see a list of co // Remove any text between [INST] and tags. These are spurious instructions for the AI chat model. newHTML = newHTML.replace(/\[INST\].+(<\/s>)?/g, ''); + // Customize the rendering of images + md.renderer.rules.image = function(tokens, idx, options, env, self) { + let token = tokens[idx]; + + // Get image source url. Only render images with src links + let srcIndex = token.attrIndex('src'); + if (srcIndex < 0) { return ''; } + let src = token.attrs[srcIndex][1]; + + // Wrap the image in a link + var aStart = ``; + var aEnd = ''; + + // Add class="text-to-image" to images + token.attrPush(['class', 'text-to-image']); + + // Use the default renderer to render image markdown format + return aStart + self.renderToken(tokens, idx, options) + aEnd; + }; + // Render markdown newHTML = md.render(newHTML); // Get any elements with a class that starts with "language" @@ -414,6 +440,9 @@ To get started, just start typing below. You can also type / to see a list of co if (chunk.startsWith("{") && chunk.endsWith("}")) { try { const responseAsJson = JSON.parse(chunk); + if (responseAsJson.imageUrl) { + rawResponse += `![${query}](${responseAsJson.imageUrl})`; + } if (responseAsJson.detail) { rawResponse += responseAsJson.detail; } @@ -516,7 +545,7 @@ To get started, just start typing below. You can also type / to see a list of co .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); }); }) .catch(err => { @@ -902,6 +931,9 @@ To get started, just start typing below. You can also type / to see a list of co margin-top: -10px; transform: rotate(-60deg) } + img.text-to-image { + max-width: 60%; + } #chat-footer { padding: 0; @@ -1029,6 +1061,9 @@ To get started, just start typing below. You can also type / to see a list of co margin: 4px; grid-template-columns: auto; } + img.text-to-image { + max-width: 100%; + } } @media only screen and (min-width: 700px) { body { From 8016a57b5e662dbe6f86f0013c7359c1cc1f5cf1 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 20:56:35 -0500 Subject: [PATCH 08/30] Show generated images in chat interface on Desktop client --- src/interface/desktop/chat.html | 47 ++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 8c3fc49d..92a11ebd 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -179,7 +179,13 @@ return numOnlineReferences; } - function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { + if (intentType === "text-to-image") { + let imageMarkdown = `![](${message})`; + renderMessage(imageMarkdown, by, dt); + return; + } + if (context == null && onlineContext == null) { renderMessage(message, by, dt); return; @@ -244,6 +250,26 @@ // Remove any text between [INST] and tags. These are spurious instructions for the AI chat model. newHTML = newHTML.replace(/\[INST\].+(<\/s>)?/g, ''); + // Customize the rendering of images + md.renderer.rules.image = function(tokens, idx, options, env, self) { + let token = tokens[idx]; + + // Get image source url. Only render images with src links + let srcIndex = token.attrIndex('src'); + if (srcIndex < 0) { return ''; } + let src = token.attrs[srcIndex][1]; + + // Wrap the image in a link + var aStart = ``; + var aEnd = ''; + + // Add class="text-to-image" to images + token.attrPush(['class', 'text-to-image']); + + // Use the default renderer to render image markdown format + return aStart + self.renderToken(tokens, idx, options) + aEnd; + }; + // Render markdown newHTML = md.render(newHTML); // Get any elements with a class that starts with "language" @@ -404,16 +430,23 @@ if (newResponseText.getElementsByClassName("spinner").length > 0) { newResponseText.removeChild(loadingSpinner); } + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. if (chunk.startsWith("{") && chunk.endsWith("}")) { try { const responseAsJson = JSON.parse(chunk); + if (responseAsJson.imageUrl) { + rawResponse += `![${query}](${responseAsJson.imageUrl})`; + } if (responseAsJson.detail) { - newResponseText.innerHTML += responseAsJson.detail; + rawResponse += responseAsJson.detail; } } catch (error) { // If the chunk is not a JSON object, just display it as is - newResponseText.innerHTML += chunk; + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); } } else { // If the chunk is not a JSON object, just display it as is @@ -522,7 +555,7 @@ .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); }); }) .catch(err => { @@ -810,6 +843,9 @@ margin-top: -10px; transform: rotate(-60deg) } + img.text-to-image { + max-width: 60%; + } #chat-footer { padding: 0; @@ -1050,6 +1086,9 @@ margin: 4px; grid-template-columns: auto; } + img.text-to-image { + max-width: 100%; + } } @media only screen and (min-width: 600px) { body { From 52c5f4170a9b9cbab37180adadf0e73cd22b7444 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 21:23:31 -0500 Subject: [PATCH 09/30] Show generated images in the chat modal of the Khoj Obsidian plugin --- src/interface/obsidian/src/chat_modal.ts | 31 ++++++++++++++++++++---- src/interface/obsidian/styles.css | 4 ++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index e5cff4d2..9786e45a 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -105,15 +105,19 @@ export class KhojChatModal extends Modal { return referenceButton; } - renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) { + renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) { if (!message) { return; + } else if (intentType === "text-to-image") { + let imageMarkdown = `![](${message})`; + this.renderMessage(chatEl, imageMarkdown, sender, dt); + return; } else if (!context) { this.renderMessage(chatEl, message, sender, dt); - return + return; } else if (!!context && context?.length === 0) { this.renderMessage(chatEl, message, sender, dt); - return + return; } let chatMessageEl = this.renderMessage(chatEl, message, sender, dt); let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0] @@ -225,7 +229,7 @@ export class KhojChatModal extends Modal { let response = await request({ url: chatUrl, headers: headers }); let chatLogs = JSON.parse(response).response; chatLogs.forEach((chatLog: any) => { - this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created)); + this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type); }); } @@ -267,7 +271,7 @@ export class KhojChatModal extends Modal { this.result = ""; responseElement.innerHTML = ""; for await (const chunk of response.body) { - const responseText = chunk.toString(); + let responseText = chunk.toString(); if (responseText.includes("### compiled references:")) { const additionalResponse = responseText.split("### compiled references:")[0]; this.renderIncrementalMessage(responseElement, additionalResponse); @@ -310,6 +314,23 @@ export class KhojChatModal extends Modal { referenceExpandButton.innerHTML = expandButtonText; references.appendChild(referenceSection); } else { + if (responseText.startsWith("{") && responseText.endsWith("}")) { + try { + const responseAsJson = JSON.parse(responseText); + if (responseAsJson.imageUrl) { + responseText = `![${query}](${responseAsJson.imageUrl})`; + } else if (responseAsJson.detail) { + responseText = responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + continue; + } + } else { + // If the chunk is not a JSON object, just display it as is + continue; + } + this.renderIncrementalMessage(responseElement, responseText); } } diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css index 6a29e722..11fc0086 100644 --- a/src/interface/obsidian/styles.css +++ b/src/interface/obsidian/styles.css @@ -217,7 +217,9 @@ button.copy-button:hover { background: #f5f5f5; cursor: pointer; } - +img { + max-width: 60%; +} #khoj-chat-footer { padding: 0; From 6e3f66c0f12df8b38763f89b5576b679a1e52191 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 22:55:22 -0500 Subject: [PATCH 10/30] Use base64 encoded image instead of source URL for persistence The source URL returned by OpenAI would expire soon. This would make the chat sessions contain non-accessible images/messages if using OpenaI image URL Get base64 encoded image from OpenAI and store directly in conversation logs. This resolves the image link expiring issue --- src/interface/desktop/chat.html | 17 ++++------------- src/interface/obsidian/src/chat_modal.ts | 6 +++--- src/khoj/interface/web/chat.html | 17 ++++------------- src/khoj/routers/api.py | 6 +++--- src/khoj/routers/helpers.py | 8 ++++---- 5 files changed, 18 insertions(+), 36 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 92a11ebd..e039c6cb 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -181,7 +181,7 @@ function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { if (intentType === "text-to-image") { - let imageMarkdown = `![](${message})`; + let imageMarkdown = `![](data:image/png;base64,${message})`; renderMessage(imageMarkdown, by, dt); return; } @@ -254,20 +254,11 @@ md.renderer.rules.image = function(tokens, idx, options, env, self) { let token = tokens[idx]; - // Get image source url. Only render images with src links - let srcIndex = token.attrIndex('src'); - if (srcIndex < 0) { return ''; } - let src = token.attrs[srcIndex][1]; - - // Wrap the image in a link - var aStart = ``; - var aEnd = ''; - // Add class="text-to-image" to images token.attrPush(['class', 'text-to-image']); // Use the default renderer to render image markdown format - return aStart + self.renderToken(tokens, idx, options) + aEnd; + return self.renderToken(tokens, idx, options); }; // Render markdown @@ -435,8 +426,8 @@ if (chunk.startsWith("{") && chunk.endsWith("}")) { try { const responseAsJson = JSON.parse(chunk); - if (responseAsJson.imageUrl) { - rawResponse += `![${query}](${responseAsJson.imageUrl})`; + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } if (responseAsJson.detail) { rawResponse += responseAsJson.detail; diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 9786e45a..145bae50 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -109,7 +109,7 @@ export class KhojChatModal extends Modal { if (!message) { return; } else if (intentType === "text-to-image") { - let imageMarkdown = `![](${message})`; + let imageMarkdown = `![](data:image/png;base64,${message})`; this.renderMessage(chatEl, imageMarkdown, sender, dt); return; } else if (!context) { @@ -317,8 +317,8 @@ export class KhojChatModal extends Modal { if (responseText.startsWith("{") && responseText.endsWith("}")) { try { const responseAsJson = JSON.parse(responseText); - if (responseAsJson.imageUrl) { - responseText = `![${query}](${responseAsJson.imageUrl})`; + if (responseAsJson.image) { + responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; } else if (responseAsJson.detail) { responseText = responseAsJson.detail; } diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 39cb6e77..97fdbebb 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -190,7 +190,7 @@ To get started, just start typing below. You can also type / to see a list of co function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { if (intentType === "text-to-image") { - let imageMarkdown = `![](${message})`; + let imageMarkdown = `![](data:image/png;base64,${message})`; renderMessage(imageMarkdown, by, dt); return; } @@ -263,20 +263,11 @@ To get started, just start typing below. You can also type / to see a list of co md.renderer.rules.image = function(tokens, idx, options, env, self) { let token = tokens[idx]; - // Get image source url. Only render images with src links - let srcIndex = token.attrIndex('src'); - if (srcIndex < 0) { return ''; } - let src = token.attrs[srcIndex][1]; - - // Wrap the image in a link - var aStart = ``; - var aEnd = ''; - // Add class="text-to-image" to images token.attrPush(['class', 'text-to-image']); // Use the default renderer to render image markdown format - return aStart + self.renderToken(tokens, idx, options) + aEnd; + return self.renderToken(tokens, idx, options); }; // Render markdown @@ -440,8 +431,8 @@ To get started, just start typing below. You can also type / to see a list of co if (chunk.startsWith("{") && chunk.endsWith("}")) { try { const responseAsJson = JSON.parse(chunk); - if (responseAsJson.imageUrl) { - rawResponse += `![${query}](${responseAsJson.imageUrl})`; + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } if (responseAsJson.detail) { rawResponse += responseAsJson.detail; diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ae31c260..d53f023a 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -706,9 +706,9 @@ async def chat( status_code=200, ) elif conversation_command == ConversationCommand.Image: - image_url, status_code = await text_to_image(q) - await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image") - content_obj = {"imageUrl": image_url, "intentType": "text-to-image"} + image, status_code = await text_to_image(q) + await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image") + content_obj = {"image": image, "intentType": "text-to-image"} return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4e43289f..f34ae815 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -252,7 +252,7 @@ def generate_chat_response( async def text_to_image(message: str) -> Tuple[Optional[str], int]: status_code = 200 - image_url = None + image = None # Send the audio data to the Whisper API text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() @@ -264,13 +264,13 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: client = openai.OpenAI(api_key=openai_chat_config.api_key) text2image_model = text_to_image_config.model_name try: - response = client.images.generate(prompt=message, model=text2image_model) - image_url = response.data[0].url + response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json") + image = response.data[0].b64_json except openai.OpenAIError as e: logger.error(f"Image Generation failed with {e.http_status}: {e.error}") status_code = 500 - return image_url, status_code + return image, status_code class ApiUserRateLimiter: From d124266923076135909e0055e9eb725fe7d78591 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 00:41:16 -0500 Subject: [PATCH 11/30] Reduce promise based nesting in chat JS func used in desktop, web client Use async/await to reduce .then() based nesting to improve code readability --- src/interface/desktop/chat.html | 196 +++++++++++++++--------------- src/khoj/interface/web/chat.html | 200 +++++++++++++++---------------- 2 files changed, 196 insertions(+), 200 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index e039c6cb..279397bd 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -346,115 +346,113 @@ chatInput.classList.remove("option-enabled"); // Call specified Khoj API which returns a streamed response of type text/plain - fetch(url, { headers }) - .then(response => { - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let rawResponse = ""; - let references = null; + let response = await fetch(url, { headers }); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let rawResponse = ""; + let references = null; - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; + } + + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); + + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); + + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); } + }); - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); + } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } + if (responseAsJson.detail) { + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); - } - - references.appendChild(referenceExpandButton); - - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); - } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); - } - }); - - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } - - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - readStream(); - } } + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + readStream(); + } } - readStream(); + + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; }); + } + readStream(); } function incrementalChat(event) { diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 97fdbebb..4a406a6b 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -309,7 +309,7 @@ To get started, just start typing below. You can also type / to see a list of co return element } - function chat() { + async function chat() { // Extract required fields for search from form let query = document.getElementById("chat-input").value.trim(); let resultsCount = localStorage.getItem("khojResultsCount") || 5; @@ -351,115 +351,113 @@ To get started, just start typing below. You can also type / to see a list of co chatInput.classList.remove("option-enabled"); // Call specified Khoj API which returns a streamed response of type text/plain - fetch(url) - .then(response => { - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let rawResponse = ""; - let references = null; + let response = await fetch(url); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let rawResponse = ""; + let references = null; - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; + } + + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); + + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); + + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); } + }); - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); + } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } + if (responseAsJson.detail) { + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); - } - - references.appendChild(referenceExpandButton); - - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); - } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); - } - }); - - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } - - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - readStream(); - } } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + readStream(); + } } - readStream(); + + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; }); - } + } + readStream(); + }; function incrementalChat(event) { if (!event.shiftKey && event.key === 'Enter') { From 8f2f053968623469e34f88d35cb076827b79fe31 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 01:03:52 -0500 Subject: [PATCH 12/30] Fix rendering image on chat response in web, desktop client --- src/interface/desktop/chat.html | 214 +++++++++++++++++-------------- src/khoj/interface/web/chat.html | 171 ++++++++++++------------ 2 files changed, 211 insertions(+), 174 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 279397bd..cad7971f 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -345,114 +345,142 @@ let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Call specified Khoj API which returns a streamed response of type text/plain + // Call specified Khoj API let response = await fetch(url, { headers }); - const reader = response.body.getReader(); - const decoder = new TextDecoder(); let rawResponse = ""; - let references = null; + const contentType = response.headers.get("content-type"); - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + if (contentType === "application/json") { + // Handle JSON response + try { + const responseAsJson = await response.json(); + if (responseAsJson.image) { + // If response has image field, response is a generated image. + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } + if (responseAsJson.detail) { + // If response has detail field, response is an error message. + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + } + } else { + // Handle streamed response of type text/event-stream or text/plain + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let references = null; - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); + readStream(); - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; } - references.appendChild(referenceExpandButton); + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); - } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); - } - }); - - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } - - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); - readStream(); - } - } + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); + } + }); + + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); + } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } + + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } + if (responseAsJson.detail) { + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + } + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + readStream(); + } + } + + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + }); + } } - readStream(); } function incrementalChat(event) { diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 4a406a6b..c243ed67 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -350,113 +350,122 @@ To get started, just start typing below. You can also type / to see a list of co let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Call specified Khoj API which returns a streamed response of type text/plain + // Call specified Khoj API let response = await fetch(url); - const reader = response.body.getReader(); - const decoder = new TextDecoder(); let rawResponse = ""; - let references = null; + const contentType = response.headers.get("content-type"); - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + if (contentType === "application/json") { + // Handle JSON response + try { + const responseAsJson = await response.json(); + if (responseAsJson.image) { + // If response has image field, response is a generated image. + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } + if (responseAsJson.detail) { + // If response has detail field, response is an error message. + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + } + } else { + // Handle streamed response of type text/event-stream or text/plain + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let references = null; - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); + readStream(); - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; } - references.appendChild(referenceExpandButton); + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); + + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); } - }); - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } + references.appendChild(referenceExpandButton); - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } + }); + + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } + // If the chunk is not a JSON object, just display it as is rawResponse += chunk; newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); readStream(); } - } + }); // Scroll to bottom of chat window as chat response is streamed document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + }; } - readStream(); }; function incrementalChat(event) { From 162b219f2bbd6fabbf90bae5503646e21925010a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 01:29:36 -0500 Subject: [PATCH 13/30] Throw unsupported error when server not configured for image, speech-to-text --- src/interface/desktop/chat.html | 10 +++++++--- src/interface/obsidian/src/chat_modal.ts | 8 +++++--- src/khoj/interface/web/chat.html | 10 +++++++--- src/khoj/routers/api.py | 4 ++-- src/khoj/routers/helpers.py | 4 ++-- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index cad7971f..6b7fde07 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -675,9 +675,13 @@ .then(response => response.ok ? response.json() : Promise.reject(response)) .then(data => { chatInput.value += data.text; }) .catch(err => { - err.status == 422 - ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") - : flashStatusInChatInput("⛔️ Failed to transcribe audio") + if (err.status === 501) { + flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") + } else if (err.status === 422) { + flashStatusInChatInput("⛔️ Audio file to large to process.") + } else { + flashStatusInChatInput("⛔️ Failed to transcribe audio.") + } }); }; diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 145bae50..09f7a181 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -410,10 +410,12 @@ export class KhojChatModal extends Modal { if (response.status === 200) { console.log(response); chatInput.value += response.json.text; - } else if (response.status === 422) { - throw new Error("⛔️ Failed to transcribe audio"); - } else { + } else if (response.status === 501) { throw new Error("⛔️ Configure speech-to-text model on server."); + } else if (response.status === 422) { + throw new Error("⛔️ Audio file to large to process."); + } else { + throw new Error("⛔️ Failed to transcribe audio."); } }; diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index c243ed67..e85759fb 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -638,9 +638,13 @@ To get started, just start typing below. You can also type / to see a list of co .then(response => response.ok ? response.json() : Promise.reject(response)) .then(data => { chatInput.value += data.text; }) .catch(err => { - err.status == 422 - ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") - : flashStatusInChatInput("⛔️ Failed to transcribe audio") + if (err.status === 501) { + flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") + } else if (err.status === 422) { + flashStatusInChatInput("⛔️ Audio file to large to process.") + } else { + flashStatusInChatInput("⛔️ Failed to transcribe audio.") + } }); }; diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d53f023a..4d2c80a2 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -626,8 +626,8 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not speech_to_text_config: - # If the user has not configured a speech to text model, return an unprocessable entity error - status_code = 422 + # If the user has not configured a speech to text model, return an unsupported on server error + status_code = 501 elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: api_key = openai_chat_config.api_key speech2text_model = speech_to_text_config.model_name diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f34ae815..a780eb20 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -258,8 +258,8 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not text_to_image_config: - # If the user has not configured a text to image model, return an unprocessable entity error - status_code = 422 + # If the user has not configured a text to image model, return an unsupported on server error + status_code = 501 elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: client = openai.OpenAI(api_key=openai_chat_config.api_key) text2image_model = text_to_image_config.model_name From 408b7413e9ef75fbc1d0a5e2f453119b4b8ea131 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 02:40:28 -0500 Subject: [PATCH 14/30] Use global openai client for transcribe, image --- src/khoj/configure.py | 6 ++++++ .../processor/conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 2 +- src/khoj/processor/conversation/openai/whisper.py | 5 ++--- src/khoj/routers/api.py | 15 ++++++--------- src/khoj/routers/helpers.py | 8 ++++---- src/khoj/utils/models.py | 12 ++---------- src/khoj/utils/state.py | 8 +++++--- 8 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19e7d403..717ad859 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -7,6 +7,7 @@ import requests import os # External Packages +import openai import schedule from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -22,6 +23,7 @@ from starlette.authentication import ( # Internal Packages from khoj.database.models import KhojUser, Subscription from khoj.database.adapters import ( + ConversationAdapters, get_all_users, get_or_create_search_model, aget_user_subscription_state, @@ -138,6 +140,10 @@ def configure_server( config = FullConfig() state.config = config + if ConversationAdapters.has_valid_openai_conversation_config(): + openai_config = ConversationAdapters.get_openai_conversation_config() + state.openai_client = openai.OpenAI(api_key=openai_config.api_key) + # Initialize Search Models from Config and initialize content try: state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 41b1844b..211e3cac 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -47,7 +47,7 @@ def extract_questions_offline( if use_history: for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj": + if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image": chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"A: {chat['message']}\n" diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index dc708ab7..7bebc26c 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -41,7 +41,7 @@ def extract_questions( [ f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n' for chat in conversation_log.get("chat", [])[-4:] - if chat["by"] == "khoj" + if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image" ] ) diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py index 351319b7..bd0e66df 100644 --- a/src/khoj/processor/conversation/openai/whisper.py +++ b/src/khoj/processor/conversation/openai/whisper.py @@ -6,11 +6,10 @@ from asgiref.sync import sync_to_async from openai import OpenAI -async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: +async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str: """ Transcribe audio file using Whisper model via OpenAI's API """ # Send the audio data to the Whisper API - client = OpenAI(api_key=api_key) response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file) - return response["text"] + return response.text diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4d2c80a2..7efd8bfd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -19,7 +19,7 @@ from starlette.authentication import requires from khoj.configure import configure_server from khoj.database import adapters from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import ChatModelOptions +from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions from khoj.database.models import Entry as DbEntry from khoj.database.models import ( GithubConfig, @@ -624,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi # Send the audio data to the Whisper API speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() - openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not speech_to_text_config: # If the user has not configured a speech to text model, return an unsupported on server error status_code = 501 - elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: - api_key = openai_chat_config.api_key + elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI: speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key) - elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE: + user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client) + elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE: speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model) + user_message = await transcribe_audio_offline(audio_filename, speech2text_model) finally: # Close and Delete the temporary audio file audio_file.close() @@ -793,7 +791,6 @@ async def extract_references_and_questions( conversation_config = await ConversationAdapters.aget_conversation_config(user) if conversation_config is None: conversation_config = await ConversationAdapters.aget_default_conversation_config() - openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() if ( offline_chat_config and offline_chat_config.enabled @@ -810,7 +807,7 @@ async def extract_references_and_questions( inferred_queries = extract_questions_offline( defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False ) - elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: + elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat = await ConversationAdapters.get_openai_chat() api_key = openai_chat_config.api_key diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a780eb20..4e883f35 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -256,15 +256,15 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: # Send the audio data to the Whisper API text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() - openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 - elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: - client = openai.OpenAI(api_key=openai_chat_config.api_key) + elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: text2image_model = text_to_image_config.model_name try: - response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json") + response = state.openai_client.images.generate( + prompt=message, model=text2image_model, response_format="b64_json" + ) image = response.data[0].b64_json except openai.OpenAIError as e: logger.error(f"Image Generation failed with {e.http_status}: {e.error}") diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index 37d09418..e1298b08 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -22,17 +22,9 @@ class BaseEncoder(ABC): class OpenAI(BaseEncoder): - def __init__(self, model_name, device=None): + def __init__(self, model_name, client: openai.OpenAI, device=None): self.model_name = model_name - if ( - not state.processor_config - or not state.processor_config.conversation - or not state.processor_config.conversation.openai_model - ): - raise Exception( - f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}" - ) - self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key) + self.openai_client = client self.embedding_dimensions = None def encode(self, entries, device=None, **kwargs): diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index b54cf4b3..d5358868 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -1,15 +1,16 @@ # Standard Packages +from collections import defaultdict import os +from pathlib import Path import threading from typing import List, Dict -from collections import defaultdict # External Packages -from pathlib import Path -from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel +from openai import OpenAI from whisper import Whisper # Internal Packages +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.utils import config as utils_config from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.helpers import LRU, get_device @@ -21,6 +22,7 @@ search_models = SearchModels() embeddings_model: EmbeddingsModel = None cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() +openai_client: OpenAI = None gpt4all_processor_config: GPT4AllProcessorModel = None whisper_model: Whisper = None config_file: Path = None From 7504669f2bef0a2217e8793f2667d8519ac0ee38 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 03:48:07 -0500 Subject: [PATCH 15/30] Fix rendering image on chat response in obsidian client --- src/interface/obsidian/src/chat_modal.ts | 34 ++++++++++++++++-------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 09f7a181..115f4c1f 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi import { KhojSetting } from 'src/settings'; import fetch from "node-fetch"; +export interface ChatJsonResult { + image?: string; + detail?: string; +} + + export class KhojChatModal extends Modal { result: string; setting: KhojSetting; @@ -270,6 +276,23 @@ export class KhojChatModal extends Modal { this.result = ""; responseElement.innerHTML = ""; + if (response.headers.get("content-type") == "application/json") { + let responseText = "" + try { + const responseAsJson = await response.json() as ChatJsonResult; + if (responseAsJson.image) { + responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; + } else if (responseAsJson.detail) { + responseText = responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + responseText = response.body.read().toString() + } finally { + this.renderIncrementalMessage(responseElement, responseText); + } + } + for await (const chunk of response.body) { let responseText = chunk.toString(); if (responseText.includes("### compiled references:")) { @@ -315,17 +338,6 @@ export class KhojChatModal extends Modal { references.appendChild(referenceSection); } else { if (responseText.startsWith("{") && responseText.endsWith("}")) { - try { - const responseAsJson = JSON.parse(responseText); - if (responseAsJson.image) { - responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; - } else if (responseAsJson.detail) { - responseText = responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - continue; - } } else { // If the chunk is not a JSON object, just display it as is continue; From 73a107690d0aadc1e5e6242acb24ee4418cb1a23 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 16 Dec 2023 09:03:52 +0530 Subject: [PATCH 16/30] Add a ConversationCommand rate limiter for the chat endpoint --- src/khoj/configure.py | 48 ++++++++++++++++++------------------- src/khoj/routers/api.py | 4 ++++ src/khoj/routers/helpers.py | 34 ++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 24 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 717ad859..4bb23812 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -77,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user: - if state.billing_enabled: - subscription_state = await aget_user_subscription_state(user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) - if subscribed: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + if not state.billing_enabled: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + + subscription_state = await aget_user_subscription_state(user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -99,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user_with_token: - if state.billing_enabled: - subscription_state = await aget_user_subscription_state(user_with_token.user) - subscribed = ( - subscription_state == SubscriptionState.SUBSCRIBED.value - or subscription_state == SubscriptionState.TRIAL.value - or subscription_state == SubscriptionState.UNSUBSCRIBED.value - ) - if subscribed: - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser( - user_with_token.user - ) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) - return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + if not state.billing_enabled: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + + subscription_state = await aget_user_subscription_state(user_with_token.user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 7efd8bfd..65516a9c 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -46,6 +46,7 @@ from khoj.routers.helpers import ( is_ready_to_chat, update_telemetry_state, validate_conversation_config, + ConversationCommandRateLimiter, ) from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -67,6 +68,7 @@ from khoj.utils.state import SearchType # Initialize Router api = APIRouter() logger = logging.getLogger(__name__) +conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100) def map_config_to_object(content_source: str): @@ -670,6 +672,8 @@ async def chat( await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) + conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command) + q = q.replace(f"/{conversation_command.value}", "").strip() meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4e883f35..56fba861 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -300,6 +300,40 @@ class ApiUserRateLimiter: user_requests.append(time()) +class ConversationCommandRateLimiter: + def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int): + self.cache = defaultdict(lambda: defaultdict(list)) + self.trial_rate_limit = trial_rate_limit + self.subscribed_rate_limit = subscribed_rate_limit + self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image] + + def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): + if state.billing_enabled is False: + return + + if not request.user.is_authenticated: + return + + if conversation_command not in self.restricted_commands: + return + + user: KhojUser = request.user.object + user_cache = self.cache[user.uuid] + subscribed = has_required_scope(request, ["premium"]) + user_cache[conversation_command].append(time()) + + # Remove requests outside of the 24-hr time window + cutoff = time() - 60 * 60 * 24 + while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff: + user_cache[conversation_command].pop(0) + + if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit: + raise HTTPException(status_code=429, detail="Too Many Requests") + if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit: + raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.") + return + + class ApiIndexedDataLimiter: def __init__( self, From 5f6dcf9f2e6c28feeb45d32a2ebc998eb019f05a Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 16 Dec 2023 09:18:56 +0530 Subject: [PATCH 17/30] Add a rate limiter for the transcribe API endpoint --- src/khoj/interface/web/chat.html | 2 ++ src/khoj/routers/api.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index e85759fb..aa6bd4b9 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -642,6 +642,8 @@ To get started, just start typing below. You can also type / to see a list of co flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") } else if (err.status === 422) { flashStatusInChatInput("⛔️ Audio file to large to process.") + } else if (err.status === 429) { + flashStatusInChatInput("⛔️ " + err.statusText); } else { flashStatusInChatInput("⛔️ Failed to transcribe audio.") } diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 65516a9c..a4063afa 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -606,7 +606,13 @@ async def chat_options( @api.post("/transcribe") @requires(["authenticated"]) -async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)): +async def transcribe( + request: Request, + common: CommonQueryParams, + file: UploadFile = File(...), + rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)), + rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), +): user: KhojUser = request.user.object audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm" user_message: str = None From 3065cea562bc8a0ee7379c958a529052c7d89ae7 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 16 Dec 2023 09:24:26 +0530 Subject: [PATCH 18/30] Address mypy typing issues --- src/khoj/database/adapters/__init__.py | 2 +- src/khoj/routers/helpers.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 6e17c15b..31aa1d57 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -401,7 +401,7 @@ class ConversationAdapters: ) max_results = 3 - all_questions = await sync_to_async(list)(all_questions) + all_questions = await sync_to_async(list)(all_questions) # type: ignore if len(all_questions) < max_results: return all_questions diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 56fba861..d2a79e39 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -267,7 +267,7 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: ) image = response.data[0].b64_json except openai.OpenAIError as e: - logger.error(f"Image Generation failed with {e.http_status}: {e.error}") + logger.error(f"Image Generation failed with {e}", exc_info=True) status_code = 500 return image, status_code @@ -302,7 +302,7 @@ class ApiUserRateLimiter: class ConversationCommandRateLimiter: def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int): - self.cache = defaultdict(lambda: defaultdict(list)) + self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) self.trial_rate_limit = trial_rate_limit self.subscribed_rate_limit = subscribed_rate_limit self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image] @@ -351,7 +351,7 @@ class ApiIndexedDataLimiter: if state.billing_enabled is False: return subscribed = has_required_scope(request, ["premium"]) - incoming_data_size_mb = 0 + incoming_data_size_mb = 0.0 deletion_file_names = set() if not request.user.is_authenticated: From 61dde8ed89112f115b07062cfc291f801d119324 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 17 Dec 2023 12:54:50 +0530 Subject: [PATCH 19/30] If text to image config isn't set, send back an error message to the client --- src/khoj/routers/api.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index a4063afa..235b6e7c 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -715,6 +715,13 @@ async def chat( ) elif conversation_command == ConversationCommand.Image: image, status_code = await text_to_image(q) + if image is None: + content_obj = { + "image": image, + "intentType": "text-to-image", + "detail": "Failed to generate image. Make sure your image generation configuration is set.", + } + return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image") content_obj = {"image": image, "intentType": "text-to-image"} return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) From 0459666beb54d8f9a83233600047c4e9addc5689 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 17 Dec 2023 12:55:18 +0530 Subject: [PATCH 20/30] CSRF Cookie not set error in prod. Try fixing https forwarding for mitigation --- src/khoj/app/settings.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/khoj/app/settings.py b/src/khoj/app/settings.py index bacbf904..86db4b12 100644 --- a/src/khoj/app/settings.py +++ b/src/khoj/app/settings.py @@ -33,6 +33,9 @@ ALLOWED_HOSTS = [f".{KHOJ_DOMAIN}", "localhost", "127.0.0.1", "[::1]"] CSRF_TRUSTED_ORIGINS = [ f"https://*.{KHOJ_DOMAIN}", f"https://{KHOJ_DOMAIN}", + f"http://*.{KHOJ_DOMAIN}", + f"http://{KHOJ_DOMAIN}", + f"https://app.{KHOJ_DOMAIN}", ] COOKIE_SAMESITE = "None" @@ -42,6 +45,7 @@ if DEBUG or os.getenv("KHOJ_DOMAIN") == None: else: SESSION_COOKIE_DOMAIN = KHOJ_DOMAIN CSRF_COOKIE_DOMAIN = KHOJ_DOMAIN + SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTOCOL", "https") SESSION_COOKIE_SECURE = True CSRF_COOKIE_SECURE = True From 09544dee09335565bcabbc75d80e579c66b5c6c5 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 17 Dec 2023 16:44:19 +0530 Subject: [PATCH 21/30] Add TextToImageModelConfig to the admin page --- src/khoj/database/admin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 2213fb6e..f64a02df 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -12,6 +12,7 @@ from khoj.database.models import ( SpeechToTextModelOptions, Subscription, ReflectiveQuestion, + TextToImageModelConfig, ) admin.site.register(KhojUser, UserAdmin) @@ -23,3 +24,4 @@ admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(SearchModelConfig) admin.site.register(Subscription) admin.site.register(ReflectiveQuestion) +admin.site.register(TextToImageModelConfig) From 7cb64cb2f9f801504adb58d7f02c603727b0252e Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 17 Dec 2023 18:25:03 +0530 Subject: [PATCH 22/30] Add telemetry for image generation conversation command --- src/khoj/routers/api.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 235b6e7c..065020d3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -714,6 +714,13 @@ async def chat( status_code=200, ) elif conversation_command == ConversationCommand.Image: + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_command.value}, + **common.__dict__, + ) image, status_code = await text_to_image(q) if image is None: content_obj = { From 49af2148fee773121c97415c065b406b99e20233 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 17 Dec 2023 20:25:35 +0530 Subject: [PATCH 23/30] Miscellaneous improvements to image generation - Improve the prompt before sending it for image generation - Update the help message to include online, image functionality - Improve styling for the voice, trash buttons --- src/interface/desktop/chat.html | 4 ++-- src/khoj/interface/web/chat.html | 6 +++--- src/khoj/processor/conversation/prompts.py | 23 ++++++++++++++++++---- src/khoj/routers/helpers.py | 22 ++++++++++++++++++--- 4 files changed, 43 insertions(+), 12 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 6b7fde07..db3b2c26 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -903,7 +903,8 @@ } .input-row-button { background: var(--background-color); - border: none; + border: 1px solid var(--main-text-color); + box-shadow: 0 0 11px #aaa; border-radius: 5px; padding: 5px; font-size: 14px; @@ -989,7 +990,6 @@ color: var(--main-text-color); border: 1px solid var(--main-text-color); border-radius: 5px; - padding: 5px; font-size: 14px; font-weight: 300; line-height: 1.5em; diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index aa6bd4b9..4f8031ce 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -952,7 +952,7 @@ To get started, just start typing below. You can also type / to see a list of co grid-template-columns: auto 32px 32px; grid-column-gap: 10px; grid-row-gap: 10px; - background: #f9fafc + background: var(--background-color); } .option:hover { box-shadow: 0 0 11px #aaa; @@ -974,9 +974,9 @@ To get started, just start typing below. You can also type / to see a list of co } .input-row-button { background: var(--background-color); - border: none; + border: 1px solid var(--main-text-color); + box-shadow: 0 0 11px #aaa; border-radius: 5px; - padding: 5px; font-size: 14px; font-weight: 300; line-height: 1.5em; diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index b0e316da..31cd9160 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -109,6 +109,18 @@ Question: {query} """.strip() ) +## Image Generation +## -- + +image_generation_improve_prompt = PromptTemplate.from_template( + """ +Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. + +Query: {query} + +Improved Query:""" +) + ## Online Search Conversation ## -- online_search_conversation = PromptTemplate.from_template( @@ -295,10 +307,13 @@ Q:""" # -- help_message = PromptTemplate.from_template( """ -**/notes**: Chat using the information in your knowledge base. -**/general**: Chat using just Khoj's general knowledge. This will not search against your notes. -**/default**: Chat using your knowledge base and Khoj's general knowledge for context. -**/help**: Show this help message. +- **/notes**: Chat using the information in your knowledge base. +- **/general**: Chat using just Khoj's general knowledge. This will not search against your notes. +- **/default**: Chat using your knowledge base and Khoj's general knowledge for context. +- **/online**: Chat using the internet as a source of information. +- **/image**: Generate an image based on your message. +- **/help**: Show this help message. + You are using the **{model}** model on the **{device}**. **version**: {version} diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d2a79e39..ed366590 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -146,6 +146,20 @@ async def generate_online_subqueries(q: str) -> List[str]: return [q] +async def generate_better_image_prompt(q: str) -> str: + """ + Generate a better image prompt from the given query + """ + + image_prompt = prompts.image_generation_improve_prompt.format( + query=q, + ) + + response = await send_message_to_model_wrapper(image_prompt) + + return response.strip() + + async def send_message_to_model_wrapper( message: str, ): @@ -170,11 +184,13 @@ async def send_message_to_model_wrapper( openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() api_key = openai_chat_config.api_key chat_model = conversation_config.chat_model - return send_message_to_model( + openai_response = send_message_to_model( message=message, api_key=api_key, model=chat_model, ) + + return openai_response.content else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -254,16 +270,16 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: status_code = 200 image = None - # Send the audio data to the Whisper API text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: text2image_model = text_to_image_config.model_name + improved_image_prompt = await generate_better_image_prompt(message) try: response = state.openai_client.images.generate( - prompt=message, model=text2image_model, response_format="b64_json" + prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" ) image = response.data[0].b64_json except openai.OpenAIError as e: From 0288804f2eece10b3c42a2fc2ba225c81944f861 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 17 Dec 2023 21:02:55 +0530 Subject: [PATCH 24/30] Render the inferred query along with the image that Khoj returns --- src/interface/desktop/chat.html | 22 +++++++++++++++++++--- src/khoj/interface/web/chat.html | 14 ++++++++++++-- src/khoj/routers/api.py | 8 +++++--- src/khoj/routers/helpers.py | 2 +- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index db3b2c26..b1d0ce48 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -179,9 +179,14 @@ return numOnlineReferences; } - function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { if (intentType === "text-to-image") { let imageMarkdown = `![](data:image/png;base64,${message})`; + imageMarkdown += "\n\n"; + if (inferredQueries) { + const inferredQuery = inferredQueries?.[0]; + imageMarkdown += `**Inferred Query**: ${inferredQuery}`; + } renderMessage(imageMarkdown, by, dt); return; } @@ -357,6 +362,11 @@ if (responseAsJson.image) { // If response has image field, response is a generated image. rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + rawResponse += "\n\n"; + const inferredQueries = responseAsJson.inferredQueries?.[0]; + if (inferredQueries) { + rawResponse += `**Inferred Query**: ${inferredQueries}`; + } } if (responseAsJson.detail) { // If response has detail field, response is an error message. @@ -454,7 +464,13 @@ try { const responseAsJson = JSON.parse(chunk); if (responseAsJson.image) { + // If response has image field, response is a generated image. rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + rawResponse += "\n\n"; + const inferredQueries = responseAsJson.inferredQueries?.[0]; + if (inferredQueries) { + rawResponse += `**Inferred Query**: ${inferredQueries}`; + } } if (responseAsJson.detail) { rawResponse += responseAsJson.detail; @@ -572,7 +588,7 @@ .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]); }); }) .catch(err => { @@ -906,9 +922,9 @@ border: 1px solid var(--main-text-color); box-shadow: 0 0 11px #aaa; border-radius: 5px; - padding: 5px; font-size: 14px; font-weight: 300; + padding: 0; line-height: 1.5em; cursor: pointer; transition: background 0.3s ease-in-out; diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 4f8031ce..2c64d96c 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -188,9 +188,14 @@ To get started, just start typing below. You can also type / to see a list of co return numOnlineReferences; } - function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { if (intentType === "text-to-image") { let imageMarkdown = `![](data:image/png;base64,${message})`; + imageMarkdown += "\n\n"; + if (inferredQueries) { + const inferredQuery = inferredQueries?.[0]; + imageMarkdown += `**Inferred Query**: ${inferredQuery}`; + } renderMessage(imageMarkdown, by, dt); return; } @@ -362,6 +367,11 @@ To get started, just start typing below. You can also type / to see a list of co if (responseAsJson.image) { // If response has image field, response is a generated image. rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + rawResponse += "\n\n"; + const inferredQueries = responseAsJson.inferredQueries?.[0]; + if (inferredQueries) { + rawResponse += `**Inferred Query**: ${inferredQueries}`; + } } if (responseAsJson.detail) { // If response has detail field, response is an error message. @@ -543,7 +553,7 @@ To get started, just start typing below. You can also type / to see a list of co .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]); }); }) .catch(err => { diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 065020d3..d800da94 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -721,7 +721,7 @@ async def chat( metadata={"conversation_command": conversation_command.value}, **common.__dict__, ) - image, status_code = await text_to_image(q) + image, status_code, improved_image_prompt = await text_to_image(q) if image is None: content_obj = { "image": image, @@ -729,8 +729,10 @@ async def chat( "detail": "Failed to generate image. Make sure your image generation configuration is set.", } return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image") - content_obj = {"image": image, "intentType": "text-to-image"} + await sync_to_async(save_to_conversation_log)( + q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt] + ) + content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index ed366590..1c6b437e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -286,7 +286,7 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: logger.error(f"Image Generation failed with {e}", exc_info=True) status_code = 500 - return image, status_code + return image, status_code, improved_image_prompt class ApiUserRateLimiter: From 03cb86ee46d9e477fec14348f6baed63b83c2b2f Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 17 Dec 2023 21:28:33 +0530 Subject: [PATCH 25/30] Update typing and object assignment for new text to image method return --- src/khoj/routers/api.py | 2 +- src/khoj/routers/helpers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d800da94..a243e4dc 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -732,7 +732,7 @@ async def chat( await sync_to_async(save_to_conversation_log)( q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt] ) - content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} + content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore[assignment] return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 1c6b437e..ff063cab 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -266,7 +266,7 @@ def generate_chat_response( return chat_response, metadata -async def text_to_image(message: str) -> Tuple[Optional[str], int]: +async def text_to_image(message: str) -> Tuple[Optional[str], int, Optional[str]]: status_code = 200 image = None From 5b092d59f4b00bc20a040beca6d29ebdaa407ae3 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 17 Dec 2023 22:34:54 +0530 Subject: [PATCH 26/30] Ignore dict assignment typing error --- src/khoj/routers/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index a243e4dc..f3d741e7 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -732,7 +732,7 @@ async def chat( await sync_to_async(save_to_conversation_log)( q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt] ) - content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore[assignment] + content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. From 903a01745ffaa890783e3dd05448ac7ff96b5b5c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 18 Dec 2023 16:09:06 +0530 Subject: [PATCH 27/30] Use 0px for padding for input row buttons in web --- src/khoj/interface/web/chat.html | 1 + 1 file changed, 1 insertion(+) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 2c64d96c..f17e9b11 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -987,6 +987,7 @@ To get started, just start typing below. You can also type / to see a list of co border: 1px solid var(--main-text-color); box-shadow: 0 0 11px #aaa; border-radius: 5px; + padding: 0px; font-size: 14px; font-weight: 300; line-height: 1.5em; From 946305d977a5985c954ff97e9e5fcbfd423d3af9 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 19 Dec 2023 16:05:20 +0530 Subject: [PATCH 28/30] Add function to export conversations for debugging --- src/khoj/database/admin.py | 103 +++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index f64a02df..f92b670d 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -1,5 +1,9 @@ +import csv +import json + from django.contrib import admin from django.contrib.auth.admin import UserAdmin +from django.http import HttpResponse # Register your models here. @@ -13,6 +17,7 @@ from khoj.database.models import ( Subscription, ReflectiveQuestion, TextToImageModelConfig, + Conversation, ) admin.site.register(KhojUser, UserAdmin) @@ -25,3 +30,101 @@ admin.site.register(SearchModelConfig) admin.site.register(Subscription) admin.site.register(ReflectiveQuestion) admin.site.register(TextToImageModelConfig) + + +@admin.register(Conversation) +class ConversationAdmin(admin.ModelAdmin): + list_display = ( + "id", + "user", + "created_at", + "updated_at", + ) + search_fields = ("conversation_id",) + ordering = ("-created_at",) + + actions = ["export_selected_objects", "export_selected_minimal_objects"] + + def export_selected_objects(self, request, queryset): + response = HttpResponse(content_type="text/csv") + response["Content-Disposition"] = 'attachment; filename="conversations.csv"' + + writer = csv.writer(response) + writer.writerow(["id", "user", "created_at", "updated_at", "conversation_log"]) + + for conversation in queryset: + modified_log = conversation.conversation_log + chat_log = modified_log.get("chat", []) + for idx, log in enumerate(chat_log): + if ( + log["by"] == "khoj" + and log["intent"] + and log["intent"]["type"] + and log["intent"]["type"] == "text-to-image" + ): + log["message"] = "image redacted for space" + chat_log[idx] = log + modified_log["chat"] = chat_log + + writer.writerow( + [ + conversation.id, + conversation.user, + conversation.created_at, + conversation.updated_at, + json.dumps(modified_log), + ] + ) + + return response + + export_selected_objects.short_description = "Export selected conversations" + + def export_selected_minimal_objects(self, request, queryset): + response = HttpResponse(content_type="text/csv") + response["Content-Disposition"] = 'attachment; filename="conversations.csv"' + + writer = csv.writer(response) + writer.writerow(["id", "user", "created_at", "updated_at", "conversation_log"]) + + fields_to_keep = set(["message", "by", "created"]) + + for conversation in queryset: + return_log = dict() + chat_log = conversation.conversation_log.get("chat", []) + for idx, log in enumerate(chat_log): + updated_log = {} + for key in fields_to_keep: + updated_log[key] = log[key] + if ( + log["by"] == "khoj" + and log["intent"] + and log["intent"]["type"] + and log["intent"]["type"] == "text-to-image" + ): + updated_log["message"] = "image redacted for space" + chat_log[idx] = updated_log + return_log["chat"] = chat_log + + writer.writerow( + [ + conversation.id, + conversation.user, + conversation.created_at, + conversation.updated_at, + json.dumps(return_log), + ] + ) + + return response + + export_selected_minimal_objects.short_description = "Export selected conversations (minimal)" + + def get_actions(self, request): + actions = super().get_actions(request) + if not request.user.is_superuser: + if "export_selected_objects" in actions: + del actions["export_selected_objects"] + if "export_selected_minimal_objects" in actions: + del actions["export_selected_minimal_objects"] + return actions From 927e477f68a17765a9e9b7f01026a076b746ef77 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 19 Dec 2023 16:10:58 +0530 Subject: [PATCH 29/30] Ignore typing error in custom action short description --- src/khoj/database/admin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index f92b670d..491eb091 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -78,7 +78,7 @@ class ConversationAdmin(admin.ModelAdmin): return response - export_selected_objects.short_description = "Export selected conversations" + export_selected_objects.short_description = "Export selected conversations" # type: ignore def export_selected_minimal_objects(self, request, queryset): response = HttpResponse(content_type="text/csv") @@ -118,7 +118,7 @@ class ConversationAdmin(admin.ModelAdmin): return response - export_selected_minimal_objects.short_description = "Export selected conversations (minimal)" + export_selected_minimal_objects.short_description = "Export selected conversations (minimal)" # type: ignore def get_actions(self, request): actions = super().get_actions(request) From e3557cd8b7738fd6ca78473214395b9b1c555b11 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 19 Dec 2023 16:42:45 +0530 Subject: [PATCH 30/30] Update the personality prompt to make Khoj aware that users can share data via the desktop app --- src/khoj/processor/conversation/prompts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 31cd9160..d0c4aba4 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -15,6 +15,7 @@ You were created by Khoj Inc. with the following capabilities: - Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question. - Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. - Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay". +- Users can share information with you using the Khoj app, which is available for download at https://khoj.dev/downloads. Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev. Today is {current_date} in UTC.