Misc. chat and application improvements (#652)

* Document original query when subqueries can't be generated
* Only add messages to the chat message log if it's non-empty
* When changing the search model, alert the user that all underlying data will be deleted
* Adding more clarification to the prompt input for username, location
* Check if has_more is in the notion results before getting next_cursor
* Update prompt template for user name/location, update confirmation message when changing search model
This commit is contained in:
sabaimran 2024-02-22 19:09:22 -08:00 committed by GitHub
parent f8ec6b4464
commit b4902090e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 36 additions and 13 deletions

View file

@ -378,6 +378,13 @@ async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
return new_config
async def aget_user_search_model(user: KhojUser):
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
class ClientApplicationAdapters:
@staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str):
@ -639,6 +646,12 @@ class EntryAdapters:
deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete()
return deleted_count
@staticmethod
async def adelete_all_entries(user: KhojUser, file_source: str = None):
if file_source is None:
return await Entry.objects.filter(user=user).adelete()
return await Entry.objects.filter(user=user, file_source=file_source).adelete()
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@ -674,10 +687,6 @@ class EntryAdapters:
.values_list("file_path", flat=True)
)
@staticmethod
async def adelete_all_entries(user: KhojUser):
return await Entry.objects.filter(user=user).adelete()
@staticmethod
def get_size_of_indexed_data_in_mb(user: KhojUser):
entries = Entry.objects.filter(user=user).iterator()

View file

@ -376,6 +376,11 @@
};
function updateSearchModel() {
let confirmation = window.confirm("All your existing data will be deleted, and you will have to regenerate it. Are you sure you want to continue?");
if (!confirmation) {
return;
}
const searchModel = document.getElementById("search-models").value;
const saveSearchModelButton = document.getElementById("save-search-model");
saveSearchModelButton.disabled = true;
@ -398,7 +403,7 @@
}
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "When updating the language model, be sure to delete all your saved content and re-initialize.";
notificationBanner.innerHTML = "Khoj can now better understand the language of your content! Manually sync your data from one of the Khoj clients to update your knowledge base.";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";

View file

@ -93,7 +93,7 @@ class NotionToEntries(TextToEntries):
json=self.body_params,
).json()
responses.append(result)
if result["has_more"] == False:
if result.get("has_more", False) == False:
break
else:
self.body_params.update({"start_cursor": result["next_cursor"]})

View file

@ -441,12 +441,14 @@ You are using the **{model}** model on the **{device}**.
# --
user_location = PromptTemplate.from_template(
"""
Mention the user's location only if it's relevant to the conversation.
User's Location: {location}
""".strip()
)
user_name = PromptTemplate.from_template(
"""
Mention the user's name only if it's relevant to the conversation.
User's Name: {name}
""".strip()
)

View file

@ -11,7 +11,7 @@ from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser
from khoj.utils.helpers import merge_dicts
from khoj.utils.helpers import is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__)
model_to_prompt_size = {
@ -159,10 +159,13 @@ def generate_chatml_messages_with_context(
rest_backnforths += reciprocal_conversation_to_chatml([user_msg, assistant_msg])[::-1]
# Format user and system messages to chatml format
system_chatml_message = [ChatMessage(content=system_message, role="system")]
user_chatml_message = [ChatMessage(content=user_message, role="user")]
messages = user_chatml_message + rest_backnforths + system_chatml_message
messages = []
if not is_none_or_empty(user_message):
messages.append(ChatMessage(content=user_message, role="user"))
if len(rest_backnforths) > 0:
messages += rest_backnforths
if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system"))
# Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)

View file

@ -262,8 +262,12 @@ async def update_search_model(
):
user = request.user.object
prev_config = await adapters.aget_user_search_model(user)
new_config = await adapters.aset_user_search_model(user, int(id))
if int(id) != prev_config.id:
await EntryAdapters.adelete_all_entries(user)
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:

View file

@ -208,11 +208,11 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
response = json.loads(response)
response = [q.strip() for q in response if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}")
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}")
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]