@roostjs/ai
Class-based AI agents with typed tools, conversation memory, and streaming. Powered exclusively by Cloudflare Workers AI — no API keys required.
Installation
bun add @roostjs/ai @roostjs/schemaConfiguration
The AI package uses Cloudflare Workers AI exclusively. There are no
API keys and no external AI providers. The AI binding must be declared
in wrangler.jsonc:
{
"ai": { "binding": "AI" }
}Register the service provider to inject the provider into all agent classes:
import { AiServiceProvider } from '@roostjs/ai';
app.register(AiServiceProvider);The default model is @cf/meta/llama-3.1-8b-instruct. Override per class
using the @Model decorator.
Agent API
Agent is an abstract base class. Extend it and implement
instructions(). One agent instance represents one conversation.
abstract instructions(): string
Returns the system prompt sent as the first message in every conversation turn.
async prompt(input: string, options?: Partial<AgentConfig>): Promise<AgentResponse>
Send a user message. Appends the message to the conversation history, calls the
Cloudflare AI model, executes any tool calls returned by the model (up to
maxSteps), and returns the final response. Conversation history is
retained on the instance for subsequent calls.
const response = await agent.prompt('Hello');
console.log(response.text); // Model text response
console.log(response.toolCalls); // Tool calls made (if any)
console.log(response.messages); // Full message historyasync stream(input: string): Promise<ReadableStream<Uint8Array>>
Send a user message and return a Server-Sent Events stream. Each chunk is a JSON
object with a type field ('text-delta' or 'done').
static setProvider(provider: AIProvider): void
Set the AI provider for this agent class. Called automatically by
AiServiceProvider during boot. Use in tests or custom setups.
static clearProvider(): void
Remove the provider set via setProvider().
static fake(responses?: string[]): void
Enable fake mode. All prompt() calls return from the responses
array in order (cycling the last response when exhausted) without calling the AI binding.
static restore(): void
Disable fake mode and restore normal provider behaviour.
static assertPrompted(textOrFn: string | ((prompt: string) => boolean)): void
Assert that at least one prompt() call matched. Accepts a substring
or a predicate function. Throws if no match is found. Only valid after
fake().
static assertNeverPrompted(): void
Assert that no prompt() calls were made. Only valid after fake().
HasTools Interface
Implement HasTools on an agent subclass to enable tool use. The agent
will pass all registered tools to the model and execute their handlers when the
model requests a tool call.
tools(): Tool[]
Return the list of tool instances available to this agent.
HasStructuredOutput Interface
schema(s: typeof schema): Record<string, SchemaBuilder>
Define the expected JSON output shape using the schema builder.
Tool Interface
Implement the Tool interface to create a callable tool. Tools are
passed to the model as function definitions and invoked by the agent runtime
when the model emits a tool call.
description(): string
One-sentence description of what the tool does. Sent to the model to aid tool selection.
schema(s: typeof schema): Record<string, SchemaBuilder>
Define the tool's input parameters. Each key in the returned object is a parameter name mapped to a schema builder.
async handle(request: ToolRequest): Promise<string>
Execute the tool. Must return a string. The string is fed back to the model as the tool result.
ToolRequest API
get<T>(key: string): T
Retrieve a typed parameter value from the tool call arguments.
CloudflareAIProvider
The only built-in AIProvider implementation. Wraps AIClient
from @roostjs/cloudflare, which in turn wraps the Cloudflare Workers
Ai binding. Registered automatically by AiServiceProvider.
constructor(client: AIClient)
Construct with an AIClient wrapping the AI binding.
async chat(request: ProviderRequest): Promise<ProviderResponse>
Send a chat request to Cloudflare Workers AI and return the response.
agent() Factory
agent(options: { instructions: string; tools?: Tool[]; provider?: AIProvider }): { prompt: (input: string) => Promise<AgentResponse> }
Create an anonymous agent without defining a class. Returns an object with a single
prompt method.
Decorators
Class decorators applied to Agent subclasses. All are optional.
@Model(model: string)
Set the Cloudflare Workers AI model identifier. Default is
@cf/meta/llama-3.1-8b-instruct. The value must be a valid model
available in your Cloudflare account.
@MaxSteps(maxSteps: number)
Maximum number of tool-call iterations per prompt() call. Defaults to
5. The model stops looping once it emits a response with no tool calls
or when this limit is reached.
@Temperature(temperature: number)
Sampling temperature passed to the model. Range: 0 (deterministic) to 1.
@MaxTokens(maxTokens: number)
Maximum number of tokens in the model response.
@Provider(provider: string)
Named provider identifier. Stored in the agent config for custom provider lookup logic.
@Timeout(timeout: number)
Timeout in milliseconds for a single model call.
Types
interface AgentConfig {
provider?: string;
model?: string;
maxSteps?: number;
maxTokens?: number;
temperature?: number;
timeout?: number;
queued?: boolean;
}
interface AgentMessage {
role: 'system' | 'user' | 'assistant' | 'tool';
content: string;
toolCallId?: string;
toolName?: string;
}
interface AgentResponse {
text: string;
messages: AgentMessage[];
toolCalls: ToolCall[];
usage?: { promptTokens: number; completionTokens: number };
}
type PromptResult =
| { queued: false; text: string; messages: AgentMessage[]; toolCalls: ToolCall[]; usage?: { promptTokens: number; completionTokens: number } }
| { queued: true; taskId: string };
interface ToolCall {
id: string;
name: string;
arguments: Record<string, unknown>;
}
interface ProviderRequest {
model: string;
messages: AgentMessage[];
tools?: ProviderTool[];
maxTokens?: number;
temperature?: number;
queueRequest?: boolean;
}
interface ProviderResponse {
text: string;
toolCalls: ToolCall[];
usage?: { promptTokens: number; completionTokens: number };
taskId?: string;
}GatewayAIProvider
Routes inference requests through Cloudflare AI Gateway instead of the direct Workers AI binding. Adds observability, request caching, and automatic fallback.
The gateway URL format used is:
https://gateway.ai.cloudflare.com/v1/{accountId}/{gatewayId}/workers-ai/{model}
constructor(config: GatewayConfig, fallback: CloudflareAIProvider)
Construct with an account/gateway pair and a fallback provider. The fallback is called automatically when the gateway returns a non-2xx status or is unreachable.
interface GatewayConfig {
accountId: string;
gatewayId: string;
}async chat(request: ProviderRequest): Promise<ProviderResponse>
Send a chat request through the AI Gateway REST endpoint. If the request contains more than one non-system message, the x-session-affinity: true header is included automatically to enable prefix caching on the gateway. Falls back to the direct CloudflareAIProvider on any error.
Session Affinity and Prefix Caching
GatewayAIProvider automatically adds x-session-affinity: true to requests that contain conversation history (more than one non-system message). This header instructs the AI Gateway to route the request to the same backend replica, enabling KV-based prefix cache hits for repeated conversation prefixes. No manual configuration is required.
Usage
import { GatewayAIProvider } from '@roostjs/ai';
import { CloudflareAIProvider, AIClient } from '@roostjs/cloudflare';
const directProvider = new CloudflareAIProvider(new AIClient(env.AI));
const gatewayProvider = new GatewayAIProvider(
{ accountId: env.CF_ACCOUNT_ID, gatewayId: env.AI_GATEWAY_ID },
directProvider,
);
MyAgent.setProvider(gatewayProvider);Async Inference
Pass queued: true in the options to Agent.prompt() to dispatch the inference request asynchronously. The model call is enqueued on Cloudflare's side and the call returns immediately with a taskId.
async prompt(input: string, options?: Partial<AgentConfig>): Promise<PromptResult>
When options.queued is true, the return value has shape { queued: true; taskId: string }. Otherwise it has shape { queued: false; text: string; messages: AgentMessage[]; toolCalls: ToolCall[] }. Use the queued discriminant before accessing text or taskId.
const result = await agent.prompt('Summarize this document', { queued: true });
if (result.queued) {
// Stash result.taskId and poll later
} else {
console.log(result.text);
}agent.stream() throws if called with queued: true.
AIClient.poll()
Check the status of a queued async inference task via the Cloudflare REST API. Located in @roostjs/cloudflare.
async poll<T = string>(taskId: string, fetcher: typeof fetch, accountId: string): Promise<{ status: 'running' } | { status: 'done'; result: T }>
Polls https://api.cloudflare.com/client/v4/accounts/{accountId}/ai/run/tasks/{taskId}. The fetcher must carry an Authorization: Bearer <CF_API_TOKEN> header. Returns { status: 'running' } while the task is in progress and { status: 'done'; result: T } when complete.
import { AIClient } from '@roostjs/cloudflare';
const client = new AIClient(env.AI);
const poll = await client.poll(taskId, fetch.bind(env), env.CF_ACCOUNT_ID);
if (poll.status === 'done') {
console.log(poll.result);
}RAG Pipeline
Import from @roostjs/ai/rag.
Chunkers
Abstract base class:
abstract class Chunker {
abstract chunk(document: Document): Chunk[];
}TextChunker
Splits a document into fixed-size token windows with configurable overlap.
new TextChunker(options?: { chunkSize?: number; overlapPercent?: number })| Option | Default | Description |
|---|---|---|
chunkSize | 400 | Target token count per chunk (estimated as chars / 4). |
overlapPercent | 0.10 | Fraction of the previous chunk's words to repeat at the start of the next chunk. |
SemanticChunker
Splits on Markdown headings and double newlines first, merges segments below 10 % of chunkSize, then falls back to TextChunker for segments that still exceed chunkSize.
new SemanticChunker(options?: { chunkSize?: number; overlapPercent?: number })Accepts the same options as TextChunker.
EmbeddingPipeline
Wraps the AIClient to produce embedding vectors from text.
constructor(client: AIClient, model?: string)
Default model: @cf/baai/bge-base-en-v1.5 (768 dimensions). The model must match the dimensionality of the Vectorize index.
async embed(texts: string[]): Promise<number[][]>
Returns one embedding vector per input string. Throws EmbeddingError if the binding returns no data or a mismatched count.
RAGPipeline
Orchestrates chunking, embedding, and vector storage/retrieval.
constructor(store: VectorStore, embeddings: EmbeddingPipeline, chunker: Chunker, config?: RAGPipelineConfig)
async ingest(documents: Document[]): Promise<{ inserted: number }>
Chunk all documents, embed each chunk, and insert the resulting vectors into the VectorStore. Returns the number of vectors inserted.
async query(text: string): Promise<QueryResult[]>
Embed the query text, search the vector store for the top-K nearest chunks, filter by similarityThreshold, and return results sorted by descending score.
static fake(responses?: QueryResult[][]): void
Enable fake mode. query() returns from the responses array in order.
static restore(): void
Disable fake mode.
static assertIngested(predicate?: (docs: Document[]) => boolean): void
Assert that ingest() was called at least once. Provide a predicate to assert on the specific documents passed.
static assertQueried(predicate?: (text: string) => boolean): void
Assert that query() was called at least once. Provide a predicate to match against the query text.
RAG Types
interface Document {
id: string;
text: string;
metadata?: Record<string, unknown>;
}
interface Chunk {
id: string; // `${document.id}:${chunkIndex}`
documentId: string;
text: string;
tokenCount: number;
metadata?: Record<string, unknown>;
}
interface ChunkVector {
chunk: Chunk;
embedding: number[];
}
interface QueryResult {
chunk: Chunk;
score: number;
}
interface RAGPipelineConfig {
/** Default: 400 */
chunkSize?: number;
/** Default: 0.10 */
overlapPercent?: number;
/** Default: '@cf/baai/bge-base-en-v1.5' */
embeddingModel?: string;
/** Default: 5 */
topK?: number;
/** Default: 0.75 */
similarityThreshold?: number;
/** Vectorize namespace for multi-tenancy */
namespace?: string;
}
class EmbeddingError extends Error {}