Support /image slash command to generate images using the chat API

This commit is contained in:
Debanjum Singh Solanky 2023-12-04 17:58:04 -05:00
parent 1d9c1333f2
commit 252b35b2f0
3 changed files with 40 additions and 6 deletions

View file

@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import (
ApiUserRateLimiter,
CommonQueryParams,
agenerate_chat_response,
get_conversation_command,
text_to_image,
is_ready_to_chat,
update_telemetry_state,
validate_conversation_config,
@ -665,7 +667,7 @@ async def chat(
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response:
user = request.user.object
user: KhojUser = request.user.object
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
@ -703,6 +705,11 @@ async def chat(
media_type="text/event-stream",
status_code=200,
)
elif conversation_command == ConversationCommand.Image:
image_url, status_code = await text_to_image(q)
await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image")
content_obj = {"imageUrl": image_url, "intentType": "text-to-image"}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response(

View file

@ -9,23 +9,23 @@ from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages
from fastapi import Depends, Header, HTTPException, Request, UploadFile
import openai
from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
# Internal Packages
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.General
elif query.startswith("/online"):
return ConversationCommand.Online
elif query.startswith("/image"):
return ConversationCommand.Image
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
@ -248,6 +250,29 @@ def generate_chat_response(
return chat_response, metadata
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
status_code = 200
image_url = None
# Send the audio data to the Whisper API
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unprocessable entity error
status_code = 422
elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
client = openai.OpenAI(api_key=openai_chat_config.api_key)
text2image_model = text_to_image_config.model_name
try:
response = client.images.generate(prompt=message, model=text2image_model)
image_url = response.data[0].url
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
status_code = 500
return image_url, status_code
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests

View file

@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
Notes = "notes"
Help = "help"
Online = "online"
Image = "image"
command_descriptions = {
@ -280,6 +281,7 @@ command_descriptions = {
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.",
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}