Merge pull request #552 from khoj-ai/features/internet-enabled-search

Support internet-enabled, online searching using Serper.dev
This commit is contained in:
sabaimran 2023-11-23 12:34:05 -08:00 committed by GitHub
commit c42ec32a95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 998 additions and 240 deletions

View file

@ -60,6 +60,52 @@
return referenceButton; return referenceButton;
} }
function generateOnlineReference(reference, index) {
// Generate HTML for Chat Reference
let title = reference.title;
let link = reference.link;
let snippet = reference.snippet;
let question = reference.question;
if (question) {
question = `<b>Question:</b> ${question}<br><br>`;
} else {
question = "";
}
let linkElement = document.createElement('a');
linkElement.setAttribute('href', link);
linkElement.setAttribute('target', '_blank');
linkElement.setAttribute('rel', 'noopener noreferrer');
linkElement.classList.add("inline-chat-link");
linkElement.classList.add("reference-link");
linkElement.setAttribute('title', title);
linkElement.innerHTML = title;
let referenceButton = document.createElement('button');
referenceButton.innerHTML = linkElement.outerHTML;
referenceButton.id = `ref-${index}`;
referenceButton.classList.add("reference-button");
referenceButton.classList.add("collapsed");
referenceButton.tabIndex = 0;
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
console.log(`Toggling ref-${index}`)
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
this.innerHTML = linkElement.outerHTML + `<br><br>${question + snippet}`;
} else {
this.classList.add("collapsed");
this.classList.remove("expanded");
this.innerHTML = linkElement.outerHTML;
}
});
return referenceButton;
}
function renderMessage(message, by, dt=null, annotations=null) { function renderMessage(message, by, dt=null, annotations=null) {
let message_time = formatDate(dt ?? new Date()); let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
@ -90,8 +136,48 @@
chatBody.scrollTop = chatBody.scrollHeight; chatBody.scrollTop = chatBody.scrollHeight;
} }
function renderMessageWithReference(message, by, context=null, dt=null) { function processOnlineReferences(referenceSection, onlineContext) {
if (context == null || context.length == 0) { let numOnlineReferences = 0;
for (let subquery in onlineContext) {
let onlineReference = onlineContext[subquery];
if (onlineReference.organic && onlineReference.organic.length > 0) {
numOnlineReferences += onlineReference.organic.length;
for (let index in onlineReference.organic) {
let reference = onlineReference.organic[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) {
numOnlineReferences += onlineReference.knowledgeGraph.length;
for (let index in onlineReference.knowledgeGraph) {
let reference = onlineReference.knowledgeGraph[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) {
numOnlineReferences += onlineReference.peopleAlsoAsk.length;
for (let index in onlineReference.peopleAlsoAsk) {
let reference = onlineReference.peopleAlsoAsk[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
}
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
}
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
renderMessage(message, by, dt); renderMessage(message, by, dt);
return; return;
} }
@ -100,8 +186,11 @@
let referenceExpandButton = document.createElement('button'); let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button"); referenceExpandButton.classList.add("reference-expand-button");
let expandButtonText = context.length == 1 ? "1 reference" : `${context.length} references`; let numReferences = 0;
referenceExpandButton.innerHTML = expandButtonText;
if (context) {
numReferences += context.length;
}
references.appendChild(referenceExpandButton); references.appendChild(referenceExpandButton);
@ -127,6 +216,14 @@
referenceSection.appendChild(polishedReference); referenceSection.appendChild(polishedReference);
} }
} }
if (onlineContext) {
numReferences += processOnlineReferences(referenceSection, onlineContext);
}
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection); references.appendChild(referenceSection);
renderMessage(message, by, dt, references); renderMessage(message, by, dt, references);
@ -140,6 +237,8 @@
newHTML = newHTML.replace(/__([\s\S]*?)__/g, '<u>$1</u>'); newHTML = newHTML.replace(/__([\s\S]*?)__/g, '<u>$1</u>');
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model. // Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, ''); newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// For any text that has single backticks, replace them with <code> tags
newHTML = newHTML.replace(/`([^`]+)`/g, '<code class="chat-response">$1</code>');
return newHTML; return newHTML;
} }
@ -221,15 +320,28 @@
let referenceExpandButton = document.createElement('button'); let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button"); referenceExpandButton.classList.add("reference-expand-button");
let expandButtonText = rawReferenceAsJson.length == 1 ? "1 reference" : `${rawReferenceAsJson.length} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceExpandButton);
let referenceSection = document.createElement('div'); let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section"); referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed"); referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() { referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) { if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed"); referenceSection.classList.remove("collapsed");
@ -240,10 +352,8 @@
} }
}); });
rawReferenceAsJson.forEach((reference, index) => { let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
let polishedReference = generateReference(reference, index); referenceExpandButton.innerHTML = expandButtonText;
referenceSection.appendChild(polishedReference);
});
references.appendChild(referenceSection); references.appendChild(referenceSection);
readStream(); readStream();
} else { } else {
@ -276,6 +386,9 @@
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
chatInput.value = chatInput.value.trimStart(); chatInput.value = chatInput.value.trimStart();
let questionStarterSuggestions = document.getElementById("question-starters");
questionStarterSuggestions.style.display = "none";
if (chatInput.value.startsWith("/") && chatInput.value.split(" ").length === 1) { if (chatInput.value.startsWith("/") && chatInput.value.split(" ").length === 1) {
let chatTooltip = document.getElementById("chat-tooltip"); let chatTooltip = document.getElementById("chat-tooltip");
chatTooltip.style.display = "block"; chatTooltip.style.display = "block";
@ -324,7 +437,7 @@
const khojToken = await window.tokenAPI.getToken(); const khojToken = await window.tokenAPI.getToken();
const headers = { 'Authorization': `Bearer ${khojToken}` }; const headers = { 'Authorization': `Bearer ${khojToken}` };
fetch(`${hostURL}/api/chat/history?client=web`, { headers }) fetch(`${hostURL}/api/chat/history?client=desktop`, { headers })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.detail) { if (data.detail) {
@ -351,13 +464,38 @@
.then(response => { .then(response => {
// Render conversation history, if any // Render conversation history, if any
response.forEach(chat_log => { response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created)); renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
}); });
}) })
.catch(err => { .catch(err => {
return; return;
}); });
fetch(`${hostURL}/api/chat/starters?client=desktop`, { headers })
.then(response => response.json())
.then(data => {
// Render chat options, if any
if (data) {
let questionStarterSuggestions = document.getElementById("question-starters");
for (let index in data) {
let questionStarter = data[index];
let questionStarterButton = document.createElement('button');
questionStarterButton.innerHTML = questionStarter;
questionStarterButton.classList.add("question-starter");
questionStarterButton.addEventListener('click', function() {
questionStarterSuggestions.style.display = "none";
document.getElementById("chat-input").value = questionStarter;
chat();
});
questionStarterSuggestions.appendChild(questionStarterButton);
}
questionStarterSuggestions.style.display = "grid";
}
})
.catch(err => {
return;
});
fetch(`${hostURL}/api/chat/options`, { headers }) fetch(`${hostURL}/api/chat/options`, { headers })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
@ -397,6 +535,9 @@
<!-- Chat Body --> <!-- Chat Body -->
<div id="chat-body"></div> <div id="chat-body"></div>
<!-- Chat Suggestions -->
<div id="question-starters" style="display: none;"></div>
<!-- Chat Footer --> <!-- Chat Footer -->
<div id="chat-footer"> <div id="chat-footer">
<div id="chat-tooltip" style="display: none;"></div> <div id="chat-tooltip" style="display: none;"></div>
@ -574,6 +715,38 @@
margin: 10px; margin: 10px;
} }
div#question-starters {
grid-template-columns: repeat(auto-fit, minmax(100px, 1fr));
grid-column-gap: 8px;
}
button.question-starter {
background: var(--background-color);
color: var(--main-text-color);
border: 1px solid var(--main-text-color);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
cursor: pointer;
transition: background 0.2s ease-in-out;
text-align: left;
max-height: 75px;
transition: max-height 0.3s ease-in-out;
overflow: hidden;
}
code.chat-response {
background: var(--primary-hover);
color: var(--primary-inverse);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
}
button.reference-button { button.reference-button {
background: var(--background-color); background: var(--background-color);
color: var(--main-text-color); color: var(--main-text-color);

View file

@ -1,42 +1,43 @@
import math import math
from typing import Optional, Type, List import random
from datetime import date, datetime
import secrets import secrets
from typing import Type, List from datetime import date, datetime, timezone
from datetime import date, timezone from typing import List, Optional, Type
from django.db import models from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.backends.db import SessionStore
from pgvector.django import CosineDistance from django.db import models
from django.db.models.manager import BaseManager
from django.db.models import Q from django.db.models import Q
from django.db.models.manager import BaseManager
from fastapi import HTTPException
from pgvector.django import CosineDistance
from torch import Tensor from torch import Tensor
# Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async
from fastapi import HTTPException
from khoj.database.models import ( from khoj.database.models import (
KhojUser, ChatModelOptions,
Conversation,
Entry,
GithubConfig,
GithubRepoConfig,
GoogleUser, GoogleUser,
KhojApiUser, KhojApiUser,
KhojUser,
NotionConfig, NotionConfig,
GithubConfig, OfflineChatProcessorConversationConfig,
Entry, OpenAIProcessorConversationConfig,
GithubRepoConfig,
Conversation,
ChatModelOptions,
SearchModelConfig, SearchModelConfig,
Subscription, Subscription,
UserConversationConfig, UserConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
ReflectiveQuestion,
) )
from khoj.utils.helpers import generate_random_name
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name
async def set_notion_config(token: str, user: KhojUser): async def set_notion_config(token: str, user: KhojUser):
@ -339,6 +340,45 @@ class ConversationAdapters:
async def get_openai_chat_config(): async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst() return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
async def aget_conversation_starters(user: KhojUser):
all_questions = []
if await ReflectiveQuestion.objects.filter(user=user).aexists():
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=user).values_list)(
"question", flat=True
)
all_questions = await sync_to_async(ReflectiveQuestion.objects.filter(user=None).values_list)(
"question", flat=True
)
max_results = 3
all_questions = await sync_to_async(list)(all_questions)
if len(all_questions) < max_results:
return all_questions
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_conversation_config(user: KhojUser):
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
conversation_config = ConversationAdapters.get_conversation_config(user)
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
return conversation_config
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
if openai_chat_config and conversation_config.model_type == "openai":
return conversation_config
else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
class EntryAdapters: class EntryAdapters:
word_filer = WordFilter() word_filer = WordFilter()

