461 lines
18 KiB
Python
461 lines
18 KiB
Python
'''
|
|
Image generation module using StableDiffusion and similar models by way of ComfyUI.
|
|
DEPENDS ON:
|
|
LLM module
|
|
COMFYUI_URL, COMFYUI_DIR, COMFYUI_OUTPUT_DIR, TS_SUBNET, TS_ADDRESS, DATA_DIR, IMG_CONFIG_DIR, IMG_DIR, IMG_WORKFLOWS_DIR, LOCAL_HOSTS, API.URL, PHOTOPRISM_USER*, PHOTOPRISM_URL*, PHOTOPRISM_PASS*
|
|
*unimplemented.
|
|
'''
|
|
|
|
from fastapi import APIRouter, Request, Response, Query
|
|
from starlette.datastructures import Address
|
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from aiohttp import ClientSession, ClientTimeout
|
|
import aiofiles
|
|
from PIL import Image
|
|
from pathlib import Path
|
|
import uuid
|
|
import json
|
|
import yaml
|
|
import ipaddress
|
|
import socket
|
|
import subprocess
|
|
import os, re, io
|
|
import random
|
|
from io import BytesIO
|
|
import base64
|
|
import asyncio
|
|
import shutil
|
|
# from photoprism.Session import Session
|
|
# from photoprism.Photo import Photo
|
|
# from webdav3.client import Client
|
|
from sijapi.routers.llm import query_ollama
|
|
from sijapi import API, L, COMFYUI_URL, COMFYUI_OUTPUT_DIR, IMG_CONFIG_PATH, IMG_DIR, IMG_WORKFLOWS_DIR
|
|
|
|
img = APIRouter()
|
|
logger = L.get_module_logger("img")
|
|
CLIENT_ID = str(uuid.uuid4())
|
|
|
|
@img.post("/img")
|
|
@img.post("/v1/images/generations")
|
|
async def sd_endpoint(request: Request):
|
|
request_data = await request.json()
|
|
prompt = request_data.get("prompt")
|
|
model = request_data.get("model")
|
|
size = request_data.get("size")
|
|
earlyurl = request_data.get("earlyurl", None)
|
|
earlyout = "web" if earlyurl else None
|
|
|
|
image_path = await workflow(prompt=prompt, scene=model, size=size, earlyout=earlyout)
|
|
|
|
if earlyout == "web":
|
|
return JSONResponse({"image_url": image_path})
|
|
# return RedirectResponse(url=image_path, status_code=303)
|
|
else:
|
|
return JSONResponse({"image_url": image_path})
|
|
|
|
@img.get("/img")
|
|
@img.get("/v1/images/generations")
|
|
async def sd_endpoint(
|
|
request: Request,
|
|
prompt: str = Query(..., description="The prompt for image generation"),
|
|
earlyout: str = Query("output", description="specify web for a redirect, or json for a json with the local path")
|
|
):
|
|
image_path = await workflow(prompt=prompt, scene="wallpaper", earlyout=earlyout)
|
|
web_path = get_web_path(image_path)
|
|
|
|
if earlyout == "web":
|
|
return RedirectResponse(url=web_path, status_code=303)
|
|
else:
|
|
return JSONResponse({"image_url": image_path})
|
|
|
|
async def workflow(prompt: str, scene: str = None, size: str = None, earlyout: str = None, destination_path: str = None, downscale_to_fit: bool = False):
|
|
scene_data = get_scene(scene)
|
|
if not scene_data:
|
|
scene_data = get_matching_scene(prompt)
|
|
prompt = scene_data.get('llm_pre_prompt') + prompt
|
|
prompt_model = scene_data.get('prompt_model')
|
|
image_concept = await query_ollama(usr=prompt, sys=scene_data.get('llm_sys_msg'), model=prompt_model, max_tokens=100)
|
|
|
|
scene_workflow = random.choice(scene_data['workflows'])
|
|
if size:
|
|
logger.debug(f"Specified size: {size}")
|
|
|
|
size = size if size else scene_workflow.get('size', '1024x1024')
|
|
|
|
width, height = map(int, size.split('x'))
|
|
logger.debug(f"Parsed width: {width}; parsed height: {height}")
|
|
|
|
workflow_path = Path(IMG_WORKFLOWS_DIR) / scene_workflow['workflow']
|
|
workflow_data = json.loads(workflow_path.read_text())
|
|
|
|
post = {
|
|
"API_PrePrompt": scene_data['API_PrePrompt'] + image_concept + ', '.join(f"; (({trigger}))" for trigger in scene_data['triggers']),
|
|
"API_StylePrompt": scene_data['API_StylePrompt'],
|
|
"API_NegativePrompt": scene_data['API_NegativePrompt'],
|
|
"width": width,
|
|
"height": height
|
|
}
|
|
|
|
saved_file_key = await update_prompt_and_get_key(workflow=workflow_data, post=post, positive=image_concept)
|
|
print(f"Saved file key: {saved_file_key}")
|
|
|
|
prompt_id = await queue_prompt(workflow_data)
|
|
print(f"Prompt ID: {prompt_id}")
|
|
|
|
max_size = max(width, height) if downscale_to_fit else None
|
|
destination_path = Path(destination_path).with_suffix(".jpg") if destination_path else IMG_DIR / f"{prompt_id}.jpg"
|
|
|
|
if earlyout:
|
|
asyncio.create_task(generate_and_save_image(prompt_id, saved_file_key, max_size, destination_path))
|
|
logger.debug(f"Returning {destination_path}")
|
|
return destination_path
|
|
|
|
else:
|
|
await generate_and_save_image(prompt_id, saved_file_key, max_size, destination_path)
|
|
logger.debug(f"Returning {destination_path}")
|
|
return destination_path
|
|
|
|
|
|
async def generate_and_save_image(prompt_id, saved_file_key, max_size, destination_path):
|
|
try:
|
|
status_data = await poll_status(prompt_id)
|
|
image_data = await get_image(status_data, saved_file_key)
|
|
jpg_file_path = await save_as_jpg(image_data, prompt_id, quality=90, max_size=max_size, destination_path=destination_path)
|
|
|
|
if Path(jpg_file_path) != Path(destination_path):
|
|
logger.error(f"Mismatch between jpg_file_path, {jpg_file_path}, and detination_path, {destination_path}")
|
|
|
|
except Exception as e:
|
|
print(f"Error in generate_and_save_image: {e}")
|
|
return None
|
|
|
|
|
|
def get_web_path(file_path: Path) -> str:
|
|
uri = file_path.relative_to(IMG_DIR)
|
|
web_path = f"{API.URL}/img/{uri}"
|
|
return web_path
|
|
|
|
|
|
async def poll_status(prompt_id):
|
|
"""Asynchronously poll the job status until it's complete and return the status data."""
|
|
start_time = asyncio.get_event_loop().time()
|
|
async with ClientSession() as session:
|
|
while True:
|
|
elapsed_time = int(asyncio.get_event_loop().time() - start_time)
|
|
async with session.get(f"{COMFYUI_URL}/history/{prompt_id}") as response:
|
|
if response.status != 200:
|
|
raise Exception("Failed to get job status")
|
|
status_data = await response.json()
|
|
job_data = status_data.get(prompt_id, {})
|
|
if job_data.get("status", {}).get("completed", False):
|
|
print(f"{prompt_id} completed in {elapsed_time} seconds.")
|
|
return job_data
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
async def get_image(status_data, key):
|
|
"""Asynchronously extract the filename and subfolder from the status data and read the file."""
|
|
try:
|
|
outputs = status_data.get("outputs", {})
|
|
images_info = outputs.get(key, {}).get("images", [])
|
|
if not images_info:
|
|
raise Exception("No images found in the job output.")
|
|
|
|
image_info = images_info[0]
|
|
filename = image_info.get("filename")
|
|
subfolder = image_info.get("subfolder", "")
|
|
file_path = os.path.join(COMFYUI_OUTPUT_DIR, subfolder, filename)
|
|
|
|
async with aiofiles.open(file_path, 'rb') as file:
|
|
return await file.read()
|
|
except Exception as e:
|
|
raise Exception(f"Failed to get image: {e}")
|
|
|
|
|
|
async def save_as_jpg(image_data, prompt_id, max_size = None, quality = 100, destination_path: Path = None):
|
|
destination_path_png = (IMG_DIR / prompt_id).with_suffix(".png")
|
|
destination_path_jpg = destination_path.with_suffix(".jpg") if destination_path else (IMG_DIR / prompt_id).with_suffix(".jpg")
|
|
|
|
try:
|
|
destination_path_png.parent.mkdir(parents=True, exist_ok=True)
|
|
destination_path_jpg.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save the PNG
|
|
async with aiofiles.open(destination_path_png, 'wb') as f:
|
|
await f.write(image_data)
|
|
|
|
# Open, possibly resize, and save as JPG
|
|
with Image.open(destination_path_png) as img:
|
|
if max_size and max(img.size) > max_size:
|
|
ratio = max_size / max(img.size)
|
|
new_size = tuple([int(x * ratio) for x in img.size])
|
|
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
|
|
|
img.convert('RGB').save(destination_path_jpg, format='JPEG', quality=quality)
|
|
|
|
# Optionally remove the PNG
|
|
os.remove(destination_path_png)
|
|
|
|
return str(destination_path_jpg)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing image: {e}")
|
|
return None
|
|
|
|
|
|
|
|
def set_presets(workflow_data, preset_values):
|
|
if preset_values:
|
|
preset_node = preset_values.get('node')
|
|
preset_key = preset_values.get('key')
|
|
values = preset_values.get('values')
|
|
|
|
if preset_node and preset_key and values:
|
|
preset_value = random.choice(values)
|
|
if 'inputs' in workflow_data.get(preset_node, {}):
|
|
workflow_data[preset_node]['inputs'][preset_key] = preset_value
|
|
else:
|
|
logger.debug("Node not found in workflow_data")
|
|
else:
|
|
logger.debug("Required data missing in preset_values")
|
|
else:
|
|
logger.debug("No preset_values found")
|
|
|
|
|
|
def get_return_path(destination_path):
|
|
sd_dir = Path(IMG_DIR)
|
|
if destination_path.parent.samefile(sd_dir):
|
|
return destination_path.name
|
|
else:
|
|
return str(destination_path)
|
|
|
|
def get_scene(scene):
|
|
with open(IMG_CONFIG_PATH, 'r') as IMG_CONFIG_file:
|
|
IMG_CONFIG = yaml.safe_load(IMG_CONFIG_file)
|
|
for scene_data in IMG_CONFIG['scenes']:
|
|
if scene_data['scene'] == scene:
|
|
logger.debug(f"Found scene for \"{scene}\".")
|
|
return scene_data
|
|
return None
|
|
|
|
# This returns the scene with the most trigger words present in the provided prompt,
|
|
# or otherwise if none match it returns the first scene in the array -
|
|
# meaning the first should be considered the default scene.
|
|
def get_matching_scene(prompt):
|
|
prompt_lower = prompt.lower()
|
|
max_count = 0
|
|
scene_data = None
|
|
with open(IMG_CONFIG_PATH, 'r') as IMG_CONFIG_file:
|
|
IMG_CONFIG = yaml.safe_load(IMG_CONFIG_file)
|
|
for sc in IMG_CONFIG['scenes']:
|
|
count = sum(1 for trigger in sc['triggers'] if trigger in prompt_lower)
|
|
if count > max_count:
|
|
max_count = count
|
|
scene_data = sc
|
|
if scene_data:
|
|
logger.debug(f"Found better-matching scene: the prompt contains {max_count} words that match triggers for {scene_data.get('name')}!")
|
|
if scene_data:
|
|
return scene_data
|
|
else:
|
|
logger.debug(f"No matching scenes found, falling back to default scene.")
|
|
return IMG_CONFIG['scenes'][0]
|
|
|
|
|
|
|
|
import asyncio
|
|
import socket
|
|
import subprocess
|
|
from typing import Optional
|
|
|
|
async def ensure_comfy(retries: int = 4, timeout: float = 6.0):
|
|
"""
|
|
Ensures that ComfyUI is running, starting it if necessary.
|
|
|
|
Args:
|
|
retries (int): Number of connection attempts. Defaults to 3.
|
|
timeout (float): Time to wait between attempts in seconds. Defaults to 5.0.
|
|
|
|
Raises:
|
|
RuntimeError: If ComfyUI couldn't be started or connected to after all retries.
|
|
"""
|
|
for attempt in range(retries):
|
|
try:
|
|
with socket.create_connection(("127.0.0.1", 8188), timeout=2):
|
|
print("ComfyUI is already running.")
|
|
return
|
|
except (socket.timeout, ConnectionRefusedError):
|
|
if attempt == 0: # Only try to start ComfyUI on the first failed attempt
|
|
print("ComfyUI is not running. Starting it now...")
|
|
try:
|
|
tmux_command = (
|
|
"tmux split-window -h "
|
|
"\"source /Users/sij/.zshrc; cd /Users/sij/workshop/ComfyUI; "
|
|
"mamba activate comfyui && "
|
|
"python main.py; exec $SHELL\""
|
|
)
|
|
subprocess.Popen(tmux_command, shell=True)
|
|
print("ComfyUI started in a new tmux session.")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error starting ComfyUI: {e}")
|
|
|
|
print(f"Attempt {attempt + 1}/{retries} failed. Waiting {timeout} seconds before retrying...")
|
|
await asyncio.sleep(timeout)
|
|
|
|
raise RuntimeError(f"Failed to ensure ComfyUI is running after {retries} attempts with {timeout} second intervals.")
|
|
|
|
# async def upload_and_get_shareable_link(image_path):
|
|
# try:
|
|
# Set up the PhotoPrism session
|
|
# pp_session = Session(PHOTOPRISM_USER, PHOTOPRISM_PASS, PHOTOPRISM_URL, use_https=True)
|
|
# pp_session.create()
|
|
|
|
# Start import
|
|
# photo = Photo(pp_session)
|
|
# photo.start_import(path=os.path.dirname(image_path))
|
|
|
|
# Give PhotoPrism some time to process the upload
|
|
# await asyncio.sleep(5)
|
|
|
|
# Search for the uploaded photo
|
|
# photo_name = os.path.basename(image_path)
|
|
# search_results = photo.search(query=f"name:{photo_name}", count=1)
|
|
|
|
# if search_results['photos']:
|
|
# photo_uuid = search_results['photos'][0]['uuid']
|
|
# shareable_link = f"https://{PHOTOPRISM_URL}/p/{photo_uuid}"
|
|
# return shareable_link
|
|
# else:
|
|
# logger.error("Could not find the uploaded photo details.")
|
|
# return None
|
|
# except Exception as e:
|
|
# logger.error(f"Error in upload_and_get_shareable_link: {e}")
|
|
# return None
|
|
|
|
|
|
|
|
@img.get("/image/{prompt_id}")
|
|
async def get_image_status(prompt_id: str):
|
|
status_data = await poll_status(prompt_id)
|
|
save_image_key = None
|
|
for key, value in status_data.get("outputs", {}).items():
|
|
if "images" in value:
|
|
save_image_key = key
|
|
break
|
|
if save_image_key:
|
|
image_data = await get_image(status_data, save_image_key)
|
|
await save_as_jpg(image_data, prompt_id)
|
|
external_url = f"https://api.lone.blue/img/{prompt_id}.jpg"
|
|
return JSONResponse({"image_url": external_url})
|
|
else:
|
|
return JSONResponse(content={"status": "Processing", "details": status_data}, status_code=202)
|
|
|
|
@img.get("/image-status/{prompt_id}")
|
|
async def get_image_processing_status(prompt_id: str):
|
|
try:
|
|
status_data = await poll_status(prompt_id)
|
|
return JSONResponse(content={"status": "Processing", "details": status_data}, status_code=200)
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
|
|
|
|
@img.options("/v1/images/generations", tags=["generations"])
|
|
async def get_generation_options():
|
|
return {
|
|
"model": {
|
|
"description": "The model to use for image generation.",
|
|
"type": "string",
|
|
"example": "stable-diffusion"
|
|
},
|
|
"prompt": {
|
|
"description": "The text prompt for the image generation.",
|
|
"type": "string",
|
|
"required": True,
|
|
"example": "A beautiful sunset over the ocean."
|
|
},
|
|
"n": {
|
|
"description": "The number of images to generate.",
|
|
"type": "integer",
|
|
"default": 1,
|
|
"example": 3
|
|
},
|
|
"size": {
|
|
"description": "The size of the generated images in 'widthxheight' format.",
|
|
"type": "string",
|
|
"default": "1024x1024",
|
|
"example": "512x512"
|
|
},
|
|
"raw": {
|
|
"description": "Whether to return raw image data or not.",
|
|
"type": "boolean",
|
|
"default": False
|
|
},
|
|
"earlyurl": {
|
|
"description": "Whether to return the URL early or wait for the image to be ready.",
|
|
"type": "boolean",
|
|
"default": False
|
|
}
|
|
}
|
|
|
|
|
|
async def load_workflow(workflow_path: str, workflow:str):
|
|
workflow_path = workflow_path if workflow_path else os.path.join(IMG_WORKFLOWS_DIR, f"{workflow}.json" if not workflow.endswith('.json') else workflow)
|
|
with open(workflow_path, 'r') as file:
|
|
return json.load(file)
|
|
|
|
|
|
async def update_prompt_and_get_key(workf0ow: dict, post: dict, positive: str):
|
|
'''
|
|
Recurses through the workflow searching for and substituting the dynamic values for API_PrePrompt, API_StylePrompt, API_NegativePrompt, width, height, and seed (random integer).
|
|
Even more important, it finds and returns the key to the filepath where the file is saved, which we need to decipher status when generation is complete.
|
|
'''
|
|
found_key = [None]
|
|
|
|
def update_recursive(workflow, path=None):
|
|
if path is None:
|
|
path = []
|
|
|
|
if isinstance(workflow, dict):
|
|
for key, value in workflow.items():
|
|
current_path = path + [key]
|
|
|
|
if isinstance(value, dict):
|
|
if value.get('class_type') == 'SaveImage' and value.get('inputs', {}).get('filename_prefix') == 'API_':
|
|
found_key[0] = key
|
|
update_recursive(value, current_path)
|
|
elif isinstance(value, list):
|
|
for index, item in enumerate(value):
|
|
update_recursive(item, current_path + [str(index)])
|
|
|
|
if value == "API_PrePrompt":
|
|
workflow[key] = post.get(value, "") + positive
|
|
elif value in ["API_StylePrompt", "API_NegativePrompt"]:
|
|
workflow[key] = post.get(value, "")
|
|
elif key in ["seed", "noise_seed"]:
|
|
workflow[key] = random.randint(1000000000000, 9999999999999)
|
|
|
|
elif key in ["width", "max_width", "scaled_width", "height", "max_height", "scaled_height", "side_length", "size", "value", "dimension", "dimensions", "long", "long_side", "short", "short_side", "length"]:
|
|
logger.debug(f"Got a hit for a dimension: {key} {value}")
|
|
if value == 1023:
|
|
workflow[key] = post.get("width", 1024)
|
|
logger.debug(f"Set {key} to {workflow[key]}.")
|
|
elif value == 1025:
|
|
workflow[key] = post.get("height", 1024)
|
|
logger.debug(f"Set {key} to {workflow[key]}.")
|
|
|
|
update_recursive(workflow)
|
|
return found_key[0]
|
|
|
|
|
|
|
|
async def queue_prompt(workflow_data):
|
|
await ensure_comfy()
|
|
async with ClientSession() as session:
|
|
|
|
async with session.post(f"{COMFYUI_URL}/prompt", json={"prompt": workflow_data}) as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
return data.get('prompt_id')
|
|
else:
|
|
raise Exception(f"Failed to queue prompt. Status code: {response.status}")
|