mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Migrate to Llama.cpp for Offline Chat (#680)
## Benefits - Support all GGUF format chat models - Support more GPUs like AMD, Nvidia, Mac, Vulcan (previously just Vulcan, Mac) - Support more capabilities like larger context window, schema enforcement, speculative decoding etc. ## Changes ### Major - Use llama.cpp for offline chat models - Support larger context window - Automatically apply appropriate chat template. So offline chat models not using llama2 format are now supported - Use better default offline chat model, NousResearch/Hermes-2-Pro-Mistral-7B - Enable extract queries actor to improve notes search with offline chat - Update documentation to use llama.cpp for offline chat in Khoj ### Minor - Migrate to use NouseResearch's Hermes-2-Pro 7B as default offline chat model in khoj.yml - Rename GPT4AllChatProcessor to OfflineChatProcessor Config, Model - Only add location to image prompt generator when location known
This commit is contained in:
commit
3c3e48b18c
23 changed files with 365 additions and 320 deletions
|
@ -14,16 +14,16 @@ You can configure Khoj to chat with you about anything. When relevant, it'll use
|
||||||
|
|
||||||
### Setup (Self-Hosting)
|
### Setup (Self-Hosting)
|
||||||
#### Offline Chat
|
#### Offline Chat
|
||||||
Offline chat stays completely private and works without internet using open-source models.
|
Offline chat stays completely private and can work without internet using open-source models.
|
||||||
|
|
||||||
> **System Requirements**:
|
> **System Requirements**:
|
||||||
> - Minimum 8 GB RAM. Recommend **16Gb VRAM**
|
> - Minimum 8 GB RAM. Recommend **16Gb VRAM**
|
||||||
> - Minimum **5 GB of Disk** available
|
> - Minimum **5 GB of Disk** available
|
||||||
> - A CPU supporting [AVX or AVX2 instructions](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) is required
|
> - A CPU supporting [AVX or AVX2 instructions](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) is required
|
||||||
> - A Mac M1+ or [Vulcan supported GPU](https://vulkan.gpuinfo.org/) should significantly speed up chat response times
|
> - An Nvidia, AMD GPU or a Mac M1+ machine would significantly speed up chat response times
|
||||||
|
|
||||||
1. Open your [Khoj offline settings](http://localhost:42110/server/admin/database/offlinechatprocessorconversationconfig/) and click *Enable* on the Offline Chat configuration.
|
1. Open your [Khoj offline settings](http://localhost:42110/server/admin/database/offlinechatprocessorconversationconfig/) and click *Enable* on the Offline Chat configuration.
|
||||||
2. Open your [Chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/) and add a new option for the offline chat model you want to use. Make sure to use `Offline` as its type. We currently only support offline models that use the [Llama chat prompt](https://replicate.com/blog/how-to-prompt-llama#wrap-user-input-with-inst-inst-tags) format. We recommend using `mistral-7b-instruct-v0.1.Q4_0.gguf`.
|
2. Open your [Chat model options settings](http://localhost:42110/server/admin/database/chatmodeloptions/) and add any [GGUF chat model](https://huggingface.co/models?library=gguf) to use for offline chat. Make sure to use `Offline` as its type. For a balanced chat model that runs well on standard consumer hardware we recommend using [Hermes-2-Pro-Mistral-7B by NousResearch](https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF) by default.
|
||||||
|
|
||||||
|
|
||||||
:::tip[Note]
|
:::tip[Note]
|
||||||
|
|
|
@ -101,6 +101,7 @@ sudo -u postgres createdb khoj --password
|
||||||
|
|
||||||
##### Local Server Setup
|
##### Local Server Setup
|
||||||
- *Make sure [python](https://realpython.com/installing-python/) and [pip](https://pip.pypa.io/en/stable/installation/) are installed on your machine*
|
- *Make sure [python](https://realpython.com/installing-python/) and [pip](https://pip.pypa.io/en/stable/installation/) are installed on your machine*
|
||||||
|
- Check [llama-cpp-python setup](https://python.langchain.com/docs/integrations/llms/llamacpp#installation) if you hit any llama-cpp issues with the installation
|
||||||
|
|
||||||
Run the following command in your terminal to install the Khoj backend.
|
Run the following command in your terminal to install the Khoj backend.
|
||||||
|
|
||||||
|
@ -108,17 +109,36 @@ Run the following command in your terminal to install the Khoj backend.
|
||||||
<Tabs groupId="operating-systems">
|
<Tabs groupId="operating-systems">
|
||||||
<TabItem value="macos" label="MacOS">
|
<TabItem value="macos" label="MacOS">
|
||||||
```shell
|
```shell
|
||||||
|
# ARM/M1+ Machines
|
||||||
|
MAKE_ARGS="-DLLAMA_METAL=on" python -m pip install khoj-assistant
|
||||||
|
|
||||||
|
# Intel Machines
|
||||||
python -m pip install khoj-assistant
|
python -m pip install khoj-assistant
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="win" label="Windows">
|
<TabItem value="win" label="Windows">
|
||||||
```shell
|
```shell
|
||||||
|
# 1. (Optional) To use NVIDIA (CUDA) GPU
|
||||||
|
$env:CMAKE_ARGS = "-DLLAMA_OPENBLAS=on"
|
||||||
|
# 1. (Optional) To use AMD (ROCm) GPU
|
||||||
|
CMAKE_ARGS="-DLLAMA_HIPBLAS=on"
|
||||||
|
# 1. (Optional) To use VULCAN GPU
|
||||||
|
CMAKE_ARGS="-DLLAMA_VULKAN=on"
|
||||||
|
|
||||||
|
# 2. Install Khoj
|
||||||
py -m pip install khoj-assistant
|
py -m pip install khoj-assistant
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="unix" label="Linux">
|
<TabItem value="unix" label="Linux">
|
||||||
```shell
|
```shell
|
||||||
|
# CPU
|
||||||
python -m pip install khoj-assistant
|
python -m pip install khoj-assistant
|
||||||
|
# NVIDIA (CUDA) GPU
|
||||||
|
CMAKE_ARGS="DLLAMA_CUBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
||||||
|
# AMD (ROCm) GPU
|
||||||
|
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
||||||
|
# VULCAN GPU
|
||||||
|
CMAKE_ARGS="-DLLAMA_VULKAN=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
@ -179,13 +199,13 @@ If you're using a custom domain, you must use an SSL certificate. You can use [L
|
||||||
1. Go to http://localhost:42110/server/admin and login with your admin credentials.
|
1. Go to http://localhost:42110/server/admin and login with your admin credentials.
|
||||||
1. Go to [OpenAI settings](http://localhost:42110/server/admin/database/openaiprocessorconversationconfig/) in the server admin settings to add an OpenAI processor conversation config. This is where you set your API key. Alternatively, you can go to the [offline chat settings](http://localhost:42110/server/admin/database/offlinechatprocessorconversationconfig/) and simply create a new setting with `Enabled` set to `True`.
|
1. Go to [OpenAI settings](http://localhost:42110/server/admin/database/openaiprocessorconversationconfig/) in the server admin settings to add an OpenAI processor conversation config. This is where you set your API key. Alternatively, you can go to the [offline chat settings](http://localhost:42110/server/admin/database/offlinechatprocessorconversationconfig/) and simply create a new setting with `Enabled` set to `True`.
|
||||||
2. Go to the ChatModelOptions if you want to add additional models for chat.
|
2. Go to the ChatModelOptions if you want to add additional models for chat.
|
||||||
- Set the `chat-model` field to a supported chat model[^1] of your choice. For example, you can specify `gpt-4-turbo-preview` if you're using OpenAI or `mistral-7b-instruct-v0.1.Q4_0.gguf` if you're using offline chat.
|
- Set the `chat-model` field to a supported chat model[^1] of your choice. For example, you can specify `gpt-4-turbo-preview` if you're using OpenAI or `NousResearch/Hermes-2-Pro-Mistral-7B-GGUF` if you're using offline chat.
|
||||||
- Make sure to set the `model-type` field to `OpenAI` or `Offline` respectively.
|
- Make sure to set the `model-type` field to `OpenAI` or `Offline` respectively.
|
||||||
- The `tokenizer` and `max-prompt-size` fields are optional. Set them only when using a non-standard model (i.e not mistral, gpt or llama2 model).
|
- The `tokenizer` and `max-prompt-size` fields are optional. Set them only when using a non-standard model (i.e not mistral, gpt or llama2 model).
|
||||||
1. Select files and folders to index [using the desktop client](/get-started/setup#2-download-the-desktop-client). When you click 'Save', the files will be sent to your server for indexing.
|
1. Select files and folders to index [using the desktop client](/get-started/setup#2-download-the-desktop-client). When you click 'Save', the files will be sent to your server for indexing.
|
||||||
- Select Notion workspaces and Github repositories to index using the web interface.
|
- Select Notion workspaces and Github repositories to index using the web interface.
|
||||||
|
|
||||||
[^1]: Khoj, by default, can use [OpenAI GPT3.5+ chat models](https://platform.openai.com/docs/models/overview) or [GPT4All chat models that follow Llama2 Prompt Template](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models2.json). See [this section](/miscellaneous/advanced#use-openai-compatible-llm-api-server-self-hosting) to use non-standard chat models
|
[^1]: Khoj, by default, can use [OpenAI GPT3.5+ chat models](https://platform.openai.com/docs/models/overview) or [GGUF chat models](https://huggingface.co/models?library=gguf). See [this section](/miscellaneous/advanced#use-openai-compatible-llm-api-server-self-hosting) to use non-standard chat models
|
||||||
|
|
||||||
:::tip[Note]
|
:::tip[Note]
|
||||||
Using Safari on Mac? You might not be able to login to the admin panel. Try using Chrome or Firefox instead.
|
Using Safari on Mac? You might not be able to login to the admin panel. Try using Chrome or Firefox instead.
|
||||||
|
|
|
@ -10,4 +10,4 @@ Many Open Source projects are used to power Khoj. Here's a few of them:
|
||||||
- Charles Cave for [OrgNode Parser](http://members.optusnet.com.au/~charles57/GTD/orgnode.html)
|
- Charles Cave for [OrgNode Parser](http://members.optusnet.com.au/~charles57/GTD/orgnode.html)
|
||||||
- [Org.js](https://mooz.github.io/org-js/) to render Org-mode results on the Web interface
|
- [Org.js](https://mooz.github.io/org-js/) to render Org-mode results on the Web interface
|
||||||
- [Markdown-it](https://github.com/markdown-it/markdown-it) to render Markdown results on the Web interface
|
- [Markdown-it](https://github.com/markdown-it/markdown-it) to render Markdown results on the Web interface
|
||||||
- [GPT4All](https://github.com/nomic-ai/gpt4all) to chat with local LLM
|
- [Llama.cpp](https://github.com/ggerganov/llama.cpp) to chat with local LLM
|
||||||
|
|
|
@ -62,8 +62,7 @@ dependencies = [
|
||||||
"pymupdf >= 1.23.5",
|
"pymupdf >= 1.23.5",
|
||||||
"django == 4.2.10",
|
"django == 4.2.10",
|
||||||
"authlib == 1.2.1",
|
"authlib == 1.2.1",
|
||||||
"gpt4all == 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
"llama-cpp-python == 0.2.56",
|
||||||
"gpt4all == 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
|
||||||
"itsdangerous == 2.1.2",
|
"itsdangerous == 2.1.2",
|
||||||
"httpx == 0.25.0",
|
"httpx == 0.25.0",
|
||||||
"pgvector == 0.2.4",
|
"pgvector == 0.2.4",
|
||||||
|
|
|
@ -43,7 +43,7 @@ from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import OfflineChatProcessorModel
|
||||||
from khoj.utils.helpers import generate_random_name, is_none_or_empty
|
from khoj.utils.helpers import generate_random_name, is_none_or_empty
|
||||||
|
|
||||||
|
|
||||||
|
@ -705,8 +705,8 @@ class ConversationAdapters:
|
||||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||||
|
|
||||||
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
||||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
|
||||||
|
|
||||||
return conversation_config
|
return conversation_config
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ class ChatModelOptions(BaseModel):
|
||||||
|
|
||||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
|
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||||
|
|
||||||
|
|
||||||
|
|
71
src/khoj/migrations/migrate_offline_chat_default_model_2.py
Normal file
71
src/khoj/migrations/migrate_offline_chat_default_model_2.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
"""
|
||||||
|
Current format of khoj.yml
|
||||||
|
---
|
||||||
|
app:
|
||||||
|
...
|
||||||
|
content-type:
|
||||||
|
...
|
||||||
|
processor:
|
||||||
|
conversation:
|
||||||
|
offline-chat:
|
||||||
|
enable-offline-chat: false
|
||||||
|
chat-model: mistral-7b-instruct-v0.1.Q4_0.gguf
|
||||||
|
...
|
||||||
|
search-type:
|
||||||
|
...
|
||||||
|
|
||||||
|
New format of khoj.yml
|
||||||
|
---
|
||||||
|
app:
|
||||||
|
...
|
||||||
|
content-type:
|
||||||
|
...
|
||||||
|
processor:
|
||||||
|
conversation:
|
||||||
|
offline-chat:
|
||||||
|
enable-offline-chat: false
|
||||||
|
chat-model: NousResearch/Hermes-2-Pro-Mistral-7B-GGUF
|
||||||
|
...
|
||||||
|
search-type:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_offline_chat_default_model(args):
|
||||||
|
schema_version = "1.7.0"
|
||||||
|
raw_config = load_config_from_file(args.config_file)
|
||||||
|
previous_version = raw_config.get("version")
|
||||||
|
|
||||||
|
if "processor" not in raw_config:
|
||||||
|
return args
|
||||||
|
if raw_config["processor"] is None:
|
||||||
|
return args
|
||||||
|
if "conversation" not in raw_config["processor"]:
|
||||||
|
return args
|
||||||
|
if "offline-chat" not in raw_config["processor"]["conversation"]:
|
||||||
|
return args
|
||||||
|
if "chat-model" not in raw_config["processor"]["conversation"]["offline-chat"]:
|
||||||
|
return args
|
||||||
|
|
||||||
|
if previous_version is None or version.parse(previous_version) < version.parse(schema_version):
|
||||||
|
logger.info(
|
||||||
|
f"Upgrading config schema to {schema_version} from {previous_version} to change default (offline) chat model to mistral GGUF"
|
||||||
|
)
|
||||||
|
raw_config["version"] = schema_version
|
||||||
|
|
||||||
|
# Update offline chat model to use Nous Research's Hermes-2-Pro GGUF in path format suitable for llama-cpp
|
||||||
|
offline_chat_model = raw_config["processor"]["conversation"]["offline-chat"]["chat-model"]
|
||||||
|
if offline_chat_model == "mistral-7b-instruct-v0.1.Q4_0.gguf":
|
||||||
|
raw_config["processor"]["conversation"]["offline-chat"][
|
||||||
|
"chat-model"
|
||||||
|
] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||||
|
|
||||||
|
save_config_to_file(raw_config, args.config_file)
|
||||||
|
return args
|
|
@ -1,13 +1,15 @@
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections import deque
|
from datetime import datetime, timedelta
|
||||||
from datetime import datetime
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Iterator, List, Union
|
from typing import Any, Iterator, List, Union
|
||||||
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
from llama_cpp import Llama
|
||||||
|
|
||||||
from khoj.database.models import Agent
|
from khoj.database.models import Agent
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
|
@ -22,7 +24,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def extract_questions_offline(
|
def extract_questions_offline(
|
||||||
text: str,
|
text: str,
|
||||||
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
use_history: bool = True,
|
use_history: bool = True,
|
||||||
|
@ -32,22 +34,14 @@ def extract_questions_offline(
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
from gpt4all import GPT4All
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Assert that loaded_model is either None or of type GPT4All
|
|
||||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
|
||||||
|
|
||||||
all_questions = text.split("? ")
|
all_questions = text.split("? ")
|
||||||
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
||||||
|
|
||||||
if not should_extract_questions:
|
if not should_extract_questions:
|
||||||
return all_questions
|
return all_questions
|
||||||
|
|
||||||
gpt4all_model = loaded_model or GPT4All(model)
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
|
offline_chat_model = loaded_model or download_model(model)
|
||||||
|
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||||
|
|
||||||
|
@ -56,37 +50,36 @@ def extract_questions_offline(
|
||||||
|
|
||||||
if use_history:
|
if use_history:
|
||||||
for chat in conversation_log.get("chat", [])[-4:]:
|
for chat in conversation_log.get("chat", [])[-4:]:
|
||||||
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
|
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type"):
|
||||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
chat_history += f"A: {chat['message']}\n"
|
chat_history += f"Khoj: {chat['message']}\n\n"
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
today = datetime.today()
|
||||||
last_year = datetime.now().year - 1
|
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
|
||||||
last_christmas_date = f"{last_year}-12-25"
|
last_year = today.year - 1
|
||||||
next_christmas_date = f"{datetime.now().year}-12-25"
|
example_questions = prompts.extract_questions_offline.format(
|
||||||
system_prompt = prompts.system_prompt_extract_questions_gpt4all.format(
|
|
||||||
message=(prompts.system_prompt_message_extract_questions_gpt4all)
|
|
||||||
)
|
|
||||||
example_questions = prompts.extract_questions_gpt4all_sample.format(
|
|
||||||
query=text,
|
query=text,
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
current_date=current_date,
|
current_date=today.strftime("%Y-%m-%d"),
|
||||||
|
yesterday_date=yesterday,
|
||||||
last_year=last_year,
|
last_year=last_year,
|
||||||
last_christmas_date=last_christmas_date,
|
this_year=today.year,
|
||||||
next_christmas_date=next_christmas_date,
|
|
||||||
location=location,
|
location=location,
|
||||||
)
|
)
|
||||||
message = system_prompt + example_questions
|
messages = generate_chatml_messages_with_context(
|
||||||
|
example_questions, model_name=model, loaded_model=offline_chat_model
|
||||||
|
)
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512)
|
response = send_message_to_model_offline(messages, loaded_model=offline_chat_model)
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
try:
|
try:
|
||||||
# This will expect to be a list with a single string with a list of questions
|
# This will expect to be a list with a single string with a list of questions
|
||||||
questions = (
|
questions_str = (
|
||||||
str(response)
|
str(response)
|
||||||
.strip(empty_escape_sequences)
|
.strip(empty_escape_sequences)
|
||||||
.replace("['", '["')
|
.replace("['", '["')
|
||||||
|
@ -94,11 +87,8 @@ def extract_questions_offline(
|
||||||
.replace("</s>", "")
|
.replace("</s>", "")
|
||||||
.replace("']", '"]')
|
.replace("']", '"]')
|
||||||
.replace("', '", '", "')
|
.replace("', '", '", "')
|
||||||
.replace('["', "")
|
|
||||||
.replace('"]', "")
|
|
||||||
.split("? ")
|
|
||||||
)
|
)
|
||||||
questions = [q + "?" for q in questions[:-1]] + [questions[-1]]
|
questions: List[str] = json.loads(questions_str)
|
||||||
questions = filter_questions(questions)
|
questions = filter_questions(questions)
|
||||||
except:
|
except:
|
||||||
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||||
|
@ -121,12 +111,12 @@ def filter_questions(questions: List[str]):
|
||||||
"do not know",
|
"do not know",
|
||||||
"do not understand",
|
"do not understand",
|
||||||
]
|
]
|
||||||
filtered_questions = []
|
filtered_questions = set()
|
||||||
for q in questions:
|
for q in questions:
|
||||||
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
|
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
|
||||||
filtered_questions.append(q)
|
filtered_questions.add(q)
|
||||||
|
|
||||||
return filtered_questions
|
return list(filtered_questions)
|
||||||
|
|
||||||
|
|
||||||
def converse_offline(
|
def converse_offline(
|
||||||
|
@ -134,7 +124,7 @@ def converse_offline(
|
||||||
references=[],
|
references=[],
|
||||||
online_results=[],
|
online_results=[],
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
|
@ -147,25 +137,19 @@ def converse_offline(
|
||||||
"""
|
"""
|
||||||
Converse with user using Llama
|
Converse with user using Llama
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
from gpt4all import GPT4All
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
|
||||||
gpt4all_model = loaded_model or GPT4All(model)
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
|
offline_chat_model = loaded_model or download_model(model)
|
||||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
system_prompt = prompts.custom_system_prompt_message_gpt4all.format(
|
system_prompt = prompts.custom_system_prompt_offline_chat.format(
|
||||||
name=agent.name, bio=agent.personality, current_date=current_date
|
name=agent.name, bio=agent.personality, current_date=current_date
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = prompts.system_prompt_message_gpt4all.format(current_date=current_date)
|
system_prompt = prompts.system_prompt_offline_chat.format(current_date=current_date)
|
||||||
|
|
||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||||
|
|
||||||
|
@ -193,7 +177,7 @@ def converse_offline(
|
||||||
|
|
||||||
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
|
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
|
||||||
if not is_none_or_empty(compiled_references_message):
|
if not is_none_or_empty(compiled_references_message):
|
||||||
conversation_primer = f"{prompts.notes_conversation_gpt4all.format(references=compiled_references_message)}\n{conversation_primer}"
|
conversation_primer = f"{prompts.notes_conversation_offline.format(references=compiled_references_message)}\n{conversation_primer}"
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
|
@ -201,72 +185,44 @@ def converse_offline(
|
||||||
system_prompt,
|
system_prompt,
|
||||||
conversation_log,
|
conversation_log,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
|
loaded_model=offline_chat_model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
|
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||||
user_message = messages[-1]
|
|
||||||
system_message = messages[0]
|
|
||||||
conversation_history = messages[1:-1]
|
|
||||||
|
|
||||||
formatted_messages = [
|
|
||||||
prompts.khoj_message_gpt4all.format(message=message.content)
|
|
||||||
if message.role == "assistant"
|
|
||||||
else prompts.user_message_gpt4all.format(message=message.content)
|
|
||||||
for message in conversation_history
|
|
||||||
]
|
|
||||||
|
|
||||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||||
chat_history = "".join(formatted_messages)
|
|
||||||
templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
|
|
||||||
templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
|
|
||||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
|
||||||
response_queue: deque[str] = deque(maxlen=3) # Create a response queue with a maximum length of 3
|
|
||||||
hit_stop_phrase = False
|
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
|
|
||||||
try:
|
try:
|
||||||
|
response_iterator = send_message_to_model_offline(
|
||||||
|
messages, loaded_model=model, stop=stop_phrases, streaming=True
|
||||||
|
)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
response_queue.append(response)
|
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||||
hit_stop_phrase = any(stop_phrase in "".join(response_queue) for stop_phrase in stop_phrases)
|
|
||||||
if hit_stop_phrase:
|
|
||||||
logger.debug(f"Stop response as hit stop phrase: {''.join(response_queue)}")
|
|
||||||
break
|
|
||||||
# Start streaming the response at a lag once the queue is full
|
|
||||||
# This allows stop word testing before sending the response
|
|
||||||
if len(response_queue) == response_queue.maxlen:
|
|
||||||
g.send(response_queue[0])
|
|
||||||
finally:
|
finally:
|
||||||
if not hit_stop_phrase:
|
|
||||||
if len(response_queue) == response_queue.maxlen:
|
|
||||||
# remove already sent reponse chunk
|
|
||||||
response_queue.popleft()
|
|
||||||
# send the remaining response
|
|
||||||
g.send("".join(response_queue))
|
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
g.close()
|
g.close()
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model_offline(
|
def send_message_to_model_offline(
|
||||||
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message=""
|
messages: List[ChatMessage],
|
||||||
) -> str:
|
loaded_model=None,
|
||||||
try:
|
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||||
from gpt4all import GPT4All
|
streaming=False,
|
||||||
except ModuleNotFoundError as e:
|
stop=[],
|
||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
):
|
||||||
raise e
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
|
offline_chat_model = loaded_model or download_model(model)
|
||||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
gpt4all_model = loaded_model or GPT4All(model)
|
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
|
||||||
|
if streaming:
|
||||||
return gpt4all_model.generate(
|
return response
|
||||||
system_message + message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming
|
else:
|
||||||
)
|
return response["choices"][0]["message"].get("content", "")
|
||||||
|
|
|
@ -1,43 +1,54 @@
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def download_model(model_name: str):
|
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
||||||
try:
|
from llama_cpp.llama import Llama
|
||||||
import gpt4all
|
|
||||||
except ModuleNotFoundError as e:
|
# Initialize Model Parameters. Use n_ctx=0 to get context size from the model
|
||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
||||||
raise e
|
|
||||||
|
|
||||||
# Decide whether to load model to GPU or CPU
|
# Decide whether to load model to GPU or CPU
|
||||||
chat_model_config = None
|
device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
|
||||||
|
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
|
||||||
|
|
||||||
|
# Check if the model is already downloaded
|
||||||
|
model_path = load_model_from_cache(repo_id, filename)
|
||||||
|
chat_model = None
|
||||||
try:
|
try:
|
||||||
# Download the chat model and its config
|
if model_path:
|
||||||
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
|
chat_model = Llama(model_path, **kwargs)
|
||||||
|
|
||||||
# Try load chat model to GPU if:
|
|
||||||
# 1. Loading chat model to GPU isn't disabled via CLI and
|
|
||||||
# 2. Machine has GPU
|
|
||||||
# 3. GPU has enough free memory to load the chat model with max context length of 4096
|
|
||||||
device = (
|
|
||||||
"gpu"
|
|
||||||
if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096)
|
|
||||||
else "cpu"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
device = "cpu"
|
|
||||||
except Exception as e:
|
|
||||||
if chat_model_config is None:
|
|
||||||
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
|
|
||||||
logger.debug(f"Unable to download model config from gpt4all website: {e}")
|
|
||||||
else:
|
else:
|
||||||
raise e
|
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||||
|
except:
|
||||||
|
# Load model on CPU if GPU is not available
|
||||||
|
kwargs["n_gpu_layers"], device = 0, "cpu"
|
||||||
|
if model_path:
|
||||||
|
chat_model = Llama(model_path, **kwargs)
|
||||||
|
else:
|
||||||
|
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||||
|
|
||||||
# Now load the downloaded chat model onto appropriate device
|
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
|
||||||
chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False)
|
|
||||||
logger.debug(f"Loaded chat model to {device.upper()}.")
|
|
||||||
|
|
||||||
return chat_model
|
return chat_model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||||
|
# Construct the path to the model file in the cache directory
|
||||||
|
repo_org, repo_name = repo_id.split("/")
|
||||||
|
object_id = "--".join([repo_type, repo_org, repo_name])
|
||||||
|
model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename])
|
||||||
|
|
||||||
|
# Check if the model file exists
|
||||||
|
paths = glob.glob(model_path)
|
||||||
|
if paths:
|
||||||
|
return paths[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
|
@ -101,8 +101,3 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_
|
||||||
chat(messages=messages)
|
chat(messages=messages)
|
||||||
|
|
||||||
g.close()
|
g.close()
|
||||||
|
|
||||||
|
|
||||||
def extract_summaries(metadata):
|
|
||||||
"""Extract summaries from metadata"""
|
|
||||||
return "".join([f'\n{session["summary"]}' for session in metadata])
|
|
||||||
|
|
|
@ -65,9 +65,9 @@ no_entries_found = PromptTemplate.from_template(
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
## Conversation Prompts for GPT4All Models
|
## Conversation Prompts for Offline Chat Models
|
||||||
## --
|
## --
|
||||||
system_prompt_message_gpt4all = PromptTemplate.from_template(
|
system_prompt_offline_chat = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
You are Khoj, a smart, inquisitive and helpful personal assistant.
|
You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||||
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||||
|
@ -79,7 +79,7 @@ Today is {current_date} in UTC.
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_system_prompt_message_gpt4all = PromptTemplate.from_template(
|
custom_system_prompt_offline_chat = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
You are {name}, a personal agent on Khoj.
|
You are {name}, a personal agent on Khoj.
|
||||||
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||||
|
@ -93,40 +93,6 @@ Instructions:\n{bio}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
system_prompt_message_extract_questions_gpt4all = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
|
||||||
- Write the question as if you can search for the answer on the user's personal notes.
|
|
||||||
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
|
|
||||||
- Add as much context from the previous questions and notes as required into your search queries.
|
|
||||||
- Provide search queries as a list of questions
|
|
||||||
What follow-up questions, if any, will you need to ask to answer the user's question?
|
|
||||||
"""
|
|
||||||
|
|
||||||
system_prompt_gpt4all = PromptTemplate.from_template(
|
|
||||||
"""
|
|
||||||
<s>[INST] <<SYS>>
|
|
||||||
{message}
|
|
||||||
<</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>"""
|
|
||||||
)
|
|
||||||
|
|
||||||
system_prompt_extract_questions_gpt4all = PromptTemplate.from_template(
|
|
||||||
"""
|
|
||||||
<s>[INST] <<SYS>>
|
|
||||||
{message}
|
|
||||||
<</SYS>>[/INST]</s>"""
|
|
||||||
)
|
|
||||||
|
|
||||||
user_message_gpt4all = PromptTemplate.from_template(
|
|
||||||
"""
|
|
||||||
<s>[INST] {message} [/INST]
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
khoj_message_gpt4all = PromptTemplate.from_template(
|
|
||||||
"""
|
|
||||||
{message}</s>
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
## Notes Conversation
|
## Notes Conversation
|
||||||
## --
|
## --
|
||||||
notes_conversation = PromptTemplate.from_template(
|
notes_conversation = PromptTemplate.from_template(
|
||||||
|
@ -139,7 +105,7 @@ Notes:
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
notes_conversation_gpt4all = PromptTemplate.from_template(
|
notes_conversation_offline = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
User's Notes:
|
User's Notes:
|
||||||
{references}
|
{references}
|
||||||
|
@ -191,58 +157,50 @@ Query: {query}""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
## Summarize Notes
|
|
||||||
## --
|
|
||||||
summarize_notes = PromptTemplate.from_template(
|
|
||||||
"""
|
|
||||||
Summarize the below notes about {user_query}:
|
|
||||||
|
|
||||||
{text}
|
|
||||||
|
|
||||||
Summarize the notes in second person perspective:"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
## Answer
|
|
||||||
## --
|
|
||||||
answer = PromptTemplate.from_template(
|
|
||||||
"""
|
|
||||||
You are a friendly, helpful personal assistant.
|
|
||||||
Using the users notes below, answer their following question. If the answer is not contained within the notes, say "I don't know."
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
{text}
|
|
||||||
|
|
||||||
Question: {user_query}
|
|
||||||
|
|
||||||
Answer (in second person):"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
## Extract Questions
|
## Extract Questions
|
||||||
## --
|
## --
|
||||||
extract_questions_gpt4all_sample = PromptTemplate.from_template(
|
extract_questions_offline = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
<s>[INST] <<SYS>>Current Date: {current_date}. User's Location: {location}<</SYS>> [/INST]</s>
|
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes. Construct search queries to retrieve relevant information to answer the user's question.
|
||||||
<s>[INST] How was my trip to Cambodia? [/INST]
|
- You will be provided past questions(Q) and answers(A) for context.
|
||||||
How was my trip to Cambodia?</s>
|
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use proper nouns like name of the person or thing you are referring to.
|
||||||
<s>[INST] Who did I visit the temple with on that trip? [/INST]
|
- Add as much context from the previous questions and answers as required into your search queries.
|
||||||
Who did I visit the temple with in Cambodia?</s>
|
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||||
<s>[INST] How should I take care of my plants? [/INST]
|
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||||
What kind of plants do I have? What issues do my plants have?</s>
|
|
||||||
<s>[INST] How many tennis balls fit in the back of a 2002 Honda Civic? [/INST]
|
Current Date: {current_date}
|
||||||
What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
|
User's Location: {location}
|
||||||
<s>[INST] What did I do for Christmas last year? [/INST]
|
|
||||||
What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
|
Examples:
|
||||||
<s>[INST] How are you feeling today? [/INST]</s>
|
Q: How was my trip to Cambodia?
|
||||||
<s>[INST] Is Alice older than Bob? [/INST]
|
Khoj: ["How was my trip to Cambodia?"]
|
||||||
When was Alice born? What is Bob's age?</s>
|
|
||||||
<s>[INST] <<SYS>>
|
Q: Who did I visit the temple with on that trip?
|
||||||
Use these notes from the user's previous conversations to provide a response:
|
Khoj: ["Who did I visit the temple with in Cambodia?"]
|
||||||
|
|
||||||
|
Q: Which of them is older?
|
||||||
|
Khoj: ["When was Alice born?", "What is Bob's age?"]
|
||||||
|
|
||||||
|
Q: Where did John say he was? He mentioned it in our call last week.
|
||||||
|
Khoj: ["Where is John? dt>='{last_year}-12-25' dt<'{last_year}-12-26'", "John's location in call notes"]
|
||||||
|
|
||||||
|
Q: How can you help me?
|
||||||
|
Khoj: ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]
|
||||||
|
|
||||||
|
Q: What did I do for Christmas last year?
|
||||||
|
Khoj: ["What did I do for Christmas {last_year} dt>='{last_year}-12-25' dt<'{last_year}-12-26'"]
|
||||||
|
|
||||||
|
Q: How should I take care of my plants?
|
||||||
|
Khoj: ["What kind of plants do I have?", "What issues do my plants have?"]
|
||||||
|
|
||||||
|
Q: Who all did I meet here yesterday?
|
||||||
|
Khoj: ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]
|
||||||
|
|
||||||
|
Chat History:
|
||||||
{chat_history}
|
{chat_history}
|
||||||
<</SYS>> [/INST]</s>
|
What searches will you perform to answer the following question, using the chat history as reference? Respond with relevant search queries as list of strings.
|
||||||
<s>[INST] {query} [/INST]
|
Q: {query}
|
||||||
"""
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -260,7 +218,7 @@ User's Location: {location}
|
||||||
|
|
||||||
Q: How was my trip to Cambodia?
|
Q: How was my trip to Cambodia?
|
||||||
Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
|
Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
|
||||||
A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful.
|
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
|
||||||
|
|
||||||
Q: Who did i visit that temple with?
|
Q: Who did i visit that temple with?
|
||||||
Khoj: {{"queries": ["Who did I visit the Angkor Wat Temple in Cambodia with?"]}}
|
Khoj: {{"queries": ["Who did I visit the Angkor Wat Temple in Cambodia with?"]}}
|
||||||
|
@ -286,8 +244,8 @@ Q: What is their age difference?
|
||||||
Khoj: {{"queries": ["What is Bob's age?", "What is Tom's age?"]}}
|
Khoj: {{"queries": ["What is Bob's age?", "What is Tom's age?"]}}
|
||||||
A: Bob is {bob_tom_age_difference} years older than Tom. As Bob is {bob_age} years old and Tom is 30 years old.
|
A: Bob is {bob_tom_age_difference} years older than Tom. As Bob is {bob_age} years old and Tom is 30 years old.
|
||||||
|
|
||||||
Q: What does yesterday's note say?
|
Q: Who all did I meet here yesterday?
|
||||||
Khoj: {{"queries": ["Note from {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
|
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
|
||||||
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
|
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
|
||||||
|
|
||||||
{chat_history}
|
{chat_history}
|
||||||
|
@ -543,7 +501,6 @@ help_message = PromptTemplate.from_template(
|
||||||
- **/image**: Generate an image based on your message.
|
- **/image**: Generate an image based on your message.
|
||||||
- **/help**: Show this help message.
|
- **/help**: Show this help message.
|
||||||
|
|
||||||
|
|
||||||
You are using the **{model}** model on the **{device}**.
|
You are using the **{model}** model on the **{device}**.
|
||||||
**version**: {version}
|
**version**: {version}
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
|
@ -3,29 +3,28 @@ import logging
|
||||||
import queue
|
import queue
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
from llama_cpp.llama import Llama
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ClientApplication, KhojUser
|
from khoj.database.models import ClientApplication, KhojUser
|
||||||
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
model_to_prompt_size = {
|
model_to_prompt_size = {
|
||||||
"gpt-3.5-turbo": 3000,
|
"gpt-3.5-turbo": 12000,
|
||||||
"gpt-3.5-turbo-0125": 3000,
|
"gpt-3.5-turbo-0125": 12000,
|
||||||
"gpt-4-0125-preview": 7000,
|
"gpt-4-0125-preview": 20000,
|
||||||
"gpt-4-turbo-preview": 7000,
|
"gpt-4-turbo-preview": 20000,
|
||||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
|
||||||
"mistral-7b-instruct-v0.1.Q4_0.gguf": 1548,
|
"NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
|
||||||
}
|
|
||||||
model_to_tokenizer = {
|
|
||||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
|
||||||
"mistral-7b-instruct-v0.1.Q4_0.gguf": "mistralai/Mistral-7B-Instruct-v0.1",
|
|
||||||
}
|
}
|
||||||
|
model_to_tokenizer: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
class ThreadedGenerator:
|
class ThreadedGenerator:
|
||||||
|
@ -134,9 +133,10 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message,
|
user_message,
|
||||||
system_message,
|
system_message=None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model_name="gpt-3.5-turbo",
|
model_name="gpt-3.5-turbo",
|
||||||
|
loaded_model: Optional[Llama] = None,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
):
|
):
|
||||||
|
@ -159,7 +159,7 @@ def generate_chatml_messages_with_context(
|
||||||
chat_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
chat_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
||||||
chat_logs += [chat["message"] + chat_notes]
|
chat_logs += [chat["message"] + chat_notes]
|
||||||
|
|
||||||
rest_backnforths = []
|
rest_backnforths: List[ChatMessage] = []
|
||||||
# Extract in reverse chronological order
|
# Extract in reverse chronological order
|
||||||
for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]):
|
for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]):
|
||||||
if len(rest_backnforths) >= 2 * lookback_turns:
|
if len(rest_backnforths) >= 2 * lookback_turns:
|
||||||
|
@ -176,21 +176,30 @@ def generate_chatml_messages_with_context(
|
||||||
messages.append(ChatMessage(content=system_message, role="system"))
|
messages.append(ChatMessage(content=system_message, role="system"))
|
||||||
|
|
||||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||||
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
||||||
|
|
||||||
# Return message in chronological order
|
# Return message in chronological order
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
def truncate_messages(
|
def truncate_messages(
|
||||||
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
|
messages: list[ChatMessage],
|
||||||
|
max_prompt_size,
|
||||||
|
model_name: str,
|
||||||
|
loaded_model: Optional[Llama] = None,
|
||||||
|
tokenizer_name=None,
|
||||||
) -> list[ChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if model_name.startswith("gpt-"):
|
if loaded_model:
|
||||||
|
encoder = loaded_model.tokenizer()
|
||||||
|
elif model_name.startswith("gpt-"):
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
encoder = download_model(model_name).tokenizer()
|
||||||
|
except:
|
||||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||||
except:
|
except:
|
||||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||||
|
@ -223,12 +232,17 @@ def truncate_messages(
|
||||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||||
original_question = f"\n{original_question}"
|
original_question = f"\n{original_question}"
|
||||||
original_question_tokens = len(encoder.encode(original_question))
|
original_question_tokens = len(encoder.encode(original_question))
|
||||||
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
|
remaining_tokens = max_prompt_size - system_message_tokens
|
||||||
|
if remaining_tokens > original_question_tokens:
|
||||||
|
remaining_tokens -= original_question_tokens
|
||||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||||
|
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||||
|
else:
|
||||||
|
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
|
||||||
|
messages = [ChatMessage(content=truncated_message, role=messages[0].role)]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||||
)
|
)
|
||||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
|
||||||
|
|
||||||
return messages + [system_message] if system_message else messages
|
return messages + [system_message] if system_message else messages
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import OfflineChatProcessorModel
|
||||||
from khoj.utils.helpers import ConversationCommand, timer
|
from khoj.utils.helpers import ConversationCommand, timer
|
||||||
from khoj.utils.rawconfig import LocationData, SearchResponse
|
from khoj.utils.rawconfig import LocationData, SearchResponse
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
|
@ -318,16 +318,16 @@ async def extract_references_and_questions(
|
||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
||||||
chat_model = default_offline_llm.chat_model
|
chat_model = default_offline_llm.chat_model
|
||||||
if state.gpt4all_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
|
||||||
|
|
||||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
|
|
||||||
inferred_queries = extract_questions_offline(
|
inferred_queries = extract_questions_offline(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
should_extract_questions=False,
|
should_extract_questions=True,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
)
|
)
|
||||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
|
|
@ -33,7 +33,7 @@ from khoj.processor.conversation.utils import (
|
||||||
)
|
)
|
||||||
from khoj.routers.storage import upload_image
|
from khoj.routers.storage import upload_image
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import OfflineChatProcessorModel
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
|
@ -69,9 +69,9 @@ async def is_ready_to_chat(user: KhojUser):
|
||||||
|
|
||||||
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
|
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model = user_conversation_config.chat_model
|
||||||
if state.gpt4all_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
logger.info("Loading Offline Chat Model...")
|
logger.info("Loading Offline Chat Model...")
|
||||||
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
ready = has_openai_config or has_offline_config
|
ready = has_openai_config or has_offline_config
|
||||||
|
@ -327,10 +327,13 @@ async def generate_better_image_prompt(
|
||||||
Generate a better image prompt from the given query
|
Generate a better image prompt from the given query
|
||||||
"""
|
"""
|
||||||
|
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
|
||||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
if location_data:
|
||||||
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||||
location_prompt = prompts.user_location.format(location=location)
|
location_prompt = prompts.user_location.format(location=location)
|
||||||
|
else:
|
||||||
|
location_prompt = "Unknown"
|
||||||
|
|
||||||
user_references = "\n\n".join([f"# {item}" for item in note_references])
|
user_references = "\n\n".join([f"# {item}" for item in note_references])
|
||||||
|
|
||||||
|
@ -368,27 +371,31 @@ async def send_message_to_model_wrapper(
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
chat_model = conversation_config.chat_model
|
||||||
user_message=message, system_message=system_message, model_name=conversation_config.chat_model
|
|
||||||
)
|
|
||||||
|
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model)
|
||||||
|
|
||||||
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message, system_message=system_message, model_name=chat_model, loaded_model=loaded_model
|
||||||
|
)
|
||||||
|
|
||||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
message=truncated_messages[-1].content,
|
messages=truncated_messages,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
model=conversation_config.chat_model,
|
model=chat_model,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
system_message=truncated_messages[0].content,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == "openai":
|
elif conversation_config.model_type == "openai":
|
||||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = conversation_config.chat_model
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message, system_message=system_message, model_name=chat_model
|
||||||
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||||
)
|
)
|
||||||
|
@ -434,10 +441,10 @@ def generate_chat_response(
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
|
||||||
|
|
||||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
|
|
|
@ -70,15 +70,12 @@ class SearchModels:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPT4AllProcessorConfig:
|
class OfflineChatProcessorConfig:
|
||||||
loaded_model: Union[Any, None] = None
|
loaded_model: Union[Any, None] = None
|
||||||
|
|
||||||
|
|
||||||
class GPT4AllProcessorModel:
|
class OfflineChatProcessorModel:
|
||||||
def __init__(
|
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
|
||||||
self,
|
|
||||||
chat_model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
|
||||||
):
|
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.loaded_model = None
|
self.loaded_model = None
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -6,7 +6,7 @@ empty_escape_sequences = "\n|\r|\t| "
|
||||||
app_env_filepath = "~/.khoj/env"
|
app_env_filepath = "~/.khoj/env"
|
||||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||||
content_directory = "~/.khoj/content/"
|
content_directory = "~/.khoj/content/"
|
||||||
default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
default_offline_chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||||
default_online_chat_model = "gpt-4-turbo-preview"
|
default_online_chat_model = "gpt-4-turbo-preview"
|
||||||
|
|
||||||
empty_config = {
|
empty_config = {
|
||||||
|
|
|
@ -32,17 +32,13 @@ def initialization():
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Note: gpt4all package is not available on all devices.
|
|
||||||
# So ensure gpt4all package is installed before continuing this step.
|
|
||||||
import gpt4all
|
|
||||||
|
|
||||||
use_offline_model = input("Use offline chat model? (y/n): ")
|
use_offline_model = input("Use offline chat model? (y/n): ")
|
||||||
if use_offline_model == "y":
|
if use_offline_model == "y":
|
||||||
logger.info("🗣️ Setting up offline chat model")
|
logger.info("🗣️ Setting up offline chat model")
|
||||||
OfflineChatProcessorConversationConfig.objects.create(enabled=True)
|
OfflineChatProcessorConversationConfig.objects.create(enabled=True)
|
||||||
|
|
||||||
offline_chat_model = input(
|
offline_chat_model = input(
|
||||||
f"Enter the offline chat model you want to use, See GPT4All for supported models (default: {default_offline_chat_model}): "
|
f"Enter the offline chat model you want to use. See HuggingFace for available GGUF models (default: {default_offline_chat_model}): "
|
||||||
)
|
)
|
||||||
if offline_chat_model == "":
|
if offline_chat_model == "":
|
||||||
ChatModelOptions.objects.create(
|
ChatModelOptions.objects.create(
|
||||||
|
|
|
@ -91,7 +91,7 @@ class OpenAIProcessorConfig(ConfigBase):
|
||||||
|
|
||||||
class OfflineChatProcessorConfig(ConfigBase):
|
class OfflineChatProcessorConfig(ConfigBase):
|
||||||
enable_offline_chat: Optional[bool] = False
|
enable_offline_chat: Optional[bool] = False
|
||||||
chat_model: Optional[str] = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfig(ConfigBase):
|
class ConversationProcessorConfig(ConfigBase):
|
||||||
|
|
|
@ -9,7 +9,7 @@ from whisper import Whisper
|
||||||
|
|
||||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.utils import config as utils_config
|
from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import ContentIndex, GPT4AllProcessorModel, SearchModels
|
from khoj.utils.config import ContentIndex, OfflineChatProcessorModel, SearchModels
|
||||||
from khoj.utils.helpers import LRU, get_device
|
from khoj.utils.helpers import LRU, get_device
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ embeddings_model: Dict[str, EmbeddingsModel] = None
|
||||||
cross_encoder_model: Dict[str, CrossEncoderModel] = None
|
cross_encoder_model: Dict[str, CrossEncoderModel] = None
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
openai_client: OpenAI = None
|
openai_client: OpenAI = None
|
||||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
offline_chat_processor_config: OfflineChatProcessorModel = None
|
||||||
whisper_model: Whisper = None
|
whisper_model: Whisper = None
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
verbose: int = 0
|
verbose: int = 0
|
||||||
|
|
|
@ -40,9 +40,9 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = ChatModelOptions
|
model = ChatModelOptions
|
||||||
|
|
||||||
max_prompt_size = 2000
|
max_prompt_size = 3500
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||||
model_type = "offline"
|
model_type = "offline"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -96,3 +96,23 @@ class TestTruncateMessage:
|
||||||
assert final_tokens <= self.max_prompt_size
|
assert final_tokens <= self.max_prompt_size
|
||||||
assert len(chat_messages) == 1
|
assert len(chat_messages) == 1
|
||||||
assert truncated_chat_history[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
|
|
||||||
|
def test_truncate_single_large_question(self):
|
||||||
|
# Arrange
|
||||||
|
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
||||||
|
big_chat_message = ChatMessageFactory.build(content=big_chat_message_content)
|
||||||
|
big_chat_message.role = "user"
|
||||||
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
chat_messages = [big_chat_message]
|
||||||
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert initial_tokens > self.max_prompt_size
|
||||||
|
assert final_tokens <= self.max_prompt_size
|
||||||
|
assert len(chat_messages) == 1
|
||||||
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
|
|
|
@ -5,18 +5,12 @@ import pytest
|
||||||
SKIP_TESTS = True
|
SKIP_TESTS = True
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
SKIP_TESTS,
|
SKIP_TESTS,
|
||||||
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
|
reason="Disable in CI to avoid long test runs.",
|
||||||
)
|
)
|
||||||
|
|
||||||
import freezegun
|
import freezegun
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
|
||||||
try:
|
|
||||||
from gpt4all import GPT4All
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
|
||||||
|
|
||||||
|
|
||||||
from khoj.processor.conversation.offline.chat_model import (
|
from khoj.processor.conversation.offline.chat_model import (
|
||||||
converse_offline,
|
converse_offline,
|
||||||
extract_questions_offline,
|
extract_questions_offline,
|
||||||
|
@ -25,14 +19,12 @@ from khoj.processor.conversation.offline.chat_model import (
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
from khoj.routers.helpers import aget_relevant_output_modes
|
from khoj.routers.helpers import aget_relevant_output_modes
|
||||||
|
from khoj.utils.constants import default_offline_chat_model
|
||||||
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def loaded_model():
|
def loaded_model():
|
||||||
download_model(MODEL_NAME)
|
return download_model(default_offline_chat_model)
|
||||||
return GPT4All(MODEL_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
freezegun.configure(extend_ignore_list=["transformers"])
|
freezegun.configure(extend_ignore_list=["transformers"])
|
||||||
|
@ -40,7 +32,6 @@ freezegun.configure(extend_ignore_list=["transformers"])
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(reason="Search actor isn't very date aware nor capable of formatting")
|
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||||
|
@ -149,20 +140,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||||
message_list = [
|
message_list = [
|
||||||
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||||
]
|
]
|
||||||
|
query = "Does he have any sons?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = extract_questions_offline(
|
response = extract_questions_offline(
|
||||||
"Does he have any sons?",
|
query,
|
||||||
conversation_log=populate_chat_history(message_list),
|
conversation_log=populate_chat_history(message_list),
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
use_history=True,
|
use_history=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_expected_in_response = [
|
any_expected_with_barbara = [
|
||||||
"Anderson",
|
"sibling",
|
||||||
|
"brother",
|
||||||
]
|
]
|
||||||
|
|
||||||
any_expected_in_response = [
|
any_expected_with_anderson = [
|
||||||
"son",
|
"son",
|
||||||
"sons",
|
"sons",
|
||||||
"children",
|
"children",
|
||||||
|
@ -170,11 +163,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(response) >= 1
|
assert len(response) >= 1
|
||||||
assert all([expected_response in response[0] for expected_response in all_expected_in_response]), (
|
assert response[-1] == query, "Expected last question to be the user query, but got: " + response[-1]
|
||||||
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
# Ensure the remaining generated search queries use proper nouns and chat history context
|
||||||
|
for question in response[:-1]:
|
||||||
|
if "Barbara" in question:
|
||||||
|
assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), (
|
||||||
|
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||||
)
|
)
|
||||||
assert any([expected_response in response[0] for expected_response in any_expected_in_response]), (
|
elif "Anderson" in question:
|
||||||
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
assert any([expected_response in question for expected_response in any_expected_with_anderson]), (
|
||||||
|
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, (
|
||||||
|
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -312,6 +314,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_refuse_answering_unanswerable_question(loaded_model):
|
def test_refuse_answering_unanswerable_question(loaded_model):
|
||||||
"Chat actor should not try make up answers to unanswerable questions."
|
"Chat actor should not try make up answers to unanswerable questions."
|
||||||
|
@ -436,7 +439,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(reason="Chat actor doesn't ask clarifying questions when context is insufficient")
|
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||||
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
|
@ -15,7 +15,7 @@ from tests.helpers import ConversationFactory
|
||||||
SKIP_TESTS = True
|
SKIP_TESTS = True
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
SKIP_TESTS,
|
SKIP_TESTS,
|
||||||
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
|
reason="Disable in CI to avoid long test runs.",
|
||||||
)
|
)
|
||||||
|
|
||||||
fake = Faker()
|
fake = Faker()
|
||||||
|
@ -48,7 +48,7 @@ def create_conversation(message_list, user, agent=None):
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
|
def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat):
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
@ -339,7 +339,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert "23" in response_message
|
assert "26" in response_message
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
Loading…
Reference in a new issue