@roostjs/ai

Class-based AI agents with typed tools, conversation memory, and streaming. Powered exclusively by Cloudflare Workers AI — no API keys required.

Installation

bun add @roostjs/ai @roostjs/schema

Configuration

The AI package uses Cloudflare Workers AI exclusively. There are no API keys and no external AI providers. The AI binding must be declared in wrangler.jsonc:

{
  "ai": { "binding": "AI" }
}

Register the service provider to inject the provider into all agent classes:

import { AiServiceProvider } from '@roostjs/ai';
app.register(AiServiceProvider);

The default model is @cf/meta/llama-3.1-8b-instruct. Override per class using the @Model decorator.

Agent API

Agent is an abstract base class. Extend it and implement instructions(). One agent instance represents one conversation.

abstract instructions(): string

Returns the system prompt sent as the first message in every conversation turn.

async prompt(input: string, options?: Partial<AgentConfig>): Promise<AgentResponse>

Send a user message. Appends the message to the conversation history, calls the Cloudflare AI model, executes any tool calls returned by the model (up to maxSteps), and returns the final response. Conversation history is retained on the instance for subsequent calls.

const response = await agent.prompt('Hello');
console.log(response.text);       // Model text response
console.log(response.toolCalls);  // Tool calls made (if any)
console.log(response.messages);   // Full message history

async stream(input: string): Promise<ReadableStream<Uint8Array>>

Send a user message and return a Server-Sent Events stream. Each chunk is a JSON object with a type field ('text-delta' or 'done').

static setProvider(provider: AIProvider): void

Set the AI provider for this agent class. Called automatically by AiServiceProvider during boot. Use in tests or custom setups.

static clearProvider(): void

Remove the provider set via setProvider().

static fake(responses?: string[]): void

Enable fake mode. All prompt() calls return from the responses array in order (cycling the last response when exhausted) without calling the AI binding.

static restore(): void

Disable fake mode and restore normal provider behaviour.

static assertPrompted(textOrFn: string | ((prompt: string) => boolean)): void

Assert that at least one prompt() call matched. Accepts a substring or a predicate function. Throws if no match is found. Only valid after fake().

static assertNeverPrompted(): void

Assert that no prompt() calls were made. Only valid after fake().

HasTools Interface

Implement HasTools on an agent subclass to enable tool use. The agent will pass all registered tools to the model and execute their handlers when the model requests a tool call.

tools(): Tool[]

Return the list of tool instances available to this agent.

HasStructuredOutput Interface

schema(s: typeof schema): Record<string, SchemaBuilder>

Define the expected JSON output shape using the schema builder.

Tool Interface

Implement the Tool interface to create a callable tool. Tools are passed to the model as function definitions and invoked by the agent runtime when the model emits a tool call.

description(): string

One-sentence description of what the tool does. Sent to the model to aid tool selection.

schema(s: typeof schema): Record<string, SchemaBuilder>

Define the tool's input parameters. Each key in the returned object is a parameter name mapped to a schema builder.

async handle(request: ToolRequest): Promise<string>

Execute the tool. Must return a string. The string is fed back to the model as the tool result.

ToolRequest API

get<T>(key: string): T

Retrieve a typed parameter value from the tool call arguments.

CloudflareAIProvider

The only built-in AIProvider implementation. Wraps AIClient from @roostjs/cloudflare, which in turn wraps the Cloudflare Workers Ai binding. Registered automatically by AiServiceProvider.

constructor(client: AIClient)

Construct with an AIClient wrapping the AI binding.

async chat(request: ProviderRequest): Promise<ProviderResponse>

Send a chat request to Cloudflare Workers AI and return the response.

agent() Factory

agent(options: { instructions: string; tools?: Tool[]; provider?: AIProvider }): { prompt: (input: string) => Promise<AgentResponse> }

Create an anonymous agent without defining a class. Returns an object with a single prompt method.

Decorators

Class decorators applied to Agent subclasses. All are optional.

@Model(model: string)

Set the Cloudflare Workers AI model identifier. Default is @cf/meta/llama-3.1-8b-instruct. The value must be a valid model available in your Cloudflare account.

@MaxSteps(maxSteps: number)

Maximum number of tool-call iterations per prompt() call. Defaults to 5. The model stops looping once it emits a response with no tool calls or when this limit is reached.

@Temperature(temperature: number)

Sampling temperature passed to the model. Range: 0 (deterministic) to 1.

@MaxTokens(maxTokens: number)

Maximum number of tokens in the model response.

@Provider(provider: string)

Named provider identifier. Stored in the agent config for custom provider lookup logic.

@Timeout(timeout: number)

Timeout in milliseconds for a single model call.

Types

interface AgentConfig {
  provider?: string;
  model?: string;
  maxSteps?: number;
  maxTokens?: number;
  temperature?: number;
  timeout?: number;
  queued?: boolean;
}

interface AgentMessage {
  role: 'system' | 'user' | 'assistant' | 'tool';
  content: string;
  toolCallId?: string;
  toolName?: string;
}

interface AgentResponse {
  text: string;
  messages: AgentMessage[];
  toolCalls: ToolCall[];
  usage?: { promptTokens: number; completionTokens: number };
}

type PromptResult =
  | { queued: false; text: string; messages: AgentMessage[]; toolCalls: ToolCall[]; usage?: { promptTokens: number; completionTokens: number } }
  | { queued: true; taskId: string };

interface ToolCall {
  id: string;
  name: string;
  arguments: Record<string, unknown>;
}

interface ProviderRequest {
  model: string;
  messages: AgentMessage[];
  tools?: ProviderTool[];
  maxTokens?: number;
  temperature?: number;
  queueRequest?: boolean;
}

interface ProviderResponse {
  text: string;
  toolCalls: ToolCall[];
  usage?: { promptTokens: number; completionTokens: number };
  taskId?: string;
}

GatewayAIProvider

Routes inference requests through Cloudflare AI Gateway instead of the direct Workers AI binding. Adds observability, request caching, and automatic fallback.

The gateway URL format used is: https://gateway.ai.cloudflare.com/v1/{accountId}/{gatewayId}/workers-ai/{model}

constructor(config: GatewayConfig, fallback: CloudflareAIProvider)

Construct with an account/gateway pair and a fallback provider. The fallback is called automatically when the gateway returns a non-2xx status or is unreachable.

interface GatewayConfig {
  accountId: string;
  gatewayId: string;
}

async chat(request: ProviderRequest): Promise<ProviderResponse>

Send a chat request through the AI Gateway REST endpoint. If the request contains more than one non-system message, the x-session-affinity: true header is included automatically to enable prefix caching on the gateway. Falls back to the direct CloudflareAIProvider on any error.

Session Affinity and Prefix Caching

GatewayAIProvider automatically adds x-session-affinity: true to requests that contain conversation history (more than one non-system message). This header instructs the AI Gateway to route the request to the same backend replica, enabling KV-based prefix cache hits for repeated conversation prefixes. No manual configuration is required.

Usage

import { GatewayAIProvider } from '@roostjs/ai';
import { CloudflareAIProvider, AIClient } from '@roostjs/cloudflare';

const directProvider = new CloudflareAIProvider(new AIClient(env.AI));
const gatewayProvider = new GatewayAIProvider(
  { accountId: env.CF_ACCOUNT_ID, gatewayId: env.AI_GATEWAY_ID },
  directProvider,
);

MyAgent.setProvider(gatewayProvider);

Async Inference

Pass queued: true in the options to Agent.prompt() to dispatch the inference request asynchronously. The model call is enqueued on Cloudflare's side and the call returns immediately with a taskId.

async prompt(input: string, options?: Partial<AgentConfig>): Promise<PromptResult>

When options.queued is true, the return value has shape { queued: true; taskId: string }. Otherwise it has shape { queued: false; text: string; messages: AgentMessage[]; toolCalls: ToolCall[] }. Use the queued discriminant before accessing text or taskId.

const result = await agent.prompt('Summarize this document', { queued: true });

if (result.queued) {
  // Stash result.taskId and poll later
} else {
  console.log(result.text);
}

agent.stream() throws if called with queued: true.

AIClient.poll()

Check the status of a queued async inference task via the Cloudflare REST API. Located in @roostjs/cloudflare.

async poll<T = string>(taskId: string, fetcher: typeof fetch, accountId: string): Promise<{ status: 'running' } | { status: 'done'; result: T }>

Polls https://api.cloudflare.com/client/v4/accounts/{accountId}/ai/run/tasks/{taskId}. The fetcher must carry an Authorization: Bearer <CF_API_TOKEN> header. Returns { status: 'running' } while the task is in progress and { status: 'done'; result: T } when complete.

import { AIClient } from '@roostjs/cloudflare';

const client = new AIClient(env.AI);
const poll = await client.poll(taskId, fetch.bind(env), env.CF_ACCOUNT_ID);

if (poll.status === 'done') {
  console.log(poll.result);
}

RAG Pipeline

Import from @roostjs/ai/rag.

Chunkers

Abstract base class:

abstract class Chunker {
  abstract chunk(document: Document): Chunk[];
}

TextChunker

Splits a document into fixed-size token windows with configurable overlap.

new TextChunker(options?: { chunkSize?: number; overlapPercent?: number })
OptionDefaultDescription
chunkSize400Target token count per chunk (estimated as chars / 4).
overlapPercent0.10Fraction of the previous chunk's words to repeat at the start of the next chunk.

SemanticChunker

Splits on Markdown headings and double newlines first, merges segments below 10 % of chunkSize, then falls back to TextChunker for segments that still exceed chunkSize.

new SemanticChunker(options?: { chunkSize?: number; overlapPercent?: number })

Accepts the same options as TextChunker.

EmbeddingPipeline

Wraps the AIClient to produce embedding vectors from text.

constructor(client: AIClient, model?: string)

Default model: @cf/baai/bge-base-en-v1.5 (768 dimensions). The model must match the dimensionality of the Vectorize index.

async embed(texts: string[]): Promise<number[][]>

Returns one embedding vector per input string. Throws EmbeddingError if the binding returns no data or a mismatched count.

RAGPipeline

Orchestrates chunking, embedding, and vector storage/retrieval.

constructor(store: VectorStore, embeddings: EmbeddingPipeline, chunker: Chunker, config?: RAGPipelineConfig)

async ingest(documents: Document[]): Promise<{ inserted: number }>

Chunk all documents, embed each chunk, and insert the resulting vectors into the VectorStore. Returns the number of vectors inserted.

async query(text: string): Promise<QueryResult[]>

Embed the query text, search the vector store for the top-K nearest chunks, filter by similarityThreshold, and return results sorted by descending score.

static fake(responses?: QueryResult[][]): void

Enable fake mode. query() returns from the responses array in order.

static restore(): void

Disable fake mode.

static assertIngested(predicate?: (docs: Document[]) => boolean): void

Assert that ingest() was called at least once. Provide a predicate to assert on the specific documents passed.

static assertQueried(predicate?: (text: string) => boolean): void

Assert that query() was called at least once. Provide a predicate to match against the query text.

RAG Types

interface Document {
  id: string;
  text: string;
  metadata?: Record<string, unknown>;
}

interface Chunk {
  id: string;           // `${document.id}:${chunkIndex}`
  documentId: string;
  text: string;
  tokenCount: number;
  metadata?: Record<string, unknown>;
}

interface ChunkVector {
  chunk: Chunk;
  embedding: number[];
}

interface QueryResult {
  chunk: Chunk;
  score: number;
}

interface RAGPipelineConfig {
  /** Default: 400 */
  chunkSize?: number;
  /** Default: 0.10 */
  overlapPercent?: number;
  /** Default: '@cf/baai/bge-base-en-v1.5' */
  embeddingModel?: string;
  /** Default: 5 */
  topK?: number;
  /** Default: 0.75 */
  similarityThreshold?: number;
  /** Vectorize namespace for multi-tenancy */
  namespace?: string;
}

class EmbeddingError extends Error {}