Merge branch 'master' of github.com:khoj-ai/khoj into features/include-full-file-in-convo-with-filter

This commit is contained in:
sabaimran 2024-11-10 19:24:11 -08:00
commit 8805e731fd
15 changed files with 222 additions and 989 deletions

View file

@ -78,6 +78,7 @@ If your plugin does not need CSS, delete this file.
user-select: text;
color: var(--text-normal);
background-color: var(--active-bg);
word-break: break-word;
}
/* color chat bubble by khoj blue */
.khoj-chat-message-text.khoj {

View file

@ -4,6 +4,7 @@ div.chatMessageContainer {
margin: 12px;
border-radius: 16px;
padding: 8px 16px 0 16px;
word-break: break-word;
}
div.chatMessageWrapper {
@ -170,6 +171,7 @@ div.trainOfThoughtElement {
div.trainOfThoughtElement ol,
div.trainOfThoughtElement ul {
margin: auto;
word-break: break-word;
}
@media screen and (max-width: 768px) {

View file

@ -1,172 +0,0 @@
input.factVerification {
width: 100%;
display: block;
padding: 12px 20px;
margin: 8px 0;
border: none;
box-sizing: border-box;
border-radius: 4px;
text-align: left;
margin: auto;
margin-top: 8px;
margin-bottom: 8px;
font-size: large;
}
div.factCheckerContainer {
width: 75vw;
margin: auto;
}
input.factVerification:focus {
outline: none;
box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2);
}
div.responseText {
margin: 0;
padding: 0;
border-radius: 8px;
}
div.response {
margin-bottom: 12px;
}
a.titleLink {
color: #333;
font-weight: bold;
}
a.subLinks {
color: #333;
text-decoration: none;
font-weight: small;
border-radius: 4px;
font-size: small;
}
div.subLinks {
display: flex;
flex-direction: row;
gap: 8px;
flex-wrap: wrap;
}
div.reference {
padding: 12px;
margin: 8px;
border-radius: 8px;
}
footer.footer {
width: 100%;
background: transparent;
text-align: left;
}
div.reportActions {
display: flex;
flex-direction: row;
gap: 8px;
justify-content: space-between;
margin-top: 8px;
}
button.factCheckButton {
border: none;
cursor: pointer;
width: 100%;
border-radius: 4px;
margin: 8px;
padding-left: 1rem;
padding-right: 1rem;
line-height: 1.25rem;
}
button.factCheckButton:hover {
box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2);
}
div.spinner {
margin: 20px;
width: 40px;
height: 40px;
position: relative;
text-align: center;
-webkit-animation: sk-rotate 2.0s infinite linear;
animation: sk-rotate 2.0s infinite linear;
}
div.inputFields {
width: 100%;
display: grid;
grid-template-columns: 1fr auto;
grid-gap: 8px;
}
/* Loading Animation */
div.dot1,
div.dot2 {
width: 60%;
height: 60%;
display: inline-block;
position: absolute;
top: 0;
border-radius: 100%;
-webkit-animation: sk-bounce 2.0s infinite ease-in-out;
animation: sk-bounce 2.0s infinite ease-in-out;
}
div.dot2 {
top: auto;
bottom: 0;
-webkit-animation-delay: -1.0s;
animation-delay: -1.0s;
}
@media screen and (max-width: 768px) {
div.factCheckerContainer {
width: 95vw;
}
}
@-webkit-keyframes sk-rotate {
100% {
-webkit-transform: rotate(360deg)
}
}
@keyframes sk-rotate {
100% {
transform: rotate(360deg);
-webkit-transform: rotate(360deg)
}
}
@-webkit-keyframes sk-bounce {
0%,
100% {
-webkit-transform: scale(0.0)
}
50% {
-webkit-transform: scale(1.0)
}
}
@keyframes sk-bounce {
0%,
100% {
transform: scale(0.0);
-webkit-transform: scale(0.0);
}
50% {
transform: scale(1.0);
-webkit-transform: scale(1.0);
}
}

View file

@ -1,33 +0,0 @@
import type { Metadata } from "next";
export const metadata: Metadata = {
title: "Khoj AI - Fact Checker",
description:
"Use the Fact Checker with Khoj AI for verifying statements. It can research the internet for you, either refuting or confirming the statement using fresh data.",
icons: {
icon: "/static/assets/icons/khoj_lantern.ico",
apple: "/static/assets/icons/khoj_lantern_256x256.png",
},
openGraph: {
siteName: "Khoj AI",
title: "Khoj AI - Fact Checker",
description: "Your Second Brain.",
url: "https://app.khoj.dev/factchecker",
type: "website",
images: [
{
url: "https://assets.khoj.dev/khoj_lantern_256x256.png",
width: 256,
height: 256,
},
],
},
};
export default function RootLayout({
children,
}: Readonly<{
children: React.ReactNode;
}>) {
return <div>{children}</div>;
}

View file

@ -1,676 +0,0 @@
"use client";
import styles from "./factChecker.module.css";
import { useAuthenticatedData } from "@/app/common/auth";
import { useState, useEffect } from "react";
import ChatMessage, {
CodeContext,
Context,
OnlineContext,
OnlineContextData,
WebPage,
} from "../components/chatMessage/chatMessage";
import { ModelPicker, Model } from "../components/modelPicker/modelPicker";
import ShareLink from "../components/shareLink/shareLink";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import { Card, CardContent, CardFooter, CardHeader, CardTitle } from "@/components/ui/card";
import Link from "next/link";
import SidePanel from "../components/sidePanel/chatHistorySidePanel";
import { useIsMobileWidth } from "../common/utils";
const chatURL = "/api/chat";
const verificationPrecursor =
"Limit your search to reputable sources. Search the internet for relevant supporting or refuting information. Do not reference my notes. Refuse to answer any queries that are not falsifiable by informing me that you will not answer the question. You're not permitted to ask follow-up questions, so do the best with what you have. Respond with **TRUE** or **FALSE** or **INCONCLUSIVE**, then provide your justification. Fact Check:";
const LoadingSpinner = () => (
<div className={styles.loading}>
<div className={styles.loadingVerification}>
Researching...
<div className={styles.spinner}>
<div className={`${styles.dot1} bg-blue-300`}></div>
<div className={`${styles.dot2} bg-blue-300`}></div>
</div>
</div>
</div>
);
interface SupplementReferences {
additionalLink: string;
response: string;
linkTitle: string;
}
interface ResponseWithReferences {
context?: Context[];
online?: OnlineContext;
code?: CodeContext;
response?: string;
}
function handleCompiledReferences(chunk: string, currentResponse: string) {
const rawReference = chunk.split("### compiled references:")[1];
const rawResponse = chunk.split("### compiled references:")[0];
let references: ResponseWithReferences = {};
// Set the initial response
references.response = currentResponse + rawResponse;
const rawReferenceAsJson = JSON.parse(rawReference);
if (rawReferenceAsJson instanceof Array) {
references.context = rawReferenceAsJson;
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
references.online = rawReferenceAsJson;
}
return references;
}
async function verifyStatement(
message: string,
conversationId: string,
setIsLoading: (loading: boolean) => void,
setInitialResponse: (response: string) => void,
setInitialReferences: (references: ResponseWithReferences) => void,
) {
setIsLoading(true);
// Construct the verification payload
let verificationMessage = `${verificationPrecursor} ${message}`;
const apiURL = `${chatURL}?client=web`;
const requestBody = {
q: verificationMessage,
conversation_id: conversationId,
stream: true,
};
try {
// Send a message to the chat server to verify the fact
const response = await fetch(apiURL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
});
if (!response.body) throw new Error("No response body found");
const reader = response.body?.getReader();
let decoder = new TextDecoder();
let result = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
let chunk = decoder.decode(value, { stream: true });
if (chunk.includes("### compiled references:")) {
const references = handleCompiledReferences(chunk, result);
if (references.response) {
result = references.response;
setInitialResponse(references.response);
setInitialReferences(references);
}
} else {
result += chunk;
setInitialResponse(result);
}
}
} catch (error) {
console.error("Error verifying statement: ", error);
} finally {
setIsLoading(false);
}
}
async function spawnNewConversation(setConversationID: (conversationID: string) => void) {
let createURL = `/api/chat/sessions?client=web`;
const response = await fetch(createURL, { method: "POST" });
const data = await response.json();
setConversationID(data.conversation_id);
}
interface ReferenceVerificationProps {
message: string;
additionalLink: string;
conversationId: string;
linkTitle: string;
setChildReferencesCallback: (
additionalLink: string,
response: string,
linkTitle: string,
) => void;
prefilledResponse?: string;
}
function ReferenceVerification(props: ReferenceVerificationProps) {
const [initialResponse, setInitialResponse] = useState("");
const [isLoading, setIsLoading] = useState(true);
const verificationStatement = `${props.message}. Use this link for reference: ${props.additionalLink}`;
const isMobileWidth = useIsMobileWidth();
useEffect(() => {
if (props.prefilledResponse) {
setInitialResponse(props.prefilledResponse);
setIsLoading(false);
} else {
verifyStatement(
verificationStatement,
props.conversationId,
setIsLoading,
setInitialResponse,
() => {},
);
}
}, [verificationStatement, props.conversationId, props.prefilledResponse]);
useEffect(() => {
if (initialResponse === "") return;
if (props.prefilledResponse) return;
if (!isLoading) {
// Only set the child references when it's done loading and if the initial response is not prefilled (i.e. it was fetched from the server)
props.setChildReferencesCallback(
props.additionalLink,
initialResponse,
props.linkTitle,
);
}
}, [initialResponse, isLoading, props]);
return (
<div>
{isLoading && <LoadingSpinner />}
<ChatMessage
chatMessage={{
automationId: "",
by: "AI",
message: initialResponse,
context: [],
created: new Date().toISOString(),
onlineContext: {},
codeContext: {},
conversationId: props.conversationId,
turnId: "",
}}
isMobileWidth={isMobileWidth}
onDeleteMessage={(turnId?: string) => {}}
conversationId={props.conversationId}
/>
</div>
);
}
interface SupplementalReferenceProps {
onlineData?: OnlineContextData;
officialFactToVerify: string;
conversationId: string;
additionalLink: string;
setChildReferencesCallback: (
additionalLink: string,
response: string,
linkTitle: string,
) => void;
prefilledResponse?: string;
linkTitle?: string;
}
function SupplementalReference(props: SupplementalReferenceProps) {
const linkTitle = props.linkTitle || props.onlineData?.organic?.[0]?.title || "Reference";
const linkAsWebpage = { link: props.additionalLink } as WebPage;
return (
<Card className={`mt-2 mb-4`}>
<CardHeader>
<a
className={styles.titleLink}
href={props.additionalLink}
target="_blank"
rel="noreferrer"
>
{linkTitle}
</a>
<WebPageLink {...linkAsWebpage} />
</CardHeader>
<CardContent>
<ReferenceVerification
additionalLink={props.additionalLink}
message={props.officialFactToVerify}
linkTitle={linkTitle}
conversationId={props.conversationId}
setChildReferencesCallback={props.setChildReferencesCallback}
prefilledResponse={props.prefilledResponse}
/>
</CardContent>
</Card>
);
}
const WebPageLink = (webpage: WebPage) => {
const webpageDomain = new URL(webpage.link).hostname;
return (
<div className={styles.subLinks}>
<a
className={`${styles.subLinks} bg-blue-200 px-2`}
href={webpage.link}
target="_blank"
rel="noreferrer"
>
{webpageDomain}
</a>
</div>
);
};
export default function FactChecker() {
const [factToVerify, setFactToVerify] = useState("");
const [officialFactToVerify, setOfficialFactToVerify] = useState("");
const [isLoading, setIsLoading] = useState(false);
const [initialResponse, setInitialResponse] = useState("");
const [clickedVerify, setClickedVerify] = useState(false);
const [initialReferences, setInitialReferences] = useState<ResponseWithReferences>();
const [childReferences, setChildReferences] = useState<SupplementReferences[]>();
const [modelUsed, setModelUsed] = useState<Model>();
const isMobileWidth = useIsMobileWidth();
const [conversationID, setConversationID] = useState("");
const [runId, setRunId] = useState("");
const [loadedFromStorage, setLoadedFromStorage] = useState(false);
const [initialModel, setInitialModel] = useState<Model>();
function setChildReferencesCallback(
additionalLink: string,
response: string,
linkTitle: string,
) {
const newReferences = childReferences || [];
const exists = newReferences.find(
(reference) => reference.additionalLink === additionalLink,
);
if (exists) return;
newReferences.push({ additionalLink, response, linkTitle });
setChildReferences(newReferences);
}
let userData = useAuthenticatedData();
function storeData() {
const data = {
factToVerify,
response: initialResponse,
references: initialReferences,
childReferences,
runId,
modelUsed,
};
fetch(`/api/chat/store/factchecker`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
runId: runId,
storeData: data,
}),
});
}
useEffect(() => {
if (factToVerify) {
document.title = `AI Fact Check: ${factToVerify}`;
} else {
document.title = "AI Fact Checker";
}
}, [factToVerify]);
useEffect(() => {
const storedFact = localStorage.getItem("factToVerify");
if (storedFact) {
setFactToVerify(storedFact);
}
// Get query params from the URL
const urlParams = new URLSearchParams(window.location.search);
const factToVerifyParam = urlParams.get("factToVerify");
if (factToVerifyParam) {
setFactToVerify(factToVerifyParam);
}
const runIdParam = urlParams.get("runId");
if (runIdParam) {
setRunId(runIdParam);
// Define an async function to fetch data
const fetchData = async () => {
const storedDataURL = `/api/chat/store/factchecker?runId=${runIdParam}`;
try {
const response = await fetch(storedDataURL);
if (response.status !== 200) {
throw new Error("Failed to fetch stored data");
}
const storedData = JSON.parse(await response.json());
if (storedData) {
setOfficialFactToVerify(storedData.factToVerify);
setInitialResponse(storedData.response);
setInitialReferences(storedData.references);
setChildReferences(storedData.childReferences);
setInitialModel(storedData.modelUsed);
}
setLoadedFromStorage(true);
} catch (error) {
console.error("Error fetching stored data: ", error);
}
};
// Call the async function
fetchData();
}
}, []);
function onClickVerify() {
if (clickedVerify) return;
// Perform validation checks on the fact to verify
if (!factToVerify) {
alert("Please enter a fact to verify.");
return;
}
setClickedVerify(true);
if (!userData) {
let currentURL = window.location.href;
window.location.href = `/login?next=${currentURL}`;
}
setInitialReferences(undefined);
setInitialResponse("");
spawnNewConversation(setConversationID);
// Set the runId to a random 12-digit alphanumeric string
const newRunId = [...Array(16)].map(() => Math.random().toString(36)[2]).join("");
setRunId(newRunId);
window.history.pushState(
{},
document.title,
window.location.pathname + `?runId=${newRunId}`,
);
setOfficialFactToVerify(factToVerify);
setClickedVerify(false);
}
useEffect(() => {
if (!conversationID) return;
verifyStatement(
officialFactToVerify,
conversationID,
setIsLoading,
setInitialResponse,
setInitialReferences,
);
}, [conversationID, officialFactToVerify]);
// Store factToVerify in localStorage whenever it changes
useEffect(() => {
localStorage.setItem("factToVerify", factToVerify);
}, [factToVerify]);
// Update the meta tags for the description and og:description
useEffect(() => {
let metaTag = document.querySelector('meta[name="description"]');
if (metaTag) {
metaTag.setAttribute("content", initialResponse);
}
let metaOgTag = document.querySelector('meta[property="og:description"]');
if (!metaOgTag) {
metaOgTag = document.createElement("meta");
metaOgTag.setAttribute("property", "og:description");
document.getElementsByTagName("head")[0].appendChild(metaOgTag);
}
metaOgTag.setAttribute("content", initialResponse);
}, [initialResponse]);
const renderReferences = (
conversationId: string,
initialReferences: ResponseWithReferences,
officialFactToVerify: string,
loadedFromStorage: boolean,
childReferences?: SupplementReferences[],
) => {
if (loadedFromStorage && childReferences) {
return renderSupplementalReferences(childReferences);
}
const seenLinks = new Set();
// Any links that are present in webpages should not be searched again
Object.entries(initialReferences.online || {}).map(([key, onlineData], index) => {
const webpages = onlineData?.webpages || [];
// Webpage can be a list or a single object
if (webpages instanceof Array) {
for (let i = 0; i < webpages.length; i++) {
const webpage = webpages[i];
const additionalLink = webpage.link || "";
if (seenLinks.has(additionalLink)) {
return null;
}
seenLinks.add(additionalLink);
}
} else {
let singleWebpage = webpages as WebPage;
const additionalLink = singleWebpage.link || "";
if (seenLinks.has(additionalLink)) {
return null;
}
seenLinks.add(additionalLink);
}
});
return Object.entries(initialReferences.online || {})
.map(([key, onlineData], index) => {
let additionalLink = "";
// Loop through organic links until we find one that hasn't been searched
for (let i = 0; i < onlineData?.organic?.length; i++) {
const webpage = onlineData?.organic?.[i];
additionalLink = webpage.link || "";
if (!seenLinks.has(additionalLink)) {
break;
}
}
seenLinks.add(additionalLink);
if (additionalLink === "") return null;
return (
<SupplementalReference
key={index}
onlineData={onlineData}
officialFactToVerify={officialFactToVerify}
conversationId={conversationId}
additionalLink={additionalLink}
setChildReferencesCallback={setChildReferencesCallback}
/>
);
})
.filter(Boolean);
};
const renderSupplementalReferences = (references: SupplementReferences[]) => {
return references.map((reference, index) => {
return (
<SupplementalReference
key={index}
additionalLink={reference.additionalLink}
officialFactToVerify={officialFactToVerify}
conversationId={conversationID}
linkTitle={reference.linkTitle}
setChildReferencesCallback={setChildReferencesCallback}
prefilledResponse={reference.response}
/>
);
});
};
const renderWebpages = (webpages: WebPage[] | WebPage) => {
if (webpages instanceof Array) {
return webpages.map((webpage, index) => {
return WebPageLink(webpage);
});
} else {
return WebPageLink(webpages);
}
};
function constructShareUrl() {
const url = new URL(window.location.href);
url.searchParams.set("runId", runId);
return url.href;
}
return (
<>
<div className="relative md:fixed h-full">
<SidePanel conversationId={null} uploadedFiles={[]} isMobileWidth={isMobileWidth} />
</div>
<div className={styles.factCheckerContainer}>
<h1
className={`${styles.response} pt-8 md:pt-4 font-large outline-slate-800 dark:outline-slate-200`}
>
AI Fact Checker
</h1>
<footer className={`${styles.footer} mt-4`}>
This is an experimental AI tool. It may make mistakes.
</footer>
{initialResponse && initialReferences && childReferences ? (
<div className={styles.reportActions}>
<Button asChild variant="secondary">
<Link href="/factchecker" target="_blank" rel="noopener noreferrer">
Try Another
</Link>
</Button>
<ShareLink
buttonTitle="Share report"
title="AI Fact Checking Report"
description="Share this fact checking report with others. Anyone who has this link will be able to view the report."
url={constructShareUrl()}
onShare={loadedFromStorage ? () => {} : storeData}
/>
</div>
) : (
<div className={styles.newReportActions}>
<div className={`${styles.inputFields} mt-4`}>
<Input
type="text"
maxLength={200}
placeholder="Enter a falsifiable statement to verify"
disabled={isLoading}
onChange={(e) => setFactToVerify(e.target.value)}
value={factToVerify}
onKeyDown={(e) => {
if (e.key === "Enter") {
onClickVerify();
}
}}
onFocus={(e) => (e.target.placeholder = "")}
onBlur={(e) =>
(e.target.placeholder =
"Enter a falsifiable statement to verify")
}
/>
<Button disabled={clickedVerify} onClick={() => onClickVerify()}>
Verify
</Button>
</div>
<h3 className={`mt-4 mb-4`}>
Try with a particular model. You must be{" "}
<a
href="/settings"
className="font-medium text-blue-600 dark:text-blue-500 hover:underline"
>
subscribed
</a>{" "}
to configure the model.
</h3>
</div>
)}
<ModelPicker
disabled={isLoading || loadedFromStorage}
setModelUsed={setModelUsed}
initialModel={initialModel}
/>
{isLoading && (
<div className={styles.loading}>
<LoadingSpinner />
</div>
)}
{initialResponse && (
<Card className={`mt-4`}>
<CardHeader>
<CardTitle>{officialFactToVerify}</CardTitle>
</CardHeader>
<CardContent>
<div className={styles.responseText}>
<ChatMessage
chatMessage={{
automationId: "",
by: "AI",
message: initialResponse,
context: [],
created: new Date().toISOString(),
onlineContext: {},
codeContext: {},
conversationId: conversationID,
turnId: "",
}}
conversationId={conversationID}
onDeleteMessage={(turnId?: string) => {}}
isMobileWidth={isMobileWidth}
/>
</div>
</CardContent>
<CardFooter>
{initialReferences &&
initialReferences.online &&
Object.keys(initialReferences.online).length > 0 && (
<div className={styles.subLinks}>
{Object.entries(initialReferences.online).map(
([key, onlineData], index) => {
const webpages = onlineData?.webpages || [];
return renderWebpages(webpages);
},
)}
</div>
)}
</CardFooter>
</Card>
)}
{initialReferences && (
<div className={styles.referenceContainer}>
<h2 className="mt-4 mb-4">Supplements</h2>
<div className={styles.references}>
{initialReferences.online !== undefined &&
renderReferences(
conversationID,
initialReferences,
officialFactToVerify,
loadedFromStorage,
childReferences,
)}
</div>
</div>
)}
</div>
</>
);
}

View file

@ -168,12 +168,6 @@ class UserAuthenticationBackend(AuthenticationBackend):
if create_if_not_exists:
user, is_new = await aget_or_create_user_by_phone_number(phone_number)
if user and is_new:
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_user",
metadata={"server_id": str(user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
else:
user = await aget_user_by_phone_number(phone_number)

View file

@ -78,7 +78,12 @@ class KhojUserAdmin(UserAdmin):
search_fields = ("email", "username", "phone_number", "uuid")
filter_horizontal = ("groups", "user_permissions")
fieldsets = (("Personal info", {"fields": ("phone_number", "email_verification_code")}),) + UserAdmin.fieldsets
fieldsets = (
(
"Personal info",
{"fields": ("phone_number", "email_verification_code", "verified_phone_number", "verified_email")},
),
) + UserAdmin.fieldsets
actions = ["get_email_login_url"]

View file

@ -113,6 +113,7 @@ class InformationCollectionIteration:
onlineContext: dict = None,
codeContext: dict = None,
summarizedResult: str = None,
warning: str = None,
):
self.tool = tool
self.query = query
@ -120,6 +121,7 @@ class InformationCollectionIteration:
self.onlineContext = onlineContext
self.codeContext = codeContext
self.summarizedResult = summarizedResult
self.warning = warning
def construct_iteration_history(
@ -350,7 +352,11 @@ def generate_chatml_messages_with_context(
message_context += chat.get("intent").get("inferred-queries")[0]
if not is_none_or_empty(chat.get("context")):
references = "\n\n".join(
{f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []}
{
f"# File: {item['file']}\n## {item['compiled']}\n"
for item in chat.get("context") or []
if isinstance(item, dict)
}
)
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"

View file

@ -4,7 +4,7 @@ import logging
import os
import urllib.parse
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import aiohttp
from bs4 import BeautifulSoup
@ -66,6 +66,7 @@ async def search_online(
custom_filters: List[str] = [],
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_images: List[str] = None,
previous_subqueries: Set = set(),
agent: Agent = None,
attached_files: str = None,
tracer: dict = {},
@ -77,7 +78,7 @@ async def search_online(
return
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
new_subqueries = await generate_online_subqueries(
query,
conversation_history,
location,
@ -87,33 +88,42 @@ async def search_online(
tracer=tracer,
attached_files=attached_files,
)
response_dict = {}
subqueries = list(new_subqueries - previous_subqueries)
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
if subqueries:
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
if is_none_or_empty(subqueries):
logger.info("No new subqueries to search online")
yield response_dict
return
with timer(f"Internet searches for {list(subqueries)} took", logger):
logger.info(f"🌐 Searching the Internet for {subqueries}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(subqueries)
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {subqueries} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery, location) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks)
response_dict = {subquery: search_result for subquery, search_result in search_results}
# Gather distinct web pages from organic results for subqueries without an instant answer.
# Content of web pages is directly available when Jina is used for search.
webpages: Dict[str, Dict] = {}
for subquery in response_dict:
if "answerBox" in response_dict[subquery]:
continue
for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]:
for idx, organic in enumerate(response_dict[subquery].get("organic", [])):
link = organic.get("link")
if link in webpages:
if link in webpages and idx < max_webpages_to_read:
webpages[link]["queries"].add(subquery)
else:
# Content of web pages is directly available when Jina is used for search.
elif idx < max_webpages_to_read:
webpages[link] = {"queries": {subquery}, "content": organic.get("content")}
# Only keep webpage content for up to max_webpages_to_read organic results.
if idx >= max_webpages_to_read and not is_none_or_empty(organic.get("content")):
organic["content"] = None
response_dict[subquery]["organic"][idx] = organic
# Read, extract relevant info from the retrieved web pages
if webpages:
@ -123,7 +133,9 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
read_webpage_and_extract_content(
data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
@ -371,3 +383,25 @@ async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dic
for item in response_json["data"]
]
return query, {"organic": parsed_response}
def deduplicate_organic_results(online_results: dict) -> dict:
"""Deduplicate organic search results based on links across all queries."""
# Keep track of seen links to filter out duplicates across queries
seen_links = set()
deduplicated_results = {}
# Process each query's results
for query, results in online_results.items():
# Filter organic results keeping only first occurrence of each link
filtered_organic = []
for result in results.get("organic", []):
link = result.get("link")
if link and link not in seen_links:
seen_links.add(link)
filtered_organic.append(result)
# Update results with deduplicated organic entries
deduplicated_results[query] = {**results, "organic": filtered_organic}
return deduplicated_results

View file

@ -6,7 +6,7 @@ import os
import threading
import time
import uuid
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Set, Union
import cron_descriptor
import pytz
@ -349,6 +349,7 @@ async def extract_references_and_questions(
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
attached_files: str = None,
tracer: dict = {},
@ -482,6 +483,7 @@ async def extract_references_and_questions(
)
# Collate search results as context for GPT
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")

View file

@ -27,7 +27,11 @@ from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.online_search import (
deduplicate_organic_results,
read_webpages,
search_online,
)
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions
from khoj.routers.email import send_query_feedback
@ -779,8 +783,13 @@ async def chat(
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
try:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
except HTTPException as e:
async for result in send_llm_response(str(e.detail)):
yield result
return
defiltered_query = defilter_query(q)
@ -815,11 +824,8 @@ async def chat(
yield research_result
# researched_results = await extract_relevant_info(q, researched_results, agent)
logger.info(f"Researched Results: {researched_results}")
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
if state.verbose > 1:
logger.debug(f"Researched Results: {researched_results}")
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
@ -1069,12 +1075,13 @@ async def chat(
)
## Send Gathered References
unique_online_results = deduplicate_organic_results(online_results)
async for result in send_event(
ChatEvent.REFERENCES,
{
"inferredQueries": inferred_queries,
"context": compiled_references,
"onlineContext": online_results,
"onlineContext": unique_online_results,
"codeContext": code_results,
},
):

View file

@ -20,6 +20,7 @@ from typing import (
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
@ -539,7 +540,7 @@ async def generate_online_subqueries(
agent: Agent = None,
attached_files: str = None,
tracer: dict = {},
) -> List[str]:
) -> Set[str]:
"""
Generate subqueries from the given query
"""
@ -575,14 +576,14 @@ async def generate_online_subqueries(
try:
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, set) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
return {q}
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
return {q}
async def schedule_query(
@ -1208,9 +1209,6 @@ def generate_chat_response(
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
query_to_run = q
if meta_research:
query_to_run = f"AI Research: {meta_research} {q}"
try:
partial_completion = partial(
save_to_conversation_log,
@ -1229,6 +1227,13 @@ def generate_chat_response(
tracer=tracer,
)
query_to_run = q
if meta_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>"
compiled_references = []
online_results = {}
code_results = {}
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
if not vision_available and query_images:
@ -1375,25 +1380,28 @@ class ApiUserRateLimiter:
# Check if the user has exceeded the rate limit
if subscribed and count_requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests} requests in {self.window} seconds for user: {user}. Limit is {self.subscribed_requests} requests."
)
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
if not subscribed and count_requests >= self.requests:
if self.requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests} requests in {self.window} seconds for user: {user}. Limit is {self.subscribed_requests} requests."
)
raise HTTPException(
status_code=429,
detail="Slow down! Too Many Requests",
)
logger.info(
f"Rate limit: {count_requests} requests in {self.window} seconds for user: {user}. Limit is {self.subscribed_requests} requests."
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for subscribed user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
if not subscribed and count_requests >= self.requests:
if self.requests >= self.subscribed_requests:
logger.info(
f"Rate limit: {count_requests}/{self.subscribed_requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. But let's chat more tomorrow?",
)
logger.info(
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for user: {user}."
)
raise HTTPException(
status_code=429,
detail="I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can continue our conversation tomorrow?",
)
# Add the current request to the cache
@ -1419,6 +1427,7 @@ class ApiImageRateLimiter:
# Check number of images
if len(body.images) > self.max_images:
logger.info(f"Rate limit: {len(body.images)}/{self.max_images} images not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
@ -1439,6 +1448,7 @@ class ApiImageRateLimiter:
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
logger.info(f"Data limit: {total_size_mb}MB/{self.max_combined_size_mb}MB size not allowed per message.")
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
@ -1474,13 +1484,19 @@ class ConversationCommandRateLimiter:
if subscribed and count_requests >= self.subscribed_rate_limit:
logger.info(
f"Rate limit: {count_requests} requests in 24 hours for user: {user}. Limit is {self.subscribed_rate_limit} requests."
f"Rate limit: {count_requests}/{self.subscribed_rate_limit} requests not allowed in 24 hours for subscribed user: {user}."
)
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
if not subscribed and count_requests >= self.trial_rate_limit:
raise HTTPException(
status_code=429,
detail=f"We're glad you're enjoying Khoj! You've exceeded your `/{conversation_command.value}` command usage limit for today. Subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
detail=f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your `/{conversation_command.value}` command usage limit for today. Maybe we can talk about something else for today?",
)
if not subscribed and count_requests >= self.trial_rate_limit:
logger.info(
f"Rate limit: {count_requests}/{self.trial_rate_limit} requests not allowed in 24 hours for user: {user}."
)
raise HTTPException(
status_code=429,
detail=f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your `/{conversation_command.value}` command usage limit for today. You can subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings) or we can talk about something else for today?",
)
await UserRequests.objects.acreate(user=user, slug=command_slug)
return
@ -1526,16 +1542,28 @@ class ApiIndexedDataLimiter:
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
logger.info(
f"Data limit: {incoming_data_size_mb}MB incoming will exceed {self.subscribed_num_entries_size}MB allowed for subscribed user: {user}."
)
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
logger.info(
f"Data limit: {incoming_data_size_mb}MB incoming will exceed {self.num_entries_size}MB allowed for user: {user}."
)
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
logger.info(
f"Data limit: {incoming_data_size_mb}MB incoming + {user_size_data}MB existing will exceed {self.subscribed_total_entries_size}MB allowed for subscribed user: {user}."
)
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
logger.info(
f"Data limit: {incoming_data_size_mb}MB incoming + {user_size_data}MB existing will exceed {self.subscribed_total_entries_size}MB allowed for non subscribed user: {user}."
)
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
@ -1623,6 +1651,11 @@ def scheduled_chat(
# encode the conversation_id to avoid any issues with special characters
query_dict["conversation_id"] = [quote(str(conversation_id))]
# validate that the conversation id exists. If not, delete the automation and exit.
if not ConversationAdapters.get_conversation_by_id(conversation_id):
AutomationAdapters.delete_automation(user, job_id)
return
# Restructure the original query_dict into a valid JSON payload for the chat API
json_payload = {key: values[0] for key, values in query_dict.items()}

View file

@ -42,39 +42,36 @@ async def apick_next_tool(
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
previous_iterations_history: str = None,
previous_iterations: List[InformationCollectionIteration] = [],
max_iterations: int = 5,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
attached_files: str = None,
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
"""
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
# Construct tool options for the agent to choose from
tool_options = dict()
tool_options_str = ""
agent_tools = agent.input_tools if agent else []
for tool, description in function_calling_description_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
# Construct chat history with user and iteration history with researcher agent for context
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
if query_images:
query = f"[placeholder for user attached images]\n{query}"
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
# Extract Past User Message and Inferred Questions from Conversation Log
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
function_planning_prompt = prompts.plan_function_execution.format(
tools=tool_options_str,
chat_history=chat_history,
@ -87,16 +84,25 @@ async def apick_next_tool(
max_iterations=max_iterations,
)
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
query=query,
context=function_planning_prompt,
response_type="json_object",
user=user,
query_images=query_images,
tracer=tracer,
attached_files=attached_files,
try:
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
query=query,
context=function_planning_prompt,
response_type="json_object",
user=user,
query_images=query_images,
attached_files=attached_files,
tracer=tracer,
)
except Exception as e:
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
yield InformationCollectionIteration(
tool=None,
query=None,
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
)
return
try:
response = clean_json(response)
@ -104,8 +110,15 @@ async def apick_next_tool(
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
warning = None
logger.info(f"Response for determining relevant tools: {response}")
if send_status_func:
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations}
if (selected_tool, generated_query) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration
elif send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
@ -115,13 +128,14 @@ async def apick_next_tool(
yield InformationCollectionIteration(
tool=selected_tool,
query=generated_query,
warning=warning,
)
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield InformationCollectionIteration(
tool=None,
query=None,
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
)
@ -149,7 +163,6 @@ async def execute_information_collection(
document_results: List[Dict[str, str]] = []
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
async for result in apick_next_tool(
query,
@ -159,7 +172,7 @@ async def execute_information_collection(
location,
user_name,
agent,
previous_iterations_history,
previous_iterations,
MAX_ITERATIONS,
send_status_func,
tracer=tracer,
@ -170,9 +183,16 @@ async def execute_information_collection(
elif isinstance(result, InformationCollectionIteration):
this_iteration = result
if this_iteration.tool == ConversationCommand.Notes:
# Skip running iteration if warning present in iteration
if this_iteration.warning:
logger.warning(f"Research mode: {this_iteration.warning}.")
elif this_iteration.tool == ConversationCommand.Notes:
this_iteration.context = []
document_results = []
previous_inferred_queries = {
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
}
async for result in extract_references_and_questions(
request,
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
@ -184,6 +204,7 @@ async def execute_information_collection(
location,
send_status_func,
query_images,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
attached_files=attached_files,
@ -208,6 +229,12 @@ async def execute_information_collection(
logger.error(f"Error extracting document references: {e}", exc_info=True)
elif this_iteration.tool == ConversationCommand.Online:
previous_subqueries = {
subquery
for iteration in previous_iterations
if iteration.onlineContext
for subquery in iteration.onlineContext.keys()
}
async for result in search_online(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
@ -217,11 +244,16 @@ async def execute_information_collection(
[],
max_webpages_to_read=0,
query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif is_none_or_empty(result):
this_iteration.warning = (
"Detected previously run online search queries. Skipping iteration. Try something different."
)
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
@ -309,16 +341,19 @@ async def execute_information_collection(
current_iteration += 1
if document_results or online_results or code_results or summarize_files:
results_data = f"**Results**:\n"
if document_results or online_results or code_results or summarize_files or this_iteration.warning:
results_data = f"\n<iteration>{current_iteration}\n<tool>{this_iteration.tool}</tool>\n<query>{this_iteration.query}</query>\n<results>"
if document_results:
results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
if online_results:
results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if code_results:
results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<code_results>\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if summarize_files:
results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += "\n</results>\n</iteration>"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data

View file

@ -51,16 +51,6 @@ def chat_page(request: Request):
)
@web_client.get("/factchecker", response_class=FileResponse)
def fact_checker_page(request: Request):
return templates.TemplateResponse(
"factchecker/index.html",
context={
"request": request,
},
)
@web_client.get("/login", response_class=FileResponse)
def login_page(request: Request):
next_url = get_next_url(request)

View file

@ -101,10 +101,10 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
return {"decision": "FALSE", "explanation": f"Evaluation failed: {str(e)}"}
def process_batch(batch, counter, results, dataset_length):
for prompt, answer, reasoning_type in batch:
counter += 1
logger.info(f"Processing example: {counter}/{dataset_length}")
def process_batch(batch, batch_start, results, dataset_length):
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
current_index = batch_start + idx
logger.info(f"Processing example: {current_index}/{dataset_length}")
# Trigger research mode if enabled
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE else prompt
@ -113,12 +113,16 @@ def process_batch(batch, counter, results, dataset_length):
agent_response = get_agent_response(prompt)
# Evaluate response
evaluation = evaluate_response(prompt, agent_response, answer)
if agent_response is None or agent_response.strip() == "":
evaluation["decision"] = False
evaluation["explanation"] = "Agent response is empty. This maybe due to a service error."
else:
evaluation = evaluate_response(prompt, agent_response, answer)
# Store results
results.append(
{
"index": counter,
"index": current_index,
"prompt": prompt,
"ground_truth": answer,
"agent_response": agent_response,
@ -165,12 +169,13 @@ def main():
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for i in range(0, dataset_length, BATCH_SIZE):
batch_start = i
batch = zip(
dataset["Prompt"][i : i + BATCH_SIZE],
dataset["Answer"][i : i + BATCH_SIZE],
dataset["reasoning_types"][i : i + BATCH_SIZE],
)
futures.append(executor.submit(process_batch, batch, counter, results, dataset_length))
futures.append(executor.submit(process_batch, batch, batch_start, results, dataset_length))
# Wait for all futures to complete
concurrent.futures.wait(futures)