mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-24 07:55:07 +01:00
Merge branch 'features/add-agents-ui' of github.com:khoj-ai/khoj into features/chat-socket-streaming
This commit is contained in:
commit
6b4c4f10b5
14 changed files with 162 additions and 76 deletions
|
@ -1,7 +1,7 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
bind = "0.0.0.0:42110"
|
bind = "0.0.0.0:42110"
|
||||||
workers = 4
|
workers = 8
|
||||||
worker_class = "uvicorn.workers.UvicornWorker"
|
worker_class = "uvicorn.workers.UvicornWorker"
|
||||||
timeout = 120
|
timeout = 120
|
||||||
keep_alive = 60
|
keep_alive = 60
|
||||||
|
|
|
@ -794,7 +794,6 @@
|
||||||
chatBody.dataset.conversationId = "";
|
chatBody.dataset.conversationId = "";
|
||||||
chatBody.dataset.conversationTitle = "";
|
chatBody.dataset.conversationTitle = "";
|
||||||
loadChat();
|
loadChat();
|
||||||
flashStatusInChatInput("🗑 Cleared previous conversation history");
|
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
flashStatusInChatInput("⛔️ Failed to clear conversation history");
|
flashStatusInChatInput("⛔️ Failed to clear conversation history");
|
||||||
|
@ -856,28 +855,6 @@
|
||||||
let conversationMenu = document.createElement('div');
|
let conversationMenu = document.createElement('div');
|
||||||
conversationMenu.classList.add("conversation-menu");
|
conversationMenu.classList.add("conversation-menu");
|
||||||
|
|
||||||
let deleteButton = document.createElement('button');
|
|
||||||
deleteButton.innerHTML = "Delete";
|
|
||||||
deleteButton.classList.add("delete-conversation-button");
|
|
||||||
deleteButton.classList.add("three-dot-menu-button-item");
|
|
||||||
deleteButton.addEventListener('click', function() {
|
|
||||||
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
|
|
||||||
fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers })
|
|
||||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
|
||||||
.then(data => {
|
|
||||||
let chatBody = document.getElementById("chat-body");
|
|
||||||
chatBody.innerHTML = "";
|
|
||||||
chatBody.dataset.conversationId = "";
|
|
||||||
chatBody.dataset.conversationTitle = "";
|
|
||||||
loadChat();
|
|
||||||
})
|
|
||||||
.catch(err => {
|
|
||||||
return;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
conversationMenu.appendChild(deleteButton);
|
|
||||||
threeDotMenu.appendChild(conversationMenu);
|
|
||||||
|
|
||||||
let editTitleButton = document.createElement('button');
|
let editTitleButton = document.createElement('button');
|
||||||
editTitleButton.innerHTML = "Rename";
|
editTitleButton.innerHTML = "Rename";
|
||||||
editTitleButton.classList.add("edit-title-button");
|
editTitleButton.classList.add("edit-title-button");
|
||||||
|
@ -903,12 +880,13 @@
|
||||||
|
|
||||||
conversationTitleInput.addEventListener('click', function(event) {
|
conversationTitleInput.addEventListener('click', function(event) {
|
||||||
event.stopPropagation();
|
event.stopPropagation();
|
||||||
|
});
|
||||||
|
conversationTitleInput.addEventListener('keydown', function(event) {
|
||||||
if (event.key === "Enter") {
|
if (event.key === "Enter") {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
conversationTitleInputButton.click();
|
conversationTitleInputButton.click();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
conversationTitleInputBox.appendChild(conversationTitleInput);
|
conversationTitleInputBox.appendChild(conversationTitleInput);
|
||||||
let conversationTitleInputButton = document.createElement('button');
|
let conversationTitleInputButton = document.createElement('button');
|
||||||
conversationTitleInputButton.innerHTML = "Save";
|
conversationTitleInputButton.innerHTML = "Save";
|
||||||
|
@ -918,7 +896,7 @@
|
||||||
let newTitle = conversationTitleInput.value;
|
let newTitle = conversationTitleInput.value;
|
||||||
if (newTitle != null) {
|
if (newTitle != null) {
|
||||||
let editURL = `/api/chat/title?client=web&conversation_id=${incomingConversationId}&title=${newTitle}`;
|
let editURL = `/api/chat/title?client=web&conversation_id=${incomingConversationId}&title=${newTitle}`;
|
||||||
fetch(`${hostURL}${editURL}` , { method: "PATCH" })
|
fetch(`${hostURL}${editURL}` , { method: "PATCH", headers })
|
||||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||||
.then(data => {
|
.then(data => {
|
||||||
conversationButton.textContent = newTitle;
|
conversationButton.textContent = newTitle;
|
||||||
|
@ -931,8 +909,35 @@
|
||||||
conversationTitleInputBox.appendChild(conversationTitleInputButton);
|
conversationTitleInputBox.appendChild(conversationTitleInputButton);
|
||||||
conversationMenu.appendChild(conversationTitleInputBox);
|
conversationMenu.appendChild(conversationTitleInputBox);
|
||||||
});
|
});
|
||||||
|
|
||||||
conversationMenu.appendChild(editTitleButton);
|
conversationMenu.appendChild(editTitleButton);
|
||||||
threeDotMenu.appendChild(conversationMenu);
|
threeDotMenu.appendChild(conversationMenu);
|
||||||
|
|
||||||
|
let deleteButton = document.createElement('button');
|
||||||
|
deleteButton.innerHTML = "Delete";
|
||||||
|
deleteButton.classList.add("delete-conversation-button");
|
||||||
|
deleteButton.classList.add("three-dot-menu-button-item");
|
||||||
|
deleteButton.addEventListener('click', function() {
|
||||||
|
// Ask for confirmation before deleting chat session
|
||||||
|
let confirmation = confirm('Are you sure you want to delete this chat session?');
|
||||||
|
if (!confirmation) return;
|
||||||
|
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
|
||||||
|
fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers })
|
||||||
|
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||||
|
.then(data => {
|
||||||
|
let chatBody = document.getElementById("chat-body");
|
||||||
|
chatBody.innerHTML = "";
|
||||||
|
chatBody.dataset.conversationId = "";
|
||||||
|
chatBody.dataset.conversationTitle = "";
|
||||||
|
loadChat();
|
||||||
|
})
|
||||||
|
.catch(err => {
|
||||||
|
return;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
conversationMenu.appendChild(deleteButton);
|
||||||
|
threeDotMenu.appendChild(conversationMenu);
|
||||||
});
|
});
|
||||||
threeDotMenu.appendChild(threeDotMenuButton);
|
threeDotMenu.appendChild(threeDotMenuButton);
|
||||||
conversationButton.appendChild(threeDotMenu);
|
conversationButton.appendChild(threeDotMenu);
|
||||||
|
|
|
@ -363,4 +363,4 @@ def upload_telemetry():
|
||||||
@schedule.repeat(schedule.every(31).minutes)
|
@schedule.repeat(schedule.every(31).minutes)
|
||||||
def delete_old_user_requests():
|
def delete_old_user_requests():
|
||||||
num_deleted = delete_user_requests()
|
num_deleted = delete_user_requests()
|
||||||
logger.info(f"🗑️ Deleted {num_deleted[0]} day-old user requests")
|
logger.debug(f"🗑️ Deleted {num_deleted[0]} day-old user requests")
|
||||||
|
|
|
@ -470,7 +470,7 @@ class ConversationAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_by_user(
|
def get_conversation_by_user(
|
||||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
|
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
|
||||||
):
|
) -> Optional[Conversation]:
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
conversation = (
|
conversation = (
|
||||||
Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
||||||
|
@ -518,19 +518,21 @@ class ConversationAdapters:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_conversation_by_user(
|
async def aget_conversation_by_user(
|
||||||
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, slug: str = None
|
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
|
||||||
):
|
) -> Optional[Conversation]:
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).afirst()
|
||||||
elif slug:
|
elif title:
|
||||||
conversation = Conversation.objects.filter(user=user, client=client_application, slug=slug)
|
return await Conversation.objects.filter(user=user, client=client_application, title=title).afirst()
|
||||||
else:
|
else:
|
||||||
conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
|
conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
|
||||||
|
|
||||||
if await conversation.aexists():
|
if await conversation.aexists():
|
||||||
return await conversation.prefetch_related("agent").afirst()
|
return await conversation.prefetch_related("agent").afirst()
|
||||||
|
|
||||||
return await Conversation.objects.acreate(user=user, client=client_application, slug=slug)
|
return await (
|
||||||
|
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
|
||||||
|
) or await Conversation.objects.acreate(user=user, client=client_application)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def adelete_conversation_by_user(
|
async def adelete_conversation_by_user(
|
||||||
|
|
17
src/khoj/database/migrations/0031_alter_googleuser_locale.py
Normal file
17
src/khoj/database/migrations/0031_alter_googleuser_locale.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# Generated by Django 4.2.10 on 2024-03-15 10:04
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0030_conversation_slug_and_title"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="googleuser",
|
||||||
|
name="locale",
|
||||||
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
|
),
|
||||||
|
]
|
14
src/khoj/database/migrations/0032_merge_20240322_0427.py
Normal file
14
src/khoj/database/migrations/0032_merge_20240322_0427.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# Generated by Django 4.2.10 on 2024-03-22 04:27
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0031_agent_conversation_agent"),
|
||||||
|
("database", "0031_alter_googleuser_locale"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations: List[str] = []
|
|
@ -47,7 +47,7 @@ class GoogleUser(models.Model):
|
||||||
given_name = models.CharField(max_length=200, null=True, default=None, blank=True)
|
given_name = models.CharField(max_length=200, null=True, default=None, blank=True)
|
||||||
family_name = models.CharField(max_length=200, null=True, default=None, blank=True)
|
family_name = models.CharField(max_length=200, null=True, default=None, blank=True)
|
||||||
picture = models.CharField(max_length=200, null=True, default=None)
|
picture = models.CharField(max_length=200, null=True, default=None)
|
||||||
locale = models.CharField(max_length=200)
|
locale = models.CharField(max_length=200, null=True, default=None, blank=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
|
@ -1334,6 +1334,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
|
|
||||||
conversationTitleInput.addEventListener('click', function(event) {
|
conversationTitleInput.addEventListener('click', function(event) {
|
||||||
event.stopPropagation();
|
event.stopPropagation();
|
||||||
|
});
|
||||||
|
conversationTitleInput.addEventListener('keydown', function(event) {
|
||||||
if (event.key === "Enter") {
|
if (event.key === "Enter") {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
conversationTitleInputButton.click();
|
conversationTitleInputButton.click();
|
||||||
|
@ -1370,6 +1372,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
deleteButton.classList.add("delete-conversation-button");
|
deleteButton.classList.add("delete-conversation-button");
|
||||||
deleteButton.classList.add("three-dot-menu-button-item");
|
deleteButton.classList.add("three-dot-menu-button-item");
|
||||||
deleteButton.addEventListener('click', function() {
|
deleteButton.addEventListener('click', function() {
|
||||||
|
// Ask for confirmation before deleting chat session
|
||||||
|
let confirmation = confirm('Are you sure you want to delete this chat session?');
|
||||||
|
if (!confirmation) return;
|
||||||
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
|
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
|
||||||
fetch(deleteURL , { method: "DELETE" })
|
fetch(deleteURL , { method: "DELETE" })
|
||||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||||
|
@ -1379,7 +1384,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
chatBody.dataset.conversationId = "";
|
chatBody.dataset.conversationId = "";
|
||||||
chatBody.dataset.conversationTitle = "";
|
chatBody.dataset.conversationTitle = "";
|
||||||
loadChat();
|
loadChat();
|
||||||
flashStatusInChatInput("🗑 Cleared previous conversation history");
|
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
flashStatusInChatInput("⛔️ Failed to clear conversation history");
|
flashStatusInChatInput("⛔️ Failed to clear conversation history");
|
||||||
|
|
|
@ -47,7 +47,8 @@ def extract_questions(
|
||||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||||
|
|
||||||
prompt = prompts.extract_questions.format(
|
prompt = prompts.extract_questions.format(
|
||||||
current_date=today.strftime("%A, %Y-%m-%d"),
|
current_date=today.strftime("%Y-%m-%d"),
|
||||||
|
day_of_week=today.strftime("%A"),
|
||||||
last_new_year=last_new_year.strftime("%Y"),
|
last_new_year=last_new_year.strftime("%Y"),
|
||||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||||
|
|
|
@ -253,8 +253,8 @@ You are Khoj, an extremely smart and helpful search assistant with the ability t
|
||||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||||
|
|
||||||
What searches will you need to perform to answer the users question? Respond with search queries as list of strings in a JSON object.
|
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
|
||||||
Current Date: {current_date}
|
Current Date: {day_of_week}, {current_date}
|
||||||
User's Location: {location}
|
User's Location: {location}
|
||||||
|
|
||||||
Q: How was my trip to Cambodia?
|
Q: How was my trip to Cambodia?
|
||||||
|
@ -418,7 +418,7 @@ You are Khoj, an advanced google search assistant. You are tasked with construct
|
||||||
- You will receive the conversation history as context.
|
- You will receive the conversation history as context.
|
||||||
- Add as much context from the previous questions and answers as required into your search queries.
|
- Add as much context from the previous questions and answers as required into your search queries.
|
||||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||||
- Use site: and after: google search operators when appropriate
|
- Use site: google search operators when appropriate
|
||||||
- You have access to the the whole internet to retrieve information.
|
- You have access to the the whole internet to retrieve information.
|
||||||
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev
|
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev
|
||||||
|
|
||||||
|
@ -433,7 +433,7 @@ User: I like to use Hacker News to get my tech news.
|
||||||
AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups.
|
AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups.
|
||||||
|
|
||||||
Q: Summarize posts about vector databases on Hacker News since Feb 2024
|
Q: Summarize posts about vector databases on Hacker News since Feb 2024
|
||||||
Khoj: {{"queries": ["site:news.ycombinator.com after:2024/02/01 vector database"]}}
|
Khoj: {{"queries": ["site:news.ycombinator.com vector database since 1 February 2024"]}}
|
||||||
|
|
||||||
History:
|
History:
|
||||||
User: I'm currently living in New York but I'm thinking about moving to San Francisco.
|
User: I'm currently living in New York but I'm thinking about moving to San Francisco.
|
||||||
|
|
|
@ -199,19 +199,26 @@ def truncate_messages(
|
||||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message = messages.pop()
|
# Extract system message from messages
|
||||||
assert type(system_message.content) == str
|
system_message = None
|
||||||
system_message_tokens = len(encoder.encode(system_message.content))
|
for idx, message in enumerate(messages):
|
||||||
|
if message.role == "system":
|
||||||
|
system_message = messages.pop(idx)
|
||||||
|
break
|
||||||
|
|
||||||
|
system_message_tokens = (
|
||||||
|
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||||
|
)
|
||||||
|
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||||
|
|
||||||
|
# Drop older messages until under max supported prompt size by model
|
||||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||||
messages.pop()
|
messages.pop()
|
||||||
assert type(system_message.content) == str
|
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||||
|
|
||||||
# Truncate current message if still over max supported prompt size by model
|
# Truncate current message if still over max supported prompt size by model
|
||||||
if (tokens + system_message_tokens) > max_prompt_size:
|
if (tokens + system_message_tokens) > max_prompt_size:
|
||||||
assert type(system_message.content) == str
|
|
||||||
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
||||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||||
original_question = f"\n{original_question}"
|
original_question = f"\n{original_question}"
|
||||||
|
@ -223,7 +230,7 @@ def truncate_messages(
|
||||||
)
|
)
|
||||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||||
|
|
||||||
return messages + [system_message]
|
return messages + [system_message] if system_message else messages
|
||||||
|
|
||||||
|
|
||||||
def reciprocal_conversation_to_chatml(message_pair):
|
def reciprocal_conversation_to_chatml(message_pair):
|
||||||
|
|
|
@ -369,6 +369,7 @@ async def extract_references_and_questions(
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
with timer("Searching knowledge base took", logger):
|
with timer("Searching knowledge base took", logger):
|
||||||
result_list = []
|
result_list = []
|
||||||
|
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
||||||
for query in inferred_queries:
|
for query in inferred_queries:
|
||||||
n_items = min(n, 3) if using_offline_chat else n
|
n_items = min(n, 3) if using_offline_chat else n
|
||||||
result_list.extend(
|
result_list.extend(
|
||||||
|
|
|
@ -455,7 +455,7 @@ async def chat(
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
d: Optional[float] = 0.18,
|
d: Optional[float] = 0.18,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
slug: Optional[str] = None,
|
title: Optional[str] = None,
|
||||||
conversation_id: Optional[int] = None,
|
conversation_id: Optional[int] = None,
|
||||||
city: Optional[str] = None,
|
city: Optional[str] = None,
|
||||||
region: Optional[str] = None,
|
region: Optional[str] = None,
|
||||||
|
@ -482,9 +482,13 @@ async def chat(
|
||||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||||
|
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
user, request.user.client_app, conversation_id, slug
|
user, request.user.client_app, conversation_id, title
|
||||||
)
|
)
|
||||||
|
if not conversation:
|
||||||
|
return Response(
|
||||||
|
content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
|
||||||
|
)
|
||||||
|
else:
|
||||||
meta_log = conversation.conversation_log
|
meta_log = conversation.conversation_log
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default]:
|
if conversation_commands == [ConversationCommand.Default]:
|
||||||
|
@ -557,7 +561,7 @@ async def chat(
|
||||||
intent_type=intent_type,
|
intent_type=intent_type,
|
||||||
inferred_queries=[improved_image_prompt],
|
inferred_queries=[improved_image_prompt],
|
||||||
client_application=request.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation.id,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
)
|
)
|
||||||
|
@ -575,7 +579,7 @@ async def chat(
|
||||||
conversation_commands,
|
conversation_commands,
|
||||||
user,
|
user,
|
||||||
request.user.client_app,
|
request.user.client_app,
|
||||||
conversation_id,
|
conversation.id,
|
||||||
location,
|
location,
|
||||||
user_name,
|
user_name,
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,49 +19,80 @@ class TestTruncateMessage:
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
def test_truncate_message_all_small(self):
|
def test_truncate_message_all_small(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(500)
|
# Arrange
|
||||||
|
chat_history = ChatMessageFactory.build_batch(500)
|
||||||
|
|
||||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
# Act
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_messages) < 500
|
assert len(chat_history) < 500
|
||||||
assert len(chat_messages) > 1
|
assert len(chat_history) > 1
|
||||||
assert tokens <= self.max_prompt_size
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
def test_truncate_message_first_large(self):
|
def test_truncate_message_first_large(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(25)
|
# Arrange
|
||||||
|
chat_history = ChatMessageFactory.build_batch(25)
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_messages.insert(0, big_chat_message)
|
chat_history.insert(0, big_chat_message)
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||||
|
|
||||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
# Act
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_messages) == 1
|
assert len(chat_history) == 1
|
||||||
assert prompt[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
assert tokens <= self.max_prompt_size
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
def test_truncate_message_last_large(self):
|
def test_truncate_message_last_large(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(25)
|
# Arrange
|
||||||
|
chat_history = ChatMessageFactory.build_batch(25)
|
||||||
|
chat_history[0].role = "system" # Mark the first message as system message
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
|
||||||
chat_messages.insert(0, big_chat_message)
|
chat_history.insert(0, big_chat_message)
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||||
|
|
||||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
# Act
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties.
|
# The original object has been modified. Verify certain properties.
|
||||||
assert len(prompt) == (
|
assert len(truncated_chat_history) == (
|
||||||
len(chat_messages) + 1
|
len(chat_history) + 1
|
||||||
) # Because the system_prompt is popped off from the chat_messages lsit
|
) # Because the system_prompt is popped off from the chat_messages lsit
|
||||||
assert len(prompt) < 26
|
assert len(truncated_chat_history) < 26
|
||||||
assert len(prompt) > 1
|
assert len(truncated_chat_history) > 1
|
||||||
assert prompt[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
assert tokens <= self.max_prompt_size
|
assert initial_tokens > self.max_prompt_size
|
||||||
|
assert final_tokens <= self.max_prompt_size
|
||||||
|
|
||||||
|
def test_truncate_single_large_non_system_message(self):
|
||||||
|
# Arrange
|
||||||
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
||||||
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
|
big_chat_message.role = "user"
|
||||||
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
chat_messages = [big_chat_message]
|
||||||
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert initial_tokens > self.max_prompt_size
|
||||||
|
assert final_tokens <= self.max_prompt_size
|
||||||
|
assert len(chat_messages) == 1
|
||||||
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
|
|
Loading…
Reference in a new issue