Try respond even if document search via inference endpoint fails

The huggingface endpoint can be flaky. Khoj shouldn't refuse to
respond to user if document search fails.
It should transparently mention that document lookup failed.
But try respond as best as it can without the document references

This changes provides graceful failover when inference endpoint
requests fail either when encoding query or reranking retrieved docs
This commit is contained in:
Debanjum Singh Solanky 2024-10-14 17:39:44 -07:00
parent 9affeb9e85
commit 1b04b801c6
3 changed files with 35 additions and 22 deletions

View file

@ -114,6 +114,7 @@ class CrossEncoderModel:
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(target_url, json=payload, headers=headers)
response.raise_for_status()
return response.json()["scores"]
cross_inp = [[query, hit.additional[key]] for hit in hits]

View file

@ -3,7 +3,6 @@ import base64
import json
import logging
import time
import warnings
from datetime import datetime
from functools import partial
from typing import Dict, Optional
@ -839,25 +838,33 @@ async def chat(
# Gather Context
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
d,
conversation_id,
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
try:
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
d,
conversation_id,
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
except Exception as e:
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
logger.warning(error_message)
async for result in send_event(
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
):
yield result
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))

View file

@ -3,6 +3,7 @@ import math
from pathlib import Path
from typing import List, Optional, Tuple, Type, Union
import requests
import torch
from asgiref.sync import sync_to_async
from sentence_transformers import util
@ -231,8 +232,12 @@ def setup(
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
try:
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
except requests.exceptions.HTTPError as e:
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
cross_scores = [0.0] * len(hits)
# Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)):