mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 10:53:02 +01:00
Improve handling of harmful categorized responses by Gemini
Previously Khoj would stop in the middle of response generation when the safety filters got triggered at default thresholds. This was confusing as it felt like a service error, not expected behavior. Going forward Khoj will - Only block responding to high confidence harmful content detected by Gemini's safety filters instead of using the default safety settings - Show an explanatory, conversational response (w/ harm category) when response is terminated due to Gemini's safety filters
This commit is contained in:
parent
ec1f87a896
commit
893ae60a6a
1 changed files with 106 additions and 7 deletions
|
@ -1,7 +1,18 @@
|
|||
import logging
|
||||
import random
|
||||
from threading import Thread
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.generativeai.types.answer_types import FinishReason
|
||||
from google.generativeai.types.generation_types import (
|
||||
GenerateContentResponse,
|
||||
StopCandidateException,
|
||||
)
|
||||
from google.generativeai.types.safety_types import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
HarmProbability,
|
||||
)
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
|
@ -32,14 +43,35 @@ def gemini_completion_with_backoff(
|
|||
model_kwargs = model_kwargs or dict()
|
||||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["max_output_tokens"] = max_tokens
|
||||
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
|
||||
model = genai.GenerativeModel(
|
||||
model_name,
|
||||
generation_config=model_kwargs,
|
||||
system_instruction=system_prompt,
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
|
||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
||||
return aggregated_response.text
|
||||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
||||
return aggregated_response.text
|
||||
except StopCandidateException as e:
|
||||
response_message, _ = handle_gemini_response(e.args)
|
||||
# Respond with reason for stopping
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {response_message}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
return response_message
|
||||
|
||||
|
||||
@retry(
|
||||
|
@ -79,15 +111,82 @@ def gemini_llm_thread(
|
|||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["max_output_tokens"] = max_tokens
|
||||
model_kwargs["stop_sequences"] = ["Notes:\n["]
|
||||
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
|
||||
model = genai.GenerativeModel(
|
||||
model_name,
|
||||
generation_config=model_kwargs,
|
||||
system_instruction=system_prompt,
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
||||
g.send(chunk.text)
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
g.send(message)
|
||||
if stopped:
|
||||
raise StopCandidateException(message)
|
||||
except StopCandidateException as e:
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
g.close()
|
||||
|
||||
|
||||
def handle_gemini_response(candidates, prompt_feedback=None):
|
||||
"""Check if Gemini response was blocked and return an explanatory error message."""
|
||||
# Check if the response was blocked due to safety concerns with the prompt
|
||||
if len(candidates) == 0 and prompt_feedback:
|
||||
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
|
||||
stopped = True
|
||||
# Check if the response was blocked due to safety concerns with the generated content
|
||||
elif candidates[0].finish_reason == FinishReason.SAFETY:
|
||||
message = generate_safety_response(candidates[0].safety_ratings)
|
||||
stopped = True
|
||||
# Check if the response was stopped due to reaching maximum token limit or other reasons
|
||||
elif candidates[0].finish_reason != FinishReason.STOP:
|
||||
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
|
||||
stopped = True
|
||||
# Otherwise, the response is valid and can be used
|
||||
else:
|
||||
message = None
|
||||
stopped = False
|
||||
return message, stopped
|
||||
|
||||
|
||||
def generate_safety_response(safety_ratings):
|
||||
"""Generate a conversational response based on the safety ratings of the response."""
|
||||
# Get the safety rating with the highest probability
|
||||
max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
|
||||
# Remove the "HARM_CATEGORY_" prefix and title case the category name
|
||||
max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
|
||||
# Add a bit of variety to the discomfort level based on the safety rating probability
|
||||
discomfort_level = {
|
||||
HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
|
||||
HarmProbability.LOW: "a bit ",
|
||||
HarmProbability.MEDIUM: "moderately ",
|
||||
HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
|
||||
}[max_safety_rating.probability]
|
||||
# Generate a response using a random response template
|
||||
safety_response_choice = random.choice(
|
||||
[
|
||||
"\nUmm, I'd rather not to respond to that. The conversation has some probability of going into **{category}** territory.",
|
||||
"\nI'd prefer not to talk about **{category}** related topics. It makes me {discomfort_level}uncomfortable.",
|
||||
"\nI feel {discomfort_level}squeamish talking about **{category}** related stuff! Can we talk about something less controversial?",
|
||||
"\nThat sounds {discomfort_level}outside the [Overtone Window](https://en.wikipedia.org/wiki/Overton_window) of acceptable conversation. Should we stick to something less {category} related?",
|
||||
]
|
||||
)
|
||||
return safety_response_choice.format(
|
||||
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue