From eda33e092fd66adf33666569f31b2f1a863fe723 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 25 Apr 2024 19:12:27 +0530 Subject: [PATCH 1/6] Enable using Stable Diffusion 3 for Image Generation via API --- ...texttoimagemodelconfig_api_key_and_more.py | 24 ++++++++++++ src/khoj/database/models/__init__.py | 2 + src/khoj/routers/helpers.py | 37 ++++++++++++++----- 3 files changed, 54 insertions(+), 9 deletions(-) create mode 100644 src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py diff --git a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py new file mode 100644 index 00000000..069b16ca --- /dev/null +++ b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py @@ -0,0 +1,24 @@ +# Generated by Django 4.2.11 on 2024-06-20 19:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0047_alter_entry_file_type"), + ] + + operations = [ + migrations.AddField( + model_name="texttoimagemodelconfig", + name="api_key", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + migrations.AlterField( + model_name="texttoimagemodelconfig", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200 + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 00a594ca..89f9802b 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -234,9 +234,11 @@ class SearchModelConfig(BaseModel): class TextToImageModelConfig(BaseModel): class ModelType(models.TextChoices): OPENAI = "openai" + STABILITYAI = "stability-ai" model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + api_key = models.CharField(max_length=200, default=None, null=True, blank=True) class SpeechToTextModelOptions(BaseModel): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 90bfd8c6..8cc40f80 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -753,7 +753,7 @@ async def text_to_image( status_code = 501 message = "Failed to generate image. Setup image generation on the server." return image_url or image, status_code, message, intent_type.value - elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + elif state.openai_client: logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name chat_history = "" @@ -775,17 +775,36 @@ async def text_to_image( note_references=references, online_results=online_results, ) - with timer("Generate image with OpenAI", logger): - if send_status_func: - await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") - response = state.openai_client.images.generate( - prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" - ) - image = response.data[0].b64_json + if send_status_func: + await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + + if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + with timer("Generate image with OpenAI", logger): + response = state.openai_client.images.generate( + prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" + ) + image = response.data[0].b64_json + decoded_image = base64.b64decode(image) + + elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: + with timer("Generate image with Stability AI", logger): + response = requests.post( + f"https://api.stability.ai/v2beta/stable-image/generate/sd3", + headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, + files={"none": ""}, + data={ + "prompt": improved_image_prompt, + "model": text2image_model, + "mode": "text-to-image", + "output_format": "png", + "seed": 1032622926, + "aspect_ratio": "1:1", + }, + ) + decoded_image = response.content with timer("Convert image to webp", logger): # Convert png to webp for faster loading - decoded_image = base64.b64decode(image) image_io = io.BytesIO(decoded_image) png_image = Image.open(image_io) webp_image_io = io.BytesIO() From fdd4c0246168c91a2afd44500759ef48fb8cf6ec Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 25 Apr 2024 23:07:01 +0530 Subject: [PATCH 2/6] Use shorter prompt generator to prompt SD3 to create better images --- src/khoj/processor/conversation/prompts.py | 31 +++++++++++++++++++++- src/khoj/routers/helpers.py | 29 ++++++++++++++------ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index a1c7dff1..fcc9fc63 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -121,7 +121,7 @@ User's Notes: ## Image Generation ## -- -image_generation_improve_prompt = PromptTemplate.from_template( +image_generation_improve_prompt_dalle = PromptTemplate.from_template( """ You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt: @@ -143,6 +143,35 @@ Remember, now you are generating a prompt to improve the image generation. Add a Improved Query:""" ) +image_generation_improve_prompt_sd = PromptTemplate.from_template( + """ +You are a talented creator. Write 2-5 sentences with precise image composition, position details to create an image. +Use the provided context below to add specific, fine details to the image composition. +Retain any important information and follow any instructions from the original prompt. +Put any text to be rendered in the image within double quotes in your improved prompt. +You are provided with the following context to help enhance the original prompt: + +Today's Date: {current_date} +User's Location: {location} + +User's Notes: +{references} + +Online References: +{online_results} + +Conversation Log: +{chat_history} + +Original Prompt: "{query}" + +Now create an improved prompt using the context provided above to generate an image. +Retain any important information and follow any instructions from the original prompt. +Use the additional context from the user's notes, online references and conversation log to improve the image generation. + +Improved Prompt:""" +) + ## Online Search Conversation ## -- online_search_conversation = PromptTemplate.from_template( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 8cc40f80..66807cc3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -453,12 +453,14 @@ async def generate_better_image_prompt( location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, + model_type: Optional[str] = None, ) -> str: """ Generate a better image prompt from the given query """ today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + model_type = model_type or TextToImageModelConfig.ModelType.OPENAI if location_data: location = f"{location_data.city}, {location_data.region}, {location_data.country}" @@ -477,14 +479,24 @@ async def generate_better_image_prompt( elif online_results[result].get("webpages"): simplified_online_results[result] = online_results[result]["webpages"] - image_prompt = prompts.image_generation_improve_prompt.format( - query=q, - chat_history=conversation_history, - location=location_prompt, - current_date=today_date, - references=user_references, - online_results=simplified_online_results, - ) + if model_type == TextToImageModelConfig.ModelType.OPENAI: + image_prompt = prompts.image_generation_improve_prompt_dalle.format( + query=q, + chat_history=conversation_history, + location=location_prompt, + current_date=today_date, + references=user_references, + online_results=simplified_online_results, + ) + elif model_type == TextToImageModelConfig.ModelType.STABILITYAI: + image_prompt = prompts.image_generation_improve_prompt_sd.format( + query=q, + chat_history=conversation_history, + location=location_prompt, + current_date=today_date, + references=user_references, + online_results=simplified_online_results, + ) summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() @@ -774,6 +786,7 @@ async def text_to_image( location_data=location_data, note_references=references, online_results=online_results, + model_type=text_to_image_config.model_type, ) if send_status_func: await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") From eb09aba7472aea68f3e0b78950b54e100ce6d7bf Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 26 Apr 2024 16:49:11 +0530 Subject: [PATCH 3/6] Remove quotes wrapping the prompt from being passed to image gen model --- src/khoj/routers/helpers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 66807cc3..e7392738 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -502,8 +502,11 @@ async def generate_better_image_prompt( with timer("Chat actor: Generate contextual image prompt", logger): response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model) + response = response.strip() + if response.startswith(('"', "'")) and response.endswith(('"', "'")): + response = response[1:-1] - return response.strip() + return response async def send_message_to_model_wrapper( From 2c4bf91a611f547b57d6e4c5147db32fbbf119e5 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 21 Jun 2024 01:16:17 +0530 Subject: [PATCH 4/6] Allow user to set paint model to use from web client config page --- src/khoj/database/adapters/__init__.py | 27 ++++++++ src/khoj/database/admin.py | 11 +++- ...texttoimagemodelconfig_api_key_and_more.py | 25 +++++++- src/khoj/database/models/__init__.py | 5 ++ src/khoj/interface/web/base_config.html | 1 + src/khoj/interface/web/config.html | 64 ++++++++++++++++++- src/khoj/routers/api_config.py | 29 +++++++++ src/khoj/routers/helpers.py | 2 +- src/khoj/routers/web_client.py | 10 ++- 9 files changed, 167 insertions(+), 7 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 00ad75f1..f00fe179 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -45,6 +45,7 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserConversationConfig, + UserPaintModelConfig, UserRequests, UserSearchModelConfig, UserVoiceModelConfig, @@ -891,6 +892,32 @@ class ConversationAdapters: async def aget_text_to_image_model_config(): return await TextToImageModelConfig.objects.filter().afirst() + @staticmethod + def get_text_to_image_model_options(): + return TextToImageModelConfig.objects.all() + + @staticmethod + def get_user_paint_model_config(user: KhojUser): + config = UserPaintModelConfig.objects.filter(user=user).first() + if not config: + return None + return config.setting + + @staticmethod + async def aget_user_paint_model(user: KhojUser): + config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() + if not config: + return None + return config.setting + + @staticmethod + async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int): + config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst() + if not config: + return None + new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) + return new_config + class FileObjectAdapters: @staticmethod diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 3bc0f76d..95e3508c 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions) admin.site.register(SearchModelConfig) admin.site.register(ReflectiveQuestion) admin.site.register(UserSearchModelConfig) -admin.site.register(TextToImageModelConfig) admin.site.register(ClientApplication) admin.site.register(GithubConfig) admin.site.register(NotionConfig) @@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin): search_fields = ("id", "chat_model", "model_type") +@admin.register(TextToImageModelConfig) +class TextToImageModelOptionsAdmin(admin.ModelAdmin): + list_display = ( + "id", + "model_name", + "model_type", + ) + search_fields = ("id", "model_name", "model_type") + + @admin.register(OpenAIProcessorConversationConfig) class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin): list_display = ( diff --git a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py index 069b16ca..d5d3e972 100644 --- a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py +++ b/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py @@ -1,5 +1,7 @@ -# Generated by Django 4.2.11 on 2024-06-20 19:02 +# Generated by Django 4.2.11 on 2024-06-20 19:48 +import django.db.models.deletion +from django.conf import settings from django.db import migrations, models @@ -21,4 +23,25 @@ class Migration(migrations.Migration): choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200 ), ), + migrations.CreateModel( + name="UserPaintModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "setting", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig" + ), + ), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + options={ + "abstract": False, + }, + ), ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 89f9802b..6da7c928 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -265,6 +265,11 @@ class UserSearchModelConfig(BaseModel): setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) +class UserPaintModelConfig(BaseModel): + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE) + + class Conversation(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) conversation_log = models.JSONField(default=dict) diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 7747229e..d4712206 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -332,6 +332,7 @@ margin: 20px; } + select#paint-models, select#search-models, select#voice-models, select#chat-models { diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 2c3c98db..88725c64 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -192,11 +192,37 @@
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} - {% else %} - + {% endif %} +
+ +
+
+ Chat +

