mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2025-03-16 23:22:22 +00:00
model specific summarization (#1119)
* model specific summarization * update guard functions * patch model picker and key inputs
This commit is contained in:
parent
9449fcd737
commit
81fd82e133
7 changed files with 80 additions and 10 deletions
frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection
server/utils/agents/aibitat
|
@ -97,7 +97,7 @@ export default function AgentModelSelection({
|
||||||
<option
|
<option
|
||||||
key={model.id}
|
key={model.id}
|
||||||
value={model.id}
|
value={model.id}
|
||||||
selected={workspace?.chatModel === model.id}
|
selected={workspace?.agentModel === model.id}
|
||||||
>
|
>
|
||||||
{model.name}
|
{model.name}
|
||||||
</option>
|
</option>
|
||||||
|
|
|
@ -12,6 +12,7 @@ const { Telemetry } = require("../../../models/telemetry.js");
|
||||||
class AIbitat {
|
class AIbitat {
|
||||||
emitter = new EventEmitter();
|
emitter = new EventEmitter();
|
||||||
|
|
||||||
|
provider = null;
|
||||||
defaultProvider = null;
|
defaultProvider = null;
|
||||||
defaultInterrupt;
|
defaultInterrupt;
|
||||||
maxRounds;
|
maxRounds;
|
||||||
|
@ -39,6 +40,7 @@ class AIbitat {
|
||||||
provider,
|
provider,
|
||||||
...rest,
|
...rest,
|
||||||
};
|
};
|
||||||
|
this.provider = this.defaultProvider.provider;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -2,6 +2,7 @@ const { Document } = require("../../../../models/documents");
|
||||||
const { safeJsonParse } = require("../../../http");
|
const { safeJsonParse } = require("../../../http");
|
||||||
const { validate } = require("uuid");
|
const { validate } = require("uuid");
|
||||||
const { summarizeContent } = require("../utils/summarize");
|
const { summarizeContent } = require("../utils/summarize");
|
||||||
|
const Provider = require("../providers/ai-provider");
|
||||||
|
|
||||||
const docSummarizer = {
|
const docSummarizer = {
|
||||||
name: "document-summarizer",
|
name: "document-summarizer",
|
||||||
|
@ -95,7 +96,19 @@ const docSummarizer = {
|
||||||
document?.title ?? "a discovered file."
|
document?.title ?? "a discovered file."
|
||||||
}`
|
}`
|
||||||
);
|
);
|
||||||
if (document?.content?.length < 8000) return content;
|
|
||||||
|
if (!document.content || document.content.length === 0) {
|
||||||
|
throw new Error(
|
||||||
|
"This document has no readable content that could be found."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
document.content?.length <
|
||||||
|
Provider.contextLimit(this.super.provider)
|
||||||
|
) {
|
||||||
|
return document.content;
|
||||||
|
}
|
||||||
|
|
||||||
this.super.introspect(
|
this.super.introspect(
|
||||||
`${this.caller}: Summarizing ${document?.title ?? ""}...`
|
`${this.caller}: Summarizing ${document?.title ?? ""}...`
|
||||||
|
@ -109,6 +122,7 @@ const docSummarizer = {
|
||||||
});
|
});
|
||||||
|
|
||||||
return await summarizeContent(
|
return await summarizeContent(
|
||||||
|
this.super.provider,
|
||||||
this.controller.signal,
|
this.controller.signal,
|
||||||
document.content
|
document.content
|
||||||
);
|
);
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
const { CollectorApi } = require("../../../collectorApi");
|
const { CollectorApi } = require("../../../collectorApi");
|
||||||
|
const Provider = require("../providers/ai-provider");
|
||||||
const { summarizeContent } = require("../utils/summarize");
|
const { summarizeContent } = require("../utils/summarize");
|
||||||
|
|
||||||
const webScraping = {
|
const webScraping = {
|
||||||
|
@ -61,7 +62,11 @@ const webScraping = {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (content?.length <= 8000) {
|
if (!content || content?.length === 0) {
|
||||||
|
throw new Error("There was no content to be collected or read.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (content.length < Provider.contextLimit(this.super.provider)) {
|
||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +79,11 @@ const webScraping = {
|
||||||
);
|
);
|
||||||
this.controller.abort();
|
this.controller.abort();
|
||||||
});
|
});
|
||||||
return summarizeContent(this.controller.signal, content);
|
return summarizeContent(
|
||||||
|
this.super.provider,
|
||||||
|
this.controller.signal,
|
||||||
|
content
|
||||||
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
|
@ -2,6 +2,9 @@
|
||||||
* A service that provides an AI client to create a completion.
|
* A service that provides an AI client to create a completion.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
const { ChatOpenAI } = require("langchain/chat_models/openai");
|
||||||
|
const { ChatAnthropic } = require("langchain/chat_models/anthropic");
|
||||||
|
|
||||||
class Provider {
|
class Provider {
|
||||||
_client;
|
_client;
|
||||||
constructor(client) {
|
constructor(client) {
|
||||||
|
@ -14,6 +17,37 @@ class Provider {
|
||||||
get client() {
|
get client() {
|
||||||
return this._client;
|
return this._client;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LangChainChatModel(provider = "openai", config = {}) {
|
||||||
|
switch (provider) {
|
||||||
|
case "openai":
|
||||||
|
return new ChatOpenAI({
|
||||||
|
openAIApiKey: process.env.OPEN_AI_KEY,
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
case "anthropic":
|
||||||
|
return new ChatAnthropic({
|
||||||
|
anthropicApiKey: process.env.ANTHROPIC_API_KEY,
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
default:
|
||||||
|
return new ChatOpenAI({
|
||||||
|
openAIApiKey: process.env.OPEN_AI_KEY,
|
||||||
|
...config,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static contextLimit(provider = "openai") {
|
||||||
|
switch (provider) {
|
||||||
|
case "openai":
|
||||||
|
return 8_000;
|
||||||
|
case "anthropic":
|
||||||
|
return 100_000;
|
||||||
|
default:
|
||||||
|
return 8_000;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = Provider;
|
module.exports = Provider;
|
||||||
|
|
|
@ -186,7 +186,8 @@ class AnthropicProvider extends Provider {
|
||||||
const completion = response.content.find((msg) => msg.type === "text");
|
const completion = response.content.find((msg) => msg.type === "text");
|
||||||
return {
|
return {
|
||||||
result:
|
result:
|
||||||
completion?.text ?? "I could not generate a response from this.",
|
completion?.text ??
|
||||||
|
"The model failed to complete the task and return back a valid response.",
|
||||||
cost: 0,
|
cost: 0,
|
||||||
};
|
};
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
const { loadSummarizationChain } = require("langchain/chains");
|
const { loadSummarizationChain } = require("langchain/chains");
|
||||||
const { ChatOpenAI } = require("langchain/chat_models/openai");
|
|
||||||
const { PromptTemplate } = require("langchain/prompts");
|
const { PromptTemplate } = require("langchain/prompts");
|
||||||
const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
|
const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
|
||||||
|
const Provider = require("../providers/ai-provider");
|
||||||
/**
|
/**
|
||||||
* Summarize content using OpenAI's GPT-3.5 model.
|
* Summarize content using OpenAI's GPT-3.5 model.
|
||||||
*
|
*
|
||||||
|
@ -9,11 +9,20 @@ const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
|
||||||
* @param content The content to summarize.
|
* @param content The content to summarize.
|
||||||
* @returns The summarized content.
|
* @returns The summarized content.
|
||||||
*/
|
*/
|
||||||
async function summarizeContent(controllerSignal, content) {
|
|
||||||
const llm = new ChatOpenAI({
|
const SUMMARY_MODEL = {
|
||||||
openAIApiKey: process.env.OPEN_AI_KEY,
|
anthropic: "claude-3-opus-20240229", // 200,000 tokens
|
||||||
|
openai: "gpt-3.5-turbo-1106", // 16,385 tokens
|
||||||
|
};
|
||||||
|
|
||||||
|
async function summarizeContent(
|
||||||
|
provider = "openai",
|
||||||
|
controllerSignal,
|
||||||
|
content
|
||||||
|
) {
|
||||||
|
const llm = Provider.LangChainChatModel(provider, {
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
modelName: "gpt-3.5-turbo-16k-0613",
|
modelName: SUMMARY_MODEL[provider],
|
||||||
});
|
});
|
||||||
|
|
||||||
const textSplitter = new RecursiveCharacterTextSplitter({
|
const textSplitter = new RecursiveCharacterTextSplitter({
|
||||||
|
@ -41,6 +50,7 @@ async function summarizeContent(controllerSignal, content) {
|
||||||
combineMapPrompt: mapPromptTemplate,
|
combineMapPrompt: mapPromptTemplate,
|
||||||
verbose: process.env.NODE_ENV === "development",
|
verbose: process.env.NODE_ENV === "development",
|
||||||
});
|
});
|
||||||
|
|
||||||
const res = await chain.call({
|
const res = await chain.call({
|
||||||
...(controllerSignal ? { signal: controllerSignal } : {}),
|
...(controllerSignal ? { signal: controllerSignal } : {}),
|
||||||
input_documents: docs,
|
input_documents: docs,
|
||||||
|
|
Loading…
Add table
Reference in a new issue