mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 16:14:21 +00:00
Merge Improve Khoj Chat PR #183 from debanjum/improve-chat-interface
# Improve Khoj Chat ## Main Changes - Use the new [API](https://openai.com/blog/introducing-chatgpt-and-whisper-apis) for [ChatGPT](https://openai.com/blog/chatgpt) to improve conversation quality and cost - Improve Prompt to answer query using indexed notes - Previously was asking GPT to summarize the notes - Both the chat and answer API use this new prompt - Support Multi-Turn conversations - Pass previous messages and associated reference notes to ChatGPT for context - Show note snippets referenced to generate response - Allows fact-checking, getting details - Simplify chat interface by using only single unified chat type for now ## Miscellaneous - Replace summarize with answer API. Summarize via API not useful for now - Only pass Khoj search results above a threshold confidence to GPT for context - Allows Khoj to say don't know if it can't find answer to query from notes - Allows relying on (only) conversation history to generate response in multi-turn conversation - Move Chat API out of beta. Update Readme
This commit is contained in:
commit
6c0e82b2d6
11 changed files with 330 additions and 271 deletions
README.mdpyproject.toml
src/khoj
tests
75
README.md
75
README.md
|
@ -22,9 +22,10 @@
|
|||
- [Install](#1-Install)
|
||||
- [Run](#2-Run)
|
||||
- [Configure](#3-Configure)
|
||||
- [Install Plugins](#4-install-interface-plugins)
|
||||
- [Use](#Use)
|
||||
- [Interfaces](#Interfaces-1)
|
||||
- [Query Filters](#Query-filters)
|
||||
- [Khoj Search](#Khoj-search)
|
||||
- [Khoj Chat](#Khoj-chat)
|
||||
- [Upgrade](#Upgrade)
|
||||
- [Khoj Server](#upgrade-khoj-server)
|
||||
- [Khoj.el](#upgrade-khoj-on-emacs)
|
||||
|
@ -33,12 +34,11 @@
|
|||
- [Troubleshoot](#Troubleshoot)
|
||||
- [Advanced Usage](#advanced-usage)
|
||||
- [Access Khoj on Mobile](#access-khoj-on-mobile)
|
||||
- [Chat with Notes](#chat-with-notes)
|
||||
- [Use OpenAI Models for Search](#use-openai-models-for-search)
|
||||
- [Search across Different Languages](#search-across-different-languages)
|
||||
- [Miscellaneous](#Miscellaneous)
|
||||
- [Setup OpenAI API key in Khoj](#set-your-openai-api-key-in-khoj)
|
||||
- [Beta API](#beta-api)
|
||||
- [GPT API](#gpt-api)
|
||||
- [Performance](#Performance)
|
||||
- [Query Performance](#Query-performance)
|
||||
- [Indexing Performance](#Indexing-performance)
|
||||
|
@ -131,21 +131,29 @@ Note: To start Khoj automatically in the background use [Task scheduler](https:/
|
|||
1. Enable content types and point to files to search in the First Run Screen that pops up on app start
|
||||
2. Click `Configure` and wait. The app will download ML models and index the content for search
|
||||
|
||||
## Use
|
||||
### Interfaces
|
||||
### 4. Install Interface Plugins
|
||||
Khoj exposes a web interface by default.<br />
|
||||
The optional steps below allow using Khoj from within an existing application like Obsidian or Emacs.
|
||||
|
||||
- **Khoj Obsidian**:<br />
|
||||
[Install](https://github.com/debanjum/khoj/tree/master/src/interface/obsidian#2-Setup-Plugin) the Khoj Obsidian plugin
|
||||
|
||||
- **Khoj Emacs**:<br />
|
||||
[Install](https://github.com/debanjum/khoj/tree/master/src/interface/emacs#2-Install-Khojel) khoj.el
|
||||
|
||||
## Use
|
||||
### Khoj Search
|
||||
- **Khoj via Obsidian**
|
||||
- [Install](https://github.com/debanjum/khoj/tree/master/src/interface/obsidian#2-Setup-Plugin) the Khoj Obsidian plugin
|
||||
- Click the *Khoj search* icon 🔎 on the [Ribbon](https://help.obsidian.md/User+interface/Workspace/Ribbon) or Search for *Khoj: Search* in the [Command Palette](https://help.obsidian.md/Plugins/Command+palette)
|
||||
- **Khoj via Emacs**
|
||||
- [Install](https://github.com/debanjum/khoj/tree/master/src/interface/emacs#installation) [khoj.el](./src/interface/emacs/khoj.el)
|
||||
- Run `M-x khoj <user-query>`
|
||||
- **Khoj via Web**
|
||||
- Open <http://localhost:8000/> via desktop interface or directly
|
||||
- **Khoj via API**
|
||||
- See the Khoj FastAPI [Swagger Docs](http://localhost:8000/docs), [ReDocs](http://localhost:8000/redocs)
|
||||
|
||||
### Query Filters
|
||||
<details><summary>Query Filters</summary>
|
||||
|
||||
Use structured query syntax to filter the natural language search results
|
||||
- **Word Filter**: Get entries that include/exclude a specified term
|
||||
- Entries that contain term_to_include: `+"term_to_include"`
|
||||
|
@ -164,6 +172,30 @@ Use structured query syntax to filter the natural language search results
|
|||
- excluding words *"big"* and *"brother"*
|
||||
- that best match the natural language query *"what is the meaning of life?"*
|
||||
|
||||
</details>
|
||||
|
||||
### Khoj Chat
|
||||
#### Overview
|
||||
- Creates a personal assistant for you to inquire and engage with your notes
|
||||
- Uses [ChatGPT](https://openai.com/blog/chatgpt) and [Khoj search](#khoj-search)
|
||||
- Supports multi-turn conversations with the relevant notes for context
|
||||
- Shows reference notes used to generate a response
|
||||
- **Note**: *Your query and top notes from khoj search will be sent to OpenAI for processing*
|
||||
|
||||
#### Setup
|
||||
- [Setup your OpenAI API key in Khoj](#set-your-openai-api-key-in-khoj)
|
||||
|
||||
#### Use
|
||||
1. Open [/chat](http://localhost:8000/chat)[^2]
|
||||
2. Type your queries and see response by Khoj from your notes
|
||||
|
||||
#### Demo
|
||||
![](https://github.com/debanjum/khoj/blob/master/docs/khoj_chat_web_interface.png?)
|
||||
|
||||
### Details
|
||||
1. Your query is used to retrieve the most relevant notes, if any, using Khoj search
|
||||
2. These notes, the last few messages and associated metadata is passed to ChatGPT along with your query for a response
|
||||
|
||||
## Upgrade
|
||||
### Upgrade Khoj Server
|
||||
```shell
|
||||
|
@ -223,23 +255,6 @@ pip install --upgrade --pre khoj-assistant
|
|||
|
||||
![](https://github.com/debanjum/khoj/blob/master/docs/khoj_pwa_android.png?)
|
||||
|
||||
### Chat with Notes
|
||||
#### Overview
|
||||
- Provides a chat interface to inquire and engage with your notes
|
||||
- Chat Types:
|
||||
- **Summarize**: Pulls the most relevant note from your notes and summarizes it
|
||||
- **Chat**: Also does general chat. It guesses whether to give a general response or search, summarizes from your note. <br />
|
||||
E.g *"how was your day?"* will give a general response. But *When did I go surfing?* should give a response from your notes
|
||||
- **Note**: *Your query and top note from search result will be sent to OpenAI for processing*
|
||||
|
||||
#### Use
|
||||
1. [Setup your OpenAI API key in Khoj](#set-your-openai-api-key-in-khoj)
|
||||
2. Open [/chat?t=summarize](http://localhost:8000/chat?t=summarize)[^2]
|
||||
3. Type your queries, see summarized response by Khoj from your notes
|
||||
|
||||
#### Demo
|
||||
![](https://github.com/debanjum/khoj/blob/master/docs/khoj_chat_web_interface.png?)
|
||||
|
||||
### Use OpenAI Models for Search
|
||||
#### Setup
|
||||
1. Set `encoder-type`, `encoder` and `model-directory` under `asymmetric` and/or `symmetric` `search-type` in your `khoj.yml`[^1]:
|
||||
|
@ -282,7 +297,7 @@ pip install --upgrade --pre khoj-assistant
|
|||
If you want, Khoj can be configured to use OpenAI for search and chat.<br />
|
||||
Add your OpenAI API to Khoj by using either of the two options below:
|
||||
- Open the Khoj desktop GUI, add your [OpenAI API key](https://beta.openai.com/account/api-keys) and click *Configure*
|
||||
Ensure khoj is started without the `--no-gui` flag. Check your system tray to see if Khoj 🦅 is minimized there.
|
||||
Ensure khoj is started **without** the `--no-gui` flag. Check your system tray to see if Khoj 🦅 is minimized there.
|
||||
- Set `openai-api-key` field under `processor.conversation` section in your `khoj.yml`[^1] to your [OpenAI API key](https://beta.openai.com/account/api-keys) and restart khoj:
|
||||
```diff
|
||||
processor:
|
||||
|
@ -293,10 +308,10 @@ Add your OpenAI API to Khoj by using either of the two options below:
|
|||
conversation-logfile: "~/.khoj/processor/conversation/conversation_logs.json"
|
||||
```
|
||||
|
||||
**Warning**: *This will enable khoj to send your query and note(s) to OpenAI for processing*
|
||||
**Warning**: *This will enable Khoj to send your query and note(s) to OpenAI for processing*
|
||||
|
||||
### Beta API
|
||||
- The beta [chat](http://localhost:8000/api/beta/chat), [summarize](http://localhost:8000/api/beta/summarize) and [search](http://localhost:8000/api/beta/search) API endpoints use [OpenAI API](https://openai.com/api/)
|
||||
### GPT API
|
||||
- The [chat](http://localhost:8000/api/chat), [answer](http://localhost:8000/api/beta/answer) and [search](http://localhost:8000/api/beta/search) API endpoints use [OpenAI API](https://openai.com/api/)
|
||||
- They are disabled by default
|
||||
- To use them:
|
||||
1. [Setup your OpenAI API key in Khoj](#set-your-openai-api-key-in-khoj)
|
||||
|
|
|
@ -40,7 +40,7 @@ dependencies = [
|
|||
"defusedxml == 0.7.1",
|
||||
"fastapi == 0.77.1",
|
||||
"jinja2 == 3.1.2",
|
||||
"openai == 0.20.0",
|
||||
"openai >= 0.27.0",
|
||||
"pillow == 9.3.0",
|
||||
"pydantic == 1.9.1",
|
||||
"pyqt6 == 6.3.1",
|
||||
|
|
|
@ -9,6 +9,7 @@ import schedule
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.gpt import summarize
|
||||
from khoj.processor.ledger.beancount_to_jsonl import BeancountToJsonl
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
|
@ -186,3 +187,39 @@ def configure_conversation_processor(conversation_processor_config):
|
|||
conversation_processor.chat_session = ""
|
||||
|
||||
return conversation_processor
|
||||
|
||||
|
||||
@schedule.repeat(schedule.every(15).minutes)
|
||||
def save_chat_session():
|
||||
# No need to create empty log file
|
||||
if not (
|
||||
state.processor_config
|
||||
and state.processor_config.conversation
|
||||
and state.processor_config.conversation.meta_log
|
||||
and state.processor_config.conversation.chat_session
|
||||
):
|
||||
return
|
||||
|
||||
# Summarize Conversation Logs for this Session
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
openai_api_key = state.processor_config.conversation.openai_api_key
|
||||
conversation_log = state.processor_config.conversation.meta_log
|
||||
model = state.processor_config.conversation.model
|
||||
session = {
|
||||
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
|
||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||
"session-end": len(conversation_log["chat"]),
|
||||
}
|
||||
if "session" in conversation_log:
|
||||
conversation_log["session"].append(session)
|
||||
else:
|
||||
conversation_log["session"] = [session]
|
||||
|
||||
# Save Conversation Metadata Logs to Disk
|
||||
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
|
||||
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
|
||||
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
||||
json.dump(conversation_log, logfile)
|
||||
|
||||
state.processor_config.conversation.chat_session = None
|
||||
logger.info("📩 Saved current chat session to conversation logs")
|
||||
|
|
|
@ -6,15 +6,9 @@
|
|||
|
||||
<link rel="icon" href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 144 144%22><text y=%22.86em%22 font-size=%22144%22>🦅</text></svg>">
|
||||
<link rel="icon" type="image/png" sizes="144x144" href="/static/assets/icons/favicon-144x144.png">
|
||||
<link rel="manifest" href="/static/khoj.webmanifest">
|
||||
<link rel="manifest" href="/static/khoj_chat.webmanifest">
|
||||
</head>
|
||||
<script>
|
||||
function setTypeFieldInUrl(type) {
|
||||
let url = new URL(window.location.href);
|
||||
url.searchParams.set("t", type.value);
|
||||
window.history.pushState({}, "", url.href);
|
||||
}
|
||||
|
||||
function formatDate(date) {
|
||||
// Format date in HH:MM, DD MMM YYYY format
|
||||
let time_string = date.toLocaleTimeString('en-IN', { hour: '2-digit', minute: '2-digit', hour12: false });
|
||||
|
@ -22,6 +16,11 @@
|
|||
return `${time_string}, ${date_string}`;
|
||||
}
|
||||
|
||||
function generateReference(reference, index) {
|
||||
// Generate HTML for Chat Reference
|
||||
return `<sup><abbr title="${reference}" tabindex="0">${index}</abbr></sup>`;
|
||||
}
|
||||
|
||||
function renderMessage(message, by, dt=null) {
|
||||
let message_time = formatDate(dt ?? new Date());
|
||||
let by_name = by == "khoj" ? "🦅 Khoj" : "🤔 You";
|
||||
|
@ -31,15 +30,26 @@
|
|||
<div class="chat-message-text ${by}">${message}</div>
|
||||
</div>
|
||||
`;
|
||||
// Scroll to bottom of input-body element
|
||||
// Scroll to bottom of chat-body element
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null) {
|
||||
let references = '';
|
||||
if (context) {
|
||||
references = context
|
||||
.split("\n\n# ")
|
||||
.map((reference, index) => generateReference(reference, index))
|
||||
.join("<sup>,</sup>");
|
||||
}
|
||||
|
||||
renderMessage(message+references, by, dt);
|
||||
}
|
||||
|
||||
function chat() {
|
||||
// Extract required fields for search from form
|
||||
query = document.getElementById("chat-input").value.trim();
|
||||
type_ = document.getElementById("chat-type").value;
|
||||
console.log(`Query: ${query}, Type: ${type_}`);
|
||||
let query = document.getElementById("chat-input").value.trim();
|
||||
console.log(`Query: ${query}`);
|
||||
|
||||
// Short circuit on empty query
|
||||
if (query.length === 0)
|
||||
|
@ -50,18 +60,15 @@
|
|||
document.getElementById("chat-input").value = "";
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
url = type_ === "chat"
|
||||
? `/api/beta/chat?q=${encodeURIComponent(query)}`
|
||||
: `/api/beta/summarize?q=${encodeURIComponent(query)}`;
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}`;
|
||||
|
||||
// Call specified Khoj API
|
||||
fetch(url)
|
||||
.then(response => response.json())
|
||||
.then(data => data.response)
|
||||
.then(response => {
|
||||
.then(data => {
|
||||
// Render message by Khoj to chat body
|
||||
console.log(response);
|
||||
renderMessage(response, "khoj");
|
||||
console.log(data.response);
|
||||
renderMessageWithReference(data.response, "khoj", data.context);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -73,18 +80,13 @@
|
|||
}
|
||||
|
||||
window.onload = function () {
|
||||
// Fill type field with value passed in URL query parameters, if any.
|
||||
var type_via_url = new URLSearchParams(window.location.search).get("t");
|
||||
if (type_via_url)
|
||||
document.getElementById("chat-type").value = type_via_url;
|
||||
|
||||
fetch('/api/beta/chat')
|
||||
fetch('/api/chat')
|
||||
.then(response => response.json())
|
||||
.then(data => data.response)
|
||||
.then(chat_logs => {
|
||||
// Render conversation history, if any
|
||||
chat_logs.forEach(chat_log => {
|
||||
renderMessage(chat_log.message, chat_log.by, new Date(chat_log.created));
|
||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created));
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -109,12 +111,6 @@
|
|||
<!-- Chat Footer -->
|
||||
<div id="chat-footer">
|
||||
<input type="text" id="chat-input" class="option" onkeyup=incrementalChat(event) autofocus="autofocus" placeholder="What is the meaning of life?">
|
||||
|
||||
<!--Select Chat Type from: Chat, Summarize -->
|
||||
<select id="chat-type" class="option" onchange="setTypeFieldInUrl(this)">
|
||||
<option value="chat">Chat</option>
|
||||
<option value="summarize">Summarize</option>
|
||||
</select>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
|
@ -217,7 +213,7 @@
|
|||
#chat-footer {
|
||||
padding: 0;
|
||||
display: grid;
|
||||
grid-template-columns: minmax(70px, 85%) auto;
|
||||
grid-template-columns: minmax(70px, 100%);
|
||||
grid-column-gap: 10px;
|
||||
grid-row-gap: 10px;
|
||||
}
|
||||
|
@ -234,6 +230,29 @@
|
|||
font-size: medium;
|
||||
}
|
||||
|
||||
@media (pointer: coarse), (hover: none) {
|
||||
abbr[title] {
|
||||
position: relative;
|
||||
padding-left: 4px; /* space references out to ease tapping */
|
||||
}
|
||||
abbr[title]:focus:after {
|
||||
content: attr(title);
|
||||
|
||||
/* position tooltip */
|
||||
position: absolute;
|
||||
left: 16px; /* open tooltip to right of ref link, instead of on top of it */
|
||||
width: auto;
|
||||
z-index: 1; /* show tooltip above chat messages */
|
||||
|
||||
/* style tooltip */
|
||||
background-color: #aaa;
|
||||
color: #f8fafc;
|
||||
border-radius: 2px;
|
||||
box-shadow: 1px 1px 4px 0 rgba(0, 0, 0, 0.4);
|
||||
font-size: 14px;
|
||||
padding: 2px 4px;
|
||||
}
|
||||
}
|
||||
@media only screen and (max-width: 600px) {
|
||||
body {
|
||||
grid-template-columns: 1fr;
|
||||
|
|
16
src/khoj/interface/web/khoj_chat.webmanifest
Normal file
16
src/khoj/interface/web/khoj_chat.webmanifest
Normal file
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"name": "Khoj Chat",
|
||||
"short_name": "Khoj Chat",
|
||||
"description": "A personal assistant for your notes",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/static/assets/icons/favicon-144x144.png",
|
||||
"sizes": "144x144",
|
||||
"type": "image/png"
|
||||
}
|
||||
],
|
||||
"theme_color": "#ffffff",
|
||||
"background_color": "#ffffff",
|
||||
"display": "standalone",
|
||||
"start_url": "/chat"
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
# Standard Packages
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# External Packages
|
||||
|
@ -8,6 +9,38 @@ import openai
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=500):
|
||||
"""
|
||||
Answer user query using provided text as reference with OpenAI's GPT
|
||||
"""
|
||||
# Initialize Variables
|
||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Setup Prompt based on Summary Type
|
||||
prompt = f"""
|
||||
You are a friendly, helpful personal assistant.
|
||||
Using the users notes below, answer their following question. If the answer is not contained within the notes, say "I don't know."
|
||||
|
||||
Notes:
|
||||
{text}
|
||||
|
||||
Question: {user_query}
|
||||
|
||||
Answer (in second person):"""
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop='"""'
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
story = response["choices"][0]["text"]
|
||||
return str(story).replace("\n\n", "")
|
||||
|
||||
|
||||
def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200):
|
||||
|
@ -34,6 +67,7 @@ Summarize the below notes about {user_query}:
|
|||
Summarize the notes in second person perspective:"""
|
||||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
|
||||
)
|
||||
|
@ -77,6 +111,7 @@ A:{ "search-type": "notes" }"""
|
|||
print(f"Message -> Prompt: {text} -> {prompt}")
|
||||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
|
||||
)
|
||||
|
@ -86,104 +121,68 @@ A:{ "search-type": "notes" }"""
|
|||
return json.loads(story.strip(empty_escape_sequences))
|
||||
|
||||
|
||||
def understand(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0):
|
||||
def converse(text, user_query, conversation_log=None, api_key=None, temperature=0):
|
||||
"""
|
||||
Understand user input using OpenAI's GPT
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
"""
|
||||
# Initialize Variables
|
||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
understand_primer = """
|
||||
Objective: Extract intent and trigger emotion information as JSON from each chat message
|
||||
|
||||
Potential intent types and valid argument values are listed below:
|
||||
- intent
|
||||
- remember(memory-type, query);
|
||||
- memory-type=["companion","notes","ledger","image","music"]
|
||||
- search(search-type, query);
|
||||
- search-type=["google"]
|
||||
- generate(activity, query);
|
||||
- activity=["paint","write","chat"]
|
||||
- trigger-emotion(emotion)
|
||||
- emotion=["happy","confidence","fear","surprise","sadness","disgust","anger","shy","curiosity","calm"]
|
||||
|
||||
Some examples are given below for reference:
|
||||
Q: How are you doing?
|
||||
A: { "intent": {"type": "generate", "activity": "chat", "query": "How are you doing?"}, "trigger-emotion": "happy" }
|
||||
Q: Do you remember what I told you about my brother Antoine when we were at the beach?
|
||||
A: { "intent": {"type": "remember", "memory-type": "companion", "query": "Brother Antoine when we were at the beach"}, "trigger-emotion": "curiosity" }
|
||||
Q: what was that fantasy story you told me last time?
|
||||
A: { "intent": {"type": "remember", "memory-type": "companion", "query": "fantasy story told last time"}, "trigger-emotion": "curiosity" }
|
||||
Q: Let's make some drawings about the stars on a clear full moon night!
|
||||
A: { "intent": {"type": "generate", "activity": "paint", "query": "stars on a clear full moon night"}, "trigger-emotion: "happy" }
|
||||
Q: Do you know anything about Lebanon cuisine in the 18th century?
|
||||
A: { "intent": {"type": "search", "search-type": "google", "query": "lebanon cusine in the 18th century"}, "trigger-emotion; "confidence" }
|
||||
Q: Tell me a scary story
|
||||
A: { "intent": {"type": "generate", "activity": "write", "query": "A scary story"}, "trigger-emotion": "fear" }
|
||||
Q: What fiction book was I reading last week about AI starship?
|
||||
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "fiction book about AI starship last week"}, "trigger-emotion": "curiosity" }
|
||||
Q: How much did I spend at Subway for dinner last time?
|
||||
A: { "intent": {"type": "remember", "memory-type": "ledger", "query": "last Subway dinner"}, "trigger-emotion": "calm" }
|
||||
Q: I'm feeling sleepy
|
||||
A: { "intent": {"type": "generate", "activity": "chat", "query": "I'm feeling sleepy"}, "trigger-emotion": "calm" }
|
||||
Q: What was that popular Sri lankan song that Alex had mentioned?
|
||||
A: { "intent": {"type": "remember", "memory-type": "music", "query": "popular Sri lankan song mentioned by Alex"}, "trigger-emotion": "curiosity" }
|
||||
Q: You're pretty funny!
|
||||
A: { "intent": {"type": "generate", "activity": "chat", "query": "You're pretty funny!"}, "trigger-emotion": "shy" }
|
||||
Q: Can you recommend a movie to watch from my notes?
|
||||
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "recommend movie to watch"}, "trigger-emotion": "curiosity" }
|
||||
Q: When did I go surfing last?
|
||||
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" }
|
||||
Q: Can you dance for me?
|
||||
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }"""
|
||||
|
||||
# Setup Prompt with Understand Primer
|
||||
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
|
||||
if verbose > 1:
|
||||
print(f"Message -> Prompt: {text} -> {prompt}")
|
||||
|
||||
# Get Response from GPT
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
story = str(response["choices"][0]["text"])
|
||||
return json.loads(story.strip(empty_escape_sequences))
|
||||
|
||||
|
||||
def converse(text, model, conversation_history=None, api_key=None, temperature=0.9, max_tokens=150):
|
||||
"""
|
||||
Converse with user using OpenAI's GPT
|
||||
"""
|
||||
# Initialize Variables
|
||||
max_words = 500
|
||||
model = "gpt-3.5-turbo"
|
||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
personality_primer = "You are a friendly, helpful personal assistant."
|
||||
conversation_primer = f"""
|
||||
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion.
|
||||
Using the notes and our chats as context, answer the following question.
|
||||
Current Date: {datetime.now().strftime("%Y-%m-%d")}
|
||||
|
||||
Human: Hello, who are you?
|
||||
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?"""
|
||||
Notes:
|
||||
{text}
|
||||
|
||||
Question: {user_query}"""
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
prompt = message_to_prompt(text, conversation_history or conversation_primer)
|
||||
prompt = " ".join(prompt.split()[:max_words])
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
personality_primer,
|
||||
conversation_log,
|
||||
)
|
||||
|
||||
# Get Response from GPT
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
logger.debug(f"Conversation Context for GPT: {messages}")
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=0.6,
|
||||
stop=["\n", "Human:", "AI:"],
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
story = str(response["choices"][0]["text"])
|
||||
story = str(response["choices"][0]["message"]["content"])
|
||||
return story.strip(empty_escape_sequences)
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(user_message, system_message, conversation_log=None):
|
||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||
# Extract Chat History for Context
|
||||
chat_logs = [f'{chat["message"]}\n\nNotes:\n{chat.get("context","")}' for chat in conversation_log.get("chat", [])]
|
||||
last_backnforth = reciprocal_conversation_to_chatml(chat_logs[-2:])
|
||||
rest_backnforth = reciprocal_conversation_to_chatml(chat_logs[-4:-2])
|
||||
|
||||
# Format user and system messages to chatml format
|
||||
system_chatml_message = [message_to_chatml(system_message, "system")]
|
||||
user_chatml_message = [message_to_chatml(user_message, "user")]
|
||||
|
||||
return rest_backnforth + system_chatml_message + last_backnforth + user_chatml_message
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
"""Convert a single back and forth between user and assistant to chatml format"""
|
||||
return [message_to_chatml(message, role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||
|
||||
|
||||
def message_to_chatml(message, role="assistant"):
|
||||
"""Create chatml message from message and role"""
|
||||
return {"role": role, "content": message}
|
||||
|
||||
|
||||
def message_to_prompt(
|
||||
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
|
||||
):
|
||||
|
@ -193,22 +192,20 @@ def message_to_prompt(
|
|||
return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
|
||||
|
||||
|
||||
def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]):
|
||||
def message_to_log(user_message, gpt_message, khoj_message_metadata={}, conversation_log=[]):
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_user_message_metadata = {
|
||||
default_khoj_message_metadata = {
|
||||
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
|
||||
"trigger-emotion": "calm",
|
||||
}
|
||||
current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Create json log from Human's message
|
||||
human_log = user_message_metadata or default_user_message_metadata
|
||||
human_log["message"] = user_message
|
||||
human_log["by"] = "you"
|
||||
human_log["created"] = current_dt
|
||||
human_log = {"message": user_message, "by": "you", "created": current_dt}
|
||||
|
||||
# Create json log from GPT's response
|
||||
khoj_log = {"message": gpt_message, "by": "khoj", "created": current_dt}
|
||||
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
|
||||
khoj_log = merge_dicts({"message": gpt_message, "by": "khoj", "created": current_dt}, khoj_log)
|
||||
|
||||
conversation_log.extend([human_log, khoj_log])
|
||||
return conversation_log
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# Standard Packages
|
||||
import math
|
||||
import yaml
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
|
@ -9,6 +10,7 @@ from fastapi import HTTPException
|
|||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.processor.conversation.gpt import converse, message_to_log, message_to_prompt
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import FullConfig, SearchResponse
|
||||
|
@ -53,7 +55,14 @@ async def set_config_data(updated_config: FullConfig):
|
|||
|
||||
|
||||
@api.get("/search", response_model=List[SearchResponse])
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||
def search(
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = None,
|
||||
r: Optional[bool] = False,
|
||||
score_threshold: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
):
|
||||
results: List[SearchResponse] = []
|
||||
if q is None or q == "":
|
||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
||||
|
@ -62,9 +71,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n
|
||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||
|
||||
# return cached results, if available
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}"
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||
if query_cache_key in state.query_cache:
|
||||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[query_cache_key]
|
||||
|
@ -72,7 +82,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||
# query org-mode notes
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
|
@ -81,7 +93,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown files
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
|
@ -90,7 +104,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||
# query transactions
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
|
@ -99,7 +115,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
elif (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
|
@ -108,7 +126,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
elif (t == SearchType.Image or t == None) and state.model.image_search:
|
||||
# query images
|
||||
with timer("Query took", logger):
|
||||
hits = image_search.query(user_query, results_count, state.model.image_search)
|
||||
hits = image_search.query(
|
||||
user_query, results_count, state.model.image_search, score_threshold=score_threshold
|
||||
)
|
||||
output_directory = constants.web_directory / "images"
|
||||
|
||||
# collate and return results
|
||||
|
@ -129,6 +149,8 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe,
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
|
@ -162,3 +184,40 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
|
|||
logger.info("📬 Processor reconfigured via API")
|
||||
|
||||
return {"status": "ok", "message": "khoj reloaded"}
|
||||
|
||||
|
||||
@api.get("/chat")
|
||||
def chat(q: Optional[str] = None):
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
|
||||
# Load Conversation History
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# If user query is empty, return chat history
|
||||
if not q:
|
||||
if meta_log.get("chat"):
|
||||
return {"status": "ok", "response": meta_log["chat"]}
|
||||
else:
|
||||
return {"status": "ok", "response": []}
|
||||
|
||||
# Collate context for GPT
|
||||
result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False)
|
||||
collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list])
|
||||
logger.debug(f"Reference Context:\n{collated_result}")
|
||||
|
||||
try:
|
||||
gpt_response = converse(collated_result, q, meta_log, api_key=api_key)
|
||||
status = "ok"
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
status = "error"
|
||||
|
||||
# Update Conversation History
|
||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q, gpt_response, khoj_message_metadata={"context": collated_result}, conversation_log=meta_log.get("chat", [])
|
||||
)
|
||||
|
||||
return {"status": status, "response": gpt_response, "context": collated_result}
|
||||
|
|
|
@ -1,24 +1,18 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
# External Packages
|
||||
import schedule
|
||||
from fastapi import APIRouter
|
||||
|
||||
# Internal Packages
|
||||
from khoj.routers.api import search
|
||||
from khoj.processor.conversation.gpt import (
|
||||
converse,
|
||||
answer,
|
||||
extract_search_type,
|
||||
message_to_log,
|
||||
message_to_prompt,
|
||||
understand,
|
||||
summarize,
|
||||
)
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils.helpers import get_from_dict, resolve_absolute_path
|
||||
from khoj.utils.helpers import get_from_dict
|
||||
from khoj.utils import state
|
||||
|
||||
|
||||
|
@ -48,116 +42,23 @@ def search_beta(q: str, n: Optional[int] = 1):
|
|||
return {"status": "ok", "result": search_results, "type": search_type}
|
||||
|
||||
|
||||
@api_beta.get("/summarize")
|
||||
def summarize_beta(q: str):
|
||||
@api_beta.get("/answer")
|
||||
def answer_beta(q: str):
|
||||
# Initialize Variables
|
||||
model = state.processor_config.conversation.model
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
|
||||
# Load Conversation History
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
# Collate context for GPT
|
||||
result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False)
|
||||
collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list])
|
||||
logger.debug(f"Reference Context:\n{collated_result}")
|
||||
|
||||
# Converse with OpenAI GPT
|
||||
result_list = search(q, n=1, r=True)
|
||||
collated_result = "\n".join([item.entry for item in result_list])
|
||||
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
|
||||
# Make GPT respond to user query using provided context
|
||||
try:
|
||||
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
|
||||
gpt_response = answer(collated_result, user_query=q, model=model, api_key=api_key)
|
||||
status = "ok"
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
status = "error"
|
||||
|
||||
# Update Conversation History
|
||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q, gpt_response, conversation_log=meta_log.get("chat", [])
|
||||
)
|
||||
|
||||
return {"status": status, "response": gpt_response}
|
||||
|
||||
|
||||
@api_beta.get("/chat")
|
||||
def chat(q: Optional[str] = None):
|
||||
# Initialize Variables
|
||||
model = state.processor_config.conversation.model
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
|
||||
# Load Conversation History
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# If user query is empty, return chat history
|
||||
if not q:
|
||||
if meta_log.get("chat"):
|
||||
return {"status": "ok", "response": meta_log["chat"]}
|
||||
else:
|
||||
return {"status": "ok", "response": []}
|
||||
|
||||
# Converse with OpenAI GPT
|
||||
metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose)
|
||||
logger.debug(f'Understood: {get_from_dict(metadata, "intent")}')
|
||||
|
||||
if get_from_dict(metadata, "intent", "memory-type") == "notes":
|
||||
query = get_from_dict(metadata, "intent", "query")
|
||||
result_list = search(query, n=1, t=SearchType.Org, r=True)
|
||||
collated_result = "\n".join([item.entry for item in result_list])
|
||||
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
|
||||
try:
|
||||
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
|
||||
status = "ok"
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
status = "error"
|
||||
else:
|
||||
try:
|
||||
gpt_response = converse(q, model, chat_session, api_key=api_key)
|
||||
status = "ok"
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
status = "error"
|
||||
|
||||
# Update Conversation History
|
||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q, gpt_response, metadata, meta_log.get("chat", [])
|
||||
)
|
||||
|
||||
return {"status": status, "response": gpt_response}
|
||||
|
||||
|
||||
@schedule.repeat(schedule.every(5).minutes)
|
||||
def save_chat_session():
|
||||
# No need to create empty log file
|
||||
if not (
|
||||
state.processor_config
|
||||
and state.processor_config.conversation
|
||||
and state.processor_config.conversation.meta_log
|
||||
and state.processor_config.conversation.chat_session
|
||||
):
|
||||
return
|
||||
|
||||
# Summarize Conversation Logs for this Session
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
openai_api_key = state.processor_config.conversation.openai_api_key
|
||||
conversation_log = state.processor_config.conversation.meta_log
|
||||
model = state.processor_config.conversation.model
|
||||
session = {
|
||||
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
|
||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||
"session-end": len(conversation_log["chat"]),
|
||||
}
|
||||
if "session" in conversation_log:
|
||||
conversation_log["session"].append(session)
|
||||
else:
|
||||
conversation_log["session"] = [session]
|
||||
|
||||
# Save Conversation Metadata Logs to Disk
|
||||
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
|
||||
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
|
||||
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
||||
json.dump(conversation_log, logfile)
|
||||
|
||||
state.processor_config.conversation.chat_session = None
|
||||
logger.info("📩 Saved current chat session to conversation logs")
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
import glob
|
||||
import math
|
||||
import pathlib
|
||||
import copy
|
||||
import shutil
|
||||
|
@ -142,7 +143,7 @@ def extract_metadata(image_name):
|
|||
return image_processed_metadata
|
||||
|
||||
|
||||
def query(raw_query, count, model: ImageSearchModel):
|
||||
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||
|
@ -198,6 +199,9 @@ def query(raw_query, count, model: ImageSearchModel):
|
|||
for corpus_id, scores in image_hits.items()
|
||||
]
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
|
||||
|
||||
# Sort the images based on their combined metadata, image scores
|
||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
|
@ -99,7 +100,13 @@ def compute_embeddings(
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]:
|
||||
def query(
|
||||
raw_query: str,
|
||||
model: TextSearchModel,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
dedupe: bool = True,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
|
||||
|
@ -129,11 +136,15 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
|||
if rank_results:
|
||||
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
||||
|
||||
# Order results by cross-encoder score followed by bi-encoder score
|
||||
hits = sort_results(rank_results, hits)
|
||||
|
||||
# Deduplicate entries by raw entry text before showing to users
|
||||
hits = deduplicate_results(entries, hits)
|
||||
if dedupe:
|
||||
hits = deduplicate_results(entries, hits)
|
||||
|
||||
return hits, entries
|
||||
|
||||
|
@ -143,7 +154,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
|||
SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": entries[hit["corpus_id"]].raw,
|
||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
|
||||
"score": f"{hit.get('cross-score', 'score')}:.3f",
|
||||
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
|
||||
}
|
||||
)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.gpt import converse, understand, message_to_prompt
|
||||
from khoj.processor.conversation.gpt import converse, message_to_prompt
|
||||
|
||||
|
||||
# Initialize variables for tests
|
||||
|
|
Loading…
Add table
Reference in a new issue