mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add additional personalization in Chat via Location, Username (#644)
* Add location metadata to chat history * Add support for custom configuration of the user name * Add region, country, city in the desktop app's URL for context in chat * Update prompts to specify user location, rather than just location. * Add location data to Obsidian chat query * Use first word for first name, last word for last name when setting profile name
This commit is contained in:
parent
a3eb17b7d4
commit
32ec54172e
18 changed files with 286 additions and 22 deletions
|
@ -31,6 +31,22 @@
|
|||
});
|
||||
}
|
||||
|
||||
let region = null;
|
||||
let city = null;
|
||||
let countryName = null;
|
||||
|
||||
fetch("https://ipapi.co/json")
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
region = data.region;
|
||||
city = data.city;
|
||||
countryName = data.country_name;
|
||||
})
|
||||
.catch(err => {
|
||||
console.log(err);
|
||||
return;
|
||||
});
|
||||
|
||||
function formatDate(date) {
|
||||
// Format date in HH:MM, DD MMM YYYY format
|
||||
let time_string = date.toLocaleTimeString('en-IN', { hour: '2-digit', minute: '2-digit', hour12: false });
|
||||
|
@ -337,7 +353,7 @@
|
|||
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
let url = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}`;
|
||||
let url = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}`;
|
||||
const khojToken = await window.tokenAPI.getToken();
|
||||
const headers = { 'Authorization': `Bearer ${khojToken}` };
|
||||
|
||||
|
|
|
@ -10,6 +10,9 @@ export interface ChatJsonResult {
|
|||
export class KhojChatModal extends Modal {
|
||||
result: string;
|
||||
setting: KhojSetting;
|
||||
region: string;
|
||||
city: string;
|
||||
countryName: string;
|
||||
|
||||
constructor(app: App, setting: KhojSetting) {
|
||||
super(app);
|
||||
|
@ -17,6 +20,19 @@ export class KhojChatModal extends Modal {
|
|||
|
||||
// Register Modal Keybindings to send user message
|
||||
this.scope.register([], 'Enter', async () => { await this.chat() });
|
||||
|
||||
|
||||
fetch("https://ipapi.co/json")
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
this.region = data.region;
|
||||
this.city = data.city;
|
||||
this.countryName = data.country_name;
|
||||
})
|
||||
.catch(err => {
|
||||
console.log(err);
|
||||
return;
|
||||
});
|
||||
}
|
||||
|
||||
async chat() {
|
||||
|
@ -354,7 +370,7 @@ export class KhojChatModal extends Modal {
|
|||
|
||||
// Get chat response from Khoj backend
|
||||
let encodedQuery = encodeURIComponent(query);
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true`;
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true®ion=${this.region}&city=${this.city}&country=${this.countryName}`;
|
||||
let responseElement = this.createKhojResponseDiv();
|
||||
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
|
|
|
@ -173,6 +173,24 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
|
|||
return user
|
||||
|
||||
|
||||
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
|
||||
user.first_name = first_name
|
||||
user.last_name = last_name
|
||||
user.save()
|
||||
return user
|
||||
|
||||
|
||||
def get_user_name(user: KhojUser):
|
||||
full_name = user.get_full_name()
|
||||
if not is_none_or_empty(full_name):
|
||||
return full_name
|
||||
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
|
||||
if google_profile:
|
||||
return google_profile.given_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_user_subscription(email: str) -> Optional[Subscription]:
|
||||
return Subscription.objects.filter(user__email=email).first()
|
||||
|
||||
|
@ -291,6 +309,17 @@ def delete_user_requests(window: timedelta = timedelta(days=1)):
|
|||
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
||||
|
||||
|
||||
async def aget_user_name(user: KhojUser):
|
||||
full_name = user.get_full_name()
|
||||
if not is_none_or_empty(full_name):
|
||||
return full_name
|
||||
google_profile: GoogleUser = await GoogleUser.objects.filter(user=user).afirst()
|
||||
if google_profile:
|
||||
return google_profile.given_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
|
||||
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
|
||||
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
|
||||
|
|
4
src/khoj/interface/web/assets/icons/user-silhouette.svg
Normal file
4
src/khoj/interface/web/assets/icons/user-silhouette.svg
Normal file
|
@ -0,0 +1,4 @@
|
|||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M3.95442 10.166C4.04608 9.76202 3.79293 9.36025 3.38898 9.26859C2.98504 9.17693 2.58327 9.43009 2.49161 9.83403L3.95442 10.166ZM5.49981 4.73283C5.19117 5.00907 5.1649 5.48322 5.44115 5.79187C5.71739 6.10051 6.19154 6.12678 6.50019 5.85053L5.49981 4.73283ZM15 14.25C14.5858 14.25 14.25 14.5858 14.25 15C14.25 15.4142 14.5858 15.75 15 15.75L15 14.25ZM17.25 18.7083C17.25 19.1225 17.5858 19.4583 18 19.4583C18.4142 19.4583 18.75 19.1225 18.75 18.7083H17.25ZM5.25 18.7083C5.25 19.1225 5.58579 19.4583 6 19.4583C6.41421 19.4583 6.75 19.1225 6.75 18.7083H5.25ZM9 15L8.99998 15.75H9V15ZM11 15.75C11.4142 15.75 11.75 15.4142 11.75 15C11.75 14.5858 11.4142 14.25 11 14.25V15.75ZM12 3.75C16.5563 3.75 20.25 7.44365 20.25 12H21.75C21.75 6.61522 17.3848 2.25 12 2.25V3.75ZM12 20.25C7.44365 20.25 3.75 16.5563 3.75 12H2.25C2.25 17.3848 6.61522 21.75 12 21.75V20.25ZM20.25 12C20.25 16.5563 16.5563 20.25 12 20.25V21.75C17.3848 21.75 21.75 17.3848 21.75 12H20.25ZM3.75 12C3.75 11.3688 3.82074 10.7551 3.95442 10.166L2.49161 9.83403C2.33338 10.5313 2.25 11.2564 2.25 12H3.75ZM6.50019 5.85053C7.96026 4.54373 9.88655 3.75 12 3.75V2.25C9.50333 2.25 7.22428 3.1894 5.49981 4.73283L6.50019 5.85053ZM14.25 9C14.25 10.2426 13.2426 11.25 12 11.25V12.75C14.0711 12.75 15.75 11.0711 15.75 9H14.25ZM12 11.25C10.7574 11.25 9.75 10.2426 9.75 9H8.25C8.25 11.0711 9.92893 12.75 12 12.75V11.25ZM9.75 9C9.75 7.75736 10.7574 6.75 12 6.75V5.25C9.92893 5.25 8.25 6.92893 8.25 9H9.75ZM12 6.75C13.2426 6.75 14.25 7.75736 14.25 9H15.75C15.75 6.92893 14.0711 5.25 12 5.25V6.75ZM15 15.75C15.6008 15.75 16.1482 16.0891 16.5769 16.6848C17.0089 17.2852 17.25 18.0598 17.25 18.7083H18.75C18.75 17.7371 18.4052 16.6575 17.7944 15.8086C17.1801 14.9551 16.2275 14.25 15 14.25L15 15.75ZM6.75 18.7083C6.75 18.0598 6.99109 17.2852 7.42315 16.6848C7.85183 16.0891 8.39919 15.75 8.99998 15.75L9.00002 14.25C7.77253 14.25 6.81989 14.9551 6.20564 15.8086C5.59477 16.6575 5.25 17.7371 5.25 18.7083H6.75ZM9 15.75H11V14.25H9V15.75Z" fill="#000000"/>
|
||||
</svg>
|
After Width: | Height: | Size: 2.2 KiB |
|
@ -103,7 +103,7 @@
|
|||
|
||||
.section-title {
|
||||
margin: 0;
|
||||
padding: 0 0 16px 0;
|
||||
padding: 12px 0 16px 0;
|
||||
font-size: 32;
|
||||
font-weight: normal;
|
||||
}
|
||||
|
@ -361,6 +361,11 @@
|
|||
margin-right: 8px;
|
||||
}
|
||||
|
||||
input#profile_given_name {
|
||||
width: 100%;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
@media screen and (max-width: 700px) {
|
||||
.section-cards {
|
||||
grid-template-columns: 1fr;
|
||||
|
|
|
@ -43,6 +43,22 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
});
|
||||
}
|
||||
|
||||
let region = null;
|
||||
let city = null;
|
||||
let countryName = null;
|
||||
|
||||
fetch("https://ipapi.co/json")
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
region = data.region;
|
||||
city = data.city;
|
||||
countryName = data.country_name;
|
||||
})
|
||||
.catch(err => {
|
||||
console.log(err);
|
||||
return;
|
||||
});
|
||||
|
||||
function formatDate(date) {
|
||||
// Format date in HH:MM, DD MMM YYYY format
|
||||
let time_string = date.toLocaleTimeString('en-IN', { hour: '2-digit', minute: '2-digit', hour12: false });
|
||||
|
@ -345,7 +361,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
}
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}`;
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}®ion=${region}&city=${city}&country=${countryName}`;
|
||||
|
||||
let new_response = document.createElement("div");
|
||||
new_response.classList.add("chat-message", "khoj");
|
||||
|
|
|
@ -3,6 +3,25 @@
|
|||
|
||||
<div class="page">
|
||||
<div id="content" class="section">
|
||||
<h2 class="section-title">Profile</h2>
|
||||
<div class="section-cards">
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/user-silhouette.svg" alt="Profile Name">
|
||||
<h3 class="card-title">
|
||||
Name
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<input type="text" id="profile_given_name" class="form-control" placeholder="Enter your name here" value="{{ given_name }}">
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<button id="save-model" class="card-button happy" onclick="saveProfileGivenName()">
|
||||
Save
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<h2 class="section-title">Content</h2>
|
||||
<button id="compute-index-size" class="card-button" onclick="getIndexedDataSize()">
|
||||
Data Usage
|
||||
|
@ -297,6 +316,27 @@
|
|||
</div>
|
||||
<script>
|
||||
|
||||
function saveProfileGivenName() {
|
||||
const givenName = document.getElementById("profile_given_name").value;
|
||||
fetch('/api/config/user/name?name=' + givenName, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status == "ok") {
|
||||
let notificationBanner = document.getElementById("notification-banner");
|
||||
notificationBanner.innerHTML = "Profile name has been updated!";
|
||||
notificationBanner.style.display = "block";
|
||||
setTimeout(function() {
|
||||
notificationBanner.style.display = "none";
|
||||
}, 5000);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function updateChatModel() {
|
||||
const chatModel = document.getElementById("chat-models").value;
|
||||
const saveModelButton = document.getElementById("save-model");
|
||||
|
@ -314,6 +354,14 @@
|
|||
if (data.status == "ok") {
|
||||
saveModelButton.innerHTML = "Save";
|
||||
saveModelButton.disabled = false;
|
||||
|
||||
let notificationBanner = document.getElementById("notification-banner");
|
||||
notificationBanner.innerHTML = "Conversation model has been updated!";
|
||||
notificationBanner.style.display = "block";
|
||||
setTimeout(function() {
|
||||
notificationBanner.style.display = "none";
|
||||
}, 5000);
|
||||
|
||||
} else {
|
||||
saveModelButton.innerHTML = "Error";
|
||||
saveModelButton.disabled = false;
|
||||
|
|
|
@ -13,6 +13,7 @@ from khoj.processor.conversation.utils import (
|
|||
from khoj.utils import state
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -24,6 +25,7 @@ def extract_questions_offline(
|
|||
conversation_log={},
|
||||
use_history: bool = True,
|
||||
should_extract_questions: bool = True,
|
||||
location_data: LocationData = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -45,6 +47,8 @@ def extract_questions_offline(
|
|||
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = ""
|
||||
|
||||
|
@ -68,6 +72,7 @@ def extract_questions_offline(
|
|||
last_year=last_year,
|
||||
last_christmas_date=last_christmas_date,
|
||||
next_christmas_date=next_christmas_date,
|
||||
location=location,
|
||||
)
|
||||
message = system_prompt + example_questions
|
||||
state.chat_lock.acquire()
|
||||
|
@ -133,6 +138,8 @@ def converse_offline(
|
|||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
|
@ -150,6 +157,15 @@ def converse_offline(
|
|||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
conversation_primer = f"{location_prompt}\n{conversation_primer}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
conversation_primer = f"{user_name_prompt}\n{conversation_primer}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
@ -13,6 +12,7 @@ from khoj.processor.conversation.openai.utils import (
|
|||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -24,6 +24,7 @@ def extract_questions(
|
|||
api_key=None,
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
location_data: LocationData = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -32,6 +33,8 @@ def extract_questions(
|
|||
def _valid_question(question: str):
|
||||
return not is_none_or_empty(question) and question != "[]"
|
||||
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
|
@ -56,6 +59,7 @@ def extract_questions(
|
|||
chat_history=chat_history,
|
||||
text=text,
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
)
|
||||
messages = [ChatMessage(content=prompt, role="assistant")]
|
||||
|
||||
|
@ -125,6 +129,8 @@ def converse(
|
|||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
|
@ -135,6 +141,15 @@ def converse(
|
|||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
conversation_primer = f"{location_prompt}\n{conversation_primer}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
conversation_primer = f"{user_name_prompt}\n{conversation_primer}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
|
|
|
@ -122,6 +122,9 @@ image_generation_improve_prompt = PromptTemplate.from_template(
|
|||
"""
|
||||
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information from the query. Use the conversation log to inform your response.
|
||||
|
||||
Today's Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
Conversation Log:
|
||||
{chat_history}
|
||||
|
||||
|
@ -181,7 +184,7 @@ Answer (in second person):"""
|
|||
## --
|
||||
extract_questions_gpt4all_sample = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
|
||||
<s>[INST] <<SYS>>Current Date: {current_date}. User's Location: {location}<</SYS>> [/INST]</s>
|
||||
<s>[INST] How was my trip to Cambodia? [/INST]
|
||||
How was my trip to Cambodia?</s>
|
||||
<s>[INST] Who did I visit the temple with on that trip? [/INST]
|
||||
|
@ -215,6 +218,7 @@ You are Khoj, an extremely smart and helpful search assistant with the ability t
|
|||
What searches, if any, will you need to perform to answer the users question?
|
||||
Provide search queries as a JSON list of strings
|
||||
Current Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
Q: How was my trip to Cambodia?
|
||||
|
||||
|
@ -355,6 +359,7 @@ You are Khoj, an extremely smart and helpful search assistant. You are tasked wi
|
|||
What Google searches, if any, will you need to perform to answer the user's question?
|
||||
Provide search queries as a list of strings
|
||||
Current Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
Here are some examples:
|
||||
History:
|
||||
|
@ -431,3 +436,17 @@ You are using the **{model}** model on the **{device}**.
|
|||
**version**: {version}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
# Personalization to the user
|
||||
# --
|
||||
user_location = PromptTemplate.from_template(
|
||||
"""
|
||||
User's Location: {location}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
user_name = PromptTemplate.from_template(
|
||||
"""
|
||||
User's Name: {name}
|
||||
""".strip()
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@ import requests
|
|||
|
||||
from khoj.routers.helpers import extract_relevant_info, generate_online_subqueries
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -33,7 +34,7 @@ OLOSTEP_QUERY_PARAMS = {
|
|||
}
|
||||
|
||||
|
||||
async def search_with_google(query: str, conversation_history: dict):
|
||||
async def search_with_google(query: str, conversation_history: dict, location: LocationData):
|
||||
def _search_with_google(subquery: str):
|
||||
payload = json.dumps(
|
||||
{
|
||||
|
@ -62,7 +63,7 @@ async def search_with_google(query: str, conversation_history: dict):
|
|||
raise ValueError("SERPER_DEV_API_KEY is not set")
|
||||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(query, conversation_history)
|
||||
subqueries = await generate_online_subqueries(query, conversation_history, location)
|
||||
|
||||
response_dict = {}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ from khoj.search_type import image_search, text_search
|
|||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, timer
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
from khoj.utils.rawconfig import LocationData, SearchResponse
|
||||
from khoj.utils.state import SearchType
|
||||
|
||||
# Initialize Router
|
||||
|
@ -275,6 +275,7 @@ async def extract_references_and_questions(
|
|||
n: int,
|
||||
d: float,
|
||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
location_data: LocationData = None,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
|
@ -320,7 +321,11 @@ async def extract_references_and_questions(
|
|||
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||
|
||||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
defiltered_query,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
should_extract_questions=False,
|
||||
location_data=location_data,
|
||||
)
|
||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
|
@ -328,7 +333,11 @@ async def extract_references_and_questions(
|
|||
api_key = openai_chat_config.api_key
|
||||
chat_model = default_openai_llm.chat_model
|
||||
inferred_queries = extract_questions(
|
||||
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
|
||||
defiltered_query,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
|
|
|
@ -10,7 +10,7 @@ from fastapi.requests import Request
|
|||
from fastapi.responses import Response, StreamingResponse
|
||||
from starlette.authentication import requires
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||
|
@ -36,6 +36,7 @@ from khoj.utils.helpers import (
|
|||
get_device,
|
||||
is_none_or_empty,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
# Initialize Router
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -219,6 +220,9 @@ async def chat(
|
|||
stream: Optional[bool] = False,
|
||||
slug: Optional[str] = None,
|
||||
conversation_id: Optional[int] = None,
|
||||
city: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
country: Optional[str] = None,
|
||||
rate_limiter_per_minute=Depends(
|
||||
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
),
|
||||
|
@ -251,8 +255,15 @@ async def chat(
|
|||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
location = None
|
||||
|
||||
if city or region or country:
|
||||
location = LocationData(city=city, region=region, country=country)
|
||||
|
||||
user_name = await aget_user_name(user)
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands
|
||||
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
|
||||
)
|
||||
online_results: Dict = dict()
|
||||
|
||||
|
@ -269,7 +280,7 @@ async def chat(
|
|||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
online_results = await search_with_google(defiltered_query, meta_log)
|
||||
online_results = await search_with_google(defiltered_query, meta_log, location)
|
||||
except ValueError as e:
|
||||
return StreamingResponse(
|
||||
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
|
||||
|
@ -284,7 +295,7 @@ async def chat(
|
|||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
**common.__dict__,
|
||||
)
|
||||
image, status_code, improved_image_prompt = await text_to_image(q, meta_log)
|
||||
image, status_code, improved_image_prompt = await text_to_image(q, meta_log, location_data=location)
|
||||
if image is None:
|
||||
content_obj = {
|
||||
"image": image,
|
||||
|
@ -316,6 +327,8 @@ async def chat(
|
|||
user,
|
||||
request.user.client_app,
|
||||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
)
|
||||
|
||||
chat_metadata.update({"conversation_command": ",".join([cmd.value for cmd in conversation_commands])})
|
||||
|
|
|
@ -290,6 +290,38 @@ async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
|||
)
|
||||
|
||||
|
||||
@api_config.post("/user/name", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
def set_user_name(
|
||||
request: Request,
|
||||
name: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
split_name = name.split(" ")
|
||||
|
||||
if len(split_name) > 2:
|
||||
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
|
||||
|
||||
if len(split_name) == 1:
|
||||
first_name = split_name[0]
|
||||
last_name = ""
|
||||
else:
|
||||
first_name, last_name = split_name[0], split_name[-1]
|
||||
|
||||
adapters.set_user_name(user, first_name, last_name)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_user_name",
|
||||
client=client,
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.get("/types", response_model=List[str])
|
||||
@requires(["authenticated"])
|
||||
def get_config_types(
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import partial
|
||||
|
@ -9,6 +8,7 @@ from time import time
|
|||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from starlette.authentication import has_required_scope
|
||||
|
||||
|
@ -39,6 +39,7 @@ from khoj.utils.helpers import (
|
|||
log_telemetry,
|
||||
tool_descriptions_for_llm,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -180,10 +181,11 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
|
|||
return [ConversationCommand.Default]
|
||||
|
||||
|
||||
async def generate_online_subqueries(q: str, conversation_history: dict) -> List[str]:
|
||||
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
"""
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
|
@ -191,6 +193,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict) -> List
|
|||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(online_queries_prompt)
|
||||
|
@ -227,14 +230,19 @@ async def extract_relevant_info(q: str, corpus: dict) -> List[str]:
|
|||
return response.strip()
|
||||
|
||||
|
||||
async def generate_better_image_prompt(q: str, conversation_history: str) -> str:
|
||||
async def generate_better_image_prompt(q: str, conversation_history: str, location_data: LocationData) -> str:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
"""
|
||||
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
image_prompt = prompts.image_generation_improve_prompt.format(
|
||||
query=q,
|
||||
chat_history=conversation_history,
|
||||
location=location,
|
||||
current_date=today_date,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(image_prompt)
|
||||
|
@ -293,6 +301,8 @@ def generate_chat_response(
|
|||
user: KhojUser = None,
|
||||
client_application: ClientApplication = None,
|
||||
conversation_id: int = None,
|
||||
location_data: LocationData = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -330,6 +340,8 @@ def generate_chat_response(
|
|||
model=conversation_config.chat_model,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
|
@ -347,6 +359,8 @@ def generate_chat_response(
|
|||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
@ -358,7 +372,9 @@ def generate_chat_response(
|
|||
return chat_response, metadata
|
||||
|
||||
|
||||
async def text_to_image(message: str, conversation_log: dict) -> Tuple[Optional[str], int, Optional[str]]:
|
||||
async def text_to_image(
|
||||
message: str, conversation_log: dict, location_data: LocationData
|
||||
) -> Tuple[Optional[str], int, Optional[str]]:
|
||||
status_code = 200
|
||||
image = None
|
||||
|
||||
|
@ -373,7 +389,7 @@ async def text_to_image(message: str, conversation_log: dict) -> Tuple[Optional[
|
|||
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: {chat['message']}\n"
|
||||
improved_image_prompt = await generate_better_image_prompt(message, chat_history)
|
||||
improved_image_prompt = await generate_better_image_prompt(message, chat_history, location_data=location_data)
|
||||
try:
|
||||
response = state.openai_client.images.generate(
|
||||
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
|
||||
|
|
|
@ -13,6 +13,7 @@ from khoj.database.adapters import (
|
|||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
get_user_github_config,
|
||||
get_user_name,
|
||||
get_user_notion_config,
|
||||
get_user_subscription_state,
|
||||
)
|
||||
|
@ -138,6 +139,7 @@ def config_page(request: Request):
|
|||
if user_subscription and user_subscription.renewal_date
|
||||
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
|
||||
)
|
||||
given_name = get_user_name(user)
|
||||
|
||||
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
|
||||
successfully_configured = {
|
||||
|
@ -166,6 +168,7 @@ def config_page(request: Request):
|
|||
"current_model_state": successfully_configured,
|
||||
"anonymous_mode": state.anonymous_mode,
|
||||
"username": user.username,
|
||||
"given_name": given_name,
|
||||
"conversation_options": all_conversation_options,
|
||||
"search_model_options": all_search_model_options,
|
||||
"selected_search_model_config": current_search_model_option.id,
|
||||
|
|
|
@ -283,8 +283,8 @@ command_descriptions = {
|
|||
}
|
||||
|
||||
tool_descriptions_for_llm = {
|
||||
ConversationCommand.Default: "Use this if there might be a mix of general and personal knowledge in the question",
|
||||
ConversationCommand.General: "Use this when you can answer the question without needing any additional online or personal information",
|
||||
ConversationCommand.Default: "Use this if there might be a mix of general and personal knowledge in the question, or if you can't make sense of the query",
|
||||
ConversationCommand.General: "Use this when you can answer the question without any outside information or personal knowledge",
|
||||
ConversationCommand.Notes: "Use this when you would like to use the user's personal knowledge base to answer the question",
|
||||
ConversationCommand.Online: "Use this when you would like to look up information on the internet",
|
||||
}
|
||||
|
|
|
@ -21,6 +21,12 @@ class ConfigBase(BaseModel):
|
|||
return setattr(self, key, value)
|
||||
|
||||
|
||||
class LocationData(BaseModel):
|
||||
city: Optional[str]
|
||||
region: Optional[str]
|
||||
country: Optional[str]
|
||||
|
||||
|
||||
class TextConfigBase(ConfigBase):
|
||||
compressed_jsonl: Path
|
||||
embeddings_file: Path
|
||||
|
|
Loading…
Reference in a new issue