mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
⚡️ Stream Responses by Khoj Chat on Web, Obsidian
- What - Stream chat responses from OpenAI API to Web, Obsidian clients - Implement using a callback function which manages a queue where new tokens can be placed as they come on. As the thread is read from, tokens are removed. - When the final token has been processed, add the `compiled_references` to the queue to be rendered by the `chat` client - When the thread has been closed, save the accumulated conversation log in the user's history using a `partial func` - Incrementally decode tokens on the front end and add them as they appear from the streamed response - Why This significantly reduces perceived latency and OpenAI API request timeouts for Chat Closes https://github.com/khoj-ai/khoj/issues/257
This commit is contained in:
commit
6c2a8a5bce
9 changed files with 393 additions and 55 deletions
|
@ -8,7 +8,12 @@
|
|||
"build": "tsc -noEmit -skipLibCheck && node esbuild.config.mjs production",
|
||||
"version": "node version-bump.mjs && git add manifest.json versions.json"
|
||||
},
|
||||
"keywords": ["search", "chat", "AI", "assistant"],
|
||||
"keywords": [
|
||||
"search",
|
||||
"chat",
|
||||
"AI",
|
||||
"assistant"
|
||||
],
|
||||
"author": "Debanjum Singh Solanky",
|
||||
"license": "GPL-3.0-or-later",
|
||||
"devDependencies": {
|
||||
|
@ -20,5 +25,9 @@
|
|||
"obsidian": "latest",
|
||||
"tslib": "2.4.0",
|
||||
"typescript": "4.7.4"
|
||||
},
|
||||
"dependencies": {
|
||||
"@types/node-fetch": "^2.6.4",
|
||||
"node-fetch": "3.0.0"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import { App, Modal, request, Setting } from 'obsidian';
|
||||
import { KhojSetting } from 'src/settings';
|
||||
|
||||
import fetch from "node-fetch";
|
||||
|
||||
export class KhojChatModal extends Modal {
|
||||
result: string;
|
||||
|
@ -34,13 +34,8 @@ export class KhojChatModal extends Modal {
|
|||
// Create area for chat logs
|
||||
contentEl.createDiv({ attr: { id: "khoj-chat-body", class: "khoj-chat-body" } });
|
||||
|
||||
// Get conversation history from Khoj backend
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat?client=obsidian`;
|
||||
let response = await request(chatUrl);
|
||||
let chatLogs = JSON.parse(response).response;
|
||||
chatLogs.forEach((chatLog: any) => {
|
||||
this.renderMessageWithReferences(chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
|
||||
});
|
||||
// Get chat history from Khoj backend
|
||||
await this.getChatHistory();
|
||||
|
||||
// Add chat input field
|
||||
contentEl.createEl("input",
|
||||
|
@ -107,6 +102,35 @@ export class KhojChatModal extends Modal {
|
|||
return chat_message_el
|
||||
}
|
||||
|
||||
createKhojResponseDiv(dt?: Date): HTMLDivElement {
|
||||
let message_time = this.formatDate(dt ?? new Date());
|
||||
|
||||
// Append message to conversation history HTML element.
|
||||
// The chat logs should display above the message input box to follow standard UI semantics
|
||||
let chat_body_el = this.contentEl.getElementsByClassName("khoj-chat-body")[0];
|
||||
let chat_message_el = chat_body_el.createDiv({
|
||||
attr: {
|
||||
"data-meta": `🏮 Khoj at ${message_time}`,
|
||||
class: `khoj-chat-message khoj`
|
||||
},
|
||||
}).createDiv({
|
||||
attr: {
|
||||
class: `khoj-chat-message-text khoj`
|
||||
},
|
||||
})
|
||||
|
||||
// Scroll to bottom after inserting chat messages
|
||||
this.modalEl.scrollTop = this.modalEl.scrollHeight;
|
||||
|
||||
return chat_message_el
|
||||
}
|
||||
|
||||
renderIncrementalMessage(htmlElement: HTMLDivElement, additionalMessage: string) {
|
||||
htmlElement.innerHTML += additionalMessage;
|
||||
// Scroll to bottom of modal, till the send message input box
|
||||
this.modalEl.scrollTop = this.modalEl.scrollHeight;
|
||||
}
|
||||
|
||||
formatDate(date: Date): string {
|
||||
// Format date in HH:MM, DD MMM YYYY format
|
||||
let time_string = date.toLocaleTimeString('en-IN', { hour: '2-digit', minute: '2-digit', hour12: false });
|
||||
|
@ -114,6 +138,17 @@ export class KhojChatModal extends Modal {
|
|||
return `${time_string}, ${date_string}`;
|
||||
}
|
||||
|
||||
|
||||
async getChatHistory(): Promise<void> {
|
||||
// Get chat history from Khoj backend
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat/init?client=obsidian`;
|
||||
let response = await request(chatUrl);
|
||||
let chatLogs = JSON.parse(response).response;
|
||||
chatLogs.forEach((chatLog: any) => {
|
||||
this.renderMessageWithReferences(chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
|
||||
});
|
||||
}
|
||||
|
||||
async getChatResponse(query: string | undefined | null): Promise<void> {
|
||||
// Exit if query is empty
|
||||
if (!query || query === "") return;
|
||||
|
@ -124,10 +159,37 @@ export class KhojChatModal extends Modal {
|
|||
// Get chat response from Khoj backend
|
||||
let encodedQuery = encodeURIComponent(query);
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&client=obsidian`;
|
||||
let response = await request(chatUrl);
|
||||
let data = JSON.parse(response);
|
||||
let responseElement = this.createKhojResponseDiv();
|
||||
|
||||
// Render Khoj response as chat message
|
||||
this.renderMessage(data.response, "khoj");
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
this.renderIncrementalMessage(responseElement, "🤔");
|
||||
|
||||
let response = await fetch(chatUrl, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Content-Type": "text/event-stream"
|
||||
},
|
||||
})
|
||||
|
||||
try {
|
||||
if (response.body == null) {
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
// Clear thinking status message
|
||||
if (responseElement.innerHTML === "🤔") {
|
||||
responseElement.innerHTML = "";
|
||||
}
|
||||
|
||||
for await (const chunk of response.body) {
|
||||
const responseText = chunk.toString();
|
||||
if (responseText.startsWith("### compiled references:")) {
|
||||
return;
|
||||
}
|
||||
this.renderIncrementalMessage(responseElement, responseText);
|
||||
}
|
||||
} catch (err) {
|
||||
this.renderIncrementalMessage(responseElement, "Sorry, unable to get response from Khoj backend ❤️🩹. Contact developer for help at team@khoj.dev or <a href='https://discord.gg/BDgyabRM6e'>in Discord</a>")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,6 +40,19 @@
|
|||
resolved "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.11.tgz"
|
||||
integrity sha512-wOuvG1SN4Us4rez+tylwwwCV1psiNVOkJeM3AUWUNWg/jDQY2+HE/444y5gc+jBmRqASOm2Oeh5c1axHobwRKQ==
|
||||
|
||||
"@types/node-fetch@^2.6.4":
|
||||
version "2.6.4"
|
||||
resolved "https://registry.yarnpkg.com/@types/node-fetch/-/node-fetch-2.6.4.tgz#1bc3a26de814f6bf466b25aeb1473fa1afe6a660"
|
||||
integrity sha512-1ZX9fcN4Rvkvgv4E6PAY5WXUFWFcRWxZa3EW83UjycOB9ljJCedb2CupIP4RZMEwF/M3eTcCihbBRgwtGbg5Rg==
|
||||
dependencies:
|
||||
"@types/node" "*"
|
||||
form-data "^3.0.0"
|
||||
|
||||
"@types/node@*":
|
||||
version "20.3.3"
|
||||
resolved "https://registry.yarnpkg.com/@types/node/-/node-20.3.3.tgz#329842940042d2b280897150e023e604d11657d6"
|
||||
integrity sha512-wheIYdr4NYML61AjC8MKj/2jrR/kDQri/CIpVoZwldwhnIrD/j9jIU5bJ8yBKuB2VhpFV7Ab6G2XkBjv9r9Zzw==
|
||||
|
||||
"@types/node@^16.11.6":
|
||||
version "16.18.12"
|
||||
resolved "https://registry.npmjs.org/@types/node/-/node-16.18.12.tgz"
|
||||
|
@ -137,6 +150,11 @@ array-union@^2.1.0:
|
|||
resolved "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz"
|
||||
integrity sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==
|
||||
|
||||
asynckit@^0.4.0:
|
||||
version "0.4.0"
|
||||
resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79"
|
||||
integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==
|
||||
|
||||
braces@^3.0.2:
|
||||
version "3.0.2"
|
||||
resolved "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz"
|
||||
|
@ -149,6 +167,18 @@ builtin-modules@3.3.0:
|
|||
resolved "https://registry.npmjs.org/builtin-modules/-/builtin-modules-3.3.0.tgz"
|
||||
integrity sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw==
|
||||
|
||||
combined-stream@^1.0.8:
|
||||
version "1.0.8"
|
||||
resolved "https://registry.yarnpkg.com/combined-stream/-/combined-stream-1.0.8.tgz#c3d45a8b34fd730631a110a8a2520682b31d5a7f"
|
||||
integrity sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==
|
||||
dependencies:
|
||||
delayed-stream "~1.0.0"
|
||||
|
||||
data-uri-to-buffer@^3.0.1:
|
||||
version "3.0.1"
|
||||
resolved "https://registry.yarnpkg.com/data-uri-to-buffer/-/data-uri-to-buffer-3.0.1.tgz#594b8973938c5bc2c33046535785341abc4f3636"
|
||||
integrity sha512-WboRycPNsVw3B3TL559F7kuBUM4d8CgMEvk6xEJlOp7OBPjt6G7z8WMWlD2rOFZLk6OYfFIUGsCOWzcQH9K2og==
|
||||
|
||||
debug@^4.3.4:
|
||||
version "4.3.4"
|
||||
resolved "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz"
|
||||
|
@ -156,6 +186,11 @@ debug@^4.3.4:
|
|||
dependencies:
|
||||
ms "2.1.2"
|
||||
|
||||
delayed-stream@~1.0.0:
|
||||
version "1.0.0"
|
||||
resolved "https://registry.yarnpkg.com/delayed-stream/-/delayed-stream-1.0.0.tgz#df3ae199acadfb7d440aaae0b29e2272b24ec619"
|
||||
integrity sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==
|
||||
|
||||
dir-glob@^3.0.1:
|
||||
version "3.0.1"
|
||||
resolved "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz"
|
||||
|
@ -349,6 +384,14 @@ fastq@^1.6.0:
|
|||
dependencies:
|
||||
reusify "^1.0.4"
|
||||
|
||||
fetch-blob@^3.1.2:
|
||||
version "3.2.0"
|
||||
resolved "https://registry.yarnpkg.com/fetch-blob/-/fetch-blob-3.2.0.tgz#f09b8d4bbd45adc6f0c20b7e787e793e309dcce9"
|
||||
integrity sha512-7yAQpD2UMJzLi1Dqv7qFYnPbaPx7ZfFK6PiIxQ4PfkGPyNyl2Ugx+a/umUonmKqjhM4DnfbMvdX6otXq83soQQ==
|
||||
dependencies:
|
||||
node-domexception "^1.0.0"
|
||||
web-streams-polyfill "^3.0.3"
|
||||
|
||||
fill-range@^7.0.1:
|
||||
version "7.0.1"
|
||||
resolved "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz"
|
||||
|
@ -356,6 +399,15 @@ fill-range@^7.0.1:
|
|||
dependencies:
|
||||
to-regex-range "^5.0.1"
|
||||
|
||||
form-data@^3.0.0:
|
||||
version "3.0.1"
|
||||
resolved "https://registry.yarnpkg.com/form-data/-/form-data-3.0.1.tgz#ebd53791b78356a99af9a300d4282c4d5eb9755f"
|
||||
integrity sha512-RHkBKtLWUVwd7SqRIvCZMEvAMoGUp0XU+seQiZejj0COz3RI3hWP4sCv3gZWWLjJTd7rGwcsF5eKZGii0r/hbg==
|
||||
dependencies:
|
||||
asynckit "^0.4.0"
|
||||
combined-stream "^1.0.8"
|
||||
mime-types "^2.1.12"
|
||||
|
||||
functional-red-black-tree@^1.0.1:
|
||||
version "1.0.1"
|
||||
resolved "https://registry.npmjs.org/functional-red-black-tree/-/functional-red-black-tree-1.0.1.tgz"
|
||||
|
@ -422,6 +474,18 @@ micromatch@^4.0.4:
|
|||
braces "^3.0.2"
|
||||
picomatch "^2.3.1"
|
||||
|
||||
mime-db@1.52.0:
|
||||
version "1.52.0"
|
||||
resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.52.0.tgz#bbabcdc02859f4987301c856e3387ce5ec43bf70"
|
||||
integrity sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==
|
||||
|
||||
mime-types@^2.1.12:
|
||||
version "2.1.35"
|
||||
resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.35.tgz#381a871b62a734450660ae3deee44813f70d959a"
|
||||
integrity sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==
|
||||
dependencies:
|
||||
mime-db "1.52.0"
|
||||
|
||||
moment@2.29.4:
|
||||
version "2.29.4"
|
||||
resolved "https://registry.npmjs.org/moment/-/moment-2.29.4.tgz"
|
||||
|
@ -432,6 +496,19 @@ ms@2.1.2:
|
|||
resolved "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz"
|
||||
integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==
|
||||
|
||||
node-domexception@^1.0.0:
|
||||
version "1.0.0"
|
||||
resolved "https://registry.yarnpkg.com/node-domexception/-/node-domexception-1.0.0.tgz#6888db46a1f71c0b76b3f7555016b63fe64766e5"
|
||||
integrity sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==
|
||||
|
||||
node-fetch@3.0.0:
|
||||
version "3.0.0"
|
||||
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-3.0.0.tgz#79da7146a520036f2c5f644e4a26095f17e411ea"
|
||||
integrity sha512-bKMI+C7/T/SPU1lKnbQbwxptpCrG9ashG+VkytmXCPZyuM9jB6VU+hY0oi4lC8LxTtAeWdckNCTa3nrGsAdA3Q==
|
||||
dependencies:
|
||||
data-uri-to-buffer "^3.0.1"
|
||||
fetch-blob "^3.1.2"
|
||||
|
||||
obsidian@latest:
|
||||
version "1.1.1"
|
||||
resolved "https://registry.npmjs.org/obsidian/-/obsidian-1.1.1.tgz"
|
||||
|
@ -513,6 +590,11 @@ typescript@4.7.4:
|
|||
resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.7.4.tgz#1a88596d1cf47d59507a1bcdfb5b9dfe4d488235"
|
||||
integrity sha512-C0WQT0gezHuw6AdY1M2jxUO83Rjf0HP7Sk1DtXj6j1EwkQNZrHAg2XPWlq62oqEhYvONq5pkC2Y9oPljWToLmQ==
|
||||
|
||||
web-streams-polyfill@^3.0.3:
|
||||
version "3.2.1"
|
||||
resolved "https://registry.yarnpkg.com/web-streams-polyfill/-/web-streams-polyfill-3.2.1.tgz#71c2718c52b45fd49dbeee88634b3a60ceab42a6"
|
||||
integrity sha512-e0MO3wdXWKrLbL0DgGnUV7WHVuw9OUvL4hjgnPkIeEvESk74gAITi5G606JtZPp39cd8HA9VQzCIvA49LpPN5Q==
|
||||
|
||||
yallist@^4.0.0:
|
||||
version "4.0.0"
|
||||
resolved "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz"
|
||||
|
|
|
@ -253,6 +253,13 @@ def upload_telemetry():
|
|||
|
||||
try:
|
||||
logger.debug(f"📡 Upload usage telemetry to {constants.telemetry_server}:\n{state.telemetry}")
|
||||
for log in state.telemetry:
|
||||
for field in log:
|
||||
# Check if the value for the field is JSON serializable
|
||||
try:
|
||||
json.dumps(log[field])
|
||||
except TypeError:
|
||||
log[field] = str(log[field])
|
||||
requests.post(constants.telemetry_server, json=state.telemetry)
|
||||
except Exception as e:
|
||||
logger.error(f"📡 Error uploading telemetry: {e}")
|
||||
|
|
|
@ -64,13 +64,57 @@
|
|||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`;
|
||||
|
||||
// Call specified Khoj API
|
||||
let chat_body = document.getElementById("chat-body");
|
||||
let new_response = document.createElement("div");
|
||||
new_response.classList.add("chat-message", "khoj");
|
||||
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||
chat_body.appendChild(new_response);
|
||||
|
||||
let new_response_text = document.createElement("div");
|
||||
new_response_text.classList.add("chat-message-text", "khoj");
|
||||
new_response.appendChild(new_response_text);
|
||||
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
new_response_text.innerHTML = "🤔";
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
|
||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
||||
fetch(url)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
// Render message by Khoj to chat body
|
||||
console.log(data.response);
|
||||
renderMessageWithReference(data.response, "khoj", data.context);
|
||||
.then(response => {
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
function readStream() {
|
||||
reader.read().then(({ done, value }) => {
|
||||
if (done) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Decode message chunk from stream
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
if (chunk.startsWith("### compiled references:")) {
|
||||
// Display references used to generate response
|
||||
const rawReferences = chunk.split("### compiled references:")[1];
|
||||
const rawReferencesAsJson = JSON.parse(rawReferences);
|
||||
let polishedReferences = rawReferencesAsJson
|
||||
.map((reference, index) => generateReference(reference, index))
|
||||
.join("<sup>,</sup>");
|
||||
new_response_text.innerHTML += polishedReferences;
|
||||
} else {
|
||||
// Display response from Khoj
|
||||
if (new_response_text.innerHTML === "🤔") {
|
||||
// Clear temporary status message
|
||||
new_response_text.innerHTML = "";
|
||||
}
|
||||
new_response_text.innerHTML += chunk;
|
||||
readStream();
|
||||
}
|
||||
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
});
|
||||
}
|
||||
readStream();
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -82,7 +126,7 @@
|
|||
}
|
||||
|
||||
window.onload = function () {
|
||||
fetch('/api/chat?client=web')
|
||||
fetch('/api/chat/init?client=web')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
|
@ -384,7 +428,7 @@
|
|||
<script>
|
||||
var khojBannerSubmit = document.getElementById("khoj-banner-submit");
|
||||
|
||||
khojBannerSubmit.addEventListener("click", function(event) {
|
||||
khojBannerSubmit?.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
var email = document.getElementById("khoj-banner-email").value;
|
||||
fetch("https://lantern.khoj.dev/beta/users/", {
|
||||
|
|
|
@ -144,7 +144,15 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
|
|||
return json.loads(response.strip(empty_escape_sequences))
|
||||
|
||||
|
||||
def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", api_key=None, temperature=0.2):
|
||||
def converse(
|
||||
references,
|
||||
user_query,
|
||||
conversation_log={},
|
||||
model="gpt-3.5-turbo",
|
||||
api_key=None,
|
||||
temperature=0.2,
|
||||
completion_func=None,
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
"""
|
||||
|
@ -167,15 +175,15 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
|
|||
conversation_log,
|
||||
model,
|
||||
)
|
||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Conversation Context for GPT: {messages}")
|
||||
response = chat_completion_with_backoff(
|
||||
return chat_completion_with_backoff(
|
||||
messages=messages,
|
||||
compiled_references=references,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
openai_api_key=api_key,
|
||||
completion_func=completion_func,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
return response.strip(empty_escape_sequences)
|
||||
|
|
|
@ -2,11 +2,17 @@
|
|||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
from threading import Thread
|
||||
import json
|
||||
|
||||
# External Packages
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.schema import ChatMessage
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
import openai
|
||||
import tiktoken
|
||||
from tenacity import (
|
||||
|
@ -20,12 +26,55 @@ from tenacity import (
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
import queue
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
def __init__(self, compiled_references, completion_func=None):
|
||||
self.queue = queue.Queue()
|
||||
self.compiled_references = compiled_references
|
||||
self.completion_func = completion_func
|
||||
self.response = ""
|
||||
self.start_time = perf_counter()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
item = self.queue.get()
|
||||
if item is StopIteration:
|
||||
time_to_response = perf_counter() - self.start_time
|
||||
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
|
||||
if self.completion_func:
|
||||
# The completion func effective acts as a callback.
|
||||
# It adds the aggregated response to the conversation history. It's constructed in api.py.
|
||||
self.completion_func(gpt_response=self.response)
|
||||
raise StopIteration
|
||||
return item
|
||||
|
||||
def send(self, data):
|
||||
self.response += data
|
||||
self.queue.put(data)
|
||||
|
||||
def close(self):
|
||||
if self.compiled_references and len(self.compiled_references) > 0:
|
||||
self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}")
|
||||
self.queue.put(StopIteration)
|
||||
|
||||
|
||||
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def __init__(self, gen: ThreadedGenerator):
|
||||
super().__init__()
|
||||
self.gen = gen
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
||||
self.gen.send(token)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
|
@ -62,15 +111,31 @@ def completion_with_backoff(**kwargs):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
|
||||
def chat_completion_with_backoff(
|
||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||
callback_handler = StreamingChatCallbackHandler(g)
|
||||
chat = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callback_manager=BaseCallbackManager([callback_handler]),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
)
|
||||
return chat(messages).content
|
||||
|
||||
chat(messages=messages)
|
||||
|
||||
g.close()
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
|
|
|
@ -6,6 +6,7 @@ import yaml
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
from functools import partial
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
|
@ -34,6 +35,7 @@ from khoj.utils.rawconfig import (
|
|||
from khoj.utils.state import SearchType
|
||||
from khoj.utils import state, constants
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# Initialize Router
|
||||
api = APIRouter()
|
||||
|
@ -393,10 +395,9 @@ def update(
|
|||
return {"status": "ok", "message": "khoj reloaded"}
|
||||
|
||||
|
||||
@api.get("/chat")
|
||||
async def chat(
|
||||
@api.get("/chat/init")
|
||||
def chat_init(
|
||||
request: Request,
|
||||
q: Optional[str] = None,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
|
@ -411,13 +412,68 @@ async def chat(
|
|||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
user_state = {
|
||||
"client_host": request.client.host,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
}
|
||||
|
||||
state.telemetry += [
|
||||
log_telemetry(
|
||||
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
||||
)
|
||||
]
|
||||
|
||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||
|
||||
|
||||
@api.get("/chat", response_class=StreamingResponse)
|
||||
async def chat(
|
||||
request: Request,
|
||||
q: Optional[str] = None,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
) -> StreamingResponse:
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
gpt_response: str,
|
||||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
inferred_queries: List[str],
|
||||
chat_session: str,
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
|
||||
# Load Conversation History
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# If user query is empty, return chat history
|
||||
# If user query is empty, return nothing
|
||||
if not q:
|
||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||
return StreamingResponse(None)
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
|
@ -438,7 +494,7 @@ async def chat(
|
|||
result_list = []
|
||||
for query in inferred_queries:
|
||||
result_list.extend(
|
||||
await search(query, request=request, n=5, r=True, score_threshold=-5.0, dedupe=False)
|
||||
await search(query, request=request, n=5, r=False, score_threshold=-5.0, dedupe=False)
|
||||
)
|
||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||
|
||||
|
@ -446,24 +502,6 @@ async def chat(
|
|||
conversation_type = "notes" if compiled_references else "general"
|
||||
logger.debug(f"Conversation Type: {conversation_type}")
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
|
||||
status = "ok"
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
status = "error"
|
||||
|
||||
# Update Conversation History
|
||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
|
||||
user_state = {
|
||||
"client_host": request.client.host,
|
||||
"user_agent": user_agent or "unknown",
|
||||
|
@ -477,4 +515,23 @@ async def chat(
|
|||
)
|
||||
]
|
||||
|
||||
return {"status": status, "response": gpt_response, "context": compiled_references}
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
chat_session=chat_session,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
gpt_response = converse(
|
||||
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
|
||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||
|
|
|
@ -175,7 +175,11 @@ def get_server_id():
|
|||
|
||||
|
||||
def log_telemetry(
|
||||
telemetry_type: str, api: str = None, client: str = None, app_config: AppConfig = None, properties: dict = None
|
||||
telemetry_type: str,
|
||||
api: str = None,
|
||||
client: Optional[str] = None,
|
||||
app_config: Optional[AppConfig] = None,
|
||||
properties: dict = None,
|
||||
):
|
||||
"""Log basic app usage telemetry like client, os, api called"""
|
||||
# Do not log usage telemetry, if telemetry is disabled via app config
|
||||
|
|
Loading…
Reference in a new issue