Merge pull request #556 from khoj-ai/features/reflective-suggested-questions

Add support for suggesting base questions to users
This commit is contained in:
sabaimran 2023-11-23 11:57:02 -08:00 committed by GitHub
commit e3b32e412c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 185 additions and 5 deletions

View file

@ -386,6 +386,9 @@
let chatInput = document.getElementById("chat-input");
chatInput.value = chatInput.value.trimStart();
let questionStarterSuggestions = document.getElementById("question-starters");
questionStarterSuggestions.style.display = "none";
if (chatInput.value.startsWith("/") && chatInput.value.split(" ").length === 1) {
let chatTooltip = document.getElementById("chat-tooltip");
chatTooltip.style.display = "block";
@ -468,6 +471,31 @@
return;
});
fetch(`${hostURL}/api/chat/starters?client=desktop`, { headers })
.then(response => response.json())
.then(data => {
// Render chat options, if any
if (data) {
let questionStarterSuggestions = document.getElementById("question-starters");
for (let index in data) {
let questionStarter = data[index];
let questionStarterButton = document.createElement('button');
questionStarterButton.innerHTML = questionStarter;
questionStarterButton.classList.add("question-starter");
questionStarterButton.addEventListener('click', function() {
questionStarterSuggestions.style.display = "none";
document.getElementById("chat-input").value = questionStarter;
chat();
});
questionStarterSuggestions.appendChild(questionStarterButton);
}
questionStarterSuggestions.style.display = "grid";
}
})
.catch(err => {
return;
});
fetch(`${hostURL}/api/chat/options`, { headers })
.then(response => response.json())
.then(data => {
@ -507,6 +535,9 @@
<!-- Chat Body -->
<div id="chat-body"></div>
<!-- Chat Suggestions -->
<div id="question-starters" style="display: none;"></div>
<!-- Chat Footer -->
<div id="chat-footer">
<div id="chat-tooltip" style="display: none;"></div>
@ -684,6 +715,28 @@
margin: 10px;
}
div#question-starters {
grid-template-columns: repeat(auto-fit, minmax(100px, 1fr));
grid-column-gap: 8px;
}
button.question-starter {
background: var(--background-color);
color: var(--main-text-color);
border: 1px solid var(--main-text-color);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
cursor: pointer;
transition: background 0.2s ease-in-out;
text-align: left;
max-height: 75px;
transition: max-height 0.3s ease-in-out;
overflow: hidden;
}
code.chat-response {
background: var(--primary-hover);
color: var(--primary-inverse);

View file

@ -1,9 +1,9 @@
import math
import random
import secrets
from datetime import date, datetime, timezone
from typing import List, Optional, Type
# Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
from django.db import models
@ -28,6 +28,9 @@ from khoj.database.models import (
SearchModelConfig,
Subscription,
UserConversationConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
ReflectiveQuestion,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@ -337,6 +340,25 @@ class ConversationAdapters:
async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
async def aget_conversation_starters(user: KhojUser):
all_questions = []
if await ReflectiveQuestion.objects.filter(user=user).aexists():
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
"question", flat=True
)
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=None).values_list)(
"question", flat=True
)
max_results = 3
all_questions = await sync_to_async(list)(all_questions)
if len(all_questions) < max_results:
return all_questions
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_conversation_config(user: KhojUser):
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()

View file

@ -0,0 +1,36 @@
# Generated by Django 4.2.7 on 2023-11-20 01:13
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0019_alter_googleuser_family_name_and_more"),
]
operations = [
migrations.CreateModel(
name="ReflectiveQuestion",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("question", models.CharField(max_length=500)),
(
"user",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
]

View file

@ -141,6 +141,11 @@ class Conversation(BaseModel):
conversation_log = models.JSONField(default=dict)
class ReflectiveQuestion(BaseModel):
question = models.CharField(max_length=500)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
class Entry(BaseModel):
class EntryType(models.TextChoices):
IMAGE = "image"

View file

@ -153,8 +153,8 @@ To get started, just start typing below. You can also type / to see a list of co
numOnlineReferences += onlineReference.organic.length;
for (let index in onlineReference.organic) {
let reference = onlineReference.organic[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
@ -162,8 +162,8 @@ To get started, just start typing below. You can also type / to see a list of co
numOnlineReferences += onlineReference.knowledgeGraph.length;
for (let index in onlineReference.knowledgeGraph) {
let reference = onlineReference.knowledgeGraph[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
@ -426,6 +426,9 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input");
chatInput.value = chatInput.value.trimStart();
let questionStarterSuggestions = document.getElementById("question-starters");
questionStarterSuggestions.style.display = "none";
if (chatInput.value.startsWith("/") && chatInput.value.split(" ").length === 1) {
let chatTooltip = document.getElementById("chat-tooltip");
chatTooltip.style.display = "block";
@ -505,6 +508,31 @@ To get started, just start typing below. You can also type / to see a list of co
return;
});
fetch('/api/chat/starters')
.then(response => response.json())
.then(data => {
// Render chat options, if any
if (data) {
let questionStarterSuggestions = document.getElementById("question-starters");
for (let index in data) {
let questionStarter = data[index];
let questionStarterButton = document.createElement('button');
questionStarterButton.innerHTML = questionStarter;
questionStarterButton.classList.add("question-starter");
questionStarterButton.addEventListener('click', function() {
questionStarterSuggestions.style.display = "none";
document.getElementById("chat-input").value = questionStarter;
chat();
});
questionStarterSuggestions.appendChild(questionStarterButton);
}
questionStarterSuggestions.style.display = "grid";
}
})
.catch(err => {
return;
});
// Fill query field with value passed in URL query parameters, if any.
var query_via_url = new URLSearchParams(window.location.search).get("q");
if (query_via_url) {
@ -524,6 +552,9 @@ To get started, just start typing below. You can also type / to see a list of co
<!-- Chat Body -->
<div id="chat-body"></div>
<!-- Chat Suggestions -->
<div id="question-starters" style="display: none;"></div>
<!-- Chat Footer -->
<div id="chat-footer">
<div id="chat-tooltip" style="display: none;"></div>
@ -584,6 +615,28 @@ To get started, just start typing below. You can also type / to see a list of co
margin: 10px;
}
div#question-starters {
grid-template-columns: repeat(auto-fit, minmax(100px, 1fr));
grid-column-gap: 8px;
}
button.question-starter {
background: var(--background-color);
color: var(--main-text-color);
border: 1px solid var(--main-text-color);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
cursor: pointer;
transition: background 0.2s ease-in-out;
text-align: left;
max-height: 75px;
transition: max-height 0.3s ease-in-out;
overflow: hidden;
}
button.reference-button {
background: var(--background-color);
color: var(--main-text-color);

View file

@ -512,6 +512,17 @@ def update(
return {"status": "ok", "message": "khoj reloaded"}
@api.get("/chat/starters", response_class=Response)
@requires(["authenticated"])
async def chat_starters(
request: Request,
common: CommonQueryParams,
) -> Response:
user: KhojUser = request.user.object
starter_questions = await ConversationAdapters.aget_conversation_starters(user)
return Response(content=json.dumps(starter_questions), media_type="application/json", status_code=200)
@api.get("/chat/history")
@requires(["authenticated"])
def chat_history(