This commit is contained in:
sanj 2025-04-09 21:18:53 -07:00
parent 45a8fca64d
commit 7c0c99a79b
2 changed files with 223 additions and 84 deletions

246
llux.py
View file

@ -82,6 +82,9 @@ class Llux:
base_url=tts_url,
api_key=tts_api_key
)
self.awaiting_own_image = False # Flag to listen for bot's own image
self.awaiting_timeout = 0 # Timestamp to stop listening
# Create Matrix client
self.client = AsyncClient(self.server, self.username)
@ -243,9 +246,10 @@ class Llux:
await self.send_message(channel, f"Failed to send audio: {str(e)}")
async def send_image(self, channel: str, image_path: str) -> None:
async def send_image(self, channel: str, image_path: str) -> str:
"""
Send an image to a Matrix channel by uploading the file and then sending an m.image message.
Returns the event ID of the sent message.
"""
try:
with open(image_path, "rb") as f:
@ -256,11 +260,12 @@ class Llux:
)
if upload_error:
self.logger.error(f"Failed to upload image: {upload_error}")
return
return None
self.logger.debug(f"Successfully uploaded image, URI: {upload_response.content_uri}")
await self.client.room_send(
# Send the image message and capture the response
send_response = await self.client.room_send(
room_id=channel,
message_type="m.room.message",
content={
@ -275,40 +280,15 @@ class Llux:
}
}
)
event_id = send_response.event_id
self.logger.debug(f"Image sent with event ID: {event_id}")
return event_id
except Exception as e:
self.logger.error(f"Error sending image: {e}", exc_info=True)
await self.send_message(channel, f"Failed to send image: {str(e)}")
async def generate_image(self, prompt: str) -> str:
"""
Run the FLUX.1-schnell pipeline for the given prompt and saves the result as a JPEG.
Returns the path to the file. Each call uses a new random seed for unique results.
"""
self.logger.debug(f"Generating FLUX image for prompt: {prompt}")
rand_seed = random.randint(0, 2**32 - 1)
generator = torch.Generator("cpu").manual_seed(rand_seed)
try:
result = self.pipe(
prompt,
output_type="pil",
num_inference_steps=self.diffusers_steps,
generator=generator
)
image = result.images[0] # PIL Image
# Save as JPEG to minimize file size
fd, path = tempfile.mkstemp(suffix=".jpg", dir=self.temp_dir)
os.close(fd) # We only need the path; close the file descriptor
image = image.convert("RGB") # Ensure we're in RGB mode
image.save(path, "JPEG", quality=90)
self.logger.debug(f"Generated image saved to {path}")
return path
except Exception as e:
self.logger.error(f"Error generating image with FLUX.1-schnell: {e}", exc_info=True)
raise e
return None
async def download_image(self, mxc_url: str) -> Optional[str]:
"""
@ -342,20 +322,57 @@ class Llux:
self.logger.error(f"Error downloading/processing image: {e}", exc_info=True)
return None
async def generate_image(self, prompt: str) -> str:
"""
Generate an image based on a text prompt using the FLUX pipeline.
Args:
prompt (str): The text description of the image to generate.
Returns:
str: The file path to the generated image.
"""
self.logger.debug(f"Generating FLUX image for prompt: {prompt}")
rand_seed = random.randint(0, 2**32 - 1)
generator = torch.Generator("cpu").manual_seed(rand_seed)
try:
result = self.pipe(
prompt,
output_type="pil",
num_inference_steps=self.diffusers_steps,
generator=generator
)
image = result.images[0] # PIL Image
fd, path = tempfile.mkstemp(suffix=".jpg", dir=self.temp_dir)
os.close(fd)
image = image.convert("RGB")
image.save(path, "JPEG", quality=90)
self.logger.debug(f"Generated image saved to {path}")
return path
except Exception as e:
self.logger.error(f"Error generating image with FLUX: {e}", exc_info=True)
raise e
async def generate_tts(self, text: str) -> str:
"""
Generate an audio (MP3) file for the given text using an OpenAI-compatible TTS server.
Returns the path to the MP3 file.
Generate an audio file from text using a TTS service.
Args:
text (str): The text to convert to speech.
Returns:
str: The file path to the generated MP3 audio file.
"""
self.logger.info(f"Generating TTS for text: '{text}'")
# Create a temporary file to store the result
fd, path = tempfile.mkstemp(suffix=".mp3", dir=self.temp_dir)
os.close(fd) # We only need the path.
os.close(fd)
try:
# Write `text` to some temporary input file if needed, or just pass it directly
# Here we just pass it directly to the TTS call.
with self.tts_client.audio.speech.with_streaming_response.create(
model=self.tts_model,
voice=self.tts_voice,
@ -363,7 +380,7 @@ class Llux:
response_format="mp3"
) as response:
response.stream_to_file(path)
self.logger.debug(f"TTS audio saved to {path}")
return path
except Exception as e:
@ -444,29 +461,30 @@ class Llux:
sender2: Optional[str] = None
) -> None:
"""
Send conversation messages to Ollama for a response, add that response to the history,
and finally send it to the channel. If 'sender2' is provided, it is prepended to the
final response text instead of 'sender'.
Added: check if the model is available before use; if not, log and return an error message.
Send conversation messages to Ollama for a response, handle tool calls if present,
add the response to the history, and send it to the channel.
"""
try:
has_image = any("images" in msg for msg in messages)
# We decide which model to use
use_model_key = self.vision_model if has_image else self.default_model
# Double-check that the chosen model key is in self.models
if use_model_key not in self.models:
error_text = (
f"Requested model '{use_model_key}' is not available in config. "
f"Requested model '{use_model_key}' not available. "
f"Available models: {', '.join(self.models.keys())}"
)
self.logger.error(error_text)
await self.send_message(channel, error_text)
return
model_to_use = self.models[use_model_key]
# Define available tools
available_functions = {
"generate_image": self.generate_image,
"generate_tts": self.generate_tts,
}
log_messages = [
{
**msg,
@ -474,7 +492,8 @@ class Llux:
} for msg in messages
]
self.logger.debug(f"Sending to Ollama - model: {model_to_use}, messages: {json.dumps(log_messages)}")
# Initial chat with tools
response = ollama.chat(
model=model_to_use,
messages=messages,
@ -483,21 +502,61 @@ class Llux:
"temperature": self.temperature,
"repeat_penalty": self.repeat_penalty,
},
tools=[self.generate_image, self.generate_tts]
)
response_text = response["message"]["content"]
self.logger.debug(f"Ollama response (model: {model_to_use}): {response_text}")
response_message = response["message"]
# Handle tool calls if present
if response_message.get("tool_calls"):
for tool in response_message["tool_calls"]:
function_name = tool["function"]["name"]
function_args = tool["function"]["arguments"]
function_to_call = available_functions.get(function_name)
if function_to_call:
self.logger.debug(f"Calling tool: {function_name} with args: {function_args}")
result = await function_to_call(**function_args)
if function_name == "generate_image":
event_id = await self.send_image(channel, result)
if event_id:
self.temp_images[event_id] = result
await self.add_history(
"assistant", channel, sender, f"Generated image for: {function_args['prompt']}", result
)
# Set flag to listen for own image event
self.awaiting_own_image = True
self.awaiting_timeout = time.time() + 5
await asyncio.sleep(1) # Allow event to arrive
self.awaiting_own_image = False
elif function_name == "generate_tts":
await self.send_audio(channel, result)
await self.add_history(
"assistant", channel, sender, f"Generated TTS for: {function_args['text']}"
)
# Add tool result to messages and get final response
messages.append({
"role": "tool",
"content": f"Tool {function_name} executed successfully",
"name": function_name
})
final_response = ollama.chat(
model=model_to_use,
messages=messages,
options={"top_p": self.top_p, "temperature": self.temperature, "repeat_penalty": self.repeat_penalty}
)
response_text = final_response["message"]["content"]
else:
response_text = f"Tool {function_name} not found"
else:
response_text = response_message["content"]
await self.add_history("assistant", channel, sender, response_text)
# Inline response format with sender2 fallback
target_user = sender2 if sender2 else sender
final_text = f"{target_user} {response_text.strip()}"
try:
await self.send_message(channel, final_text)
except Exception as e:
self.logger.error(f"Error sending message: {e}", exc_info=True)
await self.send_message(channel, final_text)
except Exception as e:
error_msg = f"Something went wrong: {e}"
self.logger.error(error_msg, exc_info=True)
@ -710,10 +769,7 @@ class Llux:
) -> None:
"""
Generate an image with the configured pipeline for the given prompt and send it to the Matrix room.
Then add the generated image to the user's conversation history so subsequent AI calls
can leverage the vision model.
Provides the user with an estimated completion time based on warmup duration and number of inference steps.
Then add the generated image to the user's conversation history and wait for the event to be processed.
"""
try:
# Let user know we're working on it and estimate time based on warmup and steps
@ -735,15 +791,16 @@ class Llux:
self.logger.info(f"User requested image for prompt: '{prompt}'")
path = await self.generate_image(prompt)
# Upload & send the image to the Matrix room
await self.send_image(channel, path)
# Store the image path temporarily with a unique key (e.g., timestamp)
temp_key = f"pending_{int(time.time()*1000)}"
self.temp_images[temp_key] = path
except Exception as e:
err_msg = f"Error generating image: {str(e)}"
self.logger.error(err_msg, exc_info=True)
await self.send_message(channel, err_msg)
else:
# Store the generated image in the conversation history for the user
# Upload & send the image to the Matrix room, capturing the event ID
event_id = await self.send_image(channel, path)
if not event_id:
raise Exception("Failed to send image, no event ID returned")
# Store the generated image in conversation history
await self.add_history(
role="assistant",
channel=channel,
@ -751,6 +808,21 @@ class Llux:
message=f"Generated image for prompt: {prompt}",
image_path=path
)
# Set flag to listen for own image event for 5 seconds
self.awaiting_own_image = True
self.awaiting_timeout = time.time() + 5 # Wait up to 5 seconds
# Wait briefly to allow the event to arrive
await asyncio.sleep(1)
# Reset the flag after timeout
self.awaiting_own_image = False
except Exception as e:
err_msg = f"Error generating image: {str(e)}"
self.logger.error(err_msg, exc_info=True)
await self.send_message(channel, err_msg)
async def handle_message(
self,
@ -771,7 +843,7 @@ class Llux:
- .tts : generate and send audio
- .model / .clear : admin commands
"""
self.logger.debug(f"Handling message: {message[0]} from {sender_display}")
self.logger.debug(f"Handling message: {message[0]} from {sender_display}, event_id: {event.event_id}")
user_commands = {
".ai": lambda: self.ai(channel, message, sender, event),
@ -804,14 +876,20 @@ class Llux:
async def message_callback(self, room: MatrixRoom, event: Any) -> None:
"""
Callback to handle messages (text or image) that arrive in the Matrix room.
If the message is from someone else (not the bot), it is processed accordingly.
Process own images when awaiting_own_image is True.
"""
message_time = datetime.datetime.fromtimestamp(event.server_timestamp / 1000)
if message_time > self.join_time and event.sender != self.username:
# Check if message is recent and either from another user or the bot when awaiting an image
if message_time > self.join_time and (
event.sender != self.username or (self.awaiting_own_image and time.time() < self.awaiting_timeout)
):
try:
if isinstance(event, RoomMessageImage):
await self.handle_image(room, event)
# Reset the flag after processing an image if it was the bot's own
if event.sender == self.username:
self.awaiting_own_image = False
elif isinstance(event, RoomMessageText):
message = event.body.split(" ")
sender = event.sender

61
ponytest.py Normal file
View file

@ -0,0 +1,61 @@
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
model_id = "AstraliteHeart/pony-diffusion"
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
num_samples = 2
def infer(prompt):
with autocast("cuda"):
images = pipe([prompt] * num_samples, guidance_scale=7.5)["sample"]
return images
with block as demo:
gr.Markdown("<h1><center>Pony Diffusion</center></h1>")
gr.Markdown(
"pony-diffusion is a latent text-to-image diffusion model that has been conditioned on high-quality pony images through fine-tuning."
)
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt", show_label=False, max_lines=1
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Run").style(
margin=False,
rounded=(False, True, True, False),
)
gallery = gr.Gallery(label="Generated images", show_label=False).style(
grid=[2], height="auto"
)
text.submit(infer, inputs=[text], outputs=gallery)
btn.click(infer, inputs=[text], outputs=gallery)
gr.Markdown(
"""___
<p style='text-align: center'>
Created by https://huggingface.co/hakurei
<br/>
</p>"""
)
demo.launch(debug=True)