Rebase with master

This commit is contained in:
sabaimran 2023-12-19 21:02:49 +05:30
commit 6dd2b05bf5
25 changed files with 797 additions and 341 deletions

View file

@ -42,7 +42,7 @@ dependencies = [
"fastapi >= 0.104.1", "fastapi >= 0.104.1",
"python-multipart >= 0.0.5", "python-multipart >= 0.0.5",
"jinja2 == 3.1.2", "jinja2 == 3.1.2",
"openai >= 0.27.0, < 1.0.0", "openai >= 1.0.0",
"tiktoken >= 0.3.2", "tiktoken >= 0.3.2",
"tenacity >= 8.2.2", "tenacity >= 8.2.2",
"pillow ~= 9.5.0", "pillow ~= 9.5.0",

View file

@ -179,7 +179,18 @@
return numOnlineReferences; return numOnlineReferences;
} }
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
imageMarkdown += "\n\n";
if (inferredQueries) {
const inferredQuery = inferredQueries?.[0];
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
}
if (context == null && onlineContext == null) { if (context == null && onlineContext == null) {
renderMessage(message, by, dt); renderMessage(message, by, dt);
return; return;
@ -244,6 +255,17 @@
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model. // Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, ''); newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown // Render markdown
newHTML = md.render(newHTML); newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language" // Get any elements with a class that starts with "language"
@ -328,109 +350,153 @@
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled"); chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain // Call specified Khoj API
fetch(url, { headers }) let response = await fetch(url, { headers });
.then(response => { let rawResponse = "";
const reader = response.body.getReader(); const contentType = response.headers.get("content-type");
const decoder = new TextDecoder();
let rawResponse = "";
let references = null;
function readStream() { if (contentType === "application/json") {
reader.read().then(({ done, value }) => { // Handle JSON response
if (done) { try {
// Append any references after all the data has been streamed const responseAsJson = await response.json();
if (references != null) { if (responseAsJson.image) {
newResponseText.appendChild(references); // If response has image field, response is a generated image.
} rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; rawResponse += "\n\n";
document.getElementById("chat-input").removeAttribute("disabled"); const inferredQueries = responseAsJson.inferredQueries?.[0];
return; if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = null;
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
} }
// Decode message chunk from stream references.appendChild(referenceExpandButton);
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) { referenceExpandButton.addEventListener('click', function() {
const additionalResponse = chunk.split("### compiled references:")[0]; if (referenceSection.classList.contains("collapsed")) {
rawResponse += additionalResponse; referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
}
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = ""; newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse)); newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream(); readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk;
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
} }
}
// Scroll to bottom of chat window as chat response is streamed // Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}); });
} }
readStream(); }
});
} }
function incrementalChat(event) { function incrementalChat(event) {
@ -522,7 +588,7 @@
.then(response => { .then(response => {
// Render conversation history, if any // Render conversation history, if any
response.forEach(chat_log => { response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
}); });
}) })
.catch(err => { .catch(err => {
@ -625,9 +691,13 @@
.then(response => response.ok ? response.json() : Promise.reject(response)) .then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; }) .then(data => { chatInput.value += data.text; })
.catch(err => { .catch(err => {
err.status == 422 if (err.status === 501) {
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio") } else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
}); });
}; };
@ -810,6 +880,9 @@
margin-top: -10px; margin-top: -10px;
transform: rotate(-60deg) transform: rotate(-60deg)
} }
img.text-to-image {
max-width: 60%;
}
#chat-footer { #chat-footer {
padding: 0; padding: 0;
@ -846,11 +919,12 @@
} }
.input-row-button { .input-row-button {
background: var(--background-color); background: var(--background-color);
border: none; border: 1px solid var(--main-text-color);
box-shadow: 0 0 11px #aaa;
border-radius: 5px; border-radius: 5px;
padding: 5px;
font-size: 14px; font-size: 14px;
font-weight: 300; font-weight: 300;
padding: 0;
line-height: 1.5em; line-height: 1.5em;
cursor: pointer; cursor: pointer;
transition: background 0.3s ease-in-out; transition: background 0.3s ease-in-out;
@ -932,7 +1006,6 @@
color: var(--main-text-color); color: var(--main-text-color);
border: 1px solid var(--main-text-color); border: 1px solid var(--main-text-color);
border-radius: 5px; border-radius: 5px;
padding: 5px;
font-size: 14px; font-size: 14px;
font-weight: 300; font-weight: 300;
line-height: 1.5em; line-height: 1.5em;
@ -1050,6 +1123,9 @@
margin: 4px; margin: 4px;
grid-template-columns: auto; grid-template-columns: auto;
} }
img.text-to-image {
max-width: 100%;
}
} }
@media only screen and (min-width: 600px) { @media only screen and (min-width: 600px) {
body { body {

View file

@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
import { KhojSetting } from 'src/settings'; import { KhojSetting } from 'src/settings';
import fetch from "node-fetch"; import fetch from "node-fetch";
export interface ChatJsonResult {
image?: string;
detail?: string;
}
export class KhojChatModal extends Modal { export class KhojChatModal extends Modal {
result: string; result: string;
setting: KhojSetting; setting: KhojSetting;
@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
return referenceButton; return referenceButton;
} }
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) { renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
if (!message) { if (!message) {
return; return;
} else if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
this.renderMessage(chatEl, imageMarkdown, sender, dt);
return;
} else if (!context) { } else if (!context) {
this.renderMessage(chatEl, message, sender, dt); this.renderMessage(chatEl, message, sender, dt);
return return;
} else if (!!context && context?.length === 0) { } else if (!!context && context?.length === 0) {
this.renderMessage(chatEl, message, sender, dt); this.renderMessage(chatEl, message, sender, dt);
return return;
} }
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt); let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0] let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
let response = await request({ url: chatUrl, headers: headers }); let response = await request({ url: chatUrl, headers: headers });
let chatLogs = JSON.parse(response).response; let chatLogs = JSON.parse(response).response;
chatLogs.forEach((chatLog: any) => { chatLogs.forEach((chatLog: any) => {
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created)); this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
}); });
} }
@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
this.result = ""; this.result = "";
responseElement.innerHTML = ""; responseElement.innerHTML = "";
if (response.headers.get("content-type") == "application/json") {
let responseText = ""
try {
const responseAsJson = await response.json() as ChatJsonResult;
if (responseAsJson.image) {
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.detail) {
responseText = responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
responseText = response.body.read().toString()
} finally {
this.renderIncrementalMessage(responseElement, responseText);
}
}
for await (const chunk of response.body) { for await (const chunk of response.body) {
const responseText = chunk.toString(); let responseText = chunk.toString();
if (responseText.includes("### compiled references:")) { if (responseText.includes("### compiled references:")) {
const additionalResponse = responseText.split("### compiled references:")[0]; const additionalResponse = responseText.split("### compiled references:")[0];
this.renderIncrementalMessage(responseElement, additionalResponse); this.renderIncrementalMessage(responseElement, additionalResponse);
@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
referenceExpandButton.innerHTML = expandButtonText; referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection); references.appendChild(referenceSection);
} else { } else {
if (responseText.startsWith("{") && responseText.endsWith("}")) {
} else {
// If the chunk is not a JSON object, just display it as is
continue;
}
this.renderIncrementalMessage(responseElement, responseText); this.renderIncrementalMessage(responseElement, responseText);
} }
} }
@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
if (response.status === 200) { if (response.status === 200) {
console.log(response); console.log(response);
chatInput.value += response.json.text; chatInput.value += response.json.text;
} else if (response.status === 422) { } else if (response.status === 501) {
throw new Error("⛔️ Failed to transcribe audio");
} else {
throw new Error("⛔️ Configure speech-to-text model on server."); throw new Error("⛔️ Configure speech-to-text model on server.");
} else if (response.status === 422) {
throw new Error("⛔️ Audio file to large to process.");
} else {
throw new Error("⛔️ Failed to transcribe audio.");
} }
}; };

View file

@ -217,7 +217,9 @@ button.copy-button:hover {
background: #f5f5f5; background: #f5f5f5;
cursor: pointer; cursor: pointer;
} }
img {
max-width: 60%;
}
#khoj-chat-footer { #khoj-chat-footer {
padding: 0; padding: 0;

View file

@ -33,6 +33,9 @@ ALLOWED_HOSTS = [f".{KHOJ_DOMAIN}", "localhost", "127.0.0.1", "[::1]"]
CSRF_TRUSTED_ORIGINS = [ CSRF_TRUSTED_ORIGINS = [
f"https://*.{KHOJ_DOMAIN}", f"https://*.{KHOJ_DOMAIN}",
f"https://{KHOJ_DOMAIN}", f"https://{KHOJ_DOMAIN}",
f"http://*.{KHOJ_DOMAIN}",
f"http://{KHOJ_DOMAIN}",
f"https://app.{KHOJ_DOMAIN}",
] ]
COOKIE_SAMESITE = "None" COOKIE_SAMESITE = "None"
@ -42,6 +45,7 @@ if DEBUG or os.getenv("KHOJ_DOMAIN") == None:
else: else:
SESSION_COOKIE_DOMAIN = KHOJ_DOMAIN SESSION_COOKIE_DOMAIN = KHOJ_DOMAIN
CSRF_COOKIE_DOMAIN = KHOJ_DOMAIN CSRF_COOKIE_DOMAIN = KHOJ_DOMAIN
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTOCOL", "https")
SESSION_COOKIE_SECURE = True SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True CSRF_COOKIE_SECURE = True

View file

@ -7,6 +7,7 @@ import requests
import os import os
# External Packages # External Packages
import openai
import schedule import schedule
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
@ -22,6 +23,7 @@ from starlette.authentication import (
# Internal Packages # Internal Packages
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import ( from khoj.database.adapters import (
ConversationAdapters,
get_all_users, get_all_users,
get_or_create_search_models, get_or_create_search_models,
aget_user_subscription_state, aget_user_subscription_state,
@ -75,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst() .afirst()
) )
if user: if user:
if state.billing_enabled: if not state.billing_enabled:
subscription_state = await aget_user_subscription_state(user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value subscription_state = await aget_user_subscription_state(user)
or subscription_state == SubscriptionState.TRIAL.value subscribed = (
or subscription_state == SubscriptionState.UNSUBSCRIBED.value subscription_state == SubscriptionState.SUBSCRIBED.value
) or subscription_state == SubscriptionState.TRIAL.value
if subscribed: or subscription_state == SubscriptionState.UNSUBSCRIBED.value
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) )
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header # Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1] bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@ -97,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst() .afirst()
) )
if user_with_token: if user_with_token:
if state.billing_enabled: if not state.billing_enabled:
subscription_state = await aget_user_subscription_state(user_with_token.user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value subscription_state = await aget_user_subscription_state(user_with_token.user)
or subscription_state == SubscriptionState.TRIAL.value subscribed = (
or subscription_state == SubscriptionState.UNSUBSCRIBED.value subscription_state == SubscriptionState.SUBSCRIBED.value
) or subscription_state == SubscriptionState.TRIAL.value
if subscribed: or subscription_state == SubscriptionState.UNSUBSCRIBED.value
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser( )
user_with_token.user if subscribed:
) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode: if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user: if user:
@ -138,6 +140,10 @@ def configure_server(
config = FullConfig() config = FullConfig()
state.config = config state.config = config
if ConversationAdapters.has_valid_openai_conversation_config():
openai_config = ConversationAdapters.get_openai_conversation_config()
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
# Initialize Search Models from Config and initialize content # Initialize Search Models from Config and initialize content
try: try:
search_models = get_or_create_search_models() search_models = get_or_create_search_models()

View file

@ -22,6 +22,7 @@ from khoj.database.models import (
GithubConfig, GithubConfig,
GithubRepoConfig, GithubRepoConfig,
GoogleUser, GoogleUser,
TextToImageModelConfig,
KhojApiUser, KhojApiUser,
KhojUser, KhojUser,
NotionConfig, NotionConfig,
@ -407,7 +408,7 @@ class ConversationAdapters:
) )
max_results = 3 max_results = 3
all_questions = await sync_to_async(list)(all_questions) all_questions = await sync_to_async(list)(all_questions) # type: ignore
if len(all_questions) < max_results: if len(all_questions) < max_results:
return all_questions return all_questions
@ -433,6 +434,10 @@ class ConversationAdapters:
else: else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat") raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
@staticmethod
async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst()
class EntryAdapters: class EntryAdapters:
word_filer = WordFilter() word_filer = WordFilter()

View file

@ -1,5 +1,9 @@
import csv
import json
from django.contrib import admin from django.contrib import admin
from django.contrib.auth.admin import UserAdmin from django.contrib.auth.admin import UserAdmin
from django.http import HttpResponse
# Register your models here. # Register your models here.
@ -13,6 +17,8 @@ from khoj.database.models import (
Subscription, Subscription,
ReflectiveQuestion, ReflectiveQuestion,
UserSearchModelConfig, UserSearchModelConfig,
TextToImageModelConfig,
Conversation,
) )
admin.site.register(KhojUser, UserAdmin) admin.site.register(KhojUser, UserAdmin)
@ -25,3 +31,102 @@ admin.site.register(SearchModelConfig)
admin.site.register(Subscription) admin.site.register(Subscription)
admin.site.register(ReflectiveQuestion) admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig) admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
@admin.register(Conversation)
class ConversationAdmin(admin.ModelAdmin):
list_display = (
"id",
"user",
"created_at",
"updated_at",
)
search_fields = ("conversation_id",)
ordering = ("-created_at",)
actions = ["export_selected_objects", "export_selected_minimal_objects"]
def export_selected_objects(self, request, queryset):
response = HttpResponse(content_type="text/csv")
response["Content-Disposition"] = 'attachment; filename="conversations.csv"'
writer = csv.writer(response)
writer.writerow(["id", "user", "created_at", "updated_at", "conversation_log"])
for conversation in queryset:
modified_log = conversation.conversation_log
chat_log = modified_log.get("chat", [])
for idx, log in enumerate(chat_log):
if (
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image"
):
log["message"] = "image redacted for space"
chat_log[idx] = log
modified_log["chat"] = chat_log
writer.writerow(
[
conversation.id,
conversation.user,
conversation.created_at,
conversation.updated_at,
json.dumps(modified_log),
]
)
return response
export_selected_objects.short_description = "Export selected conversations" # type: ignore
def export_selected_minimal_objects(self, request, queryset):
response = HttpResponse(content_type="text/csv")
response["Content-Disposition"] = 'attachment; filename="conversations.csv"'
writer = csv.writer(response)
writer.writerow(["id", "user", "created_at", "updated_at", "conversation_log"])
fields_to_keep = set(["message", "by", "created"])
for conversation in queryset:
return_log = dict()
chat_log = conversation.conversation_log.get("chat", [])
for idx, log in enumerate(chat_log):
updated_log = {}
for key in fields_to_keep:
updated_log[key] = log[key]
if (
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image"
):
updated_log["message"] = "image redacted for space"
chat_log[idx] = updated_log
return_log["chat"] = chat_log
writer.writerow(
[
conversation.id,
conversation.user,
conversation.created_at,
conversation.updated_at,
json.dumps(return_log),
]
)
return response
export_selected_minimal_objects.short_description = "Export selected conversations (minimal)" # type: ignore
def get_actions(self, request):
actions = super().get_actions(request)
if not request.user.is_superuser:
if "export_selected_objects" in actions:
del actions["export_selected_objects"]
if "export_selected_minimal_objects" in actions:
del actions["export_selected_minimal_objects"]
return actions

View file

@ -0,0 +1,25 @@
# Generated by Django 4.2.7 on 2023-12-04 22:17
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0021_speechtotextmodeloptions_and_more"),
]
operations = [
migrations.CreateModel(
name="TextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("model_name", models.CharField(default="dall-e-3", max_length=200)),
("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)),
],
options={
"abstract": False,
},
),
]

View file

@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel):
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
class OpenAIProcessorConversationConfig(BaseModel): class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200) api_key = models.CharField(max_length=200)

View file

@ -183,12 +183,23 @@ To get started, just start typing below. You can also type / to see a list of co
referenceSection.appendChild(polishedReference); referenceSection.appendChild(polishedReference);
} }
} }
} }
return numOnlineReferences; return numOnlineReferences;
} }
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
imageMarkdown += "\n\n";
if (inferredQueries) {
const inferredQuery = inferredQueries?.[0];
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
if (context == null && onlineContext == null) { if (context == null && onlineContext == null) {
renderMessage(message, by, dt); renderMessage(message, by, dt);
return; return;
@ -253,6 +264,17 @@ To get started, just start typing below. You can also type / to see a list of co
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model. // Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, ''); newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown // Render markdown
newHTML = md.render(newHTML); newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language" // Get any elements with a class that starts with "language"
@ -292,7 +314,7 @@ To get started, just start typing below. You can also type / to see a list of co
return element return element
} }
function chat() { async function chat() {
// Extract required fields for search from form // Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim(); let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5; let resultsCount = localStorage.getItem("khojResultsCount") || 5;
@ -333,113 +355,128 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled"); chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain // Call specified Khoj API
fetch(url) let response = await fetch(url);
.then(response => { let rawResponse = "";
const reader = response.body.getReader(); const contentType = response.headers.get("content-type");
const decoder = new TextDecoder();
let rawResponse = "";
let references = null;
function readStream() { if (contentType === "application/json") {
reader.read().then(({ done, value }) => { // Handle JSON response
if (done) { try {
// Append any references after all the data has been streamed const responseAsJson = await response.json();
if (references != null) { if (responseAsJson.image) {
newResponseText.appendChild(references); // If response has image field, response is a generated image.
} rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; rawResponse += "\n\n";
document.getElementById("chat-input").removeAttribute("disabled"); const inferredQueries = responseAsJson.inferredQueries?.[0];
return; if (inferredQueries) {
} rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
} }
readStream(); if (responseAsJson.detail) {
}); // If response has detail field, response is an error message.
} rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = null;
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
});
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
};
}
};
function incrementalChat(event) { function incrementalChat(event) {
if (!event.shiftKey && event.key === 'Enter') { if (!event.shiftKey && event.key === 'Enter') {
@ -516,7 +553,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => { .then(response => {
// Render conversation history, if any // Render conversation history, if any
response.forEach(chat_log => { response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
}); });
}) })
.catch(err => { .catch(err => {
@ -611,9 +648,15 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => response.ok ? response.json() : Promise.reject(response)) .then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; }) .then(data => { chatInput.value += data.text; })
.catch(err => { .catch(err => {
err.status == 422 if (err.status === 501) {
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio") } else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else if (err.status === 429) {
flashStatusInChatInput("⛔️ " + err.statusText);
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
}); });
}; };
@ -902,6 +945,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin-top: -10px; margin-top: -10px;
transform: rotate(-60deg) transform: rotate(-60deg)
} }
img.text-to-image {
max-width: 60%;
}
#chat-footer { #chat-footer {
padding: 0; padding: 0;
@ -916,7 +962,7 @@ To get started, just start typing below. You can also type / to see a list of co
grid-template-columns: auto 32px 32px; grid-template-columns: auto 32px 32px;
grid-column-gap: 10px; grid-column-gap: 10px;
grid-row-gap: 10px; grid-row-gap: 10px;
background: #f9fafc background: var(--background-color);
} }
.option:hover { .option:hover {
box-shadow: 0 0 11px #aaa; box-shadow: 0 0 11px #aaa;
@ -938,9 +984,10 @@ To get started, just start typing below. You can also type / to see a list of co
} }
.input-row-button { .input-row-button {
background: var(--background-color); background: var(--background-color);
border: none; border: 1px solid var(--main-text-color);
box-shadow: 0 0 11px #aaa;
border-radius: 5px; border-radius: 5px;
padding: 5px; padding: 0px;
font-size: 14px; font-size: 14px;
font-weight: 300; font-weight: 300;
line-height: 1.5em; line-height: 1.5em;
@ -1029,6 +1076,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin: 4px; margin: 4px;
grid-template-columns: auto; grid-template-columns: auto;
} }
img.text-to-image {
max-width: 100%;
}
} }
@media only screen and (min-width: 700px) { @media only screen and (min-width: 700px) {
body { body {

View file

@ -47,7 +47,7 @@ def extract_questions_offline(
if use_history: if use_history:
for chat in conversation_log.get("chat", [])[-4:]: for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj": if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n" chat_history += f"A: {chat['message']}\n"

View file

@ -12,11 +12,12 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e raise e
# Download the chat model
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
# Decide whether to load model to GPU or CPU # Decide whether to load model to GPU or CPU
chat_model_config = None
try: try:
# Download the chat model and its config
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
# Try load chat model to GPU if: # Try load chat model to GPU if:
# 1. Loading chat model to GPU isn't disabled via CLI and # 1. Loading chat model to GPU isn't disabled via CLI and
# 2. Machine has GPU # 2. Machine has GPU
@ -26,6 +27,12 @@ def download_model(model_name: str):
) )
except ValueError: except ValueError:
device = "cpu" device = "cpu"
except Exception as e:
if chat_model_config is None:
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
logger.debug(f"Unable to download model config from gpt4all website: {e}")
else:
raise e
# Now load the downloaded chat model onto appropriate device # Now load the downloaded chat model onto appropriate device
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)

View file

@ -41,7 +41,7 @@ def extract_questions(
[ [
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n' f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:] for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
] ]
) )
@ -123,8 +123,8 @@ def send_message_to_model(
def converse( def converse(
references, references,
online_results,
user_query, user_query,
online_results=[],
conversation_log={}, conversation_log={},
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None, api_key: Optional[str] = None,

View file

@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
@retry( @retry(
retry=( retry=(
retry_if_exception_type(openai.error.Timeout) retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai.error.APIError) | retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai.error.APIConnectionError) | retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError) | retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError) | retry_if_exception_type(openai._exceptions.APIStatusError)
), ),
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs):
@retry( @retry(
retry=( retry=(
retry_if_exception_type(openai.error.Timeout) retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai.error.APIError) | retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai.error.APIConnectionError) | retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError) | retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError) | retry_if_exception_type(openai._exceptions.APIStatusError)
), ),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),

View file

@ -3,13 +3,13 @@ from io import BufferedReader
# External Packages # External Packages
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
import openai from openai import OpenAI
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
""" """
Transcribe audio file using Whisper model via OpenAI's API Transcribe audio file using Whisper model via OpenAI's API
""" """
# Send the audio data to the Whisper API # Send the audio data to the Whisper API
response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key) response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
return response["text"] return response.text

View file

@ -15,6 +15,7 @@ You were created by Khoj Inc. with the following capabilities:
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question. - Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. - Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay". - Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
- Users can share information with you using the Khoj app, which is available for download at https://khoj.dev/downloads.
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev. Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
Today is {current_date} in UTC. Today is {current_date} in UTC.
@ -109,6 +110,18 @@ Question: {query}
""".strip() """.strip()
) )
## Image Generation
## --
image_generation_improve_prompt = PromptTemplate.from_template(
"""
Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation.
Query: {query}
Improved Query:"""
)
## Online Search Conversation ## Online Search Conversation
## -- ## --
online_search_conversation = PromptTemplate.from_template( online_search_conversation = PromptTemplate.from_template(
@ -295,10 +308,13 @@ Q:"""
# -- # --
help_message = PromptTemplate.from_template( help_message = PromptTemplate.from_template(
""" """
**/notes**: Chat using the information in your knowledge base. - **/notes**: Chat using the information in your knowledge base.
**/general**: Chat using just Khoj's general knowledge. This will not search against your notes. - **/general**: Chat using just Khoj's general knowledge. This will not search against your notes.
**/default**: Chat using your knowledge base and Khoj's general knowledge for context. - **/default**: Chat using your knowledge base and Khoj's general knowledge for context.
**/help**: Show this help message. - **/online**: Chat using the internet as a source of information.
- **/image**: Generate an image based on your message.
- **/help**: Show this help message.
You are using the **{model}** model on the **{device}**. You are using the **{model}** model on the **{device}**.
**version**: {version} **version**: {version}

View file

@ -4,6 +4,7 @@ from time import perf_counter
import json import json
from datetime import datetime from datetime import datetime
import queue import queue
from typing import Any, Dict, List
import tiktoken import tiktoken
# External packages # External packages
@ -11,6 +12,8 @@ from langchain.schema import ChatMessage
from transformers import AutoTokenizer from transformers import AutoTokenizer
# Internal Packages # Internal Packages
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser
from khoj.utils.helpers import merge_dicts from khoj.utils.helpers import merge_dicts
@ -89,6 +92,32 @@ def message_to_log(
return conversation_log return conversation_log
def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
meta_log: Dict,
user_message_time: str = None,
compiled_references: List[str] = [],
online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
intent_type: str = "remember",
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, user_message,
system_message, system_message,

View file

@ -19,7 +19,7 @@ from starlette.authentication import requires
from khoj.configure import configure_server from khoj.configure import configure_server
from khoj.database import adapters from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
from khoj.database.models import ChatModelOptions from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import ( from khoj.database.models import (
GithubConfig, GithubConfig,
@ -35,15 +35,18 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
CommonQueryParams, CommonQueryParams,
agenerate_chat_response, agenerate_chat_response,
get_conversation_command, get_conversation_command,
text_to_image,
is_ready_to_chat, is_ready_to_chat,
update_telemetry_state, update_telemetry_state,
validate_conversation_config, validate_conversation_config,
ConversationCommandRateLimiter,
) )
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
@ -65,6 +68,7 @@ from khoj.utils.state import SearchType
# Initialize Router # Initialize Router
api = APIRouter() api = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
def map_config_to_object(content_source: str): def map_config_to_object(content_source: str):
@ -603,7 +607,13 @@ async def chat_options(
@api.post("/transcribe") @api.post("/transcribe")
@requires(["authenticated"]) @requires(["authenticated"])
async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)): async def transcribe(
request: Request,
common: CommonQueryParams,
file: UploadFile = File(...),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
):
user: KhojUser = request.user.object user: KhojUser = request.user.object
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm" audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
user_message: str = None user_message: str = None
@ -623,17 +633,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
# Send the audio data to the Whisper API # Send the audio data to the Whisper API
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not speech_to_text_config: if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unprocessable entity error # If the user has not configured a speech to text model, return an unsupported on server error
status_code = 422 status_code = 501
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
api_key = openai_chat_config.api_key
speech2text_model = speech_to_text_config.model_name speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key) user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE: elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model) user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
finally: finally:
# Close and Delete the temporary audio file # Close and Delete the temporary audio file
audio_file.close() audio_file.close()
@ -666,11 +674,13 @@ async def chat(
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response: ) -> Response:
user = request.user.object user: KhojUser = request.user.object
await is_ready_to_chat(user) await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True) conversation_command = get_conversation_command(query=q, any_references=True)
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
q = q.replace(f"/{conversation_command.value}", "").strip() q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
@ -704,6 +714,27 @@ async def chat(
media_type="text/event-stream", media_type="text/event-stream",
status_code=200, status_code=200,
) )
elif conversation_command == ConversationCommand.Image:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_command.value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(q)
if image is None:
content_obj = {
"image": image,
"intentType": "text-to-image",
"detail": "Failed to generate image. Make sure your image generation configuration is set.",
}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
await sync_to_async(save_to_conversation_log)(
q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt]
)
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice. # Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
@ -787,7 +818,6 @@ async def extract_references_and_questions(
conversation_config = await ConversationAdapters.aget_conversation_config(user) conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None: if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
if ( if (
offline_chat_config offline_chat_config
and offline_chat_config.enabled and offline_chat_config.enabled
@ -804,7 +834,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline( inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
) )
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat() openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key

View file

@ -9,23 +9,23 @@ from functools import partial
from time import time from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile
import openai
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
# Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry from khoj.utils.helpers import ConversationCommand, log_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1) executor = ThreadPoolExecutor(max_workers=1)
@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.General return ConversationCommand.General
elif query.startswith("/online"): elif query.startswith("/online"):
return ConversationCommand.Online return ConversationCommand.Online
elif query.startswith("/image"):
return ConversationCommand.Image
# If no relevant notes found for the given query # If no relevant notes found for the given query
elif not any_references: elif not any_references:
return ConversationCommand.General return ConversationCommand.General
@ -144,6 +146,20 @@ async def generate_online_subqueries(q: str) -> List[str]:
return [q] return [q]
async def generate_better_image_prompt(q: str) -> str:
"""
Generate a better image prompt from the given query
"""
image_prompt = prompts.image_generation_improve_prompt.format(
query=q,
)
response = await send_message_to_model_wrapper(image_prompt)
return response.strip()
async def send_message_to_model_wrapper( async def send_message_to_model_wrapper(
message: str, message: str,
): ):
@ -168,11 +184,13 @@ async def send_message_to_model_wrapper(
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
return send_message_to_model( openai_response = send_message_to_model(
message=message, message=message,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model,
) )
return openai_response.content
else: else:
raise HTTPException(status_code=500, detail="Invalid conversation config") raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -186,30 +204,7 @@ def generate_chat_response(
conversation_command: ConversationCommand = ConversationCommand.Default, conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None, user: KhojUser = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
def _save_to_conversation_log(
q: str,
chat_response: str,
user_message_time: str,
compiled_references: List[str],
online_results: Dict[str, Any],
inferred_queries: List[str],
meta_log,
):
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
# Initialize Variables # Initialize Variables
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_response = None chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}") logger.debug(f"Conversation Type: {conversation_command.name}")
@ -217,13 +212,13 @@ def generate_chat_response(
try: try:
partial_completion = partial( partial_completion = partial(
_save_to_conversation_log, save_to_conversation_log,
q, q,
user_message_time=user_message_time, user=user,
meta_log=meta_log,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
meta_log=meta_log,
) )
conversation_config = ConversationAdapters.get_valid_conversation_config(user) conversation_config = ConversationAdapters.get_valid_conversation_config(user)
@ -251,9 +246,9 @@ def generate_chat_response(
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
chat_response = converse( chat_response = converse(
compiled_references, compiled_references,
online_results,
q, q,
meta_log, online_results=online_results,
conversation_log=meta_log,
model=chat_model, model=chat_model,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
@ -271,6 +266,29 @@ def generate_chat_response(
return chat_response, metadata return chat_response, metadata
async def text_to_image(message: str) -> Tuple[Optional[str], int, Optional[str]]:
status_code = 200
image = None
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
text2image_model = text_to_image_config.model_name
improved_image_prompt = await generate_better_image_prompt(message)
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
status_code = 500
return image, status_code, improved_image_prompt
class ApiUserRateLimiter: class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int): def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests self.requests = requests
@ -298,6 +316,40 @@ class ApiUserRateLimiter:
user_requests.append(time()) user_requests.append(time())
class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
if state.billing_enabled is False:
return
if not request.user.is_authenticated:
return
if conversation_command not in self.restricted_commands:
return
user: KhojUser = request.user.object
user_cache = self.cache[user.uuid]
subscribed = has_required_scope(request, ["premium"])
user_cache[conversation_command].append(time())
# Remove requests outside of the 24-hr time window
cutoff = time() - 60 * 60 * 24
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
user_cache[conversation_command].pop(0)
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
return
class ApiIndexedDataLimiter: class ApiIndexedDataLimiter:
def __init__( def __init__(
self, self,
@ -315,7 +367,7 @@ class ApiIndexedDataLimiter:
if state.billing_enabled is False: if state.billing_enabled is False:
return return
subscribed = has_required_scope(request, ["premium"]) subscribed = has_required_scope(request, ["premium"])
incoming_data_size_mb = 0 incoming_data_size_mb = 0.0
deletion_file_names = set() deletion_file_names = set()
if not request.user.is_authenticated: if not request.user.is_authenticated:

View file

@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
Notes = "notes" Notes = "notes"
Help = "help" Help = "help"
Online = "online" Online = "online"
Image = "image"
command_descriptions = { command_descriptions = {
@ -280,6 +281,7 @@ command_descriptions = {
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.", ConversationCommand.Online: "Look up information on the internet.",
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
} }

View file

@ -7,6 +7,7 @@ from khoj.database.models import (
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ChatModelOptions, ChatModelOptions,
SpeechToTextModelOptions, SpeechToTextModelOptions,
TextToImageModelConfig,
) )
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
@ -103,6 +104,15 @@ def initialization():
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
) )
default_text_to_image_model = "dall-e-3"
openai_text_to_image_model = input(
f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
)
openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
TextToImageModelConfig.objects.create(
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
)
if use_offline_model == "y" or use_openai_model == "y": if use_offline_model == "y" or use_openai_model == "y":
logger.info("🗣️ Chat model configuration complete") logger.info("🗣️ Chat model configuration complete")

View file

@ -22,17 +22,9 @@ class BaseEncoder(ABC):
class OpenAI(BaseEncoder): class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None): def __init__(self, model_name, client: openai.OpenAI, device=None):
self.model_name = model_name self.model_name = model_name
if ( self.openai_client = client
not state.processor_config
or not state.processor_config.conversation
or not state.processor_config.conversation.openai_model
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
)
openai.api_key = state.processor_config.conversation.openai_model.api_key
self.embedding_dimensions = None self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs): def encode(self, entries, device=None, **kwargs):
@ -43,7 +35,7 @@ class OpenAI(BaseEncoder):
processed_entry = entries[index].replace("\n", " ") processed_entry = entries[index].replace("\n", " ")
try: try:
response = openai.Embedding.create(input=processed_entry, model=self.model_name) response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)] embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
# Use current models embedding dimension, once available # Use current models embedding dimension, once available
# Else default to embedding dimensions of the text-embedding-ada-002 model # Else default to embedding dimensions of the text-embedding-ada-002 model

View file

@ -1,15 +1,16 @@
# Standard Packages # Standard Packages
from collections import defaultdict
import os import os
from pathlib import Path
import threading import threading
from typing import List, Dict from typing import List, Dict
from collections import defaultdict
# External Packages # External Packages
from pathlib import Path from openai import OpenAI
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from whisper import Whisper from whisper import Whisper
# Internal Packages # Internal Packages
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device from khoj.utils.helpers import LRU, get_device
@ -21,6 +22,7 @@ search_models = SearchModels()
embeddings_model: Dict[str, EmbeddingsModel] = None embeddings_model: Dict[str, EmbeddingsModel] = None
cross_encoder_model: CrossEncoderModel = None cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex() content_index = ContentIndex()
openai_client: OpenAI = None
gpt4all_processor_config: GPT4AllProcessorModel = None gpt4all_processor_config: GPT4AllProcessorModel = None
whisper_model: Whisper = None whisper_model: Whisper = None
config_file: Path = None config_file: Path = None

View file

@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client):
response_message = response_message.split("### compiled references")[0] response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"] expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response_message "Expected links or serper not setup in response but got: " + response_message
) )