mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Rebase with master
This commit is contained in:
commit
6dd2b05bf5
25 changed files with 797 additions and 341 deletions
|
@ -42,7 +42,7 @@ dependencies = [
|
||||||
"fastapi >= 0.104.1",
|
"fastapi >= 0.104.1",
|
||||||
"python-multipart >= 0.0.5",
|
"python-multipart >= 0.0.5",
|
||||||
"jinja2 == 3.1.2",
|
"jinja2 == 3.1.2",
|
||||||
"openai >= 0.27.0, < 1.0.0",
|
"openai >= 1.0.0",
|
||||||
"tiktoken >= 0.3.2",
|
"tiktoken >= 0.3.2",
|
||||||
"tenacity >= 8.2.2",
|
"tenacity >= 8.2.2",
|
||||||
"pillow ~= 9.5.0",
|
"pillow ~= 9.5.0",
|
||||||
|
|
|
@ -179,7 +179,18 @@
|
||||||
return numOnlineReferences;
|
return numOnlineReferences;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
||||||
|
if (intentType === "text-to-image") {
|
||||||
|
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||||
|
imageMarkdown += "\n\n";
|
||||||
|
if (inferredQueries) {
|
||||||
|
const inferredQuery = inferredQueries?.[0];
|
||||||
|
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
|
||||||
|
}
|
||||||
|
renderMessage(imageMarkdown, by, dt);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (context == null && onlineContext == null) {
|
if (context == null && onlineContext == null) {
|
||||||
renderMessage(message, by, dt);
|
renderMessage(message, by, dt);
|
||||||
return;
|
return;
|
||||||
|
@ -244,6 +255,17 @@
|
||||||
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
||||||
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
||||||
|
|
||||||
|
// Customize the rendering of images
|
||||||
|
md.renderer.rules.image = function(tokens, idx, options, env, self) {
|
||||||
|
let token = tokens[idx];
|
||||||
|
|
||||||
|
// Add class="text-to-image" to images
|
||||||
|
token.attrPush(['class', 'text-to-image']);
|
||||||
|
|
||||||
|
// Use the default renderer to render image markdown format
|
||||||
|
return self.renderToken(tokens, idx, options);
|
||||||
|
};
|
||||||
|
|
||||||
// Render markdown
|
// Render markdown
|
||||||
newHTML = md.render(newHTML);
|
newHTML = md.render(newHTML);
|
||||||
// Get any elements with a class that starts with "language"
|
// Get any elements with a class that starts with "language"
|
||||||
|
@ -328,109 +350,153 @@
|
||||||
let chatInput = document.getElementById("chat-input");
|
let chatInput = document.getElementById("chat-input");
|
||||||
chatInput.classList.remove("option-enabled");
|
chatInput.classList.remove("option-enabled");
|
||||||
|
|
||||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
// Call specified Khoj API
|
||||||
fetch(url, { headers })
|
let response = await fetch(url, { headers });
|
||||||
.then(response => {
|
let rawResponse = "";
|
||||||
const reader = response.body.getReader();
|
const contentType = response.headers.get("content-type");
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let rawResponse = "";
|
|
||||||
let references = null;
|
|
||||||
|
|
||||||
function readStream() {
|
if (contentType === "application/json") {
|
||||||
reader.read().then(({ done, value }) => {
|
// Handle JSON response
|
||||||
if (done) {
|
try {
|
||||||
// Append any references after all the data has been streamed
|
const responseAsJson = await response.json();
|
||||||
if (references != null) {
|
if (responseAsJson.image) {
|
||||||
newResponseText.appendChild(references);
|
// If response has image field, response is a generated image.
|
||||||
}
|
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
rawResponse += "\n\n";
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
const inferredQueries = responseAsJson.inferredQueries?.[0];
|
||||||
return;
|
if (inferredQueries) {
|
||||||
|
rawResponse += `**Inferred Query**: ${inferredQueries}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (responseAsJson.detail) {
|
||||||
|
// If response has detail field, response is an error message.
|
||||||
|
rawResponse += responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
} finally {
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle streamed response of type text/event-stream or text/plain
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let references = null;
|
||||||
|
|
||||||
|
readStream();
|
||||||
|
|
||||||
|
function readStream() {
|
||||||
|
reader.read().then(({ done, value }) => {
|
||||||
|
if (done) {
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
if (references != null) {
|
||||||
|
newResponseText.appendChild(references);
|
||||||
|
}
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode message chunk from stream
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
if (chunk.includes("### compiled references:")) {
|
||||||
|
const additionalResponse = chunk.split("### compiled references:")[0];
|
||||||
|
rawResponse += additionalResponse;
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
const rawReference = chunk.split("### compiled references:")[1];
|
||||||
|
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||||
|
references = document.createElement('div');
|
||||||
|
references.classList.add("references");
|
||||||
|
|
||||||
|
let referenceExpandButton = document.createElement('button');
|
||||||
|
referenceExpandButton.classList.add("reference-expand-button");
|
||||||
|
|
||||||
|
let referenceSection = document.createElement('div');
|
||||||
|
referenceSection.classList.add("reference-section");
|
||||||
|
referenceSection.classList.add("collapsed");
|
||||||
|
|
||||||
|
let numReferences = 0;
|
||||||
|
|
||||||
|
// If rawReferenceAsJson is a list, then count the length
|
||||||
|
if (Array.isArray(rawReferenceAsJson)) {
|
||||||
|
numReferences = rawReferenceAsJson.length;
|
||||||
|
|
||||||
|
rawReferenceAsJson.forEach((reference, index) => {
|
||||||
|
let polishedReference = generateReference(reference, index);
|
||||||
|
referenceSection.appendChild(polishedReference);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode message chunk from stream
|
references.appendChild(referenceExpandButton);
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
referenceExpandButton.addEventListener('click', function() {
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
if (referenceSection.classList.contains("collapsed")) {
|
||||||
rawResponse += additionalResponse;
|
referenceSection.classList.remove("collapsed");
|
||||||
|
referenceSection.classList.add("expanded");
|
||||||
|
} else {
|
||||||
|
referenceSection.classList.add("collapsed");
|
||||||
|
referenceSection.classList.remove("expanded");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||||
|
referenceExpandButton.innerHTML = expandButtonText;
|
||||||
|
references.appendChild(referenceSection);
|
||||||
|
readStream();
|
||||||
|
} else {
|
||||||
|
// Display response from Khoj
|
||||||
|
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
||||||
|
newResponseText.removeChild(loadingSpinner);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
||||||
|
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
||||||
|
try {
|
||||||
|
const responseAsJson = JSON.parse(chunk);
|
||||||
|
if (responseAsJson.image) {
|
||||||
|
// If response has image field, response is a generated image.
|
||||||
|
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||||
|
rawResponse += "\n\n";
|
||||||
|
const inferredQueries = responseAsJson.inferredQueries?.[0];
|
||||||
|
if (inferredQueries) {
|
||||||
|
rawResponse += `**Inferred Query**: ${inferredQueries}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (responseAsJson.detail) {
|
||||||
|
rawResponse += responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
} finally {
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
newResponseText.innerHTML = "";
|
newResponseText.innerHTML = "";
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
const rawReference = chunk.split("### compiled references:")[1];
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
references = document.createElement('div');
|
|
||||||
references.classList.add("references");
|
|
||||||
|
|
||||||
let referenceExpandButton = document.createElement('button');
|
|
||||||
referenceExpandButton.classList.add("reference-expand-button");
|
|
||||||
|
|
||||||
let referenceSection = document.createElement('div');
|
|
||||||
referenceSection.classList.add("reference-section");
|
|
||||||
referenceSection.classList.add("collapsed");
|
|
||||||
|
|
||||||
let numReferences = 0;
|
|
||||||
|
|
||||||
// If rawReferenceAsJson is a list, then count the length
|
|
||||||
if (Array.isArray(rawReferenceAsJson)) {
|
|
||||||
numReferences = rawReferenceAsJson.length;
|
|
||||||
|
|
||||||
rawReferenceAsJson.forEach((reference, index) => {
|
|
||||||
let polishedReference = generateReference(reference, index);
|
|
||||||
referenceSection.appendChild(polishedReference);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
|
||||||
}
|
|
||||||
|
|
||||||
references.appendChild(referenceExpandButton);
|
|
||||||
|
|
||||||
referenceExpandButton.addEventListener('click', function() {
|
|
||||||
if (referenceSection.classList.contains("collapsed")) {
|
|
||||||
referenceSection.classList.remove("collapsed");
|
|
||||||
referenceSection.classList.add("expanded");
|
|
||||||
} else {
|
|
||||||
referenceSection.classList.add("collapsed");
|
|
||||||
referenceSection.classList.remove("expanded");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
|
||||||
referenceExpandButton.innerHTML = expandButtonText;
|
|
||||||
references.appendChild(referenceSection);
|
|
||||||
readStream();
|
readStream();
|
||||||
} else {
|
|
||||||
// Display response from Khoj
|
|
||||||
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
|
||||||
newResponseText.removeChild(loadingSpinner);
|
|
||||||
}
|
|
||||||
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
|
||||||
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
|
||||||
try {
|
|
||||||
const responseAsJson = JSON.parse(chunk);
|
|
||||||
if (responseAsJson.detail) {
|
|
||||||
newResponseText.innerHTML += responseAsJson.detail;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
newResponseText.innerHTML += chunk;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
readStream();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
// Scroll to bottom of chat window as chat response is streamed
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
readStream();
|
}
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function incrementalChat(event) {
|
function incrementalChat(event) {
|
||||||
|
@ -522,7 +588,7 @@
|
||||||
.then(response => {
|
.then(response => {
|
||||||
// Render conversation history, if any
|
// Render conversation history, if any
|
||||||
response.forEach(chat_log => {
|
response.forEach(chat_log => {
|
||||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
|
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
|
@ -625,9 +691,13 @@
|
||||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||||
.then(data => { chatInput.value += data.text; })
|
.then(data => { chatInput.value += data.text; })
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
err.status == 422
|
if (err.status === 501) {
|
||||||
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||||
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
|
} else if (err.status === 422) {
|
||||||
|
flashStatusInChatInput("⛔️ Audio file to large to process.")
|
||||||
|
} else {
|
||||||
|
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -810,6 +880,9 @@
|
||||||
margin-top: -10px;
|
margin-top: -10px;
|
||||||
transform: rotate(-60deg)
|
transform: rotate(-60deg)
|
||||||
}
|
}
|
||||||
|
img.text-to-image {
|
||||||
|
max-width: 60%;
|
||||||
|
}
|
||||||
|
|
||||||
#chat-footer {
|
#chat-footer {
|
||||||
padding: 0;
|
padding: 0;
|
||||||
|
@ -846,11 +919,12 @@
|
||||||
}
|
}
|
||||||
.input-row-button {
|
.input-row-button {
|
||||||
background: var(--background-color);
|
background: var(--background-color);
|
||||||
border: none;
|
border: 1px solid var(--main-text-color);
|
||||||
|
box-shadow: 0 0 11px #aaa;
|
||||||
border-radius: 5px;
|
border-radius: 5px;
|
||||||
padding: 5px;
|
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
font-weight: 300;
|
font-weight: 300;
|
||||||
|
padding: 0;
|
||||||
line-height: 1.5em;
|
line-height: 1.5em;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
transition: background 0.3s ease-in-out;
|
transition: background 0.3s ease-in-out;
|
||||||
|
@ -932,7 +1006,6 @@
|
||||||
color: var(--main-text-color);
|
color: var(--main-text-color);
|
||||||
border: 1px solid var(--main-text-color);
|
border: 1px solid var(--main-text-color);
|
||||||
border-radius: 5px;
|
border-radius: 5px;
|
||||||
padding: 5px;
|
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
font-weight: 300;
|
font-weight: 300;
|
||||||
line-height: 1.5em;
|
line-height: 1.5em;
|
||||||
|
@ -1050,6 +1123,9 @@
|
||||||
margin: 4px;
|
margin: 4px;
|
||||||
grid-template-columns: auto;
|
grid-template-columns: auto;
|
||||||
}
|
}
|
||||||
|
img.text-to-image {
|
||||||
|
max-width: 100%;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@media only screen and (min-width: 600px) {
|
@media only screen and (min-width: 600px) {
|
||||||
body {
|
body {
|
||||||
|
|
|
@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
|
||||||
import { KhojSetting } from 'src/settings';
|
import { KhojSetting } from 'src/settings';
|
||||||
import fetch from "node-fetch";
|
import fetch from "node-fetch";
|
||||||
|
|
||||||
|
export interface ChatJsonResult {
|
||||||
|
image?: string;
|
||||||
|
detail?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
export class KhojChatModal extends Modal {
|
export class KhojChatModal extends Modal {
|
||||||
result: string;
|
result: string;
|
||||||
setting: KhojSetting;
|
setting: KhojSetting;
|
||||||
|
@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
|
||||||
return referenceButton;
|
return referenceButton;
|
||||||
}
|
}
|
||||||
|
|
||||||
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) {
|
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
|
||||||
if (!message) {
|
if (!message) {
|
||||||
return;
|
return;
|
||||||
|
} else if (intentType === "text-to-image") {
|
||||||
|
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||||
|
this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
||||||
|
return;
|
||||||
} else if (!context) {
|
} else if (!context) {
|
||||||
this.renderMessage(chatEl, message, sender, dt);
|
this.renderMessage(chatEl, message, sender, dt);
|
||||||
return
|
return;
|
||||||
} else if (!!context && context?.length === 0) {
|
} else if (!!context && context?.length === 0) {
|
||||||
this.renderMessage(chatEl, message, sender, dt);
|
this.renderMessage(chatEl, message, sender, dt);
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
||||||
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
|
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
|
||||||
|
@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
|
||||||
let response = await request({ url: chatUrl, headers: headers });
|
let response = await request({ url: chatUrl, headers: headers });
|
||||||
let chatLogs = JSON.parse(response).response;
|
let chatLogs = JSON.parse(response).response;
|
||||||
chatLogs.forEach((chatLog: any) => {
|
chatLogs.forEach((chatLog: any) => {
|
||||||
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
|
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
|
||||||
|
|
||||||
this.result = "";
|
this.result = "";
|
||||||
responseElement.innerHTML = "";
|
responseElement.innerHTML = "";
|
||||||
|
if (response.headers.get("content-type") == "application/json") {
|
||||||
|
let responseText = ""
|
||||||
|
try {
|
||||||
|
const responseAsJson = await response.json() as ChatJsonResult;
|
||||||
|
if (responseAsJson.image) {
|
||||||
|
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||||
|
} else if (responseAsJson.detail) {
|
||||||
|
responseText = responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
responseText = response.body.read().toString()
|
||||||
|
} finally {
|
||||||
|
this.renderIncrementalMessage(responseElement, responseText);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for await (const chunk of response.body) {
|
for await (const chunk of response.body) {
|
||||||
const responseText = chunk.toString();
|
let responseText = chunk.toString();
|
||||||
if (responseText.includes("### compiled references:")) {
|
if (responseText.includes("### compiled references:")) {
|
||||||
const additionalResponse = responseText.split("### compiled references:")[0];
|
const additionalResponse = responseText.split("### compiled references:")[0];
|
||||||
this.renderIncrementalMessage(responseElement, additionalResponse);
|
this.renderIncrementalMessage(responseElement, additionalResponse);
|
||||||
|
@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
|
||||||
referenceExpandButton.innerHTML = expandButtonText;
|
referenceExpandButton.innerHTML = expandButtonText;
|
||||||
references.appendChild(referenceSection);
|
references.appendChild(referenceSection);
|
||||||
} else {
|
} else {
|
||||||
|
if (responseText.startsWith("{") && responseText.endsWith("}")) {
|
||||||
|
} else {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
this.renderIncrementalMessage(responseElement, responseText);
|
this.renderIncrementalMessage(responseElement, responseText);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
|
||||||
if (response.status === 200) {
|
if (response.status === 200) {
|
||||||
console.log(response);
|
console.log(response);
|
||||||
chatInput.value += response.json.text;
|
chatInput.value += response.json.text;
|
||||||
} else if (response.status === 422) {
|
} else if (response.status === 501) {
|
||||||
throw new Error("⛔️ Failed to transcribe audio");
|
|
||||||
} else {
|
|
||||||
throw new Error("⛔️ Configure speech-to-text model on server.");
|
throw new Error("⛔️ Configure speech-to-text model on server.");
|
||||||
|
} else if (response.status === 422) {
|
||||||
|
throw new Error("⛔️ Audio file to large to process.");
|
||||||
|
} else {
|
||||||
|
throw new Error("⛔️ Failed to transcribe audio.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -217,7 +217,9 @@ button.copy-button:hover {
|
||||||
background: #f5f5f5;
|
background: #f5f5f5;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
}
|
}
|
||||||
|
img {
|
||||||
|
max-width: 60%;
|
||||||
|
}
|
||||||
|
|
||||||
#khoj-chat-footer {
|
#khoj-chat-footer {
|
||||||
padding: 0;
|
padding: 0;
|
||||||
|
|
|
@ -33,6 +33,9 @@ ALLOWED_HOSTS = [f".{KHOJ_DOMAIN}", "localhost", "127.0.0.1", "[::1]"]
|
||||||
CSRF_TRUSTED_ORIGINS = [
|
CSRF_TRUSTED_ORIGINS = [
|
||||||
f"https://*.{KHOJ_DOMAIN}",
|
f"https://*.{KHOJ_DOMAIN}",
|
||||||
f"https://{KHOJ_DOMAIN}",
|
f"https://{KHOJ_DOMAIN}",
|
||||||
|
f"http://*.{KHOJ_DOMAIN}",
|
||||||
|
f"http://{KHOJ_DOMAIN}",
|
||||||
|
f"https://app.{KHOJ_DOMAIN}",
|
||||||
]
|
]
|
||||||
|
|
||||||
COOKIE_SAMESITE = "None"
|
COOKIE_SAMESITE = "None"
|
||||||
|
@ -42,6 +45,7 @@ if DEBUG or os.getenv("KHOJ_DOMAIN") == None:
|
||||||
else:
|
else:
|
||||||
SESSION_COOKIE_DOMAIN = KHOJ_DOMAIN
|
SESSION_COOKIE_DOMAIN = KHOJ_DOMAIN
|
||||||
CSRF_COOKIE_DOMAIN = KHOJ_DOMAIN
|
CSRF_COOKIE_DOMAIN = KHOJ_DOMAIN
|
||||||
|
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTOCOL", "https")
|
||||||
|
|
||||||
SESSION_COOKIE_SECURE = True
|
SESSION_COOKIE_SECURE = True
|
||||||
CSRF_COOKIE_SECURE = True
|
CSRF_COOKIE_SECURE = True
|
||||||
|
|
|
@ -7,6 +7,7 @@ import requests
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
|
import openai
|
||||||
import schedule
|
import schedule
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
|
@ -22,6 +23,7 @@ from starlette.authentication import (
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.database.models import KhojUser, Subscription
|
from khoj.database.models import KhojUser, Subscription
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
|
ConversationAdapters,
|
||||||
get_all_users,
|
get_all_users,
|
||||||
get_or_create_search_models,
|
get_or_create_search_models,
|
||||||
aget_user_subscription_state,
|
aget_user_subscription_state,
|
||||||
|
@ -75,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if user:
|
if user:
|
||||||
if state.billing_enabled:
|
if not state.billing_enabled:
|
||||||
subscription_state = await aget_user_subscription_state(user)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||||
subscribed = (
|
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
subscription_state = await aget_user_subscription_state(user)
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
subscribed = (
|
||||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||||
)
|
or subscription_state == SubscriptionState.TRIAL.value
|
||||||
if subscribed:
|
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
)
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
if subscribed:
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||||
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||||
# Get bearer token from header
|
# Get bearer token from header
|
||||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||||
|
@ -97,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if user_with_token:
|
if user_with_token:
|
||||||
if state.billing_enabled:
|
if not state.billing_enabled:
|
||||||
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
subscribed = (
|
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
subscribed = (
|
||||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||||
)
|
or subscription_state == SubscriptionState.TRIAL.value
|
||||||
if subscribed:
|
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
|
)
|
||||||
user_with_token.user
|
if subscribed:
|
||||||
)
|
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
|
||||||
if state.anonymous_mode:
|
if state.anonymous_mode:
|
||||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||||
if user:
|
if user:
|
||||||
|
@ -138,6 +140,10 @@ def configure_server(
|
||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
state.config = config
|
state.config = config
|
||||||
|
|
||||||
|
if ConversationAdapters.has_valid_openai_conversation_config():
|
||||||
|
openai_config = ConversationAdapters.get_openai_conversation_config()
|
||||||
|
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
|
||||||
|
|
||||||
# Initialize Search Models from Config and initialize content
|
# Initialize Search Models from Config and initialize content
|
||||||
try:
|
try:
|
||||||
search_models = get_or_create_search_models()
|
search_models = get_or_create_search_models()
|
||||||
|
|
|
@ -22,6 +22,7 @@ from khoj.database.models import (
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
GoogleUser,
|
GoogleUser,
|
||||||
|
TextToImageModelConfig,
|
||||||
KhojApiUser,
|
KhojApiUser,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
|
@ -407,7 +408,7 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
|
|
||||||
max_results = 3
|
max_results = 3
|
||||||
all_questions = await sync_to_async(list)(all_questions)
|
all_questions = await sync_to_async(list)(all_questions) # type: ignore
|
||||||
if len(all_questions) < max_results:
|
if len(all_questions) < max_results:
|
||||||
return all_questions
|
return all_questions
|
||||||
|
|
||||||
|
@ -433,6 +434,10 @@ class ConversationAdapters:
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_text_to_image_model_config():
|
||||||
|
return await TextToImageModelConfig.objects.filter().afirst()
|
||||||
|
|
||||||
|
|
||||||
class EntryAdapters:
|
class EntryAdapters:
|
||||||
word_filer = WordFilter()
|
word_filer = WordFilter()
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from django.contrib.auth.admin import UserAdmin
|
from django.contrib.auth.admin import UserAdmin
|
||||||
|
from django.http import HttpResponse
|
||||||
|
|
||||||
# Register your models here.
|
# Register your models here.
|
||||||
|
|
||||||
|
@ -13,6 +17,8 @@ from khoj.database.models import (
|
||||||
Subscription,
|
Subscription,
|
||||||
ReflectiveQuestion,
|
ReflectiveQuestion,
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
|
TextToImageModelConfig,
|
||||||
|
Conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
admin.site.register(KhojUser, UserAdmin)
|
admin.site.register(KhojUser, UserAdmin)
|
||||||
|
@ -25,3 +31,102 @@ admin.site.register(SearchModelConfig)
|
||||||
admin.site.register(Subscription)
|
admin.site.register(Subscription)
|
||||||
admin.site.register(ReflectiveQuestion)
|
admin.site.register(ReflectiveQuestion)
|
||||||
admin.site.register(UserSearchModelConfig)
|
admin.site.register(UserSearchModelConfig)
|
||||||
|
admin.site.register(TextToImageModelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(Conversation)
|
||||||
|
class ConversationAdmin(admin.ModelAdmin):
|
||||||
|
list_display = (
|
||||||
|
"id",
|
||||||
|
"user",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
)
|
||||||
|
search_fields = ("conversation_id",)
|
||||||
|
ordering = ("-created_at",)
|
||||||
|
|
||||||
|
actions = ["export_selected_objects", "export_selected_minimal_objects"]
|
||||||
|
|
||||||
|
def export_selected_objects(self, request, queryset):
|
||||||
|
response = HttpResponse(content_type="text/csv")
|
||||||
|
response["Content-Disposition"] = 'attachment; filename="conversations.csv"'
|
||||||
|
|
||||||
|
writer = csv.writer(response)
|
||||||
|
writer.writerow(["id", "user", "created_at", "updated_at", "conversation_log"])
|
||||||
|
|
||||||
|
for conversation in queryset:
|
||||||
|
modified_log = conversation.conversation_log
|
||||||
|
chat_log = modified_log.get("chat", [])
|
||||||
|
for idx, log in enumerate(chat_log):
|
||||||
|
if (
|
||||||
|
log["by"] == "khoj"
|
||||||
|
and log["intent"]
|
||||||
|
and log["intent"]["type"]
|
||||||
|
and log["intent"]["type"] == "text-to-image"
|
||||||
|
):
|
||||||
|
log["message"] = "image redacted for space"
|
||||||
|
chat_log[idx] = log
|
||||||
|
modified_log["chat"] = chat_log
|
||||||
|
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
conversation.id,
|
||||||
|
conversation.user,
|
||||||
|
conversation.created_at,
|
||||||
|
conversation.updated_at,
|
||||||
|
json.dumps(modified_log),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
export_selected_objects.short_description = "Export selected conversations" # type: ignore
|
||||||
|
|
||||||
|
def export_selected_minimal_objects(self, request, queryset):
|
||||||
|
response = HttpResponse(content_type="text/csv")
|
||||||
|
response["Content-Disposition"] = 'attachment; filename="conversations.csv"'
|
||||||
|
|
||||||
|
writer = csv.writer(response)
|
||||||
|
writer.writerow(["id", "user", "created_at", "updated_at", "conversation_log"])
|
||||||
|
|
||||||
|
fields_to_keep = set(["message", "by", "created"])
|
||||||
|
|
||||||
|
for conversation in queryset:
|
||||||
|
return_log = dict()
|
||||||
|
chat_log = conversation.conversation_log.get("chat", [])
|
||||||
|
for idx, log in enumerate(chat_log):
|
||||||
|
updated_log = {}
|
||||||
|
for key in fields_to_keep:
|
||||||
|
updated_log[key] = log[key]
|
||||||
|
if (
|
||||||
|
log["by"] == "khoj"
|
||||||
|
and log["intent"]
|
||||||
|
and log["intent"]["type"]
|
||||||
|
and log["intent"]["type"] == "text-to-image"
|
||||||
|
):
|
||||||
|
updated_log["message"] = "image redacted for space"
|
||||||
|
chat_log[idx] = updated_log
|
||||||
|
return_log["chat"] = chat_log
|
||||||
|
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
conversation.id,
|
||||||
|
conversation.user,
|
||||||
|
conversation.created_at,
|
||||||
|
conversation.updated_at,
|
||||||
|
json.dumps(return_log),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
export_selected_minimal_objects.short_description = "Export selected conversations (minimal)" # type: ignore
|
||||||
|
|
||||||
|
def get_actions(self, request):
|
||||||
|
actions = super().get_actions(request)
|
||||||
|
if not request.user.is_superuser:
|
||||||
|
if "export_selected_objects" in actions:
|
||||||
|
del actions["export_selected_objects"]
|
||||||
|
if "export_selected_minimal_objects" in actions:
|
||||||
|
del actions["export_selected_minimal_objects"]
|
||||||
|
return actions
|
||||||
|
|
25
src/khoj/database/migrations/0022_texttoimagemodelconfig.py
Normal file
25
src/khoj/database/migrations/0022_texttoimagemodelconfig.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Generated by Django 4.2.7 on 2023-12-04 22:17
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0021_speechtotextmodeloptions_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="TextToImageModelConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("model_name", models.CharField(default="dall-e-3", max_length=200)),
|
||||||
|
("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
|
@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel):
|
||||||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
|
|
||||||
|
|
||||||
|
class TextToImageModelConfig(BaseModel):
|
||||||
|
class ModelType(models.TextChoices):
|
||||||
|
OPENAI = "openai"
|
||||||
|
|
||||||
|
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||||
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConversationConfig(BaseModel):
|
class OpenAIProcessorConversationConfig(BaseModel):
|
||||||
api_key = models.CharField(max_length=200)
|
api_key = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|
|
@ -183,12 +183,23 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
referenceSection.appendChild(polishedReference);
|
referenceSection.appendChild(polishedReference);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return numOnlineReferences;
|
return numOnlineReferences;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
|
||||||
|
if (intentType === "text-to-image") {
|
||||||
|
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||||
|
imageMarkdown += "\n\n";
|
||||||
|
if (inferredQueries) {
|
||||||
|
const inferredQuery = inferredQueries?.[0];
|
||||||
|
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
|
||||||
|
}
|
||||||
|
renderMessage(imageMarkdown, by, dt);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
|
||||||
if (context == null && onlineContext == null) {
|
if (context == null && onlineContext == null) {
|
||||||
renderMessage(message, by, dt);
|
renderMessage(message, by, dt);
|
||||||
return;
|
return;
|
||||||
|
@ -253,6 +264,17 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
||||||
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
||||||
|
|
||||||
|
// Customize the rendering of images
|
||||||
|
md.renderer.rules.image = function(tokens, idx, options, env, self) {
|
||||||
|
let token = tokens[idx];
|
||||||
|
|
||||||
|
// Add class="text-to-image" to images
|
||||||
|
token.attrPush(['class', 'text-to-image']);
|
||||||
|
|
||||||
|
// Use the default renderer to render image markdown format
|
||||||
|
return self.renderToken(tokens, idx, options);
|
||||||
|
};
|
||||||
|
|
||||||
// Render markdown
|
// Render markdown
|
||||||
newHTML = md.render(newHTML);
|
newHTML = md.render(newHTML);
|
||||||
// Get any elements with a class that starts with "language"
|
// Get any elements with a class that starts with "language"
|
||||||
|
@ -292,7 +314,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
return element
|
return element
|
||||||
}
|
}
|
||||||
|
|
||||||
function chat() {
|
async function chat() {
|
||||||
// Extract required fields for search from form
|
// Extract required fields for search from form
|
||||||
let query = document.getElementById("chat-input").value.trim();
|
let query = document.getElementById("chat-input").value.trim();
|
||||||
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
||||||
|
@ -333,113 +355,128 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
let chatInput = document.getElementById("chat-input");
|
let chatInput = document.getElementById("chat-input");
|
||||||
chatInput.classList.remove("option-enabled");
|
chatInput.classList.remove("option-enabled");
|
||||||
|
|
||||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
// Call specified Khoj API
|
||||||
fetch(url)
|
let response = await fetch(url);
|
||||||
.then(response => {
|
let rawResponse = "";
|
||||||
const reader = response.body.getReader();
|
const contentType = response.headers.get("content-type");
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let rawResponse = "";
|
|
||||||
let references = null;
|
|
||||||
|
|
||||||
function readStream() {
|
if (contentType === "application/json") {
|
||||||
reader.read().then(({ done, value }) => {
|
// Handle JSON response
|
||||||
if (done) {
|
try {
|
||||||
// Append any references after all the data has been streamed
|
const responseAsJson = await response.json();
|
||||||
if (references != null) {
|
if (responseAsJson.image) {
|
||||||
newResponseText.appendChild(references);
|
// If response has image field, response is a generated image.
|
||||||
}
|
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
rawResponse += "\n\n";
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
const inferredQueries = responseAsJson.inferredQueries?.[0];
|
||||||
return;
|
if (inferredQueries) {
|
||||||
}
|
rawResponse += `**Inferred Query**: ${inferredQueries}`;
|
||||||
|
}
|
||||||
// Decode message chunk from stream
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
if (chunk.includes("### compiled references:")) {
|
|
||||||
const additionalResponse = chunk.split("### compiled references:")[0];
|
|
||||||
rawResponse += additionalResponse;
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
|
|
||||||
const rawReference = chunk.split("### compiled references:")[1];
|
|
||||||
const rawReferenceAsJson = JSON.parse(rawReference);
|
|
||||||
references = document.createElement('div');
|
|
||||||
references.classList.add("references");
|
|
||||||
|
|
||||||
let referenceExpandButton = document.createElement('button');
|
|
||||||
referenceExpandButton.classList.add("reference-expand-button");
|
|
||||||
|
|
||||||
let referenceSection = document.createElement('div');
|
|
||||||
referenceSection.classList.add("reference-section");
|
|
||||||
referenceSection.classList.add("collapsed");
|
|
||||||
|
|
||||||
let numReferences = 0;
|
|
||||||
|
|
||||||
// If rawReferenceAsJson is a list, then count the length
|
|
||||||
if (Array.isArray(rawReferenceAsJson)) {
|
|
||||||
numReferences = rawReferenceAsJson.length;
|
|
||||||
|
|
||||||
rawReferenceAsJson.forEach((reference, index) => {
|
|
||||||
let polishedReference = generateReference(reference, index);
|
|
||||||
referenceSection.appendChild(polishedReference);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
|
||||||
}
|
|
||||||
|
|
||||||
references.appendChild(referenceExpandButton);
|
|
||||||
|
|
||||||
referenceExpandButton.addEventListener('click', function() {
|
|
||||||
if (referenceSection.classList.contains("collapsed")) {
|
|
||||||
referenceSection.classList.remove("collapsed");
|
|
||||||
referenceSection.classList.add("expanded");
|
|
||||||
} else {
|
|
||||||
referenceSection.classList.add("collapsed");
|
|
||||||
referenceSection.classList.remove("expanded");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
|
||||||
referenceExpandButton.innerHTML = expandButtonText;
|
|
||||||
references.appendChild(referenceSection);
|
|
||||||
readStream();
|
|
||||||
} else {
|
|
||||||
// Display response from Khoj
|
|
||||||
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
|
||||||
newResponseText.removeChild(loadingSpinner);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
|
||||||
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
|
||||||
try {
|
|
||||||
const responseAsJson = JSON.parse(chunk);
|
|
||||||
if (responseAsJson.detail) {
|
|
||||||
rawResponse += responseAsJson.detail;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
} finally {
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// If the chunk is not a JSON object, just display it as is
|
|
||||||
rawResponse += chunk;
|
|
||||||
newResponseText.innerHTML = "";
|
|
||||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
|
||||||
readStream();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scroll to bottom of chat window as chat response is streamed
|
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
readStream();
|
if (responseAsJson.detail) {
|
||||||
});
|
// If response has detail field, response is an error message.
|
||||||
}
|
rawResponse += responseAsJson.detail;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
} finally {
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle streamed response of type text/event-stream or text/plain
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let references = null;
|
||||||
|
|
||||||
|
readStream();
|
||||||
|
|
||||||
|
function readStream() {
|
||||||
|
reader.read().then(({ done, value }) => {
|
||||||
|
if (done) {
|
||||||
|
// Append any references after all the data has been streamed
|
||||||
|
if (references != null) {
|
||||||
|
newResponseText.appendChild(references);
|
||||||
|
}
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode message chunk from stream
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
if (chunk.includes("### compiled references:")) {
|
||||||
|
const additionalResponse = chunk.split("### compiled references:")[0];
|
||||||
|
rawResponse += additionalResponse;
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
|
||||||
|
const rawReference = chunk.split("### compiled references:")[1];
|
||||||
|
const rawReferenceAsJson = JSON.parse(rawReference);
|
||||||
|
references = document.createElement('div');
|
||||||
|
references.classList.add("references");
|
||||||
|
|
||||||
|
let referenceExpandButton = document.createElement('button');
|
||||||
|
referenceExpandButton.classList.add("reference-expand-button");
|
||||||
|
|
||||||
|
let referenceSection = document.createElement('div');
|
||||||
|
referenceSection.classList.add("reference-section");
|
||||||
|
referenceSection.classList.add("collapsed");
|
||||||
|
|
||||||
|
let numReferences = 0;
|
||||||
|
|
||||||
|
// If rawReferenceAsJson is a list, then count the length
|
||||||
|
if (Array.isArray(rawReferenceAsJson)) {
|
||||||
|
numReferences = rawReferenceAsJson.length;
|
||||||
|
|
||||||
|
rawReferenceAsJson.forEach((reference, index) => {
|
||||||
|
let polishedReference = generateReference(reference, index);
|
||||||
|
referenceSection.appendChild(polishedReference);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
||||||
|
}
|
||||||
|
|
||||||
|
references.appendChild(referenceExpandButton);
|
||||||
|
|
||||||
|
referenceExpandButton.addEventListener('click', function() {
|
||||||
|
if (referenceSection.classList.contains("collapsed")) {
|
||||||
|
referenceSection.classList.remove("collapsed");
|
||||||
|
referenceSection.classList.add("expanded");
|
||||||
|
} else {
|
||||||
|
referenceSection.classList.add("collapsed");
|
||||||
|
referenceSection.classList.remove("expanded");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||||
|
referenceExpandButton.innerHTML = expandButtonText;
|
||||||
|
references.appendChild(referenceSection);
|
||||||
|
readStream();
|
||||||
|
} else {
|
||||||
|
// Display response from Khoj
|
||||||
|
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
||||||
|
newResponseText.removeChild(loadingSpinner);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the chunk is not a JSON object, just display it as is
|
||||||
|
rawResponse += chunk;
|
||||||
|
newResponseText.innerHTML = "";
|
||||||
|
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||||
|
readStream();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Scroll to bottom of chat window as chat response is streamed
|
||||||
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
function incrementalChat(event) {
|
function incrementalChat(event) {
|
||||||
if (!event.shiftKey && event.key === 'Enter') {
|
if (!event.shiftKey && event.key === 'Enter') {
|
||||||
|
@ -516,7 +553,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
.then(response => {
|
.then(response => {
|
||||||
// Render conversation history, if any
|
// Render conversation history, if any
|
||||||
response.forEach(chat_log => {
|
response.forEach(chat_log => {
|
||||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
|
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
|
@ -611,9 +648,15 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||||
.then(data => { chatInput.value += data.text; })
|
.then(data => { chatInput.value += data.text; })
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
err.status == 422
|
if (err.status === 501) {
|
||||||
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||||
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
|
} else if (err.status === 422) {
|
||||||
|
flashStatusInChatInput("⛔️ Audio file to large to process.")
|
||||||
|
} else if (err.status === 429) {
|
||||||
|
flashStatusInChatInput("⛔️ " + err.statusText);
|
||||||
|
} else {
|
||||||
|
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -902,6 +945,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
margin-top: -10px;
|
margin-top: -10px;
|
||||||
transform: rotate(-60deg)
|
transform: rotate(-60deg)
|
||||||
}
|
}
|
||||||
|
img.text-to-image {
|
||||||
|
max-width: 60%;
|
||||||
|
}
|
||||||
|
|
||||||
#chat-footer {
|
#chat-footer {
|
||||||
padding: 0;
|
padding: 0;
|
||||||
|
@ -916,7 +962,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
grid-template-columns: auto 32px 32px;
|
grid-template-columns: auto 32px 32px;
|
||||||
grid-column-gap: 10px;
|
grid-column-gap: 10px;
|
||||||
grid-row-gap: 10px;
|
grid-row-gap: 10px;
|
||||||
background: #f9fafc
|
background: var(--background-color);
|
||||||
}
|
}
|
||||||
.option:hover {
|
.option:hover {
|
||||||
box-shadow: 0 0 11px #aaa;
|
box-shadow: 0 0 11px #aaa;
|
||||||
|
@ -938,9 +984,10 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
}
|
}
|
||||||
.input-row-button {
|
.input-row-button {
|
||||||
background: var(--background-color);
|
background: var(--background-color);
|
||||||
border: none;
|
border: 1px solid var(--main-text-color);
|
||||||
|
box-shadow: 0 0 11px #aaa;
|
||||||
border-radius: 5px;
|
border-radius: 5px;
|
||||||
padding: 5px;
|
padding: 0px;
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
font-weight: 300;
|
font-weight: 300;
|
||||||
line-height: 1.5em;
|
line-height: 1.5em;
|
||||||
|
@ -1029,6 +1076,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
margin: 4px;
|
margin: 4px;
|
||||||
grid-template-columns: auto;
|
grid-template-columns: auto;
|
||||||
}
|
}
|
||||||
|
img.text-to-image {
|
||||||
|
max-width: 100%;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@media only screen and (min-width: 700px) {
|
@media only screen and (min-width: 700px) {
|
||||||
body {
|
body {
|
||||||
|
|
|
@ -47,7 +47,7 @@ def extract_questions_offline(
|
||||||
|
|
||||||
if use_history:
|
if use_history:
|
||||||
for chat in conversation_log.get("chat", [])[-4:]:
|
for chat in conversation_log.get("chat", [])[-4:]:
|
||||||
if chat["by"] == "khoj":
|
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
|
||||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
chat_history += f"A: {chat['message']}\n"
|
chat_history += f"A: {chat['message']}\n"
|
||||||
|
|
||||||
|
|
|
@ -12,11 +12,12 @@ def download_model(model_name: str):
|
||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Download the chat model
|
|
||||||
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
|
|
||||||
|
|
||||||
# Decide whether to load model to GPU or CPU
|
# Decide whether to load model to GPU or CPU
|
||||||
|
chat_model_config = None
|
||||||
try:
|
try:
|
||||||
|
# Download the chat model and its config
|
||||||
|
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
|
||||||
|
|
||||||
# Try load chat model to GPU if:
|
# Try load chat model to GPU if:
|
||||||
# 1. Loading chat model to GPU isn't disabled via CLI and
|
# 1. Loading chat model to GPU isn't disabled via CLI and
|
||||||
# 2. Machine has GPU
|
# 2. Machine has GPU
|
||||||
|
@ -26,6 +27,12 @@ def download_model(model_name: str):
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
except Exception as e:
|
||||||
|
if chat_model_config is None:
|
||||||
|
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
|
||||||
|
logger.debug(f"Unable to download model config from gpt4all website: {e}")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
# Now load the downloaded chat model onto appropriate device
|
# Now load the downloaded chat model onto appropriate device
|
||||||
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
|
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
|
||||||
|
|
|
@ -41,7 +41,7 @@ def extract_questions(
|
||||||
[
|
[
|
||||||
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
|
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
|
||||||
for chat in conversation_log.get("chat", [])[-4:]
|
for chat in conversation_log.get("chat", [])[-4:]
|
||||||
if chat["by"] == "khoj"
|
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -123,8 +123,8 @@ def send_message_to_model(
|
||||||
|
|
||||||
def converse(
|
def converse(
|
||||||
references,
|
references,
|
||||||
online_results,
|
|
||||||
user_query,
|
user_query,
|
||||||
|
online_results=[],
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "gpt-3.5-turbo",
|
model: str = "gpt-3.5-turbo",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
|
|
@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
retry_if_exception_type(openai._exceptions.APITimeoutError)
|
||||||
| retry_if_exception_type(openai.error.APIError)
|
| retry_if_exception_type(openai._exceptions.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError)
|
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||||
),
|
),
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
|
@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
retry_if_exception_type(openai._exceptions.APITimeoutError)
|
||||||
| retry_if_exception_type(openai.error.APIError)
|
| retry_if_exception_type(openai._exceptions.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError)
|
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||||
),
|
),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
|
|
|
@ -3,13 +3,13 @@ from io import BufferedReader
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
import openai
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
|
async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
|
||||||
"""
|
"""
|
||||||
Transcribe audio file using Whisper model via OpenAI's API
|
Transcribe audio file using Whisper model via OpenAI's API
|
||||||
"""
|
"""
|
||||||
# Send the audio data to the Whisper API
|
# Send the audio data to the Whisper API
|
||||||
response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key)
|
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
|
||||||
return response["text"]
|
return response.text
|
||||||
|
|
|
@ -15,6 +15,7 @@ You were created by Khoj Inc. with the following capabilities:
|
||||||
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||||
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||||
|
- Users can share information with you using the Khoj app, which is available for download at https://khoj.dev/downloads.
|
||||||
|
|
||||||
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
|
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
|
||||||
Today is {current_date} in UTC.
|
Today is {current_date} in UTC.
|
||||||
|
@ -109,6 +110,18 @@ Question: {query}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## Image Generation
|
||||||
|
## --
|
||||||
|
|
||||||
|
image_generation_improve_prompt = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Improved Query:"""
|
||||||
|
)
|
||||||
|
|
||||||
## Online Search Conversation
|
## Online Search Conversation
|
||||||
## --
|
## --
|
||||||
online_search_conversation = PromptTemplate.from_template(
|
online_search_conversation = PromptTemplate.from_template(
|
||||||
|
@ -295,10 +308,13 @@ Q:"""
|
||||||
# --
|
# --
|
||||||
help_message = PromptTemplate.from_template(
|
help_message = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
**/notes**: Chat using the information in your knowledge base.
|
- **/notes**: Chat using the information in your knowledge base.
|
||||||
**/general**: Chat using just Khoj's general knowledge. This will not search against your notes.
|
- **/general**: Chat using just Khoj's general knowledge. This will not search against your notes.
|
||||||
**/default**: Chat using your knowledge base and Khoj's general knowledge for context.
|
- **/default**: Chat using your knowledge base and Khoj's general knowledge for context.
|
||||||
**/help**: Show this help message.
|
- **/online**: Chat using the internet as a source of information.
|
||||||
|
- **/image**: Generate an image based on your message.
|
||||||
|
- **/help**: Show this help message.
|
||||||
|
|
||||||
|
|
||||||
You are using the **{model}** model on the **{device}**.
|
You are using the **{model}** model on the **{device}**.
|
||||||
**version**: {version}
|
**version**: {version}
|
||||||
|
|
|
@ -4,6 +4,7 @@ from time import perf_counter
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import queue
|
import queue
|
||||||
|
from typing import Any, Dict, List
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
# External packages
|
# External packages
|
||||||
|
@ -11,6 +12,8 @@ from langchain.schema import ChatMessage
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
from khoj.database.adapters import ConversationAdapters
|
||||||
|
from khoj.database.models import KhojUser
|
||||||
from khoj.utils.helpers import merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,6 +92,32 @@ def message_to_log(
|
||||||
return conversation_log
|
return conversation_log
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_conversation_log(
|
||||||
|
q: str,
|
||||||
|
chat_response: str,
|
||||||
|
user: KhojUser,
|
||||||
|
meta_log: Dict,
|
||||||
|
user_message_time: str = None,
|
||||||
|
compiled_references: List[str] = [],
|
||||||
|
online_results: Dict[str, Any] = {},
|
||||||
|
inferred_queries: List[str] = [],
|
||||||
|
intent_type: str = "remember",
|
||||||
|
):
|
||||||
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
updated_conversation = message_to_log(
|
||||||
|
user_message=q,
|
||||||
|
chat_response=chat_response,
|
||||||
|
user_message_metadata={"created": user_message_time},
|
||||||
|
khoj_message_metadata={
|
||||||
|
"context": compiled_references,
|
||||||
|
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||||
|
"onlineContext": online_results,
|
||||||
|
},
|
||||||
|
conversation_log=meta_log.get("chat", []),
|
||||||
|
)
|
||||||
|
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
|
||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message,
|
user_message,
|
||||||
system_message,
|
system_message,
|
||||||
|
|
|
@ -19,7 +19,7 @@ from starlette.authentication import requires
|
||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
|
||||||
from khoj.database.models import ChatModelOptions
|
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
|
@ -35,15 +35,18 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||||
|
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||||
from khoj.processor.tools.online_search import search_with_google
|
from khoj.processor.tools.online_search import search_with_google
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
|
text_to_image,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
|
ConversationCommandRateLimiter,
|
||||||
)
|
)
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
@ -65,6 +68,7 @@ from khoj.utils.state import SearchType
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
api = APIRouter()
|
api = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
|
||||||
|
|
||||||
|
|
||||||
def map_config_to_object(content_source: str):
|
def map_config_to_object(content_source: str):
|
||||||
|
@ -603,7 +607,13 @@ async def chat_options(
|
||||||
|
|
||||||
@api.post("/transcribe")
|
@api.post("/transcribe")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
|
async def transcribe(
|
||||||
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
|
||||||
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||||
|
):
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
|
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
|
||||||
user_message: str = None
|
user_message: str = None
|
||||||
|
@ -623,17 +633,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
|
||||||
|
|
||||||
# Send the audio data to the Whisper API
|
# Send the audio data to the Whisper API
|
||||||
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
|
||||||
if not speech_to_text_config:
|
if not speech_to_text_config:
|
||||||
# If the user has not configured a speech to text model, return an unprocessable entity error
|
# If the user has not configured a speech to text model, return an unsupported on server error
|
||||||
status_code = 422
|
status_code = 501
|
||||||
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
|
||||||
api_key = openai_chat_config.api_key
|
|
||||||
speech2text_model = speech_to_text_config.model_name
|
speech2text_model = speech_to_text_config.model_name
|
||||||
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
|
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
|
||||||
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
|
||||||
speech2text_model = speech_to_text_config.model_name
|
speech2text_model = speech_to_text_config.model_name
|
||||||
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
|
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
|
||||||
finally:
|
finally:
|
||||||
# Close and Delete the temporary audio file
|
# Close and Delete the temporary audio file
|
||||||
audio_file.close()
|
audio_file.close()
|
||||||
|
@ -666,11 +674,13 @@ async def chat(
|
||||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user = request.user.object
|
user: KhojUser = request.user.object
|
||||||
|
|
||||||
await is_ready_to_chat(user)
|
await is_ready_to_chat(user)
|
||||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||||
|
|
||||||
|
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
|
||||||
|
|
||||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||||
|
|
||||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||||
|
@ -704,6 +714,27 @@ async def chat(
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
status_code=200,
|
status_code=200,
|
||||||
)
|
)
|
||||||
|
elif conversation_command == ConversationCommand.Image:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
metadata={"conversation_command": conversation_command.value},
|
||||||
|
**common.__dict__,
|
||||||
|
)
|
||||||
|
image, status_code, improved_image_prompt = await text_to_image(q)
|
||||||
|
if image is None:
|
||||||
|
content_obj = {
|
||||||
|
"image": image,
|
||||||
|
"intentType": "text-to-image",
|
||||||
|
"detail": "Failed to generate image. Make sure your image generation configuration is set.",
|
||||||
|
}
|
||||||
|
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt]
|
||||||
|
)
|
||||||
|
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
|
||||||
|
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||||
|
|
||||||
# Get the (streamed) chat response from the LLM of choice.
|
# Get the (streamed) chat response from the LLM of choice.
|
||||||
llm_response, chat_metadata = await agenerate_chat_response(
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
|
@ -787,7 +818,6 @@ async def extract_references_and_questions(
|
||||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
|
||||||
if (
|
if (
|
||||||
offline_chat_config
|
offline_chat_config
|
||||||
and offline_chat_config.enabled
|
and offline_chat_config.enabled
|
||||||
|
@ -804,7 +834,7 @@ async def extract_references_and_questions(
|
||||||
inferred_queries = extract_questions_offline(
|
inferred_queries = extract_questions_offline(
|
||||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||||
)
|
)
|
||||||
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
|
|
|
@ -9,23 +9,23 @@ from functools import partial
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
# External Packages
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
|
import openai
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
|
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
from khoj.database.models import KhojUser, Subscription
|
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
||||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log
|
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
|
||||||
|
|
||||||
# Internal Packages
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=1)
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
||||||
return ConversationCommand.General
|
return ConversationCommand.General
|
||||||
elif query.startswith("/online"):
|
elif query.startswith("/online"):
|
||||||
return ConversationCommand.Online
|
return ConversationCommand.Online
|
||||||
|
elif query.startswith("/image"):
|
||||||
|
return ConversationCommand.Image
|
||||||
# If no relevant notes found for the given query
|
# If no relevant notes found for the given query
|
||||||
elif not any_references:
|
elif not any_references:
|
||||||
return ConversationCommand.General
|
return ConversationCommand.General
|
||||||
|
@ -144,6 +146,20 @@ async def generate_online_subqueries(q: str) -> List[str]:
|
||||||
return [q]
|
return [q]
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_better_image_prompt(q: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate a better image prompt from the given query
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_prompt = prompts.image_generation_improve_prompt.format(
|
||||||
|
query=q,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await send_message_to_model_wrapper(image_prompt)
|
||||||
|
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
async def send_message_to_model_wrapper(
|
async def send_message_to_model_wrapper(
|
||||||
message: str,
|
message: str,
|
||||||
):
|
):
|
||||||
|
@ -168,11 +184,13 @@ async def send_message_to_model_wrapper(
|
||||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
return send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
message=message,
|
message=message,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return openai_response.content
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
@ -186,30 +204,7 @@ def generate_chat_response(
|
||||||
conversation_command: ConversationCommand = ConversationCommand.Default,
|
conversation_command: ConversationCommand = ConversationCommand.Default,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
def _save_to_conversation_log(
|
|
||||||
q: str,
|
|
||||||
chat_response: str,
|
|
||||||
user_message_time: str,
|
|
||||||
compiled_references: List[str],
|
|
||||||
online_results: Dict[str, Any],
|
|
||||||
inferred_queries: List[str],
|
|
||||||
meta_log,
|
|
||||||
):
|
|
||||||
updated_conversation = message_to_log(
|
|
||||||
user_message=q,
|
|
||||||
chat_response=chat_response,
|
|
||||||
user_message_metadata={"created": user_message_time},
|
|
||||||
khoj_message_metadata={
|
|
||||||
"context": compiled_references,
|
|
||||||
"intent": {"inferred-queries": inferred_queries},
|
|
||||||
"onlineContext": online_results,
|
|
||||||
},
|
|
||||||
conversation_log=meta_log.get("chat", []),
|
|
||||||
)
|
|
||||||
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
|
|
||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
chat_response = None
|
chat_response = None
|
||||||
logger.debug(f"Conversation Type: {conversation_command.name}")
|
logger.debug(f"Conversation Type: {conversation_command.name}")
|
||||||
|
|
||||||
|
@ -217,13 +212,13 @@ def generate_chat_response(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
partial_completion = partial(
|
partial_completion = partial(
|
||||||
_save_to_conversation_log,
|
save_to_conversation_log,
|
||||||
q,
|
q,
|
||||||
user_message_time=user_message_time,
|
user=user,
|
||||||
|
meta_log=meta_log,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
meta_log=meta_log,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
|
||||||
|
@ -251,9 +246,9 @@ def generate_chat_response(
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
chat_response = converse(
|
chat_response = converse(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
online_results,
|
|
||||||
q,
|
q,
|
||||||
meta_log,
|
online_results=online_results,
|
||||||
|
conversation_log=meta_log,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
|
@ -271,6 +266,29 @@ def generate_chat_response(
|
||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_image(message: str) -> Tuple[Optional[str], int, Optional[str]]:
|
||||||
|
status_code = 200
|
||||||
|
image = None
|
||||||
|
|
||||||
|
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||||
|
if not text_to_image_config:
|
||||||
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
|
status_code = 501
|
||||||
|
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
|
text2image_model = text_to_image_config.model_name
|
||||||
|
improved_image_prompt = await generate_better_image_prompt(message)
|
||||||
|
try:
|
||||||
|
response = state.openai_client.images.generate(
|
||||||
|
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||||
|
)
|
||||||
|
image = response.data[0].b64_json
|
||||||
|
except openai.OpenAIError as e:
|
||||||
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
|
status_code = 500
|
||||||
|
|
||||||
|
return image, status_code, improved_image_prompt
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
@ -298,6 +316,40 @@ class ApiUserRateLimiter:
|
||||||
user_requests.append(time())
|
user_requests.append(time())
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationCommandRateLimiter:
|
||||||
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
|
||||||
|
self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
|
||||||
|
self.trial_rate_limit = trial_rate_limit
|
||||||
|
self.subscribed_rate_limit = subscribed_rate_limit
|
||||||
|
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
|
||||||
|
|
||||||
|
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not request.user.is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
if conversation_command not in self.restricted_commands:
|
||||||
|
return
|
||||||
|
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
user_cache = self.cache[user.uuid]
|
||||||
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
|
user_cache[conversation_command].append(time())
|
||||||
|
|
||||||
|
# Remove requests outside of the 24-hr time window
|
||||||
|
cutoff = time() - 60 * 60 * 24
|
||||||
|
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
|
||||||
|
user_cache[conversation_command].pop(0)
|
||||||
|
|
||||||
|
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
|
||||||
|
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||||
|
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
|
||||||
|
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class ApiIndexedDataLimiter:
|
class ApiIndexedDataLimiter:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -315,7 +367,7 @@ class ApiIndexedDataLimiter:
|
||||||
if state.billing_enabled is False:
|
if state.billing_enabled is False:
|
||||||
return
|
return
|
||||||
subscribed = has_required_scope(request, ["premium"])
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
incoming_data_size_mb = 0
|
incoming_data_size_mb = 0.0
|
||||||
deletion_file_names = set()
|
deletion_file_names = set()
|
||||||
|
|
||||||
if not request.user.is_authenticated:
|
if not request.user.is_authenticated:
|
||||||
|
|
|
@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
|
||||||
Notes = "notes"
|
Notes = "notes"
|
||||||
Help = "help"
|
Help = "help"
|
||||||
Online = "online"
|
Online = "online"
|
||||||
|
Image = "image"
|
||||||
|
|
||||||
|
|
||||||
command_descriptions = {
|
command_descriptions = {
|
||||||
|
@ -280,6 +281,7 @@ command_descriptions = {
|
||||||
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
||||||
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
||||||
ConversationCommand.Online: "Look up information on the internet.",
|
ConversationCommand.Online: "Look up information on the internet.",
|
||||||
|
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||||
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from khoj.database.models import (
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
SpeechToTextModelOptions,
|
SpeechToTextModelOptions,
|
||||||
|
TextToImageModelConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
|
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
|
||||||
|
@ -103,6 +104,15 @@ def initialization():
|
||||||
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
||||||
)
|
)
|
||||||
|
|
||||||
|
default_text_to_image_model = "dall-e-3"
|
||||||
|
openai_text_to_image_model = input(
|
||||||
|
f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
|
||||||
|
)
|
||||||
|
openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
|
||||||
|
TextToImageModelConfig.objects.create(
|
||||||
|
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
|
||||||
|
)
|
||||||
|
|
||||||
if use_offline_model == "y" or use_openai_model == "y":
|
if use_offline_model == "y" or use_openai_model == "y":
|
||||||
logger.info("🗣️ Chat model configuration complete")
|
logger.info("🗣️ Chat model configuration complete")
|
||||||
|
|
||||||
|
|
|
@ -22,17 +22,9 @@ class BaseEncoder(ABC):
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseEncoder):
|
class OpenAI(BaseEncoder):
|
||||||
def __init__(self, model_name, device=None):
|
def __init__(self, model_name, client: openai.OpenAI, device=None):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
if (
|
self.openai_client = client
|
||||||
not state.processor_config
|
|
||||||
or not state.processor_config.conversation
|
|
||||||
or not state.processor_config.conversation.openai_model
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
|
|
||||||
)
|
|
||||||
openai.api_key = state.processor_config.conversation.openai_model.api_key
|
|
||||||
self.embedding_dimensions = None
|
self.embedding_dimensions = None
|
||||||
|
|
||||||
def encode(self, entries, device=None, **kwargs):
|
def encode(self, entries, device=None, **kwargs):
|
||||||
|
@ -43,7 +35,7 @@ class OpenAI(BaseEncoder):
|
||||||
processed_entry = entries[index].replace("\n", " ")
|
processed_entry = entries[index].replace("\n", " ")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
|
response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
|
||||||
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
|
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
|
||||||
# Use current models embedding dimension, once available
|
# Use current models embedding dimension, once available
|
||||||
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
|
from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import threading
|
import threading
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from pathlib import Path
|
from openai import OpenAI
|
||||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.utils import config as utils_config
|
from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import LRU, get_device
|
from khoj.utils.helpers import LRU, get_device
|
||||||
|
@ -21,6 +22,7 @@ search_models = SearchModels()
|
||||||
embeddings_model: Dict[str, EmbeddingsModel] = None
|
embeddings_model: Dict[str, EmbeddingsModel] = None
|
||||||
cross_encoder_model: CrossEncoderModel = None
|
cross_encoder_model: CrossEncoderModel = None
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
|
openai_client: OpenAI = None
|
||||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||||
whisper_model: Whisper = None
|
whisper_model: Whisper = None
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
|
|
|
@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client):
|
||||||
response_message = response_message.split("### compiled references")[0]
|
response_message = response_message.split("### compiled references")[0]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
|
expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||||
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
|
"Expected links or serper not setup in response but got: " + response_message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue