From 7009793170fa598916b97f80540ab13a54d1fa43 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 3 Dec 2023 18:16:00 -0500 Subject: [PATCH 01/15] Migrate to OpenAI Python library >= 1.0 --- pyproject.toml | 2 +- .../processor/conversation/openai/utils.py | 20 +++++++++---------- .../processor/conversation/openai/whisper.py | 5 +++-- src/khoj/utils/models.py | 4 ++-- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 519d2d60..5a206cce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "fastapi >= 0.104.1", "python-multipart >= 0.0.5", "jinja2 == 3.1.2", - "openai >= 0.27.0, < 1.0.0", + "openai >= 1.0.0", "tiktoken >= 0.3.2", "tenacity >= 8.2.2", "pillow ~= 9.5.0", diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 69fab7e5..a5868cb8 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): @retry( retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) ), wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(3), @@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs): @retry( retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) ), wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3), diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py index 72834d92..351319b7 100644 --- a/src/khoj/processor/conversation/openai/whisper.py +++ b/src/khoj/processor/conversation/openai/whisper.py @@ -3,7 +3,7 @@ from io import BufferedReader # External Packages from asgiref.sync import sync_to_async -import openai +from openai import OpenAI async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: @@ -11,5 +11,6 @@ async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: Transcribe audio file using Whisper model via OpenAI's API """ # Send the audio data to the Whisper API - response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key) + client = OpenAI(api_key=api_key) + response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file) return response["text"] diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index b5bbe292..37d09418 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -32,7 +32,7 @@ class OpenAI(BaseEncoder): raise Exception( f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}" ) - openai.api_key = state.processor_config.conversation.openai_model.api_key + self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key) self.embedding_dimensions = None def encode(self, entries, device=None, **kwargs): @@ -43,7 +43,7 @@ class OpenAI(BaseEncoder): processed_entry = entries[index].replace("\n", " ") try: - response = openai.Embedding.create(input=processed_entry, model=self.model_name) + response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name) embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)] # Use current models embedding dimension, once available # Else default to embedding dimensions of the text-embedding-ada-002 model From 2b09caa237f302065fbc0d6ff14638cfc5638f2d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 3 Dec 2023 19:13:28 -0500 Subject: [PATCH 02/15] Make online results an optional argument to the gpt converse method --- src/khoj/processor/conversation/openai/gpt.py | 2 +- src/khoj/routers/helpers.py | 4 ++-- tests/test_openai_chat_director.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index fed110f7..dc708ab7 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -123,8 +123,8 @@ def send_message_to_model( def converse( references, - online_results, user_query, + online_results=[], conversation_log={}, model: str = "gpt-3.5-turbo", api_key: Optional[str] = None, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 22b0d037..618f89ef 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -251,9 +251,9 @@ def generate_chat_response( chat_model = conversation_config.chat_model chat_response = converse( compiled_references, - online_results, q, - meta_log, + online_results=online_results, + conversation_log=meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion, diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 7b98d4da..50ad60af 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client): response_message = response_message.split("### compiled references")[0] # Assert - expected_responses = ["http://www.paulgraham.com/greatwork.html"] + expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"] assert response.status_code == 200 assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got: " + response_message + "Expected links or serper not setup in response but got: " + response_message ) From 316b7d471a0344cbd47e779aa49f6ef431ed0f4f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 13:46:25 -0500 Subject: [PATCH 03/15] Handle offline chat model retrieval when no internet Offline chat shouldn't fail on retrieve_model when no internet, if model was previously downloaded and usable offline --- src/khoj/processor/conversation/offline/utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 0b876b26..3a1862f7 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -12,11 +12,12 @@ def download_model(model_name: str): logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - # Download the chat model - chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) - # Decide whether to load model to GPU or CPU + chat_model_config = None try: + # Download the chat model and its config + chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) + # Try load chat model to GPU if: # 1. Loading chat model to GPU isn't disabled via CLI and # 2. Machine has GPU @@ -26,6 +27,12 @@ def download_model(model_name: str): ) except ValueError: device = "cpu" + except Exception as e: + if chat_model_config is None: + device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory + logger.debug(f"Unable to download model config from gpt4all website: {e}") + else: + raise e # Now load the downloaded chat model onto appropriate device chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) From f0222f6d0828c7c4d6bdfc90b083bfa3596d8878 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 18:54:22 -0500 Subject: [PATCH 04/15] Make save_to_conversation_log helper function reusable - Move it out to conversation.utils from generate_chat_response function - Log new optional intent_type argument to capture type of response expected. This can be type responses by Khoj e.g speech, image. It can be used to render responses by Khoj appropriately on clients - Make user_message_time argument optional, set the time to now by default if not passed by calling function --- src/khoj/processor/conversation/utils.py | 29 ++++++++++++++++++++++ src/khoj/routers/helpers.py | 31 +++--------------------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e2164719..4606efd9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -4,6 +4,7 @@ from time import perf_counter import json from datetime import datetime import queue +from typing import Any, Dict, List import tiktoken # External packages @@ -11,6 +12,8 @@ from langchain.schema import ChatMessage from transformers import AutoTokenizer # Internal Packages +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import KhojUser from khoj.utils.helpers import merge_dicts @@ -89,6 +92,32 @@ def message_to_log( return conversation_log +def save_to_conversation_log( + q: str, + chat_response: str, + user: KhojUser, + meta_log: Dict, + user_message_time: str = None, + compiled_references: List[str] = [], + online_results: Dict[str, Any] = {}, + inferred_queries: List[str] = [], + intent_type: str = "remember", +): + user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") + updated_conversation = message_to_log( + user_message=q, + chat_response=chat_response, + user_message_metadata={"created": user_message_time}, + khoj_message_metadata={ + "context": compiled_references, + "intent": {"inferred-queries": inferred_queries, "type": intent_type}, + "onlineContext": online_results, + }, + conversation_log=meta_log.get("chat", []), + ) + ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) + + def generate_chatml_messages_with_context( user_message, system_message, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 618f89ef..3e8ed155 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -19,7 +19,7 @@ from khoj.database.models import KhojUser, Subscription from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.openai.gpt import converse, send_message_to_model -from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log +from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log # Internal Packages from khoj.utils import state @@ -186,30 +186,7 @@ def generate_chat_response( conversation_command: ConversationCommand = ConversationCommand.Default, user: KhojUser = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: - def _save_to_conversation_log( - q: str, - chat_response: str, - user_message_time: str, - compiled_references: List[str], - online_results: Dict[str, Any], - inferred_queries: List[str], - meta_log, - ): - updated_conversation = message_to_log( - user_message=q, - chat_response=chat_response, - user_message_metadata={"created": user_message_time}, - khoj_message_metadata={ - "context": compiled_references, - "intent": {"inferred-queries": inferred_queries}, - "onlineContext": online_results, - }, - conversation_log=meta_log.get("chat", []), - ) - ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) - # Initialize Variables - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") chat_response = None logger.debug(f"Conversation Type: {conversation_command.name}") @@ -217,13 +194,13 @@ def generate_chat_response( try: partial_completion = partial( - _save_to_conversation_log, + save_to_conversation_log, q, - user_message_time=user_message_time, + user=user, + meta_log=meta_log, compiled_references=compiled_references, online_results=online_results, inferred_queries=inferred_queries, - meta_log=meta_log, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user) From 1d9c1333f2881c849914bb970d4b184d3436fe0d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 17:54:30 -0500 Subject: [PATCH 05/15] Configure text to image models available on server - Currently supports OpenAI text to image model, by default dall-e-3 - Allow setting the text to image model via CLI during server setup --- src/khoj/database/adapters/__init__.py | 5 ++++ .../migrations/0022_texttoimagemodelconfig.py | 25 +++++++++++++++++++ src/khoj/database/models/__init__.py | 8 ++++++ src/khoj/utils/initialization.py | 10 ++++++++ 4 files changed, 48 insertions(+) create mode 100644 src/khoj/database/migrations/0022_texttoimagemodelconfig.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 12a127e9..f1f4031e 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -22,6 +22,7 @@ from khoj.database.models import ( GithubConfig, GithubRepoConfig, GoogleUser, + TextToImageModelConfig, KhojApiUser, KhojUser, NotionConfig, @@ -414,6 +415,10 @@ class ConversationAdapters: else: raise ValueError("Invalid conversation config - either configure offline chat or openai chat") + @staticmethod + async def aget_text_to_image_model_config(): + return await TextToImageModelConfig.objects.filter().afirst() + class EntryAdapters: word_filer = WordFilter() diff --git a/src/khoj/database/migrations/0022_texttoimagemodelconfig.py b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py new file mode 100644 index 00000000..7450dc40 --- /dev/null +++ b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.7 on 2023-12-04 22:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0021_speechtotextmodeloptions_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="TextToImageModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("model_name", models.CharField(default="dall-e-3", max_length=200)), + ("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 82348fbe..00700f2f 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel): cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") +class TextToImageModelConfig(BaseModel): + class ModelType(models.TextChoices): + OPENAI = "openai" + + model_name = models.CharField(max_length=200, default="dall-e-3") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + + class OpenAIProcessorConversationConfig(BaseModel): api_key = models.CharField(max_length=200) diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 313b18fc..0bb78dbe 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -7,6 +7,7 @@ from khoj.database.models import ( OpenAIProcessorConversationConfig, ChatModelOptions, SpeechToTextModelOptions, + TextToImageModelConfig, ) from khoj.utils.constants import default_offline_chat_model, default_online_chat_model @@ -103,6 +104,15 @@ def initialization(): model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI ) + default_text_to_image_model = "dall-e-3" + openai_text_to_image_model = input( + f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): " + ) + openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model + TextToImageModelConfig.objects.create( + model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI + ) + if use_offline_model == "y" or use_openai_model == "y": logger.info("🗣️ Chat model configuration complete") From 252b35b2f01e16cba56f95e73f33e9fb7a99ef48 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 17:58:04 -0500 Subject: [PATCH 06/15] Support /image slash command to generate images using the chat API --- src/khoj/routers/api.py | 9 ++++++++- src/khoj/routers/helpers.py | 35 ++++++++++++++++++++++++++++++----- src/khoj/utils/helpers.py | 2 ++ 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ae125980..ae31c260 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.prompts import help_message, no_entries_found +from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.tools.online_search import search_with_google from khoj.routers.helpers import ( ApiUserRateLimiter, CommonQueryParams, agenerate_chat_response, get_conversation_command, + text_to_image, is_ready_to_chat, update_telemetry_state, validate_conversation_config, @@ -665,7 +667,7 @@ async def chat( rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: - user = request.user.object + user: KhojUser = request.user.object await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) @@ -703,6 +705,11 @@ async def chat( media_type="text/event-stream", status_code=200, ) + elif conversation_command == ConversationCommand.Image: + image_url, status_code = await text_to_image(q) + await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image") + content_obj = {"imageUrl": image_url, "intentType": "text-to-image"} + return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. llm_response, chat_metadata = await agenerate_chat_response( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3e8ed155..4e43289f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -9,23 +9,23 @@ from functools import partial from time import time from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union +# External Packages from fastapi import Depends, Header, HTTPException, Request, UploadFile +import openai from starlette.authentication import has_required_scope -from asgiref.sync import sync_to_async - +# Internal Packages from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import KhojUser, Subscription +from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log - -# Internal Packages from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import ConversationCommand, log_telemetry + logger = logging.getLogger(__name__) executor = ThreadPoolExecutor(max_workers=1) @@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver return ConversationCommand.General elif query.startswith("/online"): return ConversationCommand.Online + elif query.startswith("/image"): + return ConversationCommand.Image # If no relevant notes found for the given query elif not any_references: return ConversationCommand.General @@ -248,6 +250,29 @@ def generate_chat_response( return chat_response, metadata +async def text_to_image(message: str) -> Tuple[Optional[str], int]: + status_code = 200 + image_url = None + + # Send the audio data to the Whisper API + text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() + openai_chat_config = await ConversationAdapters.get_openai_chat_config() + if not text_to_image_config: + # If the user has not configured a text to image model, return an unprocessable entity error + status_code = 422 + elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + client = openai.OpenAI(api_key=openai_chat_config.api_key) + text2image_model = text_to_image_config.model_name + try: + response = client.images.generate(prompt=message, model=text2image_model) + image_url = response.data[0].url + except openai.OpenAIError as e: + logger.error(f"Image Generation failed with {e.http_status}: {e.error}") + status_code = 500 + + return image_url, status_code + + class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int): self.requests = requests diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 42e3835d..21fe7e98 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -273,6 +273,7 @@ class ConversationCommand(str, Enum): Notes = "notes" Help = "help" Online = "online" + Image = "image" command_descriptions = { @@ -280,6 +281,7 @@ command_descriptions = { ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Online: "Look up information on the internet.", + ConversationCommand.Image: "Generate images by describing your imagination in words.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.", } From cc051ceb4be756969a41b0431321b194d5c58490 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 20:40:54 -0500 Subject: [PATCH 07/15] Show generated images in chat interface on Web client --- src/khoj/interface/web/chat.html | 43 +++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index a55d8e29..39cb6e77 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -183,12 +183,18 @@ To get started, just start typing below. You can also type / to see a list of co referenceSection.appendChild(polishedReference); } } - } + } return numOnlineReferences; - } + } + + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { + if (intentType === "text-to-image") { + let imageMarkdown = `![](${message})`; + renderMessage(imageMarkdown, by, dt); + return; + } - function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { if (context == null && onlineContext == null) { renderMessage(message, by, dt); return; @@ -253,6 +259,26 @@ To get started, just start typing below. You can also type / to see a list of co // Remove any text between [INST] and tags. These are spurious instructions for the AI chat model. newHTML = newHTML.replace(/\[INST\].+(<\/s>)?/g, ''); + // Customize the rendering of images + md.renderer.rules.image = function(tokens, idx, options, env, self) { + let token = tokens[idx]; + + // Get image source url. Only render images with src links + let srcIndex = token.attrIndex('src'); + if (srcIndex < 0) { return ''; } + let src = token.attrs[srcIndex][1]; + + // Wrap the image in a link + var aStart = ``; + var aEnd = ''; + + // Add class="text-to-image" to images + token.attrPush(['class', 'text-to-image']); + + // Use the default renderer to render image markdown format + return aStart + self.renderToken(tokens, idx, options) + aEnd; + }; + // Render markdown newHTML = md.render(newHTML); // Get any elements with a class that starts with "language" @@ -414,6 +440,9 @@ To get started, just start typing below. You can also type / to see a list of co if (chunk.startsWith("{") && chunk.endsWith("}")) { try { const responseAsJson = JSON.parse(chunk); + if (responseAsJson.imageUrl) { + rawResponse += `![${query}](${responseAsJson.imageUrl})`; + } if (responseAsJson.detail) { rawResponse += responseAsJson.detail; } @@ -516,7 +545,7 @@ To get started, just start typing below. You can also type / to see a list of co .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); }); }) .catch(err => { @@ -902,6 +931,9 @@ To get started, just start typing below. You can also type / to see a list of co margin-top: -10px; transform: rotate(-60deg) } + img.text-to-image { + max-width: 60%; + } #chat-footer { padding: 0; @@ -1029,6 +1061,9 @@ To get started, just start typing below. You can also type / to see a list of co margin: 4px; grid-template-columns: auto; } + img.text-to-image { + max-width: 100%; + } } @media only screen and (min-width: 700px) { body { From 8016a57b5e662dbe6f86f0013c7359c1cc1f5cf1 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 20:56:35 -0500 Subject: [PATCH 08/15] Show generated images in chat interface on Desktop client --- src/interface/desktop/chat.html | 47 ++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 8c3fc49d..92a11ebd 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -179,7 +179,13 @@ return numOnlineReferences; } - function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { + if (intentType === "text-to-image") { + let imageMarkdown = `![](${message})`; + renderMessage(imageMarkdown, by, dt); + return; + } + if (context == null && onlineContext == null) { renderMessage(message, by, dt); return; @@ -244,6 +250,26 @@ // Remove any text between [INST] and tags. These are spurious instructions for the AI chat model. newHTML = newHTML.replace(/\[INST\].+(<\/s>)?/g, ''); + // Customize the rendering of images + md.renderer.rules.image = function(tokens, idx, options, env, self) { + let token = tokens[idx]; + + // Get image source url. Only render images with src links + let srcIndex = token.attrIndex('src'); + if (srcIndex < 0) { return ''; } + let src = token.attrs[srcIndex][1]; + + // Wrap the image in a link + var aStart = ``; + var aEnd = ''; + + // Add class="text-to-image" to images + token.attrPush(['class', 'text-to-image']); + + // Use the default renderer to render image markdown format + return aStart + self.renderToken(tokens, idx, options) + aEnd; + }; + // Render markdown newHTML = md.render(newHTML); // Get any elements with a class that starts with "language" @@ -404,16 +430,23 @@ if (newResponseText.getElementsByClassName("spinner").length > 0) { newResponseText.removeChild(loadingSpinner); } + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. if (chunk.startsWith("{") && chunk.endsWith("}")) { try { const responseAsJson = JSON.parse(chunk); + if (responseAsJson.imageUrl) { + rawResponse += `![${query}](${responseAsJson.imageUrl})`; + } if (responseAsJson.detail) { - newResponseText.innerHTML += responseAsJson.detail; + rawResponse += responseAsJson.detail; } } catch (error) { // If the chunk is not a JSON object, just display it as is - newResponseText.innerHTML += chunk; + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); } } else { // If the chunk is not a JSON object, just display it as is @@ -522,7 +555,7 @@ .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); }); }) .catch(err => { @@ -810,6 +843,9 @@ margin-top: -10px; transform: rotate(-60deg) } + img.text-to-image { + max-width: 60%; + } #chat-footer { padding: 0; @@ -1050,6 +1086,9 @@ margin: 4px; grid-template-columns: auto; } + img.text-to-image { + max-width: 100%; + } } @media only screen and (min-width: 600px) { body { From 52c5f4170a9b9cbab37180adadf0e73cd22b7444 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 21:23:31 -0500 Subject: [PATCH 09/15] Show generated images in the chat modal of the Khoj Obsidian plugin --- src/interface/obsidian/src/chat_modal.ts | 31 ++++++++++++++++++++---- src/interface/obsidian/styles.css | 4 ++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index e5cff4d2..9786e45a 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -105,15 +105,19 @@ export class KhojChatModal extends Modal { return referenceButton; } - renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) { + renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) { if (!message) { return; + } else if (intentType === "text-to-image") { + let imageMarkdown = `![](${message})`; + this.renderMessage(chatEl, imageMarkdown, sender, dt); + return; } else if (!context) { this.renderMessage(chatEl, message, sender, dt); - return + return; } else if (!!context && context?.length === 0) { this.renderMessage(chatEl, message, sender, dt); - return + return; } let chatMessageEl = this.renderMessage(chatEl, message, sender, dt); let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0] @@ -225,7 +229,7 @@ export class KhojChatModal extends Modal { let response = await request({ url: chatUrl, headers: headers }); let chatLogs = JSON.parse(response).response; chatLogs.forEach((chatLog: any) => { - this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created)); + this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type); }); } @@ -267,7 +271,7 @@ export class KhojChatModal extends Modal { this.result = ""; responseElement.innerHTML = ""; for await (const chunk of response.body) { - const responseText = chunk.toString(); + let responseText = chunk.toString(); if (responseText.includes("### compiled references:")) { const additionalResponse = responseText.split("### compiled references:")[0]; this.renderIncrementalMessage(responseElement, additionalResponse); @@ -310,6 +314,23 @@ export class KhojChatModal extends Modal { referenceExpandButton.innerHTML = expandButtonText; references.appendChild(referenceSection); } else { + if (responseText.startsWith("{") && responseText.endsWith("}")) { + try { + const responseAsJson = JSON.parse(responseText); + if (responseAsJson.imageUrl) { + responseText = `![${query}](${responseAsJson.imageUrl})`; + } else if (responseAsJson.detail) { + responseText = responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + continue; + } + } else { + // If the chunk is not a JSON object, just display it as is + continue; + } + this.renderIncrementalMessage(responseElement, responseText); } } diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css index 6a29e722..11fc0086 100644 --- a/src/interface/obsidian/styles.css +++ b/src/interface/obsidian/styles.css @@ -217,7 +217,9 @@ button.copy-button:hover { background: #f5f5f5; cursor: pointer; } - +img { + max-width: 60%; +} #khoj-chat-footer { padding: 0; From 6e3f66c0f12df8b38763f89b5576b679a1e52191 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 4 Dec 2023 22:55:22 -0500 Subject: [PATCH 10/15] Use base64 encoded image instead of source URL for persistence The source URL returned by OpenAI would expire soon. This would make the chat sessions contain non-accessible images/messages if using OpenaI image URL Get base64 encoded image from OpenAI and store directly in conversation logs. This resolves the image link expiring issue --- src/interface/desktop/chat.html | 17 ++++------------- src/interface/obsidian/src/chat_modal.ts | 6 +++--- src/khoj/interface/web/chat.html | 17 ++++------------- src/khoj/routers/api.py | 6 +++--- src/khoj/routers/helpers.py | 8 ++++---- 5 files changed, 18 insertions(+), 36 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 92a11ebd..e039c6cb 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -181,7 +181,7 @@ function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { if (intentType === "text-to-image") { - let imageMarkdown = `![](${message})`; + let imageMarkdown = `![](data:image/png;base64,${message})`; renderMessage(imageMarkdown, by, dt); return; } @@ -254,20 +254,11 @@ md.renderer.rules.image = function(tokens, idx, options, env, self) { let token = tokens[idx]; - // Get image source url. Only render images with src links - let srcIndex = token.attrIndex('src'); - if (srcIndex < 0) { return ''; } - let src = token.attrs[srcIndex][1]; - - // Wrap the image in a link - var aStart = ``; - var aEnd = ''; - // Add class="text-to-image" to images token.attrPush(['class', 'text-to-image']); // Use the default renderer to render image markdown format - return aStart + self.renderToken(tokens, idx, options) + aEnd; + return self.renderToken(tokens, idx, options); }; // Render markdown @@ -435,8 +426,8 @@ if (chunk.startsWith("{") && chunk.endsWith("}")) { try { const responseAsJson = JSON.parse(chunk); - if (responseAsJson.imageUrl) { - rawResponse += `![${query}](${responseAsJson.imageUrl})`; + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } if (responseAsJson.detail) { rawResponse += responseAsJson.detail; diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 9786e45a..145bae50 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -109,7 +109,7 @@ export class KhojChatModal extends Modal { if (!message) { return; } else if (intentType === "text-to-image") { - let imageMarkdown = `![](${message})`; + let imageMarkdown = `![](data:image/png;base64,${message})`; this.renderMessage(chatEl, imageMarkdown, sender, dt); return; } else if (!context) { @@ -317,8 +317,8 @@ export class KhojChatModal extends Modal { if (responseText.startsWith("{") && responseText.endsWith("}")) { try { const responseAsJson = JSON.parse(responseText); - if (responseAsJson.imageUrl) { - responseText = `![${query}](${responseAsJson.imageUrl})`; + if (responseAsJson.image) { + responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; } else if (responseAsJson.detail) { responseText = responseAsJson.detail; } diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 39cb6e77..97fdbebb 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -190,7 +190,7 @@ To get started, just start typing below. You can also type / to see a list of co function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { if (intentType === "text-to-image") { - let imageMarkdown = `![](${message})`; + let imageMarkdown = `![](data:image/png;base64,${message})`; renderMessage(imageMarkdown, by, dt); return; } @@ -263,20 +263,11 @@ To get started, just start typing below. You can also type / to see a list of co md.renderer.rules.image = function(tokens, idx, options, env, self) { let token = tokens[idx]; - // Get image source url. Only render images with src links - let srcIndex = token.attrIndex('src'); - if (srcIndex < 0) { return ''; } - let src = token.attrs[srcIndex][1]; - - // Wrap the image in a link - var aStart = ``; - var aEnd = ''; - // Add class="text-to-image" to images token.attrPush(['class', 'text-to-image']); // Use the default renderer to render image markdown format - return aStart + self.renderToken(tokens, idx, options) + aEnd; + return self.renderToken(tokens, idx, options); }; // Render markdown @@ -440,8 +431,8 @@ To get started, just start typing below. You can also type / to see a list of co if (chunk.startsWith("{") && chunk.endsWith("}")) { try { const responseAsJson = JSON.parse(chunk); - if (responseAsJson.imageUrl) { - rawResponse += `![${query}](${responseAsJson.imageUrl})`; + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } if (responseAsJson.detail) { rawResponse += responseAsJson.detail; diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ae31c260..d53f023a 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -706,9 +706,9 @@ async def chat( status_code=200, ) elif conversation_command == ConversationCommand.Image: - image_url, status_code = await text_to_image(q) - await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image") - content_obj = {"imageUrl": image_url, "intentType": "text-to-image"} + image, status_code = await text_to_image(q) + await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image") + content_obj = {"image": image, "intentType": "text-to-image"} return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4e43289f..f34ae815 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -252,7 +252,7 @@ def generate_chat_response( async def text_to_image(message: str) -> Tuple[Optional[str], int]: status_code = 200 - image_url = None + image = None # Send the audio data to the Whisper API text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() @@ -264,13 +264,13 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: client = openai.OpenAI(api_key=openai_chat_config.api_key) text2image_model = text_to_image_config.model_name try: - response = client.images.generate(prompt=message, model=text2image_model) - image_url = response.data[0].url + response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json") + image = response.data[0].b64_json except openai.OpenAIError as e: logger.error(f"Image Generation failed with {e.http_status}: {e.error}") status_code = 500 - return image_url, status_code + return image, status_code class ApiUserRateLimiter: From d124266923076135909e0055e9eb725fe7d78591 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 00:41:16 -0500 Subject: [PATCH 11/15] Reduce promise based nesting in chat JS func used in desktop, web client Use async/await to reduce .then() based nesting to improve code readability --- src/interface/desktop/chat.html | 196 +++++++++++++++--------------- src/khoj/interface/web/chat.html | 200 +++++++++++++++---------------- 2 files changed, 196 insertions(+), 200 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index e039c6cb..279397bd 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -346,115 +346,113 @@ chatInput.classList.remove("option-enabled"); // Call specified Khoj API which returns a streamed response of type text/plain - fetch(url, { headers }) - .then(response => { - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let rawResponse = ""; - let references = null; + let response = await fetch(url, { headers }); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let rawResponse = ""; + let references = null; - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; + } + + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); + + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); + + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); } + }); - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); + } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } + if (responseAsJson.detail) { + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); - } - - references.appendChild(referenceExpandButton); - - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); - } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); - } - }); - - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } - - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - readStream(); - } } + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + readStream(); + } } - readStream(); + + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; }); + } + readStream(); } function incrementalChat(event) { diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 97fdbebb..4a406a6b 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -309,7 +309,7 @@ To get started, just start typing below. You can also type / to see a list of co return element } - function chat() { + async function chat() { // Extract required fields for search from form let query = document.getElementById("chat-input").value.trim(); let resultsCount = localStorage.getItem("khojResultsCount") || 5; @@ -351,115 +351,113 @@ To get started, just start typing below. You can also type / to see a list of co chatInput.classList.remove("option-enabled"); // Call specified Khoj API which returns a streamed response of type text/plain - fetch(url) - .then(response => { - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let rawResponse = ""; - let references = null; + let response = await fetch(url); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let rawResponse = ""; + let references = null; - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; + } + + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); + + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); + + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); } + }); - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); + } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } + if (responseAsJson.detail) { + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); - } - - references.appendChild(referenceExpandButton); - - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); - } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); - } - }); - - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } - - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - readStream(); - } } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + readStream(); + } } - readStream(); + + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; }); - } + } + readStream(); + }; function incrementalChat(event) { if (!event.shiftKey && event.key === 'Enter') { From 8f2f053968623469e34f88d35cb076827b79fe31 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 01:03:52 -0500 Subject: [PATCH 12/15] Fix rendering image on chat response in web, desktop client --- src/interface/desktop/chat.html | 214 +++++++++++++++++-------------- src/khoj/interface/web/chat.html | 171 ++++++++++++------------ 2 files changed, 211 insertions(+), 174 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 279397bd..cad7971f 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -345,114 +345,142 @@ let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Call specified Khoj API which returns a streamed response of type text/plain + // Call specified Khoj API let response = await fetch(url, { headers }); - const reader = response.body.getReader(); - const decoder = new TextDecoder(); let rawResponse = ""; - let references = null; + const contentType = response.headers.get("content-type"); - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + if (contentType === "application/json") { + // Handle JSON response + try { + const responseAsJson = await response.json(); + if (responseAsJson.image) { + // If response has image field, response is a generated image. + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } + if (responseAsJson.detail) { + // If response has detail field, response is an error message. + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + } + } else { + // Handle streamed response of type text/event-stream or text/plain + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let references = null; - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); + readStream(); - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; } - references.appendChild(referenceExpandButton); + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); - } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); - } - }); - - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } - - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); - readStream(); - } - } + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); + } + }); + + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); + } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } + + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } + if (responseAsJson.detail) { + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + } + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + readStream(); + } + } + + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + }); + } } - readStream(); } function incrementalChat(event) { diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 4a406a6b..c243ed67 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -350,113 +350,122 @@ To get started, just start typing below. You can also type / to see a list of co let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Call specified Khoj API which returns a streamed response of type text/plain + // Call specified Khoj API let response = await fetch(url); - const reader = response.body.getReader(); - const decoder = new TextDecoder(); let rawResponse = ""; - let references = null; + const contentType = response.headers.get("content-type"); - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + if (contentType === "application/json") { + // Handle JSON response + try { + const responseAsJson = await response.json(); + if (responseAsJson.image) { + // If response has image field, response is a generated image. + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } + if (responseAsJson.detail) { + // If response has detail field, response is an error message. + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + } + } else { + // Handle streamed response of type text/event-stream or text/plain + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let references = null; - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); + readStream(); - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; } - references.appendChild(referenceExpandButton); + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); + + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); } - }); - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } + references.appendChild(referenceExpandButton); - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } + }); + + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } + // If the chunk is not a JSON object, just display it as is rawResponse += chunk; newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); readStream(); } - } + }); // Scroll to bottom of chat window as chat response is streamed document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + }; } - readStream(); }; function incrementalChat(event) { From 162b219f2bbd6fabbf90bae5503646e21925010a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 01:29:36 -0500 Subject: [PATCH 13/15] Throw unsupported error when server not configured for image, speech-to-text --- src/interface/desktop/chat.html | 10 +++++++--- src/interface/obsidian/src/chat_modal.ts | 8 +++++--- src/khoj/interface/web/chat.html | 10 +++++++--- src/khoj/routers/api.py | 4 ++-- src/khoj/routers/helpers.py | 4 ++-- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index cad7971f..6b7fde07 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -675,9 +675,13 @@ .then(response => response.ok ? response.json() : Promise.reject(response)) .then(data => { chatInput.value += data.text; }) .catch(err => { - err.status == 422 - ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") - : flashStatusInChatInput("⛔️ Failed to transcribe audio") + if (err.status === 501) { + flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") + } else if (err.status === 422) { + flashStatusInChatInput("⛔️ Audio file to large to process.") + } else { + flashStatusInChatInput("⛔️ Failed to transcribe audio.") + } }); }; diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 145bae50..09f7a181 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -410,10 +410,12 @@ export class KhojChatModal extends Modal { if (response.status === 200) { console.log(response); chatInput.value += response.json.text; - } else if (response.status === 422) { - throw new Error("⛔️ Failed to transcribe audio"); - } else { + } else if (response.status === 501) { throw new Error("⛔️ Configure speech-to-text model on server."); + } else if (response.status === 422) { + throw new Error("⛔️ Audio file to large to process."); + } else { + throw new Error("⛔️ Failed to transcribe audio."); } }; diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index c243ed67..e85759fb 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -638,9 +638,13 @@ To get started, just start typing below. You can also type / to see a list of co .then(response => response.ok ? response.json() : Promise.reject(response)) .then(data => { chatInput.value += data.text; }) .catch(err => { - err.status == 422 - ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") - : flashStatusInChatInput("⛔️ Failed to transcribe audio") + if (err.status === 501) { + flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") + } else if (err.status === 422) { + flashStatusInChatInput("⛔️ Audio file to large to process.") + } else { + flashStatusInChatInput("⛔️ Failed to transcribe audio.") + } }); }; diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d53f023a..4d2c80a2 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -626,8 +626,8 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not speech_to_text_config: - # If the user has not configured a speech to text model, return an unprocessable entity error - status_code = 422 + # If the user has not configured a speech to text model, return an unsupported on server error + status_code = 501 elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: api_key = openai_chat_config.api_key speech2text_model = speech_to_text_config.model_name diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f34ae815..a780eb20 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -258,8 +258,8 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not text_to_image_config: - # If the user has not configured a text to image model, return an unprocessable entity error - status_code = 422 + # If the user has not configured a text to image model, return an unsupported on server error + status_code = 501 elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: client = openai.OpenAI(api_key=openai_chat_config.api_key) text2image_model = text_to_image_config.model_name From 408b7413e9ef75fbc1d0a5e2f453119b4b8ea131 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 02:40:28 -0500 Subject: [PATCH 14/15] Use global openai client for transcribe, image --- src/khoj/configure.py | 6 ++++++ .../processor/conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 2 +- src/khoj/processor/conversation/openai/whisper.py | 5 ++--- src/khoj/routers/api.py | 15 ++++++--------- src/khoj/routers/helpers.py | 8 ++++---- src/khoj/utils/models.py | 12 ++---------- src/khoj/utils/state.py | 8 +++++--- 8 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19e7d403..717ad859 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -7,6 +7,7 @@ import requests import os # External Packages +import openai import schedule from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -22,6 +23,7 @@ from starlette.authentication import ( # Internal Packages from khoj.database.models import KhojUser, Subscription from khoj.database.adapters import ( + ConversationAdapters, get_all_users, get_or_create_search_model, aget_user_subscription_state, @@ -138,6 +140,10 @@ def configure_server( config = FullConfig() state.config = config + if ConversationAdapters.has_valid_openai_conversation_config(): + openai_config = ConversationAdapters.get_openai_conversation_config() + state.openai_client = openai.OpenAI(api_key=openai_config.api_key) + # Initialize Search Models from Config and initialize content try: state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 41b1844b..211e3cac 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -47,7 +47,7 @@ def extract_questions_offline( if use_history: for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj": + if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image": chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"A: {chat['message']}\n" diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index dc708ab7..7bebc26c 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -41,7 +41,7 @@ def extract_questions( [ f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n' for chat in conversation_log.get("chat", [])[-4:] - if chat["by"] == "khoj" + if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image" ] ) diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py index 351319b7..bd0e66df 100644 --- a/src/khoj/processor/conversation/openai/whisper.py +++ b/src/khoj/processor/conversation/openai/whisper.py @@ -6,11 +6,10 @@ from asgiref.sync import sync_to_async from openai import OpenAI -async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: +async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str: """ Transcribe audio file using Whisper model via OpenAI's API """ # Send the audio data to the Whisper API - client = OpenAI(api_key=api_key) response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file) - return response["text"] + return response.text diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4d2c80a2..7efd8bfd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -19,7 +19,7 @@ from starlette.authentication import requires from khoj.configure import configure_server from khoj.database import adapters from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import ChatModelOptions +from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions from khoj.database.models import Entry as DbEntry from khoj.database.models import ( GithubConfig, @@ -624,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi # Send the audio data to the Whisper API speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() - openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not speech_to_text_config: # If the user has not configured a speech to text model, return an unsupported on server error status_code = 501 - elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: - api_key = openai_chat_config.api_key + elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI: speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key) - elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE: + user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client) + elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE: speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model) + user_message = await transcribe_audio_offline(audio_filename, speech2text_model) finally: # Close and Delete the temporary audio file audio_file.close() @@ -793,7 +791,6 @@ async def extract_references_and_questions( conversation_config = await ConversationAdapters.aget_conversation_config(user) if conversation_config is None: conversation_config = await ConversationAdapters.aget_default_conversation_config() - openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() if ( offline_chat_config and offline_chat_config.enabled @@ -810,7 +807,7 @@ async def extract_references_and_questions( inferred_queries = extract_questions_offline( defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False ) - elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: + elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat = await ConversationAdapters.get_openai_chat() api_key = openai_chat_config.api_key diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a780eb20..4e883f35 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -256,15 +256,15 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: # Send the audio data to the Whisper API text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() - openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 - elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: - client = openai.OpenAI(api_key=openai_chat_config.api_key) + elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: text2image_model = text_to_image_config.model_name try: - response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json") + response = state.openai_client.images.generate( + prompt=message, model=text2image_model, response_format="b64_json" + ) image = response.data[0].b64_json except openai.OpenAIError as e: logger.error(f"Image Generation failed with {e.http_status}: {e.error}") diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index 37d09418..e1298b08 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -22,17 +22,9 @@ class BaseEncoder(ABC): class OpenAI(BaseEncoder): - def __init__(self, model_name, device=None): + def __init__(self, model_name, client: openai.OpenAI, device=None): self.model_name = model_name - if ( - not state.processor_config - or not state.processor_config.conversation - or not state.processor_config.conversation.openai_model - ): - raise Exception( - f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}" - ) - self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key) + self.openai_client = client self.embedding_dimensions = None def encode(self, entries, device=None, **kwargs): diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index b54cf4b3..d5358868 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -1,15 +1,16 @@ # Standard Packages +from collections import defaultdict import os +from pathlib import Path import threading from typing import List, Dict -from collections import defaultdict # External Packages -from pathlib import Path -from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel +from openai import OpenAI from whisper import Whisper # Internal Packages +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.utils import config as utils_config from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.helpers import LRU, get_device @@ -21,6 +22,7 @@ search_models = SearchModels() embeddings_model: EmbeddingsModel = None cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() +openai_client: OpenAI = None gpt4all_processor_config: GPT4AllProcessorModel = None whisper_model: Whisper = None config_file: Path = None From 7504669f2bef0a2217e8793f2667d8519ac0ee38 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 5 Dec 2023 03:48:07 -0500 Subject: [PATCH 15/15] Fix rendering image on chat response in obsidian client --- src/interface/obsidian/src/chat_modal.ts | 34 ++++++++++++++++-------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 09f7a181..115f4c1f 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi import { KhojSetting } from 'src/settings'; import fetch from "node-fetch"; +export interface ChatJsonResult { + image?: string; + detail?: string; +} + + export class KhojChatModal extends Modal { result: string; setting: KhojSetting; @@ -270,6 +276,23 @@ export class KhojChatModal extends Modal { this.result = ""; responseElement.innerHTML = ""; + if (response.headers.get("content-type") == "application/json") { + let responseText = "" + try { + const responseAsJson = await response.json() as ChatJsonResult; + if (responseAsJson.image) { + responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; + } else if (responseAsJson.detail) { + responseText = responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + responseText = response.body.read().toString() + } finally { + this.renderIncrementalMessage(responseElement, responseText); + } + } + for await (const chunk of response.body) { let responseText = chunk.toString(); if (responseText.includes("### compiled references:")) { @@ -315,17 +338,6 @@ export class KhojChatModal extends Modal { references.appendChild(referenceSection); } else { if (responseText.startsWith("{") && responseText.endsWith("}")) { - try { - const responseAsJson = JSON.parse(responseText); - if (responseAsJson.image) { - responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; - } else if (responseAsJson.detail) { - responseText = responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - continue; - } } else { // If the chunk is not a JSON object, just display it as is continue;