mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Use shorter prompt generator to prompt SD3 to create better images
This commit is contained in:
parent
eda33e092f
commit
fdd4c02461
2 changed files with 51 additions and 9 deletions
|
@ -121,7 +121,7 @@ User's Notes:
|
|||
## Image Generation
|
||||
## --
|
||||
|
||||
image_generation_improve_prompt = PromptTemplate.from_template(
|
||||
image_generation_improve_prompt_dalle = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
|
||||
|
||||
|
@ -143,6 +143,35 @@ Remember, now you are generating a prompt to improve the image generation. Add a
|
|||
Improved Query:"""
|
||||
)
|
||||
|
||||
image_generation_improve_prompt_sd = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a talented creator. Write 2-5 sentences with precise image composition, position details to create an image.
|
||||
Use the provided context below to add specific, fine details to the image composition.
|
||||
Retain any important information and follow any instructions from the original prompt.
|
||||
Put any text to be rendered in the image within double quotes in your improved prompt.
|
||||
You are provided with the following context to help enhance the original prompt:
|
||||
|
||||
Today's Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
User's Notes:
|
||||
{references}
|
||||
|
||||
Online References:
|
||||
{online_results}
|
||||
|
||||
Conversation Log:
|
||||
{chat_history}
|
||||
|
||||
Original Prompt: "{query}"
|
||||
|
||||
Now create an improved prompt using the context provided above to generate an image.
|
||||
Retain any important information and follow any instructions from the original prompt.
|
||||
Use the additional context from the user's notes, online references and conversation log to improve the image generation.
|
||||
|
||||
Improved Prompt:"""
|
||||
)
|
||||
|
||||
## Online Search Conversation
|
||||
## --
|
||||
online_search_conversation = PromptTemplate.from_template(
|
||||
|
|
|
@ -453,12 +453,14 @@ async def generate_better_image_prompt(
|
|||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
model_type: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
"""
|
||||
|
||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
|
@ -477,14 +479,24 @@ async def generate_better_image_prompt(
|
|||
elif online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
image_prompt = prompts.image_generation_improve_prompt.format(
|
||||
query=q,
|
||||
chat_history=conversation_history,
|
||||
location=location_prompt,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
)
|
||||
if model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
image_prompt = prompts.image_generation_improve_prompt_dalle.format(
|
||||
query=q,
|
||||
chat_history=conversation_history,
|
||||
location=location_prompt,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
)
|
||||
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
||||
query=q,
|
||||
chat_history=conversation_history,
|
||||
location=location_prompt,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
)
|
||||
|
||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
||||
|
||||
|
@ -774,6 +786,7 @@ async def text_to_image(
|
|||
location_data=location_data,
|
||||
note_references=references,
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
)
|
||||
if send_status_func:
|
||||
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
|
||||
|
|
Loading…
Add table
Reference in a new issue