+ Paint +

+
+
+ +
+
+ {% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} + + {% else %} + {% endif %} @@ -425,7 +451,7 @@ function updateChatModel() { const chatModel = document.getElementById("chat-models").value; - const saveModelButton = document.getElementById("save-model"); + const saveModelButton = document.getElementById("save-chat-model"); saveModelButton.disabled = true; saveModelButton.innerHTML = "Saving..."; @@ -491,6 +517,38 @@ }) }; + function updatePaintModel() { + const paintModel = document.getElementById("paint-models").value; + const saveModelButton = document.getElementById("save-paint-model"); + saveModelButton.disabled = true; + saveModelButton.innerHTML = "Saving..."; + + fetch('/api/config/data/paint/model?id=' + paintModel, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + } + }) + .then(response => response.json()) + .then(data => { + if (data.status == "ok") { + saveModelButton.innerHTML = "Save"; + saveModelButton.disabled = false; + + let notificationBanner = document.getElementById("notification-banner"); + notificationBanner.innerHTML = "Paint model has been updated!"; + notificationBanner.style.display = "block"; + setTimeout(function() { + notificationBanner.style.display = "none"; + }, 5000); + + } else { + saveModelButton.innerHTML = "Error"; + saveModelButton.disabled = false; + } + }) + }; + function clearContentType(content_source) { fetch('/api/config/data/content-source/' + content_source, { method: 'DELETE', diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index 68757de6..f25ecacd 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -315,6 +315,35 @@ async def update_search_model( return {"status": "ok"} +@api_config.post("/data/paint/model", status_code=200) +@requires(["authenticated"]) +async def update_paint_model( + request: Request, + id: str, + client: Optional[str] = None, +): + user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + if not subscribed: + raise HTTPException(status_code=403, detail="User is not subscribed to premium") + + new_config = await ConversationAdapters.aset_user_paint_model(user, int(id)) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_paint_model", + client=client, + metadata={"paint_model": new_config.setting.model_name}, + ) + + if new_config is None: + return {"status": "error", "message": "Model not found"} + + return {"status": "ok"} + + @api_config.get("/index/size", response_model=Dict[str, int]) @requires(["authenticated"]) async def get_indexed_data_size(request: Request, common: CommonQueryParams): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e7392738..976caef1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -762,7 +762,7 @@ async def text_to_image( image_url = None intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() + text_to_image_config = await ConversationAdapters.aget_user_paint_model(user) if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 1dc6a2f4..87075b49 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -262,6 +262,12 @@ def config_page(request: Request): current_search_model_option = adapters.get_user_search_model_or_default(user) + selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user) + paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() + all_paint_model_options = list() + for paint_model in paint_model_options: + all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id}) + notion_oauth_url = get_notion_auth_url(user) eleven_labs_enabled = is_eleven_labs_enabled() @@ -284,10 +290,12 @@ def config_page(request: Request): "anonymous_mode": state.anonymous_mode, "username": user.username, "given_name": given_name, - "conversation_options": all_conversation_options, "search_model_options": all_search_model_options, "selected_search_model_config": current_search_model_option.id, + "conversation_options": all_conversation_options, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, + "paint_model_options": all_paint_model_options, + "selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None, "user_photo": user_picture, "billing_enabled": state.billing_enabled, "subscription_state": user_subscription_state, From 1acf969c6ee1d1db9ba03ee5412bfd84cabf81ac Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 21 Jun 2024 02:08:02 +0530 Subject: [PATCH 5/6] Do not require OpenAI to generate image as local chat + sd3 works now Previously the text_to_image helper would only trigger the image generation flow if OpenAI client was setup. This is not required anymore as offline chat model + sd3 API works. So remove that check --- src/khoj/routers/helpers.py | 160 +++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 77 deletions(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 976caef1..79bb2365 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -768,88 +768,94 @@ async def text_to_image( status_code = 501 message = "Failed to generate image. Setup image generation on the server." return image_url or image, status_code, message, intent_type.value - elif state.openai_client: - logger.info("Generating image with OpenAI") - text2image_model = text_to_image_config.model_name - chat_history = "" - for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: - chat_history += f"Q: {chat['intent']['query']}\n" - chat_history += f"A: {chat['message']}\n" - elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): - chat_history += f"Q: Query: {chat['intent']['query']}\n" - chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n" - try: - with timer("Improve the original user query", logger): - if send_status_func: - await send_status_func("**✍🏽 Enhancing the Painting Prompt**") - improved_image_prompt = await generate_better_image_prompt( - message, - chat_history, - location_data=location_data, - note_references=references, - online_results=online_results, - model_type=text_to_image_config.model_type, + + text2image_model = text_to_image_config.model_name + chat_history = "" + for chat in conversation_log.get("chat", [])[-4:]: + if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: + chat_history += f"Q: {chat['intent']['query']}\n" + chat_history += f"A: {chat['message']}\n" + elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): + chat_history += f"Q: Query: {chat['intent']['query']}\n" + chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n" + + with timer("Improve the original user query", logger): + if send_status_func: + await send_status_func("**✍🏽 Enhancing the Painting Prompt**") + improved_image_prompt = await generate_better_image_prompt( + message, + chat_history, + location_data=location_data, + note_references=references, + online_results=online_results, + model_type=text_to_image_config.model_type, + ) + + if send_status_func: + await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + + if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + with timer("Generate image with OpenAI", logger): + try: + response = state.openai_client.images.generate( + prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" ) - if send_status_func: - await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + image = response.data[0].b64_json + decoded_image = base64.b64decode(image) + except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: + if "content_policy_violation" in e.message: + logger.error(f"Image Generation blocked by OpenAI: {e}") + status_code = e.status_code # type: ignore + message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore + return image_url or image, status_code, message, intent_type.value + else: + logger.error(f"Image Generation failed with {e}", exc_info=True) + message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore + status_code = e.status_code # type: ignore + return image_url or image, status_code, message, intent_type.value - if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: - with timer("Generate image with OpenAI", logger): - response = state.openai_client.images.generate( - prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" - ) - image = response.data[0].b64_json - decoded_image = base64.b64decode(image) - - elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: - with timer("Generate image with Stability AI", logger): - response = requests.post( - f"https://api.stability.ai/v2beta/stable-image/generate/sd3", - headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, - files={"none": ""}, - data={ - "prompt": improved_image_prompt, - "model": text2image_model, - "mode": "text-to-image", - "output_format": "png", - "seed": 1032622926, - "aspect_ratio": "1:1", - }, - ) - decoded_image = response.content - - with timer("Convert image to webp", logger): - # Convert png to webp for faster loading - image_io = io.BytesIO(decoded_image) - png_image = Image.open(image_io) - webp_image_io = io.BytesIO() - png_image.save(webp_image_io, "WEBP") - webp_image_bytes = webp_image_io.getvalue() - webp_image_io.close() - image_io.close() - - with timer("Upload image to S3", logger): - image_url = upload_image(webp_image_bytes, user.uuid) - if image_url: - intent_type = ImageIntentType.TEXT_TO_IMAGE2 - else: - intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - image = base64.b64encode(webp_image_bytes).decode("utf-8") - - return image_url or image, status_code, improved_image_prompt, intent_type.value - except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: - if "content_policy_violation" in e.message: - logger.error(f"Image Generation blocked by OpenAI: {e}") - status_code = e.status_code # type: ignore - message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image_url or image, status_code, message, intent_type.value - else: + elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: + with timer("Generate image with Stability AI", logger): + try: + response = requests.post( + f"https://api.stability.ai/v2beta/stable-image/generate/sd3", + headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, + files={"none": ""}, + data={ + "prompt": improved_image_prompt, + "model": text2image_model, + "mode": "text-to-image", + "output_format": "png", + "seed": 1032622926, + "aspect_ratio": "1:1", + }, + ) + decoded_image = response.content + except requests.RequestException as e: logger.error(f"Image Generation failed with {e}", exc_info=True) - message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore + message = f"Image generation failed with Stability AI error: {e}" status_code = e.status_code # type: ignore return image_url or image, status_code, message, intent_type.value - return image_url or image, status_code, response, intent_type.value + + with timer("Convert image to webp", logger): + # Convert png to webp for faster loading + image_io = io.BytesIO(decoded_image) + png_image = Image.open(image_io) + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "WEBP") + webp_image_bytes = webp_image_io.getvalue() + webp_image_io.close() + image_io.close() + + with timer("Upload image to S3", logger): + image_url = upload_image(webp_image_bytes, user.uuid) + if image_url: + intent_type = ImageIntentType.TEXT_TO_IMAGE2 + else: + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 + image = base64.b64encode(webp_image_bytes).decode("utf-8") + + return image_url or image, status_code, improved_image_prompt, intent_type.value class ApiUserRateLimiter: From c793d8a69ed7c9a1453214ed72d7d5e25332475c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 26 Jun 2024 09:51:06 +0530 Subject: [PATCH 6/6] Add Validation logic to save PaintModel. Use API key from Paint Model Rename Paint Model, Adapters to TextToImage for consistency --- src/khoj/database/adapters/__init__.py | 16 ++++++----- ...exttoimagemodelconfig_api_key_and_more.py} | 17 +++++++++-- src/khoj/database/models/__init__.py | 28 ++++++++++++++++++- src/khoj/routers/api_config.py | 2 +- src/khoj/routers/helpers.py | 14 ++++++++-- src/khoj/routers/web_client.py | 2 +- 6 files changed, 64 insertions(+), 15 deletions(-) rename src/khoj/database/migrations/{0048_texttoimagemodelconfig_api_key_and_more.py => 0049_texttoimagemodelconfig_api_key_and_more.py} (73%) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index f00fe179..1e43887a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -45,9 +45,9 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserConversationConfig, - UserPaintModelConfig, UserRequests, UserSearchModelConfig, + UserTextToImageModelConfig, UserVoiceModelConfig, VoiceModelOption, ) @@ -897,25 +897,27 @@ class ConversationAdapters: return TextToImageModelConfig.objects.all() @staticmethod - def get_user_paint_model_config(user: KhojUser): - config = UserPaintModelConfig.objects.filter(user=user).first() + def get_user_text_to_image_model_config(user: KhojUser): + config = UserTextToImageModelConfig.objects.filter(user=user).first() if not config: return None return config.setting @staticmethod - async def aget_user_paint_model(user: KhojUser): - config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() + async def aget_user_text_to_image_model(user: KhojUser): + config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() if not config: return None return config.setting @staticmethod - async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int): + async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int): config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst() if not config: return None - new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) + new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create( + user=user, defaults={"setting": config} + ) return new_config diff --git a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py b/src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py similarity index 73% rename from src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py rename to src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py index d5d3e972..c7ff9e81 100644 --- a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py +++ b/src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-06-20 19:48 +# Generated by Django 4.2.11 on 2024-06-26 03:27 import django.db.models.deletion from django.conf import settings @@ -7,7 +7,7 @@ from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ("database", "0047_alter_entry_file_type"), + ("database", "0048_voicemodeloption_uservoicemodelconfig"), ] operations = [ @@ -16,6 +16,17 @@ class Migration(migrations.Migration): name="api_key", field=models.CharField(blank=True, default=None, max_length=200, null=True), ), + migrations.AddField( + model_name="texttoimagemodelconfig", + name="openai_config", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.openaiprocessorconversationconfig", + ), + ), migrations.AlterField( model_name="texttoimagemodelconfig", name="model_type", @@ -24,7 +35,7 @@ class Migration(migrations.Migration): ), ), migrations.CreateModel( - name="UserPaintModelConfig", + name="UserTextToImageModelConfig", fields=[ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ("created_at", models.DateTimeField(auto_now_add=True)), diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 6da7c928..80fd3b7b 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -239,6 +239,32 @@ class TextToImageModelConfig(BaseModel): model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + openai_config = models.ForeignKey( + OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True + ) + + def clean(self): + # Custom validation logic + error = {} + if self.model_type == self.ModelType.OPENAI: + if self.api_key and self.openai_config: + error[ + "api_key" + ] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." + error[ + "openai_config" + ] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." + if self.model_type != self.ModelType.OPENAI: + if not self.api_key: + error["api_key"] = "The API key field must be set for non OpenAI models." + if self.openai_config: + error["openai_config"] = "OpenAI config cannot be set for non OpenAI models." + if error: + raise ValidationError(error) + + def save(self, *args, **kwargs): + self.clean() + super().save(*args, **kwargs) class SpeechToTextModelOptions(BaseModel): @@ -265,7 +291,7 @@ class UserSearchModelConfig(BaseModel): setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) -class UserPaintModelConfig(BaseModel): +class UserTextToImageModelConfig(BaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE) diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index f25ecacd..65faf09c 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -328,7 +328,7 @@ async def update_paint_model( if not subscribed: raise HTTPException(status_code=403, detail="User is not subscribed to premium") - new_config = await ConversationAdapters.aset_user_paint_model(user, int(id)) + new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id)) update_telemetry_state( request=request, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 79bb2365..835bb8c1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -762,7 +762,7 @@ async def text_to_image( image_url = None intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - text_to_image_config = await ConversationAdapters.aget_user_paint_model(user) + text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user) if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 @@ -796,9 +796,19 @@ async def text_to_image( if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: with timer("Generate image with OpenAI", logger): + if text_to_image_config.api_key: + api_key = text_to_image_config.api_key + elif text_to_image_config.openai_config: + api_key = text_to_image_config.openai_config.api_key + elif state.openai_client: + api_key = state.openai_client.api_key + auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: response = state.openai_client.images.generate( - prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" + prompt=improved_image_prompt, + model=text2image_model, + response_format="b64_json", + extra_headers=auth_header, ) image = response.data[0].b64_json decoded_image = base64.b64decode(image) diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 87075b49..dc52a41a 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -262,7 +262,7 @@ def config_page(request: Request): current_search_model_option = adapters.get_user_search_model_or_default(user) - selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user) + selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user) paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() all_paint_model_options = list() for paint_model in paint_model_options: