mirror of
https://github.com/Mintplex-Labs/anything-llm.git
synced 2025-03-15 14:42:23 +00:00
model specific summarization (#1119)
* model specific summarization * update guard functions * patch model picker and key inputs
This commit is contained in:
parent
9449fcd737
commit
81fd82e133
7 changed files with 80 additions and 10 deletions
frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection
server/utils/agents/aibitat
|
@ -97,7 +97,7 @@ export default function AgentModelSelection({
|
|||
<option
|
||||
key={model.id}
|
||||
value={model.id}
|
||||
selected={workspace?.chatModel === model.id}
|
||||
selected={workspace?.agentModel === model.id}
|
||||
>
|
||||
{model.name}
|
||||
</option>
|
||||
|
|
|
@ -12,6 +12,7 @@ const { Telemetry } = require("../../../models/telemetry.js");
|
|||
class AIbitat {
|
||||
emitter = new EventEmitter();
|
||||
|
||||
provider = null;
|
||||
defaultProvider = null;
|
||||
defaultInterrupt;
|
||||
maxRounds;
|
||||
|
@ -39,6 +40,7 @@ class AIbitat {
|
|||
provider,
|
||||
...rest,
|
||||
};
|
||||
this.provider = this.defaultProvider.provider;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -2,6 +2,7 @@ const { Document } = require("../../../../models/documents");
|
|||
const { safeJsonParse } = require("../../../http");
|
||||
const { validate } = require("uuid");
|
||||
const { summarizeContent } = require("../utils/summarize");
|
||||
const Provider = require("../providers/ai-provider");
|
||||
|
||||
const docSummarizer = {
|
||||
name: "document-summarizer",
|
||||
|
@ -95,7 +96,19 @@ const docSummarizer = {
|
|||
document?.title ?? "a discovered file."
|
||||
}`
|
||||
);
|
||||
if (document?.content?.length < 8000) return content;
|
||||
|
||||
if (!document.content || document.content.length === 0) {
|
||||
throw new Error(
|
||||
"This document has no readable content that could be found."
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
document.content?.length <
|
||||
Provider.contextLimit(this.super.provider)
|
||||
) {
|
||||
return document.content;
|
||||
}
|
||||
|
||||
this.super.introspect(
|
||||
`${this.caller}: Summarizing ${document?.title ?? ""}...`
|
||||
|
@ -109,6 +122,7 @@ const docSummarizer = {
|
|||
});
|
||||
|
||||
return await summarizeContent(
|
||||
this.super.provider,
|
||||
this.controller.signal,
|
||||
document.content
|
||||
);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
const { CollectorApi } = require("../../../collectorApi");
|
||||
const Provider = require("../providers/ai-provider");
|
||||
const { summarizeContent } = require("../utils/summarize");
|
||||
|
||||
const webScraping = {
|
||||
|
@ -61,7 +62,11 @@ const webScraping = {
|
|||
);
|
||||
}
|
||||
|
||||
if (content?.length <= 8000) {
|
||||
if (!content || content?.length === 0) {
|
||||
throw new Error("There was no content to be collected or read.");
|
||||
}
|
||||
|
||||
if (content.length < Provider.contextLimit(this.super.provider)) {
|
||||
return content;
|
||||
}
|
||||
|
||||
|
@ -74,7 +79,11 @@ const webScraping = {
|
|||
);
|
||||
this.controller.abort();
|
||||
});
|
||||
return summarizeContent(this.controller.signal, content);
|
||||
return summarizeContent(
|
||||
this.super.provider,
|
||||
this.controller.signal,
|
||||
content
|
||||
);
|
||||
},
|
||||
});
|
||||
},
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
* A service that provides an AI client to create a completion.
|
||||
*/
|
||||
|
||||
const { ChatOpenAI } = require("langchain/chat_models/openai");
|
||||
const { ChatAnthropic } = require("langchain/chat_models/anthropic");
|
||||
|
||||
class Provider {
|
||||
_client;
|
||||
constructor(client) {
|
||||
|
@ -14,6 +17,37 @@ class Provider {
|
|||
get client() {
|
||||
return this._client;
|
||||
}
|
||||
|
||||
static LangChainChatModel(provider = "openai", config = {}) {
|
||||
switch (provider) {
|
||||
case "openai":
|
||||
return new ChatOpenAI({
|
||||
openAIApiKey: process.env.OPEN_AI_KEY,
|
||||
...config,
|
||||
});
|
||||
case "anthropic":
|
||||
return new ChatAnthropic({
|
||||
anthropicApiKey: process.env.ANTHROPIC_API_KEY,
|
||||
...config,
|
||||
});
|
||||
default:
|
||||
return new ChatOpenAI({
|
||||
openAIApiKey: process.env.OPEN_AI_KEY,
|
||||
...config,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static contextLimit(provider = "openai") {
|
||||
switch (provider) {
|
||||
case "openai":
|
||||
return 8_000;
|
||||
case "anthropic":
|
||||
return 100_000;
|
||||
default:
|
||||
return 8_000;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = Provider;
|
||||
|
|
|
@ -186,7 +186,8 @@ class AnthropicProvider extends Provider {
|
|||
const completion = response.content.find((msg) => msg.type === "text");
|
||||
return {
|
||||
result:
|
||||
completion?.text ?? "I could not generate a response from this.",
|
||||
completion?.text ??
|
||||
"The model failed to complete the task and return back a valid response.",
|
||||
cost: 0,
|
||||
};
|
||||
} catch (error) {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
const { loadSummarizationChain } = require("langchain/chains");
|
||||
const { ChatOpenAI } = require("langchain/chat_models/openai");
|
||||
const { PromptTemplate } = require("langchain/prompts");
|
||||
const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
|
||||
const Provider = require("../providers/ai-provider");
|
||||
/**
|
||||
* Summarize content using OpenAI's GPT-3.5 model.
|
||||
*
|
||||
|
@ -9,11 +9,20 @@ const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
|
|||
* @param content The content to summarize.
|
||||
* @returns The summarized content.
|
||||
*/
|
||||
async function summarizeContent(controllerSignal, content) {
|
||||
const llm = new ChatOpenAI({
|
||||
openAIApiKey: process.env.OPEN_AI_KEY,
|
||||
|
||||
const SUMMARY_MODEL = {
|
||||
anthropic: "claude-3-opus-20240229", // 200,000 tokens
|
||||
openai: "gpt-3.5-turbo-1106", // 16,385 tokens
|
||||
};
|
||||
|
||||
async function summarizeContent(
|
||||
provider = "openai",
|
||||
controllerSignal,
|
||||
content
|
||||
) {
|
||||
const llm = Provider.LangChainChatModel(provider, {
|
||||
temperature: 0,
|
||||
modelName: "gpt-3.5-turbo-16k-0613",
|
||||
modelName: SUMMARY_MODEL[provider],
|
||||
});
|
||||
|
||||
const textSplitter = new RecursiveCharacterTextSplitter({
|
||||
|
@ -41,6 +50,7 @@ async function summarizeContent(controllerSignal, content) {
|
|||
combineMapPrompt: mapPromptTemplate,
|
||||
verbose: process.env.NODE_ENV === "development",
|
||||
});
|
||||
|
||||
const res = await chain.call({
|
||||
...(controllerSignal ? { signal: controllerSignal } : {}),
|
||||
input_documents: docs,
|
||||
|
|
Loading…
Add table
Reference in a new issue