mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Merge branch 'master' into add-prompt-tracer-for-observability
This commit is contained in:
commit
adca6cbe9d
12 changed files with 104 additions and 80 deletions
|
@ -367,6 +367,11 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||||
<PopoverContent
|
<PopoverContent
|
||||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||||
className={`${props.isMobileWidth ? "w-[100vw]" : "w-full"} rounded-md`}
|
className={`${props.isMobileWidth ? "w-[100vw]" : "w-full"} rounded-md`}
|
||||||
|
side="top"
|
||||||
|
align="center"
|
||||||
|
/* Offset below text area on home page (i.e where conversationId is unset) */
|
||||||
|
sideOffset={props.conversationId ? 0 : 80}
|
||||||
|
alignOffset={0}
|
||||||
>
|
>
|
||||||
<Command className="max-w-full">
|
<Command className="max-w-full">
|
||||||
<CommandInput
|
<CommandInput
|
||||||
|
|
|
@ -361,7 +361,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
className={`${selectedAgent === agents[index].slug ? convertColorToBorderClass(agents[index].color) : "border-muted text-muted-foreground"} hover:cursor-pointer`}
|
className={`${selectedAgent === agents[index].slug ? convertColorToBorderClass(agents[index].color) : "border-muted text-muted-foreground"} hover:cursor-pointer`}
|
||||||
>
|
>
|
||||||
<CardTitle
|
<CardTitle
|
||||||
className="text-center text-xs font-medium flex justify-center items-center px-1.5 py-1"
|
className="text-center text-xs font-medium flex justify-center items-center whitespace-nowrap px-1.5 py-1"
|
||||||
onDoubleClick={() =>
|
onDoubleClick={() =>
|
||||||
openAgentEditCard(agents[index].slug)
|
openAgentEditCard(agents[index].slug)
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@ from khoj.processor.embeddings import EmbeddingsModel
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BATCH_SIZE = 1000 # Define an appropriate batch size
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
class Command(BaseCommand):
|
||||||
help = "Convert all existing Entry objects to use a new default Search model."
|
help = "Convert all existing Entry objects to use a new default Search model."
|
||||||
|
@ -42,22 +44,24 @@ class Command(BaseCommand):
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
||||||
entries = Entry.objects.filter(entry_filter).all()
|
total_entries = Entry.objects.filter(entry_filter).count()
|
||||||
compiled_entries = [entry.compiled for entry in entries]
|
for start in tqdm(range(0, total_entries, BATCH_SIZE)):
|
||||||
updated_entries: List[Entry] = []
|
end = start + BATCH_SIZE
|
||||||
try:
|
entries = Entry.objects.filter(entry_filter)[start:end]
|
||||||
embeddings = embeddings_model.embed_documents(compiled_entries)
|
compiled_entries = [entry.compiled for entry in entries]
|
||||||
|
updated_entries: List[Entry] = []
|
||||||
|
try:
|
||||||
|
embeddings = embeddings_model.embed_documents(compiled_entries)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
except Exception as e:
|
for i, entry in enumerate(entries):
|
||||||
logger.error(f"Error embedding documents: {e}")
|
entry.embeddings = embeddings[i]
|
||||||
return
|
entry.search_model_id = search_model.id
|
||||||
|
updated_entries.append(entry)
|
||||||
|
|
||||||
for i, entry in enumerate(tqdm(entries)):
|
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
||||||
entry.embeddings = embeddings[i]
|
|
||||||
entry.search_model_id = search_model.id
|
|
||||||
updated_entries.append(entry)
|
|
||||||
|
|
||||||
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
|
||||||
|
|
||||||
search_model_config_id = options.get("search_model_id")
|
search_model_config_id = options.get("search_model_id")
|
||||||
apply = options.get("apply")
|
apply = options.get("apply")
|
||||||
|
|
|
@ -151,9 +151,7 @@ def converse_anthropic(
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||||
|
|
||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
system_prompt = prompts.custom_personality.format(
|
system_prompt = prompts.custom_personality.format(
|
||||||
|
@ -184,16 +182,16 @@ def converse_anthropic(
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
return iter([prompts.no_online_results_found.format()])
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
context_message = ""
|
||||||
conversation_primer = (
|
|
||||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
|
||||||
|
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||||
|
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
user_query,
|
||||||
|
context_message=context_message,
|
||||||
conversation_log=conversation_log,
|
conversation_log=conversation_log,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
|
|
@ -156,9 +156,7 @@ def converse_gemini(
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||||
|
|
||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
system_prompt = prompts.custom_personality.format(
|
system_prompt = prompts.custom_personality.format(
|
||||||
|
@ -190,16 +188,16 @@ def converse_gemini(
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
return iter([prompts.no_online_results_found.format()])
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
context_message = ""
|
||||||
conversation_primer = (
|
|
||||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
|
||||||
|
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||||
|
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
user_query,
|
||||||
|
context_message=context_message,
|
||||||
conversation_log=conversation_log,
|
conversation_log=conversation_log,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
|
|
@ -157,9 +157,9 @@ def converse_offline(
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
|
||||||
tracer["chat_model"] = model
|
tracer["chat_model"] = model
|
||||||
|
|
||||||
|
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
|
@ -175,8 +175,6 @@ def converse_offline(
|
||||||
day_of_week=current_date.strftime("%A"),
|
day_of_week=current_date.strftime("%A"),
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
|
||||||
|
|
||||||
if location_data:
|
if location_data:
|
||||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||||
|
@ -186,27 +184,31 @@ def converse_offline(
|
||||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||||
|
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||||
return iter([prompts.no_notes_found.format()])
|
return iter([prompts.no_notes_found.format()])
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
return iter([prompts.no_online_results_found.format()])
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
if ConversationCommand.Online in conversation_commands:
|
context_message = ""
|
||||||
|
if not is_none_or_empty(compiled_references):
|
||||||
|
context_message += f"{prompts.notes_conversation_offline.format(references=compiled_references)}\n\n"
|
||||||
|
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||||
simplified_online_results = online_results.copy()
|
simplified_online_results = online_results.copy()
|
||||||
for result in online_results:
|
for result in online_results:
|
||||||
if online_results[result].get("webpages"):
|
if online_results[result].get("webpages"):
|
||||||
simplified_online_results[result] = online_results[result]["webpages"]
|
simplified_online_results[result] = online_results[result]["webpages"]
|
||||||
|
|
||||||
conversation_primer = f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
|
context_message += (
|
||||||
if not is_none_or_empty(compiled_references_message):
|
f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}"
|
||||||
conversation_primer = f"{prompts.notes_conversation_offline.format(references=compiled_references_message)}\n\n{conversation_primer}"
|
)
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
user_query,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
conversation_log,
|
conversation_log,
|
||||||
|
context_message=context_message,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
loaded_model=offline_chat_model,
|
loaded_model=offline_chat_model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
|
|
@ -154,9 +154,7 @@ def converse(
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
compiled_references = "\n\n".join({f"# {item['compiled']}" for item in references})
|
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||||
|
|
||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
system_prompt = prompts.custom_personality.format(
|
system_prompt = prompts.custom_personality.format(
|
||||||
|
@ -187,18 +185,18 @@ def converse(
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
return iter([prompts.no_online_results_found.format()])
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
if not is_none_or_empty(online_results):
|
context_message = ""
|
||||||
conversation_primer = (
|
|
||||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
|
||||||
)
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
context_message = f"{prompts.notes_conversation.format(references=compiled_references)}\n\n"
|
||||||
|
if not is_none_or_empty(online_results):
|
||||||
|
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
user_query,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
conversation_log,
|
conversation_log,
|
||||||
|
context_message=context_message,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
|
|
@ -118,6 +118,7 @@ Use my personal notes and our past conversations to inform your response.
|
||||||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
|
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
|
||||||
|
|
||||||
User's Notes:
|
User's Notes:
|
||||||
|
-----
|
||||||
{references}
|
{references}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
@ -127,6 +128,7 @@ notes_conversation_offline = PromptTemplate.from_template(
|
||||||
Use my personal notes and our past conversations to inform your response.
|
Use my personal notes and our past conversations to inform your response.
|
||||||
|
|
||||||
User's Notes:
|
User's Notes:
|
||||||
|
-----
|
||||||
{references}
|
{references}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
@ -328,6 +330,7 @@ Use this up-to-date information from the internet to inform your response.
|
||||||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
|
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
|
||||||
|
|
||||||
Information from the internet:
|
Information from the internet:
|
||||||
|
-----
|
||||||
{online_results}
|
{online_results}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
@ -337,6 +340,7 @@ online_search_conversation_offline = PromptTemplate.from_template(
|
||||||
Use this up-to-date information from the internet to inform your response.
|
Use this up-to-date information from the internet to inform your response.
|
||||||
|
|
||||||
Information from the internet:
|
Information from the internet:
|
||||||
|
-----
|
||||||
{online_results}
|
{online_results}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from transformers import AutoTokenizer
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
||||||
|
@ -188,8 +189,9 @@ def generate_chatml_messages_with_context(
|
||||||
query_images=None,
|
query_images=None,
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
model_type="",
|
model_type="",
|
||||||
|
context_message="",
|
||||||
):
|
):
|
||||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
if not max_prompt_size:
|
if not max_prompt_size:
|
||||||
if loaded_model:
|
if loaded_model:
|
||||||
|
@ -203,21 +205,27 @@ def generate_chatml_messages_with_context(
|
||||||
# Extract Chat History for Context
|
# Extract Chat History for Context
|
||||||
chatml_messages: List[ChatMessage] = []
|
chatml_messages: List[ChatMessage] = []
|
||||||
for chat in conversation_log.get("chat", []):
|
for chat in conversation_log.get("chat", []):
|
||||||
message_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
message_context = ""
|
||||||
|
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||||
|
message_context += chat.get("intent").get("inferred-queries")[0]
|
||||||
|
if not is_none_or_empty(chat.get("context")):
|
||||||
|
references = "\n\n".join(
|
||||||
|
{f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []}
|
||||||
|
)
|
||||||
|
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
|
||||||
|
if not is_none_or_empty(chat.get("onlineContext")):
|
||||||
|
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
||||||
|
if not is_none_or_empty(message_context):
|
||||||
|
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||||
|
chatml_messages.insert(0, reconstructed_context_message)
|
||||||
|
|
||||||
role = "user" if chat["by"] == "you" else "assistant"
|
role = "user" if chat["by"] == "you" else "assistant"
|
||||||
|
message_content = construct_structured_message(chat["message"], chat.get("images"), model_type, vision_enabled)
|
||||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type"):
|
|
||||||
message_content = chat.get("intent").get("inferred-queries")[0] + message_notes
|
|
||||||
else:
|
|
||||||
message_content = chat["message"] + message_notes
|
|
||||||
|
|
||||||
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
|
|
||||||
|
|
||||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||||
|
|
||||||
chatml_messages.insert(0, reconstructed_message)
|
chatml_messages.insert(0, reconstructed_message)
|
||||||
|
|
||||||
if len(chatml_messages) >= 2 * lookback_turns:
|
if len(chatml_messages) >= 3 * lookback_turns:
|
||||||
break
|
break
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -228,6 +236,8 @@ def generate_chatml_messages_with_context(
|
||||||
role="user",
|
role="user",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if not is_none_or_empty(context_message):
|
||||||
|
messages.append(ChatMessage(content=context_message, role="user"))
|
||||||
if len(chatml_messages) > 0:
|
if len(chatml_messages) > 0:
|
||||||
messages += chatml_messages
|
messages += chatml_messages
|
||||||
if not is_none_or_empty(system_message):
|
if not is_none_or_empty(system_message):
|
||||||
|
|
|
@ -113,7 +113,7 @@ def add_files_filter(request: Request, filter: FilesFilterRequest):
|
||||||
file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
|
file_filters = ConversationAdapters.add_files_to_filter(request.user.object, conversation_id, files_filter)
|
||||||
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True)
|
logger.error(f"Error adding file filter {filter.filenames}: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=422, detail=str(e))
|
raise HTTPException(status_code=422, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -214,7 +214,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
|
||||||
(
|
(
|
||||||
"When was I born?",
|
"When was I born?",
|
||||||
"You were born on 1st April 1984.",
|
"You were born on 1st April 1984.",
|
||||||
["Testatron was born on 1st April 1984 in Testville."],
|
[{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "birth.org"}],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -415,15 +415,18 @@ def test_ask_for_clarification_if_not_enough_context_in_question():
|
||||||
context = [
|
context = [
|
||||||
{
|
{
|
||||||
"compiled": f"""# Ramya
|
"compiled": f"""# Ramya
|
||||||
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani."""
|
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
|
||||||
|
"file": "Family.md",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"compiled": f"""# Fang
|
"compiled": f"""# Fang
|
||||||
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li."""
|
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
|
||||||
|
"file": "Family.md",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"compiled": f"""# Aiyla
|
"compiled": f"""# Aiyla
|
||||||
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."""
|
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
|
||||||
|
"file": "Family.md",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -608,9 +611,11 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, defa
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_infer_task_scheduling_request(chat_client, user_query, expected_crontime, expected_qs, unexpected_qs):
|
async def test_infer_task_scheduling_request(
|
||||||
|
chat_client, user_query, expected_crontime, expected_qs, unexpected_qs, default_user2
|
||||||
|
):
|
||||||
# Act
|
# Act
|
||||||
crontime, inferred_query, _ = await schedule_query(user_query, {})
|
crontime, inferred_query, _ = await schedule_query(user_query, {}, default_user2)
|
||||||
inferred_query = inferred_query.lower()
|
inferred_query = inferred_query.lower()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -630,7 +635,7 @@ async def test_infer_task_scheduling_request(chat_client, user_query, expected_c
|
||||||
"scheduling_query, executing_query, generated_response, expected_should_notify",
|
"scheduling_query, executing_query, generated_response, expected_should_notify",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"Notify me if it is going to rain tomorrow?",
|
"Notify me only if it is going to rain tomorrow?",
|
||||||
"What's the weather forecast for tomorrow?",
|
"What's the weather forecast for tomorrow?",
|
||||||
"It is sunny and warm tomorrow.",
|
"It is sunny and warm tomorrow.",
|
||||||
False,
|
False,
|
||||||
|
@ -656,10 +661,10 @@ async def test_infer_task_scheduling_request(chat_client, user_query, expected_c
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_decision_on_when_to_notify_scheduled_task_results(
|
def test_decision_on_when_to_notify_scheduled_task_results(
|
||||||
chat_client, scheduling_query, executing_query, generated_response, expected_should_notify
|
chat_client, default_user2, scheduling_query, executing_query, generated_response, expected_should_notify
|
||||||
):
|
):
|
||||||
# Act
|
# Act
|
||||||
generated_should_notify = should_notify(scheduling_query, executing_query, generated_response)
|
generated_should_notify = should_notify(scheduling_query, executing_query, generated_response, default_user2)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert generated_should_notify == expected_should_notify
|
assert generated_should_notify == expected_should_notify
|
||||||
|
|
|
@ -307,7 +307,7 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
|
||||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||||
)
|
)
|
||||||
query = "/summarize"
|
query = "/summarize"
|
||||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||||
response_message = response.json()["response"]
|
response_message = response.json()["response"]
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
|
@ -339,7 +339,7 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
|
||||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||||
)
|
)
|
||||||
query = "/summarize tell me about Xiu"
|
query = "/summarize tell me about Xiu"
|
||||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||||
response_message = response.json()["response"]
|
response_message = response.json()["response"]
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
|
@ -367,7 +367,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
|
||||||
)
|
)
|
||||||
|
|
||||||
query = "/summarize"
|
query = "/summarize"
|
||||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||||
response_message = response.json()["response"]
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -383,7 +383,7 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
query = "/summarize"
|
query = "/summarize"
|
||||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||||
response_message = response.json()["response"]
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -418,11 +418,11 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
query = "/summarize"
|
query = "/summarize"
|
||||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id})
|
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation2.id)})
|
||||||
response_message_conv2 = response.json()["response"]
|
response_message_conv2 = response.json()["response"]
|
||||||
|
|
||||||
# now make sure that the file filter is still in conversation 1
|
# now make sure that the file filter is still in conversation 1
|
||||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id})
|
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation1.id)})
|
||||||
response_message_conv1 = response.json()["response"]
|
response_message_conv1 = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -449,7 +449,7 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
|
||||||
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
||||||
)
|
)
|
||||||
query = urllib.parse.quote("/summarize")
|
query = urllib.parse.quote("/summarize")
|
||||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||||
response_message = response.json()["response"]
|
response_message = response.json()["response"]
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||||
|
@ -481,7 +481,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
query = "/summarize"
|
query = "/summarize"
|
||||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||||
response_message = response.json()["response"]
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
|
Loading…
Reference in a new issue