mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Support /image slash command to generate images using the chat API
This commit is contained in:
parent
1d9c1333f2
commit
252b35b2f0
3 changed files with 40 additions and 6 deletions
|
@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
|||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||
from khoj.processor.tools.online_search import search_with_google
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
CommonQueryParams,
|
||||
agenerate_chat_response,
|
||||
get_conversation_command,
|
||||
text_to_image,
|
||||
is_ready_to_chat,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
|
@ -665,7 +667,7 @@ async def chat(
|
|||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
) -> Response:
|
||||
user = request.user.object
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||
|
@ -703,6 +705,11 @@ async def chat(
|
|||
media_type="text/event-stream",
|
||||
status_code=200,
|
||||
)
|
||||
elif conversation_command == ConversationCommand.Image:
|
||||
image_url, status_code = await text_to_image(q)
|
||||
await sync_to_async(save_to_conversation_log)(q, image_url, user, meta_log, intent_type="text-to-image")
|
||||
content_obj = {"imageUrl": image_url, "intentType": "text-to-image"}
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
|
||||
# Get the (streamed) chat response from the LLM of choice.
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
|
|
|
@ -9,23 +9,23 @@ from functools import partial
|
|||
from time import time
|
||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
# External Packages
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
import openai
|
||||
from starlette.authentication import has_required_scope
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||
return ConversationCommand.General
|
||||
elif query.startswith("/online"):
|
||||
return ConversationCommand.Online
|
||||
elif query.startswith("/image"):
|
||||
return ConversationCommand.Image
|
||||
# If no relevant notes found for the given query
|
||||
elif not any_references:
|
||||
return ConversationCommand.General
|
||||
|
@ -248,6 +250,29 @@ def generate_chat_response(
|
|||
return chat_response, metadata
|
||||
|
||||
|
||||
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
|
||||
status_code = 200
|
||||
image_url = None
|
||||
|
||||
# Send the audio data to the Whisper API
|
||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
if not text_to_image_config:
|
||||
# If the user has not configured a text to image model, return an unprocessable entity error
|
||||
status_code = 422
|
||||
elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
client = openai.OpenAI(api_key=openai_chat_config.api_key)
|
||||
text2image_model = text_to_image_config.model_name
|
||||
try:
|
||||
response = client.images.generate(prompt=message, model=text2image_model)
|
||||
image_url = response.data[0].url
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
|
||||
status_code = 500
|
||||
|
||||
return image_url, status_code
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
||||
self.requests = requests
|
||||
|
|
|
@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
|
|||
Notes = "notes"
|
||||
Help = "help"
|
||||
Online = "online"
|
||||
Image = "image"
|
||||
|
||||
|
||||
command_descriptions = {
|
||||
|
@ -280,6 +281,7 @@ command_descriptions = {
|
|||
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
||||
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
||||
ConversationCommand.Online: "Look up information on the internet.",
|
||||
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue