Rebase with master

This commit is contained in:
sabaimran 2023-12-19 21:02:49 +05:30
commit 6dd2b05bf5
25 changed files with 797 additions and 341 deletions

View file

@ -42,7 +42,7 @@ dependencies = [
"fastapi >= 0.104.1",
"python-multipart >= 0.0.5",
"jinja2 == 3.1.2",
"openai >= 0.27.0, < 1.0.0",
"openai >= 1.0.0",
"tiktoken >= 0.3.2",
"tenacity >= 8.2.2",
"pillow ~= 9.5.0",

View file

@ -179,7 +179,18 @@
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
imageMarkdown += "\n\n";
if (inferredQueries) {
const inferredQuery = inferredQueries?.[0];
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
}
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
@ -244,6 +255,17 @@
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown
newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language"
@ -328,109 +350,153 @@
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url, { headers })
.then(response => {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let rawResponse = "";
let references = null;
// Call specified Khoj API
let response = await fetch(url, { headers });
let rawResponse = "";
const contentType = response.headers.get("content-type");
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = null;
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
references.appendChild(referenceExpandButton);
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
}
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk;
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
}
readStream();
});
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
}
}
}
function incrementalChat(event) {
@ -522,7 +588,7 @@
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
});
})
.catch(err => {
@ -625,9 +691,13 @@
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
err.status == 422
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
if (err.status === 501) {
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
} else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
});
};
@ -810,6 +880,9 @@
margin-top: -10px;
transform: rotate(-60deg)
}
img.text-to-image {
max-width: 60%;
}
#chat-footer {
padding: 0;
@ -846,11 +919,12 @@
}
.input-row-button {
background: var(--background-color);
border: none;
border: 1px solid var(--main-text-color);
box-shadow: 0 0 11px #aaa;
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
padding: 0;
line-height: 1.5em;
cursor: pointer;
transition: background 0.3s ease-in-out;
@ -932,7 +1006,6 @@
color: var(--main-text-color);
border: 1px solid var(--main-text-color);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
@ -1050,6 +1123,9 @@
margin: 4px;
grid-template-columns: auto;
}
img.text-to-image {
max-width: 100%;
}
}
@media only screen and (min-width: 600px) {
body {

View file

@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
import { KhojSetting } from 'src/settings';
import fetch from "node-fetch";
export interface ChatJsonResult {
image?: string;
detail?: string;
}
export class KhojChatModal extends Modal {
result: string;
setting: KhojSetting;
@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
return referenceButton;
}
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) {
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
if (!message) {
return;
} else if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
this.renderMessage(chatEl, imageMarkdown, sender, dt);
return;
} else if (!context) {
this.renderMessage(chatEl, message, sender, dt);
return
return;
} else if (!!context && context?.length === 0) {
this.renderMessage(chatEl, message, sender, dt);
return
return;
}
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
let response = await request({ url: chatUrl, headers: headers });
let chatLogs = JSON.parse(response).response;
chatLogs.forEach((chatLog: any) => {
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
});
}
@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
this.result = "";
responseElement.innerHTML = "";
if (response.headers.get("content-type") == "application/json") {
let responseText = ""
try {
const responseAsJson = await response.json() as ChatJsonResult;
if (responseAsJson.image) {
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.detail) {
responseText = responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
responseText = response.body.read().toString()
} finally {
this.renderIncrementalMessage(responseElement, responseText);
}
}
for await (const chunk of response.body) {
const responseText = chunk.toString();
let responseText = chunk.toString();
if (responseText.includes("### compiled references:")) {
const additionalResponse = responseText.split("### compiled references:")[0];
this.renderIncrementalMessage(responseElement, additionalResponse);
@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
} else {
if (responseText.startsWith("{") && responseText.endsWith("}")) {
} else {
// If the chunk is not a JSON object, just display it as is
continue;
}
this.renderIncrementalMessage(responseElement, responseText);
}
}
@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
if (response.status === 200) {
console.log(response);
chatInput.value += response.json.text;
} else if (response.status === 422) {
throw new Error("⛔️ Failed to transcribe audio");
} else {
} else if (response.status === 501) {
throw new Error("⛔️ Configure speech-to-text model on server.");
} else if (response.status === 422) {
throw new Error("⛔️ Audio file to large to process.");
} else {
throw new Error("⛔️ Failed to transcribe audio.");
}
};

View file

@ -217,7 +217,9 @@ button.copy-button:hover {
background: #f5f5f5;
cursor: pointer;
}
img {
max-width: 60%;
}
#khoj-chat-footer {
padding: 0;

View file

@ -33,6 +33,9 @@ ALLOWED_HOSTS = [f".{KHOJ_DOMAIN}", "localhost", "127.0.0.1", "[::1]"]
CSRF_TRUSTED_ORIGINS = [
f"https://*.{KHOJ_DOMAIN}",
f"https://{KHOJ_DOMAIN}",
f"http://*.{KHOJ_DOMAIN}",
f"http://{KHOJ_DOMAIN}",
f"https://app.{KHOJ_DOMAIN}",
]
COOKIE_SAMESITE = "None"
@ -42,6 +45,7 @@ if DEBUG or os.getenv("KHOJ_DOMAIN") == None:
else:
SESSION_COOKIE_DOMAIN = KHOJ_DOMAIN
CSRF_COOKIE_DOMAIN = KHOJ_DOMAIN
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTOCOL", "https")
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True

View file

@ -7,6 +7,7 @@ import requests
import os
# External Packages
import openai
import schedule
from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
@ -22,6 +23,7 @@ from starlette.authentication import (
# Internal Packages
from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import (
ConversationAdapters,
get_all_users,
get_or_create_search_models,
aget_user_subscription_state,
@ -75,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user:
if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@ -97,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user_with_token:
if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
user_with_token.user
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:
@ -138,6 +140,10 @@ def configure_server(
config = FullConfig()
state.config = config
if ConversationAdapters.has_valid_openai_conversation_config():
openai_config = ConversationAdapters.get_openai_conversation_config()
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
# Initialize Search Models from Config and initialize content
try:
search_models = get_or_create_search_models()

View file

@ -22,6 +22,7 @@ from khoj.database.models import (
GithubConfig,
GithubRepoConfig,
GoogleUser,
TextToImageModelConfig,
KhojApiUser,
KhojUser,
NotionConfig,
@ -407,7 +408,7 @@ class ConversationAdapters:
)
max_results = 3
all_questions = await sync_to_async(list)(all_questions)
all_questions = await sync_to_async(list)(all_questions) # type: ignore
if len(all_questions) < max_results:
return all_questions
@ -433,6 +434,10 @@ class ConversationAdapters:
else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
@staticmethod
async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst()
class EntryAdapters:
word_filer = WordFilter()

View file

@ -1,5 +1,9 @@
import csv
import json
from django.contrib import admin
from django.contrib.auth.admin import UserAdmin
from django.http import HttpResponse
# Register your models here.
@ -13,6 +17,8 @@ from khoj.database.models import (
Subscription,
ReflectiveQuestion,
UserSearchModelConfig,
TextToImageModelConfig,
Conversation,
)
admin.site.register(KhojUser, UserAdmin)
@ -25,3 +31,102 @@ admin.site.register(SearchModelConfig)
admin.site.register(Subscription)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
@admin.register(Conversation)
class ConversationAdmin(admin.ModelAdmin):
list_display = (
"id",
"user",
"created_at",
"updated_at",
)
search_fields = ("conversation_id",)
ordering = ("-created_at",)
actions = ["export_selected_objects", "export_selected_minimal_objects"]
def export_selected_objects(self, request, queryset):
response = HttpResponse(content_type="text/csv")
response["Content-Disposition"] = 'attachment; filename="conversations.csv"'
writer = csv.writer(response)
writer.writerow(["id", "user", "created_at", "updated_at", "conversation_log"])
for conversation in queryset:
modified_log = conversation.conversation_log
chat_log = modified_log.get("chat", [])
for idx, log in enumerate(chat_log):
if (
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image"
):
log["message"] = "image redacted for space"
chat_log[idx] = log
modified_log["chat"] = chat_log
writer.writerow(
[
conversation.id,
conversation.user,
conversation.created_at,
conversation.updated_at,
json.dumps(modified_log),
]
)
return response
export_selected_objects.short_description = "Export selected conversations" # type: ignore
def export_selected_minimal_objects(self, request, queryset):
response = HttpResponse(content_type="text/csv")
response["Content-Disposition"] = 'attachment; filename="conversations.csv"'
writer = csv.writer(response)
writer.writerow(["id", "user", "created_at", "updated_at", "conversation_log"])
fields_to_keep = set(["message", "by", "created"])
for conversation in queryset:
return_log = dict()
chat_log = conversation.conversation_log.get("chat", [])
for idx, log in enumerate(chat_log):
updated_log = {}
for key in fields_to_keep:
updated_log[key] = log[key]
if (
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image"
):
updated_log["message"] = "image redacted for space"
chat_log[idx] = updated_log
return_log["chat"] = chat_log
writer.writerow(
[
conversation.id,
conversation.user,
conversation.created_at,
conversation.updated_at,
json.dumps(return_log),
]
)
return response
export_selected_minimal_objects.short_description = "Export selected conversations (minimal)" # type: ignore
def get_actions(self, request):
actions = super().get_actions(request)
if not request.user.is_superuser:
if "export_selected_objects" in actions:
del actions["export_selected_objects"]
if "export_selected_minimal_objects" in actions:
del actions["export_selected_minimal_objects"]
return actions

View file

@ -0,0 +1,25 @@
# Generated by Django 4.2.7 on 2023-12-04 22:17
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0021_speechtotextmodeloptions_and_more"),
]
operations = [
migrations.CreateModel(
name="TextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("model_name", models.CharField(default="dall-e-3", max_length=200)),
("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)),
],
options={
"abstract": False,
},
),
]

View file

@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel):
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)

