New Feature: Adding File Filtering to Conversations (#788)

* UI update for file filtered conversations
* Interactive file menu #UI to add/remove files on each conversation as references.
* Backend changes implemented to load selected file filters from a conversation into the querying process.
---------

Co-authored-by: sabaimran <narmiabas@gmail.com>
This commit is contained in:
Raghav Tirumale 2024-06-07 01:23:37 -04:00 committed by GitHub
parent 8d701ebe22
commit ba16afd3c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 399 additions and 7 deletions

View file

@ -362,6 +362,7 @@
button.sync-data:hover { button.sync-data:hover {
background-color: var(--summer-sun); background-color: var(--summer-sun);
box-shadow: 0px 3px 0px var(--background-color); box-shadow: 0px 3px 0px var(--background-color);
cursor: pointer;
} }
.sync-force-toggle { .sync-force-toggle {
align-content: center; align-content: center;

View file

@ -0,0 +1,17 @@
# Generated by Django 4.2.10 on 2024-05-29 19:56
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0043_alter_chatmodeloptions_model_type"),
]
operations = [
migrations.AddField(
model_name="conversation",
name="file_filters",
field=models.JSONField(default=list),
),
]

View file

@ -258,6 +258,7 @@ class Conversation(BaseModel):
slug = models.CharField(max_length=200, default=None, null=True, blank=True) slug = models.CharField(max_length=200, default=None, null=True, blank=True)
title = models.CharField(max_length=200, default=None, null=True, blank=True) title = models.CharField(max_length=200, default=None, null=True, blank=True)
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True) agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)
file_filters = models.JSONField(default=list)
class PublicConversation(BaseModel): class PublicConversation(BaseModel):

View file

