Merge branch 'master' into features/advanced-reasoning

This commit is contained in:
Debanjum Singh Solanky 2024-10-15 01:27:36 -07:00
commit feb6d65ef8
9 changed files with 693 additions and 515 deletions

File diff suppressed because it is too large Load diff

View file

@ -696,10 +696,12 @@ class AgentAdapters:
files: List[str],
input_tools: List[str],
output_modes: List[str],
slug: Optional[str] = None,
):
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
agent, created = await Agent.objects.filter(name=name, creator=user).aupdate_or_create(
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
defaults={
"name": name,
"creator": user,

View file

@ -114,6 +114,7 @@ class CrossEncoderModel:
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(target_url, json=payload, headers=headers)
response.raise_for_status()
return response.json()["scores"]
cross_inp = [[query, hit.additional[key]] for hit in hits]

View file

@ -143,7 +143,6 @@ async def read_webpages(
conversation_history: dict,
location: LocationData,
user: KhojUser,
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
agent: Agent = None,

View file

@ -35,6 +35,7 @@ class ModifyAgentBody(BaseModel):
files: Optional[List[str]] = []
input_tools: Optional[List[str]] = []
output_modes: Optional[List[str]] = []
slug: Optional[str] = None
@api_agents.get("", response_class=Response)
@ -192,6 +193,7 @@ async def create_agent(
body.files,
body.input_tools,
body.output_modes,
body.slug,
)
agents_packet = {
@ -233,7 +235,7 @@ async def update_agent(
status_code=400,
)
selected_agent = await AgentAdapters.aget_agent_by_name(body.name, user)
selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
if not selected_agent:
return Response(
@ -253,6 +255,7 @@ async def update_agent(
body.files,
body.input_tools,
body.output_modes,
body.slug,
)
agents_packet = {

View file

@ -213,7 +213,7 @@ def chat_history(
agent_metadata = None
if conversation.agent:
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE:
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE and conversation.agent.creator != user:
conversation.agent = None
else:
agent_metadata = {
@ -853,27 +853,36 @@ async def chat(
return
# # Gather Context
# async for result in extract_references_and_questions(
# request,
# meta_log,
# q,
# (n or 7),
# d,
# conversation_id,
# conversation_commands,
# location,
# partial(send_event, ChatEvent.STATUS),
# uploaded_image_url=uploaded_image_url,
# agent=agent,
# ):
# if isinstance(result, dict) and ChatEvent.STATUS in result:
# yield result[ChatEvent.STATUS]
# else:
# compiled_references.extend(result[0])
# inferred_queries.extend(result[1])
# defiltered_query = result[2]
# if not is_none_or_empty(compiled_references):
# # Extract Document References
# try:
# async for result in extract_references_and_questions(
# request,
# meta_log,
# q,
# (n or 7),
# d,
# conversation_id,
# conversation_commands,
# location,
# partial(send_event, ChatEvent.STATUS),
# uploaded_image_url=uploaded_image_url,
# agent=agent,
# ):
# if isinstance(result, dict) and ChatEvent.STATUS in result:
# yield result[ChatEvent.STATUS]
# else:
# compiled_references.extend(result[0])
# inferred_queries.extend(result[1])
# defiltered_query = result[2]
# except Exception as e:
# error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
# logger.warning(error_message)
# async for result in send_event(
# ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
# ):
# yield result
#
# # if not is_none_or_empty(compiled_references):
# try:
# headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
# # Strip only leading # from headings
@ -910,12 +919,13 @@ async def chat(
yield result[ChatEvent.STATUS]
else:
online_results = result
except ValueError as e:
except Exception as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
async for result in send_llm_response(error_message):
async for result in send_event(
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
):
yield result
return
## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands and pending_research:
@ -925,7 +935,6 @@ async def chat(
meta_log,
location,
user,
subscribed,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
@ -945,11 +954,15 @@ async def chat(
webpages.append(webpage["link"])
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
yield result
except ValueError as e:
except Exception as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results",
f"Error reading webpages: {e}. Attempting to respond without webpage results",
exc_info=True,
)
async for result in send_event(
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
):
yield result
## Gather Code Results
if ConversationCommand.Code in conversation_commands and pending_research:

View file

@ -345,13 +345,13 @@ async def aget_relevant_information_sources(
final_response = [ConversationCommand.Default]
else:
final_response = [ConversationCommand.General]
return final_response
except Exception as e:
except Exception:
logger.error(f"Invalid response for determining relevant tools: {response}")
if len(agent_tools) == 0:
final_response = [ConversationCommand.Default]
else:
final_response = agent_tools
return final_response
async def aget_relevant_output_modes(

View file

@ -227,7 +227,6 @@ async def execute_information_collection(
conversation_history,
location,
user,
subscribed,
send_status_func,
uploaded_image_url=uploaded_image_url,
agent=agent,

View file

@ -3,6 +3,7 @@ import math
from pathlib import Path
from typing import List, Optional, Tuple, Type, Union
import requests
import torch
from asgiref.sync import sync_to_async
from sentence_transformers import util
@ -231,8 +232,12 @@ def setup(
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
try:
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
except requests.exceptions.HTTPError as e:
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
cross_scores = [0.0] * len(hits)
# Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)):