View file

@ -183,12 +183,23 @@ To get started, just start typing below. You can also type / to see a list of co
referenceSection.appendChild(polishedReference);
}
}
}
}
return numOnlineReferences;
}
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
imageMarkdown += "\n\n";
if (inferredQueries) {
const inferredQuery = inferredQueries?.[0];
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
@ -253,6 +264,17 @@ To get started, just start typing below. You can also type / to see a list of co
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown
newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language"
@ -292,7 +314,7 @@ To get started, just start typing below. You can also type / to see a list of co
return element
}
function chat() {
async function chat() {
// Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
@ -333,113 +355,128 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url)
.then(response => {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let rawResponse = "";
let references = null;
// Call specified Khoj API
let response = await fetch(url);
let rawResponse = "";
const contentType = response.headers.get("content-type");
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
}
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
}
readStream();
});
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let references = null;
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
// Append any references after all the data has been streamed
if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return;
}
// Decode message chunk from stream
const chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const additionalResponse = chunk.split("### compiled references:")[0];
rawResponse += additionalResponse;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
const rawReference = chunk.split("### compiled references:")[1];
const rawReferenceAsJson = JSON.parse(rawReference);
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
referenceSection.classList.add("expanded");
} else {
referenceSection.classList.add("collapsed");
referenceSection.classList.remove("expanded");
}
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
});
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
};
}
};
function incrementalChat(event) {
if (!event.shiftKey && event.key === 'Enter') {
@ -516,7 +553,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
});
})
.catch(err => {
@ -611,9 +648,15 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
err.status == 422
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
if (err.status === 501) {
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
} else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else if (err.status === 429) {
flashStatusInChatInput("⛔️ " + err.statusText);
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
});
};
@ -902,6 +945,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin-top: -10px;
transform: rotate(-60deg)
}
img.text-to-image {
max-width: 60%;
}
#chat-footer {
padding: 0;
@ -916,7 +962,7 @@ To get started, just start typing below. You can also type / to see a list of co
grid-template-columns: auto 32px 32px;
grid-column-gap: 10px;
grid-row-gap: 10px;
background: #f9fafc
background: var(--background-color);
}
.option:hover {
box-shadow: 0 0 11px #aaa;
@ -938,9 +984,10 @@ To get started, just start typing below. You can also type / to see a list of co
}
.input-row-button {
background: var(--background-color);
border: none;
border: 1px solid var(--main-text-color);
box-shadow: 0 0 11px #aaa;
border-radius: 5px;
padding: 5px;
padding: 0px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
@ -1029,6 +1076,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin: 4px;
grid-template-columns: auto;
}
img.text-to-image {
max-width: 100%;
}
}
@media only screen and (min-width: 700px) {
body {

View file

@ -47,7 +47,7 @@ def extract_questions_offline(
if use_history:
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj":
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"

View file

@ -12,11 +12,12 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
# Download the chat model
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
# Decide whether to load model to GPU or CPU
chat_model_config = None
try:
# Download the chat model and its config
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
# Try load chat model to GPU if:
# 1. Loading chat model to GPU isn't disabled via CLI and
# 2. Machine has GPU
@ -26,6 +27,12 @@ def download_model(model_name: str):
)
except ValueError:
device = "cpu"
except Exception as e:
if chat_model_config is None:
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
logger.debug(f"Unable to download model config from gpt4all website: {e}")
else:
raise e
# Now load the downloaded chat model onto appropriate device
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)

View file

@ -41,7 +41,7 @@ def extract_questions(
[
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj"
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
]
)
@ -123,8 +123,8 @@ def send_message_to_model(
def converse(
references,
online_results,
user_query,
online_results=[],
conversation_log={},
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,

View file

@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs):
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError)
),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),

View file

@ -3,13 +3,13 @@ from io import BufferedReader
# External Packages
from asgiref.sync import sync_to_async
import openai
from openai import OpenAI
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
"""
Transcribe audio file using Whisper model via OpenAI's API
"""
# Send the audio data to the Whisper API
response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key)
return response["text"]
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
return response.text

View file

@ -15,6 +15,7 @@ You were created by Khoj Inc. with the following capabilities:
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
- Users can share information with you using the Khoj app, which is available for download at https://khoj.dev/downloads.
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
Today is {current_date} in UTC.
@ -109,6 +110,18 @@ Question: {query}
""".strip()
)
## Image Generation
## --
image_generation_improve_prompt = PromptTemplate.from_template(
"""
Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation.
Query: {query}
Improved Query:"""
)
## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(
@ -295,10 +308,13 @@ Q:"""
# --
help_message = PromptTemplate.from_template(
"""
**/notes**: Chat using the information in your knowledge base.
**/general**: Chat using just Khoj's general knowledge. This will not search against your notes.
**/default**: Chat using your knowledge base and Khoj's general knowledge for context.
**/help**: Show this help message.
- **/notes**: Chat using the information in your knowledge base.
- **/general**: Chat using just Khoj's general knowledge. This will not search against your notes.
- **/default**: Chat using your knowledge base and Khoj's general knowledge for context.
- **/online**: Chat using the internet as a source of information.
- **/image**: Generate an image based on your message.
- **/help**: Show this help message.
You are using the **{model}** model on the **{device}**.
**version**: {version}

View file

@ -4,6 +4,7 @@ from time import perf_counter
import json
from datetime import datetime
import queue
from typing import Any, Dict, List
import tiktoken
# External packages
@ -11,6 +12,8 @@ from langchain.schema import ChatMessage
from transformers import AutoTokenizer
# Internal Packages
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser
from khoj.utils.helpers import merge_dicts
@ -89,6 +92,32 @@ def message_to_log(
return conversation_log
def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
meta_log: Dict,
user_message_time: str = None,
compiled_references: List[str] = [],
online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
intent_type: str = "remember",
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
def generate_chatml_messages_with_context(
user_message,
system_message,

View file

@ -19,7 +19,7 @@ from starlette.authentication import requires
from khoj.configure import configure_server
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
from khoj.database.models import ChatModelOptions
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
@ -35,15 +35,18 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import (
ApiUserRateLimiter,
CommonQueryParams,
agenerate_chat_response,
get_conversation_command,
text_to_image,
is_ready_to_chat,
update_telemetry_state,
validate_conversation_config,
ConversationCommandRateLimiter,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@ -65,6 +68,7 @@ from khoj.utils.state import SearchType
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
def map_config_to_object(content_source: str):
@ -603,7 +607,13 @@ async def chat_options(
@api.post("/transcribe")
@requires(["authenticated"])
async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
async def transcribe(
request: Request,
common: CommonQueryParams,
file: UploadFile = File(...),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
):
user: KhojUser = request.user.object
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
user_message: str = None
@ -623,17 +633,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
# Send the audio data to the Whisper API
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unprocessable entity error
status_code = 422
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = openai_chat_config.api_key
# If the user has not configured a speech to text model, return an unsupported on server error
status_code = 501
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
finally:
# Close and Delete the temporary audio file
audio_file.close()
@ -666,11 +674,13 @@ async def chat(
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response:
user = request.user.object
user: KhojUser = request.user.object
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
@ -704,6 +714,27 @@ async def chat(
media_type="text/event-stream",
status_code=200,
)
elif conversation_command == ConversationCommand.Image:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_command.value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(q)
if image is None:
content_obj = {
"image": image,
"intentType": "text-to-image",
"detail": "Failed to generate image. Make sure your image generation configuration is set.",
}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
await sync_to_async(save_to_conversation_log)(
q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt]
)
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response(
@ -787,7 +818,6 @@ async def extract_references_and_questions(
conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
if (
offline_chat_config
and offline_chat_config.enabled
@ -804,7 +834,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key

View file

@ -9,23 +9,23 @@ from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages
from fastapi import Depends, Header, HTTPException, Request, UploadFile
import openai
from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log
# Internal Packages
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.General
elif query.startswith("/online"):
return ConversationCommand.Online
elif query.startswith("/image"):
return ConversationCommand.Image
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
@ -144,6 +146,20 @@ async def generate_online_subqueries(q: str) -> List[str]:
return [q]
async def generate_better_image_prompt(q: str) -> str:
"""
Generate a better image prompt from the given query
"""
image_prompt = prompts.image_generation_improve_prompt.format(
query=q,
)
response = await send_message_to_model_wrapper(image_prompt)
return response.strip()
async def send_message_to_model_wrapper(
message: str,
):
@ -168,11 +184,13 @@ async def send_message_to_model_wrapper(
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
return send_message_to_model(
openai_response = send_message_to_model(
message=message,
api_key=api_key,
model=chat_model,
)
return openai_response.content
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -186,30 +204,7 @@ def generate_chat_response(
conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
def _save_to_conversation_log(
q: str,
chat_response: str,
user_message_time: str,
compiled_references: List[str],
online_results: Dict[str, Any],
inferred_queries: List[str],
meta_log,
):
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
# Initialize Variables
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}")
@ -217,13 +212,13 @@ def generate_chat_response(
try:
partial_completion = partial(
_save_to_conversation_log,
save_to_conversation_log,
q,
user_message_time=user_message_time,
user=user,
meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
inferred_queries=inferred_queries,
meta_log=meta_log,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
@ -251,9 +246,9 @@ def generate_chat_response(
chat_model = conversation_config.chat_model
chat_response = converse(
compiled_references,
online_results,
q,
meta_log,
online_results=online_results,
conversation_log=meta_log,
model=chat_model,
api_key=api_key,
completion_func=partial_completion,
@ -271,6 +266,29 @@ def generate_chat_response(
return chat_response, metadata
async def text_to_image(message: str) -> Tuple[Optional[str], int, Optional[str]]:
status_code = 200
image = None
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
text2image_model = text_to_image_config.model_name
improved_image_prompt = await generate_better_image_prompt(message)
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
status_code = 500
return image, status_code, improved_image_prompt
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests
@ -298,6 +316,40 @@ class ApiUserRateLimiter:
user_requests.append(time())
class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
if state.billing_enabled is False:
return
if not request.user.is_authenticated:
return
if conversation_command not in self.restricted_commands:
return
user: KhojUser = request.user.object
user_cache = self.cache[user.uuid]
subscribed = has_required_scope(request, ["premium"])
user_cache[conversation_command].append(time())
# Remove requests outside of the 24-hr time window
cutoff = time() - 60 * 60 * 24
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
user_cache[conversation_command].pop(0)
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
return
class ApiIndexedDataLimiter:
def __init__(
self,
@ -315,7 +367,7 @@ class ApiIndexedDataLimiter:
if state.billing_enabled is False:
return
subscribed = has_required_scope(request, ["premium"])
incoming_data_size_mb = 0
incoming_data_size_mb = 0.0
deletion_file_names = set()
if not request.user.is_authenticated:

View file

@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
Notes = "notes"
Help = "help"
Online = "online"
Image = "image"
command_descriptions = {
@ -280,6 +281,7 @@ command_descriptions = {
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.",
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}

View file

@ -7,6 +7,7 @@ from khoj.database.models import (
OpenAIProcessorConversationConfig,
ChatModelOptions,
SpeechToTextModelOptions,
TextToImageModelConfig,
)
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
@ -103,6 +104,15 @@ def initialization():
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
)
default_text_to_image_model = "dall-e-3"
openai_text_to_image_model = input(
f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
)
openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
TextToImageModelConfig.objects.create(
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
)
if use_offline_model == "y" or use_openai_model == "y":
logger.info("🗣️ Chat model configuration complete")

View file

@ -22,17 +22,9 @@ class BaseEncoder(ABC):
class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None):
def __init__(self, model_name, client: openai.OpenAI, device=None):
self.model_name = model_name
if (
not state.processor_config
or not state.processor_config.conversation
or not state.processor_config.conversation.openai_model
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
)
openai.api_key = state.processor_config.conversation.openai_model.api_key
self.openai_client = client
self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs):
@ -43,7 +35,7 @@ class OpenAI(BaseEncoder):
processed_entry = entries[index].replace("\n", " ")
try:
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
# Use current models embedding dimension, once available
# Else default to embedding dimensions of the text-embedding-ada-002 model

View file

@ -1,15 +1,16 @@
# Standard Packages
from collections import defaultdict
import os
from pathlib import Path
import threading
from typing import List, Dict
from collections import defaultdict
# External Packages
from pathlib import Path
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from openai import OpenAI
from whisper import Whisper
# Internal Packages
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device
@ -21,6 +22,7 @@ search_models = SearchModels()
embeddings_model: Dict[str, EmbeddingsModel] = None
cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
openai_client: OpenAI = None
gpt4all_processor_config: GPT4AllProcessorModel = None
whisper_model: Whisper = None
config_file: Path = None

View file

@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client):
response_message = response_message.split("### compiled references")[0]
# Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
"Expected links or serper not setup in response but got: " + response_message
)