Merge branch 'features/add-agents-ui' of github.com:khoj-ai/khoj into features/chat-socket-streaming

This commit is contained in:
sabaimran 2024-03-23 11:22:00 +05:30
commit 6b4c4f10b5
14 changed files with 162 additions and 76 deletions

View file

@ -1,7 +1,7 @@
import multiprocessing
bind = "0.0.0.0:42110"
workers = 4
workers = 8
worker_class = "uvicorn.workers.UvicornWorker"
timeout = 120
keep_alive = 60

View file

@ -794,7 +794,6 @@
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
loadChat();
flashStatusInChatInput("🗑 Cleared previous conversation history");
})
.catch(err => {
flashStatusInChatInput("⛔️ Failed to clear conversation history");
@ -856,28 +855,6 @@
let conversationMenu = document.createElement('div');
conversationMenu.classList.add("conversation-menu");
let deleteButton = document.createElement('button');
deleteButton.innerHTML = "Delete";
deleteButton.classList.add("delete-conversation-button");
deleteButton.classList.add("three-dot-menu-button-item");
deleteButton.addEventListener('click', function() {
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers })
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => {
let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = "";
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
loadChat();
})
.catch(err => {
return;
});
});
conversationMenu.appendChild(deleteButton);
threeDotMenu.appendChild(conversationMenu);
let editTitleButton = document.createElement('button');
editTitleButton.innerHTML = "Rename";
editTitleButton.classList.add("edit-title-button");
@ -903,12 +880,13 @@
conversationTitleInput.addEventListener('click', function(event) {
event.stopPropagation();
});
conversationTitleInput.addEventListener('keydown', function(event) {
if (event.key === "Enter") {
event.preventDefault();
conversationTitleInputButton.click();
}
});
conversationTitleInputBox.appendChild(conversationTitleInput);
let conversationTitleInputButton = document.createElement('button');
conversationTitleInputButton.innerHTML = "Save";
@ -918,7 +896,7 @@
let newTitle = conversationTitleInput.value;
if (newTitle != null) {
let editURL = `/api/chat/title?client=web&conversation_id=${incomingConversationId}&title=${newTitle}`;
fetch(`${hostURL}${editURL}` , { method: "PATCH" })
fetch(`${hostURL}${editURL}` , { method: "PATCH", headers })
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => {
conversationButton.textContent = newTitle;
@ -931,8 +909,35 @@
conversationTitleInputBox.appendChild(conversationTitleInputButton);
conversationMenu.appendChild(conversationTitleInputBox);
});
conversationMenu.appendChild(editTitleButton);
threeDotMenu.appendChild(conversationMenu);
let deleteButton = document.createElement('button');
deleteButton.innerHTML = "Delete";
deleteButton.classList.add("delete-conversation-button");
deleteButton.classList.add("three-dot-menu-button-item");
deleteButton.addEventListener('click', function() {
// Ask for confirmation before deleting chat session
let confirmation = confirm('Are you sure you want to delete this chat session?');
if (!confirmation) return;
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
fetch(`${hostURL}${deleteURL}` , { method: "DELETE", headers })
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => {
let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = "";
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
loadChat();
})
.catch(err => {
return;
});
});
conversationMenu.appendChild(deleteButton);
threeDotMenu.appendChild(conversationMenu);
});
threeDotMenu.appendChild(threeDotMenuButton);
conversationButton.appendChild(threeDotMenu);

View file

@ -363,4 +363,4 @@ def upload_telemetry():
@schedule.repeat(schedule.every(31).minutes)
def delete_old_user_requests():
num_deleted = delete_user_requests()
logger.info(f"🗑️ Deleted {num_deleted[0]} day-old user requests")
logger.debug(f"🗑️ Deleted {num_deleted[0]} day-old user requests")

View file

@ -470,7 +470,7 @@ class ConversationAdapters:
@staticmethod
def get_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
):
) -> Optional[Conversation]:
if conversation_id:
conversation = (
Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
@ -518,19 +518,21 @@ class ConversationAdapters:
@staticmethod
async def aget_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, slug: str = None
):
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, title: str = None
) -> Optional[Conversation]:
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
elif slug:
conversation = Conversation.objects.filter(user=user, client=client_application, slug=slug)
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).afirst()
elif title:
return await Conversation.objects.filter(user=user, client=client_application, title=title).afirst()
else:
conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
if await conversation.aexists():
return await conversation.prefetch_related("agent").afirst()
return await Conversation.objects.acreate(user=user, client=client_application, slug=slug)
return await (
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
) or await Conversation.objects.acreate(user=user, client=client_application)
@staticmethod
async def adelete_conversation_by_user(

View file

@ -0,0 +1,17 @@
# Generated by Django 4.2.10 on 2024-03-15 10:04
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0030_conversation_slug_and_title"),
]
operations = [
migrations.AlterField(
model_name="googleuser",
name="locale",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
]

View file

@ -0,0 +1,14 @@
# Generated by Django 4.2.10 on 2024-03-22 04:27
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0031_agent_conversation_agent"),
("database", "0031_alter_googleuser_locale"),
]
operations: List[str] = []

View file

@ -47,7 +47,7 @@ class GoogleUser(models.Model):
given_name = models.CharField(max_length=200, null=True, default=None, blank=True)
family_name = models.CharField(max_length=200, null=True, default=None, blank=True)
picture = models.CharField(max_length=200, null=True, default=None)
locale = models.CharField(max_length=200)
locale = models.CharField(max_length=200, null=True, default=None, blank=True)
def __str__(self):
return self.name

View file

@ -1334,6 +1334,8 @@ To get started, just start typing below. You can also type / to see a list of co
conversationTitleInput.addEventListener('click', function(event) {
event.stopPropagation();
});
conversationTitleInput.addEventListener('keydown', function(event) {
if (event.key === "Enter") {
event.preventDefault();
conversationTitleInputButton.click();
@ -1370,6 +1372,9 @@ To get started, just start typing below. You can also type / to see a list of co
deleteButton.classList.add("delete-conversation-button");
deleteButton.classList.add("three-dot-menu-button-item");
deleteButton.addEventListener('click', function() {
// Ask for confirmation before deleting chat session
let confirmation = confirm('Are you sure you want to delete this chat session?');
if (!confirmation) return;
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
fetch(deleteURL , { method: "DELETE" })
.then(response => response.ok ? response.json() : Promise.reject(response))
@ -1379,7 +1384,6 @@ To get started, just start typing below. You can also type / to see a list of co
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
loadChat();
flashStatusInChatInput("🗑 Cleared previous conversation history");
})
.catch(err => {
flashStatusInChatInput("⛔️ Failed to clear conversation history");

View file

@ -47,7 +47,8 @@ def extract_questions(
last_new_year = current_new_year.replace(year=today.year - 1)
prompt = prompts.extract_questions.format(
current_date=today.strftime("%A, %Y-%m-%d"),
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),

View file

@ -253,8 +253,8 @@ You are Khoj, an extremely smart and helpful search assistant with the ability t
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
What searches will you need to perform to answer the users question? Respond with search queries as list of strings in a JSON object.
Current Date: {current_date}
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
Q: How was my trip to Cambodia?
@ -418,7 +418,7 @@ You are Khoj, an advanced google search assistant. You are tasked with construct
- You will receive the conversation history as context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Use site: and after: google search operators when appropriate
- Use site: google search operators when appropriate
- You have access to the the whole internet to retrieve information.
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev
@ -433,7 +433,7 @@ User: I like to use Hacker News to get my tech news.
AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups.
Q: Summarize posts about vector databases on Hacker News since Feb 2024
Khoj: {{"queries": ["site:news.ycombinator.com after:2024/02/01 vector database"]}}
Khoj: {{"queries": ["site:news.ycombinator.com vector database since 1 February 2024"]}}
History:
User: I'm currently living in New York but I'm thinking about moving to San Francisco.

View file

@ -199,19 +199,26 @@ def truncate_messages(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
system_message = messages.pop()
assert type(system_message.content) == str
system_message_tokens = len(encoder.encode(system_message.content))
# Extract system message from messages
system_message = None
for idx, message in enumerate(messages):
if message.role == "system":
system_message = messages.pop(idx)
break
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Drop older messages until under max supported prompt size by model
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
assert type(system_message.content) == str
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
assert type(system_message.content) == str
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question = f"\n{original_question}"
@ -223,7 +230,7 @@ def truncate_messages(
)
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages + [system_message]
return messages + [system_message] if system_message else messages
def reciprocal_conversation_to_chatml(message_pair):

View file

@ -369,6 +369,7 @@ async def extract_references_and_questions(
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):
result_list = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n
result_list.extend(

View file

@ -455,7 +455,7 @@ async def chat(
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
slug: Optional[str] = None,
title: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
@ -482,10 +482,14 @@ async def chat(
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
conversation = await ConversationAdapters.aget_conversation_by_user(
user, request.user.client_app, conversation_id, slug
user, request.user.client_app, conversation_id, title
)
meta_log = conversation.conversation_log
if not conversation:
return Response(
content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400
)
else:
meta_log = conversation.conversation_log
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
@ -557,7 +561,7 @@ async def chat(
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
conversation_id=conversation.id,
compiled_references=compiled_references,
online_results=online_results,
)
@ -575,7 +579,7 @@ async def chat(
conversation_commands,
user,
request.user.client_app,
conversation_id,
conversation.id,
location,
user_name,
)

View file

@ -19,49 +19,80 @@ class TestTruncateMessage:
encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self):
chat_messages = ChatMessageFactory.build_batch(500)
# Arrange
chat_history = ChatMessageFactory.build_batch(500)
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 500
assert len(chat_messages) > 1
assert len(chat_history) < 500
assert len(chat_history) > 1
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
# Arrange
chat_history = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_messages) == 1
assert prompt[0] != copy_big_chat_message
assert len(chat_history) == 1
assert truncated_chat_history[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
# Arrange
chat_history = ChatMessageFactory.build_batch(25)
chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
chat_history.insert(0, big_chat_message)
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties.
assert len(prompt) == (
len(chat_messages) + 1
assert len(truncated_chat_history) == (
len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages lsit
assert len(prompt) < 26
assert len(prompt) > 1
assert prompt[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
assert len(truncated_chat_history) < 26
assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
def test_truncate_single_large_non_system_message(self):
# Arrange
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
big_chat_message.role = "user"
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
# Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message