model specific summarization ()

* model specific summarization

* update guard functions

* patch model picker and key inputs
This commit is contained in:
Timothy Carambat 2024-04-17 14:04:51 -07:00 committed by GitHub
parent 9449fcd737
commit 81fd82e133
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 80 additions and 10 deletions
frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection
server/utils/agents/aibitat

View file

@ -97,7 +97,7 @@ export default function AgentModelSelection({
<option
key={model.id}
value={model.id}
selected={workspace?.chatModel === model.id}
selected={workspace?.agentModel === model.id}
>
{model.name}
</option>

View file

@ -12,6 +12,7 @@ const { Telemetry } = require("../../../models/telemetry.js");
class AIbitat {
emitter = new EventEmitter();
provider = null;
defaultProvider = null;
defaultInterrupt;
maxRounds;
@ -39,6 +40,7 @@ class AIbitat {
provider,
...rest,
};
this.provider = this.defaultProvider.provider;
}
/**

View file

@ -2,6 +2,7 @@ const { Document } = require("../../../../models/documents");
const { safeJsonParse } = require("../../../http");
const { validate } = require("uuid");
const { summarizeContent } = require("../utils/summarize");
const Provider = require("../providers/ai-provider");
const docSummarizer = {
name: "document-summarizer",
@ -95,7 +96,19 @@ const docSummarizer = {
document?.title ?? "a discovered file."
}`
);
if (document?.content?.length < 8000) return content;
if (!document.content || document.content.length === 0) {
throw new Error(
"This document has no readable content that could be found."
);
}
if (
document.content?.length <
Provider.contextLimit(this.super.provider)
) {
return document.content;
}
this.super.introspect(
`${this.caller}: Summarizing ${document?.title ?? ""}...`
@ -109,6 +122,7 @@ const docSummarizer = {
});
return await summarizeContent(
this.super.provider,
this.controller.signal,
document.content
);

View file

@ -1,4 +1,5 @@
const { CollectorApi } = require("../../../collectorApi");
const Provider = require("../providers/ai-provider");
const { summarizeContent } = require("../utils/summarize");
const webScraping = {
@ -61,7 +62,11 @@ const webScraping = {
);
}
if (content?.length <= 8000) {
if (!content || content?.length === 0) {
throw new Error("There was no content to be collected or read.");
}
if (content.length < Provider.contextLimit(this.super.provider)) {
return content;
}
@ -74,7 +79,11 @@ const webScraping = {
);
this.controller.abort();
});
return summarizeContent(this.controller.signal, content);
return summarizeContent(
this.super.provider,
this.controller.signal,
content
);
},
});
},

View file

@ -2,6 +2,9 @@
* A service that provides an AI client to create a completion.
*/
const { ChatOpenAI } = require("langchain/chat_models/openai");
const { ChatAnthropic } = require("langchain/chat_models/anthropic");
class Provider {
_client;
constructor(client) {
@ -14,6 +17,37 @@ class Provider {
get client() {
return this._client;
}
static LangChainChatModel(provider = "openai", config = {}) {
switch (provider) {
case "openai":
return new ChatOpenAI({
openAIApiKey: process.env.OPEN_AI_KEY,
...config,
});
case "anthropic":
return new ChatAnthropic({
anthropicApiKey: process.env.ANTHROPIC_API_KEY,
...config,
});
default:
return new ChatOpenAI({
openAIApiKey: process.env.OPEN_AI_KEY,
...config,
});
}
}
static contextLimit(provider = "openai") {
switch (provider) {
case "openai":
return 8_000;
case "anthropic":
return 100_000;
default:
return 8_000;
}
}
}
module.exports = Provider;

View file

@ -186,7 +186,8 @@ class AnthropicProvider extends Provider {
const completion = response.content.find((msg) => msg.type === "text");
return {
result:
completion?.text ?? "I could not generate a response from this.",
completion?.text ??
"The model failed to complete the task and return back a valid response.",
cost: 0,
};
} catch (error) {

View file

@ -1,7 +1,7 @@
const { loadSummarizationChain } = require("langchain/chains");
const { ChatOpenAI } = require("langchain/chat_models/openai");
const { PromptTemplate } = require("langchain/prompts");
const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
const Provider = require("../providers/ai-provider");
/**
* Summarize content using OpenAI's GPT-3.5 model.
*
@ -9,11 +9,20 @@ const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
* @param content The content to summarize.
* @returns The summarized content.
*/
async function summarizeContent(controllerSignal, content) {
const llm = new ChatOpenAI({
openAIApiKey: process.env.OPEN_AI_KEY,
const SUMMARY_MODEL = {
anthropic: "claude-3-opus-20240229", // 200,000 tokens
openai: "gpt-3.5-turbo-1106", // 16,385 tokens
};
async function summarizeContent(
provider = "openai",
controllerSignal,
content
) {
const llm = Provider.LangChainChatModel(provider, {
temperature: 0,
modelName: "gpt-3.5-turbo-16k-0613",
modelName: SUMMARY_MODEL[provider],
});
const textSplitter = new RecursiveCharacterTextSplitter({
@ -41,6 +50,7 @@ async function summarizeContent(controllerSignal, content) {
combineMapPrompt: mapPromptTemplate,
verbose: process.env.NODE_ENV === "development",
});
const res = await chain.call({
...(controllerSignal ? { signal: controllerSignal } : {}),
input_documents: docs,