@ -900,8 +900,15 @@ To get started, just start typing below. You can also type / to see a list of co
} }
// Display indexing success message // Display indexing success message
flashStatusInChatInput("✅ File indexed successfully"); flashStatusInChatInput("✅ File indexed successfully");
renderAllFiles();
//get the selected-files ul first
var selectedFiles = document.getElementsByClassName("selected-files")[0];
const escapedFileName = fileName.replace(/\./g, '\\.');
const newFile = selectedFiles.querySelector(`#${escapedFileName}`);
if(!newFile){
addFileFilterToConversation(fileName);
loadFileFiltersFromConversation();
}
}) })
.catch((error) => { .catch((error) => {
console.log(error); console.log(error);
@ -1135,6 +1142,7 @@ To get started, just start typing below. You can also type / to see a list of co
if (chatBody.dataset.conversationId) { if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`; chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
setupWebSocket(); setupWebSocket();
loadFileFiltersFromConversation();
} }
if (window.screen.width < 700) { if (window.screen.width < 700) {
@ -1172,6 +1180,7 @@ To get started, just start typing below. You can also type / to see a list of co
// Render conversation history, if any // Render conversation history, if any
let chatBody = document.getElementById("chat-body"); let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = response.conversation_id; chatBody.dataset.conversationId = response.conversation_id;
loadFileFiltersFromConversation();
setupWebSocket(); setupWebSocket();
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`; chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
@ -1320,7 +1329,6 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById("chat-input").value = query_via_url; document.getElementById("chat-input").value = query_via_url;
chat(); chat();
} }
} }
function fetchRemainingChatMessages(chatHistoryUrl) { function fetchRemainingChatMessages(chatHistoryUrl) {
@ -1813,6 +1821,230 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById('existing-conversations').classList.toggle('collapsed'); document.getElementById('existing-conversations').classList.toggle('collapsed');
document.getElementById('side-panel-collapse').style.transform = document.getElementById('side-panel').classList.contains('collapsed') ? 'rotate(0deg)' : 'rotate(180deg)'; document.getElementById('side-panel-collapse').style.transform = document.getElementById('side-panel').classList.contains('collapsed') ? 'rotate(0deg)' : 'rotate(180deg)';
} }
var allFiles;
function renderAllFiles() {
fetch('/api/config/data/computer')
.then(response => response.json())
.then(data => {
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
indexedFiles.innerHTML = "";
for (var filename of data) {
var listItem = document.createElement("li");
listItem.className = "fileName";
listItem.id = filename;
listItem.textContent = filename;
listItem.addEventListener('click', function() {
handleFileClick(this.id);
});
indexedFiles.appendChild(listItem);
}
allFiles = data;
var nofilesmessage = document.getElementsByClassName("no-files-message")[0];
if(allFiles.length === 0){
nofilesmessage.innerHTML = "No files found. Visit ";
nofilesmessage.innerHTML += "<a href=\"https://docs.khoj.dev/category/clients/\">Documentation</a>"
}
else{
nofilesmessage.innerHTML = "";
}
})
.catch((error) => {
console.error('Error:', error);
});
}
function renderFilteredFiles(){
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
indexedFiles.innerHTML = "";
var input = document.getElementsByClassName("file-input")[0];
var filter = input.value.toUpperCase();
for (var filename of allFiles) {
if (filename.toUpperCase().indexOf(filter) > -1) {
var listItem = document.createElement("li");
listItem.className = "fileName";
listItem.id = filename;
listItem.textContent = filename;
// Add an event listener for the click event
listItem.addEventListener('click', function() {
handleFileClick(this.id);
});
// Append the list item to the indexed files container
indexedFiles.appendChild(listItem);
}
}
}
function handleFileClick(elementId) {
var element = document.getElementById(elementId);
if (element) {
var selectedFiles = document.getElementsByClassName("selected-files")[0];
var selectedFile = document.getElementById(elementId);
// Check if the element has a background color indicating selection
if (element.style.backgroundColor === "var(--primary-hover)") {
// Remove the file filter from the conversation
removeFileFilterFromConversation(elementId);
// Remove the selected file from the list of selected files
if (selectedFile) {
selectedFiles.removeChild(selectedFile);
}
var selectedFile = document.getElementById(elementId);
selectedFile.style.backgroundColor = "var(--primary)";
selectedFile.style.border = "1px solid var(--primary-hover)";
} else {
// If the element is not selected, select it
element.style.backgroundColor = "var(--primary-hover)"; // Set background color
element.style.border = "3px solid orange"; // Set border
// Add the file filter to the conversation
addFileFilterToConversation(elementId);
// Add the selected file to the list of selected files
var li = document.createElement("li");
li.className = "fileName";
li.id = elementId;
li.style.backgroundColor = "var(--primary-hover)"; // match the style
li.style.border = "3px solid orange"; // match the style
li.innerText = elementId;
selectedFiles.appendChild(li);
}
} else {
console.error('Element with id', elementId, 'not found.');
}
}
function addFileFilterToConversation(filename) {
var conversation_id = document.getElementById("chat-body").dataset.conversationId;
if (!conversation_id) {
console.error("Conversation ID not found on chat-body element.");
return;
}
return fetch(`/api/chat/conversation/file-filters`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ filename, conversation_id }) // Pass the filename directly
})
.then(response => {
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return response.json();
})
.then(data => {
console.log("Response from server:", data);
return data;
})
.catch(error => {
console.error("Error:", error);
throw error;
});
}
function removeFileFilterFromConversation(filename) {
var conversation_id = document.getElementById("chat-body").dataset.conversationId;
if (!conversation_id) {
console.error("Conversation ID not found on chat-body element.");
return;
}
return fetch(`/api/chat/conversation/file-filters`, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ filename, conversation_id }) // Pass the filename directly
})
.then(response => {
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return response.json();
})
.then(data => {
console.log("Response from server:", data);
return data;
})
.catch(error => {
console.error("Error:", error);
throw error;
});
}
function getFileFiltersFromConversation() {
// Get the conversation_id from the data attribute
var conversation_id = document.getElementById("chat-body").dataset.conversationId;
// Make sure conversation_id is not undefined or null
if (!conversation_id) {
console.error("No conversation ID found on chat-body element.");
return Promise.reject("No conversation ID found on chat-body element.");
}
// Perform the fetch request
return fetch(`/api/chat/conversation/file-filters/${conversation_id}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json'
}
})
.then(function(response) {
console.log("Response status:", response.status); // Log the response status
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return response.json();
})
.then(function(data) {
console.log("Response from server:", data);
return data;
})
.catch(function(error) {
console.error("Error:", error);
throw error; // Rethrow the error to be handled elsewhere if needed
});
}
function loadFileFiltersFromConversation(){
getFileFiltersFromConversation()
.then(filters => {
var selectedFiles = document.getElementsByClassName("selected-files")[0];
selectedFiles.innerHTML = "";
for (var filter of filters) {
var li = document.createElement("li");
li.className = "fileName";
li.id = filter;
li.style.backgroundColor = "var(--primary-hover)"; // set background to orange
li.style.border = "2px solid orange"; // set border to orange
li.innerText = filter;
selectedFiles.appendChild(li);
}
//update indexed files to have checkmark if they are in the filters
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
indexedFiles.innerHTML = "";
for (var filename of allFiles) {
var li = document.createElement("li");
li.className = "fileName";
li.id = filename;
li.innerText = filename;
if (filters.includes(filename)) {
li.style.backgroundColor = "var(--primary-hover)"; // set background to orange
li.style.border = "2px solid orange"; // set border to orange
}
li.setAttribute("onclick", "handleFileClick('" + filename + "')");
indexedFiles.appendChild(li);
}
})
.catch(error => {
// Handle any errors that occur during the fetch operation
console.error("Error loading file filters:", error);
});
}
</script> </script>
<body> <body>
<div id="khoj-empty-container" class="khoj-empty-container"> <div id="khoj-empty-container" class="khoj-empty-container">
@ -1842,6 +2074,64 @@ To get started, just start typing below. You can also type / to see a list of co
<div id="connection-status-icon" style="width: 10px; height: 10px; border-radius: 50%; margin-right: 5px;"></div> <div id="connection-status-icon" style="width: 10px; height: 10px; border-radius: 50%; margin-right: 5px;"></div>
<div id="connection-status-text" style="margin: 5px;"></div> <div id="connection-status-text" style="margin: 5px;"></div>
</div> </div>
<div style="border-top: 1px solid black; ">
<div style="display: flex; align-items: center; justify-content: space-between; margin-bottom: 5px; margin-top: 5px;">
<p style="margin: 0;">Files</p>
<svg class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
<path d="M16 0c-8.836 0-16 7.163-16 16s7.163 16 16 16c8.837 0 16-7.163 16-16s-7.163-16-16-16zM16 30.032c-7.72 0-14-6.312-14-14.032s6.28-14 14-14 14 6.28 14 14-6.28 14.032-14 14.032zM23 15h-6v-6c0-0.552-0.448-1-1-1s-1 0.448-1 1v6h-6c-0.552 0-1 0.448-1 1s0.448 1 1 1h6v6c0 0.552 0.448 1 1 1s1-0.448 1-1v-6h6c0.552 0 1-0.448 1-1s-0.448-1-1-1z"></path>
</svg>
</div>
<div class="no-files-message"></div>
<ul class="selected-files" style="margin: 0; padding: 0; margin-bottom: 10px"></ul>
<input class="file-input" style="width:240px; margin-bottom: 5px; color: black; display: none; border-radius: 4px; border: 1px solid black; padding: 4px;" type="text" onkeyup="renderFilteredFiles()" placeholder="Filter">
<ul class="indexed-files" style="margin: 0; padding: 0; height:100px; overflow:hidden; overflow-y:scroll; margin-bottom:5px; display:none;"></ul>
<script>
renderAllFiles();
var fileInputs = document.getElementsByClassName('file-input');
var fileLists = document.getElementsByClassName('indexed-files');
var selectedFileLists = document.getElementsByClassName('selected-files');
var fileToggleButtons = document.getElementsByClassName('file-toggle-button');
var fileInput = fileInputs[0];
var fileList = fileLists[0];
var selectedFileList = selectedFileLists[0];
var fileToggleButton = fileToggleButtons[0];
function toggleFileInput() {
if (fileInput.style.display === 'none' || fileInput.style.display === '') {
fileInput.style.display = 'block';
fileList.style.display = 'block';
selectedFileList.style.display = 'none';
} else {
fileInput.style.display = 'none';
fileList.style.display = 'none';
selectedFileList.style.display = 'block';
}
}
fileToggleButton.addEventListener('click', function(event) {
toggleFileInput();
event.stopPropagation();
});
document.addEventListener('click', function(event) {
if (!fileInput.contains(event.target) && !fileToggleButton.contains(event.target)) {
fileInput.style.display = 'none';
fileList.style.display = 'none';
selectedFileList.style.display = 'block';
}
});
fileInput.addEventListener('click', function(event) {
event.stopPropagation(); // Prevent the document click handler from immediately hiding the input
});
fileList.addEventListener('click', function(event) {
event.stopPropagation(); // Prevent the document click handler from hiding the file list
});
</script>
</div>
<a id="agent-link" class="inline-chat-link" href=""> <a id="agent-link" class="inline-chat-link" href="">
<div id="agent-metadata" style="display: none;"> <div id="agent-metadata" style="display: none;">
Current Agent Current Agent
@ -1962,6 +2252,25 @@ To get started, just start typing below. You can also type / to see a list of co
margin: 10px; margin: 10px;
} }
li.fileName {
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
max-width: 250px;
background-color: var(--primary);
border-radius: 10px;
border: 1px solid var(--primary-hover);
margin-bottom: 4px;
margin-top: 4px;
padding-top: 4px;
padding-bottom: 4px;
}
li.fileName:hover {
background-color: var(--primary-hover);
cursor: pointer;
}
div.expanded.reference-section { div.expanded.reference-section {
display: grid; display: grid;
grid-template-rows: auto; grid-template-rows: auto;
@ -2596,6 +2905,11 @@ To get started, just start typing below. You can also type / to see a list of co
margin-left: 5px; margin-left: 5px;
} }
svg.file-toggle-button:hover {
background: var(--primary-hover);
cursor: pointer;
}
div#new-conversation { div#new-conversation {
display: grid; display: grid;
grid-auto-flow: column; grid-auto-flow: column;

View file

@ -283,6 +283,7 @@ async def extract_references_and_questions(
q: str, q: str,
n: int, n: int,
d: float, d: float,
conversation_id: int,
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None, location_data: LocationData = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
@ -308,8 +309,10 @@ async def extract_references_and_questions(
for filter in [DateFilter(), WordFilter(), FileFilter()]: for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query) defiltered_query = filter.defilter(defiltered_query)
filters_in_query = q.replace(defiltered_query, "").strip() filters_in_query = q.replace(defiltered_query, "").strip()
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False using_offline_chat = False
print(f"Filters in query: {filters_in_query}")
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):

View file

@ -57,7 +57,7 @@ from khoj.utils.helpers import (
get_device, get_device,
is_none_or_empty, is_none_or_empty,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import FilterRequest, LocationData
# Initialize Router # Initialize Router
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,6 +73,57 @@ from pydantic import BaseModel
from khoj.routers.email import send_query_feedback from khoj.routers.email import send_query_feedback
@api_chat.get("/conversation/file-filters/{conversation_id}", response_class=Response)
@requires(["authenticated"])
def get_file_filter(request: Request, conversation_id: str) -> Response:
conversation = ConversationAdapters.get_conversation_by_user(
request.user.object, conversation_id=int(conversation_id)
)
# get all files from "computer"
file_list = EntryAdapters.get_all_filenames_by_source(request.user.object, "computer")
file_filters = []
for file in conversation.file_filters:
if file in file_list:
file_filters.append(file)
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
@api_chat.post("/conversation/file-filters", response_class=Response)
@requires(["authenticated"])
def add_file_filter(request: Request, filter: FilterRequest):
try:
conversation = ConversationAdapters.get_conversation_by_user(
request.user.object, conversation_id=int(filter.conversation_id)
)
file_list = EntryAdapters.get_all_filenames_by_source(request.user.object, "computer")
if filter.filename in file_list and filter.filename not in conversation.file_filters:
conversation.file_filters.append(filter.filename)
conversation.save()
# remove files from conversation.file_filters that are not in file_list
conversation.file_filters = [file for file in conversation.file_filters if file in file_list]
conversation.save()
return Response(content=json.dumps(conversation.file_filters), media_type="application/json", status_code=200)
except Exception as e:
logger.error(f"Error adding file filter {filter.filename}: {e}", exc_info=True)
raise HTTPException(status_code=422, detail=str(e))
@api_chat.delete("/conversation/file-filters", response_class=Response)
@requires(["authenticated"])
def remove_file_filter(request: Request, filter: FilterRequest) -> Response:
conversation = ConversationAdapters.get_conversation_by_user(
request.user.object, conversation_id=int(filter.conversation_id)
)
if filter.filename in conversation.file_filters:
conversation.file_filters.remove(filter.filename)
conversation.save()
# remove files from conversation.file_filters that are not in file_list
file_list = EntryAdapters.get_all_filenames_by_source(request.user.object, "computer")
conversation.file_filters = [file for file in conversation.file_filters if file in file_list]
conversation.save()
return Response(content=json.dumps(conversation.file_filters), media_type="application/json", status_code=200)
class FeedbackData(BaseModel): class FeedbackData(BaseModel):
uquery: str uquery: str
kquery: str kquery: str
@ -586,7 +637,7 @@ async def websocket_endpoint(
continue continue
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
websocket, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update websocket, meta_log, q, 7, 0.18, conversation_id, conversation_commands, location, send_status_update
) )
if compiled_references: if compiled_references:
@ -838,7 +889,7 @@ async def chat(
return Response(content=llm_response, media_type="text/plain", status_code=200) return Response(content=llm_response, media_type="text/plain", status_code=200)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location
) )
online_results: Dict[str, Dict] = {} online_results: Dict[str, Dict] = {}

View file

@ -27,6 +27,11 @@ class LocationData(BaseModel):
country: Optional[str] country: Optional[str]
class FilterRequest(BaseModel):
filename: str
conversation_id: str
class TextConfigBase(ConfigBase): class TextConfigBase(ConfigBase):
compressed_jsonl: Path compressed_jsonl: Path
embeddings_file: Path embeddings_file: Path