Make Online Search Location Aware (#929)

## Overview
Add user country code as context for doing online search with serper.dev API.
This should find more user relevant results from online searches by Khoj

## Details
### Major
- Default to using system clock to infer user timezone on js clients
- Infer country from timezone when only timezone received by chat API
- Localize online search results to user country when location available

### Minor
- Add `__str__` func to `LocationData` class to deduplicate location string generation
This commit is contained in:
Debanjum 2024-10-03 12:33:47 -07:00 committed by GitHub
commit 4a1cb50da3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 95 additions and 53 deletions

View file

@ -60,7 +60,8 @@
let region = null;
let city = null;
let countryName = null;
let timezone = null;
let countryCode = null;
let timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
let chatMessageState = {
newResponseTextEl: null,
newResponseEl: null,
@ -76,6 +77,7 @@
region = data.region;
city = data.city;
countryName = data.country_name;
countryCode = data.country_code;
timezone = data.timezone;
})
.catch(err => {
@ -157,6 +159,7 @@
...(!!city && { city: city }),
...(!!region && { region: region }),
...(!!countryName && { country: countryName }),
...(!!countryCode && { country_code: countryCode }),
...(!!timezone && { timezone: timezone }),
};

View file

@ -308,18 +308,19 @@
<script src="./utils.js"></script>
<script src="./chatutils.js"></script>
<script>
let region = null;
let city = null;
let countryName = null;
let timezone = null;
let countryCode = null;
let timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
fetch("https://ipapi.co/json")
.then(response => response.json())
.then(data => {
region = data.region;
city = data.city;
region = data.region;
countryName = data.country_name;
countryCode = data.country_code;
timezone = data.timezone;
})
.catch(err => {
@ -410,6 +411,7 @@
...(!!city && { city: city }),
...(!!region && { region: region }),
...(!!countryName && { country: countryName }),
...(!!countryCode && { country_code: countryCode }),
...(!!timezone && { timezone: timezone }),
};

View file

@ -33,9 +33,10 @@ interface ChatMessageState {
}
interface Location {
region: string;
city: string;
countryName: string;
region?: string;
city?: string;
countryName?: string;
countryCode?: string;
timezone: string;
}
@ -43,7 +44,7 @@ export class KhojChatView extends KhojPaneView {
result: string;
setting: KhojSetting;
waitingForLocation: boolean;
location: Location;
location: Location = { timezone: Intl.DateTimeFormat().resolvedOptions().timeZone };
keyPressTimeout: NodeJS.Timeout | null = null;
userMessages: string[] = []; // Store user sent messages for input history cycling
currentMessageIndex: number = -1; // Track current message index in userMessages array
@ -70,6 +71,7 @@ export class KhojChatView extends KhojPaneView {
region: data.region,
city: data.city,
countryName: data.country_name,
countryCode: data.country_code,
timezone: data.timezone,
};
})
@ -1056,12 +1058,11 @@ export class KhojChatView extends KhojPaneView {
n: this.setting.resultsCount,
stream: true,
...(!!conversationId && { conversation_id: conversationId }),
...(!!this.location && {
city: this.location.city,
region: this.location.region,
country: this.location.countryName,
timezone: this.location.timezone,
}),
...(!!this.location && this.location.city && { city: this.location.city }),
...(!!this.location && this.location.region && { region: this.location.region }),
...(!!this.location && this.location.countryName && { country: this.location.countryName }),
...(!!this.location && this.location.countryCode && { country_code: this.location.countryCode }),
...(!!this.location && this.location.timezone && { timezone: this.location.timezone }),
};
let newResponseEl = this.createKhojResponseDiv();

View file

@ -518,12 +518,14 @@ function EditCard(props: EditCardProps) {
updateQueryUrl += `&subject=${encodeURIComponent(values.subject)}`;
}
updateQueryUrl += `&crontime=${encodeURIComponent(cronFrequency)}`;
if (props.locationData) {
if (props.locationData && props.locationData.city)
updateQueryUrl += `&city=${encodeURIComponent(props.locationData.city)}`;
if (props.locationData && props.locationData.region)
updateQueryUrl += `&region=${encodeURIComponent(props.locationData.region)}`;
if (props.locationData && props.locationData.country)
updateQueryUrl += `&country=${encodeURIComponent(props.locationData.country)}`;
if (props.locationData && props.locationData.timezone)
updateQueryUrl += `&timezone=${encodeURIComponent(props.locationData.timezone)}`;
}
let method = props.createNew ? "POST" : "PUT";

View file

@ -136,7 +136,9 @@ export default function Chat() {
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [image64, setImage64] = useState<string>("");
const locationData = useIPLocationData();
const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
};
const authenticatedData = useAuthenticatedData();
const isMobileWidth = useIsMobileWidth();
@ -241,9 +243,10 @@ export default function Chat() {
conversation_id: conversationId,
stream: true,
...(locationData && {
city: locationData.city,
region: locationData.region,
country: locationData.country,
city: locationData.city,
country_code: locationData.countryCode,
timezone: locationData.timezone,
}),
...(image64 && { image: image64 }),

View file

@ -2,13 +2,10 @@ import { useEffect, useState } from "react";
import useSWR from "swr";
export interface LocationData {
ip: string;
city: string;
region: string;
country: string;
postal: string;
latitude: number;
longitude: number;
city?: string;
region?: string;
country?: string;
countryCode?: string;
timezone: string;
}
@ -50,9 +47,7 @@ export function useIPLocationData() {
{ revalidateOnFocus: false },
);
if (locationDataError) return null;
if (!locationData) return null;
if (locationDataError || !locationData) return;
return locationData;
}

View file

@ -111,7 +111,9 @@ export default function SharedChat() {
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
const [image64, setImage64] = useState<string>("");
const locationData = useIPLocationData();
const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
};
const authenticatedData = useAuthenticatedData();
const isMobileWidth = useIsMobileWidth();
@ -231,6 +233,7 @@ export default function SharedChat() {
region: locationData.region,
country: locationData.country,
city: locationData.city,
country_code: locationData.countryCode,
timezone: locationData.timezone,
}),
...(image64 && { image: image64 }),

View file

@ -32,7 +32,7 @@ def extract_questions_anthropic(
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
@ -158,8 +158,7 @@ def converse_anthropic(
)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:

View file

@ -33,7 +33,7 @@ def extract_questions_gemini(
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
@ -163,8 +163,7 @@ def converse_gemini(
)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:

View file

@ -46,7 +46,7 @@ def extract_questions_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
@ -171,8 +171,7 @@ def converse_offline(
conversation_primer = prompts.query_prompt.format(query=user_query)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:

View file

@ -36,7 +36,7 @@ def extract_questions(
"""
Infer search queries to retrieve relevant notes to answer user query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
@ -159,8 +159,7 @@ def converse(
)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:

View file

@ -7,7 +7,6 @@ from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Union
import aiohttp
import requests
from bs4 import BeautifulSoup
from markdownify import markdownify
@ -80,7 +79,7 @@ async def search_online(
with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery) for subquery in subqueries]
search_tasks = [search_func(subquery, location) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks)
response_dict = {subquery: search_result for subquery, search_result in search_results}
@ -115,8 +114,9 @@ async def search_online(
yield response_dict
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
payload = json.dumps({"q": query})
async def search_with_google(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
country_code = location.country_code.lower() if location and location.country_code else "us"
payload = json.dumps({"q": query, "gl": country_code})
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
async with aiohttp.ClientSession() as session:
@ -220,7 +220,7 @@ async def read_webpage_with_jina(web_url: str) -> str:
return response_json["data"]["content"]
async def search_with_jina(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
encoded_query = urllib.parse.quote(query)
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
headers = {"Accept": "application/json"}

View file

@ -55,6 +55,8 @@ from khoj.utils.helpers import (
ConversationCommand,
command_descriptions,
convert_image_to_webp,
get_country_code_from_timezone,
get_country_name_from_timezone,
get_device,
is_none_or_empty,
)
@ -529,6 +531,7 @@ class ChatRequestBody(BaseModel):
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
image: Optional[str] = None
create_new: Optional[bool] = False
@ -556,7 +559,8 @@ async def chat(
conversation_id = body.conversation_id
city = body.city
region = body.region
country = body.country
country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone
image = body.image
@ -658,8 +662,8 @@ async def chat(
user_name = await aget_user_name(user)
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):

View file

@ -369,7 +369,7 @@ async def infer_webpage_urls(
"""
Infer webpage links from the given query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
@ -405,7 +405,7 @@ async def generate_online_subqueries(
"""
Generate subqueries from the given query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
@ -535,8 +535,7 @@ async def generate_better_image_prompt(
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
else:
location_prompt = "Unknown"

View file

@ -9,6 +9,7 @@ import random
import uuid
from collections import OrderedDict
from enum import Enum
from functools import lru_cache
from importlib import import_module
from importlib.metadata import version
from itertools import islice
@ -24,6 +25,7 @@ import torch
from asgiref.sync import sync_to_async
from magika import Magika
from PIL import Image
from pytz import country_names, country_timezones
from khoj.utils import constants
@ -431,3 +433,24 @@ def convert_image_to_webp(image_bytes):
webp_image_bytes = webp_image_io.getvalue()
webp_image_io.close()
return webp_image_bytes
@lru_cache
def tz_to_cc_map() -> dict[str, str]:
"""Create a mapping of timezone to country code"""
timezone_country = {}
for countrycode in country_timezones:
timezones = country_timezones[countrycode]
for timezone in timezones:
timezone_country[timezone] = countrycode
return timezone_country
def get_country_code_from_timezone(tz: str) -> str:
"""Get country code from timezone"""
return tz_to_cc_map().get(tz, "US")
def get_country_name_from_timezone(tz: str) -> str:
"""Get country name from timezone"""
return country_names.get(get_country_code_from_timezone(tz), "United States")

View file

@ -25,6 +25,17 @@ class LocationData(BaseModel):
city: Optional[str]
region: Optional[str]
country: Optional[str]
country_code: Optional[str]
def __str__(self):
parts = []
if self.city:
parts.append(self.city)
if self.region:
parts.append(self.region)
if self.country:
parts.append(self.country)
return ", ".join(parts)
class FileFilterRequest(BaseModel):