cleanup
This commit is contained in:
parent
45a8fca64d
commit
7c0c99a79b
2 changed files with 223 additions and 84 deletions
246
llux.py
246
llux.py
|
@ -82,6 +82,9 @@ class Llux:
|
|||
base_url=tts_url,
|
||||
api_key=tts_api_key
|
||||
)
|
||||
|
||||
self.awaiting_own_image = False # Flag to listen for bot's own image
|
||||
self.awaiting_timeout = 0 # Timestamp to stop listening
|
||||
|
||||
# Create Matrix client
|
||||
self.client = AsyncClient(self.server, self.username)
|
||||
|
@ -243,9 +246,10 @@ class Llux:
|
|||
await self.send_message(channel, f"Failed to send audio: {str(e)}")
|
||||
|
||||
|
||||
async def send_image(self, channel: str, image_path: str) -> None:
|
||||
async def send_image(self, channel: str, image_path: str) -> str:
|
||||
"""
|
||||
Send an image to a Matrix channel by uploading the file and then sending an m.image message.
|
||||
Returns the event ID of the sent message.
|
||||
"""
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
|
@ -256,11 +260,12 @@ class Llux:
|
|||
)
|
||||
if upload_error:
|
||||
self.logger.error(f"Failed to upload image: {upload_error}")
|
||||
return
|
||||
|
||||
return None
|
||||
|
||||
self.logger.debug(f"Successfully uploaded image, URI: {upload_response.content_uri}")
|
||||
|
||||
await self.client.room_send(
|
||||
|
||||
# Send the image message and capture the response
|
||||
send_response = await self.client.room_send(
|
||||
room_id=channel,
|
||||
message_type="m.room.message",
|
||||
content={
|
||||
|
@ -275,40 +280,15 @@ class Llux:
|
|||
}
|
||||
}
|
||||
)
|
||||
|
||||
event_id = send_response.event_id
|
||||
self.logger.debug(f"Image sent with event ID: {event_id}")
|
||||
return event_id
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending image: {e}", exc_info=True)
|
||||
await self.send_message(channel, f"Failed to send image: {str(e)}")
|
||||
|
||||
async def generate_image(self, prompt: str) -> str:
|
||||
"""
|
||||
Run the FLUX.1-schnell pipeline for the given prompt and saves the result as a JPEG.
|
||||
Returns the path to the file. Each call uses a new random seed for unique results.
|
||||
"""
|
||||
self.logger.debug(f"Generating FLUX image for prompt: {prompt}")
|
||||
rand_seed = random.randint(0, 2**32 - 1)
|
||||
generator = torch.Generator("cpu").manual_seed(rand_seed)
|
||||
|
||||
try:
|
||||
result = self.pipe(
|
||||
prompt,
|
||||
output_type="pil",
|
||||
num_inference_steps=self.diffusers_steps,
|
||||
generator=generator
|
||||
)
|
||||
image = result.images[0] # PIL Image
|
||||
|
||||
# Save as JPEG to minimize file size
|
||||
fd, path = tempfile.mkstemp(suffix=".jpg", dir=self.temp_dir)
|
||||
os.close(fd) # We only need the path; close the file descriptor
|
||||
image = image.convert("RGB") # Ensure we're in RGB mode
|
||||
image.save(path, "JPEG", quality=90)
|
||||
|
||||
self.logger.debug(f"Generated image saved to {path}")
|
||||
return path
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error generating image with FLUX.1-schnell: {e}", exc_info=True)
|
||||
raise e
|
||||
return None
|
||||
|
||||
async def download_image(self, mxc_url: str) -> Optional[str]:
|
||||
"""
|
||||
|
@ -342,20 +322,57 @@ class Llux:
|
|||
self.logger.error(f"Error downloading/processing image: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def generate_image(self, prompt: str) -> str:
|
||||
"""
|
||||
Generate an image based on a text prompt using the FLUX pipeline.
|
||||
|
||||
Args:
|
||||
prompt (str): The text description of the image to generate.
|
||||
|
||||
Returns:
|
||||
str: The file path to the generated image.
|
||||
"""
|
||||
self.logger.debug(f"Generating FLUX image for prompt: {prompt}")
|
||||
rand_seed = random.randint(0, 2**32 - 1)
|
||||
generator = torch.Generator("cpu").manual_seed(rand_seed)
|
||||
|
||||
try:
|
||||
result = self.pipe(
|
||||
prompt,
|
||||
output_type="pil",
|
||||
num_inference_steps=self.diffusers_steps,
|
||||
generator=generator
|
||||
)
|
||||
image = result.images[0] # PIL Image
|
||||
|
||||
fd, path = tempfile.mkstemp(suffix=".jpg", dir=self.temp_dir)
|
||||
os.close(fd)
|
||||
image = image.convert("RGB")
|
||||
image.save(path, "JPEG", quality=90)
|
||||
|
||||
self.logger.debug(f"Generated image saved to {path}")
|
||||
return path
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error generating image with FLUX: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def generate_tts(self, text: str) -> str:
|
||||
"""
|
||||
Generate an audio (MP3) file for the given text using an OpenAI-compatible TTS server.
|
||||
Returns the path to the MP3 file.
|
||||
Generate an audio file from text using a TTS service.
|
||||
|
||||
Args:
|
||||
text (str): The text to convert to speech.
|
||||
|
||||
Returns:
|
||||
str: The file path to the generated MP3 audio file.
|
||||
"""
|
||||
self.logger.info(f"Generating TTS for text: '{text}'")
|
||||
|
||||
# Create a temporary file to store the result
|
||||
|
||||
fd, path = tempfile.mkstemp(suffix=".mp3", dir=self.temp_dir)
|
||||
os.close(fd) # We only need the path.
|
||||
|
||||
os.close(fd)
|
||||
|
||||
try:
|
||||
# Write `text` to some temporary input file if needed, or just pass it directly
|
||||
# Here we just pass it directly to the TTS call.
|
||||
with self.tts_client.audio.speech.with_streaming_response.create(
|
||||
model=self.tts_model,
|
||||
voice=self.tts_voice,
|
||||
|
@ -363,7 +380,7 @@ class Llux:
|
|||
response_format="mp3"
|
||||
) as response:
|
||||
response.stream_to_file(path)
|
||||
|
||||
|
||||
self.logger.debug(f"TTS audio saved to {path}")
|
||||
return path
|
||||
except Exception as e:
|
||||
|
@ -444,29 +461,30 @@ class Llux:
|
|||
sender2: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Send conversation messages to Ollama for a response, add that response to the history,
|
||||
and finally send it to the channel. If 'sender2' is provided, it is prepended to the
|
||||
final response text instead of 'sender'.
|
||||
|
||||
Added: check if the model is available before use; if not, log and return an error message.
|
||||
Send conversation messages to Ollama for a response, handle tool calls if present,
|
||||
add the response to the history, and send it to the channel.
|
||||
"""
|
||||
try:
|
||||
has_image = any("images" in msg for msg in messages)
|
||||
# We decide which model to use
|
||||
use_model_key = self.vision_model if has_image else self.default_model
|
||||
|
||||
# Double-check that the chosen model key is in self.models
|
||||
|
||||
if use_model_key not in self.models:
|
||||
error_text = (
|
||||
f"Requested model '{use_model_key}' is not available in config. "
|
||||
f"Requested model '{use_model_key}' not available. "
|
||||
f"Available models: {', '.join(self.models.keys())}"
|
||||
)
|
||||
self.logger.error(error_text)
|
||||
await self.send_message(channel, error_text)
|
||||
return
|
||||
|
||||
|
||||
model_to_use = self.models[use_model_key]
|
||||
|
||||
|
||||
# Define available tools
|
||||
available_functions = {
|
||||
"generate_image": self.generate_image,
|
||||
"generate_tts": self.generate_tts,
|
||||
}
|
||||
|
||||
log_messages = [
|
||||
{
|
||||
**msg,
|
||||
|
@ -474,7 +492,8 @@ class Llux:
|
|||
} for msg in messages
|
||||
]
|
||||
self.logger.debug(f"Sending to Ollama - model: {model_to_use}, messages: {json.dumps(log_messages)}")
|
||||
|
||||
|
||||
# Initial chat with tools
|
||||
response = ollama.chat(
|
||||
model=model_to_use,
|
||||
messages=messages,
|
||||
|
@ -483,21 +502,61 @@ class Llux:
|
|||
"temperature": self.temperature,
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
},
|
||||
tools=[self.generate_image, self.generate_tts]
|
||||
)
|
||||
response_text = response["message"]["content"]
|
||||
self.logger.debug(f"Ollama response (model: {model_to_use}): {response_text}")
|
||||
|
||||
response_message = response["message"]
|
||||
|
||||
# Handle tool calls if present
|
||||
if response_message.get("tool_calls"):
|
||||
for tool in response_message["tool_calls"]:
|
||||
function_name = tool["function"]["name"]
|
||||
function_args = tool["function"]["arguments"]
|
||||
function_to_call = available_functions.get(function_name)
|
||||
|
||||
if function_to_call:
|
||||
self.logger.debug(f"Calling tool: {function_name} with args: {function_args}")
|
||||
result = await function_to_call(**function_args)
|
||||
|
||||
if function_name == "generate_image":
|
||||
event_id = await self.send_image(channel, result)
|
||||
if event_id:
|
||||
self.temp_images[event_id] = result
|
||||
await self.add_history(
|
||||
"assistant", channel, sender, f"Generated image for: {function_args['prompt']}", result
|
||||
)
|
||||
# Set flag to listen for own image event
|
||||
self.awaiting_own_image = True
|
||||
self.awaiting_timeout = time.time() + 5
|
||||
await asyncio.sleep(1) # Allow event to arrive
|
||||
self.awaiting_own_image = False
|
||||
elif function_name == "generate_tts":
|
||||
await self.send_audio(channel, result)
|
||||
await self.add_history(
|
||||
"assistant", channel, sender, f"Generated TTS for: {function_args['text']}"
|
||||
)
|
||||
|
||||
# Add tool result to messages and get final response
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": f"Tool {function_name} executed successfully",
|
||||
"name": function_name
|
||||
})
|
||||
final_response = ollama.chat(
|
||||
model=model_to_use,
|
||||
messages=messages,
|
||||
options={"top_p": self.top_p, "temperature": self.temperature, "repeat_penalty": self.repeat_penalty}
|
||||
)
|
||||
response_text = final_response["message"]["content"]
|
||||
else:
|
||||
response_text = f"Tool {function_name} not found"
|
||||
else:
|
||||
response_text = response_message["content"]
|
||||
|
||||
await self.add_history("assistant", channel, sender, response_text)
|
||||
|
||||
# Inline response format with sender2 fallback
|
||||
target_user = sender2 if sender2 else sender
|
||||
final_text = f"{target_user} {response_text.strip()}"
|
||||
|
||||
try:
|
||||
await self.send_message(channel, final_text)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending message: {e}", exc_info=True)
|
||||
|
||||
await self.send_message(channel, final_text)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Something went wrong: {e}"
|
||||
self.logger.error(error_msg, exc_info=True)
|
||||
|
@ -710,10 +769,7 @@ class Llux:
|
|||
) -> None:
|
||||
"""
|
||||
Generate an image with the configured pipeline for the given prompt and send it to the Matrix room.
|
||||
Then add the generated image to the user's conversation history so subsequent AI calls
|
||||
can leverage the vision model.
|
||||
|
||||
Provides the user with an estimated completion time based on warmup duration and number of inference steps.
|
||||
Then add the generated image to the user's conversation history and wait for the event to be processed.
|
||||
"""
|
||||
try:
|
||||
# Let user know we're working on it and estimate time based on warmup and steps
|
||||
|
@ -735,15 +791,16 @@ class Llux:
|
|||
self.logger.info(f"User requested image for prompt: '{prompt}'")
|
||||
path = await self.generate_image(prompt)
|
||||
|
||||
# Upload & send the image to the Matrix room
|
||||
await self.send_image(channel, path)
|
||||
# Store the image path temporarily with a unique key (e.g., timestamp)
|
||||
temp_key = f"pending_{int(time.time()*1000)}"
|
||||
self.temp_images[temp_key] = path
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"Error generating image: {str(e)}"
|
||||
self.logger.error(err_msg, exc_info=True)
|
||||
await self.send_message(channel, err_msg)
|
||||
else:
|
||||
# Store the generated image in the conversation history for the user
|
||||
# Upload & send the image to the Matrix room, capturing the event ID
|
||||
event_id = await self.send_image(channel, path)
|
||||
if not event_id:
|
||||
raise Exception("Failed to send image, no event ID returned")
|
||||
|
||||
# Store the generated image in conversation history
|
||||
await self.add_history(
|
||||
role="assistant",
|
||||
channel=channel,
|
||||
|
@ -751,6 +808,21 @@ class Llux:
|
|||
message=f"Generated image for prompt: {prompt}",
|
||||
image_path=path
|
||||
)
|
||||
|
||||
# Set flag to listen for own image event for 5 seconds
|
||||
self.awaiting_own_image = True
|
||||
self.awaiting_timeout = time.time() + 5 # Wait up to 5 seconds
|
||||
|
||||
# Wait briefly to allow the event to arrive
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Reset the flag after timeout
|
||||
self.awaiting_own_image = False
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"Error generating image: {str(e)}"
|
||||
self.logger.error(err_msg, exc_info=True)
|
||||
await self.send_message(channel, err_msg)
|
||||
|
||||
async def handle_message(
|
||||
self,
|
||||
|
@ -771,7 +843,7 @@ class Llux:
|
|||
- .tts : generate and send audio
|
||||
- .model / .clear : admin commands
|
||||
"""
|
||||
self.logger.debug(f"Handling message: {message[0]} from {sender_display}")
|
||||
self.logger.debug(f"Handling message: {message[0]} from {sender_display}, event_id: {event.event_id}")
|
||||
|
||||
user_commands = {
|
||||
".ai": lambda: self.ai(channel, message, sender, event),
|
||||
|
@ -804,14 +876,20 @@ class Llux:
|
|||
async def message_callback(self, room: MatrixRoom, event: Any) -> None:
|
||||
"""
|
||||
Callback to handle messages (text or image) that arrive in the Matrix room.
|
||||
If the message is from someone else (not the bot), it is processed accordingly.
|
||||
Process own images when awaiting_own_image is True.
|
||||
"""
|
||||
message_time = datetime.datetime.fromtimestamp(event.server_timestamp / 1000)
|
||||
|
||||
if message_time > self.join_time and event.sender != self.username:
|
||||
|
||||
# Check if message is recent and either from another user or the bot when awaiting an image
|
||||
if message_time > self.join_time and (
|
||||
event.sender != self.username or (self.awaiting_own_image and time.time() < self.awaiting_timeout)
|
||||
):
|
||||
try:
|
||||
if isinstance(event, RoomMessageImage):
|
||||
await self.handle_image(room, event)
|
||||
# Reset the flag after processing an image if it was the bot's own
|
||||
if event.sender == self.username:
|
||||
self.awaiting_own_image = False
|
||||
elif isinstance(event, RoomMessageText):
|
||||
message = event.body.split(" ")
|
||||
sender = event.sender
|
||||
|
|
61
ponytest.py
Normal file
61
ponytest.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
import gradio as gr
|
||||
import torch
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
model_id = "AstraliteHeart/pony-diffusion"
|
||||
device = "cuda"
|
||||
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
|
||||
block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
|
||||
|
||||
num_samples = 2
|
||||
|
||||
def infer(prompt):
|
||||
with autocast("cuda"):
|
||||
images = pipe([prompt] * num_samples, guidance_scale=7.5)["sample"]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
with block as demo:
|
||||
gr.Markdown("<h1><center>Pony Diffusion</center></h1>")
|
||||
gr.Markdown(
|
||||
"pony-diffusion is a latent text-to-image diffusion model that has been conditioned on high-quality pony images through fine-tuning."
|
||||
)
|
||||
with gr.Group():
|
||||
with gr.Box():
|
||||
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
||||
|
||||
text = gr.Textbox(
|
||||
label="Enter your prompt", show_label=False, max_lines=1
|
||||
).style(
|
||||
border=(True, False, True, True),
|
||||
rounded=(True, False, False, True),
|
||||
container=False,
|
||||
)
|
||||
btn = gr.Button("Run").style(
|
||||
margin=False,
|
||||
rounded=(False, True, True, False),
|
||||
)
|
||||
|
||||
gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
||||
grid=[2], height="auto"
|
||||
)
|
||||
text.submit(infer, inputs=[text], outputs=gallery)
|
||||
btn.click(infer, inputs=[text], outputs=gallery)
|
||||
|
||||
gr.Markdown(
|
||||
"""___
|
||||
<p style='text-align: center'>
|
||||
Created by https://huggingface.co/hakurei
|
||||
<br/>
|
||||
</p>"""
|
||||
)
|
||||
|
||||
|
||||
demo.launch(debug=True)
|
Loading…
Add table
Add a link
Reference in a new issue