diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 7eda29d4..358edf2a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -48,6 +48,7 @@ from khoj.database.models import ( UserConversationConfig, UserRequests, UserSearchModelConfig, + UserTextToImageModelConfig, UserVoiceModelConfig, VoiceModelOption, ) @@ -905,6 +906,34 @@ class ConversationAdapters: async def aget_text_to_image_model_config(): return await TextToImageModelConfig.objects.filter().afirst() + @staticmethod + def get_text_to_image_model_options(): + return TextToImageModelConfig.objects.all() + + @staticmethod + def get_user_text_to_image_model_config(user: KhojUser): + config = UserTextToImageModelConfig.objects.filter(user=user).first() + if not config: + return None + return config.setting + + @staticmethod + async def aget_user_text_to_image_model(user: KhojUser): + config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() + if not config: + return None + return config.setting + + @staticmethod + async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int): + config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst() + if not config: + return None + new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create( + user=user, defaults={"setting": config} + ) + return new_config + class FileObjectAdapters: @staticmethod diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 3bc0f76d..95e3508c 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions) admin.site.register(SearchModelConfig) admin.site.register(ReflectiveQuestion) admin.site.register(UserSearchModelConfig) -admin.site.register(TextToImageModelConfig) admin.site.register(ClientApplication) admin.site.register(GithubConfig) admin.site.register(NotionConfig) @@ -153,6 +152,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin): search_fields = ("id", "chat_model", "model_type") +@admin.register(TextToImageModelConfig) +class TextToImageModelOptionsAdmin(admin.ModelAdmin): + list_display = ( + "id", + "model_name", + "model_type", + ) + search_fields = ("id", "model_name", "model_type") + + @admin.register(OpenAIProcessorConversationConfig) class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin): list_display = ( diff --git a/src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py b/src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py new file mode 100644 index 00000000..c7ff9e81 --- /dev/null +++ b/src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py @@ -0,0 +1,58 @@ +# Generated by Django 4.2.11 on 2024-06-26 03:27 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0048_voicemodeloption_uservoicemodelconfig"), + ] + + operations = [ + migrations.AddField( + model_name="texttoimagemodelconfig", + name="api_key", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + migrations.AddField( + model_name="texttoimagemodelconfig", + name="openai_config", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.openaiprocessorconversationconfig", + ), + ), + migrations.AlterField( + model_name="texttoimagemodelconfig", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200 + ), + ), + migrations.CreateModel( + name="UserTextToImageModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "setting", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig" + ), + ), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 174d5090..62afdd2b 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -235,9 +235,37 @@ class SearchModelConfig(BaseModel): class TextToImageModelConfig(BaseModel): class ModelType(models.TextChoices): OPENAI = "openai" + STABILITYAI = "stability-ai" model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + openai_config = models.ForeignKey( + OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True + ) + + def clean(self): + # Custom validation logic + error = {} + if self.model_type == self.ModelType.OPENAI: + if self.api_key and self.openai_config: + error[ + "api_key" + ] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." + error[ + "openai_config" + ] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." + if self.model_type != self.ModelType.OPENAI: + if not self.api_key: + error["api_key"] = "The API key field must be set for non OpenAI models." + if self.openai_config: + error["openai_config"] = "OpenAI config cannot be set for non OpenAI models." + if error: + raise ValidationError(error) + + def save(self, *args, **kwargs): + self.clean() + super().save(*args, **kwargs) class SpeechToTextModelOptions(BaseModel): @@ -264,6 +292,11 @@ class UserSearchModelConfig(BaseModel): setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) +class UserTextToImageModelConfig(BaseModel): + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE) + + class Conversation(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) conversation_log = models.JSONField(default=dict) diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 7747229e..d4712206 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -332,6 +332,7 @@ margin: 20px; } + select#paint-models, select#search-models, select#voice-models, select#chat-models { diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 2c3c98db..88725c64 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -192,11 +192,37 @@
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} - {% else %} - + {% endif %} +
+ +
+
+ Chat +

+ Paint +

