model specific summarization ()

* model specific summarization

* update guard functions

* patch model picker and key inputs
This commit is contained in:
Timothy Carambat 2024-04-17 14:04:51 -07:00 committed by GitHub
parent 9449fcd737
commit 81fd82e133
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 80 additions and 10 deletions
frontend/src/pages/WorkspaceSettings/AgentConfig/AgentModelSelection
server/utils/agents/aibitat

View file

@ -97,7 +97,7 @@ export default function AgentModelSelection({
<option <option
key={model.id} key={model.id}
value={model.id} value={model.id}
selected={workspace?.chatModel === model.id} selected={workspace?.agentModel === model.id}
> >
{model.name} {model.name}
</option> </option>

View file

@ -12,6 +12,7 @@ const { Telemetry } = require("../../../models/telemetry.js");
class AIbitat { class AIbitat {
emitter = new EventEmitter(); emitter = new EventEmitter();
provider = null;
defaultProvider = null; defaultProvider = null;
defaultInterrupt; defaultInterrupt;
maxRounds; maxRounds;
@ -39,6 +40,7 @@ class AIbitat {
provider, provider,
...rest, ...rest,
}; };
this.provider = this.defaultProvider.provider;
} }
/** /**

View file

@ -2,6 +2,7 @@ const { Document } = require("../../../../models/documents");
const { safeJsonParse } = require("../../../http"); const { safeJsonParse } = require("../../../http");
const { validate } = require("uuid"); const { validate } = require("uuid");
const { summarizeContent } = require("../utils/summarize"); const { summarizeContent } = require("../utils/summarize");
const Provider = require("../providers/ai-provider");
const docSummarizer = { const docSummarizer = {
name: "document-summarizer", name: "document-summarizer",
@ -95,7 +96,19 @@ const docSummarizer = {
document?.title ?? "a discovered file." document?.title ?? "a discovered file."
}` }`
); );
if (document?.content?.length < 8000) return content;
if (!document.content || document.content.length === 0) {
throw new Error(
"This document has no readable content that could be found."
);
}
if (
document.content?.length <
Provider.contextLimit(this.super.provider)
) {
return document.content;
}
this.super.introspect( this.super.introspect(
`${this.caller}: Summarizing ${document?.title ?? ""}...` `${this.caller}: Summarizing ${document?.title ?? ""}...`
@ -109,6 +122,7 @@ const docSummarizer = {
}); });
return await summarizeContent( return await summarizeContent(
this.super.provider,
this.controller.signal, this.controller.signal,
document.content document.content
); );

View file

@ -1,4 +1,5 @@
const { CollectorApi } = require("../../../collectorApi"); const { CollectorApi } = require("../../../collectorApi");
const Provider = require("../providers/ai-provider");
const { summarizeContent } = require("../utils/summarize"); const { summarizeContent } = require("../utils/summarize");
const webScraping = { const webScraping = {
@ -61,7 +62,11 @@ const webScraping = {
); );
} }
if (content?.length <= 8000) { if (!content || content?.length === 0) {
throw new Error("There was no content to be collected or read.");
}
if (content.length < Provider.contextLimit(this.super.provider)) {
return content; return content;
} }
@ -74,7 +79,11 @@ const webScraping = {
); );
this.controller.abort(); this.controller.abort();
}); });
return summarizeContent(this.controller.signal, content); return summarizeContent(
this.super.provider,
this.controller.signal,
content
);
}, },
}); });
}, },

View file

@ -2,6 +2,9 @@
* A service that provides an AI client to create a completion. * A service that provides an AI client to create a completion.
*/ */
const { ChatOpenAI } = require("langchain/chat_models/openai");
const { ChatAnthropic } = require("langchain/chat_models/anthropic");
class Provider { class Provider {
_client; _client;
constructor(client) { constructor(client) {
@ -14,6 +17,37 @@ class Provider {
get client() { get client() {
return this._client; return this._client;
} }
static LangChainChatModel(provider = "openai", config = {}) {
switch (provider) {
case "openai":
return new ChatOpenAI({
openAIApiKey: process.env.OPEN_AI_KEY,
...config,
});
case "anthropic":
return new ChatAnthropic({
anthropicApiKey: process.env.ANTHROPIC_API_KEY,
...config,
});
default:
return new ChatOpenAI({
openAIApiKey: process.env.OPEN_AI_KEY,
...config,
});
}
}
static contextLimit(provider = "openai") {
switch (provider) {
case "openai":
return 8_000;
case "anthropic":
return 100_000;
default:
return 8_000;
}
}
} }
module.exports = Provider; module.exports = Provider;

View file

@ -186,7 +186,8 @@ class AnthropicProvider extends Provider {
const completion = response.content.find((msg) => msg.type === "text"); const completion = response.content.find((msg) => msg.type === "text");
return { return {
result: result:
completion?.text ?? "I could not generate a response from this.", completion?.text ??
"The model failed to complete the task and return back a valid response.",
cost: 0, cost: 0,
}; };
} catch (error) { } catch (error) {

View file

@ -1,7 +1,7 @@
const { loadSummarizationChain } = require("langchain/chains"); const { loadSummarizationChain } = require("langchain/chains");
const { ChatOpenAI } = require("langchain/chat_models/openai");
const { PromptTemplate } = require("langchain/prompts"); const { PromptTemplate } = require("langchain/prompts");
const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter"); const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
const Provider = require("../providers/ai-provider");
/** /**
* Summarize content using OpenAI's GPT-3.5 model. * Summarize content using OpenAI's GPT-3.5 model.
* *
@ -9,11 +9,20 @@ const { RecursiveCharacterTextSplitter } = require("langchain/text_splitter");
* @param content The content to summarize. * @param content The content to summarize.
* @returns The summarized content. * @returns The summarized content.
*/ */
async function summarizeContent(controllerSignal, content) {
const llm = new ChatOpenAI({ const SUMMARY_MODEL = {
openAIApiKey: process.env.OPEN_AI_KEY, anthropic: "claude-3-opus-20240229", // 200,000 tokens
openai: "gpt-3.5-turbo-1106", // 16,385 tokens
};
async function summarizeContent(
provider = "openai",
controllerSignal,
content
) {
const llm = Provider.LangChainChatModel(provider, {
temperature: 0, temperature: 0,
modelName: "gpt-3.5-turbo-16k-0613", modelName: SUMMARY_MODEL[provider],
}); });
const textSplitter = new RecursiveCharacterTextSplitter({ const textSplitter = new RecursiveCharacterTextSplitter({
@ -41,6 +50,7 @@ async function summarizeContent(controllerSignal, content) {
combineMapPrompt: mapPromptTemplate, combineMapPrompt: mapPromptTemplate,
verbose: process.env.NODE_ENV === "development", verbose: process.env.NODE_ENV === "development",
}); });
const res = await chain.call({ const res = await chain.call({
...(controllerSignal ? { signal: controllerSignal } : {}), ...(controllerSignal ? { signal: controllerSignal } : {}),
input_documents: docs, input_documents: docs,