View file

@ -0,0 +1,36 @@
# Generated by Django 4.2.7 on 2023-11-20 01:13
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0019_alter_googleuser_family_name_and_more"),
]
operations = [
migrations.CreateModel(
name="ReflectiveQuestion",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("question", models.CharField(max_length=500)),
(
"user",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
]

View file

@ -141,6 +141,11 @@ class Conversation(BaseModel):
conversation_log = models.JSONField(default=dict) conversation_log = models.JSONField(default=dict)
class ReflectiveQuestion(BaseModel):
question = models.CharField(max_length=500)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
class Entry(BaseModel): class Entry(BaseModel):
class EntryType(models.TextChoices): class EntryType(models.TextChoices):
IMAGE = "image" IMAGE = "image"

View file

@ -9,14 +9,15 @@
<link rel="stylesheet" href="/static/assets/khoj.css"> <link rel="stylesheet" href="/static/assets/khoj.css">
</head> </head>
<script type="text/javascript" src="/static/assets/utils.js"></script> <script type="text/javascript" src="/static/assets/utils.js"></script>
<script type="text/javascript" src="/static/assets/markdown-it.min.js"></script>
<script> <script>
let welcome_message = ` let welcome_message = `
Hi, I am Khoj, your open, personal AI 👋🏽. I can help: Hi, I am Khoj, your open, personal AI 👋🏽. I can help:
🧠 Answer general knowledge questions - 🧠 Answer general knowledge questions
💡 Be a sounding board for your ideas - 💡 Be a sounding board for your ideas
📜 Chat with your notes & documents - 📜 Chat with your notes & documents
Download the <a class='inline-chat-link' href='https://khoj.dev/downloads'>🖥️ Desktop app</a> to chat with your computer docs. Download the [🖥️ Desktop app](https://khoj.dev/downloads) to chat with your computer docs.
To get started, just start typing below. You can also type / to see a list of commands. To get started, just start typing below. You can also type / to see a list of commands.
`.trim() `.trim()
@ -69,6 +70,52 @@ To get started, just start typing below. You can also type / to see a list of co
return referenceButton; return referenceButton;
} }
function generateOnlineReference(reference, index) {
// Generate HTML for Chat Reference
let title = reference.title;
let link = reference.link;
let snippet = reference.snippet;
let question = reference.question;
if (question) {
question = `<b>Question:</b> ${question}<br><br>`;
} else {
question = "";
}
let linkElement = document.createElement('a');
linkElement.setAttribute('href', link);
linkElement.setAttribute('target', '_blank');
linkElement.setAttribute('rel', 'noopener noreferrer');
linkElement.classList.add("inline-chat-link");
linkElement.classList.add("reference-link");
linkElement.setAttribute('title', title);
linkElement.innerHTML = title;
let referenceButton = document.createElement('button');
referenceButton.innerHTML = linkElement.outerHTML;
referenceButton.id = `ref-${index}`;
referenceButton.classList.add("reference-button");
referenceButton.classList.add("collapsed");
referenceButton.tabIndex = 0;
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
console.log(`Toggling ref-${index}`)
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
this.innerHTML = linkElement.outerHTML + `<br><br>${question + snippet}`;
} else {
this.classList.add("collapsed");
this.classList.remove("expanded");
this.innerHTML = linkElement.outerHTML;
}
});
return referenceButton;
}
function renderMessage(message, by, dt=null, annotations=null) { function renderMessage(message, by, dt=null, annotations=null) {
let message_time = formatDate(dt ?? new Date()); let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
@ -83,8 +130,7 @@ To get started, just start typing below. You can also type / to see a list of co
// Create a new div for the chat message text and append it to the chat message // Create a new div for the chat message text and append it to the chat message
let chatMessageText = document.createElement('div'); let chatMessageText = document.createElement('div');
chatMessageText.className = `chat-message-text ${by}`; chatMessageText.className = `chat-message-text ${by}`;
let textNode = document.createTextNode(formattedMessage); chatMessageText.appendChild(formattedMessage);
chatMessageText.appendChild(textNode);
chatMessage.appendChild(chatMessageText); chatMessage.appendChild(chatMessageText);
// Append annotations div to the chat message // Append annotations div to the chat message
@ -99,8 +145,48 @@ To get started, just start typing below. You can also type / to see a list of co
chatBody.scrollTop = chatBody.scrollHeight; chatBody.scrollTop = chatBody.scrollHeight;
} }
function renderMessageWithReference(message, by, context=null, dt=null) { function processOnlineReferences(referenceSection, onlineContext) {
if (context == null || context.length == 0) { let numOnlineReferences = 0;
for (let subquery in onlineContext) {
let onlineReference = onlineContext[subquery];
if (onlineReference.organic && onlineReference.organic.length > 0) {
numOnlineReferences += onlineReference.organic.length;
for (let index in onlineReference.organic) {
let reference = onlineReference.organic[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) {
numOnlineReferences += onlineReference.knowledgeGraph.length;
for (let index in onlineReference.knowledgeGraph) {
let reference = onlineReference.knowledgeGraph[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) {
numOnlineReferences += onlineReference.peopleAlsoAsk.length;
for (let index in onlineReference.peopleAlsoAsk) {
let reference = onlineReference.peopleAlsoAsk[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
}
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
}
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
renderMessage(message, by, dt); renderMessage(message, by, dt);
return; return;
} }
@ -109,8 +195,11 @@ To get started, just start typing below. You can also type / to see a list of co
let referenceExpandButton = document.createElement('button'); let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button"); referenceExpandButton.classList.add("reference-expand-button");
let expandButtonText = context.length == 1 ? "1 reference" : `${context.length} references`; let numReferences = 0;
referenceExpandButton.innerHTML = expandButtonText;
if (context) {
numReferences += context.length;
}
references.appendChild(referenceExpandButton); references.appendChild(referenceExpandButton);
@ -136,20 +225,63 @@ To get started, just start typing below. You can also type / to see a list of co
referenceSection.appendChild(polishedReference); referenceSection.appendChild(polishedReference);
} }
} }
if (onlineContext) {
numReferences += processOnlineReferences(referenceSection, onlineContext);
}
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection); references.appendChild(referenceSection);
renderMessage(message, by, dt, references); renderMessage(message, by, dt, references);
} }
function formatHTMLMessage(htmlMessage) { function formatHTMLMessage(htmlMessage) {
// Replace any ``` with <div class="programmatic-output"> var md = window.markdownit();
let newHTML = htmlMessage.replace(/```([\s\S]*?)```/g, '<div class="programmatic-output"><button class="copy-button" onclick="copyProgrammaticOutput(event)">Copy</button>$1</div>'); let newHTML = htmlMessage;
// Replace any ** with <b> and __ with <u>
newHTML = newHTML.replace(/\*\*([\s\S]*?)\*\*/g, '<b>$1</b>');
newHTML = newHTML.replace(/__([\s\S]*?)__/g, '<u>$1</u>');
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model. // Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, ''); newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
return newHTML;
// Render markdown
newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language"
let element = document.createElement('div');
element.innerHTML = newHTML;
let codeBlockElements = element.querySelectorAll('[class^="language-"]');
// For each element, add a parent div with the class "programmatic-output"
codeBlockElements.forEach((codeElement) => {
// Create the parent div
let parentDiv = document.createElement('div');
parentDiv.classList.add("programmatic-output");
// Add the parent div before the code element
codeElement.parentNode.insertBefore(parentDiv, codeElement);
// Move the code element into the parent div
parentDiv.appendChild(codeElement);
// Add a copy button to each element
let copyButton = document.createElement('button');
copyButton.classList.add("copy-button");
copyButton.innerHTML = "Copy";
copyButton.addEventListener('click', copyProgrammaticOutput);
codeElement.prepend(copyButton);
});
// Get all code elements that have no class.
let codeElements = element.querySelectorAll('code:not([class])');
codeElements.forEach((codeElement) => {
// Add the class "chat-response" to each element
codeElement.classList.add("chat-response");
});
let anchorElements = element.querySelectorAll('a');
anchorElements.forEach((anchorElement) => {
// Add the class "inline-chat-link" to each element
anchorElement.classList.add("inline-chat-link");
});
return element
} }
function chat() { function chat() {
@ -205,7 +337,8 @@ To get started, just start typing below. You can also type / to see a list of co
if (done) { if (done) {
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed // Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
const currentHTML = newResponseText.innerHTML; const currentHTML = newResponseText.innerHTML;
newResponseText.innerHTML = formatHTMLMessage(currentHTML); newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(currentHTML));
if (references != null) { if (references != null) {
newResponseText.appendChild(references); newResponseText.appendChild(references);
} }
@ -226,18 +359,30 @@ To get started, just start typing below. You can also type / to see a list of co
references = document.createElement('div'); references = document.createElement('div');
references.classList.add("references"); references.classList.add("references");
let referenceExpandButton = document.createElement('button'); let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button"); referenceExpandButton.classList.add("reference-expand-button");
let expandButtonText = rawReferenceAsJson.length == 1 ? "1 reference" : `${rawReferenceAsJson.length} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceExpandButton);
let referenceSection = document.createElement('div'); let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section"); referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed"); referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() { referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) { if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed"); referenceSection.classList.remove("collapsed");
@ -248,10 +393,8 @@ To get started, just start typing below. You can also type / to see a list of co
} }
}); });
rawReferenceAsJson.forEach((reference, index) => { let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
let polishedReference = generateReference(reference, index); referenceExpandButton.innerHTML = expandButtonText;
referenceSection.appendChild(polishedReference);
});
references.appendChild(referenceSection); references.appendChild(referenceSection);
readStream(); readStream();
} else { } else {
@ -283,6 +426,9 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
chatInput.value = chatInput.value.trimStart(); chatInput.value = chatInput.value.trimStart();
let questionStarterSuggestions = document.getElementById("question-starters");
questionStarterSuggestions.style.display = "none";
if (chatInput.value.startsWith("/") && chatInput.value.split(" ").length === 1) { if (chatInput.value.startsWith("/") && chatInput.value.split(" ").length === 1) {
let chatTooltip = document.getElementById("chat-tooltip"); let chatTooltip = document.getElementById("chat-tooltip");
chatTooltip.style.display = "block"; chatTooltip.style.display = "block";
@ -342,7 +488,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => { .then(response => {
// Render conversation history, if any // Render conversation history, if any
response.forEach(chat_log => { response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created)); renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
}); });
}) })
.catch(err => { .catch(err => {
@ -362,6 +508,31 @@ To get started, just start typing below. You can also type / to see a list of co
return; return;
}); });
fetch('/api/chat/starters')
.then(response => response.json())
.then(data => {
// Render chat options, if any
if (data) {
let questionStarterSuggestions = document.getElementById("question-starters");
for (let index in data) {
let questionStarter = data[index];
let questionStarterButton = document.createElement('button');
questionStarterButton.innerHTML = questionStarter;
questionStarterButton.classList.add("question-starter");
questionStarterButton.addEventListener('click', function() {
questionStarterSuggestions.style.display = "none";
document.getElementById("chat-input").value = questionStarter;
chat();
});
questionStarterSuggestions.appendChild(questionStarterButton);
}
questionStarterSuggestions.style.display = "grid";
}
})
.catch(err => {
return;
});
// Fill query field with value passed in URL query parameters, if any. // Fill query field with value passed in URL query parameters, if any.
var query_via_url = new URLSearchParams(window.location.search).get("q"); var query_via_url = new URLSearchParams(window.location.search).get("q");
if (query_via_url) { if (query_via_url) {
@ -381,6 +552,9 @@ To get started, just start typing below. You can also type / to see a list of co
<!-- Chat Body --> <!-- Chat Body -->
<div id="chat-body"></div> <div id="chat-body"></div>
<!-- Chat Suggestions -->
<div id="question-starters" style="display: none;"></div>
<!-- Chat Footer --> <!-- Chat Footer -->
<div id="chat-footer"> <div id="chat-footer">
<div id="chat-tooltip" style="display: none;"></div> <div id="chat-tooltip" style="display: none;"></div>
@ -441,6 +615,28 @@ To get started, just start typing below. You can also type / to see a list of co
margin: 10px; margin: 10px;
} }
div#question-starters {
grid-template-columns: repeat(auto-fit, minmax(100px, 1fr));
grid-column-gap: 8px;
}
button.question-starter {
background: var(--background-color);
color: var(--main-text-color);
border: 1px solid var(--main-text-color);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
cursor: pointer;
transition: background 0.2s ease-in-out;
text-align: left;
max-height: 75px;
transition: max-height 0.3s ease-in-out;
overflow: hidden;
}
button.reference-button { button.reference-button {
background: var(--background-color); background: var(--background-color);
color: var(--main-text-color); color: var(--main-text-color);
@ -491,6 +687,16 @@ To get started, just start typing below. You can also type / to see a list of co
background: var(--primary-hover); background: var(--primary-hover);
} }
code.chat-response {
background: var(--primary-hover);
color: var(--primary-inverse);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
}
#chat-body { #chat-body {
font-size: medium; font-size: medium;
margin: 0px; margin: 0px;
@ -626,6 +832,22 @@ To get started, just start typing below. You can also type / to see a list of co
border-bottom: 1px dotted var(--main-text-color); border-bottom: 1px dotted var(--main-text-color);
} }
a.reference-link {
color: var(--main-text-color);
border-bottom: 1px dotted var(--main-text-color);
}
button.copy-button {
display: block;
border-radius: 4px;
background-color: var(--background-color);
}
button.copy-button:hover {
background: #f5f5f5;
cursor: pointer;
}
@media (pointer: coarse), (hover: none) { @media (pointer: coarse), (hover: none) {
abbr[title] { abbr[title] {
position: relative; position: relative;
@ -699,6 +921,10 @@ To get started, just start typing below. You can also type / to see a list of co
padding: 0px; padding: 0px;
} }
p {
margin: 0;
}
div.programmatic-output { div.programmatic-output {
background-color: #f5f5f5; background-color: #f5f5f5;
border: 1px solid #ddd; border: 1px solid #ddd;

View file

@ -2,19 +2,20 @@
import logging import logging
import time import time
from datetime import datetime from datetime import datetime
from typing import Dict, List, Union, Tuple from typing import Dict, List, Tuple, Union
# External Packages # External Packages
import requests import requests
from khoj.database.models import Entry as DbEntry
from khoj.database.models import GithubConfig, KhojUser
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.text_to_entries import TextToEntries
# Internal Packages # Internal Packages
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.processor.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from khoj.processor.text_to_entries import TextToEntries
from khoj.database.models import Entry as DbEntry, GithubConfig, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -1,17 +1,19 @@
# Standard Packages # Standard Packages
import logging import logging
import re import re
import urllib3
from pathlib import Path from pathlib import Path
from typing import Tuple, List from typing import List, Tuple
import urllib3
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
# Internal Packages # Internal Packages
from khoj.processor.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.constants import empty_escape_sequences from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -1,19 +1,18 @@
# Standard Packages # Standard Packages
import logging import logging
from enum import Enum
from typing import Tuple from typing import Tuple
# External Packages # External Packages
import requests import requests
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser, NotionConfig
from khoj.processor.content.text_to_entries import TextToEntries
# Internal Packages # Internal Packages
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, NotionContentConfig from khoj.utils.rawconfig import Entry, NotionContentConfig
from khoj.processor.text_to_entries import TextToEntries
from khoj.utils.rawconfig import Entry
from khoj.database.models import Entry as DbEntry, KhojUser, NotionConfig
from enum import Enum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -3,14 +3,15 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Iterable, List, Tuple from typing import Iterable, List, Tuple
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
# Internal Packages # Internal Packages
from khoj.processor.org_mode import orgnode from khoj.processor.content.org_mode import orgnode
from khoj.processor.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils import state
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.utils import state
from khoj.database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -1,18 +1,19 @@
# Standard Packages # Standard Packages
import os
import logging
from typing import List, Tuple
import base64 import base64
import logging
import os
from typing import List, Tuple
# External Packages # External Packages
from langchain.document_loaders import PyMuPDFLoader from langchain.document_loaders import PyMuPDFLoader
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
# Internal Packages # Internal Packages
from khoj.processor.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -2,15 +2,16 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
# Internal Packages # Internal Packages
from khoj.processor.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -121,6 +121,7 @@ def filter_questions(questions: List[str]):
def converse_offline( def converse_offline(
references, references,
online_results,
user_query, user_query,
conversation_log={}, conversation_log={},
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
@ -147,6 +148,13 @@ def converse_offline(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message): if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message):
return iter([prompts.no_notes_found.format()]) return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message): elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
conversation_primer = user_query conversation_primer = user_query
else: else:
@ -164,20 +172,13 @@ def converse_offline(
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
) )
g = ThreadedGenerator(references, completion_func=completion_func) g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model)) t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
t.start() t.start()
return g return g
def llm_thread(g, messages: List[ChatMessage], model: Any): def llm_thread(g, messages: List[ChatMessage], model: Any):
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
assert isinstance(model, GPT4All), "model should be of type GPT4All"
user_message = messages[-1] user_message = messages[-1]
system_message = messages[0] system_message = messages[0]
conversation_history = messages[1:-1] conversation_history = messages[1:-1]
@ -196,7 +197,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
prompted_message = templated_system_message + chat_history + templated_user_message prompted_message = templated_system_message + chat_history + templated_user_message
state.chat_lock.acquire() state.chat_lock.acquire()
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=512) response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
try: try:
for response in response_iterator: for response in response_iterator:
if any(stop_word in response.strip() for stop_word in stop_words): if any(stop_word in response.strip() for stop_word in stop_words):
@ -206,3 +207,18 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
finally: finally:
state.chat_lock.release() state.chat_lock.release()
g.close() g.close()
def send_message_to_model_offline(
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False
):
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
gpt4all_model = loaded_model or GPT4All(model)
return gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming)

View file

@ -100,8 +100,30 @@ def extract_questions(
return questions return questions
def send_message_to_model(
message,
api_key,
model,
):
"""
Send message to model
"""
messages = [ChatMessage(content=message, role="assistant")]
# Get Response from GPT
return completion_with_backoff(
messages=messages,
model_name=model,
temperature=0,
max_tokens=100,
model_kwargs={"stop": ["A: ", "\n"]},
openai_api_key=api_key,
)
def converse( def converse(
references, references,
online_results,
user_query, user_query,
conversation_log={}, conversation_log={},
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
@ -123,6 +145,13 @@ def converse(
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references): if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format()) completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()]) return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references): elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
conversation_primer = prompts.general_conversation.format(query=user_query) conversation_primer = prompts.general_conversation.format(query=user_query)
else: else:
@ -144,6 +173,7 @@ def converse(
return chat_completion_with_backoff( return chat_completion_with_backoff(
messages=messages, messages=messages,
compiled_references=references, compiled_references=references,
online_results=online_results,
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
openai_api_key=api_key, openai_api_key=api_key,

View file

@ -69,9 +69,16 @@ def completion_with_backoff(**kwargs):
reraise=True, reraise=True,
) )
def chat_completion_with_backoff( def chat_completion_with_backoff(
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None messages,
compiled_references,
online_results,
model_name,
temperature,
openai_api_key=None,
completion_func=None,
model_kwargs=None,
): ):
g = ThreadedGenerator(compiled_references, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs)) t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
t.start() t.start()
return g return g

View file

@ -10,7 +10,7 @@ You are Khoj, a smart, inquisitive and helpful personal assistant.
Use your general knowledge and the past conversation with the user as context to inform your responses. Use your general knowledge and the past conversation with the user as context to inform your responses.
You were created by Khoj Inc. with the following capabilities: You were created by Khoj Inc. with the following capabilities:
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you. - You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you. They can share files with you using any Khoj client, including the native Desktop app, the Obsidian or Emacs plugins, or the web app.
- You cannot set reminders. - You cannot set reminders.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question. - Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. - Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
@ -35,6 +35,12 @@ no_notes_found = PromptTemplate.from_template(
""".strip() """.strip()
) )
no_online_results_found = PromptTemplate.from_template(
"""
I'm sorry, I couldn't find any relevant information from the internet to respond to your message.
""".strip()
)
no_entries_found = PromptTemplate.from_template( no_entries_found = PromptTemplate.from_template(
""" """
It looks like you haven't added any notes yet. No worries, you can fix that by downloading the Khoj app from <a href=https://khoj.dev/downloads>here</a>. It looks like you haven't added any notes yet. No worries, you can fix that by downloading the Khoj app from <a href=https://khoj.dev/downloads>here</a>.
@ -103,6 +109,45 @@ Question: {query}
""".strip() """.strip()
) )
## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(
"""
Use this up-to-date information from the internet to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
Information from the internet: {online_results}
Query: {query}""".strip()
)
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
The user has a question which you can use the internet to respond to. Can you break down the question into subqueries to get the correct answer? Provide search queries as a JSON list of strings
Today's date in UTC: {current_date}
Here are some examples of questions and subqueries:
Q: Posts about vector databases on Hacker News
A: ["site:"news.ycombinator.com vector database"]
Q: What is the weather like in New York and San Francisco?
A: ["weather in new york", "weather in san francisco"]
Q: What is the latest news about Google stock?
A: ["google stock news"]
Q: When is the next lunar eclipse?
A: ["next lunar eclipse"]
Q: How many oranges would fit in NASA's Saturn V rocket?
A: ["volume of an orange", "volume of saturn v rocket"]
This is the user's query:
Q: {query}
A: """.strip()
)
## Summarize Notes ## Summarize Notes
## -- ## --

View file

@ -29,9 +29,10 @@ model_to_tokenizer = {
class ThreadedGenerator: class ThreadedGenerator:
def __init__(self, compiled_references, completion_func=None): def __init__(self, compiled_references, online_results, completion_func=None):
self.queue = queue.Queue() self.queue = queue.Queue()
self.compiled_references = compiled_references self.compiled_references = compiled_references
self.online_results = online_results
self.completion_func = completion_func self.completion_func = completion_func
self.response = "" self.response = ""
self.start_time = perf_counter() self.start_time = perf_counter()
@ -62,6 +63,8 @@ class ThreadedGenerator:
def close(self): def close(self):
if self.compiled_references and len(self.compiled_references) > 0: if self.compiled_references and len(self.compiled_references) > 0:
self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}") self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}")
elif self.online_results and len(self.online_results) > 0:
self.queue.put(f"### compiled references:{json.dumps(self.online_results)}")
self.queue.put(StopIteration) self.queue.put(StopIteration)

View file

View file

@ -0,0 +1,52 @@
import requests
import json
import os
import logging
from khoj.routers.helpers import generate_online_subqueries
logger = logging.getLogger(__name__)
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
url = "https://google.serper.dev/search"
async def search_with_google(query: str):
def _search_with_google(subquery: str):
payload = json.dumps(
{
"q": subquery,
}
)
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code != 200:
logger.error(response.text)
return {}
json_response = response.json()
sub_response_dict = {}
sub_response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
sub_response_dict["organic"] = json_response.get("organic", [])
sub_response_dict["answerBox"] = json_response.get("answerBox", [])
sub_response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
return sub_response_dict
if SERPER_DEV_API_KEY is None:
raise ValueError("SERPER_DEV_API_KEY is not set")
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query)
response_dict = {}
for subquery in subqueries:
logger.info(f"Searching with Google for '{subquery}'")
response_dict[subquery] = _search_with_google(subquery)
return response_dict

View file

@ -1,63 +1,63 @@
# Standard Packages # Standard Packages
import concurrent.futures import concurrent.futures
import json
import logging
import math import math
import time import time
import logging from typing import Any, Dict, List, Optional, Union
import json
from typing import Annotated, List, Optional, Union, Any from asgiref.sync import sync_to_async
# External Packages # External Packages
from fastapi import APIRouter, Depends, HTTPException, Header, Request from fastapi import APIRouter, Depends, Header, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires
from asgiref.sync import sync_to_async
# Internal Packages # Internal Packages
from khoj.configure import configure_server from khoj.configure import configure_server
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel, GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
from khoj.utils.rawconfig import (
FullConfig,
SearchConfig,
SearchResponse,
GithubContentConfig,
NotionContentConfig,
)
from khoj.utils.state import SearchType
from khoj.utils import state, constants
from khoj.utils.helpers import AsyncIteratorWrapper, get_device
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import (
CommonQueryParams,
get_conversation_command,
validate_conversation_config,
agenerate_chat_response,
update_telemetry_state,
is_ready_to_chat,
ApiUserRateLimiter,
)
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
from fastapi.requests import Request
from khoj.database import adapters from khoj.database import adapters
from khoj.database.adapters import EntryAdapters, ConversationAdapters from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import ChatModelOptions
from khoj.database.models import Entry as DbEntry
from khoj.database.models import ( from khoj.database.models import (
GithubConfig,
KhojUser,
LocalMarkdownConfig, LocalMarkdownConfig,
LocalOrgConfig, LocalOrgConfig,
LocalPdfConfig, LocalPdfConfig,
LocalPlaintextConfig, LocalPlaintextConfig,
KhojUser,
Entry as DbEntry,
GithubConfig,
NotionConfig, NotionConfig,
ChatModelOptions,
) )
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import (
ApiUserRateLimiter,
CommonQueryParams,
agenerate_chat_response,
get_conversation_command,
is_ready_to_chat,
update_telemetry_state,
validate_conversation_config,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import GPT4AllProcessorModel, TextSearchModel
from khoj.utils.helpers import (
AsyncIteratorWrapper,
ConversationCommand,
command_descriptions,
get_device,
is_none_or_empty,
timer,
)
from khoj.utils.rawconfig import FullConfig, GithubContentConfig, NotionContentConfig, SearchConfig, SearchResponse
from khoj.utils.state import SearchType
# Initialize Router # Initialize Router
api = APIRouter() api = APIRouter()
@ -512,6 +512,17 @@ def update(
return {"status": "ok", "message": "khoj reloaded"} return {"status": "ok", "message": "khoj reloaded"}
@api.get("/chat/starters", response_class=Response)
@requires(["authenticated"])
async def chat_starters(
request: Request,
common: CommonQueryParams,
) -> Response:
user: KhojUser = request.user.object
starter_questions = await ConversationAdapters.aget_conversation_starters(user)
return Response(content=json.dumps(starter_questions), media_type="application/json", status_code=200)
@api.get("/chat/history") @api.get("/chat/history")
@requires(["authenticated"]) @requires(["authenticated"])
def chat_history( def chat_history(
@ -577,6 +588,7 @@ async def chat(
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
) )
online_results: Dict = dict()
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General conversation_command = ConversationCommand.General
@ -593,11 +605,22 @@ async def chat(
no_entries_found_format = no_entries_found.format() no_entries_found_format = no_entries_found.format()
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
elif conversation_command == ConversationCommand.Online:
try:
online_results = await search_with_google(defiltered_query)
except ValueError as e:
return StreamingResponse(
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
media_type="text/event-stream",
status_code=200,
)
# Get the (streamed) chat response from the LLM of choice. # Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query, defiltered_query,
meta_log, meta_log,
compiled_references, compiled_references,
online_results,
inferred_queries, inferred_queries,
conversation_command, conversation_command,
user, user,
@ -650,7 +673,7 @@ async def extract_references_and_questions(
compiled_references: List[Any] = [] compiled_references: List[Any] = []
inferred_queries: List[str] = [] inferred_queries: List[str] = []
if conversation_type == ConversationCommand.General: if conversation_type == ConversationCommand.General or conversation_type == ConversationCommand.Online:
return compiled_references, inferred_queries, q return compiled_references, inferred_queries, q
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):

View file

@ -1,26 +1,28 @@
# Standard Packages # Standard Packages
import asyncio import asyncio
import json
import logging
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
import logging
from time import time from time import time
from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages # External Packages
from fastapi import HTTPException, Header, Request, Depends from fastapi import Depends, Header, HTTPException, Request
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, Subscription
from khoj.processor.conversation import prompts
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log
# Internal Packages # Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import ConversationAdapters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -96,6 +98,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Help return ConversationCommand.Help
elif query.startswith("/general"): elif query.startswith("/general"):
return ConversationCommand.General return ConversationCommand.General
elif query.startswith("/online"):
return ConversationCommand.Online
# If no relevant notes found for the given query # If no relevant notes found for the given query
elif not any_references: elif not any_references:
return ConversationCommand.General return ConversationCommand.General
@ -112,10 +116,70 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args) return await loop.run_in_executor(executor, generate_chat_response, *args)
async def generate_online_subqueries(q: str) -> List[str]:
"""
Generate subqueries from the given query
"""
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
current_date=utc_date,
query=q,
)
response = await send_message_to_model_wrapper(online_queries_prompt)
# Validate that the response is a non-empty, JSON-serializable list
try:
response = response.strip()
response = json.loads(response)
response = [q.strip() for q in response if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [q]
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [q]
async def send_message_to_model_wrapper(
message: str,
):
conversation_config = await ConversationAdapters.aget_default_conversation_config()
if conversation_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
if conversation_config.model_type == "offline":
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model
return send_message_to_model_offline(
message=message,
loaded_model=loaded_model,
model=conversation_config.chat_model,
streaming=False,
)
elif conversation_config.model_type == "openai":
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
return send_message_to_model(
message=message,
api_key=api_key,
model=chat_model,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
def generate_chat_response( def generate_chat_response(
q: str, q: str,
meta_log: dict, meta_log: dict,
compiled_references: List[str] = [], compiled_references: List[str] = [],
online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [], inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default, conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None, user: KhojUser = None,
@ -125,6 +189,7 @@ def generate_chat_response(
chat_response: str, chat_response: str,
user_message_time: str, user_message_time: str,
compiled_references: List[str], compiled_references: List[str],
online_results: Dict[str, Any],
inferred_queries: List[str], inferred_queries: List[str],
meta_log, meta_log,
): ):
@ -132,7 +197,11 @@ def generate_chat_response(
user_message=q, user_message=q,
chat_response=chat_response, chat_response=chat_response,
user_message_metadata={"created": user_message_time}, user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []), conversation_log=meta_log.get("chat", []),
) )
ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
@ -150,22 +219,20 @@ def generate_chat_response(
q, q,
user_message_time=user_message_time, user_message_time=user_message_time,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
meta_log=meta_log, meta_log=meta_log,
) )
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config() conversation_config = ConversationAdapters.get_valid_conversation_config(user)
conversation_config = ConversationAdapters.get_conversation_config(user) if conversation_config.model_type == "offline":
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None: if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model) state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model loaded_model = state.gpt4all_processor_config.loaded_model
chat_response = converse_offline( chat_response = converse_offline(
references=compiled_references, references=compiled_references,
online_results=online_results,
user_query=q, user_query=q,
loaded_model=loaded_model, loaded_model=loaded_model,
conversation_log=meta_log, conversation_log=meta_log,
@ -176,11 +243,13 @@ def generate_chat_response(
tokenizer_name=conversation_config.tokenizer, tokenizer_name=conversation_config.tokenizer,
) )
elif openai_chat_config and conversation_config.model_type == "openai": elif conversation_config.model_type == "openai":
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
chat_response = converse( chat_response = converse(
compiled_references, compiled_references,
online_results,
q, q,
meta_log, meta_log,
model=chat_model, model=chat_model,

View file

@ -7,12 +7,12 @@ from pydantic import BaseModel
from starlette.authentication import requires from starlette.authentication import requires
from khoj.database.models import GithubConfig, KhojUser, NotionConfig from khoj.database.models import GithubConfig, KhojUser, NotionConfig
from khoj.processor.github.github_to_entries import GithubToEntries from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.notion.notion_to_entries import NotionToEntries from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.pdf.pdf_to_entries import PdfToEntries from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.routers.helpers import update_telemetry_state from khoj.routers.helpers import update_telemetry_state
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils import constants, state from khoj.utils import constants, state

View file

@ -18,7 +18,7 @@ from khoj.utils.models import BaseEncoder
from khoj.utils.state import SearchType from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.database.adapters import EntryAdapters from khoj.database.adapters import EntryAdapters
from khoj.database.models import KhojUser, Entry as DbEntry from khoj.database.models import KhojUser, Entry as DbEntry
@ -141,7 +141,7 @@ def collate_results(hits, dedupe=True):
else: else:
hit_ids.add(hit.corpus_id) hit_ids.add(hit.corpus_id)
yield SearchResponse.parse_obj( yield SearchResponse.model_validate(
{ {
"entry": hit.raw, "entry": hit.raw,
"score": hit.distance, "score": hit.distance,

View file

@ -272,12 +272,14 @@ class ConversationCommand(str, Enum):
General = "general" General = "general"
Notes = "notes" Notes = "notes"
Help = "help" Help = "help"
Online = "online"
command_descriptions = { command_descriptions = {
ConversationCommand.General: "Only talk about information that relies on Khoj's general knowledge, not your personal knowledge base.", ConversationCommand.General: "Only talk about information that relies on Khoj's general knowledge, not your personal knowledge base.",
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
} }

View file

@ -1,48 +1,40 @@
# External Packages # External Packages
import os import os
from fastapi.testclient import TestClient
from pathlib import Path from pathlib import Path
import pytest
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
import os
from fastapi import FastAPI
import pytest
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.testclient import TestClient
# Internal Packages # Internal Packages
from khoj.configure import configure_routes, configure_search_types, configure_middleware from khoj.configure import configure_middleware, configure_routes, configure_search_types
from khoj.database.models import (
GithubConfig,
GithubRepoConfig,
KhojApiUser,
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPlaintextConfig,
)
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.routers.indexer import configure_content
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils import fs_syncer, state
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels
from khoj.utils.constants import web_directory from khoj.utils.constants import web_directory
from khoj.utils.helpers import resolve_absolute_path from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import ContentConfig, ImageContentConfig, ImageSearchConfig, SearchConfig
ContentConfig,
ImageContentConfig,
SearchConfig,
ImageSearchConfig,
)
from khoj.utils import state, fs_syncer
from khoj.routers.indexer import configure_content
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from khoj.database.models import (
KhojApiUser,
LocalOrgConfig,
LocalMarkdownConfig,
LocalPlaintextConfig,
GithubConfig,
KhojUser,
GithubRepoConfig,
)
from tests.helpers import ( from tests.helpers import (
UserFactory,
ChatModelOptionsFactory, ChatModelOptionsFactory,
OpenAIProcessorConversationConfigFactory,
OfflineChatProcessorConversationConfigFactory, OfflineChatProcessorConversationConfigFactory,
UserConversationProcessorConfigFactory, OpenAIProcessorConversationConfigFactory,
SubscriptionFactory, SubscriptionFactory,
UserConversationProcessorConfigFactory,
UserFactory,
) )

View file

@ -1,23 +1,23 @@
# Standard Modules # Standard Modules
from io import BytesIO from io import BytesIO
from PIL import Image
from urllib.parse import quote from urllib.parse import quote
import pytest import pytest
from fastapi import FastAPI
# External Packages # External Packages
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi import FastAPI from PIL import Image
import pytest
# Internal Packages # Internal Packages
from khoj.configure import configure_routes, configure_search_types from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from khoj.database.models import KhojUser, KhojApiUser
from khoj.database.adapters import EntryAdapters from khoj.database.adapters import EntryAdapters
from khoj.database.models import KhojApiUser, KhojUser
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.search_type import image_search, text_search
from khoj.utils import state
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.utils.state import config, content_index, search_models
# Test # Test

View file

@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
import urllib.parse import urllib.parse
from urllib.parse import quote
# External Packages # External Packages
import pytest import pytest
@ -54,6 +55,26 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
) )
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(chat_client):
# Act
q = "/online give me the link to paul graham's essay how to do great work"
encoded_q = quote(q, safe="")
response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true")
response_message = response.content.decode("utf-8")
response_message = response_message.split("### compiled references")[0]
# Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality @pytest.mark.chatquality
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)

View file

@ -4,7 +4,7 @@ from pathlib import Path
import os import os
# Internal Packages # Internal Packages
from khoj.processor.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.utils.fs_syncer import get_markdown_files from khoj.utils.fs_syncer import get_markdown_files
from khoj.utils.rawconfig import TextContentConfig from khoj.utils.rawconfig import TextContentConfig

View file

@ -1,24 +1,14 @@
# Standard Modules # Standard Modules
from io import BytesIO
from PIL import Image
from urllib.parse import quote from urllib.parse import quote
import pytest
# External Packages # External Packages
from fastapi.testclient import TestClient
from fastapi import FastAPI, UploadFile
from io import BytesIO
import pytest import pytest
from khoj.database.models import KhojApiUser, KhojUser
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
# Internal Packages # Internal Packages
from khoj.configure import configure_routes, configure_search_types from khoj.search_type import text_search
from khoj.utils import state
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from khoj.database.models import KhojUser, KhojApiUser
from khoj.database.adapters import EntryAdapters
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

View file

@ -1,6 +1,7 @@
# Standard Packages # Standard Packages
import os import os
import urllib.parse import urllib.parse
from urllib.parse import quote
# External Packages # External Packages
import pytest import pytest
@ -54,6 +55,26 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
) )
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(chat_client):
# Act
q = "/online give me the link to paul graham's essay how to do great work"
encoded_q = quote(q, safe="")
response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true")
response_message = response.content.decode("utf-8")
response_message = response_message.split("### compiled references")[0]
# Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality @pytest.mark.chatquality

View file

@ -3,8 +3,8 @@ import json
import os import os
# Internal Packages # Internal Packages
from khoj.processor.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import is_none_or_empty from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.utils.fs_syncer import get_org_files from khoj.utils.fs_syncer import get_org_files

View file

@ -2,7 +2,7 @@
import datetime import datetime
# Internal Packages # Internal Packages
from khoj.processor.org_mode import orgnode from khoj.processor.content.org_mode import orgnode
# Test # Test

View file

@ -3,7 +3,7 @@ import json
import os import os
# Internal Packages # Internal Packages
from khoj.processor.pdf.pdf_to_entries import PdfToEntries from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.utils.fs_syncer import get_pdf_files from khoj.utils.fs_syncer import get_pdf_files
from khoj.utils.rawconfig import TextContentConfig from khoj.utils.rawconfig import TextContentConfig

View file

@ -3,11 +3,12 @@ import json
import os import os
from pathlib import Path from pathlib import Path
from khoj.database.models import KhojUser, LocalPlaintextConfig
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
# Internal Packages # Internal Packages
from khoj.utils.fs_syncer import get_plaintext_files from khoj.utils.fs_syncer import get_plaintext_files
from khoj.utils.rawconfig import TextContentConfig from khoj.utils.rawconfig import TextContentConfig
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.database.models import LocalPlaintextConfig, KhojUser
def test_plaintext_file(tmp_path): def test_plaintext_file(tmp_path):

View file

@ -1,19 +1,20 @@
# System Packages # System Packages
import logging
from pathlib import Path
import os
import asyncio import asyncio
import logging
import os
from pathlib import Path
# External Packages # External Packages
import pytest import pytest
from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
# Internal Packages # Internal Packages
from khoj.search_type import text_search from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from khoj.processor.github.github_to_entries import GithubToEntries
from khoj.utils.fs_syncer import collect_files, get_org_files from khoj.utils.fs_syncer import collect_files, get_org_files
from khoj.database.models import LocalOrgConfig, KhojUser, Entry, GithubConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)