+
+
+ +
+
+ {% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} + + {% else %} + {% endif %} @@ -425,7 +451,7 @@ function updateChatModel() { const chatModel = document.getElementById("chat-models").value; - const saveModelButton = document.getElementById("save-model"); + const saveModelButton = document.getElementById("save-chat-model"); saveModelButton.disabled = true; saveModelButton.innerHTML = "Saving..."; @@ -491,6 +517,38 @@ }) }; + function updatePaintModel() { + const paintModel = document.getElementById("paint-models").value; + const saveModelButton = document.getElementById("save-paint-model"); + saveModelButton.disabled = true; + saveModelButton.innerHTML = "Saving..."; + + fetch('/api/config/data/paint/model?id=' + paintModel, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + } + }) + .then(response => response.json()) + .then(data => { + if (data.status == "ok") { + saveModelButton.innerHTML = "Save"; + saveModelButton.disabled = false; + + let notificationBanner = document.getElementById("notification-banner"); + notificationBanner.innerHTML = "Paint model has been updated!"; + notificationBanner.style.display = "block"; + setTimeout(function() { + notificationBanner.style.display = "none"; + }, 5000); + + } else { + saveModelButton.innerHTML = "Error"; + saveModelButton.disabled = false; + } + }) + }; + function clearContentType(content_source) { fetch('/api/config/data/content-source/' + content_source, { method: 'DELETE', diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index a1c7dff1..fcc9fc63 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -121,7 +121,7 @@ User's Notes: ## Image Generation ## -- -image_generation_improve_prompt = PromptTemplate.from_template( +image_generation_improve_prompt_dalle = PromptTemplate.from_template( """ You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt: @@ -143,6 +143,35 @@ Remember, now you are generating a prompt to improve the image generation. Add a Improved Query:""" ) +image_generation_improve_prompt_sd = PromptTemplate.from_template( + """ +You are a talented creator. Write 2-5 sentences with precise image composition, position details to create an image. +Use the provided context below to add specific, fine details to the image composition. +Retain any important information and follow any instructions from the original prompt. +Put any text to be rendered in the image within double quotes in your improved prompt. +You are provided with the following context to help enhance the original prompt: + +Today's Date: {current_date} +User's Location: {location} + +User's Notes: +{references} + +Online References: +{online_results} + +Conversation Log: +{chat_history} + +Original Prompt: "{query}" + +Now create an improved prompt using the context provided above to generate an image. +Retain any important information and follow any instructions from the original prompt. +Use the additional context from the user's notes, online references and conversation log to improve the image generation. + +Improved Prompt:""" +) + ## Online Search Conversation ## -- online_search_conversation = PromptTemplate.from_template( diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index 6f1bed30..10b1044c 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -341,6 +341,35 @@ async def update_search_model( return {"status": "ok"} +@api_config.post("/data/paint/model", status_code=200) +@requires(["authenticated"]) +async def update_paint_model( + request: Request, + id: str, + client: Optional[str] = None, +): + user = request.user.object + subscribed = has_required_scope(request, ["premium"]) + + if not subscribed: + raise HTTPException(status_code=403, detail="User is not subscribed to premium") + + new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id)) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_paint_model", + client=client, + metadata={"paint_model": new_config.setting.model_name}, + ) + + if new_config is None: + return {"status": "error", "message": "Model not found"} + + return {"status": "ok"} + + @api_config.get("/index/size", response_model=Dict[str, int]) @requires(["authenticated"]) async def get_indexed_data_size(request: Request, common: CommonQueryParams): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 90bfd8c6..835bb8c1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -453,12 +453,14 @@ async def generate_better_image_prompt( location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, + model_type: Optional[str] = None, ) -> str: """ Generate a better image prompt from the given query """ today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + model_type = model_type or TextToImageModelConfig.ModelType.OPENAI if location_data: location = f"{location_data.city}, {location_data.region}, {location_data.country}" @@ -477,21 +479,34 @@ async def generate_better_image_prompt( elif online_results[result].get("webpages"): simplified_online_results[result] = online_results[result]["webpages"] - image_prompt = prompts.image_generation_improve_prompt.format( - query=q, - chat_history=conversation_history, - location=location_prompt, - current_date=today_date, - references=user_references, - online_results=simplified_online_results, - ) + if model_type == TextToImageModelConfig.ModelType.OPENAI: + image_prompt = prompts.image_generation_improve_prompt_dalle.format( + query=q, + chat_history=conversation_history, + location=location_prompt, + current_date=today_date, + references=user_references, + online_results=simplified_online_results, + ) + elif model_type == TextToImageModelConfig.ModelType.STABILITYAI: + image_prompt = prompts.image_generation_improve_prompt_sd.format( + query=q, + chat_history=conversation_history, + location=location_prompt, + current_date=today_date, + references=user_references, + online_results=simplified_online_results, + ) summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() with timer("Chat actor: Generate contextual image prompt", logger): response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model) + response = response.strip() + if response.startswith(('"', "'")) and response.endswith(('"', "'")): + response = response[1:-1] - return response.strip() + return response async def send_message_to_model_wrapper( @@ -747,74 +762,110 @@ async def text_to_image( image_url = None intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() + text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user) if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." return image_url or image, status_code, message, intent_type.value - elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: - logger.info("Generating image with OpenAI") - text2image_model = text_to_image_config.model_name - chat_history = "" - for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: - chat_history += f"Q: {chat['intent']['query']}\n" - chat_history += f"A: {chat['message']}\n" - elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): - chat_history += f"Q: Query: {chat['intent']['query']}\n" - chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n" - try: - with timer("Improve the original user query", logger): - if send_status_func: - await send_status_func("**✍🏽 Enhancing the Painting Prompt**") - improved_image_prompt = await generate_better_image_prompt( - message, - chat_history, - location_data=location_data, - note_references=references, - online_results=online_results, - ) - with timer("Generate image with OpenAI", logger): - if send_status_func: - await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + + text2image_model = text_to_image_config.model_name + chat_history = "" + for chat in conversation_log.get("chat", [])[-4:]: + if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: + chat_history += f"Q: {chat['intent']['query']}\n" + chat_history += f"A: {chat['message']}\n" + elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): + chat_history += f"Q: Query: {chat['intent']['query']}\n" + chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n" + + with timer("Improve the original user query", logger): + if send_status_func: + await send_status_func("**✍🏽 Enhancing the Painting Prompt**") + improved_image_prompt = await generate_better_image_prompt( + message, + chat_history, + location_data=location_data, + note_references=references, + online_results=online_results, + model_type=text_to_image_config.model_type, + ) + + if send_status_func: + await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") + + if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + with timer("Generate image with OpenAI", logger): + if text_to_image_config.api_key: + api_key = text_to_image_config.api_key + elif text_to_image_config.openai_config: + api_key = text_to_image_config.openai_config.api_key + elif state.openai_client: + api_key = state.openai_client.api_key + auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {} + try: response = state.openai_client.images.generate( - prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" + prompt=improved_image_prompt, + model=text2image_model, + response_format="b64_json", + extra_headers=auth_header, ) image = response.data[0].b64_json - - with timer("Convert image to webp", logger): - # Convert png to webp for faster loading decoded_image = base64.b64decode(image) - image_io = io.BytesIO(decoded_image) - png_image = Image.open(image_io) - webp_image_io = io.BytesIO() - png_image.save(webp_image_io, "WEBP") - webp_image_bytes = webp_image_io.getvalue() - webp_image_io.close() - image_io.close() + except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: + if "content_policy_violation" in e.message: + logger.error(f"Image Generation blocked by OpenAI: {e}") + status_code = e.status_code # type: ignore + message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore + return image_url or image, status_code, message, intent_type.value + else: + logger.error(f"Image Generation failed with {e}", exc_info=True) + message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore + status_code = e.status_code # type: ignore + return image_url or image, status_code, message, intent_type.value - with timer("Upload image to S3", logger): - image_url = upload_image(webp_image_bytes, user.uuid) - if image_url: - intent_type = ImageIntentType.TEXT_TO_IMAGE2 - else: - intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - image = base64.b64encode(webp_image_bytes).decode("utf-8") - - return image_url or image, status_code, improved_image_prompt, intent_type.value - except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: - if "content_policy_violation" in e.message: - logger.error(f"Image Generation blocked by OpenAI: {e}") - status_code = e.status_code # type: ignore - message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image_url or image, status_code, message, intent_type.value - else: + elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: + with timer("Generate image with Stability AI", logger): + try: + response = requests.post( + f"https://api.stability.ai/v2beta/stable-image/generate/sd3", + headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, + files={"none": ""}, + data={ + "prompt": improved_image_prompt, + "model": text2image_model, + "mode": "text-to-image", + "output_format": "png", + "seed": 1032622926, + "aspect_ratio": "1:1", + }, + ) + decoded_image = response.content + except requests.RequestException as e: logger.error(f"Image Generation failed with {e}", exc_info=True) - message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore + message = f"Image generation failed with Stability AI error: {e}" status_code = e.status_code # type: ignore return image_url or image, status_code, message, intent_type.value - return image_url or image, status_code, response, intent_type.value + + with timer("Convert image to webp", logger): + # Convert png to webp for faster loading + image_io = io.BytesIO(decoded_image) + png_image = Image.open(image_io) + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "WEBP") + webp_image_bytes = webp_image_io.getvalue() + webp_image_io.close() + image_io.close() + + with timer("Upload image to S3", logger): + image_url = upload_image(webp_image_bytes, user.uuid) + if image_url: + intent_type = ImageIntentType.TEXT_TO_IMAGE2 + else: + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 + image = base64.b64encode(webp_image_bytes).decode("utf-8") + + return image_url or image, status_code, improved_image_prompt, intent_type.value class ApiUserRateLimiter: diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index e93d1ea3..b44c374c 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -249,6 +249,12 @@ def config_page(request: Request): current_search_model_option = adapters.get_user_search_model_or_default(user) + selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user) + paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() + all_paint_model_options = list() + for paint_model in paint_model_options: + all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id}) + notion_oauth_url = get_notion_auth_url(user) eleven_labs_enabled = is_eleven_labs_enabled() @@ -271,10 +277,12 @@ def config_page(request: Request): "anonymous_mode": state.anonymous_mode, "username": user.username, "given_name": given_name, - "conversation_options": all_conversation_options, "search_model_options": all_search_model_options, "selected_search_model_config": current_search_model_option.id, + "conversation_options": all_conversation_options, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, + "paint_model_options": all_paint_model_options, + "selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None, "user_photo": user_picture, "billing_enabled": state.billing_enabled, "subscription_state": user_subscription_state,