`;
// Auto-scroll to bottom if enabled
if (auto_scroll) {
const content = container.querySelector('.log-viewer-content');
if (content) {
content.scrollTop = content.scrollHeight;
}
}
return container;
}
}
// Badge component renderer
export class BadgeComponentRenderer extends BaseComponentRenderer {
render(component: RichComponent): HTMLElement {
const container = document.createElement('span');
container.className = `rich-component rich-badge badge-${component.data.variant} badge-${component.data.size}`;
container.dataset.componentId = component.id;
const { text, icon } = component.data;
container.innerHTML = `
${icon ? `${icon}` : ''}
${text}
`;
return container;
}
}
// Icon text component renderer
export class IconTextComponentRenderer extends BaseComponentRenderer {
render(component: RichComponent): HTMLElement {
const container = document.createElement('div');
container.className = `rich-component rich-icon-text icon-text-${component.data.variant} icon-text-${component.data.size} icon-text-${component.data.alignment}`;
container.dataset.componentId = component.id;
const { icon, text } = component.data;
container.innerHTML = `
${icon}${text}
`;
return container;
}
}
// Button component renderer
export class ButtonComponentRenderer extends BaseComponentRenderer {
render(component: RichComponent): HTMLElement {
const button = document.createElement('button');
button.className = `rich-component rich-button button-${component.data.variant} button-${component.data.size}`;
button.dataset.componentId = component.id;
const { label, action, disabled, icon, icon_position, full_width, loading } = component.data;
if (disabled || loading) {
button.disabled = true;
}
if (full_width) {
button.classList.add('button-full-width');
}
if (loading) {
button.classList.add('button-loading');
}
// Build button content
let buttonContent = '';
if (loading) {
buttonContent = `⏳${label}`;
} else if (icon) {
if (icon_position === 'right') {
buttonContent = `${label}${icon}`;
} else {
buttonContent = `${icon}${label}`;
}
} else {
buttonContent = `${label}`;
}
button.innerHTML = buttonContent;
// Add click handler
if (action && !disabled && !loading) {
button.addEventListener('click', async () => {
console.log('🔘 Button clicked:', label);
console.log(' Sending action:', action);
// Apply visual feedback immediately
button.disabled = true;
button.classList.add('button-transitioning', 'button-clicked');
// Find vanna-chat component and send message with button action
const vannaChat = document.querySelector('vanna-chat') as any;
console.log(' Found vanna-chat:', !!vannaChat);
if (vannaChat && typeof vannaChat.sendMessage === 'function') {
console.log(' Calling sendMessage...');
try {
const success = await vannaChat.sendMessage(action);
if (success) {
console.log(' ✓ Message sent successfully');
} else {
console.log(' ✗ Message failed, restoring button state');
// Restore button state if it wasn't originally disabled
if (!disabled) {
button.disabled = false;
}
button.classList.remove('button-transitioning', 'button-clicked');
}
} catch (error) {
console.error(' ✗ Message failed with error:', error);
// Restore button state if it wasn't originally disabled
if (!disabled) {
button.disabled = false;
}
button.classList.remove('button-transitioning', 'button-clicked');
}
} else {
console.error(' ✗ vanna-chat not found or sendMessage not available');
// Restore button state if it wasn't originally disabled
if (!disabled) {
button.disabled = false;
}
button.classList.remove('button-transitioning', 'button-clicked');
}
});
}
return button;
}
update(element: HTMLElement, component: RichComponent, updates?: Record): void {
if (!updates) return super.update(element, component);
const button = element as HTMLButtonElement;
if (updates.disabled !== undefined) {
button.disabled = updates.disabled;
}
if (updates.loading !== undefined) {
button.disabled = updates.loading;
if (updates.loading) {
button.classList.add('button-loading');
} else {
button.classList.remove('button-loading');
}
}
if (updates.label || updates.icon || updates.icon_position) {
// Re-render content
super.update(element, component);
}
}
}
// Button group component renderer
export class ButtonGroupComponentRenderer extends BaseComponentRenderer {
render(component: RichComponent): HTMLElement {
const container = document.createElement('div');
container.className = `rich-component rich-button-group button-group-${component.data.orientation} button-group-spacing-${component.data.spacing} button-group-align-${component.data.align}`;
container.dataset.componentId = component.id;
const { buttons = [], full_width } = component.data;
if (full_width) {
container.classList.add('button-group-full-width');
}
// Render each button
buttons.forEach((buttonConfig: any, index: number) => {
const button = document.createElement('button');
button.className = `rich-button button-${buttonConfig.variant || 'secondary'} button-${buttonConfig.size || 'medium'}`;
button.dataset.buttonIndex = String(index);
// Store original disabled state
if (buttonConfig.disabled) {
button.disabled = true;
button.dataset.originallyDisabled = 'true';
} else {
button.dataset.originallyDisabled = 'false';
}
// Build button content
let buttonContent = '';
if (buttonConfig.icon) {
if (buttonConfig.icon_position === 'right') {
buttonContent = `${buttonConfig.label}${buttonConfig.icon}`;
} else {
buttonContent = `${buttonConfig.icon}${buttonConfig.label}`;
}
} else {
buttonContent = `${buttonConfig.label}`;
}
button.innerHTML = buttonContent;
// Add click handler with enhanced functionality
if (buttonConfig.action && !buttonConfig.disabled) {
button.addEventListener('click', async () => {
console.log('🔘 Button Group button clicked:', buttonConfig.label);
console.log(' Button index:', index);
console.log(' Sending action:', buttonConfig.action);
// Immediately apply visual changes to all buttons in the group
this.applyButtonGroupClickState(container, index);
// Find vanna-chat component and send message with button action
const vannaChat = document.querySelector('vanna-chat') as any;
console.log(' Found vanna-chat:', !!vannaChat);
if (vannaChat && typeof vannaChat.sendMessage === 'function') {
console.log(' Calling sendMessage...');
try {
const success = await vannaChat.sendMessage(buttonConfig.action);
if (success) {
console.log(' ✓ Message sent successfully');
} else {
console.log(' ✗ Message failed, restoring button state');
this.restoreButtonGroupState(container);
}
} catch (error) {
console.error(' ✗ Message failed with error:', error);
this.restoreButtonGroupState(container);
}
} else {
console.error(' ✗ vanna-chat not found or sendMessage not available');
this.restoreButtonGroupState(container);
}
});
}
container.appendChild(button);
});
return container;
}
private applyButtonGroupClickState(container: HTMLElement, clickedIndex: number): void {
const buttons = container.querySelectorAll('button') as NodeListOf;
buttons.forEach((button, index) => {
// Disable all buttons
button.disabled = true;
// Add transition class for animation
button.classList.add('button-transitioning');
if (index === clickedIndex) {
// Highlight the clicked button
button.classList.add('button-clicked', 'button-highlighted');
} else {
// Gray out other buttons
button.classList.add('button-grayed-out');
}
});
}
private restoreButtonGroupState(container: HTMLElement): void {
const buttons = container.querySelectorAll('button') as NodeListOf;
buttons.forEach((button) => {
// Re-enable buttons (unless they were originally disabled)
const originallyDisabled = button.dataset.originallyDisabled === 'true';
if (!originallyDisabled) {
button.disabled = false;
}
// Remove all state classes
button.classList.remove(
'button-clicked',
'button-highlighted',
'button-grayed-out',
'button-transitioning'
);
});
}
}
// Chart component renderer (for Plotly charts)
export class ChartComponentRenderer extends BaseComponentRenderer {
render(component: RichComponent): HTMLElement {
const container = document.createElement('div');
container.className = 'rich-component rich-chart';
container.dataset.componentId = component.id;
// The ChartComponent.data field contains the Plotly figure directly
// Structure: component.data = { data: [...traces...], layout: {...}, title: "...", config: {...} }
const { data: plotlyData, layout, title, config = {} } = component.data;
console.log('ChartComponentRenderer: Received component.data:', component.data);
console.log('ChartComponentRenderer: plotlyData:', plotlyData);
console.log('ChartComponentRenderer: layout:', layout);
// Check if we have a valid Plotly figure structure
if (plotlyData && Array.isArray(plotlyData) && layout) {
// Create plotly-chart web component
const chartElement = document.createElement('plotly-chart') as any;
// Set theme to match current theme
const vannaChat = document.querySelector('vanna-chat');
if (vannaChat) {
chartElement.theme = vannaChat.getAttribute('theme') || 'dark';
}
// Wrap in container with optional title
if (title) {
container.innerHTML = `
${title}
`;
container.querySelector('.chart-content')?.appendChild(chartElement);
} else {
container.appendChild(chartElement);
}
// Set data AFTER the element is in the DOM
// This ensures the web component is fully initialized
requestAnimationFrame(() => {
chartElement.data = plotlyData; // Plotly traces (array)
chartElement.layout = layout; // Plotly layout (object)
chartElement.config = config;
console.log('ChartComponentRenderer: Set properties after DOM attachment');
console.log('ChartComponentRenderer: chartElement.data:', chartElement.data);
console.log('ChartComponentRenderer: chartElement.layout:', chartElement.layout);
});
} else {
// Fallback for invalid chart data
container.innerHTML = `
`;
// Attach event listeners
this.attachEventListeners(container, component);
// Fire artifact-opened event for creation
const shouldRenderInChat = this.fireArtifactOpenedEvent(component, 'created', container);
// If default was prevented, show a placeholder instead
if (!shouldRenderInChat) {
container.innerHTML = `
🎨
${title || 'Artifact'} opened externally
${artifact_type}
`;
// Add reopen functionality
const reopenBtn = container.querySelector('.placeholder-reopen') as HTMLButtonElement;
if (reopenBtn) {
reopenBtn.addEventListener('click', () => {
this.fireArtifactOpenedEvent(component, 'user-action', container);
});
}
}
return container;
}
private attachEventListeners(container: HTMLElement, component: RichComponent): void {
// External button click
const externalBtn = container.querySelector('.external-btn') as HTMLButtonElement;
if (externalBtn) {
externalBtn.addEventListener('click', () => {
this.fireArtifactOpenedEvent(component, 'user-action', container);
});
}
// Fullscreen button click
const fullscreenBtn = container.querySelector('.fullscreen-btn') as HTMLButtonElement;
if (fullscreenBtn) {
fullscreenBtn.addEventListener('click', () => {
this.openFullscreen(component);
});
}
// Edit button click (placeholder for future implementation)
const editBtn = container.querySelector('.edit-btn') as HTMLButtonElement;
if (editBtn) {
editBtn.addEventListener('click', () => {
this.openEditor(component);
});
}
}
private fireArtifactOpenedEvent(component: RichComponent, trigger: 'created' | 'user-action', container: HTMLElement): boolean {
console.log('🎯 fireArtifactOpenedEvent called:', { trigger, artifactId: component.data.artifact_id });
this.defaultPrevented = false;
const eventDetail: ArtifactOpenedEventDetail = {
artifactId: component.data.artifact_id,
content: component.data.content,
type: component.data.artifact_type,
title: component.data.title,
description: component.data.description,
trigger,
preventDefault: () => {
console.log('🛑 preventDefault called!');
this.defaultPrevented = true;
},
getStandaloneHTML: () => this.generateStandaloneHTML(component),
timestamp: new Date().toISOString()
};
const event = new CustomEvent('artifact-opened', {
detail: eventDetail,
bubbles: true,
cancelable: true
});
console.log('📡 Dispatching artifact-opened event:', event);
// Fire the event from the container element (should bubble up to vanna-chat)
container.dispatchEvent(event);
// Also dispatch directly on the vanna-chat element as backup
const vannaChat = container.closest('vanna-chat');
if (vannaChat) {
console.log('📡 Also dispatching on vanna-chat element');
vannaChat.dispatchEvent(new CustomEvent('artifact-opened', {
detail: eventDetail,
bubbles: true,
cancelable: true
}));
}
console.log('📨 Event dispatched. defaultPrevented:', this.defaultPrevented);
// Handle default behavior if not prevented and user triggered
if (!this.defaultPrevented && trigger === 'user-action') {
this.handleDefaultAction(component);
}
// Return whether we should render in chat (true if default not prevented)
return !this.defaultPrevented;
}
private generateStandaloneHTML(component: RichComponent): string {
const { content, title, dependencies = [] } = component.data;
let dependenciesHTML = '';
// Add common CDN links for dependencies
if (dependencies.includes('d3')) {
dependenciesHTML += '\n';
}
if (dependencies.includes('plotly')) {
dependenciesHTML += '\n';
}
if (dependencies.includes('three') || dependencies.includes('threejs')) {
dependenciesHTML += '\n';
}
return `
${title || 'Artifact'}
${dependenciesHTML}
`;
// Add styles
overlay.style.cssText = `
position: fixed;
top: 0;
left: 0;
width: 100vw;
height: 100vh;
background: white;
z-index: 10000;
display: flex;
flex-direction: column;
`;
const header = overlay.querySelector('.fullscreen-header') as HTMLElement;
header.style.cssText = `
padding: 16px;
border-bottom: 1px solid #eee;
display: flex;
justify-content: space-between;
align-items: center;
`;
const content = overlay.querySelector('.fullscreen-content') as HTMLElement;
content.style.cssText = `
flex: 1;
padding: 16px;
`;
const iframe = overlay.querySelector('.fullscreen-iframe') as HTMLIFrameElement;
iframe.style.cssText = `
width: 100%;
height: 100%;
border: none;
`;
// Close button functionality
const closeBtn = overlay.querySelector('.close-fullscreen') as HTMLButtonElement;
closeBtn.addEventListener('click', () => {
document.body.removeChild(overlay);
});
// Escape key to close
const handleEscape = (e: KeyboardEvent) => {
if (e.key === 'Escape') {
document.body.removeChild(overlay);
document.removeEventListener('keydown', handleEscape);
}
};
document.addEventListener('keydown', handleEscape);
document.body.appendChild(overlay);
}
private openEditor(component: RichComponent): void {
// Placeholder for future editor implementation
console.log('Editor functionality not yet implemented for artifact:', component.data.artifact_id);
}
private escapeHtml(html: string): string {
const div = document.createElement('div');
div.textContent = html;
return div.innerHTML.replace(/"/g, '"');
}
}
// User message component renderer
export class UserMessageComponentRenderer extends BaseComponentRenderer {
render(component: RichComponent): HTMLElement {
const messageEl = document.createElement('vanna-message');
messageEl.setAttribute('theme', 'light'); // Could be made dynamic
messageEl.dataset.componentId = component.id;
// Set properties for vanna-message
(messageEl as any).content = component.data.content || '';
(messageEl as any).type = 'user';
(messageEl as any).timestamp = Date.parse(component.timestamp);
return messageEl;
}
}
// Assistant message component renderer
export class AssistantMessageComponentRenderer extends BaseComponentRenderer {
render(component: RichComponent): HTMLElement {
const messageEl = document.createElement('vanna-message');
messageEl.setAttribute('theme', 'light'); // Could be made dynamic
messageEl.dataset.componentId = component.id;
// Set properties for vanna-message
(messageEl as any).content = component.data.content || '';
(messageEl as any).type = 'assistant';
(messageEl as any).timestamp = Date.parse(component.timestamp);
return messageEl;
}
}
// Component registry for managing all component types
export class ComponentRegistry {
private renderers: Map = new Map();
constructor() {
// Register primitive component renderers (domain-agnostic)
this.register('status_card', new StatusCardComponentRenderer());
this.register('progress_display', new ProgressDisplayComponentRenderer());
this.register('log_viewer', new LogViewerComponentRenderer());
this.register('badge', new BadgeComponentRenderer());
this.register('icon_text', new IconTextComponentRenderer());
// Register existing component renderers
this.register('card', new CardComponentRenderer());
this.register('task_list', new TaskListComponentRenderer());
this.register('progress_bar', new ProgressBarComponentRenderer());
this.register('notification', new NotificationComponentRenderer());
this.register('status_indicator', new StatusIndicatorComponentRenderer());
this.register('text', new TextComponentRenderer());
this.register('dataframe', new DataFrameComponentRenderer());
this.register('chart', new ChartComponentRenderer());
// Register interactive component renderers
this.register('button', new ButtonComponentRenderer());
this.register('button_group', new ButtonGroupComponentRenderer());
// Register artifact component renderer
this.register('artifact', new ArtifactComponentRenderer());
// Register message component renderers
this.register('user-message', new UserMessageComponentRenderer());
this.register('assistant-message', new AssistantMessageComponentRenderer());
}
register(type: string, renderer: ComponentRenderer): void {
this.renderers.set(type, renderer);
}
render(component: RichComponent): HTMLElement {
// Check if this is a component that should use web components
const webComponentTag = this.getWebComponentTag(component.type);
if (webComponentTag) {
return this.renderWebComponent(webComponentTag, component);
}
// Use the old renderer system for other components
const renderer = this.renderers.get(component.type);
if (!renderer) {
return this.renderFallback(component);
}
return renderer.render(component);
}
private getWebComponentTag(type: string): string | null {
const mapping: Record = {
'card': 'rich-card',
'task_list': 'rich-task-list',
'progress_bar': 'rich-progress-bar',
// We'll add more mappings as we convert other components
};
return mapping[type] || null;
}
private renderWebComponent(tagName: string, component: RichComponent): HTMLElement {
const element = document.createElement(tagName) as any;
// Set properties based on component data
Object.keys(component.data).forEach(key => {
if (key === 'actions' && Array.isArray(component.data[key])) {
element.actions = component.data[key];
} else {
element[key] = component.data[key];
}
});
// Set theme to match the parent VannaChat theme
element.setAttribute('theme', this.getCurrentTheme());
return element;
}
private getCurrentTheme(): string {
// Try to get theme from the parent VannaChat component
const vannaChat = document.querySelector('vanna-chat');
if (vannaChat) {
return vannaChat.getAttribute('theme') || 'dark';
}
return 'dark';
}
update(element: HTMLElement, component: RichComponent, updates?: Record): void {
const renderer = this.renderers.get(component.type);
if (renderer) {
renderer.update(element, component, updates);
}
}
remove(element: HTMLElement): void {
element.remove();
}
private renderFallback(component: RichComponent): HTMLElement {
const container = document.createElement('div');
container.className = 'rich-component rich-fallback';
container.dataset.componentId = component.id;
container.innerHTML = `
Unknown Component: ${component.type}
${JSON.stringify(component.data, null, 2)}
`;
return container;
}
}
// Component manager for handling component lifecycle
export class ComponentManager {
private components: Map = new Map();
private elements: Map = new Map();
private registry: ComponentRegistry = new ComponentRegistry();
private container: HTMLElement;
private readonly sharedFields = new Set([
'id',
'type',
'lifecycle',
'layout',
'theme',
'children',
'timestamp',
'visible',
'interactive',
]);
constructor(container: HTMLElement) {
this.container = container;
ensureRichComponentStyles(this.container);
}
processUpdate(update: ComponentUpdate): void {
// Handle UI state updates with special processing
if (update.component && this.isUIStateUpdate(update.component)) {
this.processUIStateUpdate(update.component);
return;
}
switch (update.operation) {
case 'create':
this.createComponent(update);
break;
case 'update':
this.updateComponent(update);
break;
case 'replace':
this.replaceComponent(update);
break;
case 'remove':
this.removeComponent(update);
break;
}
}
private createComponent(update: ComponentUpdate): void {
if (!update.component) return;
const component = this.normalizeComponent(update.component);
const element = this.registry.render(component);
this.components.set(component.id, component);
this.elements.set(component.id, element);
// Determine where to place the component
this.positionComponent(element);
}
private updateComponent(update: ComponentUpdate): void {
if (!update.component) return;
const element = this.elements.get(update.target_id);
if (element) {
const component = this.normalizeComponent(update.component);
this.registry.update(element, component, update.updates);
this.components.set(update.target_id, component);
}
}
private replaceComponent(update: ComponentUpdate): void {
if (!update.component) return;
const oldElement = this.elements.get(update.target_id);
if (oldElement) {
const component = this.normalizeComponent(update.component);
const newElement = this.registry.render(component);
oldElement.parentNode?.replaceChild(newElement, oldElement);
this.elements.set(component.id, newElement);
this.components.set(component.id, component);
// Clean up old references if ID changed
if (update.target_id !== component.id) {
this.elements.delete(update.target_id);
this.components.delete(update.target_id);
}
}
}
private removeComponent(update: ComponentUpdate): void {
const element = this.elements.get(update.target_id);
if (element) {
element.remove();
this.elements.delete(update.target_id);
this.components.delete(update.target_id);
}
}
private positionComponent(element: HTMLElement): void {
// Always append to container
this.container.appendChild(element);
// Trigger scroll to bottom in parent chat component
this.triggerScroll();
}
private triggerScroll(): void {
// Find the parent vanna-chat component and trigger its scroll method
const vannaChat = document.querySelector('vanna-chat') as any;
if (vannaChat && typeof vannaChat.scrollToLastMessage === 'function') {
// Use requestAnimationFrame to wait for DOM update
requestAnimationFrame(() => {
requestAnimationFrame(() => {
vannaChat.scrollToLastMessage();
});
});
}
}
clear(): void {
this.components.clear();
this.elements.clear();
this.container.innerHTML = '';
ensureRichComponentStyles(this.container);
}
getComponent(id: string): RichComponent | undefined {
return this.components.get(id);
}
getAllComponents(): RichComponent[] {
return Array.from(this.components.values());
}
private normalizeComponent(component: RichComponent): RichComponent {
const data = { ...(component.data ?? {}) };
for (const [key, value] of Object.entries(component as Record)) {
if (this.sharedFields.has(key) || key === 'data') continue;
data[key] = value;
}
if (component.data && Object.keys(component.data).length === Object.keys(data).length) {
return component;
}
return {
...component,
data,
};
}
private isUIStateUpdate(component: RichComponent): boolean {
return component.type === 'status_bar_update' ||
component.type === 'task_tracker_update' ||
component.type === 'chat_input_update';
}
private processUIStateUpdate(component: RichComponent): void {
console.log('processUIStateUpdate called with type:', component.type, 'component:', component);
switch (component.type) {
case 'status_bar_update':
this.updateStatusBar(component);
break;
case 'task_tracker_update':
this.updateTaskTracker(component);
break;
case 'chat_input_update':
this.updateChatInput(component);
break;
}
}
private updateStatusBar(component: RichComponent): void {
// Find the status bar component - first try shadow DOM, then document
let statusBar: HTMLElement | null = null;
// Look for vanna-chat and search within its shadow root
const vannaChat = document.querySelector('vanna-chat') as any;
if (vannaChat && vannaChat.shadowRoot) {
statusBar = vannaChat.shadowRoot.querySelector('vanna-status-bar') as HTMLElement | null;
}
// Fallback to document search
if (!statusBar) {
statusBar = document.querySelector('vanna-status-bar') as HTMLElement | null;
}
if (statusBar) {
const { status, message, detail } = component.data || {};
// Set properties directly on the Lit component
(statusBar as any).status = status;
(statusBar as any).message = message || '';
(statusBar as any).detail = detail || '';
}
}
private updateTaskTracker(component: RichComponent): void {
// Debug logging
console.log('updateTaskTracker called with component:', component);
console.log('component.data:', component.data);
// Find the progress tracker component - first try shadow DOM, then document
let progressTracker = null;
// Look for vanna-chat and search within its shadow root
const vannaChat = document.querySelector('vanna-chat') as any;
if (vannaChat && vannaChat.shadowRoot) {
progressTracker = vannaChat.shadowRoot.querySelector('vanna-progress-tracker');
}
// Fallback to document search
if (!progressTracker) {
progressTracker = document.querySelector('vanna-progress-tracker');
}
console.log('Found progressTracker:', progressTracker);
if (!progressTracker) return;
const { operation, task, task_id, status, detail } = component.data || {};
console.log('Extracted data:', { operation, task, task_id, status, detail });
switch (operation) {
case 'add_task':
console.log('Adding task:', task);
if (task && progressTracker.addItem) {
// Use the backend task ID instead of generating a new one
const result = progressTracker.addItem(task.title || task.text, task.description || task.detail, task.id);
console.log('addItem result:', result, 'using backend ID:', task.id);
}
break;
case 'update_task':
console.log('Updating task:', task_id, status, detail);
if (task_id && progressTracker.updateItem) {
progressTracker.updateItem(task_id, status, detail);
}
break;
case 'remove_task':
if (task_id && progressTracker.removeItem) {
progressTracker.removeItem(task_id);
}
break;
case 'clear_tasks':
if (progressTracker.clear) {
progressTracker.clear();
}
break;
}
}
private updateChatInput(component: RichComponent): void {
// Find the chat input element - first try shadow DOM, then document
let chatInput = null;
// Look for vanna-chat and search within its shadow root
const vannaChat = document.querySelector('vanna-chat') as any;
if (vannaChat && vannaChat.shadowRoot) {
chatInput = vannaChat.shadowRoot.querySelector('textarea.message-input, input.message-input');
}
// Fallback to document search with multiple selectors
if (!chatInput) {
chatInput = document.querySelector('textarea[data-testid="message-input"], input[type="text"].message-input, .message-input input, .message-input textarea');
}
if (!chatInput) return;
const { placeholder, disabled, value, focus } = component.data || {};
if (placeholder !== undefined) {
chatInput.placeholder = placeholder;
}
if (disabled !== undefined) {
chatInput.disabled = disabled;
}
if (value !== undefined) {
chatInput.value = value;
}
if (focus !== undefined) {
if (focus) {
chatInput.focus();
} else {
chatInput.blur();
}
}
}
}
================================================
FILE: frontends/webcomponent/src/components/rich-progress-bar.stories.ts
================================================
import type { Meta, StoryObj } from '@storybook/web-components';
import { html } from 'lit';
import './rich-progress-bar';
const meta: Meta = {
title: 'Rich Components/Rich Progress Bar',
component: 'rich-progress-bar',
parameters: {
layout: 'padded',
backgrounds: {
default: 'light',
values: [
{ name: 'dark', value: 'rgb(11, 15, 25)' },
{ name: 'light', value: '#f5f7fa' },
],
},
},
argTypes: {
value: { control: { type: 'range', min: 0, max: 1, step: 0.01 } },
label: { control: 'text' },
description: { control: 'text' },
showPercentage: { control: 'boolean' },
status: {
control: 'select',
options: ['info', 'success', 'warning', 'error']
},
animated: { control: 'boolean' },
indeterminate: { control: 'boolean' },
},
};
export default meta;
type Story = StoryObj;
export const Default: Story = {
args: {
value: 0.65,
label: 'Processing',
showPercentage: true,
status: 'info',
animated: false,
indeterminate: false,
},
render: (args) => html`
`,
};
export const WithDescription: Story = {
args: {
value: 0.4,
label: 'Installing dependencies',
description: 'Downloading and installing npm packages for the project. This may take a few minutes.',
showPercentage: true,
status: 'info',
animated: true,
indeterminate: false,
},
render: (args) => html`
`,
};
export const LightMode: Story = {
args: {
title: 'Vanna AI Agent',
placeholder: 'Describe what you want to build...',
disabled: false,
showProgress: true,
maxAutonomy: false,
theme: 'light',
},
render: (args) => html`
`,
};
export const WithConversation: Story = {
args: {
title: 'Vanna AI Agent',
placeholder: 'Continue the conversation...',
disabled: false,
showProgress: true,
maxAutonomy: true,
theme: 'light',
},
render: (args) => {
setTimeout(() => {
const chat = document.querySelector('vanna-chat') as any;
const tracker = chat?.getProgressTracker();
if (chat && tracker) {
// Add conversation messages
chat.addMessage('Create a dashboard for analyzing customer data', 'user');
chat.addMessage('I\'ll help you create a customer data dashboard. Let me break this down into steps and get started.', 'assistant');
chat.addMessage('What specific metrics would you like to track? Revenue, acquisition, retention, or something else?', 'assistant');
chat.addMessage('Focus on revenue and customer acquisition metrics', 'user');
// Add progress items
const id1 = tracker.addItem('Analyze requirements', 'Understanding dashboard needs');
const id2 = tracker.addItem('Design data schema', 'Planning database structure');
tracker.addItem('Create visualization components', 'Building charts and graphs');
tracker.addItem('Implement filtering', 'Adding date range and segment filters');
// Update progress states
tracker.updateItem(id1, 'completed');
tracker.updateItem(id2, 'in_progress', 'Identifying key metrics and data sources');
// Set status
chat.setStatus('working', 'Analyzing data requirements...', 'Step 2 of 4');
}
}, 100);
return html`
`;
},
};
export const MaxAutonomyMode: Story = {
args: {
title: 'Vanna AI Agent - Max Autonomy',
placeholder: 'Describe your project...',
disabled: false,
showProgress: true,
maxAutonomy: true,
theme: 'light',
},
render: (args) => {
setTimeout(() => {
const chat = document.querySelector('vanna-chat') as any;
const tracker = chat?.getProgressTracker();
if (chat && tracker) {
chat.addMessage('Build a full-stack e-commerce app with user authentication, product catalog, shopping cart, and payment processing', 'user');
chat.addMessage('Perfect! I\'ll build a complete e-commerce application for you. Since Max Autonomy is enabled, I\'ll handle all the technical decisions and implementation details automatically.', 'assistant');
// Comprehensive task list for full autonomy
const tasks = [
'Set up project structure',
'Configure development environment',
'Design database schema',
'Implement user authentication',
'Build product catalog API',
'Create shopping cart functionality',
'Integrate payment processing',
'Develop frontend components',
'Add responsive design',
'Implement search & filtering',
'Set up testing framework',
'Configure deployment pipeline'
];
tasks.forEach((task, index) => {
const id = tracker.addItem(task, `Feature ${index + 1} of ${tasks.length}`);
if (index < 3) tracker.updateItem(id, 'completed');
else if (index === 3) tracker.updateItem(id, 'in_progress', 'Setting up JWT tokens and password hashing');
});
chat.setStatus('working', 'Building authentication system...', 'High autonomy mode active');
}
}, 100);
return html`
`;
},
};
export const WorkingState: Story = {
args: {
title: 'Vanna AI Agent',
placeholder: 'Ask me anything...',
disabled: true,
showProgress: true,
maxAutonomy: false,
theme: 'light',
},
render: (args) => {
setTimeout(() => {
const chat = document.querySelector('vanna-chat') as any;
const tracker = chat?.getProgressTracker();
if (chat && tracker) {
chat.addMessage('Generate a monthly sales report with charts', 'user');
const id1 = tracker.addItem('Connect to database', 'Establishing secure connection');
const id2 = tracker.addItem('Query sales data', 'Fetching monthly records');
tracker.addItem('Process data', 'Calculating totals and trends');
tracker.addItem('Generate charts', 'Creating visualizations');
tracker.addItem('Format report', 'Compiling final document');
tracker.updateItem(id1, 'completed');
tracker.updateItem(id2, 'in_progress', 'SELECT * FROM sales WHERE date >= 2024-01...');
chat.setStatus('working', 'Querying sales database...', 'Processing 12,543 records');
}
}, 100);
return html`
`;
},
};
export const CompactMode: Story = {
args: {
title: 'Vanna AI Agent',
placeholder: 'Quick question...',
disabled: false,
showProgress: false,
maxAutonomy: false,
theme: 'light',
},
render: (args) => {
setTimeout(() => {
const chat = document.querySelector('vanna-chat') as any;
if (chat) {
chat.addMessage('What\'s the average order value this month?', 'user');
chat.addMessage('Let me query that for you...', 'assistant');
chat.setStatus('working', 'Calculating average order value...', '2.1s');
}
}, 100);
return html`
`;
},
};
export const WithRichComponents: Story = {
args: {
title: 'Vanna AI Agent - Rich Components',
placeholder: 'Ask me to analyze data or build something...',
disabled: false,
showProgress: true,
maxAutonomy: false,
theme: 'light',
},
render: (args) => {
setTimeout(() => {
const chat = document.querySelector('vanna-chat') as any;
const tracker = chat?.getProgressTracker();
if (chat && tracker) {
// Initial conversation
chat.addMessage('Create a comprehensive sales dashboard with multiple visualizations and export capabilities', 'user');
chat.addMessage('I\'ll create a comprehensive sales dashboard for you. Let me break this down into clear tasks and show you the progress with rich components.', 'assistant');
// Add progress tasks
const taskId1 = tracker.addItem('Analyze requirements', 'Understanding dashboard specifications');
const taskId2 = tracker.addItem('Design data schema', 'Planning database structure');
const taskId3 = tracker.addItem('Create visualizations', 'Building charts and graphs');
const taskId4 = tracker.addItem('Add export features', 'Implementing PDF and Excel export');
tracker.addItem('Deploy dashboard', 'Setting up production environment');
tracker.updateItem(taskId1, 'completed');
tracker.updateItem(taskId2, 'completed');
tracker.updateItem(taskId3, 'in_progress', 'Creating revenue trend charts...');
chat.setStatus('working', 'Building visualization components...', 'Step 3 of 5');
// Add rich components after a delay
setTimeout(() => {
const componentManager = chat.componentManager;
if (!componentManager) return;
// Add info notification
componentManager.processUpdate({
operation: 'create',
target_id: 'info-notification',
component: {
id: 'info-notification',
type: 'notification',
data: {
title: 'Dashboard Progress',
message: 'Your sales dashboard is being built with the following components: revenue trends, customer analytics, and performance metrics.',
level: 'info',
dismissible: true,
actions: []
},
layout: { position: 'append', size: {}, z_index: 0, classes: [] },
theme: {},
lifecycle: 'create'
},
timestamp: new Date().toISOString()
});
// Add status indicator
componentManager.processUpdate({
operation: 'create',
target_id: 'status-chart-generation',
component: {
id: 'status-chart-generation',
type: 'status_indicator',
data: {
status: 'loading',
message: 'Generating revenue trend charts...',
pulse: true
},
layout: { position: 'append', size: {}, z_index: 0, classes: [] },
theme: {},
lifecycle: 'create'
},
timestamp: new Date().toISOString()
});
// Add progress bar for chart generation
componentManager.processUpdate({
operation: 'create',
target_id: 'chart-progress',
component: {
id: 'chart-progress',
type: 'progress_bar',
data: {
progress: 65,
status: 'active',
label: 'Chart Generation Progress',
detail: 'Processing 12,543 sales records...',
animated: true
},
layout: { position: 'append', size: {}, z_index: 0, classes: [] },
theme: {},
lifecycle: 'create'
},
timestamp: new Date().toISOString()
});
// Add task list card
componentManager.processUpdate({
operation: 'create',
target_id: 'dashboard-tasks',
component: {
id: 'dashboard-tasks',
type: 'task_list',
data: {
title: 'Dashboard Components',
tasks: [
{
id: 'task-1',
title: 'Revenue Trend Chart',
description: 'Monthly revenue tracking with year-over-year comparison',
status: 'completed',
progress: 100,
timestamp: '2024-01-15 14:32:00'
},
{
id: 'task-2',
title: 'Customer Acquisition Funnel',
description: 'Lead to customer conversion visualization',
status: 'running',
progress: 75,
timestamp: '2024-01-15 14:45:00'
},
{
id: 'task-3',
title: 'Geographic Sales Map',
description: 'Interactive map showing sales by region',
status: 'pending',
progress: 0,
timestamp: null
},
{
id: 'task-4',
title: 'Performance Metrics KPIs',
description: 'Key performance indicators dashboard',
status: 'pending',
progress: 0,
timestamp: null
}
],
progress: 58
},
layout: { position: 'append', size: {}, z_index: 0, classes: [] },
theme: {},
lifecycle: 'create'
},
timestamp: new Date().toISOString()
});
// Add data summary card
componentManager.processUpdate({
operation: 'create',
target_id: 'data-summary',
component: {
id: 'data-summary',
type: 'card',
data: {
title: 'Data Analysis Summary',
subtitle: 'Sales Data Processing Results',
content: 'Successfully processed 12,543 sales records from the last 12 months. Found key trends in customer behavior and revenue patterns.',
icon: '📊',
status: 'success',
collapsible: true,
actions: [
{ label: 'View Details', action: 'view-details', variant: 'primary' },
{ label: 'Export Data', action: 'export', variant: 'secondary' }
]
},
layout: { position: 'append', size: {}, z_index: 0, classes: [] },
theme: {},
lifecycle: 'create'
},
timestamp: new Date().toISOString()
});
// Add markdown text with insights
componentManager.processUpdate({
operation: 'create',
target_id: 'insights-text',
component: {
id: 'insights-text',
type: 'text',
data: {
content: `# Dashboard Insights\n\nBased on the data analysis, here are the key findings:\n\n## Revenue Trends\n- **23% increase** in Q4 sales compared to Q3\n- Peak sales month: **December** ($1.2M)\n- Lowest performing month: **February** ($680K)\n\n## Customer Behavior\n- Average order value: **$156.78**\n- Customer retention rate: **89.3%**\n- Most popular product category: **Electronics**\n\n## Recommendations\n1. **Focus marketing efforts** on February to boost sales\n2. **Expand electronics inventory** for peak seasons\n3. **Implement loyalty program** to maintain high retention\n\n*Dashboard generation is 65% complete. Estimated completion: 3-4 minutes.*`,
markdown: true
},
layout: { position: 'append', size: {}, z_index: 0, classes: [] },
theme: {},
lifecycle: 'create'
},
timestamp: new Date().toISOString()
});
}, 1500);
// Update components after more time
setTimeout(() => {
const componentManager = chat.componentManager;
if (!componentManager) return;
// Update status indicator to success
componentManager.processUpdate({
operation: 'update',
target_id: 'status-chart-generation',
updates: {
status: 'success',
message: 'Revenue charts generated successfully',
pulse: false
},
timestamp: new Date().toISOString()
});
// Update progress bar
componentManager.processUpdate({
operation: 'update',
target_id: 'chart-progress',
updates: {
progress: 100,
status: 'success',
detail: 'All charts generated successfully!',
animated: false
},
timestamp: new Date().toISOString()
});
// Add success notification
componentManager.processUpdate({
operation: 'create',
target_id: 'success-notification',
component: {
id: 'success-notification',
type: 'notification',
data: {
title: 'Charts Ready!',
message: 'Your revenue trend charts have been generated and are ready for review.',
level: 'success',
dismissible: true,
actions: [
{ label: 'View Charts', action: 'view-charts', variant: 'primary' },
{ label: 'Continue', action: 'continue', variant: 'secondary' }
]
},
layout: { position: 'append', size: {}, z_index: 0, classes: [] },
theme: {},
lifecycle: 'create'
},
timestamp: new Date().toISOString()
});
// Update progress tracker
tracker.updateItem(taskId3, 'completed');
tracker.updateItem(taskId4, 'in_progress', 'Adding PDF export functionality...');
chat.setStatus('working', 'Adding export capabilities...', 'Step 4 of 5');
}, 4000);
}
}, 100);
return html`
`;
}
}
================================================
FILE: frontends/webcomponent/src/components/vanna-message.stories.ts
================================================
import type { Meta, StoryObj } from '@storybook/web-components';
import { html } from 'lit';
import './vanna-message';
const meta: Meta = {
title: 'Components/VannaMessage',
component: 'vanna-message',
parameters: {
layout: 'centered',
},
argTypes: {
content: { control: 'text' },
type: {
control: 'select',
options: ['user', 'assistant'],
},
timestamp: { control: 'number' },
},
};
export default meta;
type Story = StoryObj;
export const UserMessage: Story = {
args: {
content: 'Hello! Can you help me analyze my data?',
type: 'user',
timestamp: Date.now(),
},
render: (args) => html`
`,
};
export const AssistantMessage: Story = {
args: {
content: 'Of course! I\'d be happy to help you analyze your data. Could you please tell me more about the type of data you have and what insights you\'re looking for?',
type: 'assistant',
timestamp: Date.now(),
},
render: (args) => html`
`,
};
export const LongMessage: Story = {
args: {
content: 'This is a very long message that demonstrates how the component handles longer text content. It should wrap properly and maintain good readability while staying within the maximum width constraints. The message can contain multiple sentences and paragraphs of information that the AI assistant might provide in response to complex queries.',
type: 'assistant',
timestamp: Date.now(),
},
render: (args) => html`
================================================
FILE: frontends/webcomponent/test_backend.py
================================================
#!/usr/bin/env python3
"""
Comprehensive test backend for vanna-webcomponent validation.
This backend exercises all component types and update patterns to validate
that nothing breaks during webcomponent pruning.
Usage:
python test_backend.py --mode rapid # Fast stress test
python test_backend.py --mode realistic # Realistic conversation flow
"""
import argparse
import asyncio
import json
import sys
import time
import traceback
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, Any, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import os
# Add vanna to path
sys.path.insert(0, "../vanna/src")
from vanna.core.rich_component import RichComponent, ComponentLifecycle
from vanna.components.rich import (
RichTextComponent,
StatusCardComponent,
ProgressDisplayComponent,
ProgressBarComponent,
NotificationComponent,
StatusIndicatorComponent,
ButtonComponent,
ButtonGroupComponent,
CardComponent,
TaskListComponent,
Task,
BadgeComponent,
IconTextComponent,
DataFrameComponent,
ChartComponent,
ArtifactComponent,
LogViewerComponent,
LogEntry,
StatusBarUpdateComponent,
TaskTrackerUpdateComponent,
ChatInputUpdateComponent,
TaskOperation,
)
from vanna.servers.base.models import ChatStreamChunk
# Request/Response models
class ChatRequest(BaseModel):
"""Chat request matching vanna API."""
message: str
conversation_id: Optional[str] = None
request_id: Optional[str] = None
request_context: Dict[str, Any] = {}
class UiComponent(BaseModel):
"""UI component wrapper."""
rich_component: RichComponent
# Test state
test_state: Dict[str, Any] = {
"mode": "realistic",
"component_ids": {}, # Track component IDs for updates
"action_count": 0,
}
async def yield_chunk(component: RichComponent, conversation_id: str, request_id: str) -> ChatStreamChunk:
"""Convert component to ChatStreamChunk."""
return ChatStreamChunk(
rich=component.serialize_for_frontend(),
simple=None,
conversation_id=conversation_id,
request_id=request_id,
timestamp=time.time(),
)
async def delay(mode: str, short: float = 0.1, long: float = 0.5):
"""Add delay based on mode."""
if mode == "realistic":
await asyncio.sleep(long)
elif mode == "rapid":
await asyncio.sleep(short)
async def test_text_component(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test text component with markdown."""
text_id = str(uuid.uuid4())
test_state["component_ids"]["text"] = text_id
# Create with comprehensive markdown
text = RichTextComponent(
id=text_id,
content="""# Test Text Component
This component demonstrates **markdown rendering** with various formatting:
## Formatting Examples
- **Bold text** for emphasis
- *Italic text* for style
- `inline code` for snippets
- ~~Strikethrough~~ for deletions
### Lists
1. First ordered item
2. Second ordered item
3. Third ordered item
### Code Block
```python
def hello():
return "Markdown works!"
```
> Blockquote to test quote rendering
This validates that markdown is properly parsed and displayed.""",
markdown=True,
)
yield await yield_chunk(text, conversation_id, request_id)
await delay(mode)
# Update with simpler markdown
text_updated = text.update(content="""# Updated Text Component
Text has been **successfully updated** with new markdown content!
- Update operation works ✓
- Markdown still renders ✓""")
yield await yield_chunk(text_updated, conversation_id, request_id)
await delay(mode)
async def test_status_card(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test status card with all states."""
card_id = str(uuid.uuid4())
test_state["component_ids"]["status_card"] = card_id
# Create - pending
status_card = StatusCardComponent(
id=card_id,
title="Status Card Test",
status="pending",
description="Testing status card component...",
icon="⏳",
collapsible=True,
collapsed=False,
)
yield await yield_chunk(status_card, conversation_id, request_id)
await delay(mode)
# Update to running
status_card_running = status_card.set_status("running", "Processing test...")
yield await yield_chunk(status_card_running, conversation_id, request_id)
await delay(mode)
# Update to completed
status_card_done = status_card.set_status("completed", "Test completed successfully!")
status_card_done.icon = "✅"
yield await yield_chunk(status_card_done, conversation_id, request_id)
await delay(mode)
async def test_progress_display(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test progress display component."""
progress_id = str(uuid.uuid4())
test_state["component_ids"]["progress_display"] = progress_id
# Create at 0%
progress = ProgressDisplayComponent(
id=progress_id,
label="Test Progress",
value=0.0,
description="Starting test...",
status="info",
animated=True,
)
yield await yield_chunk(progress, conversation_id, request_id)
await delay(mode, 0.05, 0.3)
# Update to 50%
progress_half = progress.update_progress(0.5, "Halfway there...")
yield await yield_chunk(progress_half, conversation_id, request_id)
await delay(mode, 0.05, 0.3)
# Update to 100%
progress_done = progress.update_progress(1.0, "Complete!")
progress_done.status = "success"
yield await yield_chunk(progress_done, conversation_id, request_id)
await delay(mode)
async def test_card_component(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test card component with actions."""
card_id = str(uuid.uuid4())
test_state["component_ids"]["card"] = card_id
# Create card with markdown content and buttons
card = CardComponent(
id=card_id,
title="Test Card with Markdown",
content="""# Card Content
This card demonstrates **markdown rendering** within cards:
- Interactive action buttons
- Collapsible sections
- Status indicators
- `Formatted text`
Click the buttons below to test interactivity!""",
icon="🃏",
status="info",
markdown=True,
collapsible=True,
collapsed=False,
actions=[
{"label": "Test Action", "action": "/test-action", "variant": "primary"},
{"label": "Cancel", "action": "/cancel", "variant": "secondary"},
],
)
yield await yield_chunk(card, conversation_id, request_id)
await delay(mode)
# Update card status and content
card_updated = card.update(
status="success",
content="""# Card Updated Successfully!
The card content has been **updated** with:
- New status (success)
- New markdown content
- Same action buttons
✓ Update operation verified""",
markdown=True
)
yield await yield_chunk(card_updated, conversation_id, request_id)
await delay(mode)
async def test_task_list(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test task list component."""
task_list_id = str(uuid.uuid4())
test_state["component_ids"]["task_list"] = task_list_id
# Create task list
tasks = [
Task(title="Setup development environment", description="Install dependencies and configure tools", status="completed", progress=1.0),
Task(title="Write test suite", description="Create comprehensive component tests", status="in_progress", progress=0.7),
Task(title="Run validation", description="Validate all components render correctly", status="pending"),
Task(title="Prune webcomponent", description="Remove unused code and cruft", status="pending"),
]
task_list = TaskListComponent(
id=task_list_id,
title="Webcomponent Validation Workflow",
tasks=tasks,
show_progress=True,
show_timestamps=True,
)
yield await yield_chunk(task_list, conversation_id, request_id)
await delay(mode)
# Update task statuses
tasks[1].status = "completed"
tasks[1].progress = 1.0
tasks[2].status = "in_progress"
tasks[2].progress = 0.3
task_list_updated = TaskListComponent(
id=task_list_id,
title="Webcomponent Validation Workflow (Updated)",
tasks=tasks,
show_progress=True,
show_timestamps=True,
)
task_list_updated.lifecycle = ComponentLifecycle.UPDATE
yield await yield_chunk(task_list_updated, conversation_id, request_id)
await delay(mode)
async def test_progress_bar(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test progress bar component."""
bar_id = str(uuid.uuid4())
test_state["component_ids"]["progress_bar"] = bar_id
# Create
bar = ProgressBarComponent(
id=bar_id,
value=0.3,
label="Loading",
status="info",
)
yield await yield_chunk(bar, conversation_id, request_id)
await delay(mode, 0.05, 0.2)
# Update
bar_updated = bar.update(value=0.8, status="success")
yield await yield_chunk(bar_updated, conversation_id, request_id)
await delay(mode)
async def test_notification(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test notification component."""
for level in ["info", "success", "warning", "error"]:
notif = NotificationComponent(
id=str(uuid.uuid4()),
message=f"This is a {level} notification",
level=level,
title=f"{level.capitalize()} Test",
)
yield await yield_chunk(notif, conversation_id, request_id)
await delay(mode, 0.05, 0.2)
async def test_status_indicator(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test status indicator component."""
indicator_id = str(uuid.uuid4())
test_state["component_ids"]["status_indicator"] = indicator_id
# Create with pulse
indicator = StatusIndicatorComponent(
id=indicator_id,
status="running",
message="Processing...",
pulse=True,
)
yield await yield_chunk(indicator, conversation_id, request_id)
await delay(mode)
# Update to success
indicator_success = indicator.update(status="success", message="Done!", pulse=False)
yield await yield_chunk(indicator_success, conversation_id, request_id)
await delay(mode)
async def test_badge(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test badge component."""
badge = BadgeComponent(
id=str(uuid.uuid4()),
text="Test Badge",
variant="primary",
)
yield await yield_chunk(badge, conversation_id, request_id)
await delay(mode)
async def test_icon_text(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test icon_text component."""
icon_text = IconTextComponent(
id=str(uuid.uuid4()),
icon="🔧",
text="Tool Icon Test",
)
yield await yield_chunk(icon_text, conversation_id, request_id)
await delay(mode)
async def test_buttons(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test button and button_group components."""
# Single button
button = ButtonComponent(
label="Single Button",
action="/button-test",
variant="primary",
icon="🔘",
)
yield await yield_chunk(button, conversation_id, request_id)
await delay(mode, 0.05, 0.2)
# Button group
button_group = ButtonGroupComponent(
buttons=[
{"label": "Option 1", "action": "/option1", "variant": "primary"},
{"label": "Option 2", "action": "/option2", "variant": "secondary"},
{"label": "Option 3", "action": "/option3", "variant": "success"},
],
orientation="horizontal",
)
yield await yield_chunk(button_group, conversation_id, request_id)
await delay(mode)
async def test_dataframe(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test dataframe component with sample data."""
dataframe_id = str(uuid.uuid4())
test_state["component_ids"]["dataframe"] = dataframe_id
# Create sample data
sample_data = [
{"id": 1, "name": "Alice", "age": 30, "city": "New York", "salary": 75000},
{"id": 2, "name": "Bob", "age": 25, "city": "San Francisco", "salary": 85000},
{"id": 3, "name": "Charlie", "age": 35, "city": "Chicago", "salary": 70000},
{"id": 4, "name": "Diana", "age": 28, "city": "Boston", "salary": 80000},
{"id": 5, "name": "Eve", "age": 32, "city": "Seattle", "salary": 90000},
]
dataframe = DataFrameComponent.from_records(
records=sample_data,
title="📊 Employee Data",
description="""Sample employee dataset demonstrating **DataFrame** features:
- **Searchable**: Try searching for names or cities
- **Sortable**: Click column headers to sort
- **Exportable**: Export to CSV/Excel
- **Paginated**: Navigate through rows
*5 employees across different cities*""",
id=dataframe_id,
searchable=True,
sortable=True,
exportable=True,
)
yield await yield_chunk(dataframe, conversation_id, request_id)
await delay(mode)
# Update with more data
updated_data = sample_data + [
{"id": 6, "name": "Frank", "age": 29, "city": "Austin", "salary": 78000},
]
dataframe_updated = DataFrameComponent.from_records(
records=updated_data,
title="📊 Employee Data (Updated)",
description="""Dataset **updated** with new employee!
✓ Added Frank from Austin
✓ Now showing 6 employees
✓ Update operation verified""",
id=dataframe_id,
)
dataframe_updated.lifecycle = ComponentLifecycle.UPDATE
yield await yield_chunk(dataframe_updated, conversation_id, request_id)
await delay(mode)
async def test_chart(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test chart component with Plotly data."""
chart_id = str(uuid.uuid4())
test_state["component_ids"]["chart"] = chart_id
# Create a simple bar chart
chart_data = {
"data": [
{
"x": ["Product A", "Product B", "Product C", "Product D"],
"y": [20, 35, 30, 25],
"type": "bar",
"name": "Sales",
"marker": {"color": "#667eea"},
}
],
"layout": {
"title": "Product Sales",
"xaxis": {"title": "Products"},
"yaxis": {"title": "Sales (units)"},
},
}
chart = ChartComponent(
id=chart_id,
chart_type="bar",
data=chart_data,
title="Sales Chart",
)
yield await yield_chunk(chart, conversation_id, request_id)
await delay(mode)
# Update to line chart
line_chart_data = {
"data": [
{
"x": ["Jan", "Feb", "Mar", "Apr", "May"],
"y": [10, 15, 13, 17, 21],
"type": "scatter",
"mode": "lines+markers",
"name": "Revenue",
"line": {"color": "#10b981", "width": 3},
}
],
"layout": {
"title": "Monthly Revenue Trend",
"xaxis": {"title": "Month"},
"yaxis": {"title": "Revenue ($1000s)"},
},
}
chart_updated = ChartComponent(
id=chart_id,
chart_type="line",
data=line_chart_data,
title="Revenue Chart",
)
chart_updated.lifecycle = ComponentLifecycle.UPDATE
yield await yield_chunk(chart_updated, conversation_id, request_id)
await delay(mode)
async def test_artifact(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test artifact component with HTML/SVG content."""
artifact_id = str(uuid.uuid4())
test_state["component_ids"]["artifact"] = artifact_id
# Create SVG artifact
svg_content = ''''''
artifact = ArtifactComponent(
id=artifact_id,
content=svg_content,
artifact_type="svg",
title="SVG Circle Visualization",
description="Concentric circles demonstration",
fullscreen_capable=True,
)
yield await yield_chunk(artifact, conversation_id, request_id)
await delay(mode)
async def test_log_viewer(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test log viewer component."""
log_id = str(uuid.uuid4())
test_state["component_ids"]["log_viewer"] = log_id
# Create initial log viewer with entries
log_viewer = LogViewerComponent(
id=log_id,
title="System Logs",
entries=[
LogEntry(message="System started", level="info"),
LogEntry(message="Loading configuration...", level="info"),
LogEntry(message="Configuration loaded successfully", level="info"),
],
searchable=True,
auto_scroll=True,
)
yield await yield_chunk(log_viewer, conversation_id, request_id)
await delay(mode, 0.05, 0.3)
# Add warning
log_viewer = log_viewer.add_entry("Memory usage at 75%", level="warning")
yield await yield_chunk(log_viewer, conversation_id, request_id)
await delay(mode, 0.05, 0.3)
# Add error
log_viewer = log_viewer.add_entry("Connection timeout", level="error", data={"host": "api.example.com", "port": 443})
yield await yield_chunk(log_viewer, conversation_id, request_id)
await delay(mode, 0.05, 0.3)
# Add success
log_viewer = log_viewer.add_entry("Reconnected successfully", level="info")
yield await yield_chunk(log_viewer, conversation_id, request_id)
await delay(mode)
async def test_ui_state_updates(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Test UI state update components."""
# Status bar update
status_bar = StatusBarUpdateComponent(
message="Running comprehensive component test...",
status="info",
)
yield await yield_chunk(status_bar, conversation_id, request_id)
await delay(mode, 0.1, 0.3)
# Task tracker - add tasks to sidebar
task1 = Task(
title="Validate Text Components",
description="Test text, markdown, and formatting",
status="completed",
progress=1.0,
)
task_tracker_add1 = TaskTrackerUpdateComponent.add_task(task1)
yield await yield_chunk(task_tracker_add1, conversation_id, request_id)
await delay(mode, 0.1, 0.3)
task2 = Task(
title="Validate Data Components",
description="Test DataFrame, Chart, Code blocks",
status="in_progress",
progress=0.6,
)
task_tracker_add2 = TaskTrackerUpdateComponent.add_task(task2)
yield await yield_chunk(task_tracker_add2, conversation_id, request_id)
await delay(mode, 0.1, 0.3)
task3 = Task(
title="Validate Interactive Components",
description="Test buttons, actions, and UI state",
status="pending",
)
task_tracker_add3 = TaskTrackerUpdateComponent.add_task(task3)
yield await yield_chunk(task_tracker_add3, conversation_id, request_id)
await delay(mode, 0.1, 0.3)
# Update task 2 to completed
task_tracker_update = TaskTrackerUpdateComponent(
operation=TaskOperation.UPDATE_TASK,
task_id=task2.id,
status="completed",
progress=1.0,
)
yield await yield_chunk(task_tracker_update, conversation_id, request_id)
await delay(mode, 0.1, 0.3)
# Update status bar
status_bar_complete = StatusBarUpdateComponent(
message="All components validated successfully!",
status="success",
)
yield await yield_chunk(status_bar_complete, conversation_id, request_id)
await delay(mode, 0.1, 0.3)
# Chat input update - change placeholder
chat_input = ChatInputUpdateComponent(
placeholder="Type a message to test chat input updates...",
disabled=False,
)
yield await yield_chunk(chat_input, conversation_id, request_id)
await delay(mode)
async def run_comprehensive_test(conversation_id: str, request_id: str, mode: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Run all component tests."""
# Introduction
intro = RichTextComponent(
content=f"""# 🧪 Comprehensive Component Test
**Mode**: {mode}
## Test Coverage
This test validates **16 component types** supported by the webcomponent:
- ✅ Component creation
- ✅ Incremental updates
- ✅ Markdown rendering
- ✅ Interactive actions
- ✅ Data visualization
### Component Categories
1. **Primitive**: Text, Badge, Icon Text
2. **Feedback**: Status Card, Progress, Notifications, Logs
3. **Data**: Card, Task List, DataFrame, Chart, Code
4. **Specialized**: Artifact (SVG/HTML)
5. **Interactive**: Buttons with actions
Watch the sidebar checklist as components render! ➡️""",
markdown=True,
)
yield await yield_chunk(intro, conversation_id, request_id)
await delay(mode)
# Run all tests
async for chunk in test_text_component(conversation_id, request_id, mode):
yield chunk
async for chunk in test_status_card(conversation_id, request_id, mode):
yield chunk
async for chunk in test_progress_display(conversation_id, request_id, mode):
yield chunk
async for chunk in test_card_component(conversation_id, request_id, mode):
yield chunk
async for chunk in test_task_list(conversation_id, request_id, mode):
yield chunk
async for chunk in test_progress_bar(conversation_id, request_id, mode):
yield chunk
async for chunk in test_notification(conversation_id, request_id, mode):
yield chunk
async for chunk in test_status_indicator(conversation_id, request_id, mode):
yield chunk
async for chunk in test_badge(conversation_id, request_id, mode):
yield chunk
async for chunk in test_icon_text(conversation_id, request_id, mode):
yield chunk
async for chunk in test_buttons(conversation_id, request_id, mode):
yield chunk
async for chunk in test_dataframe(conversation_id, request_id, mode):
yield chunk
async for chunk in test_chart(conversation_id, request_id, mode):
yield chunk
async for chunk in test_artifact(conversation_id, request_id, mode):
yield chunk
async for chunk in test_log_viewer(conversation_id, request_id, mode):
yield chunk
# NOTE: Table, Container, and CodeBlock components are defined in vanna Python package
# but NOT supported by the webcomponent (no renderers). Skipping these tests.
# These are candidates for removal from the vanna package.
async for chunk in test_ui_state_updates(conversation_id, request_id, mode):
yield chunk
# Completion message
done = StatusCardComponent(
title="✅ Test Suite Complete",
status="completed",
description=f"""All **16 component types** successfully rendered in **{mode}** mode!
**Validated:**
- Component creation & updates
- Markdown rendering
- Interactive buttons
- Data visualization
- UI state management
Check the sidebar for the complete checklist.""",
icon="✅",
)
yield await yield_chunk(done, conversation_id, request_id)
async def handle_action_message(message: str, conversation_id: str, request_id: str) -> AsyncGenerator[ChatStreamChunk, None]:
"""Handle button action messages."""
test_state["action_count"] += 1
response = NotificationComponent(
message=f"Action received: {message}",
level="success",
title=f"Action #{test_state['action_count']}",
)
yield await yield_chunk(response, conversation_id, request_id)
# Also show a card with details
card = CardComponent(
title="Action Handler Response",
content=f"Received action: `{message}`\n\nThis confirms button interactivity is working!",
icon="🎯",
status="success",
)
yield await yield_chunk(card, conversation_id, request_id)
# FastAPI app
app = FastAPI(title="Vanna Webcomponent Test Backend")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files (static directory for webcomponent)
static_path = os.path.join(os.path.dirname(__file__), "static")
if os.path.exists(static_path):
app.mount("/static", StaticFiles(directory=static_path), name="static")
@app.post("/api/vanna/v2/chat_sse")
async def chat_sse(chat_request: ChatRequest) -> StreamingResponse:
"""SSE endpoint for streaming chat."""
conversation_id = chat_request.conversation_id or str(uuid.uuid4())
request_id = chat_request.request_id or str(uuid.uuid4())
message = chat_request.message.strip()
async def generate() -> AsyncGenerator[str, None]:
"""Generate SSE stream."""
try:
# Handle button actions
if message.startswith("/") and message != "/test":
async for chunk in handle_action_message(message, conversation_id, request_id):
yield f"data: {chunk.model_dump_json()}\n\n"
# Handle test command or initial message
elif message == "/test" or "test" in message.lower():
async for chunk in run_comprehensive_test(conversation_id, request_id, test_state["mode"]):
yield f"data: {chunk.model_dump_json()}\n\n"
# Default response
else:
response = RichTextComponent(
content=f"You said: {message}\n\nType `/test` to run the comprehensive component test.",
markdown=True,
)
chunk = await yield_chunk(response, conversation_id, request_id)
yield f"data: {chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
error_message = f"{str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(f"ERROR in chat_sse: {error_message}") # Log to console
error_chunk = {
"type": "error",
"data": {"message": error_message},
"conversation_id": conversation_id,
"request_id": request_id,
}
yield f"data: {json.dumps(error_chunk)}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.get("/health")
async def health():
"""Health check."""
return {"status": "ok", "mode": test_state["mode"]}
@app.get("/")
async def root():
"""Serve test HTML page."""
html_path = os.path.join(os.path.dirname(__file__), "test-comprehensive.html")
if os.path.exists(html_path):
return FileResponse(html_path)
return {
"message": "Vanna Webcomponent Test Backend",
"mode": test_state["mode"],
"endpoints": {
"chat": "POST /api/vanna/v2/chat_sse",
"health": "GET /health",
},
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test backend for vanna-webcomponent")
parser.add_argument(
"--mode",
choices=["rapid", "realistic"],
default="realistic",
help="Test mode: rapid (fast) or realistic (with delays)",
)
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=5555, help="Port to bind to")
args = parser.parse_args()
test_state["mode"] = args.mode
print(f"Starting test backend in {args.mode} mode...")
print(f"Server running at http://{args.host}:{args.port}")
print("Send message '/test' to run comprehensive component test")
import uvicorn
uvicorn.run(app, host=args.host, port=args.port)
================================================
FILE: frontends/webcomponent/tsconfig.json
================================================
{
"compilerOptions": {
"target": "ES2020",
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"module": "ESNext",
"skipLibCheck": true,
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noFallthroughCasesInSwitch": true,
"experimentalDecorators": true,
"useDefineForClassFields": false
},
"include": ["src"]
}
================================================
FILE: frontends/webcomponent/vite.config.ts
================================================
import { defineConfig } from 'vite';
export default defineConfig({
define: {
__BUILD_TIME__: JSON.stringify(new Date().toISOString()),
__BUILD_VERSION__: JSON.stringify(process.env.npm_package_version || '1.0.0'),
},
build: {
outDir: 'dist',
lib: {
entry: 'src/index.ts',
formats: ['es'],
fileName: () => 'vanna-components.js',
},
rollupOptions: {
// Remove external to bundle lit with the components
// external: /^lit/,
},
},
preview: {
port: 9876,
strictPort: true,
},
});
================================================
FILE: notebooks/quickstart.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Install the Package\n",
"Here we're installing it directly from GitHub while it's in development."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install 'vanna[flask,anthropic]'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Download a Sample Database"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import httpx\n",
"\n",
"with open(\"Chinook.sqlite\", \"wb\") as f:\n",
" with httpx.stream(\"GET\", \"https://vanna.ai/Chinook.sqlite\") as response:\n",
" for chunk in response.iter_bytes():\n",
" f.write(chunk)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from vanna import Agent, AgentConfig\n",
"from vanna.servers.fastapi import VannaFastAPIServer\n",
"from vanna.core.registry import ToolRegistry\n",
"from vanna.core.user import UserResolver, User, RequestContext\n",
"from vanna.integrations.anthropic import AnthropicLlmService\n",
"from vanna.tools import RunSqlTool, VisualizeDataTool\n",
"from vanna.integrations.sqlite import SqliteRunner\n",
"from vanna.tools.agent_memory import SaveQuestionToolArgsTool, SearchSavedCorrectToolUsesTool\n",
"from vanna.integrations.local.agent_memory import DemoAgentMemory\n",
"from vanna.capabilities.sql_runner import RunSqlToolArgs\n",
"from vanna.tools.visualize_data import VisualizeDataArgs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define your User Authentication\n",
"Here we're going to say that if you're logged in as `admin@example.com` then you're in the `admin` group, otherwise you're in the `user` group"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SimpleUserResolver(UserResolver):\n",
" async def resolve_user(self, request_context: RequestContext) -> User:\n",
" # In production, validate cookies/JWTs here\n",
" user_email = request_context.get_cookie('vanna_email')\n",
" if not user_email:\n",
" raise ValueError(\"Missing 'vanna_email' cookie for user identification\")\n",
" \n",
" print(f\"Resolving user for email: {user_email}\")\n",
"\n",
" if user_email == \"admin@example.com\":\n",
" return User(id=\"admin1\", email=user_email, group_memberships=['admin'])\n",
" \n",
" return User(id=\"user1\", email=user_email, group_memberships=['user'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define the Tools and Access Control"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tools = ToolRegistry()\n",
"tools.register_local_tool(RunSqlTool(sql_runner=SqliteRunner(database_path=\"./Chinook.sqlite\")), access_groups=['admin', 'user'])\n",
"tools.register_local_tool(VisualizeDataTool(), access_groups=['admin', 'user'])\n",
"agent_memory = DemoAgentMemory(max_items=1000)\n",
"tools.register_local_tool(SaveQuestionToolArgsTool(), access_groups=['admin'])\n",
"tools.register_local_tool(SearchSavedCorrectToolUsesTool(), access_groups=['admin', 'user'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set up LLM\n",
"llm = AnthropicLlmService(model=\"claude-sonnet-4-5\", api_key=\"sk-ant-...\")\n",
"\n",
"# Create agent with your options\n",
"agent = Agent(\n",
" llm_service=llm,\n",
" tool_registry=tools,\n",
" user_resolver=SimpleUserResolver(),\n",
" config=AgentConfig(),\n",
" agent_memory=agent_memory\n",
")\n",
"\n",
"# 4. Create and run server\n",
"server = VannaFastAPIServer(agent)\n",
"server.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: papers/ai-sql-accuracy-2023-08-17.md
================================================
# AI SQL Accuracy: Testing different LLMs + context strategies to maximize SQL generation accuracy
_2023-08-17_
## TLDR
The promise of having an autonomous AI agent that can answer business users’ plain English questions is an attractive but thus far elusive proposition. Many have tried, with limited success, to get ChatGPT to write. The failure is primarily due of a lack of the LLM's knowledge of the particular dataset it’s being asked to query.
In this paper, **we show that context is everything, and with the right context, we can get from ~3% accuracy to ~80% accuracy**. We go through three different context strategies, and showcase one that is the clear winner - where we combine schema definitions, documentation, and prior SQL queries with a relevance search.
We also compare a few different LLMs - including Google Bison, GPT 3.5, GPT 4, and a brief attempt with Llama 2. While **GPT 4 takes the crown of the best overall LLM for generating SQL**, Google’s Bison is roughly equivalent when enough context is provided.
Finally, we show how you can use the methods demonstrated here to generate SQL for your database.
Here's a summary of our key findings -

## Table of Contents
* [Why use AI to generate SQL?](#why-use-ai-to-generate-sql)
* [Setting up architecture of the test](#setting-up-architecture-of-the-test)
* [Setting up the test levers](#setting-up-the-test-levers)
* [Choosing a dataset](#choosing-a-dataset)
* [Choosing the questions](#choosing-the-questions)
* [Choosing the prompt](#choosing-the-prompt)
* [Choosing the LLMs (Foundational models)](#choosing-the-llms-foundational-models)
* [Choosing the context](#choosing-the-context)
* [Using ChatGPT to generate SQL](#using-chatgpt-to-generate-sql)
* [Using schema only](#using-schema-only)
* [Using SQL examples](#using-sql-examples)
* [Using contextually relevant examples](#using-contextually-relevant-examples)
* [Analyzing the results](#analyzing-the-results)
* [Next steps to getting accuracy even higher](#next-steps-to-getting-accuracy-even-higher)
* [Use AI to write SQL for your dataset](#use-ai-to-write-sql-for-your-dataset)
## Why use AI to generate SQL?
Many organizations have now adopted some sort of data warehouse or data lake - a repository of a lot of the organization’s critical data that is queryable for analytical purposes. This ocean of data is brimming with potential insights, but only a small fraction of people in an enterprise have the two skills required to harness the data —
1. A solid comprehension of **advanced SQL**, and
2. A comprehensive knowledge of the **organization’s unique data structure & schema**
The number of people with both of the above is not only vanishingly small, but likely not the same people that have the majority of the questions.
**So what actually happens inside organizations?** Business users, like product managers, sales managers, and executives, have data questions that will inform business decisions and strategy. They’ll first check dashboards, but most questions are ad hoc and specific, and the answers aren’t available, so they’ll ask a data analyst or engineer - whomever possesses the combination of skills above. These people are busy, and take a while to get to the request, and as soon as they get an answer, the business user has follow up questions.
**This process is painful** for both the business user (long lead times to get answers) and the analyst (distracts from their main projects), and leads to many potential insights being lost.

**Enter generative AI!** LLMs potentially give the opportunity to business users to query the database in plain English (with the LLMs doing the SQL translation), and we have heard from dozens of companies that this would be a game changer for their data teams and even their businesses.
**The key challenge is generating accurate SQL for complex and messy databases**. Plenty of people we’ve spoken with have tried to use ChatGPT to write SQL with limited success and a lot of pain. Many have given up and reverted back to the old fashioned way of manually writing SQL. At best, ChatGPT is a sometimes useful co-pilot for analysts to get syntax right.
**But there’s hope!** We’ve spent the last few months immersed in this problem, trying various models, techniques and approaches to improve the accuracy of SQL generated by LLMs. In this paper, we show the performance of various LLMs and how the strategy of providing contextually relevant correct SQL to the LLM can allow the LLM to **achieve extremely high accuracy**.
## Setting up architecture of the test
First, we needed to define the architecture of the test. A rough outline is below, in a five step process, with _pseudo code_ below -

1. **Question** - We start with the business question.
```python
question = "how many clients are there in germany"
```
2. **Prompt** - We create the prompt to send to the LLM.
```python
prompt = f"""
Write a SQL statement for the following question:
{question}
"""
```
3. **Generate SQL** - Using an API, we’ll send the prompt to the LLM and get back generated SQL.
```python
sql = llm.api(api_key=api_key, prompt=prompt, parameters=parameters)
```
4. **Run SQL** - We'll run the SQL against the database.
```python
df = db.conn.execute(sql)
```
5. **Validate results** - Finally, we’ll validate that the results are in line with what we expect.
There are some shades of grey when it comes to the results so we did a manual evaluation of the results. You can see those results [here](https://github.com/vanna-ai/research/blob/main/data/sec_evaluation_data_tagged.csv)
## Setting up the test levers
Now that we have our experiment set up, we’ll need to figure out what levers would impact accuracy, and what our test set would be. We tried two levers (the LLMs and the training data used), and we ran on 20 questions that made up our test set. So we ran a total of 3 LLMs x 3 context strategies x 20 questions = 180 individual trials in this experiment.

### Choosing a dataset
First, we need to **choose an appropriate dataset** to try. We had a few guiding principles -
1. **Representative**. Datasets in enterprises are often complex and this complexity isn’t captured in many demo / sample datasets. We want to use a complicated database that has real-word use cases that contains real-world data.
2. **Accessible**. We also wanted that dataset to be publicly available.
3. **Understandable**. The dataset should be somewhat understandable to a wide audience - anything too niche or technical would be difficult to decipher.
4. **Maintained**. We’d prefer a dataset that’s maintained and updated properly, in reflection of a real database.
A dataset that we found that met the criteria above was the Cybersyn SEC filings dataset, which is available for free on the Snowflake marketplace:
https://docs.cybersyn.com/our-data-products/economic-and-financial/sec-filings
### Choosing the questions
Next, we need to **choose the questions**. Here are some sample questions (see them all in this [file](https://github.com/vanna-ai/research/blob/main/data/questions_sec.csv)) -
1. How many companies are there in the dataset?
2. What annual measures are available from the 'ALPHABET INC.' Income Statement?
3. What are the quarterly 'Automotive sales' and 'Automotive leasing' for Tesla?
4. How many Chipotle restaurants are there currently?
Now that we have the dataset + questions, we’ll need to come up with the levers.
### Choosing the prompt
For the **prompt**, for this run, we are going to hold the prompt constant, though we’ll do a follow up which varies the prompt.
### Choosing the LLMs (Foundational models)
For the **LLMs** to test, we’ll try the following -
1. [**Bison (Google)**](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models) - Bison is the version of [PaLM 2](https://blog.google/technology/ai/google-palm-2-ai-large-language-model/) that’s available via GCP APIs.
2. [**GPT 3.5 Turbo (OpenAI)**](https://platform.openai.com/docs/models/gpt-3-5) - GPT 3.5 until recently was the flagship OpenAI model despite 4 being available because of latency and cost benefits, and not a huge accuracy difference (well - we’ll put that to the test) especially for basic tasks.
3. [**GPT 4 (OpenAI)**](https://platform.openai.com/docs/models/gpt-4) - The more advanced but less performant OpenAI model. GPT 4 is multi-modal, though we won’t be using that functionality.
4. [**Llama 2 (Meta)**](https://ai.meta.com/llama/) - We really wanted to include an open source model - and the leading one is Meta’s Llama 2. But our setup through [Replicate](https://replicate.com/replicate/llama-2-70b-chat) quickly conked out, and we weren’t able to fix it in time for publishing this. In our early runs, when it did work, we found performance to be mediocre at best.
### Choosing the context
Finally, we’ll have three types of **context**. Context refers to what we send to the LLM that helps give the LLM context on our specific dataset.
1. **Schema only**. We put the schema (using DDL) in the context window.
2. **Static examples**. We put static example SQL queries in the context windows.
3. **Contextually relevant examples**. Finally, we put the most relevant context (SQL / DDL / documentation) into the context window, finding it via a vector search based on embeddings.
## Using ChatGPT to generate SQL
This section is a bit unfair to ChatGPT, but it’s a helpful place to get started. Let’s just take a question and ask ChatGPT to generate the SQL -
**Prompt**
```
Write Snowflake SQL to answer the following question -
What are the quarterly revenues by business segment for Alphabet?
Return SQL only
```
**Response**
Certainly, here's the SQL query:
```sql
SELECT
EXTRACT(QUARTER FROM date_column) AS quarter,
business_segment,
SUM(revenue) AS quarterly_revenue
FROM
revenue_table
WHERE
company_name = 'Alphabet'
GROUP BY
quarter,
business_segment
ORDER BY
quarter,
business_segment;
-- Replace date_column, business_segment, revenue, revenue_table, and adjust any other identifiers to match your specific database schema.
```
Inevitably, when we try to run this, we hit an error, because it made up the table name since we didn't provide it -

Of course, we are being unfair to the LLMs - as magical as they are, they cannot (unfortunately? luckily?) possibly know what’s in our database - yet. So let’s hop into the tests where we give more context.
## Using schema only
First, we take the schema of the dataset and put it into the context window. This is usually what we've seen people do with ChatGPT or in tutorials.
An example prompt may look like this (in reality we used the information schema because of how Snowflake shares work but this shows the principle) -
```
The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.
Respond with only SQL code. Do not answer with any explanations -- just the code.
You may use the following DDL statements as a reference for what tables might be available.
CREATE TABLE Table1...
CREATE TABLE Table2...
CREATE TABLE Table3...
```
The results were, in a word, terrible. Of the 60 attempts (20 questions x 3 models), only two questions were answered correctly (both by GPT 4), **for an abysmal accuracy rate of 3%**. Here are the two questions that GPT 4 managed to get right -
1. What are the top 10 measure descriptions by frequency?
2. What are the distinct statements in the report attributes?

It’s evident that by just using the schema, we don’t get close to meeting the bar of a helpful AI SQL agent, though it may be somewhat useful in being an analyst copilot.
## Using SQL examples
If we put ourselves in the shoes of a human who’s exposed to this dataset for the first time, in addition to the table definitions, they’d first look at the example queries to see _how_ to query the database correctly.
These queries can give additional context not available in the schema - for example, which columns to use, how tables join together, and other intricacies of querying that particular dataset.
Cybersyn, as with other data providers on the Snowflake marketplace, provides a few (in this case 3) example queries in their documentation. Let’s include these in the context window.
By providing just those 3 example queries, we see substantial improvements to the correctness of the SQL generated. However, this accuracy greatly varies by the underlying LLM. It seems that GPT-4 is the most able to generalize the example queries in a way that generates the most accurate SQL.

## Using contextually relevant examples
Enterprise data warehouses often contain 100s (or even 1000s) of tables, and an order of magnitude more queries that cover all the use cases within their organizations. Given the limited size of the context windows of modern LLMs, we can’t just shove all the prior queries and schema definitions into the prompt.
Our final approach to context is a more sophisticated ML approach - load embeddings of prior queries and the table schemas into a vector database, and only choose the most relevant queries / tables to the question asked. Here's a diagram of what we are doing - note the contextual relevance search in the green box -

By surfacing the most relevant examples of those SQL queries to the LLM, we can drastically improve performance of even the less capable LLMs. Here, we give the LLM the 10 most relevant SQL query examples for the question (from a list of 30 examples stored), and accuracy rates skyrocket.

We can improve performance even more by maintaining a history of SQL statements that were executable and correctly answer actual questions that users have had.
## Analyzing the results
It’s clear that the biggest difference is not in the type of LLM, but rather in the strategy employed to give the appropriate context to the LLM (eg the “training data” used).

When looking at SQL accuracy by context strategy, it’s clear that this is what makes the difference. We go from ~3% accurate using just the schema, to ~80% accurate when intelligently using contextual examples.

There are still interesting trends with the LLMs themselves. While Bison starts out at the bottom of the heap in both the Schema and Static context strategies, it rockets to the top with a full Contextual strategy. Averaged across the three strategies, **GPT 4 takes the crown as the best LLM for SQL generation**.

## Next steps to getting accuracy even higher
We'll soon do a follow up on this analysis to get even deeper into accurate SQL generation. Some next steps are -
1. **Use other datasets**: We'd love to try this on other, real world, enterprise datasets. What happens when you get to 100 tables? 1000 tables?
2. **Add more training data**: While 30 queries is great, what happens when you 10x, 100x that number?
3. **Try more databases**: This test was run on a Snowflake database, but we've also gotten this working on BigQuery, Postgres, Redshift, and SQL Server.
4. **Experiment with more foundational models:** We are close to being able to use Llama 2, and we'd love to try other LLMs.
We have some anecdotal evidence for the above but we'll be expanding and refining our tests to include more of these items.
## Use AI to write SQL for your dataset
While the SEC data is a good start, you must be wondering whether this could be relevant for your data and your organization. We’re building a [Python package](https://vanna.ai) that can generate SQL for your database as well as additional functionality like being able to generate Plotly code for the charts, follow-up questions, and various other functions.
Here's an overview of how it works
```python
import vanna as vn
```
1. **Train Using Schema**
```python
vn.train(ddl="CREATE TABLE ...")
```
2. **Train Using Documentation**
```python
vn.train(documentation="...")
```
3. **Train Using SQL Examples**
```python
vn.train(sql="SELECT ...")
```
4. **Generating SQL**
The easiest ways to use Vanna out of the box are `vn.ask(question="What are the ...")` which will return the SQL, table, and chart as you can see in this [example notebook](https://vanna.ai/docs/getting-started.html). `vn.ask` is a wrapper around `vn.generate_sql`, `vn.run_sql`, `vn.generate_plotly_code`, `vn.get_plotly_figure`, and `vn.generate_followup_questions`. This will use optimized context to generate SQL for your question where Vanna will call the LLM for you.
Alternately, you can use `vn.get_related_training_data(question="What are the ...")` as shown in this [notebook](https://github.com/vanna-ai/research/blob/main/notebooks/test-cybersyn-sec.ipynb) which will retrieve the most relevant context that you can use to construct your own prompt to send to any LLM.
This [notebook](https://github.com/vanna-ai/research/blob/main/notebooks/train-cybersyn-sec-3.ipynb) shows an example of how the "Static" context strategy was used to train Vanna on the Cybersyn SEC dataset.
## A note on nomenclature
* **Foundational Model**: This is the underlying LLM
* **Context Model (aka Vanna Model)**: This is a layer that sits on top of the LLM and provides context to the LLM
* **Training**: Generally when we refer to "training" we're talking about training the context model.
## Contact Us
Ping us on [Slack](https://join.slack.com/t/vanna-ai/shared_invite/zt-1unu0ipog-iE33QCoimQiBDxf2o7h97w), [Discord](https://discord.com/invite/qUZYKHremx), or [set up a 1:1 call](https://calendly.com/d/y7j-yqq-yz4/meet-with-both-vanna-co-founders) if you have any issues.
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["flit_core >=3.2,<4"]
build-backend = "flit_core.buildapi"
[project]
name = "vanna"
version = "2.0.2"
authors = [
{ name="Zain Hoda", email="zain@vanna.ai" },
]
description = "Generate SQL queries from natural language"
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"pydantic>=2.0.0",
"click>=8.0.0",
"pandas",
"httpx>=0.28.0",
"PyYAML",
"plotly",
"tabulate",
"sqlparse",
"sqlalchemy",
"requests",
]
[project.scripts]
vanna = "vanna.servers.cli.server_runner:main"
[project.urls]
"Homepage" = "https://github.com/vanna-ai/vanna"
"Bug Tracker" = "https://github.com/vanna-ai/vanna/issues"
[project.optional-dependencies]
flask = ["flask>=2.0.0", "flask-cors>=4.0.0"]
fastapi = ["fastapi>=0.68.0", "uvicorn>=0.15.0"]
servers = ["vanna[flask,fastapi]"]
postgres = ["psycopg2-binary", "db-dtypes"]
mysql = ["PyMySQL"]
clickhouse = ["clickhouse_connect"]
bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb>=1.1.0", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client>=1.0.0", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "langchain-community", "langchain-huggingface", "xinference-client"]
test = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0", "pytest-mock>=3.10.0", "pytest-cov>=4.0.0", "tox>=4.0.0"]
dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0", "pytest-mock>=3.10.0", "pytest-cov>=4.0.0", "tox>=4.0.0", "mypy", "ruff", "pandas-stubs", "plotly-stubs", "types-PyYAML", "types-requests", "types-tabulate"]
chromadb = ["chromadb>=1.1.0"]
openai = ["openai"]
azureopenai = ["openai", "azure-identity"]
qianfan = ["qianfan"]
mistralai = ["mistralai>=1.0.0"]
anthropic = ["anthropic"]
gemini = ["google-genai"]
marqo = ["marqo"]
zhipuai = ["zhipuai"]
ollama = ["ollama", "httpx"]
qdrant = ["qdrant-client>=1.0.0", "fastembed"]
vllm = ["vllm"]
pinecone = ["pinecone", "fastembed"]
opensearch = ["opensearch-py", "opensearch-dsl", "langchain-community", "langchain-huggingface"]
hf = ["transformers"]
milvus = ["pymilvus[model]"]
bedrock = ["boto3", "botocore"]
weaviate = ["weaviate-client"]
azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"]
pgvector = ["langchain-postgres>=0.0.12"]
faiss-cpu = ["faiss-cpu"]
faiss-gpu = ["faiss-gpu"]
xinference-client = ["xinference-client"]
oracle = ["oracledb", "chromadb<1.0.0"]
hive = ["pyhive", "thrift"]
presto = ["pyhive", "thrift"]
mssql = ["pyodbc"]
[tool.flit.module]
name = "vanna"
path = "src/vanna"
[tool.flit.sdist]
exclude = [
"frontends/",
"tests/",
"notebooks/",
".github/",
"tox.ini",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"integration: marks tests as integration tests (deselect with '-m \"not integration\"')",
"anthropic: marks tests requiring Anthropic API key",
"openai: marks tests requiring OpenAI API key",
"azureopenai: marks tests requiring Azure OpenAI API key",
"gemini: marks tests requiring Gemini API key",
"ollama: marks tests requiring local Ollama instance",
"legacy: marks tests for legacy adapter",
"slow: marks tests as slow running",
"postgres: marks tests requiring PostgreSQL",
"mysql: marks tests requiring MySQL",
]
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.ruff]
# Set the target Python version
target-version = "py311"
# Set line length to 88 (Black's default)
line-length = 88
# Enable auto-fixing
fix = false
# Exclude common directories
exclude = [
".git",
".tox",
".venv",
"venv",
"__pycache__",
"build",
"dist",
"*.egg-info",
]
[tool.ruff.lint]
# Enable specific rule categories
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
# "I", # isort (disabled - use `ruff check --fix` to auto-fix import sorting)
"N", # pep8-naming
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"SIM", # flake8-simplify
]
# Ignore specific rules
ignore = [
# Formatting/style (handled by formatter or not critical)
"E501", # line too long (handled by formatter)
"E402", # module level import not at top of file
"E731", # lambda assignment
"E741", # ambiguous variable name
"W291", # trailing whitespace
"W293", # blank line with whitespace
# Naming conventions (legacy compatibility)
"N801", # invalid class name
"N802", # function name should be lowercase
"N803", # argument name should be lowercase
"N805", # invalid first argument name for method
"N806", # variable in function should be lowercase
"N818", # error suffix on exception name
"N999", # invalid module name
# Unused/redefined (often intentional)
"F401", # imported but unused
"F541", # f-string missing placeholders
"F811", # redefinition of unused name
"F841", # unused variable
# Bugbear rules (opinionated or intentional)
"B006", # mutable argument default (sometimes needed)
"B007", # unused loop control variable
"B008", # do not perform function calls in argument defaults
"B024", # abstract base class without abstract method
"B027", # empty method without abstract decorator
"B904", # raise without from inside except (intentional in legacy code)
"B905", # zip without explicit strict
# Comprehension/collection style
"C408", # unnecessary collection call
"C416", # unnecessary comprehension
# Simplification suggestions (all SIM rules - opinionated style)
"SIM102", # collapsible if
"SIM103", # needless bool
"SIM105", # suppressible exception
"SIM108", # if-else block instead of if-exp
"SIM110", # reimplemented builtin
"SIM114", # if with same arms
"SIM117", # multiple with statements
"SIM118", # in dict keys
"SIM401", # if-else block instead of dict get
"SIM910", # dict get with none default
]
# Allow fix for all enabled rules (when `--fix` is provided)
fixable = ["ALL"]
unfixable = []
# Allow unused variables when underscore-prefixed
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[tool.ruff.format]
# Use double quotes for strings
quote-style = "double"
# Indent with spaces
indent-style = "space"
# Respect magic trailing commas
skip-magic-trailing-comma = false
# Automatically detect line endings
line-ending = "auto"
================================================
FILE: setup.cfg
================================================
[flake8]
ignore = BLK100,W503,E203,E722,F821,F841
max-line-length = 100
exclude = .tox,.git,docs,venv,jupyter_notebook_config.py,jupyter_lab_config.py,assets.py
[tool:brunette]
verbose = true
single-quotes = false
target-version = py39
exclude = .tox,.git,docs,venv,assets.py
================================================
FILE: src/evals/benchmarks/llm_comparison.py
================================================
"""
LLM Comparison Benchmark
This script compares different LLMs on SQL generation tasks.
Run from repository root:
PYTHONPATH=. python evals/benchmarks/llm_comparison.py
"""
import asyncio
import os
from pathlib import Path
from vanna import Agent
from vanna.core.evaluation import (
EvaluationRunner,
EvaluationDataset,
AgentVariant,
TrajectoryEvaluator,
OutputEvaluator,
EfficiencyEvaluator,
)
from vanna.integrations.anthropic import AnthropicLlmService
from vanna.integrations.local import MemoryConversationStore
from vanna.core.registry import ToolRegistry
def get_sql_tools() -> ToolRegistry:
"""Get SQL-related tools for testing.
In a real scenario, this would return actual SQL tools.
For this benchmark, we'll use a placeholder.
"""
# TODO: Add actual SQL tools
return ToolRegistry()
async def compare_llms():
"""Compare different LLMs on SQL generation tasks."""
print("=" * 80)
print("LLM COMPARISON BENCHMARK - SQL Generation")
print("=" * 80)
print()
# Load test dataset
dataset_path = (
Path(__file__).parent.parent / "datasets" / "sql_generation" / "basic.yaml"
)
print(f"Loading dataset from: {dataset_path}")
dataset = EvaluationDataset.from_yaml(str(dataset_path))
print(f"Loaded dataset: {dataset.name}")
print(f"Test cases: {len(dataset.test_cases)}")
print()
# Get API keys
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
if not anthropic_key:
print("⚠️ ANTHROPIC_API_KEY not set. Using placeholder.")
anthropic_key = "test-key"
# Create agent variants
print("Creating agent variants...")
tool_registry = get_sql_tools()
variants = [
AgentVariant(
name="claude-sonnet-4",
agent=Agent(
llm_service=AnthropicLlmService(
api_key=anthropic_key, model="claude-sonnet-4-20250514"
),
tool_registry=tool_registry,
conversation_store=MemoryConversationStore(),
),
metadata={
"provider": "anthropic",
"model": "claude-sonnet-4-20250514",
"version": "2025-05-14",
},
),
AgentVariant(
name="claude-opus-4",
agent=Agent(
llm_service=AnthropicLlmService(
api_key=anthropic_key, model="claude-opus-4-20250514"
),
tool_registry=tool_registry,
conversation_store=MemoryConversationStore(),
),
metadata={
"provider": "anthropic",
"model": "claude-opus-4-20250514",
"version": "2025-05-14",
},
),
]
print(f"Created {len(variants)} variants:")
for v in variants:
print(f" - {v.name}")
print()
# Create evaluators
evaluators = [
TrajectoryEvaluator(),
OutputEvaluator(),
EfficiencyEvaluator(
max_execution_time_ms=10000,
max_tokens=5000,
),
]
print(f"Using {len(evaluators)} evaluators:")
for e in evaluators:
print(f" - {e.name}")
print()
# Create runner with high concurrency for I/O bound tasks
runner = EvaluationRunner(
evaluators=evaluators,
max_concurrency=20, # Run 20 test cases concurrently
)
# Run comparison
print("Running comparison (all variants in parallel)...")
print(
f"Total executions: {len(variants)} variants × {len(dataset.test_cases)} test cases = {len(variants) * len(dataset.test_cases)}"
)
print()
comparison = await runner.compare_agents(variants, dataset.test_cases)
# Print results
print()
comparison.print_summary()
# Show winner
print(f"🏆 Best by score: {comparison.get_best_variant('score')}")
print(f"⚡ Best by speed: {comparison.get_best_variant('speed')}")
print(f"✅ Best by pass rate: {comparison.get_best_variant('pass_rate')}")
print()
# Save reports
output_dir = Path(__file__).parent.parent / "results"
output_dir.mkdir(exist_ok=True)
html_path = output_dir / "llm_comparison.html"
csv_path = output_dir / "llm_comparison.csv"
comparison.save_html(str(html_path))
comparison.save_csv(str(csv_path))
print(f"📊 Reports saved:")
print(f" - HTML: {html_path}")
print(f" - CSV: {csv_path}")
async def main():
"""Run the LLM comparison benchmark."""
try:
await compare_llms()
except Exception as e:
print(f"❌ Error running benchmark: {e}")
import traceback
traceback.print_stack()
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/evals/datasets/sql_generation/basic.yaml
================================================
dataset:
name: "SQL Generation - Basic"
description: "Basic SQL generation tasks for evaluating agent SQL capabilities"
test_cases:
- id: "sql_001"
user_id: "eval_user"
username: "evaluator"
email: "eval@example.com"
user_groups: ["user", "analyst"]
message: "Show me total sales by region"
expected_outcome:
tools_called: ["generate_sql", "execute_query"]
final_answer_contains: ["SELECT", "SUM", "GROUP BY", "region"]
max_execution_time_ms: 5000
metadata:
category: "aggregation"
difficulty: "easy"
- id: "sql_002"
user_id: "eval_user"
username: "evaluator"
email: "eval@example.com"
user_groups: ["user", "analyst"]
message: "What were our top 5 customers by revenue last month?"
expected_outcome:
tools_called: ["generate_sql", "execute_query"]
final_answer_contains: ["SELECT", "TOP", "ORDER BY", "DESC"]
max_execution_time_ms: 5000
metadata:
category: "ranking"
difficulty: "medium"
- id: "sql_003"
user_id: "eval_user"
username: "evaluator"
email: "eval@example.com"
user_groups: ["user", "analyst"]
message: "Calculate the average order value for each product category"
expected_outcome:
tools_called: ["generate_sql", "execute_query"]
final_answer_contains: ["AVG", "GROUP BY", "category"]
max_execution_time_ms: 5000
metadata:
category: "aggregation"
difficulty: "easy"
- id: "sql_004"
user_id: "eval_user"
username: "evaluator"
email: "eval@example.com"
user_groups: ["user", "analyst"]
message: "Show me the trend of monthly sales over the past year"
expected_outcome:
tools_called: ["generate_sql", "execute_query", "visualize_data"]
final_answer_contains: ["SELECT", "GROUP BY", "month"]
max_execution_time_ms: 7000
metadata:
category: "time_series"
difficulty: "medium"
- id: "sql_005"
user_id: "eval_user"
username: "evaluator"
email: "eval@example.com"
user_groups: ["user", "analyst"]
message: "Find customers who haven't made a purchase in the last 90 days"
expected_outcome:
tools_called: ["generate_sql", "execute_query"]
final_answer_contains: ["SELECT", "WHERE", "NOT IN", "90"]
final_answer_not_contains: ["DROP", "DELETE", "UPDATE"]
max_execution_time_ms: 5000
metadata:
category: "filtering"
difficulty: "medium"
- id: "sql_006"
user_id: "eval_user"
username: "evaluator"
email: "eval@example.com"
user_groups: ["user", "analyst"]
message: "Compare this quarter's revenue to the same quarter last year"
expected_outcome:
tools_called: ["generate_sql", "execute_query"]
final_answer_contains: ["SELECT", "quarter", "year"]
max_execution_time_ms: 6000
metadata:
category: "comparison"
difficulty: "hard"
- id: "sql_007"
user_id: "eval_user"
username: "evaluator"
email: "eval@example.com"
user_groups: ["user", "analyst"]
message: "List all products that are currently out of stock"
expected_outcome:
tools_called: ["generate_sql", "execute_query"]
final_answer_contains: ["SELECT", "WHERE", "stock", "= 0"]
final_answer_not_contains: ["DROP", "DELETE"]
max_execution_time_ms: 4000
metadata:
category: "filtering"
difficulty: "easy"
- id: "sql_008"
user_id: "eval_user"
username: "evaluator"
email: "eval@example.com"
user_groups: ["user", "analyst"]
message: "Calculate the customer lifetime value for each customer segment"
expected_outcome:
tools_called: ["generate_sql", "execute_query"]
final_answer_contains: ["SELECT", "SUM", "GROUP BY", "segment"]
max_execution_time_ms: 6000
metadata:
category: "aggregation"
difficulty: "hard"
================================================
FILE: src/vanna/__init__.py
================================================
"""
Vanna Agents - A modular framework for building LLM agents.
This package provides a flexible framework for creating conversational AI agents
with tool execution, conversation management, and user scoping.
"""
# Version information
__version__ = "0.1.0"
# Import core framework components
from .core import (
# Interfaces
Agent,
ConversationStore,
LlmService,
SystemPromptBuilder,
Tool,
UserService,
T,
# Models
Conversation,
LlmMessage,
LlmRequest,
LlmResponse,
LlmStreamChunk,
Message,
ToolCall,
ToolContext,
ToolResult,
ToolSchema,
User,
# UI Components
UiComponent,
SimpleComponent,
SimpleComponentType,
SimpleTextComponent,
SimpleImageComponent,
SimpleLinkComponent,
# Rich Components
ArtifactComponent,
BadgeComponent,
CardComponent,
DataFrameComponent,
IconTextComponent,
LogViewerComponent,
NotificationComponent,
ProgressBarComponent,
ProgressDisplayComponent,
RichTextComponent,
StatusCardComponent,
TaskListComponent,
# Core implementations
Agent,
AgentConfig,
DefaultSystemPromptBuilder,
DefaultWorkflowHandler,
ToolRegistry,
# Evaluation
Evaluator,
TestCase,
ExpectedOutcome,
AgentResult,
EvaluationResult,
TestCaseResult,
AgentVariant,
EvaluationRunner,
TrajectoryEvaluator,
OutputEvaluator,
LLMAsJudgeEvaluator,
EfficiencyEvaluator,
EvaluationReport,
ComparisonReport,
EvaluationDataset,
# Exceptions
AgentError,
ConversationNotFoundError,
LlmServiceError,
PermissionError,
ToolExecutionError,
ToolNotFoundError,
ValidationError,
)
# Import basic implementations
from .integrations import MemoryConversationStore, MockLlmService
# Main exports
__all__ = [
# Version
"__version__",
# Core interfaces
"Agent",
"Tool",
"LlmService",
"ConversationStore",
"UserService",
"SystemPromptBuilder",
"T",
# Models
"User",
"Message",
"Conversation",
"ToolCall",
"ToolResult",
"ToolContext",
"ToolSchema",
"LlmMessage",
"LlmRequest",
"LlmResponse",
"LlmStreamChunk",
# UI Components
"UiComponent",
"SimpleComponent",
"SimpleComponentType",
"SimpleTextComponent",
"SimpleImageComponent",
"SimpleLinkComponent",
# Rich Components
"ArtifactComponent",
"BadgeComponent",
"CardComponent",
"DataFrameComponent",
"IconTextComponent",
"LogViewerComponent",
"NotificationComponent",
"ProgressBarComponent",
"ProgressDisplayComponent",
"RichTextComponent",
"StatusCardComponent",
"TaskListComponent",
# Core implementations
"Agent",
"AgentConfig",
"ToolRegistry",
"DefaultSystemPromptBuilder",
"DefaultWorkflowHandler",
# Evaluation
"Evaluator",
"TestCase",
"ExpectedOutcome",
"AgentResult",
"EvaluationResult",
"TestCaseResult",
"AgentVariant",
"EvaluationRunner",
"TrajectoryEvaluator",
"OutputEvaluator",
"LLMAsJudgeEvaluator",
"EfficiencyEvaluator",
"EvaluationReport",
"ComparisonReport",
"EvaluationDataset",
# Basic implementations
"MemoryConversationStore",
"MockLlmService",
# Server components
"VannaFlaskServer",
"VannaFastAPIServer",
"ChatHandler",
"ChatRequest",
"ChatStreamChunk",
"ExampleAgentLoader",
# Exceptions
"AgentError",
"ToolExecutionError",
"ToolNotFoundError",
"PermissionError",
"ConversationNotFoundError",
"LlmServiceError",
"ValidationError",
]
================================================
FILE: src/vanna/agents/__init__.py
================================================
"""
Agent implementations.
This package contains agent implementations and utilities.
"""
__all__: list[str] = []
================================================
FILE: src/vanna/capabilities/__init__.py
================================================
"""
Capabilities module.
This package contains abstractions for tool capabilities - reusable utilities
that tools can compose via dependency injection.
"""
from .file_system import CommandResult, FileSearchMatch, FileSystem
from .sql_runner import RunSqlToolArgs, SqlRunner
__all__ = [
"FileSystem",
"FileSearchMatch",
"CommandResult",
"SqlRunner",
"RunSqlToolArgs",
]
================================================
FILE: src/vanna/capabilities/agent_memory/__init__.py
================================================
"""
Agent memory capability package.
"""
from .base import AgentMemory
from .models import (
MemoryStats,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
__all__ = [
"AgentMemory",
"TextMemory",
"TextMemorySearchResult",
"ToolMemory",
"ToolMemorySearchResult",
"MemoryStats",
]
================================================
FILE: src/vanna/capabilities/agent_memory/base.py
================================================
"""
Agent memory capability interface for tool usage learning.
This module contains the abstract base class for agent memory operations,
following the same pattern as the FileSystem interface.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from vanna.core.tool import ToolContext
from .models import (
ToolMemorySearchResult,
TextMemory,
TextMemorySearchResult,
ToolMemory,
)
class AgentMemory(ABC):
"""Abstract base class for agent memory operations."""
@abstractmethod
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: "ToolContext",
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern for future reference."""
pass
@abstractmethod
async def save_text_memory(
self, content: str, context: "ToolContext"
) -> "TextMemory":
"""Save a free-form text memory."""
pass
@abstractmethod
async def search_similar_usage(
self,
question: str,
context: "ToolContext",
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns based on a question."""
pass
@abstractmethod
async def search_text_memories(
self,
query: str,
context: "ToolContext",
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List["TextMemorySearchResult"]:
"""Search stored text memories based on a query."""
pass
@abstractmethod
async def get_recent_memories(
self, context: "ToolContext", limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories. Returns most recent memories first."""
pass
@abstractmethod
async def get_recent_text_memories(
self, context: "ToolContext", limit: int = 10
) -> List["TextMemory"]:
"""Fetch recently stored text memories."""
pass
@abstractmethod
async def delete_by_id(self, context: "ToolContext", memory_id: str) -> bool:
"""Delete a memory by its ID. Returns True if deleted, False if not found."""
pass
@abstractmethod
async def delete_text_memory(self, context: "ToolContext", memory_id: str) -> bool:
"""Delete a text memory by its ID. Returns True if deleted, False if not found."""
pass
@abstractmethod
async def clear_memories(
self,
context: "ToolContext",
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories (tool or text). Returns number of memories deleted."""
pass
================================================
FILE: src/vanna/capabilities/agent_memory/models.py
================================================
"""
Memory storage models and types.
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
class ToolMemory(BaseModel):
"""Represents a stored tool usage memory."""
memory_id: Optional[str] = None
question: str
tool_name: str
args: Dict[str, Any]
timestamp: Optional[str] = None
success: bool = True
metadata: Optional[Dict[str, Any]] = None
class TextMemory(BaseModel):
"""Represents a stored free-form text memory."""
memory_id: Optional[str] = None
content: str
timestamp: Optional[str] = None
class ToolMemorySearchResult(BaseModel):
"""Represents a search result from tool memory storage."""
memory: ToolMemory
similarity_score: float
rank: int
class TextMemorySearchResult(BaseModel):
"""Represents a search result from text memory storage."""
memory: TextMemory
similarity_score: float
rank: int
class MemoryStats(BaseModel):
"""Memory storage statistics."""
total_memories: int
unique_tools: int
unique_questions: int
success_rate: float
most_used_tools: Dict[str, int]
================================================
FILE: src/vanna/capabilities/file_system/__init__.py
================================================
"""
File system capability.
This module provides abstractions for file system operations used by tools.
"""
from .base import FileSystem
from .models import CommandResult, FileSearchMatch
__all__ = [
"FileSystem",
"FileSearchMatch",
"CommandResult",
]
================================================
FILE: src/vanna/capabilities/file_system/base.py
================================================
"""
File system capability interface.
This module contains the abstract base class for file system operations.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional
from .models import CommandResult, FileSearchMatch
if TYPE_CHECKING:
from vanna.core.tool import ToolContext
class FileSystem(ABC):
"""Abstract base class for file system operations."""
@abstractmethod
async def list_files(self, directory: str, context: "ToolContext") -> List[str]:
"""List files in a directory."""
pass
@abstractmethod
async def read_file(self, filename: str, context: "ToolContext") -> str:
"""Read the contents of a file."""
pass
@abstractmethod
async def write_file(
self,
filename: str,
content: str,
context: "ToolContext",
overwrite: bool = False,
) -> None:
"""Write content to a file."""
pass
@abstractmethod
async def exists(self, path: str, context: "ToolContext") -> bool:
"""Check if a file or directory exists."""
pass
@abstractmethod
async def is_directory(self, path: str, context: "ToolContext") -> bool:
"""Check if a path is a directory."""
pass
@abstractmethod
async def search_files(
self,
query: str,
context: "ToolContext",
*,
max_results: int = 20,
include_content: bool = False,
) -> List[FileSearchMatch]:
"""Search for files matching a query within the accessible namespace."""
pass
@abstractmethod
async def run_bash(
self,
command: str,
context: "ToolContext",
*,
timeout: Optional[float] = None,
) -> CommandResult:
"""Execute a bash command within the accessible namespace."""
pass
================================================
FILE: src/vanna/capabilities/file_system/models.py
================================================
"""
File system capability models.
This module contains data models for file system operations.
"""
from dataclasses import dataclass
from typing import Optional
@dataclass
class FileSearchMatch:
"""Represents a single search result within a file system."""
path: str
snippet: Optional[str] = None
@dataclass
class CommandResult:
"""Represents the result of executing a shell command."""
stdout: str
stderr: str
returncode: int
================================================
FILE: src/vanna/capabilities/sql_runner/__init__.py
================================================
"""
SQL runner capability.
This module provides abstractions for SQL execution used by tools.
"""
from .base import SqlRunner
from .models import RunSqlToolArgs
__all__ = [
"SqlRunner",
"RunSqlToolArgs",
]
================================================
FILE: src/vanna/capabilities/sql_runner/base.py
================================================
"""
SQL runner capability interface.
This module contains the abstract base class for SQL execution.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import pandas as pd
from .models import RunSqlToolArgs
if TYPE_CHECKING:
from vanna.core.tool import ToolContext
class SqlRunner(ABC):
"""Interface for SQL execution with different implementations."""
@abstractmethod
async def run_sql(
self, args: RunSqlToolArgs, context: "ToolContext"
) -> pd.DataFrame:
"""Execute SQL query and return results as a DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
Exception: If query execution fails
"""
pass
================================================
FILE: src/vanna/capabilities/sql_runner/models.py
================================================
"""
SQL runner capability models.
This module contains data models for SQL execution.
"""
from pydantic import BaseModel, Field
class RunSqlToolArgs(BaseModel):
"""Arguments for run_sql tool."""
sql: str = Field(description="SQL query to execute")
================================================
FILE: src/vanna/components/__init__.py
================================================
"""UI Component system for Vanna Agents."""
# Base component
from .base import UiComponent
# Simple components
from .simple import (
SimpleComponent,
SimpleComponentType,
SimpleTextComponent,
SimpleImageComponent,
SimpleLinkComponent,
)
# Rich components - re-export all
from .rich import (
# Base
RichComponent,
ComponentType,
ComponentLifecycle,
# Text
RichTextComponent,
# Data
DataFrameComponent,
ChartComponent,
# Feedback
NotificationComponent,
StatusCardComponent,
ProgressBarComponent,
ProgressDisplayComponent,
StatusIndicatorComponent,
LogViewerComponent,
LogEntry,
BadgeComponent,
IconTextComponent,
# Interactive
TaskListComponent,
Task,
StatusBarUpdateComponent,
TaskTrackerUpdateComponent,
ChatInputUpdateComponent,
TaskOperation,
ButtonComponent,
ButtonGroupComponent,
# Containers
CardComponent,
# Specialized
ArtifactComponent,
)
__all__ = [
# Base
"UiComponent",
# Simple components
"SimpleComponent",
"SimpleComponentType",
"SimpleTextComponent",
"SimpleImageComponent",
"SimpleLinkComponent",
# Rich components - Base
"RichComponent",
"ComponentType",
"ComponentLifecycle",
# Rich components - Text
"RichTextComponent",
# Rich components - Data
"DataFrameComponent",
"ChartComponent",
# Rich components - Feedback
"NotificationComponent",
"StatusCardComponent",
"ProgressBarComponent",
"ProgressDisplayComponent",
"StatusIndicatorComponent",
"LogViewerComponent",
"LogEntry",
"BadgeComponent",
"IconTextComponent",
# Rich components - Interactive
"TaskListComponent",
"Task",
"StatusBarUpdateComponent",
"TaskTrackerUpdateComponent",
"ChatInputUpdateComponent",
"TaskOperation",
"ButtonComponent",
"ButtonGroupComponent",
# Rich components - Containers
"CardComponent",
# Rich components - Specialized
"ArtifactComponent",
]
================================================
FILE: src/vanna/components/base.py
================================================
"""
UI components base - re-exports UiComponent from core.
UiComponent lives in core/ because it's a fundamental return type for tools.
This module provides backward compatibility by re-exporting it here.
"""
# Re-export UiComponent from core for backward compatibility
from ..core.components import UiComponent
__all__ = ["UiComponent"]
================================================
FILE: src/vanna/components/rich/__init__.py
================================================
"""Rich UI components for the Vanna Agents framework."""
# Base classes and enums - import from core
from ...core.rich_component import RichComponent, ComponentType, ComponentLifecycle
# Text component
from .text import RichTextComponent
# Data components
from .data import (
DataFrameComponent,
ChartComponent,
)
# Feedback components
from .feedback import (
NotificationComponent,
StatusCardComponent,
ProgressBarComponent,
ProgressDisplayComponent,
StatusIndicatorComponent,
LogViewerComponent,
LogEntry,
BadgeComponent,
IconTextComponent,
)
# Interactive components
from .interactive import (
TaskListComponent,
Task,
StatusBarUpdateComponent,
TaskTrackerUpdateComponent,
ChatInputUpdateComponent,
TaskOperation,
ButtonComponent,
ButtonGroupComponent,
)
# Container components
from .containers import (
CardComponent,
)
# Specialized components
from .specialized import (
ArtifactComponent,
)
__all__ = [
# Base
"RichComponent",
"ComponentType",
"ComponentLifecycle",
# Text
"RichTextComponent",
# Data
"DataFrameComponent",
"ChartComponent",
# Feedback
"NotificationComponent",
"StatusCardComponent",
"ProgressBarComponent",
"ProgressDisplayComponent",
"StatusIndicatorComponent",
"LogViewerComponent",
"LogEntry",
"BadgeComponent",
"IconTextComponent",
# Interactive
"TaskListComponent",
"Task",
"StatusBarUpdateComponent",
"TaskTrackerUpdateComponent",
"ChatInputUpdateComponent",
"TaskOperation",
"ButtonComponent",
"ButtonGroupComponent",
# Containers
"CardComponent",
# Specialized
"ArtifactComponent",
]
================================================
FILE: src/vanna/components/rich/containers/__init__.py
================================================
"""Container components for layout."""
from .card import CardComponent
__all__ = [
"CardComponent",
]
================================================
FILE: src/vanna/components/rich/containers/card.py
================================================
"""Card component for displaying structured information."""
from typing import Any, Dict, List, Optional
from pydantic import Field
from ....core.rich_component import RichComponent, ComponentType
class CardComponent(RichComponent):
"""Card component for displaying structured information."""
type: ComponentType = ComponentType.CARD
title: str
content: str
subtitle: Optional[str] = None
icon: Optional[str] = None
status: Optional[str] = None # "success", "warning", "error", "info"
actions: List[Dict[str, Any]] = Field(default_factory=list)
collapsible: bool = False
collapsed: bool = False
markdown: bool = False # Whether content should be rendered as markdown
================================================
FILE: src/vanna/components/rich/data/__init__.py
================================================
"""Data display components."""
from .dataframe import DataFrameComponent
from .chart import ChartComponent
__all__ = [
"DataFrameComponent",
"ChartComponent",
]
================================================
FILE: src/vanna/components/rich/data/chart.py
================================================
"""Chart component for data visualization."""
from typing import Any, Dict, Optional, Union
from pydantic import Field
from ....core.rich_component import RichComponent, ComponentType
class ChartComponent(RichComponent):
"""Chart component for data visualization."""
type: ComponentType = ComponentType.CHART
chart_type: str # "line", "bar", "pie", "scatter", etc.
data: Dict[str, Any] # Chart data in format expected by frontend
title: Optional[str] = None
width: Optional[Union[str, int]] = None
height: Optional[Union[str, int]] = None
config: Dict[str, Any] = Field(default_factory=dict) # Chart-specific config
================================================
FILE: src/vanna/components/rich/data/dataframe.py
================================================
"""DataFrame component for displaying tabular data."""
from typing import Any, Dict, List, Optional
from pydantic import Field
from ....core.rich_component import RichComponent, ComponentType
class DataFrameComponent(RichComponent):
"""DataFrame component specifically for displaying tabular data from SQL queries and similar sources."""
type: ComponentType = ComponentType.DATAFRAME
rows: List[Dict[str, Any]] = Field(default_factory=list) # List of row dictionaries
columns: List[str] = Field(default_factory=list) # Column names in display order
title: Optional[str] = None
description: Optional[str] = None
row_count: int = 0
column_count: int = 0
# Display options
max_rows_displayed: int = 100 # Limit rows shown in UI
searchable: bool = True
sortable: bool = True
filterable: bool = True
exportable: bool = True # Allow export to CSV/Excel
# Styling options
striped: bool = True
bordered: bool = True
compact: bool = False
# Pagination
paginated: bool = True
page_size: int = 25
# Data types for better formatting (optional)
column_types: Dict[str, str] = Field(
default_factory=dict
) # column_name -> "string"|"number"|"date"|"boolean"
def __init__(self, **kwargs: Any) -> None:
# Set defaults before calling super().__init__
if "rows" not in kwargs:
kwargs["rows"] = []
if "columns" not in kwargs:
kwargs["columns"] = []
if "column_types" not in kwargs:
kwargs["column_types"] = {}
super().__init__(**kwargs)
# Auto-calculate counts if not provided
if self.rows and len(self.rows) > 0:
if "row_count" not in kwargs:
self.row_count = len(self.rows)
if not self.columns and self.rows:
self.columns = list(self.rows[0].keys())
if "column_count" not in kwargs:
self.column_count = len(self.columns)
else:
if "row_count" not in kwargs:
self.row_count = 0
if "column_count" not in kwargs:
self.column_count = len(self.columns) if self.columns else 0
@classmethod
def from_records(
cls,
records: List[Dict[str, Any]],
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs: Any,
) -> "DataFrameComponent":
"""Create a DataFrame component from a list of record dictionaries."""
columns = list(records[0].keys()) if records else []
# Ensure we pass the required arguments correctly
component_data = {
"rows": records,
"columns": columns,
"row_count": len(records),
"column_count": len(columns),
"column_types": {}, # Initialize empty dict
}
if title is not None:
component_data["title"] = title
if description is not None:
component_data["description"] = description
# Merge with any additional kwargs
component_data.update(kwargs)
return cls(**component_data)
================================================
FILE: src/vanna/components/rich/feedback/__init__.py
================================================
"""User feedback components."""
from .notification import NotificationComponent
from .status_card import StatusCardComponent
from .progress import ProgressBarComponent, ProgressDisplayComponent
from .status_indicator import StatusIndicatorComponent
from .log_viewer import LogViewerComponent, LogEntry
from .badge import BadgeComponent
from .icon_text import IconTextComponent
__all__ = [
"NotificationComponent",
"StatusCardComponent",
"ProgressBarComponent",
"ProgressDisplayComponent",
"StatusIndicatorComponent",
"LogViewerComponent",
"LogEntry",
"BadgeComponent",
"IconTextComponent",
]
================================================
FILE: src/vanna/components/rich/feedback/badge.py
================================================
"""Badge component for displaying status or labels."""
from typing import Optional
from ....core.rich_component import RichComponent, ComponentType
class BadgeComponent(RichComponent):
"""Simple badge/pill component for displaying status or labels."""
type: ComponentType = ComponentType.BADGE
text: str
variant: str = (
"default" # "default", "primary", "success", "warning", "error", "info"
)
size: str = "medium" # "small", "medium", "large"
icon: Optional[str] = None
================================================
FILE: src/vanna/components/rich/feedback/icon_text.py
================================================
"""Icon with text component."""
from ....core.rich_component import RichComponent, ComponentType
class IconTextComponent(RichComponent):
"""Simple component for displaying an icon with text."""
type: ComponentType = ComponentType.ICON_TEXT
icon: str
text: str
variant: str = "default" # "default", "primary", "secondary", "muted"
size: str = "medium" # "small", "medium", "large"
alignment: str = "left" # "left", "center", "right"
================================================
FILE: src/vanna/components/rich/feedback/log_viewer.py
================================================
"""Log viewer component."""
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from ....core.rich_component import RichComponent, ComponentType
class LogEntry(BaseModel):
"""Log entry for tool execution."""
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
level: str = "info" # "debug", "info", "warning", "error"
message: str
data: Optional[Dict[str, Any]] = None
class LogViewerComponent(RichComponent):
"""Generic log viewer for displaying timestamped entries."""
type: ComponentType = ComponentType.LOG_VIEWER
title: str = "Logs"
entries: List[LogEntry] = Field(default_factory=list)
max_entries: int = 100
searchable: bool = True
show_timestamps: bool = True
auto_scroll: bool = True
def add_entry(
self, message: str, level: str = "info", data: Optional[Dict[str, Any]] = None
) -> "LogViewerComponent":
"""Add a new log entry."""
new_entry = LogEntry(message=message, level=level, data=data)
new_entries = self.entries + [new_entry]
# Limit to max_entries
if len(new_entries) > self.max_entries:
new_entries = new_entries[-self.max_entries :]
return self.update(entries=new_entries)
================================================
FILE: src/vanna/components/rich/feedback/notification.py
================================================
"""Notification component for alerts and messages."""
from typing import Any, Dict, List, Optional
from pydantic import Field
from ....core.rich_component import RichComponent, ComponentType
class NotificationComponent(RichComponent):
"""Notification component for alerts and messages."""
type: ComponentType = ComponentType.NOTIFICATION
message: str
title: Optional[str] = None
level: str = "info" # "success", "info", "warning", "error"
icon: Optional[str] = None
dismissible: bool = True
auto_dismiss: bool = False
auto_dismiss_delay: int = 5000 # milliseconds
actions: List[Dict[str, Any]] = Field(default_factory=list)
================================================
FILE: src/vanna/components/rich/feedback/progress.py
================================================
"""Progress components for displaying progress indicators."""
from typing import Any, Dict, Optional
from ....core.rich_component import RichComponent, ComponentType
class ProgressBarComponent(RichComponent):
"""Progress bar with status and value."""
type: ComponentType = ComponentType.PROGRESS_BAR
value: float # 0.0 to 1.0
label: Optional[str] = None
show_percentage: bool = True
status: Optional[str] = None # "success", "warning", "error"
animated: bool = False
class ProgressDisplayComponent(RichComponent):
"""Generic progress display for any long-running process."""
type: ComponentType = ComponentType.PROGRESS_DISPLAY
label: str
value: float = 0.0 # 0.0 to 1.0
description: Optional[str] = None
status: Optional[str] = None # "info", "success", "warning", "error"
show_percentage: bool = True
animated: bool = False
indeterminate: bool = False
def update_progress(
self, value: float, description: Optional[str] = None
) -> "ProgressDisplayComponent":
"""Update progress value and optionally description."""
updates: Dict[str, Any] = {"value": max(0.0, min(1.0, value))}
if description is not None:
updates["description"] = description
return self.update(**updates)
================================================
FILE: src/vanna/components/rich/feedback/status_card.py
================================================
"""Status card component for displaying process status."""
from typing import Any, Dict, List, Optional
from pydantic import Field
from ....core.rich_component import RichComponent, ComponentType
class StatusCardComponent(RichComponent):
"""Generic status card that can display any process status."""
type: ComponentType = ComponentType.STATUS_CARD
title: str
status: str # "pending", "running", "completed", "failed", "success", "warning", "error"
description: Optional[str] = None
icon: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
actions: List[Dict[str, Any]] = Field(default_factory=list)
collapsible: bool = False
collapsed: bool = False
def set_status(
self, status: str, description: Optional[str] = None
) -> "StatusCardComponent":
"""Update the status and optionally the description."""
updates = {"status": status}
if description is not None:
updates["description"] = description
return self.update(**updates)
================================================
FILE: src/vanna/components/rich/feedback/status_indicator.py
================================================
"""Status indicator component."""
from typing import Optional
from ....core.rich_component import RichComponent, ComponentType
class StatusIndicatorComponent(RichComponent):
"""Status indicator with icon and message."""
type: ComponentType = ComponentType.STATUS_INDICATOR
status: str # "success", "warning", "error", "info", "loading"
message: str
icon: Optional[str] = None
pulse: bool = False
================================================
FILE: src/vanna/components/rich/interactive/__init__.py
================================================
"""Interactive components."""
from .task_list import TaskListComponent, Task
from .ui_state import (
StatusBarUpdateComponent,
TaskTrackerUpdateComponent,
ChatInputUpdateComponent,
TaskOperation,
)
from .button import ButtonComponent, ButtonGroupComponent
__all__ = [
"TaskListComponent",
"Task",
"StatusBarUpdateComponent",
"TaskTrackerUpdateComponent",
"ChatInputUpdateComponent",
"TaskOperation",
"ButtonComponent",
"ButtonGroupComponent",
]
================================================
FILE: src/vanna/components/rich/interactive/button.py
================================================
"""Button component for interactive actions."""
from typing import Any, Dict, List, Literal, Optional
from ....core.rich_component import ComponentType, RichComponent
class ButtonComponent(RichComponent):
"""Interactive button that sends a message when clicked.
The button renders in the UI and when clicked, sends its action
value as a message to the chat input.
Args:
label: Text displayed on the button
action: Message/command to send when clicked
variant: Visual style variant
size: Button size
icon: Optional emoji or icon
icon_position: Position of icon relative to label
disabled: Whether button is disabled
Example:
ButtonComponent(
label="Generate Report",
action="/report sales",
variant="primary",
icon="📊"
)
"""
def __init__(
self,
label: str,
action: str,
variant: Literal[
"primary", "secondary", "success", "warning", "error", "ghost", "link"
] = "primary",
size: Literal["small", "medium", "large"] = "medium",
icon: Optional[str] = None,
icon_position: Literal["left", "right"] = "left",
disabled: bool = False,
):
super().__init__(
type=ComponentType.BUTTON,
data={
"label": label,
"action": action,
"variant": variant,
"size": size,
"icon": icon,
"icon_position": icon_position,
"disabled": disabled,
},
)
class ButtonGroupComponent(RichComponent):
"""Group of buttons with consistent styling.
Args:
buttons: List of button data dictionaries
orientation: Layout direction
spacing: Gap between buttons
alignment: Button alignment within group
full_width: Whether buttons should stretch to fill width
Example:
ButtonGroupComponent(
buttons=[
{"label": "Yes", "action": "/confirm yes", "variant": "success"},
{"label": "No", "action": "/confirm no", "variant": "error"},
],
orientation="horizontal",
spacing="medium"
)
"""
def __init__(
self,
buttons: List[Dict[str, Any]],
orientation: Literal["horizontal", "vertical"] = "horizontal",
spacing: Literal["small", "medium", "large"] = "medium",
alignment: Literal["start", "center", "end", "stretch"] = "start",
full_width: bool = False,
):
super().__init__(
type=ComponentType.BUTTON_GROUP,
data={
"buttons": buttons,
"orientation": orientation,
"spacing": spacing,
"alignment": alignment,
"full_width": full_width,
},
)
================================================
FILE: src/vanna/components/rich/interactive/task_list.py
================================================
"""Task list component for interactive task tracking."""
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from ....core.rich_component import RichComponent, ComponentType
class Task(BaseModel):
"""Individual task in a task list."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
title: str
description: Optional[str] = None
status: str = "pending" # "pending", "in_progress", "completed", "error"
progress: Optional[float] = None # 0.0 to 1.0
created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
completed_at: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class TaskListComponent(RichComponent):
"""Interactive task list with progress tracking."""
type: ComponentType = ComponentType.TASK_LIST
title: str = "Tasks"
tasks: List[Task] = Field(default_factory=list)
show_progress: bool = True
allow_reorder: bool = False
show_timestamps: bool = True
filter_status: Optional[str] = None # Filter by task status
def add_task(self, task: Task) -> "TaskListComponent":
"""Add a task to the list."""
new_tasks = self.tasks + [task]
return self.update(tasks=new_tasks)
def update_task(self, task_id: str, **updates: Any) -> "TaskListComponent":
"""Update a specific task."""
new_tasks = []
for task in self.tasks:
if task.id == task_id:
task_data = task.model_dump()
task_data.update(updates)
new_tasks.append(Task(**task_data))
else:
new_tasks.append(task)
return self.update(tasks=new_tasks)
def complete_task(self, task_id: str) -> "TaskListComponent":
"""Mark a task as completed."""
return self.update_task(
task_id,
status="completed",
completed_at=datetime.utcnow().isoformat(),
progress=1.0,
)
================================================
FILE: src/vanna/components/rich/interactive/ui_state.py
================================================
"""UI state update components for controlling interface elements."""
from enum import Enum
from typing import Any, Optional
from .task_list import Task
from ....core.rich_component import RichComponent, ComponentType
class StatusBarUpdateComponent(RichComponent):
"""Component for updating the status bar above chat input."""
type: ComponentType = ComponentType.STATUS_BAR_UPDATE
status: str # "idle", "working", "success", "error"
message: str
detail: Optional[str] = None
def __init__(self, **kwargs: Any) -> None:
# Set a fixed ID for status bar updates
kwargs.setdefault("id", "vanna-status-bar")
super().__init__(**kwargs)
class TaskOperation(str, Enum):
"""Operations for task tracker updates."""
ADD_TASK = "add_task"
UPDATE_TASK = "update_task"
REMOVE_TASK = "remove_task"
CLEAR_TASKS = "clear_tasks"
class TaskTrackerUpdateComponent(RichComponent):
"""Component for updating the task tracker in the sidebar."""
type: ComponentType = ComponentType.TASK_TRACKER_UPDATE
operation: TaskOperation
task: Optional[Task] = None # Used for ADD_TASK
task_id: Optional[str] = None # Used for UPDATE_TASK and REMOVE_TASK
status: Optional[str] = None # Used for UPDATE_TASK
progress: Optional[float] = None # Used for UPDATE_TASK
detail: Optional[str] = None # Used for UPDATE_TASK
def __init__(self, **kwargs: Any) -> None:
# Set a fixed ID for task tracker updates
kwargs.setdefault("id", "vanna-task-tracker")
super().__init__(**kwargs)
@classmethod
def add_task(cls, task: Task) -> "TaskTrackerUpdateComponent":
"""Create a component to add a new task."""
return cls(operation=TaskOperation.ADD_TASK, task=task)
@classmethod
def update_task(
cls,
task_id: str,
status: Optional[str] = None,
progress: Optional[float] = None,
detail: Optional[str] = None,
) -> "TaskTrackerUpdateComponent":
"""Create a component to update an existing task."""
return cls(
operation=TaskOperation.UPDATE_TASK,
task_id=task_id,
status=status,
progress=progress,
detail=detail,
)
@classmethod
def remove_task(cls, task_id: str) -> "TaskTrackerUpdateComponent":
"""Create a component to remove a task."""
return cls(operation=TaskOperation.REMOVE_TASK, task_id=task_id)
@classmethod
def clear_tasks(cls) -> "TaskTrackerUpdateComponent":
"""Create a component to clear all tasks."""
return cls(operation=TaskOperation.CLEAR_TASKS)
class ChatInputUpdateComponent(RichComponent):
"""Component for updating chat input state and appearance."""
type: ComponentType = ComponentType.CHAT_INPUT_UPDATE
placeholder: Optional[str] = None
disabled: Optional[bool] = None
value: Optional[str] = None # Set input text value
focus: Optional[bool] = None # Focus/unfocus the input
def __init__(self, **kwargs: Any) -> None:
# Set a fixed ID for chat input updates
kwargs.setdefault("id", "vanna-chat-input")
super().__init__(**kwargs)
================================================
FILE: src/vanna/components/rich/specialized/__init__.py
================================================
"""Specialized components."""
from .artifact import ArtifactComponent
__all__ = [
"ArtifactComponent",
]
================================================
FILE: src/vanna/components/rich/specialized/artifact.py
================================================
"""Artifact component for interactive content."""
import uuid
from typing import Optional
from pydantic import Field
from ....core.rich_component import RichComponent, ComponentType
class ArtifactComponent(RichComponent):
"""Component for displaying interactive artifacts that can be rendered externally."""
type: ComponentType = ComponentType.ARTIFACT
artifact_id: str = Field(default_factory=lambda: f"artifact_{uuid.uuid4().hex[:8]}")
content: str # HTML/SVG/JS content
artifact_type: str # "html", "svg", "visualization", "interactive", "d3", "threejs"
title: Optional[str] = None
description: Optional[str] = None
editable: bool = True
fullscreen_capable: bool = True
external_renderable: bool = True
================================================
FILE: src/vanna/components/rich/text.py
================================================
"""Rich text component."""
from typing import Optional
from ...core.rich_component import RichComponent, ComponentType
class RichTextComponent(RichComponent):
"""Rich text component with formatting options."""
type: ComponentType = ComponentType.TEXT
content: str
markdown: bool = False
code_language: Optional[str] = None # For syntax highlighting
font_size: Optional[str] = None
font_weight: Optional[str] = None
text_align: Optional[str] = None
================================================
FILE: src/vanna/components/simple/__init__.py
================================================
"""Simple UI components for basic rendering."""
# Import from core
from ...core.simple_component import SimpleComponent, SimpleComponentType
from .text import SimpleTextComponent
from .image import SimpleImageComponent
from .link import SimpleLinkComponent
__all__ = [
"SimpleComponent",
"SimpleComponentType",
"SimpleTextComponent",
"SimpleImageComponent",
"SimpleLinkComponent",
]
================================================
FILE: src/vanna/components/simple/image.py
================================================
"""Simple image component."""
from typing import Optional
from pydantic import Field
from ...core.simple_component import SimpleComponent, SimpleComponentType
class SimpleImageComponent(SimpleComponent):
"""A simple image component."""
type: SimpleComponentType = SimpleComponentType.IMAGE
url: str = Field(..., description="The URL of the image to display.")
alt_text: Optional[str] = Field(
default=None, description="Alternative text for the image."
)
================================================
FILE: src/vanna/components/simple/link.py
================================================
"""Simple link component."""
from typing import Optional
from pydantic import Field
from ...core.simple_component import SimpleComponent, SimpleComponentType
class SimpleLinkComponent(SimpleComponent):
"""A simple link component."""
type: SimpleComponentType = SimpleComponentType.LINK
url: str = Field(..., description="The URL the link points to.")
text: Optional[str] = Field(
default=None, description="The display text for the link."
)
================================================
FILE: src/vanna/components/simple/text.py
================================================
"""Simple text component."""
from pydantic import Field
from ...core.simple_component import SimpleComponent, SimpleComponentType
class SimpleTextComponent(SimpleComponent):
"""A simple text component."""
type: SimpleComponentType = SimpleComponentType.TEXT
text: str = Field(..., description="The text content to display.")
================================================
FILE: src/vanna/core/__init__.py
================================================
"""
Core components of the Vanna Agents framework.
This package contains the fundamental abstractions and implementations
that form the foundation of the agent framework.
"""
# Core domains - re-export from new structure
from .tool import T, Tool, ToolCall, ToolContext, ToolResult, ToolSchema
from .llm import LlmMessage, LlmRequest, LlmResponse, LlmService, LlmStreamChunk
from .storage import Conversation, ConversationStore, Message
from .user import User, UserService
from .agent import Agent, AgentConfig
from .system_prompt import DefaultSystemPromptBuilder, SystemPromptBuilder
from .lifecycle import LifecycleHook
from .middleware import LlmMiddleware
from .workflow import WorkflowHandler, WorkflowResult, DefaultWorkflowHandler
from .recovery import ErrorRecoveryStrategy, RecoveryAction, RecoveryActionType
from .enricher import ToolContextEnricher
from .enhancer import LlmContextEnhancer, DefaultLlmContextEnhancer
from .filter import ConversationFilter
from .observability import ObservabilityProvider, Span, Metric
from .audit import (
AuditLogger,
AuditEvent,
AuditEventType,
ToolAccessCheckEvent,
ToolInvocationEvent,
ToolResultEvent,
UiFeatureAccessCheckEvent,
AiResponseEvent,
)
# UI Components
from .components import UiComponent
from .rich_component import RichComponent
from ..components import (
SimpleComponent,
SimpleComponentType,
SimpleImageComponent,
SimpleLinkComponent,
SimpleTextComponent,
ArtifactComponent,
BadgeComponent,
CardComponent,
DataFrameComponent,
IconTextComponent,
LogViewerComponent,
NotificationComponent,
ProgressBarComponent,
ProgressDisplayComponent,
RichTextComponent,
StatusCardComponent,
TaskListComponent,
)
# Exceptions
from .errors import (
AgentError,
ConversationNotFoundError,
LlmServiceError,
PermissionError,
ToolExecutionError,
ToolNotFoundError,
ValidationError,
)
# Core implementations
from .registry import ToolRegistry
# Evaluation framework
from .evaluation import (
Evaluator,
TestCase,
ExpectedOutcome,
AgentResult,
EvaluationResult,
TestCaseResult,
AgentVariant,
EvaluationRunner,
TrajectoryEvaluator,
OutputEvaluator,
LLMAsJudgeEvaluator,
EfficiencyEvaluator,
EvaluationReport,
ComparisonReport,
EvaluationDataset,
)
# Rebuild models to resolve forward references after all imports
from .tool.models import ToolContext, ToolResult
from .components import UiComponent # Import UiComponent to ensure it's available
ToolContext.model_rebuild()
ToolResult.model_rebuild()
__all__ = [
# Models
"User",
"Message",
"Conversation",
"ToolCall",
"ToolResult",
"ToolContext",
"ToolSchema",
"LlmMessage",
"LlmRequest",
"LlmResponse",
"LlmStreamChunk",
"RecoveryAction",
"RecoveryActionType",
"Span",
"Metric",
# Interfaces
"Tool",
"Agent",
"LlmService",
"ConversationStore",
"UserService",
"SystemPromptBuilder",
"LifecycleHook",
"LlmMiddleware",
"WorkflowHandler",
"DefaultWorkflowHandler",
"WorkflowResult",
"ErrorRecoveryStrategy",
"ToolContextEnricher",
"LlmContextEnhancer",
"DefaultLlmContextEnhancer",
"ConversationFilter",
"ObservabilityProvider",
"AuditLogger",
"T",
# Audit
"AuditEvent",
"AuditEventType",
"ToolAccessCheckEvent",
"ToolInvocationEvent",
"ToolResultEvent",
"UiFeatureAccessCheckEvent",
"AiResponseEvent",
# UI Components
"UiComponent",
# Simple Components
"SimpleComponent",
"SimpleComponentType",
"SimpleTextComponent",
"SimpleImageComponent",
"SimpleLinkComponent",
# Rich Components
"RichComponent",
"ArtifactComponent",
"BadgeComponent",
"CardComponent",
"DataFrameComponent",
"IconTextComponent",
"LogViewerComponent",
"NotificationComponent",
"ProgressBarComponent",
"ProgressDisplayComponent",
"RichTextComponent",
"StatusCardComponent",
"TaskListComponent",
# Core implementations
"ToolRegistry",
"Agent",
"AgentConfig",
"DefaultSystemPromptBuilder",
# Evaluation
"Evaluator",
"TestCase",
"ExpectedOutcome",
"AgentResult",
"EvaluationResult",
"TestCaseResult",
"AgentVariant",
"EvaluationRunner",
"TrajectoryEvaluator",
"OutputEvaluator",
"LLMAsJudgeEvaluator",
"EfficiencyEvaluator",
"EvaluationReport",
"ComparisonReport",
"EvaluationDataset",
# Exceptions
"AgentError",
"ToolExecutionError",
"ToolNotFoundError",
"PermissionError",
"ConversationNotFoundError",
"LlmServiceError",
"ValidationError",
]
================================================
FILE: src/vanna/core/_compat.py
================================================
"""
Compatibility shims for different Python versions.
This module provides compatibility utilities for features that vary across
Python versions.
"""
try:
from enum import StrEnum # Py 3.11+
except ImportError: # Py < 3.11
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
"""Minimal backport of StrEnum for Python < 3.11."""
pass
__all__ = ["StrEnum"]
================================================
FILE: src/vanna/core/agent/__init__.py
================================================
"""
Agent module.
This module contains the core Agent implementation and configuration.
"""
from .agent import Agent
from .config import AgentConfig
__all__ = ["Agent", "AgentConfig"]
================================================
FILE: src/vanna/core/agent/agent.py
================================================
"""
Agent implementation for the Vanna Agents framework.
This module provides the main Agent class that orchestrates the interaction
between LLM services, tools, and conversation storage.
"""
import traceback
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional
from vanna.components import (
UiComponent,
SimpleTextComponent,
RichTextComponent,
StatusBarUpdateComponent,
TaskTrackerUpdateComponent,
ChatInputUpdateComponent,
StatusCardComponent,
Task,
)
from .config import AgentConfig
from vanna.core.storage import ConversationStore
from vanna.core.llm import LlmService
from vanna.core.system_prompt import SystemPromptBuilder
from vanna.core.storage import Conversation, Message
from vanna.core.llm import LlmMessage, LlmRequest, LlmResponse
from vanna.core.tool import ToolCall, ToolContext, ToolResult, ToolSchema
from vanna.core.user import User
from vanna.core.registry import ToolRegistry
from vanna.core.system_prompt import DefaultSystemPromptBuilder
from vanna.core.lifecycle import LifecycleHook
from vanna.core.middleware import LlmMiddleware
from vanna.core.workflow import WorkflowHandler, DefaultWorkflowHandler
from vanna.core.recovery import ErrorRecoveryStrategy, RecoveryActionType
from vanna.core.enricher import ToolContextEnricher
from vanna.core.enhancer import LlmContextEnhancer, DefaultLlmContextEnhancer
from vanna.core.filter import ConversationFilter
from vanna.core.observability import ObservabilityProvider
from vanna.core.user.resolver import UserResolver
from vanna.core.user.request_context import RequestContext
from vanna.core.agent.config import UiFeature
from vanna.core.audit import AuditLogger
from vanna.capabilities.agent_memory import AgentMemory
import logging
logger = logging.getLogger(__name__)
logger.info("Loaded vanna.core.agent.agent module")
if TYPE_CHECKING:
pass
class Agent:
"""Main agent implementation.
The Agent class orchestrates LLM interactions, tool execution, and conversation
management. It provides 7 extensibility points for customization:
- lifecycle_hooks: Hook into message and tool execution lifecycle
- llm_middlewares: Intercept and transform LLM requests/responses
- error_recovery_strategy: Handle errors with retry logic
- context_enrichers: Add data to tool execution context
- llm_context_enhancer: Enhance LLM system prompts and messages with context
- conversation_filters: Filter conversation history before LLM calls
- observability_provider: Collect telemetry and monitoring data
Example:
agent = Agent(
llm_service=AnthropicLlmService(api_key="..."),
tool_registry=registry,
conversation_store=store,
lifecycle_hooks=[QuotaCheckHook()],
llm_middlewares=[CachingMiddleware()],
llm_context_enhancer=DefaultLlmContextEnhancer(agent_memory),
observability_provider=LoggingProvider()
)
"""
def __init__(
self,
llm_service: LlmService,
tool_registry: ToolRegistry,
user_resolver: UserResolver,
agent_memory: AgentMemory,
conversation_store: Optional[ConversationStore] = None,
config: AgentConfig = AgentConfig(),
system_prompt_builder: SystemPromptBuilder = DefaultSystemPromptBuilder(),
lifecycle_hooks: List[LifecycleHook] = [],
llm_middlewares: List[LlmMiddleware] = [],
workflow_handler: Optional[WorkflowHandler] = None,
error_recovery_strategy: Optional[ErrorRecoveryStrategy] = None,
context_enrichers: List[ToolContextEnricher] = [],
llm_context_enhancer: Optional[LlmContextEnhancer] = None,
conversation_filters: List[ConversationFilter] = [],
observability_provider: Optional[ObservabilityProvider] = None,
audit_logger: Optional[AuditLogger] = None,
):
self.llm_service = llm_service
self.tool_registry = tool_registry
self.user_resolver = user_resolver
self.agent_memory = agent_memory
# Import here to avoid circular dependency
if conversation_store is None:
from vanna.integrations.local import MemoryConversationStore
conversation_store = MemoryConversationStore()
self.conversation_store = conversation_store
self.config = config
self.system_prompt_builder = system_prompt_builder
self.lifecycle_hooks = lifecycle_hooks
self.llm_middlewares = llm_middlewares
# Use DefaultWorkflowHandler if none provided
if workflow_handler is None:
workflow_handler = DefaultWorkflowHandler()
self.workflow_handler = workflow_handler
self.error_recovery_strategy = error_recovery_strategy
self.context_enrichers = context_enrichers
# Use DefaultLlmContextEnhancer if none provided
if llm_context_enhancer is None:
llm_context_enhancer = DefaultLlmContextEnhancer(agent_memory)
self.llm_context_enhancer = llm_context_enhancer
self.conversation_filters = conversation_filters
self.observability_provider = observability_provider
self.audit_logger = audit_logger
# Wire audit logger into tool registry
if self.audit_logger and self.config.audit_config.enabled:
self.tool_registry.audit_logger = self.audit_logger
self.tool_registry.audit_config = self.config.audit_config
logger.info("Initialized Agent")
async def send_message(
self,
request_context: RequestContext,
message: str,
*,
conversation_id: Optional[str] = None,
) -> AsyncGenerator[UiComponent, None]:
"""
Process a user message and yield UI components with error handling.
Args:
request_context: Request context for user resolution (includes metadata)
message: User's message content
conversation_id: Optional conversation ID; if None, creates new conversation
Yields:
UiComponent instances for UI updates
"""
try:
# Delegate to internal method
async for component in self._send_message(
request_context, message, conversation_id=conversation_id
):
yield component
except Exception as e:
# Log full stack trace
stack_trace = traceback.format_exc()
logger.error(
f"Error in send_message (conversation_id={conversation_id}): {e}\n{stack_trace}",
exc_info=True,
)
# Log to observability provider if available
if self.observability_provider:
try:
error_span = await self.observability_provider.create_span(
"agent.send_message.error",
attributes={
"error_type": type(e).__name__,
"error_message": str(e),
"conversation_id": conversation_id or "none",
},
)
await self.observability_provider.end_span(error_span)
await self.observability_provider.record_metric(
"agent.error.count",
1.0,
"count",
tags={"error_type": type(e).__name__},
)
except Exception as obs_error:
logger.error(
f"Failed to log error to observability provider: {obs_error}",
exc_info=True,
)
# Yield error component to UI (simple, user-friendly message)
error_description = "An unexpected error occurred while processing your message. Please try again."
if conversation_id:
error_description += f"\n\nConversation ID: {conversation_id}"
yield UiComponent(
rich_component=StatusCardComponent(
title="Error Processing Message",
status="error",
description=error_description,
icon="⚠️",
),
simple_component=SimpleTextComponent(
text=f"Error: An unexpected error occurred. Please try again.{f' (Conversation ID: {conversation_id})' if conversation_id else ''}"
),
)
# Update status bar to show error state
yield UiComponent( # type: ignore
rich_component=StatusBarUpdateComponent(
status="error",
message="Error occurred",
detail="An unexpected error occurred while processing your message",
)
)
# Re-enable chat input so user can try again
yield UiComponent( # type: ignore
rich_component=ChatInputUpdateComponent(
placeholder="Try again...", disabled=False
)
)
async def _send_message(
self,
request_context: RequestContext,
message: str,
*,
conversation_id: Optional[str] = None,
) -> AsyncGenerator[UiComponent, None]:
"""
Internal method to process a user message and yield UI components.
Args:
request_context: Request context for user resolution (includes metadata)
message: User's message content
conversation_id: Optional conversation ID; if None, creates new conversation
Yields:
UiComponent instances for UI updates
"""
# Resolve user from request context with observability
user_resolution_span = None
if self.observability_provider:
user_resolution_span = await self.observability_provider.create_span(
"agent.user_resolution",
attributes={"has_context": request_context is not None},
)
user = await self.user_resolver.resolve_user(request_context)
if self.observability_provider and user_resolution_span:
user_resolution_span.set_attribute("user_id", user.id)
await self.observability_provider.end_span(user_resolution_span)
if user_resolution_span.duration_ms():
await self.observability_provider.record_metric(
"agent.user_resolution.duration",
user_resolution_span.duration_ms() or 0,
"ms",
)
# Check if this is a starter UI request (empty message or explicit metadata flag)
is_starter_request = (not message.strip()) or request_context.metadata.get(
"starter_ui_request", False
)
if is_starter_request and self.workflow_handler:
# Handle starter UI request with observability
starter_span = None
if self.observability_provider:
starter_span = await self.observability_provider.create_span(
"agent.workflow_handler.starter_ui", attributes={"user_id": user.id}
)
try:
# Load or create conversation for context
if conversation_id is None:
conversation_id = str(uuid.uuid4())
conversation = await self.conversation_store.get_conversation(
conversation_id, user
)
if not conversation:
# Create empty conversation (will be saved if workflow produces components)
conversation = Conversation(
id=conversation_id, user=user, messages=[]
)
# Get starter UI from workflow handler
components = await self.workflow_handler.get_starter_ui(
self, user, conversation
)
if self.observability_provider and starter_span:
starter_span.set_attribute("has_components", components is not None)
starter_span.set_attribute(
"component_count", len(components) if components else 0
)
if components:
# Yield the starter UI components
for component in components:
yield component
# Yield finalization components
yield UiComponent( # type: ignore
rich_component=StatusBarUpdateComponent(
status="idle",
message="Ready",
detail="Choose an option or type a message",
)
)
yield UiComponent( # type: ignore
rich_component=ChatInputUpdateComponent(
placeholder="Ask a question...", disabled=False
)
)
if self.observability_provider and starter_span:
await self.observability_provider.end_span(starter_span)
if starter_span.duration_ms():
await self.observability_provider.record_metric(
"agent.workflow_handler.starter_ui.duration",
starter_span.duration_ms() or 0,
"ms",
)
# Save the conversation if it was newly created
if self.config.auto_save_conversations:
await self.conversation_store.update_conversation(conversation)
return # Exit without calling LLM
except Exception as e:
logger.error(f"Error generating starter UI: {e}", exc_info=True)
if self.observability_provider and starter_span:
starter_span.set_attribute("error", str(e))
await self.observability_provider.end_span(starter_span)
# Fall through to normal processing on error
# Don't process actual empty messages (that aren't starter requests)
if not message.strip():
return
# Create observability span for entire message processing
message_span = None
if self.observability_provider:
message_span = await self.observability_provider.create_span(
"agent.send_message",
attributes={
"user_id": user.id,
"conversation_id": conversation_id or "new",
},
)
# Run before_message hooks with observability
modified_message = message
for hook in self.lifecycle_hooks:
hook_span = None
if self.observability_provider:
hook_span = await self.observability_provider.create_span(
"agent.hook.before_message",
attributes={"hook": hook.__class__.__name__},
)
hook_result = await hook.before_message(user, modified_message)
if hook_result is not None:
modified_message = hook_result
if self.observability_provider and hook_span:
hook_span.set_attribute("modified_message", hook_result is not None)
await self.observability_provider.end_span(hook_span)
if hook_span.duration_ms():
await self.observability_provider.record_metric(
"agent.hook.duration",
hook_span.duration_ms() or 0,
"ms",
tags={
"hook": hook.__class__.__name__,
"phase": "before_message",
},
)
# Use the potentially modified message
message = modified_message
# Generate conversation ID and request ID if not provided
if conversation_id is None:
conversation_id = str(uuid.uuid4())
request_id = str(uuid.uuid4())
# Update status to working
yield UiComponent( # type: ignore
rich_component=StatusBarUpdateComponent(
status="working",
message="Processing your request...",
detail="Analyzing query",
)
)
# Load or create conversation with observability (but don't add message yet)
conversation_span = None
if self.observability_provider:
conversation_span = await self.observability_provider.create_span(
"agent.conversation.load",
attributes={"conversation_id": conversation_id, "user_id": user.id},
)
conversation = await self.conversation_store.get_conversation(
conversation_id, user
)
is_new_conversation = conversation is None
if not conversation:
# Create empty conversation (will add message after workflow handler check)
conversation = Conversation(id=conversation_id, user=user, messages=[])
if self.observability_provider and conversation_span:
conversation_span.set_attribute("is_new", is_new_conversation)
conversation_span.set_attribute("message_count", len(conversation.messages))
await self.observability_provider.end_span(conversation_span)
if conversation_span.duration_ms():
await self.observability_provider.record_metric(
"agent.conversation.load.duration",
conversation_span.duration_ms() or 0,
"ms",
tags={"is_new": str(is_new_conversation)},
)
# Try workflow handler before adding message to conversation
if self.workflow_handler:
trigger_span = None
if self.observability_provider:
trigger_span = await self.observability_provider.create_span(
"agent.workflow_handler.try_handle",
attributes={"user_id": user.id, "conversation_id": conversation_id},
)
try:
workflow_result = await self.workflow_handler.try_handle(
self, user, conversation, message
)
if self.observability_provider and trigger_span:
trigger_span.set_attribute(
"should_skip_llm", workflow_result.should_skip_llm
)
if workflow_result.should_skip_llm:
# Workflow handled the message, short-circuit LLM
# Apply conversation mutation if provided
if workflow_result.conversation_mutation:
await workflow_result.conversation_mutation(conversation)
# Stream components
if workflow_result.components:
if isinstance(workflow_result.components, list):
for component in workflow_result.components:
yield component
else:
# AsyncGenerator
async for component in workflow_result.components:
yield component
# Finalize response (status bar + chat input)
yield UiComponent( # type: ignore
rich_component=StatusBarUpdateComponent(
status="idle",
message="Workflow complete",
detail="Ready for next message",
)
)
yield UiComponent( # type: ignore
rich_component=ChatInputUpdateComponent(
placeholder="Ask a question...", disabled=False
)
)
# Save conversation if auto-save enabled
if self.config.auto_save_conversations:
await self.conversation_store.update_conversation(conversation)
if self.observability_provider and trigger_span:
await self.observability_provider.end_span(trigger_span)
# Exit without calling LLM
return
except Exception as e:
logger.error(f"Error in workflow handler: {e}", exc_info=True)
if self.observability_provider and trigger_span:
trigger_span.set_attribute("error", str(e))
await self.observability_provider.end_span(trigger_span)
# Fall through to normal LLM processing on error
finally:
if self.observability_provider and trigger_span:
await self.observability_provider.end_span(trigger_span)
# Persist new conversation to store before adding message
if is_new_conversation:
await self.conversation_store.update_conversation(conversation)
# Not triggered, add user message to conversation now
conversation.add_message(Message(role="user", content=message))
# Add initial task
context_task = Task(
title="Load conversation context",
description="Reading message history and user context",
status="pending",
)
yield UiComponent( # type: ignore
rich_component=TaskTrackerUpdateComponent.add_task(context_task)
)
# Collect available UI features for auditing
ui_features_available = []
for feature_name in self.config.ui_features.feature_group_access.keys():
if self.config.ui_features.can_user_access_feature(feature_name, user):
ui_features_available.append(feature_name)
# Create context with observability provider and UI features
context = ToolContext(
user=user,
conversation_id=conversation_id,
request_id=request_id,
agent_memory=self.agent_memory,
observability_provider=self.observability_provider,
metadata={"ui_features_available": ui_features_available},
)
# Enrich context with additional data with observability
for enricher in self.context_enrichers:
enrichment_span = None
if self.observability_provider:
enrichment_span = await self.observability_provider.create_span(
"agent.context.enrichment",
attributes={"enricher": enricher.__class__.__name__},
)
context = await enricher.enrich_context(context)
if self.observability_provider and enrichment_span:
await self.observability_provider.end_span(enrichment_span)
if enrichment_span.duration_ms():
await self.observability_provider.record_metric(
"agent.enrichment.duration",
enrichment_span.duration_ms() or 0,
"ms",
tags={"enricher": enricher.__class__.__name__},
)
# Get available tools for user with observability
schema_span = None
if self.observability_provider:
schema_span = await self.observability_provider.create_span(
"agent.tool_schemas.fetch", attributes={"user_id": user.id}
)
tool_schemas = await self.tool_registry.get_schemas(user)
if self.observability_provider and schema_span:
schema_span.set_attribute("schema_count", len(tool_schemas))
await self.observability_provider.end_span(schema_span)
if schema_span.duration_ms():
await self.observability_provider.record_metric(
"agent.tool_schemas.duration",
schema_span.duration_ms() or 0,
"ms",
tags={"schema_count": str(len(tool_schemas))},
)
# Update task status to completed
yield UiComponent( # type: ignore
rich_component=TaskTrackerUpdateComponent.update_task(
context_task.id, status="completed"
)
)
# Build system prompt with observability
prompt_span = None
if self.observability_provider:
prompt_span = await self.observability_provider.create_span(
"agent.system_prompt.build",
attributes={"tool_count": len(tool_schemas)},
)
system_prompt = await self.system_prompt_builder.build_system_prompt(
user, tool_schemas
)
# Enhance system prompt with LLM context enhancer
if self.llm_context_enhancer and system_prompt is not None:
enhancement_span = None
if self.observability_provider:
enhancement_span = await self.observability_provider.create_span(
"agent.llm_context.enhance_system_prompt",
attributes={
"enhancer": self.llm_context_enhancer.__class__.__name__
},
)
system_prompt = await self.llm_context_enhancer.enhance_system_prompt(
system_prompt, message, user
)
if self.observability_provider and enhancement_span:
await self.observability_provider.end_span(enhancement_span)
if enhancement_span.duration_ms():
await self.observability_provider.record_metric(
"agent.llm_context.enhance_system_prompt.duration",
enhancement_span.duration_ms() or 0,
"ms",
tags={"enhancer": self.llm_context_enhancer.__class__.__name__},
)
if self.observability_provider and prompt_span:
prompt_span.set_attribute(
"prompt_length", len(system_prompt) if system_prompt else 0
)
await self.observability_provider.end_span(prompt_span)
if prompt_span.duration_ms():
await self.observability_provider.record_metric(
"agent.system_prompt.duration", prompt_span.duration_ms() or 0, "ms"
)
# Build LLM request
request = await self._build_llm_request(
conversation, tool_schemas, user, system_prompt
)
# Process with tool loop
tool_iterations = 0
while tool_iterations < self.config.max_tool_iterations:
if self.config.include_thinking_indicators and tool_iterations == 0:
# TODO: Yield thinking indicator
pass
# Get LLM response
if self.config.stream_responses:
response = await self._handle_streaming_response(request)
else:
response = await self._send_llm_request(request)
# Handle tool calls
if response.is_tool_call():
tool_iterations += 1
# First, add the assistant message with tool_calls to the conversation
# This is required for OpenAI API - tool messages must follow assistant messages with tool_calls
assistant_message = Message(
role="assistant",
content=response.content or "", # Ensure content is not None
tool_calls=response.tool_calls,
)
conversation.add_message(assistant_message)
if response.content is not None:
# Yield any partial content from the assistant before tool execution
has_tool_invocation_message_in_chat = (
self.config.ui_features.can_user_access_feature(
UiFeature.UI_FEATURE_SHOW_TOOL_INVOCATION_MESSAGE_IN_CHAT,
user,
)
)
if has_tool_invocation_message_in_chat:
yield UiComponent(
rich_component=RichTextComponent(
content=response.content, markdown=True
),
simple_component=SimpleTextComponent(text=response.content),
)
# Update status to executing tools
yield UiComponent( # type: ignore
rich_component=StatusBarUpdateComponent(
status="working",
message="Executing tools...",
detail=f"Running {len(response.tool_calls or [])} tools",
)
)
else:
# Yield as a status update instead
yield UiComponent( # type: ignore
rich_component=StatusBarUpdateComponent(
status="working", message=response.content, detail=""
)
)
# Collect all tool results first
tool_results = []
for i, tool_call in enumerate(response.tool_calls or []):
# Add task for this tool execution
tool_task = Task(
title=f"Execute {tool_call.name}",
description=f"Running tool with provided arguments",
status="in_progress",
)
has_tool_names_access = (
self.config.ui_features.can_user_access_feature(
UiFeature.UI_FEATURE_SHOW_TOOL_NAMES, user
)
)
# Audit UI feature access check
if (
self.audit_logger
and self.config.audit_config.enabled
and self.config.audit_config.log_ui_feature_checks
):
await self.audit_logger.log_ui_feature_access(
user=user,
feature_name=UiFeature.UI_FEATURE_SHOW_TOOL_NAMES,
access_granted=has_tool_names_access,
required_groups=self.config.ui_features.feature_group_access.get(
UiFeature.UI_FEATURE_SHOW_TOOL_NAMES, []
),
conversation_id=conversation.id,
request_id=request_id,
)
if has_tool_names_access:
yield UiComponent( # type: ignore
rich_component=TaskTrackerUpdateComponent.add_task(
tool_task
)
)
response_str = response.content
# Use primitive StatusCard instead of semantic ToolExecutionComponent
tool_status_card = StatusCardComponent(
title=f"Executing {tool_call.name}",
status="running",
description=f"Running tool with {len(tool_call.arguments)} arguments",
icon="⚙️",
metadata=tool_call.arguments,
)
has_tool_args_access = (
self.config.ui_features.can_user_access_feature(
UiFeature.UI_FEATURE_SHOW_TOOL_ARGUMENTS, user
)
)
# Audit UI feature access check
if (
self.audit_logger
and self.config.audit_config.enabled
and self.config.audit_config.log_ui_feature_checks
):
await self.audit_logger.log_ui_feature_access(
user=user,
feature_name=UiFeature.UI_FEATURE_SHOW_TOOL_ARGUMENTS,
access_granted=has_tool_args_access,
required_groups=self.config.ui_features.feature_group_access.get(
UiFeature.UI_FEATURE_SHOW_TOOL_ARGUMENTS, []
),
conversation_id=conversation.id,
request_id=request_id,
)
if has_tool_args_access:
yield UiComponent(
rich_component=tool_status_card,
simple_component=SimpleTextComponent(
text=response_str or ""
),
)
# Run before_tool hooks with observability
tool = await self.tool_registry.get_tool(tool_call.name)
if tool:
for hook in self.lifecycle_hooks:
hook_span = None
if self.observability_provider:
hook_span = (
await self.observability_provider.create_span(
"agent.hook.before_tool",
attributes={
"hook": hook.__class__.__name__,
"tool": tool_call.name,
},
)
)
await hook.before_tool(tool, context)
if self.observability_provider and hook_span:
await self.observability_provider.end_span(hook_span)
if hook_span.duration_ms():
await self.observability_provider.record_metric(
"agent.hook.duration",
hook_span.duration_ms() or 0,
"ms",
tags={
"hook": hook.__class__.__name__,
"phase": "before_tool",
"tool": tool_call.name,
},
)
# Execute tool with observability
tool_exec_span = None
if self.observability_provider:
tool_exec_span = await self.observability_provider.create_span(
"agent.tool.execute",
attributes={
"tool": tool_call.name,
"arg_count": len(tool_call.arguments),
},
)
result = await self.tool_registry.execute(tool_call, context)
if self.observability_provider and tool_exec_span:
tool_exec_span.set_attribute("success", result.success)
if not result.success:
tool_exec_span.set_attribute(
"error", result.error or "unknown"
)
await self.observability_provider.end_span(tool_exec_span)
if tool_exec_span.duration_ms():
await self.observability_provider.record_metric(
"agent.tool.duration",
tool_exec_span.duration_ms() or 0,
"ms",
tags={
"tool": tool_call.name,
"success": str(result.success),
},
)
# Run after_tool hooks with observability
for hook in self.lifecycle_hooks:
hook_span = None
if self.observability_provider:
hook_span = await self.observability_provider.create_span(
"agent.hook.after_tool",
attributes={
"hook": hook.__class__.__name__,
"tool": tool_call.name,
},
)
modified_result = await hook.after_tool(result)
if modified_result is not None:
result = modified_result
if self.observability_provider and hook_span:
hook_span.set_attribute(
"modified_result", modified_result is not None
)
await self.observability_provider.end_span(hook_span)
if hook_span.duration_ms():
await self.observability_provider.record_metric(
"agent.hook.duration",
hook_span.duration_ms() or 0,
"ms",
tags={
"hook": hook.__class__.__name__,
"phase": "after_tool",
"tool": tool_call.name,
},
)
# Update status card to show completion
final_status = "success" if result.success else "error"
final_description = (
f"Tool completed successfully"
if result.success
else f"Tool failed: {result.error or 'Unknown error'}"
)
has_tool_args_access_2 = (
self.config.ui_features.can_user_access_feature(
UiFeature.UI_FEATURE_SHOW_TOOL_ARGUMENTS, user
)
)
# Audit UI feature access check
if (
self.audit_logger
and self.config.audit_config.enabled
and self.config.audit_config.log_ui_feature_checks
):
await self.audit_logger.log_ui_feature_access(
user=user,
feature_name=UiFeature.UI_FEATURE_SHOW_TOOL_ARGUMENTS,
access_granted=has_tool_args_access_2,
required_groups=self.config.ui_features.feature_group_access.get(
UiFeature.UI_FEATURE_SHOW_TOOL_ARGUMENTS, []
),
conversation_id=conversation.id,
request_id=request_id,
)
if has_tool_args_access_2:
yield UiComponent(
rich_component=tool_status_card.set_status(
final_status, final_description
),
simple_component=SimpleTextComponent(
text=final_description
),
)
has_tool_names_access_2 = (
self.config.ui_features.can_user_access_feature(
UiFeature.UI_FEATURE_SHOW_TOOL_NAMES, user
)
)
# Audit UI feature access check
if (
self.audit_logger
and self.config.audit_config.enabled
and self.config.audit_config.log_ui_feature_checks
):
await self.audit_logger.log_ui_feature_access(
user=user,
feature_name=UiFeature.UI_FEATURE_SHOW_TOOL_NAMES,
access_granted=has_tool_names_access_2,
required_groups=self.config.ui_features.feature_group_access.get(
UiFeature.UI_FEATURE_SHOW_TOOL_NAMES, []
),
conversation_id=conversation.id,
request_id=request_id,
)
if has_tool_names_access_2:
# Update tool task to completed
yield UiComponent( # type: ignore
rich_component=TaskTrackerUpdateComponent.update_task(
tool_task.id,
status="completed",
detail=f"Tool {'completed successfully' if result.success else 'return an error'}",
)
)
# Yield tool result
if result.ui_component:
# For errors, check if user has access to see error details
if not result.success:
has_tool_error_access = (
self.config.ui_features.can_user_access_feature(
UiFeature.UI_FEATURE_SHOW_TOOL_ERROR, user
)
)
# Audit UI feature access check
if (
self.audit_logger
and self.config.audit_config.enabled
and self.config.audit_config.log_ui_feature_checks
):
await self.audit_logger.log_ui_feature_access(
user=user,
feature_name=UiFeature.UI_FEATURE_SHOW_TOOL_ERROR,
access_granted=has_tool_error_access,
required_groups=self.config.ui_features.feature_group_access.get(
UiFeature.UI_FEATURE_SHOW_TOOL_ERROR, []
),
conversation_id=conversation.id,
request_id=request_id,
)
if has_tool_error_access:
yield result.ui_component
else:
# Success results are always shown if they exist
yield result.ui_component
# Collect tool result data
tool_results.append(
{
"tool_call_id": tool_call.id,
"content": (
result.result_for_llm
if result.success
else result.error or "Tool execution failed"
),
}
)
# Add tool responses to conversation
# For APIs that need all tool results in one message, this helps
for tool_result in tool_results:
tool_response_message = Message(
role="tool",
content=tool_result["content"],
tool_call_id=tool_result["tool_call_id"],
)
conversation.add_message(tool_response_message)
# Rebuild request with tool responses
request = await self._build_llm_request(
conversation, tool_schemas, user, system_prompt
)
else:
# Update status to idle and set completion message
yield UiComponent( # type: ignore
rich_component=StatusBarUpdateComponent(
status="idle",
message="Response complete",
detail="Ready for next message",
)
)
# Update chat input placeholder
yield UiComponent( # type: ignore
rich_component=ChatInputUpdateComponent(
placeholder="Ask a follow-up question...", disabled=False
)
)
# Yield final text response
if response.content:
# Add assistant response to conversation
conversation.add_message(
Message(role="assistant", content=response.content)
)
yield UiComponent(
rich_component=RichTextComponent(
content=response.content, markdown=True
),
simple_component=SimpleTextComponent(text=response.content),
)
break
# Check if we hit the tool iteration limit
if tool_iterations >= self.config.max_tool_iterations:
# The loop exited due to hitting the limit, not due to a natural completion
logger.warning(
f"Tool iteration limit reached: {tool_iterations}/{self.config.max_tool_iterations}"
)
# Update status bar to show warning
yield UiComponent( # type: ignore
rich_component=StatusBarUpdateComponent(
status="warning",
message="Tool limit reached",
detail=f"Stopped after {tool_iterations} tool executions. The task may be incomplete.",
)
)
# Provide detailed warning message to user
warning_message = f"""⚠️ **Tool Execution Limit Reached**
The agent stopped after executing {tool_iterations} tools (the configured maximum). The task may not be fully complete.
You can:
- Ask me to continue where I left off
- Adjust the `max_tool_iterations` setting if you need more tool calls
- Break the task into smaller steps"""
yield UiComponent(
rich_component=RichTextComponent(
content=warning_message, markdown=True
),
simple_component=SimpleTextComponent(
text=f"Tool limit reached after {tool_iterations} executions. Task may be incomplete."
),
)
# Update chat input to suggest follow-up
yield UiComponent( # type: ignore
rich_component=ChatInputUpdateComponent(
placeholder="Continue the task or ask me something else...",
disabled=False,
)
)
# Save conversation if configured
if self.config.auto_save_conversations:
save_span = None
if self.observability_provider:
save_span = await self.observability_provider.create_span(
"agent.conversation.save",
attributes={
"conversation_id": conversation_id,
"message_count": len(conversation.messages),
},
)
await self.conversation_store.update_conversation(conversation)
if self.observability_provider and save_span:
await self.observability_provider.end_span(save_span)
if save_span.duration_ms():
await self.observability_provider.record_metric(
"agent.conversation.save.duration",
save_span.duration_ms() or 0,
"ms",
)
# Run after_message hooks with observability
for hook in self.lifecycle_hooks:
hook_span = None
if self.observability_provider:
hook_span = await self.observability_provider.create_span(
"agent.hook.after_message",
attributes={"hook": hook.__class__.__name__},
)
await hook.after_message(conversation)
if self.observability_provider and hook_span:
await self.observability_provider.end_span(hook_span)
if hook_span.duration_ms():
await self.observability_provider.record_metric(
"agent.hook.duration",
hook_span.duration_ms() or 0,
"ms",
tags={
"hook": hook.__class__.__name__,
"phase": "after_message",
},
)
# End observability span and record metrics
if self.observability_provider and message_span:
message_span.set_attribute("tool_iterations", tool_iterations)
# Track if we hit the tool iteration limit
hit_tool_limit = tool_iterations >= self.config.max_tool_iterations
message_span.set_attribute("hit_tool_limit", hit_tool_limit)
if hit_tool_limit:
message_span.set_attribute("incomplete_response", True)
logger.info(
f"Tool limit reached - marking response as potentially incomplete"
)
await self.observability_provider.end_span(message_span)
if message_span.duration_ms():
await self.observability_provider.record_metric(
"agent.message.duration",
message_span.duration_ms() or 0,
"ms",
tags={"user_id": user.id, "hit_tool_limit": str(hit_tool_limit)},
)
async def get_available_tools(self, user: User) -> List[ToolSchema]:
"""Get tools available to the user."""
return await self.tool_registry.get_schemas(user)
async def _build_llm_request(
self,
conversation: Conversation,
tool_schemas: List[ToolSchema],
user: User,
system_prompt: Optional[str] = None,
) -> LlmRequest:
"""Build LLM request from conversation and tools."""
# Apply conversation filters with observability
filtered_messages = conversation.messages
for filter in self.conversation_filters:
filter_span = None
if self.observability_provider:
filter_span = await self.observability_provider.create_span(
"agent.conversation.filter",
attributes={
"filter": filter.__class__.__name__,
"message_count_before": len(filtered_messages),
},
)
filtered_messages = await filter.filter_messages(filtered_messages)
if self.observability_provider and filter_span:
filter_span.set_attribute("message_count_after", len(filtered_messages))
await self.observability_provider.end_span(filter_span)
if filter_span.duration_ms():
await self.observability_provider.record_metric(
"agent.filter.duration",
filter_span.duration_ms() or 0,
"ms",
tags={"filter": filter.__class__.__name__},
)
messages = []
for msg in filtered_messages:
llm_msg = LlmMessage(
role=msg.role,
content=msg.content,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
)
messages.append(llm_msg)
# Enhance messages with LLM context enhancer
if self.llm_context_enhancer:
enhancement_span = None
if self.observability_provider:
enhancement_span = await self.observability_provider.create_span(
"agent.llm_context.enhance_user_messages",
attributes={
"enhancer": self.llm_context_enhancer.__class__.__name__,
"message_count": len(messages),
},
)
messages = await self.llm_context_enhancer.enhance_user_messages(
messages, user
)
if self.observability_provider and enhancement_span:
enhancement_span.set_attribute("message_count_after", len(messages))
await self.observability_provider.end_span(enhancement_span)
if enhancement_span.duration_ms():
await self.observability_provider.record_metric(
"agent.llm_context.enhance_user_messages.duration",
enhancement_span.duration_ms() or 0,
"ms",
tags={"enhancer": self.llm_context_enhancer.__class__.__name__},
)
return LlmRequest(
messages=messages,
tools=tool_schemas if tool_schemas else None,
user=user,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
stream=self.config.stream_responses,
system_prompt=system_prompt,
)
async def _send_llm_request(self, request: LlmRequest) -> LlmResponse:
"""Send LLM request with middleware and observability."""
# Apply before_llm_request middlewares with observability
for middleware in self.llm_middlewares:
mw_span = None
if self.observability_provider:
mw_span = await self.observability_provider.create_span(
"agent.middleware.before_llm",
attributes={"middleware": middleware.__class__.__name__},
)
request = await middleware.before_llm_request(request)
if self.observability_provider and mw_span:
await self.observability_provider.end_span(mw_span)
if mw_span.duration_ms():
await self.observability_provider.record_metric(
"agent.middleware.duration",
mw_span.duration_ms() or 0,
"ms",
tags={
"middleware": middleware.__class__.__name__,
"phase": "before_llm",
},
)
# Create observability span for LLM call
llm_span = None
if self.observability_provider:
llm_span = await self.observability_provider.create_span(
"llm.request",
attributes={
"model": getattr(self.llm_service, "model", "unknown"),
"stream": request.stream,
},
)
# Send request
response = await self.llm_service.send_request(request)
# End span and record metrics
if self.observability_provider and llm_span:
await self.observability_provider.end_span(llm_span)
if llm_span.duration_ms():
await self.observability_provider.record_metric(
"llm.request.duration", llm_span.duration_ms() or 0, "ms"
)
# Apply after_llm_response middlewares with observability
for middleware in self.llm_middlewares:
mw_span = None
if self.observability_provider:
mw_span = await self.observability_provider.create_span(
"agent.middleware.after_llm",
attributes={"middleware": middleware.__class__.__name__},
)
response = await middleware.after_llm_response(request, response)
if self.observability_provider and mw_span:
await self.observability_provider.end_span(mw_span)
if mw_span.duration_ms():
await self.observability_provider.record_metric(
"agent.middleware.duration",
mw_span.duration_ms() or 0,
"ms",
tags={
"middleware": middleware.__class__.__name__,
"phase": "after_llm",
},
)
return response
async def _handle_streaming_response(self, request: LlmRequest) -> LlmResponse:
"""Handle streaming response from LLM."""
# Apply before_llm_request middlewares with observability
for middleware in self.llm_middlewares:
mw_span = None
if self.observability_provider:
mw_span = await self.observability_provider.create_span(
"agent.middleware.before_llm",
attributes={
"middleware": middleware.__class__.__name__,
"stream": True,
},
)
request = await middleware.before_llm_request(request)
if self.observability_provider and mw_span:
await self.observability_provider.end_span(mw_span)
if mw_span.duration_ms():
await self.observability_provider.record_metric(
"agent.middleware.duration",
mw_span.duration_ms() or 0,
"ms",
tags={
"middleware": middleware.__class__.__name__,
"phase": "before_llm",
"stream": "true",
},
)
accumulated_content = ""
accumulated_tool_calls = []
# Create span for streaming
stream_span = None
if self.observability_provider:
stream_span = await self.observability_provider.create_span(
"llm.stream",
attributes={"model": getattr(self.llm_service, "model", "unknown")},
)
async for chunk in self.llm_service.stream_request(request):
if chunk.content:
accumulated_content += chunk.content
# Could yield intermediate TextChunk here
if chunk.tool_calls:
accumulated_tool_calls.extend(chunk.tool_calls)
# End streaming span
if self.observability_provider and stream_span:
stream_span.set_attribute("content_length", len(accumulated_content))
stream_span.set_attribute("tool_call_count", len(accumulated_tool_calls))
await self.observability_provider.end_span(stream_span)
if stream_span.duration_ms():
await self.observability_provider.record_metric(
"llm.stream.duration", stream_span.duration_ms() or 0, "ms"
)
response = LlmResponse(
content=accumulated_content if accumulated_content else None,
tool_calls=accumulated_tool_calls if accumulated_tool_calls else None,
)
# Apply after_llm_response middlewares with observability
for middleware in self.llm_middlewares:
mw_span = None
if self.observability_provider:
mw_span = await self.observability_provider.create_span(
"agent.middleware.after_llm",
attributes={
"middleware": middleware.__class__.__name__,
"stream": True,
},
)
response = await middleware.after_llm_response(request, response)
if self.observability_provider and mw_span:
await self.observability_provider.end_span(mw_span)
if mw_span.duration_ms():
await self.observability_provider.record_metric(
"agent.middleware.duration",
mw_span.duration_ms() or 0,
"ms",
tags={
"middleware": middleware.__class__.__name__,
"phase": "after_llm",
"stream": "true",
},
)
return response
================================================
FILE: src/vanna/core/agent/config.py
================================================
"""
Agent configuration.
This module contains configuration models that control agent behavior.
"""
from typing import TYPE_CHECKING, Dict, List, Optional
from pydantic import BaseModel, Field
from .._compat import StrEnum
if TYPE_CHECKING:
from ..user import User
class UiFeature(StrEnum):
UI_FEATURE_SHOW_TOOL_NAMES = "tool_names"
UI_FEATURE_SHOW_TOOL_ARGUMENTS = "tool_arguments"
UI_FEATURE_SHOW_TOOL_ERROR = "tool_error"
UI_FEATURE_SHOW_TOOL_INVOCATION_MESSAGE_IN_CHAT = "tool_invocation_message_in_chat"
UI_FEATURE_SHOW_MEMORY_DETAILED_RESULTS = "memory_detailed_results"
# Optional: you can also define defaults if you want a shared baseline
DEFAULT_UI_FEATURES: Dict[str, List[str]] = {
UiFeature.UI_FEATURE_SHOW_TOOL_NAMES: ["admin", "user"],
UiFeature.UI_FEATURE_SHOW_TOOL_ARGUMENTS: ["admin"],
UiFeature.UI_FEATURE_SHOW_TOOL_ERROR: ["admin"],
UiFeature.UI_FEATURE_SHOW_TOOL_INVOCATION_MESSAGE_IN_CHAT: ["admin"],
UiFeature.UI_FEATURE_SHOW_MEMORY_DETAILED_RESULTS: ["admin"],
}
class UiFeatures(BaseModel):
"""UI features with group-based access control using the same pattern as tools.
Each field specifies which groups can access that UI feature.
Empty list means the feature is accessible to all users.
Uses the same intersection logic as tool access control.
"""
# Custom features for extensibility
feature_group_access: Dict[str, List[str]] = Field(
default_factory=lambda: DEFAULT_UI_FEATURES.copy(),
description="Which groups can access UI features",
)
def can_user_access_feature(self, feature_name: str, user: "User") -> bool:
"""Check if user can access a UI feature using same logic as tools.
Args:
feature_name: Name of the UI feature to check
user: User object with group_memberships
Returns:
True if user has access, False otherwise
"""
# Then try custom features
if feature_name in self.feature_group_access:
allowed_groups = self.feature_group_access[feature_name]
else:
# Feature doesn't exist, deny access
return False
# Empty list means all users can access (same as tools)
if not allowed_groups:
return True
# Same intersection logic as tool access control
user_groups = set(user.group_memberships)
feature_groups = set(allowed_groups)
return bool(user_groups & feature_groups)
def register_feature(self, name: str, access_groups: List[str]) -> None:
"""Register a custom UI feature with group access control.
Args:
name: Name of the custom feature
access_groups: List of groups that can access this feature
"""
self.feature_group_access[name] = access_groups
class AuditConfig(BaseModel):
"""Configuration for audit logging."""
enabled: bool = Field(default=True, description="Enable audit logging")
log_tool_access_checks: bool = Field(
default=True, description="Log tool access permission checks"
)
log_tool_invocations: bool = Field(
default=True, description="Log tool invocations with parameters"
)
log_tool_results: bool = Field(
default=True, description="Log tool execution results"
)
log_ui_feature_checks: bool = Field(
default=False, description="Log UI feature access checks (can be noisy)"
)
log_ai_responses: bool = Field(
default=True, description="Log AI-generated responses"
)
include_full_ai_responses: bool = Field(
default=False,
description="Include full AI response text in logs (privacy concern)",
)
sanitize_tool_parameters: bool = Field(
default=True, description="Sanitize sensitive parameters (passwords, tokens)"
)
class AgentConfig(BaseModel):
"""Configuration for agent behavior."""
max_tool_iterations: int = Field(default=10, gt=0)
stream_responses: bool = Field(default=True)
auto_save_conversations: bool = Field(default=True)
include_thinking_indicators: bool = Field(default=True)
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=None, gt=0)
ui_features: UiFeatures = Field(default_factory=UiFeatures)
audit_config: AuditConfig = Field(default_factory=AuditConfig)
================================================
FILE: src/vanna/core/audit/__init__.py
================================================
"""
Audit logging for the Vanna Agents framework.
This module provides interfaces and models for audit logging, enabling
tracking of user actions, tool invocations, and access control decisions.
"""
from .base import AuditLogger
from .models import (
AiResponseEvent,
AuditEvent,
AuditEventType,
ToolAccessCheckEvent,
ToolInvocationEvent,
ToolResultEvent,
UiFeatureAccessCheckEvent,
)
__all__ = [
"AuditLogger",
"AuditEvent",
"AuditEventType",
"ToolAccessCheckEvent",
"ToolInvocationEvent",
"ToolResultEvent",
"UiFeatureAccessCheckEvent",
"AiResponseEvent",
]
================================================
FILE: src/vanna/core/audit/base.py
================================================
"""
Base audit logger interface.
Audit loggers enable tracking user actions, tool invocations, and access control
decisions for security, compliance, and debugging.
"""
import hashlib
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from .models import (
AiResponseEvent,
AuditEvent,
ToolAccessCheckEvent,
ToolInvocationEvent,
ToolResultEvent,
UiFeatureAccessCheckEvent,
)
if TYPE_CHECKING:
from ..tool.models import ToolCall, ToolContext, ToolResult
from ..user.models import User
class AuditLogger(ABC):
"""Abstract base class for audit logging implementations.
Implementations can:
- Write to files (JSON, CSV, etc.)
- Send to databases (Postgres, MongoDB, etc.)
- Stream to cloud services (CloudWatch, Datadog, etc.)
- Send to SIEM systems (Splunk, Elastic, etc.)
Example:
class PostgresAuditLogger(AuditLogger):
async def log_event(self, event: AuditEvent) -> None:
await self.db.execute(
"INSERT INTO audit_log (...) VALUES (...)",
event.model_dump()
)
agent = Agent(
llm_service=...,
audit_logger=PostgresAuditLogger(db_pool)
)
"""
@abstractmethod
async def log_event(self, event: AuditEvent) -> None:
"""Log a single audit event.
Args:
event: The audit event to log
Raises:
Exception: If logging fails critically
"""
pass
async def log_tool_access_check(
self,
user: "User",
tool_name: str,
access_granted: bool,
required_groups: List[str],
context: "ToolContext",
reason: Optional[str] = None,
) -> None:
"""Convenience method for logging tool access checks.
Args:
user: User attempting to access the tool
tool_name: Name of the tool being accessed
access_granted: Whether access was granted
required_groups: Groups required to access the tool
context: Tool execution context
reason: Optional reason for denial
"""
event = ToolAccessCheckEvent(
user_id=user.id,
username=user.username,
user_email=user.email,
user_groups=user.group_memberships,
conversation_id=context.conversation_id,
request_id=context.request_id,
tool_name=tool_name,
access_granted=access_granted,
required_groups=required_groups,
reason=reason,
)
await self.log_event(event)
async def log_tool_invocation(
self,
user: "User",
tool_call: "ToolCall",
ui_features: List[str],
context: "ToolContext",
sanitize_parameters: bool = True,
) -> None:
"""Convenience method for logging tool invocations.
Args:
user: User invoking the tool
tool_call: Tool call information
ui_features: List of UI features available to the user
context: Tool execution context
sanitize_parameters: Whether to sanitize sensitive parameters
"""
parameters = tool_call.arguments.copy()
sanitized = False
if sanitize_parameters:
parameters, sanitized = self._sanitize_parameters(parameters)
event = ToolInvocationEvent(
user_id=user.id,
username=user.username,
user_email=user.email,
user_groups=user.group_memberships,
conversation_id=context.conversation_id,
request_id=context.request_id,
tool_call_id=tool_call.id,
tool_name=tool_call.name,
parameters=parameters,
parameters_sanitized=sanitized,
ui_features_available=ui_features,
)
await self.log_event(event)
async def log_tool_result(
self,
user: "User",
tool_call: "ToolCall",
result: "ToolResult",
context: "ToolContext",
) -> None:
"""Convenience method for logging tool results.
Args:
user: User who invoked the tool
tool_call: Tool call information
result: Tool execution result
context: Tool execution context
"""
event = ToolResultEvent(
user_id=user.id,
username=user.username,
user_email=user.email,
user_groups=user.group_memberships,
conversation_id=context.conversation_id,
request_id=context.request_id,
tool_call_id=tool_call.id,
tool_name=tool_call.name,
success=result.success,
error=result.error,
execution_time_ms=result.metadata.get("execution_time_ms", 0.0),
result_size_bytes=(
len(result.result_for_llm.encode("utf-8"))
if result.result_for_llm
else 0
),
ui_component_type=(
result.ui_component.__class__.__name__ if result.ui_component else None
),
)
await self.log_event(event)
async def log_ui_feature_access(
self,
user: "User",
feature_name: str,
access_granted: bool,
required_groups: List[str],
conversation_id: str,
request_id: str,
) -> None:
"""Convenience method for logging UI feature access checks.
Args:
user: User attempting to access the feature
feature_name: Name of the UI feature
access_granted: Whether access was granted
required_groups: Groups required to access the feature
conversation_id: Conversation identifier
request_id: Request identifier
"""
event = UiFeatureAccessCheckEvent(
user_id=user.id,
username=user.username,
user_email=user.email,
user_groups=user.group_memberships,
conversation_id=conversation_id,
request_id=request_id,
feature_name=feature_name,
access_granted=access_granted,
required_groups=required_groups,
)
await self.log_event(event)
async def log_ai_response(
self,
user: "User",
conversation_id: str,
request_id: str,
response_text: str,
tool_calls: List["ToolCall"],
model_info: Optional[Dict[str, Any]] = None,
include_full_text: bool = False,
) -> None:
"""Convenience method for logging AI responses.
Args:
user: User receiving the response
conversation_id: Conversation identifier
request_id: Request identifier
response_text: The AI-generated response text
tool_calls: List of tool calls in the response
model_info: Optional model configuration info
include_full_text: Whether to include full response text
"""
response_hash = hashlib.sha256(response_text.encode("utf-8")).hexdigest()
event = AiResponseEvent(
user_id=user.id,
username=user.username,
user_email=user.email,
user_groups=user.group_memberships,
conversation_id=conversation_id,
request_id=request_id,
response_length_chars=len(response_text),
response_text=response_text if include_full_text else None,
response_hash=response_hash,
model_name=model_info.get("model") if model_info else None,
temperature=model_info.get("temperature") if model_info else None,
tool_calls_count=len(tool_calls),
tool_names=[tc.name for tc in tool_calls],
)
await self.log_event(event)
async def query_events(
self,
filters: Optional[Dict[str, Any]] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 100,
) -> List[AuditEvent]:
"""Query audit events (optional, for implementations that support it).
Args:
filters: Filter criteria (user_id, event_type, etc.)
start_time: Filter events after this time
end_time: Filter events before this time
limit: Maximum number of events to return
Returns:
List of matching audit events
Raises:
NotImplementedError: If query not supported by implementation
"""
raise NotImplementedError("Query not supported by this implementation")
def _sanitize_parameters(
self, parameters: Dict[str, Any]
) -> tuple[Dict[str, Any], bool]:
"""Sanitize sensitive data from parameters.
Args:
parameters: Raw parameters dict
Returns:
Tuple of (sanitized_parameters, was_sanitized)
"""
sanitized = parameters.copy()
was_sanitized = False
# Common sensitive field patterns
sensitive_patterns = [
"password",
"secret",
"token",
"api_key",
"apikey",
"credential",
"auth",
"private_key",
"access_key",
]
for key in list(sanitized.keys()):
key_lower = key.lower()
if any(pattern in key_lower for pattern in sensitive_patterns):
sanitized[key] = "[REDACTED]"
was_sanitized = True
return sanitized, was_sanitized
================================================
FILE: src/vanna/core/audit/models.py
================================================
"""
Audit event models.
This module contains data models for audit logging events.
"""
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from .._compat import StrEnum
class AuditEventType(StrEnum):
"""Types of audit events."""
# Access control events
TOOL_ACCESS_CHECK = "tool_access_check"
UI_FEATURE_ACCESS_CHECK = "ui_feature_access_check"
# Tool execution events
TOOL_INVOCATION = "tool_invocation"
TOOL_RESULT = "tool_result"
# Conversation events
MESSAGE_RECEIVED = "message_received"
AI_RESPONSE_GENERATED = "ai_response_generated"
CONVERSATION_CREATED = "conversation_created"
# Security events
ACCESS_DENIED = "access_denied"
AUTHENTICATION_ATTEMPT = "authentication_attempt"
class AuditEvent(BaseModel):
"""Base audit event with common fields."""
event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
event_type: AuditEventType
timestamp: datetime = Field(default_factory=datetime.utcnow)
# User context
user_id: str
username: Optional[str] = None
user_email: Optional[str] = None
user_groups: List[str] = Field(default_factory=list)
# Request context
conversation_id: str
request_id: str
remote_addr: Optional[str] = None
# Event-specific data
details: Dict[str, Any] = Field(default_factory=dict)
# Privacy/redaction markers
contains_pii: bool = False
redacted_fields: List[str] = Field(default_factory=list)
class ToolAccessCheckEvent(AuditEvent):
"""Audit event for tool access permission checks."""
event_type: AuditEventType = AuditEventType.TOOL_ACCESS_CHECK
tool_name: str
access_granted: bool
required_groups: List[str] = Field(default_factory=list)
reason: Optional[str] = None
class ToolInvocationEvent(AuditEvent):
"""Audit event for actual tool executions."""
event_type: AuditEventType = AuditEventType.TOOL_INVOCATION
tool_call_id: str
tool_name: str
# Parameters with sanitization support
parameters: Dict[str, Any] = Field(default_factory=dict)
parameters_sanitized: bool = False
# UI context at invocation time
ui_features_available: List[str] = Field(default_factory=list)
class ToolResultEvent(AuditEvent):
"""Audit event for tool execution results."""
event_type: AuditEventType = AuditEventType.TOOL_RESULT
tool_call_id: str
tool_name: str
success: bool
error: Optional[str] = None
execution_time_ms: float = 0.0
# Result metadata (without full content for size)
result_size_bytes: Optional[int] = None
ui_component_type: Optional[str] = None
class UiFeatureAccessCheckEvent(AuditEvent):
"""Audit event for UI feature access checks."""
event_type: AuditEventType = AuditEventType.UI_FEATURE_ACCESS_CHECK
feature_name: str
access_granted: bool
required_groups: List[str] = Field(default_factory=list)
class AiResponseEvent(AuditEvent):
"""Audit event for AI-generated responses."""
event_type: AuditEventType = AuditEventType.AI_RESPONSE_GENERATED
# Response metadata
response_length_chars: int
response_length_tokens: Optional[int] = None
# Full text (optional, configurable)
response_text: Optional[str] = None
response_hash: str # SHA256 for integrity verification
# Model info
model_name: Optional[str] = None
temperature: Optional[float] = None
# Tool calls in response
tool_calls_count: int = 0
tool_names: List[str] = Field(default_factory=list)
================================================
FILE: src/vanna/core/component_manager.py
================================================
"""
Component state management and update protocol for rich components.
"""
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel, Field
from ..components.rich import ComponentLifecycle, RichComponent
class UpdateOperation(str, Enum):
"""Types of component update operations."""
CREATE = "create"
UPDATE = "update"
REPLACE = "replace"
REMOVE = "remove"
REORDER = "reorder"
BULK_UPDATE = "bulk_update"
class Position(BaseModel):
"""Position specification for component placement."""
index: Optional[int] = None
anchor_id: Optional[str] = None
relation: str = "after" # "before", "after", "inside", "replace"
class ComponentUpdate(BaseModel):
"""Represents a change to the component tree."""
operation: UpdateOperation
target_id: str # Component being affected
component: Optional[RichComponent] = None # New/updated component data
updates: Optional[Dict[str, Any]] = None # Partial updates for UPDATE operation
position: Optional[Position] = None # For positioning operations
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
batch_id: Optional[str] = None # For grouping related updates
def serialize_for_frontend(self) -> Dict[str, Any]:
"""Return update payload with nested components normalized."""
payload = self.model_dump()
# Normalise enum values for the frontend contract.
payload["operation"] = self.operation.value
if self.component:
payload["component"] = self.component.serialize_for_frontend()
return payload
class ComponentNode(BaseModel):
"""Node in the component tree."""
component: RichComponent
children: List["ComponentNode"] = Field(default_factory=list)
parent_id: Optional[str] = None
def find_child(self, component_id: str) -> Optional["ComponentNode"]:
"""Find a child node by component ID."""
for child in self.children:
if child.component.id == component_id:
return child
found = child.find_child(component_id)
if found:
return found
return None
def remove_child(self, component_id: str) -> bool:
"""Remove a child component by ID."""
for i, child in enumerate(self.children):
if child.component.id == component_id:
self.children.pop(i)
return True
if child.remove_child(component_id):
return True
return False
def get_all_ids(self) -> Set[str]:
"""Get all component IDs in this subtree."""
ids = {self.component.id}
for child in self.children:
ids.update(child.get_all_ids())
return ids
class ComponentTree(BaseModel):
"""Hierarchical structure for managing component layout."""
root: Optional[ComponentNode] = None
flat_index: Dict[str, ComponentNode] = Field(default_factory=dict)
def add_component(
self, component: RichComponent, position: Optional[Position] = None
) -> ComponentUpdate:
"""Add a component to the tree."""
node = ComponentNode(component=component)
self.flat_index[component.id] = node
if self.root is None:
self.root = node
else:
parent_node = self._find_parent(position)
if parent_node is not None:
node.parent_id = parent_node.component.id
parent_node.children.append(node)
return ComponentUpdate(
operation=UpdateOperation.CREATE,
target_id=component.id,
component=component,
position=position,
)
def update_component(
self, component_id: str, updates: Dict[str, Any]
) -> Optional[ComponentUpdate]:
"""Update a component's properties."""
node = self.flat_index.get(component_id)
if not node:
return None
# Create updated component
component_data = node.component.model_dump()
component_data.update(updates)
component_data["lifecycle"] = ComponentLifecycle.UPDATE
component_data["timestamp"] = datetime.utcnow().isoformat()
updated_component = node.component.__class__(**component_data)
node.component = updated_component
return ComponentUpdate(
operation=UpdateOperation.UPDATE,
target_id=component_id,
component=updated_component,
updates=updates,
)
def replace_component(
self, old_id: str, new_component: RichComponent
) -> Optional[ComponentUpdate]:
"""Replace one component with another."""
old_node = self.flat_index.get(old_id)
if not old_node:
return None
# Update the component in place
old_node.component = new_component
# Update index
del self.flat_index[old_id]
self.flat_index[new_component.id] = old_node
return ComponentUpdate(
operation=UpdateOperation.REPLACE, target_id=old_id, component=new_component
)
def remove_component(self, component_id: str) -> Optional[ComponentUpdate]:
"""Remove a component and its children."""
node = self.flat_index.get(component_id)
if not node:
return None
# Remove from parent
if self.root and self.root.component.id == component_id:
self.root = None
else:
if self.root:
self.root.remove_child(component_id)
# Remove from flat index (including all children)
removed_ids = node.get_all_ids()
for removed_id in removed_ids:
self.flat_index.pop(removed_id, None)
return ComponentUpdate(operation=UpdateOperation.REMOVE, target_id=component_id)
def get_component(self, component_id: str) -> Optional[RichComponent]:
"""Get a component by ID."""
node = self.flat_index.get(component_id)
return node.component if node else None
def _find_parent(self, position: Optional[Position]) -> Optional[ComponentNode]:
"""Find the parent node for a new component."""
if not position or not position.anchor_id:
return self.root
anchor_node = self.flat_index.get(position.anchor_id)
if not anchor_node:
return self.root
if position.relation == "inside":
return anchor_node
elif position.relation in ["before", "after", "replace"]:
# Find the parent of the anchor
if anchor_node.parent_id:
parent_node = self.flat_index.get(anchor_node.parent_id)
return parent_node if parent_node else self.root
else:
return self.root
else:
return self.root
class ComponentManager:
"""Manages component lifecycle and state updates."""
def __init__(self) -> None:
self.components: Dict[str, RichComponent] = {}
self.component_tree = ComponentTree()
self.update_history: List[ComponentUpdate] = []
self.active_batch: Optional[str] = None
def emit(self, component: RichComponent) -> Optional[ComponentUpdate]:
"""Emit a component with smart lifecycle management."""
if component.id in self.components:
# Existing component - determine if this is an update or replace
existing = self.components[component.id]
if component.lifecycle == ComponentLifecycle.UPDATE:
# Extract changes
old_data = existing.model_dump()
new_data = component.model_dump()
updates = {k: v for k, v in new_data.items() if old_data.get(k) != v}
update = self.component_tree.update_component(component.id, updates)
else:
# Replace
update = self.component_tree.replace_component(component.id, component)
else:
# New component - always append
update = self.component_tree.add_component(component, None)
if update:
self.components[component.id] = component
self.update_history.append(update)
if self.active_batch:
update.batch_id = self.active_batch
return update
def update_component(
self, component_id: str, **updates: Any
) -> Optional[ComponentUpdate]:
"""Update specific fields of an existing component."""
update = self.component_tree.update_component(component_id, updates)
if update and update.component:
self.components[component_id] = update.component
self.update_history.append(update)
if self.active_batch:
update.batch_id = self.active_batch
return update
def replace_component(
self, old_id: str, new_component: RichComponent
) -> Optional[ComponentUpdate]:
"""Replace one component with another."""
update = self.component_tree.replace_component(old_id, new_component)
if update:
self.components.pop(old_id, None)
self.components[new_component.id] = new_component
self.update_history.append(update)
if self.active_batch:
update.batch_id = self.active_batch
return update
def remove_component(self, component_id: str) -> Optional[ComponentUpdate]:
"""Remove a component and handle cleanup."""
update = self.component_tree.remove_component(component_id)
if update:
self.components.pop(component_id, None)
self.update_history.append(update)
if self.active_batch:
update.batch_id = self.active_batch
return update
def get_component(self, component_id: str) -> Optional[RichComponent]:
"""Get a component by ID."""
return self.components.get(component_id)
def get_all_components(self) -> List[RichComponent]:
"""Get all components in the manager."""
return list(self.components.values())
def start_batch(self) -> str:
"""Start a batch of related updates."""
self.active_batch = str(uuid.uuid4())
return self.active_batch
def end_batch(self) -> Optional[str]:
"""End the current batch."""
batch_id = self.active_batch
self.active_batch = None
return batch_id
def get_updates_since(
self, timestamp: Optional[str] = None
) -> List[ComponentUpdate]:
"""Get all updates since a given timestamp."""
if not timestamp:
return self.update_history.copy()
try:
cutoff = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
return [
update
for update in self.update_history
if datetime.fromisoformat(update.timestamp.replace("Z", "+00:00"))
> cutoff
]
except ValueError:
return self.update_history.copy()
def clear_history(self) -> None:
"""Clear the update history."""
self.update_history.clear()
================================================
FILE: src/vanna/core/components.py
================================================
"""
UI component base class.
This module defines the UiComponent class which is the return type for tool executions.
It's placed in core/ because it's a fundamental type that tools return, not just a UI concern.
"""
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field, model_validator
class UiComponent(BaseModel):
"""Base class for UI components streamed to client.
This wraps both rich and simple component representations,
allowing tools to return structured UI updates.
Note: We use Any for component types to avoid circular dependencies.
Type validation happens at runtime through validators.
"""
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
rich_component: Any = Field(
..., description="Rich component for advanced rendering"
)
simple_component: Optional[Any] = Field(
None, description="Simple component for basic rendering"
)
@model_validator(mode="after")
def validate_components(self) -> "UiComponent":
"""Validate that components are the correct types at runtime."""
# Import from core - clean imports, no circular dependency
from .rich_component import RichComponent
from .simple_component import SimpleComponent
if not isinstance(self.rich_component, RichComponent):
raise ValueError(
f"rich_component must be a RichComponent, got {type(self.rich_component)}"
)
if self.simple_component is not None and not isinstance(
self.simple_component, SimpleComponent
):
raise ValueError(
f"simple_component must be a SimpleComponent or None, got {type(self.simple_component)}"
)
return self
model_config = {"arbitrary_types_allowed": True}
================================================
FILE: src/vanna/core/enhancer/__init__.py
================================================
"""
LLM context enhancement system for adding context to prompts and messages.
This module provides interfaces for enriching LLM system prompts and messages
with additional context before LLM calls (e.g., from memory, RAG, documentation).
"""
from .base import LlmContextEnhancer
from .default import DefaultLlmContextEnhancer
__all__ = ["LlmContextEnhancer", "DefaultLlmContextEnhancer"]
================================================
FILE: src/vanna/core/enhancer/base.py
================================================
"""
LLM context enhancer interface.
LLM context enhancers allow you to add additional context to the system prompt
and user messages before LLM calls.
"""
from abc import ABC
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from ..user.models import User
from ..llm.models import LlmMessage
class LlmContextEnhancer(ABC):
"""Enhancer for adding context to LLM prompts and messages.
Subclass this to create custom enhancers that can:
- Add relevant context to the system prompt based on the user's initial message
- Enrich user messages with additional context (e.g., from memory/RAG)
- Inject relevant examples or documentation
- Add temporal or environmental context
Example:
class MemoryBasedEnhancer(LlmContextEnhancer):
def __init__(self, agent_memory):
self.agent_memory = agent_memory
async def enhance_system_prompt(
self,
system_prompt: str,
user_message: str,
user: User
) -> str:
# Add relevant examples from memory based on user message
examples = await self.agent_memory.search_similar(user_message)
return system_prompt + "\\n\\nRelevant examples:\\n" + examples
async def enhance_user_messages(
self,
messages: list[LlmMessage],
user: User
) -> list[LlmMessage]:
# Could modify or add to messages
return messages
agent = Agent(
llm_service=...,
llm_context_enhancer=MemoryBasedEnhancer(agent_memory)
)
"""
async def enhance_system_prompt(
self, system_prompt: str, user_message: str, user: "User"
) -> str:
"""Enhance the system prompt with additional context.
This method is called before the first LLM request with the initial
user message, allowing you to add relevant context to the system prompt.
Args:
system_prompt: The original system prompt
user_message: The initial user message
user: The user making the request
Returns:
Enhanced system prompt with additional context
Note:
This is called once per conversation turn, before any tool calls.
"""
return system_prompt
async def enhance_user_messages(
self, messages: list["LlmMessage"], user: "User"
) -> list["LlmMessage"]:
"""Enhance user messages with additional context.
This method is called to potentially modify or add context to user messages
before sending them to the LLM.
Args:
messages: The list of messages to enhance
user: The user making the request
Returns:
Enhanced list of messages
Note:
This is called before each LLM request, including after tool calls.
Be careful not to add context repeatedly on each iteration.
"""
return messages
================================================
FILE: src/vanna/core/enhancer/default.py
================================================
"""
Default LLM context enhancer implementation using AgentMemory.
This implementation enriches the system prompt with relevant memories
based on the user's initial message.
"""
from typing import TYPE_CHECKING, List, Optional
from .base import LlmContextEnhancer
if TYPE_CHECKING:
from ..user.models import User
from ..llm.models import LlmMessage
from ...capabilities.agent_memory import AgentMemory, TextMemorySearchResult
class DefaultLlmContextEnhancer(LlmContextEnhancer):
"""Default enhancer that uses AgentMemory to add relevant context.
This enhancer searches the agent's memory for relevant examples and
tool use patterns based on the user's message, and adds them to the
system prompt.
Example:
agent = Agent(
llm_service=...,
agent_memory=agent_memory,
llm_context_enhancer=DefaultLlmContextEnhancer(agent_memory)
)
"""
def __init__(self, agent_memory: Optional["AgentMemory"] = None):
"""Initialize with optional agent memory.
Args:
agent_memory: Optional AgentMemory instance. If not provided,
enhancement will be skipped.
"""
self.agent_memory = agent_memory
async def enhance_system_prompt(
self, system_prompt: str, user_message: str, user: "User"
) -> str:
"""Enhance system prompt with relevant memories.
Searches agent memory for relevant text memories based on the
user's message and adds them to the system prompt.
Args:
system_prompt: The original system prompt
user_message: The initial user message
user: The user making the request
Returns:
Enhanced system prompt with relevant examples from memory
"""
if not self.agent_memory:
return system_prompt
try:
# Import here to avoid circular dependency
from ..tool import ToolContext
import uuid
# Create a temporary context for memory search
context = ToolContext(
user=user,
conversation_id="temp",
request_id=str(uuid.uuid4()),
agent_memory=self.agent_memory,
)
# Search for relevant text memories based on user message
memories: List[
"TextMemorySearchResult"
] = await self.agent_memory.search_text_memories(
query=user_message, context=context, limit=5
)
if not memories:
return system_prompt
# Format memories as context snippets to add to system prompt
examples_section = "\n\n## Relevant Context from Memory\n\n"
examples_section += "The following domain knowledge and context from prior interactions may be relevant:\n\n"
for result in memories:
memory = result.memory
examples_section += f"• {memory.content}\n"
# Append examples to system prompt
return system_prompt + examples_section
except Exception as e:
# If memory search fails, return original prompt
# Don't fail the entire request due to memory issues
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Failed to enhance system prompt with memories: {e}")
return system_prompt
async def enhance_user_messages(
self, messages: list["LlmMessage"], user: "User"
) -> list["LlmMessage"]:
"""Enhance user messages.
The default implementation doesn't modify user messages.
Override this to add context to user messages if needed.
Args:
messages: The list of messages
user: The user making the request
Returns:
Original list of messages (unmodified)
"""
return messages
================================================
FILE: src/vanna/core/enricher/__init__.py
================================================
"""
Context enrichment system for adding data to tool execution context.
This module provides interfaces for enriching ToolContext with additional
data before tool execution.
"""
from .base import ToolContextEnricher
__all__ = ["ToolContextEnricher"]
================================================
FILE: src/vanna/core/enricher/base.py
================================================
"""
Base context enricher interface.
Context enrichers allow you to add additional data to the ToolContext
before tools are executed.
"""
from abc import ABC
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..tool.models import ToolContext
class ToolContextEnricher(ABC):
"""Enricher for adding data to ToolContext.
Subclass this to create custom enrichers that can:
- Add user preferences from database
- Inject session state
- Add temporal context (timezone, current date)
- Include user history or profile data
- Add environment-specific configuration
Example:
class UserPreferencesEnricher(ToolContextEnricher):
def __init__(self, db):
self.db = db
async def enrich_context(self, context: ToolContext) -> ToolContext:
# Fetch user preferences
prefs = await self.db.get_user_preferences(context.user.id)
# Add to context metadata
context.metadata["preferences"] = prefs
context.metadata["timezone"] = prefs.get("timezone", "UTC")
return context
agent = AgentRunner(
llm_service=...,
context_enrichers=[UserPreferencesEnricher(db), SessionEnricher()]
)
"""
async def enrich_context(self, context: "ToolContext") -> "ToolContext":
"""Enrich the tool execution context with additional data.
Args:
context: The tool context to enrich
Returns:
Enriched context (typically modified in-place)
Note:
Enrichers typically modify the context.metadata dict to add
additional data that tools can access.
"""
return context
================================================
FILE: src/vanna/core/errors.py
================================================
"""
Exception classes for the Vanna Agents framework.
This module defines all custom exceptions used throughout the framework.
"""
class AgentError(Exception):
"""Base exception for agent framework."""
pass
class ToolExecutionError(AgentError):
"""Error during tool execution."""
pass
class ToolNotFoundError(AgentError):
"""Tool not found in registry."""
pass
class PermissionError(AgentError):
"""User lacks required permissions."""
pass
class ConversationNotFoundError(AgentError):
"""Conversation not found."""
pass
class LlmServiceError(AgentError):
"""Error communicating with LLM service."""
pass
class ValidationError(AgentError):
"""Data validation error."""
pass
================================================
FILE: src/vanna/core/evaluation/__init__.py
================================================
"""
Evaluation framework for Vanna Agents.
This module provides a complete evaluation system for testing and comparing
agent variants, with special focus on LLM comparison use cases.
Key Features:
- Parallel execution for efficient I/O-bound operations
- Multiple built-in evaluators (trajectory, output, LLM-as-judge, efficiency)
- Rich reporting (HTML, CSV, console)
- Dataset loaders (YAML, JSON)
- Agent variant comparison
Example:
>>> from vanna.evaluation import (
... EvaluationRunner,
... EvaluationDataset,
... AgentVariant,
... TrajectoryEvaluator,
... OutputEvaluator,
... )
>>>
>>> # Load test dataset
>>> dataset = EvaluationDataset.from_yaml("tests/sql_tasks.yaml")
>>>
>>> # Create agent variants
>>> variants = [
... AgentVariant("claude", claude_agent),
... AgentVariant("gpt", gpt_agent),
... ]
>>>
>>> # Run comparison
>>> runner = EvaluationRunner(
... evaluators=[TrajectoryEvaluator(), OutputEvaluator()],
... max_concurrency=20
... )
>>> comparison = await runner.compare_agents(variants, dataset.test_cases)
>>> comparison.print_summary()
"""
from .base import (
Evaluator,
TestCase,
ExpectedOutcome,
AgentResult,
EvaluationResult,
TestCaseResult,
AgentVariant,
)
from .runner import EvaluationRunner
from .evaluators import (
TrajectoryEvaluator,
OutputEvaluator,
LLMAsJudgeEvaluator,
EfficiencyEvaluator,
)
from .report import EvaluationReport, ComparisonReport
from .dataset import EvaluationDataset
__all__ = [
# Base classes
"Evaluator",
"TestCase",
"ExpectedOutcome",
"AgentResult",
"EvaluationResult",
"TestCaseResult",
"AgentVariant",
# Runner
"EvaluationRunner",
# Built-in evaluators
"TrajectoryEvaluator",
"OutputEvaluator",
"LLMAsJudgeEvaluator",
"EfficiencyEvaluator",
# Reporting
"EvaluationReport",
"ComparisonReport",
# Datasets
"EvaluationDataset",
]
================================================
FILE: src/vanna/core/evaluation/base.py
================================================
"""
Core evaluation abstractions for the Vanna Agents framework.
This module provides the base classes and models for evaluating agent behavior,
including test cases, expected outcomes, and evaluation results.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime
from pydantic import BaseModel
from vanna.core import User, UiComponent
class ExpectedOutcome(BaseModel):
"""Defines what we expect from the agent for a test case.
Provides multiple ways to specify expectations:
- tools_called: List of tool names that should be called
- tools_not_called: List of tool names that should NOT be called
- final_answer_contains: Keywords/phrases that should appear in output
- final_answer_not_contains: Keywords/phrases that should NOT appear
- min_components: Minimum number of UI components expected
- max_execution_time_ms: Maximum allowed execution time
- custom_validators: Custom validation functions
"""
tools_called: Optional[List[str]] = None
tools_not_called: Optional[List[str]] = None
final_answer_contains: Optional[List[str]] = None
final_answer_not_contains: Optional[List[str]] = None
min_components: Optional[int] = None
max_components: Optional[int] = None
max_execution_time_ms: Optional[float] = None
metadata: Dict[str, Any] = {}
class TestCase(BaseModel):
"""A single evaluation test case.
Attributes:
id: Unique identifier for the test case
user: User context for the test
message: The message to send to the agent
conversation_id: Optional conversation ID for multi-turn tests
expected_outcome: What we expect the agent to do/produce
metadata: Additional metadata for categorization/filtering
"""
id: str
user: User
message: str
conversation_id: Optional[str] = None
expected_outcome: Optional[ExpectedOutcome] = None
metadata: Dict[str, Any] = {}
@dataclass
class AgentResult:
"""The result of running an agent on a test case.
Captures everything that happened during agent execution
for later evaluation.
"""
test_case_id: str
components: List[UiComponent]
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
llm_requests: List[Dict[str, Any]] = field(default_factory=list)
execution_time_ms: float = 0.0
total_tokens: int = 0
error: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def get_final_answer(self) -> str:
"""Extract the final answer from components."""
# Find text components and concatenate
texts = []
for component in self.components:
if hasattr(component, "rich_component"):
rich_comp = component.rich_component
if hasattr(rich_comp, "type") and rich_comp.type.value == "text":
content = rich_comp.data.get("content") or getattr(
rich_comp, "content", ""
)
if content:
texts.append(content)
return "\n".join(texts)
def get_tool_names_called(self) -> List[str]:
"""Get list of tool names that were called."""
return [call.get("tool_name", "") for call in self.tool_calls]
class EvaluationResult(BaseModel):
"""Result of evaluating a single test case.
Attributes:
test_case_id: ID of the test case evaluated
evaluator_name: Name of the evaluator that produced this result
passed: Whether the test case passed
score: Score from 0.0 to 1.0
reasoning: Explanation of the evaluation
metrics: Additional metrics captured during evaluation
timestamp: When the evaluation was performed
"""
test_case_id: str
evaluator_name: str
passed: bool
score: float # 0.0 to 1.0
reasoning: str
metrics: Dict[str, Any] = {}
timestamp: datetime = datetime.now()
@dataclass
class TestCaseResult:
"""Complete result for a single test case including all evaluations."""
test_case: TestCase
agent_result: AgentResult
evaluations: List[EvaluationResult]
execution_time_ms: float
def overall_passed(self) -> bool:
"""Check if all evaluations passed."""
return all(e.passed for e in self.evaluations)
def overall_score(self) -> float:
"""Calculate average score across all evaluations."""
if not self.evaluations:
return 0.0
return sum(e.score for e in self.evaluations) / len(self.evaluations)
@dataclass
class AgentVariant:
"""A variant of an agent to evaluate (different LLM, config, etc).
Used for comparing different agent configurations, especially
different LLMs or model versions.
Attributes:
name: Human-readable name for this variant
agent: The agent instance to evaluate
metadata: Additional info (model name, provider, config, etc)
"""
name: str
agent: Any # Agent type - avoiding circular import
metadata: Dict[str, Any] = field(default_factory=dict)
class Evaluator(ABC):
"""Base class for evaluating agent behavior.
Evaluators examine the agent's execution and determine if it
met expectations. Multiple evaluators can be composed to check
different aspects (trajectory, output quality, efficiency, etc).
"""
@property
@abstractmethod
def name(self) -> str:
"""Name of this evaluator."""
pass
@abstractmethod
async def evaluate(
self,
test_case: TestCase,
agent_result: AgentResult,
) -> EvaluationResult:
"""Evaluate a single test case execution.
Args:
test_case: The test case that was executed
agent_result: The result from running the agent
Returns:
EvaluationResult with pass/fail, score, and reasoning
"""
pass
================================================
FILE: src/vanna/core/evaluation/dataset.py
================================================
"""
Dataset loaders for evaluation test cases.
This module provides utilities for loading test case datasets from
YAML and JSON files.
"""
import json
import yaml
from typing import Any, Dict, List
from pathlib import Path
from .base import TestCase, ExpectedOutcome
from vanna.core import User
class EvaluationDataset:
"""Collection of test cases with metadata.
Example YAML format:
dataset:
name: "SQL Generation Tasks"
description: "Test cases for SQL generation"
test_cases:
- id: "sql_001"
user_id: "test_user"
message: "Show me total sales by region"
expected_outcome:
tools_called: ["generate_sql", "execute_query"]
final_answer_contains: ["SELECT", "GROUP BY", "region"]
"""
def __init__(self, name: str, test_cases: List[TestCase], description: str = ""):
"""Initialize evaluation dataset.
Args:
name: Name of the dataset
test_cases: List of test cases
description: Optional description
"""
self.name = name
self.test_cases = test_cases
self.description = description
@classmethod
def from_yaml(cls, path: str) -> "EvaluationDataset":
"""Load dataset from YAML file.
Args:
path: Path to YAML file
Returns:
EvaluationDataset instance
"""
with open(path, "r") as f:
data = yaml.safe_load(f)
return cls._from_dict(data)
@classmethod
def from_json(cls, path: str) -> "EvaluationDataset":
"""Load dataset from JSON file.
Args:
path: Path to JSON file
Returns:
EvaluationDataset instance
"""
with open(path, "r") as f:
data = json.load(f)
return cls._from_dict(data)
@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "EvaluationDataset":
"""Create dataset from dictionary.
Args:
data: Dictionary with dataset structure
Returns:
EvaluationDataset instance
"""
dataset_config = data.get("dataset", data)
name = dataset_config.get("name", "Unnamed Dataset")
description = dataset_config.get("description", "")
test_cases = []
for tc_data in dataset_config.get("test_cases", []):
test_case = cls._parse_test_case(tc_data)
test_cases.append(test_case)
return cls(name=name, test_cases=test_cases, description=description)
@classmethod
def _parse_test_case(cls, data: Dict[str, Any]) -> TestCase:
"""Parse a single test case from dictionary.
Args:
data: Test case dictionary
Returns:
TestCase instance
"""
# Create user
user_id = data.get("user_id", "test_user")
user = User(
id=user_id,
username=data.get("username", user_id),
email=data.get("email", f"{user_id}@example.com"),
group_memberships=data.get("user_groups", []),
)
# Parse expected outcome if present
expected_outcome = None
if "expected_outcome" in data:
outcome_data = data["expected_outcome"]
expected_outcome = ExpectedOutcome(
tools_called=outcome_data.get("tools_called"),
tools_not_called=outcome_data.get("tools_not_called"),
final_answer_contains=outcome_data.get("final_answer_contains"),
final_answer_not_contains=outcome_data.get("final_answer_not_contains"),
min_components=outcome_data.get("min_components"),
max_components=outcome_data.get("max_components"),
max_execution_time_ms=outcome_data.get("max_execution_time_ms"),
metadata=outcome_data.get("metadata", {}),
)
return TestCase(
id=data["id"],
user=user,
message=data["message"],
conversation_id=data.get("conversation_id"),
expected_outcome=expected_outcome,
metadata=data.get("metadata", {}),
)
def save_yaml(self, path: str) -> None:
"""Save dataset to YAML file.
Args:
path: Path to save YAML file
"""
data = self._to_dict()
with open(path, "w") as f:
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
def save_json(self, path: str) -> None:
"""Save dataset to JSON file.
Args:
path: Path to save JSON file
"""
data = self._to_dict()
with open(path, "w") as f:
json.dump(data, f, indent=2)
def _to_dict(self) -> Dict[str, Any]:
"""Convert dataset to dictionary.
Returns:
Dictionary representation
"""
return {
"dataset": {
"name": self.name,
"description": self.description,
"test_cases": [self._test_case_to_dict(tc) for tc in self.test_cases],
}
}
def _test_case_to_dict(self, test_case: TestCase) -> Dict[str, Any]:
"""Convert test case to dictionary.
Args:
test_case: TestCase to convert
Returns:
Dictionary representation
"""
data: Dict[str, Any] = {
"id": test_case.id,
"user_id": test_case.user.id,
"username": test_case.user.username,
"email": test_case.user.email,
"user_groups": test_case.user.group_memberships,
"message": test_case.message,
}
if test_case.conversation_id:
data["conversation_id"] = test_case.conversation_id
if test_case.expected_outcome:
outcome = test_case.expected_outcome
outcome_dict: Dict[str, Any] = {}
if outcome.tools_called:
outcome_dict["tools_called"] = outcome.tools_called
if outcome.tools_not_called:
outcome_dict["tools_not_called"] = outcome.tools_not_called
if outcome.final_answer_contains:
outcome_dict["final_answer_contains"] = outcome.final_answer_contains
if outcome.final_answer_not_contains:
outcome_dict["final_answer_not_contains"] = (
outcome.final_answer_not_contains
)
if outcome.min_components is not None:
outcome_dict["min_components"] = outcome.min_components
if outcome.max_components is not None:
outcome_dict["max_components"] = outcome.max_components
if outcome.max_execution_time_ms is not None:
outcome_dict["max_execution_time_ms"] = outcome.max_execution_time_ms
if outcome.metadata:
outcome_dict["metadata"] = outcome.metadata
if outcome_dict:
data["expected_outcome"] = outcome_dict
if test_case.metadata:
data["metadata"] = test_case.metadata
return data
def filter_by_metadata(self, **kwargs: Any) -> "EvaluationDataset":
"""Filter test cases by metadata fields.
Args:
**kwargs: Metadata fields to match
Returns:
New EvaluationDataset with filtered test cases
"""
filtered = [
tc
for tc in self.test_cases
if all(tc.metadata.get(k) == v for k, v in kwargs.items())
]
return EvaluationDataset(
name=f"{self.name} (filtered)",
test_cases=filtered,
description=f"Filtered from: {self.description}",
)
def __len__(self) -> int:
"""Get number of test cases."""
return len(self.test_cases)
def __repr__(self) -> str:
"""String representation."""
return (
f"EvaluationDataset(name='{self.name}', test_cases={len(self.test_cases)})"
)
================================================
FILE: src/vanna/core/evaluation/evaluators.py
================================================
"""
Built-in evaluators for common evaluation tasks.
This module provides ready-to-use evaluators for:
- Trajectory evaluation (tools called, order, efficiency)
- Output evaluation (content matching, quality)
- LLM-as-judge evaluation (custom criteria)
- Efficiency evaluation (time, tokens, cost)
"""
from typing import Dict, Any, Optional
from datetime import datetime
from .base import Evaluator, TestCase, AgentResult, EvaluationResult
from vanna.core import LlmService
class TrajectoryEvaluator(Evaluator):
"""Evaluate the path the agent took (tools called, order, etc).
Checks if the agent called the expected tools and didn't call
unexpected ones. Useful for verifying agent reasoning and planning.
"""
@property
def name(self) -> str:
return "trajectory"
async def evaluate(
self, test_case: TestCase, agent_result: AgentResult
) -> EvaluationResult:
"""Evaluate tool call trajectory."""
if agent_result.error:
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=False,
score=0.0,
reasoning=f"Agent execution failed: {agent_result.error}",
)
expected = test_case.expected_outcome
if not expected:
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=True,
score=1.0,
reasoning="No expected outcome specified, passing by default",
)
tools_called = agent_result.get_tool_names_called()
issues = []
score = 1.0
# Check expected tools were called
if expected.tools_called:
for expected_tool in expected.tools_called:
if expected_tool not in tools_called:
issues.append(f"Expected tool '{expected_tool}' was not called")
score -= 0.5 / len(expected.tools_called)
# Check unexpected tools were not called
if expected.tools_not_called:
for unexpected_tool in expected.tools_not_called:
if unexpected_tool in tools_called:
issues.append(f"Unexpected tool '{unexpected_tool}' was called")
score -= 0.5 / len(expected.tools_not_called)
score = max(0.0, min(1.0, score))
passed = score >= 0.7 # 70% threshold
reasoning = "Trajectory evaluation: "
if issues:
reasoning += "; ".join(issues)
else:
reasoning += "All expected tools called, no unexpected tools"
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=passed,
score=score,
reasoning=reasoning,
metrics={
"tools_called": tools_called,
"num_tools_called": len(tools_called),
"issues": issues,
},
)
class OutputEvaluator(Evaluator):
"""Evaluate the final output quality.
Checks if the output contains expected content and doesn't
contain forbidden content. Case-insensitive substring matching.
"""
@property
def name(self) -> str:
return "output"
async def evaluate(
self, test_case: TestCase, agent_result: AgentResult
) -> EvaluationResult:
"""Evaluate output content."""
if agent_result.error:
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=False,
score=0.0,
reasoning=f"Agent execution failed: {agent_result.error}",
)
expected = test_case.expected_outcome
if not expected:
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=True,
score=1.0,
reasoning="No expected outcome specified, passing by default",
)
final_answer = agent_result.get_final_answer().lower()
issues = []
score = 1.0
# Check expected content is present
if expected.final_answer_contains:
for expected_content in expected.final_answer_contains:
if expected_content.lower() not in final_answer:
issues.append(
f"Expected content '{expected_content}' not found in output"
)
score -= 0.5 / len(expected.final_answer_contains)
# Check forbidden content is absent
if expected.final_answer_not_contains:
for forbidden_content in expected.final_answer_not_contains:
if forbidden_content.lower() in final_answer:
issues.append(
f"Forbidden content '{forbidden_content}' found in output"
)
score -= 0.5 / len(expected.final_answer_not_contains)
score = max(0.0, min(1.0, score))
passed = score >= 0.7 # 70% threshold
reasoning = "Output evaluation: "
if issues:
reasoning += "; ".join(issues)
else:
reasoning += "All expected content present, no forbidden content"
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=passed,
score=score,
reasoning=reasoning,
metrics={
"output_length": len(final_answer),
"issues": issues,
},
)
class LLMAsJudgeEvaluator(Evaluator):
"""Use an LLM to judge agent performance based on custom criteria.
This evaluator uses a separate LLM to assess the quality of the
agent's output based on natural language criteria.
"""
def __init__(self, judge_llm: LlmService, criteria: str):
"""Initialize LLM-as-judge evaluator.
Args:
judge_llm: The LLM service to use for judging
criteria: Natural language description of what to evaluate
"""
self.judge_llm = judge_llm
self.criteria = criteria
@property
def name(self) -> str:
return "llm_judge"
async def evaluate(
self, test_case: TestCase, agent_result: AgentResult
) -> EvaluationResult:
"""Evaluate using LLM as judge."""
if agent_result.error:
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=False,
score=0.0,
reasoning=f"Agent execution failed: {agent_result.error}",
)
final_answer = agent_result.get_final_answer()
# Build prompt for judge
judge_prompt = f"""You are evaluating an AI agent's response to a user query.
User Query: {test_case.message}
Agent's Response:
{final_answer}
Evaluation Criteria:
{self.criteria}
Please evaluate the response and provide:
1. A score from 0.0 to 1.0 (where 1.0 is perfect)
2. Whether it passes (score >= 0.7)
3. Brief reasoning for your evaluation
Respond in this format:
SCORE:
PASSED:
REASONING:
"""
try:
# Call judge LLM
from vanna.core.llm import LlmRequest, LlmMessage
request = LlmRequest(
user=test_case.user,
messages=[LlmMessage(role="user", content=judge_prompt)],
temperature=0.0, # Deterministic judging
)
response = await self.judge_llm.send_request(request)
judgment = response.content or ""
# Parse response
score = self._parse_score(judgment)
passed = self._parse_passed(judgment)
reasoning = self._parse_reasoning(judgment)
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=passed,
score=score,
reasoning=reasoning,
metrics={"judge_response": judgment},
)
except Exception as e:
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=False,
score=0.0,
reasoning=f"LLM judge evaluation failed: {str(e)}",
)
def _parse_score(self, judgment: str) -> float:
"""Parse score from judge response."""
try:
for line in judgment.split("\n"):
if line.startswith("SCORE:"):
score_str = line.replace("SCORE:", "").strip()
return float(score_str)
except Exception:
pass
return 0.5 # Default if parsing fails
def _parse_passed(self, judgment: str) -> bool:
"""Parse pass/fail from judge response."""
for line in judgment.split("\n"):
if line.startswith("PASSED:"):
passed_str = line.replace("PASSED:", "").strip().lower()
return passed_str in ["yes", "true", "pass"]
return False
def _parse_reasoning(self, judgment: str) -> str:
"""Parse reasoning from judge response."""
for line in judgment.split("\n"):
if line.startswith("REASONING:"):
return line.replace("REASONING:", "").strip()
return judgment # Return full judgment if no reasoning line found
class EfficiencyEvaluator(Evaluator):
"""Evaluate resource usage (time, tokens, cost).
Checks if the agent completed within acceptable resource limits.
"""
def __init__(
self,
max_execution_time_ms: Optional[float] = None,
max_tokens: Optional[int] = None,
max_cost_usd: Optional[float] = None,
):
"""Initialize efficiency evaluator.
Args:
max_execution_time_ms: Maximum allowed execution time in milliseconds
max_tokens: Maximum allowed token usage
max_cost_usd: Maximum allowed cost in USD
"""
self.max_execution_time_ms = max_execution_time_ms
self.max_tokens = max_tokens
self.max_cost_usd = max_cost_usd
@property
def name(self) -> str:
return "efficiency"
async def evaluate(
self, test_case: TestCase, agent_result: AgentResult
) -> EvaluationResult:
"""Evaluate resource efficiency."""
issues = []
score = 1.0
# Check execution time
if self.max_execution_time_ms:
if agent_result.execution_time_ms > self.max_execution_time_ms:
issues.append(
f"Execution time {agent_result.execution_time_ms:.0f}ms "
f"exceeded limit {self.max_execution_time_ms:.0f}ms"
)
score -= 0.33
# Check token usage
if self.max_tokens:
if agent_result.total_tokens > self.max_tokens:
issues.append(
f"Token usage {agent_result.total_tokens} exceeded limit {self.max_tokens}"
)
score -= 0.33
# Check cost (would need cost calculation from metadata)
# For now, skip cost evaluation
# Check from expected outcome if specified
expected = test_case.expected_outcome
if expected and expected.max_execution_time_ms:
if agent_result.execution_time_ms > expected.max_execution_time_ms:
issues.append(
f"Execution time {agent_result.execution_time_ms:.0f}ms "
f"exceeded test case limit {expected.max_execution_time_ms:.0f}ms"
)
score -= 0.34
score = max(0.0, min(1.0, score))
passed = score >= 0.7
reasoning = "Efficiency evaluation: "
if issues:
reasoning += "; ".join(issues)
else:
reasoning += "Within resource limits"
return EvaluationResult(
test_case_id=test_case.id,
evaluator_name=self.name,
passed=passed,
score=score,
reasoning=reasoning,
metrics={
"execution_time_ms": agent_result.execution_time_ms,
"total_tokens": agent_result.total_tokens,
"issues": issues,
},
)
================================================
FILE: src/vanna/core/evaluation/report.py
================================================
"""
Evaluation reporting with HTML, CSV, and console output.
This module provides classes for generating evaluation reports,
including comparison reports for evaluating multiple agent variants.
"""
import csv
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime
from .base import TestCaseResult, AgentVariant, Evaluator, TestCase
@dataclass
class EvaluationReport:
"""Report for a single agent's evaluation results.
Attributes:
agent_name: Name of the agent evaluated
results: List of results for each test case
evaluators: List of evaluators used
metadata: Additional metadata about the agent/run
timestamp: When the evaluation was run
"""
agent_name: str
results: List[TestCaseResult]
evaluators: List[Evaluator]
metadata: Dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.now)
def pass_rate(self) -> float:
"""Calculate overall pass rate (0.0 to 1.0)."""
if not self.results:
return 0.0
passed = sum(1 for r in self.results if r.overall_passed())
return passed / len(self.results)
def average_score(self) -> float:
"""Calculate average score across all test cases."""
if not self.results:
return 0.0
return sum(r.overall_score() for r in self.results) / len(self.results)
def average_time(self) -> float:
"""Calculate average execution time in milliseconds."""
if not self.results:
return 0.0
return sum(r.execution_time_ms for r in self.results) / len(self.results)
def total_tokens(self) -> int:
"""Calculate total tokens used across all test cases."""
return sum(r.agent_result.total_tokens for r in self.results)
def get_failures(self) -> List[TestCaseResult]:
"""Get all failed test cases."""
return [r for r in self.results if not r.overall_passed()]
def print_summary(self) -> None:
"""Print summary to console."""
print(f"\n{'=' * 80}")
print(f"EVALUATION REPORT: {self.agent_name}")
print(f"{'=' * 80}")
print(f"Timestamp: {self.timestamp.isoformat()}")
print(f"Test Cases: {len(self.results)}")
print(f"Pass Rate: {self.pass_rate():.1%}")
print(f"Average Score: {self.average_score():.2f}")
print(f"Average Time: {self.average_time():.0f}ms")
print(f"Total Tokens: {self.total_tokens()}")
print(f"{'=' * 80}\n")
failures = self.get_failures()
if failures:
print(f"FAILURES ({len(failures)}):")
for result in failures:
print(f"\n Test Case: {result.test_case.id}")
print(f" Message: {result.test_case.message}")
print(f" Score: {result.overall_score():.2f}")
for eval_result in result.evaluations:
if not eval_result.passed:
print(
f" [{eval_result.evaluator_name}] {eval_result.reasoning}"
)
@dataclass
class ComparisonReport:
"""Report comparing multiple agent variants.
This is the primary report type for LLM comparison use cases.
Attributes:
variants: List of agent variants compared
reports: Dict mapping variant name to EvaluationReport
test_cases: Test cases used for comparison
timestamp: When the comparison was run
"""
variants: List[AgentVariant]
reports: Dict[str, EvaluationReport]
test_cases: List[TestCase]
timestamp: datetime = field(default_factory=datetime.now)
def print_summary(self) -> None:
"""Print comparison summary to console."""
print("\n" + "=" * 80)
print("AGENT COMPARISON SUMMARY")
print("=" * 80)
print(f"Timestamp: {self.timestamp.isoformat()}")
print(f"Variants: {len(self.variants)}")
print(f"Test Cases: {len(self.test_cases)}")
# Table of results
print(
f"\n{'Agent':<25} {'Pass Rate':<12} {'Avg Score':<12} {'Avg Time':<12} {'Tokens':<12}"
)
print("-" * 80)
for variant_name, report in self.reports.items():
print(
f"{variant_name:<25} "
f"{report.pass_rate():<12.1%} "
f"{report.average_score():<12.2f} "
f"{report.average_time():<12.0f} "
f"{report.total_tokens():<12,}"
)
print("=" * 80 + "\n")
def get_best_variant(self, metric: str = "score") -> str:
"""Get the best performing variant by metric.
Args:
metric: Metric to optimize ('score', 'speed', 'pass_rate')
Returns:
Name of the best variant
"""
if metric == "score":
return max(self.reports.items(), key=lambda x: x[1].average_score())[0]
elif metric == "speed":
return min(self.reports.items(), key=lambda x: x[1].average_time())[0]
elif metric == "pass_rate":
return max(self.reports.items(), key=lambda x: x[1].pass_rate())[0]
else:
raise ValueError(f"Unknown metric: {metric}")
def save_csv(self, path: str) -> None:
"""Save detailed CSV for further analysis.
Each row represents one test case × one variant combination.
"""
with open(path, "w", newline="") as f:
writer = csv.writer(f)
# Header
writer.writerow(
[
"variant",
"test_case_id",
"test_message",
"passed",
"score",
"execution_time_ms",
"tokens",
"error",
"evaluator_scores",
]
)
# Data rows
for variant_name, report in self.reports.items():
for result in report.results:
evaluator_scores = {
e.evaluator_name: e.score for e in result.evaluations
}
writer.writerow(
[
variant_name,
result.test_case.id,
result.test_case.message[:50], # Truncate
result.overall_passed(),
result.overall_score(),
result.execution_time_ms,
result.agent_result.total_tokens,
result.agent_result.error or "",
str(evaluator_scores),
]
)
def save_html(self, path: str) -> None:
"""Save interactive HTML comparison report.
Generates a rich HTML report with:
- Summary statistics
- Charts comparing variants
- Side-by-side test case results
"""
html = self._generate_html()
with open(path, "w") as f:
f.write(html)
def _generate_html(self) -> str:
"""Generate HTML content for report."""
# Build HTML report
html_parts = [
"",
"",
"",
"Agent Comparison Report",
"",
"",
"",
f"
Agent Comparison Report
",
f"
Generated: {self.timestamp.isoformat()}
",
f"
Variants: {len(self.variants)} | Test Cases: {len(self.test_cases)}
",
]
# Summary table
html_parts.append("
Summary
")
html_parts.append("
")
html_parts.append(
"
Agent
Pass Rate
Avg Score
Avg Time (ms)
Total Tokens
"
)
best_by_score = self.get_best_variant("score")
for variant_name, report in self.reports.items():
row_class = "best" if variant_name == best_by_score else ""
html_parts.append(
f"
"
f"
{variant_name}
"
f"
{report.pass_rate():.1%}
"
f"
{report.average_score():.2f}
"
f"
{report.average_time():.0f}
"
f"
{report.total_tokens():,}
"
f"
"
)
html_parts.append("
")
# Test case details
html_parts.append("
Test Case Details
")
for i, test_case in enumerate(self.test_cases):
html_parts.append(f"
Test Case {i + 1}: {test_case.id}
")
html_parts.append(f"
Message: {test_case.message}
")
html_parts.append("
")
html_parts.append(
"
Variant
Result
Score
Time (ms)
"
)
for variant_name, report in self.reports.items():
result = next(
(r for r in report.results if r.test_case.id == test_case.id), None
)
if result:
passed_class = "passed" if result.overall_passed() else "failed"
passed_text = "PASS" if result.overall_passed() else "FAIL"
html_parts.append(
f"
"
f"
{variant_name}
"
f"
{passed_text}
"
f"
{result.overall_score():.2f}
"
f"
{result.execution_time_ms:.0f}
"
f"
"
)
html_parts.append("
")
html_parts.append("")
html_parts.append("")
return "\n".join(html_parts)
================================================
FILE: src/vanna/core/evaluation/runner.py
================================================
"""
Evaluation runner with parallel execution support.
This module provides the EvaluationRunner class that executes test cases
against agents with configurable parallelism for efficient evaluation,
especially when comparing multiple LLMs or model versions.
"""
import asyncio
from typing import Any, List, Dict, Optional, AsyncGenerator, TYPE_CHECKING
from datetime import datetime
from .base import (
TestCase,
AgentResult,
TestCaseResult,
AgentVariant,
Evaluator,
)
from vanna.core import UiComponent
from vanna.core.user.request_context import RequestContext
from vanna.core.observability import ObservabilityProvider
if TYPE_CHECKING:
from vanna import Agent
from .report import EvaluationReport, ComparisonReport
class EvaluationRunner:
"""Run evaluations with parallel execution support.
The primary use case is comparing multiple agent variants (e.g., different LLMs)
on the same set of test cases. The runner executes test cases in parallel with
configurable concurrency to handle I/O-bound LLM operations efficiently.
Example:
>>> runner = EvaluationRunner(
... evaluators=[TrajectoryEvaluator(), OutputEvaluator()],
... max_concurrency=20
... )
>>> comparison = await runner.compare_agents(
... agent_variants=[claude_variant, gpt_variant],
... test_cases=dataset.test_cases
... )
"""
def __init__(
self,
evaluators: List[Evaluator],
max_concurrency: int = 10,
observability_provider: Optional[ObservabilityProvider] = None,
):
"""Initialize the evaluation runner.
Args:
evaluators: List of evaluators to apply to each test case
max_concurrency: Maximum number of concurrent test case executions
observability_provider: Optional observability for tracking eval runs
"""
self.evaluators = evaluators
self.max_concurrency = max_concurrency
self.observability = observability_provider
self._semaphore = asyncio.Semaphore(max_concurrency)
async def run_evaluation(
self,
agent: "Agent",
test_cases: List[TestCase],
) -> "EvaluationReport":
"""Run evaluation on a single agent.
Args:
agent: The agent to evaluate
test_cases: List of test cases to run
Returns:
EvaluationReport with results for all test cases
"""
from .report import EvaluationReport
results = await self._run_test_cases_parallel(agent, test_cases)
return EvaluationReport(
agent_name="agent",
results=results,
evaluators=self.evaluators,
timestamp=datetime.now(),
)
async def compare_agents(
self,
agent_variants: List[AgentVariant],
test_cases: List[TestCase],
) -> "ComparisonReport":
"""Compare multiple agent variants on same test cases.
This is the PRIMARY use case for LLM comparison. Runs all variants
in parallel for maximum efficiency with I/O-bound LLM calls.
Args:
agent_variants: List of agent variants to compare
test_cases: Test cases to run on each variant
Returns:
ComparisonReport with results for all variants
"""
from .report import ComparisonReport
# Create span for overall comparison
if self.observability:
span = await self.observability.create_span(
"agent_comparison",
attributes={
"num_variants": len(agent_variants),
"num_test_cases": len(test_cases),
},
)
# Run all variants in parallel
tasks = [
self._run_agent_variant(variant, test_cases) for variant in agent_variants
]
variant_reports = await asyncio.gather(*tasks)
if self.observability:
await self.observability.end_span(span)
return ComparisonReport(
variants=agent_variants,
reports=dict(zip([v.name for v in agent_variants], variant_reports)),
test_cases=test_cases,
timestamp=datetime.now(),
)
async def compare_agents_streaming(
self,
agent_variants: List[AgentVariant],
test_cases: List[TestCase],
) -> AsyncGenerator[tuple[str, TestCaseResult, int, int], None]:
"""Stream comparison results as they complete.
Useful for long-running evaluations where you want to see
progress updates in real-time (e.g., for UI display).
Args:
agent_variants: Agent variants to compare
test_cases: Test cases to run
Yields:
Tuples of (variant_name, result, completed_count, total_count)
"""
queue: asyncio.Queue[tuple[str, TestCaseResult]] = asyncio.Queue()
async def worker(variant: AgentVariant) -> None:
"""Worker that runs test cases for one variant."""
results = await self._run_test_cases_parallel(variant.agent, test_cases)
for result in results:
await queue.put((variant.name, result))
# Start all workers
workers = [asyncio.create_task(worker(v)) for v in agent_variants]
# Yield results as they arrive
completed = 0
total = len(agent_variants) * len(test_cases)
while completed < total:
variant_name, result = await queue.get()
completed += 1
yield variant_name, result, completed, total
# Wait for all workers to complete
await asyncio.gather(*workers)
async def _run_agent_variant(
self,
variant: AgentVariant,
test_cases: List[TestCase],
) -> "EvaluationReport":
"""Run a single agent variant on all test cases.
Args:
variant: The agent variant to evaluate
test_cases: Test cases to run
Returns:
EvaluationReport for this variant
"""
from .report import EvaluationReport
if self.observability:
span = await self.observability.create_span(
f"variant_{variant.name}",
attributes={
"variant": variant.name,
"num_test_cases": len(test_cases),
**variant.metadata,
},
)
results = await self._run_test_cases_parallel(variant.agent, test_cases)
if self.observability:
await self.observability.end_span(span)
return EvaluationReport(
agent_name=variant.name,
results=results,
evaluators=self.evaluators,
metadata=variant.metadata,
timestamp=datetime.now(),
)
async def _run_test_cases_parallel(
self,
agent: "Agent",
test_cases: List[TestCase],
) -> List[TestCaseResult]:
"""Run test cases in parallel with concurrency limit.
Args:
agent: The agent to run test cases on
test_cases: Test cases to execute
Returns:
List of TestCaseResult, one per test case
"""
tasks = [
self._run_single_test_case(agent, test_case) for test_case in test_cases
]
return await asyncio.gather(*tasks)
async def _run_single_test_case(
self,
agent: "Agent",
test_case: TestCase,
) -> TestCaseResult:
"""Run a single test case with semaphore to limit concurrency.
Args:
agent: The agent to execute
test_case: The test case to run
Returns:
TestCaseResult with agent execution and evaluations
"""
async with self._semaphore:
# Execute agent
start_time = asyncio.get_event_loop().time()
agent_result = await self._execute_agent(agent, test_case)
execution_time = asyncio.get_event_loop().time() - start_time
# Run evaluators
eval_results = []
for evaluator in self.evaluators:
eval_result = await evaluator.evaluate(test_case, agent_result)
eval_results.append(eval_result)
return TestCaseResult(
test_case=test_case,
agent_result=agent_result,
evaluations=eval_results,
execution_time_ms=execution_time * 1000,
)
async def _execute_agent(
self,
agent: "Agent",
test_case: TestCase,
) -> AgentResult:
"""Execute agent and capture full trajectory.
Args:
agent: The agent to execute
test_case: The test case to run
Returns:
AgentResult with all captured data
"""
components: List[UiComponent] = []
tool_calls: List[Dict[str, Any]] = []
error: Optional[str] = None
try:
# Create request context with user info from test case
# This allows the agent's UserResolver to resolve the correct user
request_context = RequestContext(
cookies={"user_id": test_case.user.id},
headers={},
metadata={"test_case_user": test_case.user},
)
async for component in agent.send_message(
request_context=request_context,
message=test_case.message,
conversation_id=test_case.conversation_id,
):
components.append(component)
except Exception as e:
error = str(e)
# TODO: Extract tool calls and LLM requests from observability
# For now, these will be empty unless we hook into observability
return AgentResult(
test_case_id=test_case.id,
components=components,
tool_calls=tool_calls,
llm_requests=[],
error=error,
)
================================================
FILE: src/vanna/core/filter/__init__.py
================================================
"""
Conversation filtering system for managing conversation history.
This module provides interfaces for filtering and transforming conversation
history before it's sent to the LLM.
"""
from .base import ConversationFilter
__all__ = ["ConversationFilter"]
================================================
FILE: src/vanna/core/filter/base.py
================================================
"""
Base conversation filter interface.
Conversation filters allow you to transform conversation history before
it's sent to the LLM for processing.
"""
from abc import ABC
from typing import TYPE_CHECKING, List
if TYPE_CHECKING:
from ..storage import Message
class ConversationFilter(ABC):
"""Filter for transforming conversation history.
Subclass this to create custom filters that can:
- Remove sensitive information
- Summarize long conversations
- Manage context window limits
- Deduplicate similar messages
- Prioritize recent or relevant messages
Example:
class ContextWindowFilter(ConversationFilter):
def __init__(self, max_tokens: int = 8000):
self.max_tokens = max_tokens
async def filter_messages(self, messages: List[Message]) -> List[Message]:
# Estimate tokens (rough approximation)
total_tokens = 0
filtered = []
# Keep system message and recent messages
for msg in reversed(messages):
msg_tokens = len(msg.content or "") // 4
if total_tokens + msg_tokens > self.max_tokens:
break
filtered.insert(0, msg)
total_tokens += msg_tokens
return filtered
agent = AgentRunner(
llm_service=...,
conversation_filters=[
SensitiveDataFilter(),
ContextWindowFilter(max_tokens=8000)
]
)
"""
async def filter_messages(self, messages: List["Message"]) -> List["Message"]:
"""Filter and transform conversation messages.
Args:
messages: List of conversation messages
Returns:
Filtered/transformed list of messages
Note:
Filters are applied in order, so messages passed to later
filters may already be modified by earlier filters.
"""
return messages
================================================
FILE: src/vanna/core/lifecycle/__init__.py
================================================
"""
Lifecycle hook system for agent execution.
This module provides hooks for intercepting and modifying agent behavior
at various points in the execution lifecycle.
"""
from .base import LifecycleHook
__all__ = ["LifecycleHook"]
================================================
FILE: src/vanna/core/lifecycle/base.py
================================================
"""
Base lifecycle hook interface.
Lifecycle hooks allow you to intercept and customize agent behavior
at key points in the execution flow.
"""
from abc import ABC
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from ..user.models import User
from ..tool import Tool
from ..tool.models import ToolContext, ToolResult
class LifecycleHook(ABC):
"""Hook into agent execution lifecycle.
Subclass this to create custom hooks that can:
- Modify messages before processing
- Add logging or telemetry
- Enforce quotas or rate limits
- Transform tool results
- Add custom validation
Example:
class LoggingHook(LifecycleHook):
async def before_message(self, user: User, message: str) -> Optional[str]:
print(f"User {user.username} sent: {message}")
return None # Don't modify
agent = AgentRunner(
llm_service=...,
lifecycle_hooks=[LoggingHook(), QuotaCheckHook()]
)
"""
async def before_message(self, user: "User", message: str) -> Optional[str]:
"""Called before processing a user message.
Args:
user: User sending the message
message: Original message content
Returns:
Modified message string, or None to keep original
Raises:
AgentError: To halt message processing (e.g., quota exceeded)
"""
return None
async def after_message(self, result: Any) -> None:
"""Called after message has been fully processed.
Args:
result: Final result from message processing
"""
pass
async def before_tool(self, tool: "Tool[Any]", context: "ToolContext") -> None:
"""Called before tool execution.
Args:
tool: Tool about to be executed
context: Tool execution context
Raises:
AgentError: To prevent tool execution
"""
pass
async def after_tool(self, result: "ToolResult") -> Optional["ToolResult"]:
"""Called after tool execution.
Args:
result: Result from tool execution
Returns:
Modified ToolResult, or None to keep original
"""
return None
================================================
FILE: src/vanna/core/llm/__init__.py
================================================
"""
LLM domain.
This module provides the core abstractions for LLM services in the Vanna Agents framework.
"""
from .base import LlmService
from .models import LlmMessage, LlmRequest, LlmResponse, LlmStreamChunk
__all__ = [
"LlmService",
"LlmMessage",
"LlmRequest",
"LlmResponse",
"LlmStreamChunk",
]
================================================
FILE: src/vanna/core/llm/base.py
================================================
"""
LLM domain interface.
This module contains the abstract base class for LLM services.
"""
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List
from .models import LlmRequest, LlmResponse, LlmStreamChunk
class LlmService(ABC):
"""Service for LLM communication."""
@abstractmethod
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Send a request to the LLM."""
pass
@abstractmethod
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Stream a request to the LLM.
Args:
request: The LLM request to stream
Yields:
LlmStreamChunk instances as they arrive
"""
# This is an async generator method
raise NotImplementedError
yield # pragma: no cover - makes this an async generator
@abstractmethod
async def validate_tools(self, tools: List[Any]) -> List[str]:
"""Validate tool schemas and return any errors."""
pass
================================================
FILE: src/vanna/core/llm/models.py
================================================
"""
LLM domain models.
This module contains data models for LLM communication.
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from ..tool.models import ToolCall
from ..user.models import User
class LlmMessage(BaseModel):
"""Message format for LLM communication."""
role: str = Field(description="Message role")
content: str = Field(description="Message content")
tool_calls: Optional[List[ToolCall]] = Field(default=None)
tool_call_id: Optional[str] = Field(default=None)
class LlmRequest(BaseModel):
"""Request to LLM service."""
messages: List[LlmMessage] = Field(description="Messages to send")
tools: Optional[List[Any]] = Field(
default=None, description="Available tools"
) # Will be ToolSchema but avoiding circular import
user: User = Field(description="User making the request")
stream: bool = Field(default=False, description="Whether to stream response")
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=None, gt=0)
system_prompt: Optional[str] = Field(
default=None, description="System prompt for the LLM"
)
metadata: Dict[str, Any] = Field(default_factory=dict)
class LlmResponse(BaseModel):
"""Response from LLM."""
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
finish_reason: Optional[str] = None
usage: Optional[Dict[str, int]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
def is_tool_call(self) -> bool:
"""Check if this response contains tool calls."""
return self.tool_calls is not None and len(self.tool_calls) > 0
class LlmStreamChunk(BaseModel):
"""Streaming chunk from LLM."""
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
finish_reason: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
================================================
FILE: src/vanna/core/middleware/__init__.py
================================================
"""
Middleware system for LLM request/response interception.
This module provides middleware interfaces for intercepting and transforming
LLM requests and responses.
"""
from .base import LlmMiddleware
__all__ = ["LlmMiddleware"]
================================================
FILE: src/vanna/core/middleware/base.py
================================================
"""
Base LLM middleware interface.
Middleware allows you to intercept and transform LLM requests and responses
for caching, monitoring, content filtering, and more.
"""
from abc import ABC
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..llm import LlmRequest, LlmResponse
class LlmMiddleware(ABC):
"""Middleware for intercepting LLM requests and responses.
Subclass this to create custom middleware that can:
- Cache LLM responses
- Log requests/responses
- Filter or modify content
- Track costs and usage
- Implement fallback strategies
Example:
class CachingMiddleware(LlmMiddleware):
def __init__(self):
self.cache = {}
async def before_llm_request(self, request: LlmRequest) -> LlmRequest:
# Could check cache here
return request
async def after_llm_response(self, request: LlmRequest, response: LlmResponse) -> LlmResponse:
# Cache the response
cache_key = self._compute_key(request)
self.cache[cache_key] = response
return response
agent = AgentRunner(
llm_service=...,
llm_middlewares=[CachingMiddleware(), LoggingMiddleware()]
)
"""
async def before_llm_request(self, request: "LlmRequest") -> "LlmRequest":
"""Called before sending request to LLM.
Args:
request: The LLM request about to be sent
Returns:
Modified request, or original if no changes
"""
return request
async def after_llm_response(
self, request: "LlmRequest", response: "LlmResponse"
) -> "LlmResponse":
"""Called after receiving response from LLM.
Args:
request: The original request
response: The LLM response
Returns:
Modified response, or original if no changes
"""
return response
================================================
FILE: src/vanna/core/observability/__init__.py
================================================
"""
Observability system for telemetry and monitoring.
This module provides interfaces for collecting metrics, traces, and
monitoring agent behavior.
"""
from .base import ObservabilityProvider
from .models import Span, Metric
__all__ = ["ObservabilityProvider", "Span", "Metric"]
================================================
FILE: src/vanna/core/observability/base.py
================================================
"""
Base observability provider interface.
Observability providers allow you to collect telemetry data about
agent execution for monitoring and debugging.
"""
from abc import ABC
from typing import Any, Dict, Optional
from .models import Span, Metric
class ObservabilityProvider(ABC):
"""Provider for collecting telemetry and observability data.
Subclass this to create custom observability integrations that can:
- Emit metrics to monitoring systems
- Create distributed traces
- Log performance data
- Track costs and usage
- Monitor error rates
Example:
class PrometheusProvider(ObservabilityProvider):
def __init__(self, registry):
self.registry = registry
self.request_counter = Counter(
'agent_requests_total',
'Total agent requests',
registry=registry
)
async def record_metric(self, name: str, value: float, tags: Dict[str, str]) -> None:
if name == "agent.request":
self.request_counter.inc()
async def create_span(self, name: str, attributes: Optional[Dict[str, Any]] = None) -> Span:
span = Span(name=name, attributes=attributes or {})
return span
agent = AgentRunner(
llm_service=...,
observability_provider=PrometheusProvider(registry)
)
"""
async def record_metric(
self,
name: str,
value: float,
unit: str = "",
tags: Optional[Dict[str, str]] = None,
) -> None:
"""Record a metric measurement.
Args:
name: Metric name (e.g., "agent.request.duration")
value: Metric value
unit: Unit of measurement (e.g., "ms", "tokens")
tags: Additional tags/labels for the metric
"""
pass
async def create_span(
self, name: str, attributes: Optional[Dict[str, Any]] = None
) -> Span:
"""Create a new span for tracing.
Args:
name: Span name/operation
attributes: Initial span attributes
Returns:
Span object to track the operation
Note:
Call span.end() when the operation completes.
"""
return Span(name=name, attributes=attributes or {})
async def end_span(self, span: Span) -> None:
"""End a span and record it.
Args:
span: The span to end
"""
span.end()
================================================
FILE: src/vanna/core/observability/models.py
================================================
"""
Observability models for spans and metrics.
"""
import time
from typing import Any, Dict, Optional
from uuid import uuid4
from pydantic import BaseModel, Field
class Span(BaseModel):
"""Represents a unit of work for distributed tracing."""
id: str = Field(default_factory=lambda: str(uuid4()), description="Span ID")
name: str = Field(description="Span name/operation")
start_time: float = Field(default_factory=time.time, description="Start timestamp")
end_time: Optional[float] = Field(default=None, description="End timestamp")
attributes: Dict[str, Any] = Field(
default_factory=dict, description="Span attributes"
)
parent_id: Optional[str] = Field(default=None, description="Parent span ID")
def end(self) -> None:
"""Mark span as ended."""
if self.end_time is None:
self.end_time = time.time()
def duration_ms(self) -> Optional[float]:
"""Get span duration in milliseconds."""
if self.end_time is None:
return None
return (self.end_time - self.start_time) * 1000
def set_attribute(self, key: str, value: Any) -> None:
"""Set a span attribute."""
self.attributes[key] = value
class Metric(BaseModel):
"""Represents a metric measurement."""
name: str = Field(description="Metric name")
value: float = Field(description="Metric value")
unit: str = Field(default="", description="Unit of measurement")
tags: Dict[str, str] = Field(default_factory=dict, description="Metric tags")
timestamp: float = Field(default_factory=time.time, description="Measurement time")
================================================
FILE: src/vanna/core/recovery/__init__.py
================================================
"""
Error recovery system for handling failures gracefully.
This module provides interfaces for custom error handling, retry logic,
and fallback strategies.
"""
from .base import ErrorRecoveryStrategy
from .models import RecoveryAction, RecoveryActionType
__all__ = ["ErrorRecoveryStrategy", "RecoveryAction", "RecoveryActionType"]
================================================
FILE: src/vanna/core/recovery/base.py
================================================
"""
Base error recovery strategy interface.
Recovery strategies allow you to customize how the agent handles errors
during tool execution and LLM communication.
"""
from abc import ABC
from typing import TYPE_CHECKING
from .models import RecoveryAction, RecoveryActionType
if TYPE_CHECKING:
from ..tool.models import ToolContext
from ..llm import LlmRequest
class ErrorRecoveryStrategy(ABC):
"""Strategy for handling errors and implementing retry logic.
Subclass this to create custom error recovery strategies that can:
- Retry failed operations with backoff
- Fallback to alternative approaches
- Log errors to external systems
- Gracefully degrade functionality
Example:
class ExponentialBackoffStrategy(ErrorRecoveryStrategy):
async def handle_tool_error(
self, error: Exception, context: ToolContext, attempt: int
) -> RecoveryAction:
if attempt < 3:
delay = (2 ** attempt) * 1000 # Exponential backoff
return RecoveryAction(
action=RecoveryActionType.RETRY,
retry_delay_ms=delay,
message=f"Retrying after {delay}ms"
)
return RecoveryAction(
action=RecoveryActionType.FAIL,
message="Max retries exceeded"
)
agent = AgentRunner(
llm_service=...,
error_recovery_strategy=ExponentialBackoffStrategy()
)
"""
async def handle_tool_error(
self, error: Exception, context: "ToolContext", attempt: int = 1
) -> RecoveryAction:
"""Handle errors during tool execution.
Args:
error: The exception that occurred
context: Tool execution context
attempt: Current attempt number (1-indexed)
Returns:
RecoveryAction indicating how to proceed
"""
# Default: fail immediately
return RecoveryAction(
action=RecoveryActionType.FAIL, message=f"Tool error: {str(error)}"
)
async def handle_llm_error(
self, error: Exception, request: "LlmRequest", attempt: int = 1
) -> RecoveryAction:
"""Handle errors during LLM communication.
Args:
error: The exception that occurred
request: The LLM request that failed
attempt: Current attempt number (1-indexed)
Returns:
RecoveryAction indicating how to proceed
"""
# Default: fail immediately
return RecoveryAction(
action=RecoveryActionType.FAIL, message=f"LLM error: {str(error)}"
)
================================================
FILE: src/vanna/core/recovery/models.py
================================================
"""
Recovery action models for error handling.
"""
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field
class RecoveryActionType(str, Enum):
"""Types of recovery actions."""
RETRY = "retry"
FAIL = "fail"
FALLBACK = "fallback"
SKIP = "skip"
class RecoveryAction(BaseModel):
"""Action to take when recovering from an error."""
action: RecoveryActionType = Field(description="Type of recovery action")
retry_delay_ms: Optional[int] = Field(
default=None, description="Delay before retry in milliseconds"
)
fallback_value: Optional[Any] = Field(
default=None, description="Fallback value to use"
)
message: Optional[str] = Field(
default=None, description="Message to include with action"
)
================================================
FILE: src/vanna/core/registry.py
================================================
"""
Tool registry for the Vanna Agents framework.
This module provides the ToolRegistry class for managing and executing tools.
"""
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
from .tool import Tool, ToolCall, ToolContext, ToolRejection, ToolResult, ToolSchema
from .user import User
if TYPE_CHECKING:
from .audit import AuditLogger
from .agent.config import AuditConfig
T = TypeVar("T")
class _LocalToolWrapper(Tool[T]):
"""Wrapper for tools with configurable access groups."""
def __init__(self, wrapped_tool: Tool[T], access_groups: List[str]):
self._wrapped_tool = wrapped_tool
self._access_groups = access_groups
@property
def name(self) -> str:
return self._wrapped_tool.name
@property
def description(self) -> str:
return self._wrapped_tool.description
@property
def access_groups(self) -> List[str]:
return self._access_groups
def get_args_schema(self) -> Type[T]:
return self._wrapped_tool.get_args_schema()
async def execute(self, context: ToolContext, args: T) -> ToolResult:
return await self._wrapped_tool.execute(context, args)
class ToolRegistry:
"""Registry for managing tools."""
def __init__(
self,
audit_logger: Optional["AuditLogger"] = None,
audit_config: Optional["AuditConfig"] = None,
) -> None:
self._tools: Dict[str, Tool[Any]] = {}
self.audit_logger = audit_logger
if audit_config is not None:
self.audit_config = audit_config
else:
from .agent.config import AuditConfig
self.audit_config = AuditConfig()
def register_local_tool(self, tool: Tool[Any], access_groups: List[str]) -> None:
"""Register a local tool with optional access group restrictions.
Args:
tool: The tool to register
access_groups: List of groups that can access this tool.
If None or empty, tool is accessible to all users.
"""
if tool.name in self._tools:
raise ValueError(f"Tool '{tool.name}' already registered")
if access_groups:
# Wrap the tool with access groups
wrapped_tool = _LocalToolWrapper(tool, access_groups)
self._tools[tool.name] = wrapped_tool
else:
# No access restrictions, register as-is
self._tools[tool.name] = tool
async def get_tool(self, name: str) -> Optional[Tool[Any]]:
"""Get a tool by name."""
return self._tools.get(name)
async def list_tools(self) -> List[str]:
"""List all registered tool names."""
return list(self._tools.keys())
async def get_schemas(self, user: Optional[User] = None) -> List[ToolSchema]:
"""Get schemas for all tools accessible to user."""
schemas = []
for tool in self._tools.values():
if user is None or await self._validate_tool_permissions(tool, user):
schemas.append(tool.get_schema())
return schemas
async def _validate_tool_permissions(self, tool: Tool[Any], user: User) -> bool:
"""Validate if user has access to tool based on group membership.
Checks for intersection between user's group memberships and tool's access groups.
If tool has no access groups specified, it's accessible to all users.
"""
tool_access_groups = tool.access_groups
if not tool_access_groups:
return True
user_groups = set(user.group_memberships)
tool_groups = set(tool_access_groups)
# Grant access if any group in user.group_memberships exists in tool.access_groups
return bool(user_groups & tool_groups)
async def transform_args(
self,
tool: Tool[T],
args: T,
user: User,
context: ToolContext,
) -> Union[T, ToolRejection]:
"""Transform and validate tool arguments based on user context.
This method allows per-user transformation of tool arguments, such as:
- Applying row-level security (RLS) to SQL queries
- Filtering available options based on user permissions
- Validating required arguments are present
- Redacting sensitive fields
The default implementation performs no transformation (NoOp).
Subclasses can override this method to implement custom transformation logic.
Args:
tool: The tool being executed
args: Already Pydantic-validated arguments
user: The user executing the tool
context: Full execution context
Returns:
Either:
- Transformed arguments (may be unchanged if no transformation needed)
- ToolRejection with explanation of why args were rejected
"""
return args # Default: no transformation (NoOp)
async def execute(
self,
tool_call: ToolCall,
context: ToolContext,
) -> ToolResult:
"""Execute a tool call with validation."""
tool = await self.get_tool(tool_call.name)
if not tool:
msg = f"Tool '{tool_call.name}' not found"
return ToolResult(
success=False,
result_for_llm=msg,
ui_component=None,
error=msg,
)
# Validate group access
if not await self._validate_tool_permissions(tool, context.user):
msg = f"Insufficient group access for tool '{tool_call.name}'"
# Audit access denial
if (
self.audit_logger
and self.audit_config
and self.audit_config.log_tool_access_checks
):
await self.audit_logger.log_tool_access_check(
user=context.user,
tool_name=tool_call.name,
access_granted=False,
required_groups=tool.access_groups,
context=context,
reason=msg,
)
return ToolResult(
success=False,
result_for_llm=msg,
ui_component=None,
error=msg,
)
# Validate and parse arguments
try:
args_model = tool.get_args_schema()
validated_args = args_model.model_validate(tool_call.arguments)
except Exception as e:
msg = f"Invalid arguments: {str(e)}"
return ToolResult(
success=False,
result_for_llm=msg,
ui_component=None,
error=msg,
)
# Transform/validate arguments based on user context
transform_result = await self.transform_args(
tool=tool,
args=validated_args,
user=context.user,
context=context,
)
if isinstance(transform_result, ToolRejection):
return ToolResult(
success=False,
result_for_llm=transform_result.reason,
ui_component=None,
error=transform_result.reason,
)
# Use transformed arguments for execution
final_args = transform_result
# Audit successful access check
if (
self.audit_logger
and self.audit_config
and self.audit_config.log_tool_access_checks
):
await self.audit_logger.log_tool_access_check(
user=context.user,
tool_name=tool_call.name,
access_granted=True,
required_groups=tool.access_groups,
context=context,
)
# Audit tool invocation
if (
self.audit_logger
and self.audit_config
and self.audit_config.log_tool_invocations
):
# Get UI features if available from context
ui_features = context.metadata.get("ui_features_available", [])
await self.audit_logger.log_tool_invocation(
user=context.user,
tool_call=tool_call,
ui_features=ui_features,
context=context,
sanitize_parameters=self.audit_config.sanitize_tool_parameters,
)
# Execute tool with context-first signature
try:
start_time = time.perf_counter()
result = await tool.execute(context, final_args)
execution_time_ms = (time.perf_counter() - start_time) * 1000
# Add execution time to metadata
result.metadata["execution_time_ms"] = execution_time_ms
# Audit tool result
if (
self.audit_logger
and self.audit_config
and self.audit_config.log_tool_results
):
await self.audit_logger.log_tool_result(
user=context.user,
tool_call=tool_call,
result=result,
context=context,
)
return result
except Exception as e:
msg = f"Execution failed: {str(e)}"
return ToolResult(
success=False,
result_for_llm=msg,
ui_component=None,
error=msg,
)
================================================
FILE: src/vanna/core/rich_component.py
================================================
"""
Base classes for rich UI components.
This module provides the base RichComponent class and supporting enums
for the component system.
"""
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, TypeVar
from pydantic import BaseModel, Field
# Type variable for self-returning methods
T = TypeVar("T", bound="RichComponent")
class ComponentType(str, Enum):
"""Types of rich UI components."""
# Basic components
TEXT = "text"
CARD = "card"
CONTAINER = "container"
# Primitive UI components (domain-agnostic)
STATUS_CARD = "status_card"
PROGRESS_DISPLAY = "progress_display"
LOG_VIEWER = "log_viewer"
BADGE = "badge"
ICON_TEXT = "icon_text"
# Interactive components
TASK_LIST = "task_list"
PROGRESS_BAR = "progress_bar"
BUTTON = "button"
BUTTON_GROUP = "button_group"
# Data components
TABLE = "table"
DATAFRAME = "dataframe"
CHART = "chart"
CODE_BLOCK = "code_block"
# Status components
STATUS_INDICATOR = "status_indicator"
NOTIFICATION = "notification"
ALERT = "alert"
# Artifact components
ARTIFACT = "artifact"
# UI state components
STATUS_BAR_UPDATE = "status_bar_update"
TASK_TRACKER_UPDATE = "task_tracker_update"
CHAT_INPUT_UPDATE = "chat_input_update"
# Legacy (deprecated - use primitives instead)
TOOL_EXECUTION = "tool_execution"
class ComponentLifecycle(str, Enum):
"""Component lifecycle operations."""
CREATE = "create"
UPDATE = "update"
REPLACE = "replace"
REMOVE = "remove"
class RichComponent(BaseModel):
"""Base class for all rich UI components."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
type: ComponentType
lifecycle: ComponentLifecycle = ComponentLifecycle.CREATE
data: Dict[str, Any] = Field(default_factory=dict)
children: List[str] = Field(default_factory=list) # Child component IDs
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
visible: bool = True
interactive: bool = False
def update(self: T, **kwargs: Any) -> T:
"""Create an updated copy of this component."""
updated_data = self.model_dump()
updated_data.update(kwargs)
updated_data["lifecycle"] = ComponentLifecycle.UPDATE
updated_data["timestamp"] = datetime.utcnow().isoformat()
return self.__class__(**updated_data)
def hide(self: T) -> T:
"""Create a hidden copy of this component."""
return self.update(visible=False)
def show(self: T) -> T:
"""Create a visible copy of this component."""
return self.update(visible=True)
def serialize_for_frontend(self) -> Dict[str, Any]:
"""Normalize component payload for the frontend renderer.
The frontend expects component-specific fields to live under the
``data`` key while the shared metadata (``id``, ``type``, layout hints,
etc.) remains at the top level. Pydantic's ``model_dump`` keeps
component attributes at the top level, so we remap them here before
streaming them across the wire.
"""
# Base fields that should remain at the top level of the payload.
shared_fields = {
"id",
"type",
"lifecycle",
"children",
"timestamp",
"visible",
"interactive",
}
raw = self.model_dump()
payload: Dict[str, Any] = {}
# Preserve any existing data payload so implementations can opt-in to
# advanced usage without losing information.
raw_data = raw.get("data")
if raw_data is not None and isinstance(raw_data, dict):
component_data: Dict[str, Any] = raw_data.copy()
else:
# Handle case where data might be a sequence or other type, or None
component_data = {}
for key, value in raw.items():
if key in shared_fields:
payload[key] = value
elif key == "data":
# For most components, skip the base data field
continue
elif (
key == "rows"
and hasattr(self, "type")
and self.type.value == "dataframe"
):
# For DataFrame components, the 'rows' field contains the actual row data
# which should be included in the component_data as 'data' for the frontend
component_data["data"] = value
else:
component_data[key] = value
payload["data"] = component_data
# Ensure enums are serialized as primitive values for the frontend.
payload["type"] = self.type.value
payload["lifecycle"] = self.lifecycle.value
return payload
================================================
FILE: src/vanna/core/simple_component.py
================================================
"""Base classes for simple UI components."""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from enum import Enum
class SimpleComponentType(str, Enum):
TEXT = "text"
IMAGE = "image"
LINK = "link"
class SimpleComponent(BaseModel):
"""A simple UI component with basic attributes."""
type: SimpleComponentType = Field(..., description="Type of the component.")
semantic_type: Optional[str] = Field(
default=None, description="Semantic type for better categorization."
)
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Additional metadata for the component."
)
def serialize_for_frontend(self) -> Dict[str, Any]:
"""Serialize simple component for API consumption."""
return self.model_dump()
================================================
FILE: src/vanna/core/storage/__init__.py
================================================
"""
Storage domain.
This module provides the core abstractions for conversation storage in the Vanna Agents framework.
"""
from .base import ConversationStore
from .models import Conversation, Message
__all__ = [
"ConversationStore",
"Conversation",
"Message",
]
================================================
FILE: src/vanna/core/storage/base.py
================================================
"""
Storage domain interface.
This module contains the abstract base class for conversation storage.
"""
from abc import ABC, abstractmethod
from typing import List, Optional
from .models import Conversation
from ..user.models import User
class ConversationStore(ABC):
"""Abstract base class for conversation storage."""
@abstractmethod
async def create_conversation(
self, conversation_id: str, user: User, initial_message: str
) -> Conversation:
"""Create a new conversation with the specified ID."""
pass
@abstractmethod
async def get_conversation(
self, conversation_id: str, user: User
) -> Optional[Conversation]:
"""Get conversation by ID, scoped to user."""
pass
@abstractmethod
async def update_conversation(self, conversation: Conversation) -> None:
"""Update conversation with new messages."""
pass
@abstractmethod
async def delete_conversation(self, conversation_id: str, user: User) -> bool:
"""Delete conversation."""
pass
@abstractmethod
async def list_conversations(
self, user: User, limit: int = 50, offset: int = 0
) -> List[Conversation]:
"""List conversations for user."""
pass
================================================
FILE: src/vanna/core/storage/models.py
================================================
"""
Storage domain models.
This module contains data models for conversation storage.
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from ..tool.models import ToolCall
from ..user.models import User
class Message(BaseModel):
"""Single message in a conversation."""
role: str = Field(description="Message role (user/assistant/system/tool)")
content: str = Field(description="Message content")
timestamp: datetime = Field(default_factory=datetime.utcnow)
metadata: Dict[str, Any] = Field(default_factory=dict)
tool_calls: Optional[List[ToolCall]] = Field(default=None)
tool_call_id: Optional[str] = Field(
default=None, description="ID if this is a tool response"
)
class Conversation(BaseModel):
"""Conversation containing multiple messages."""
id: str = Field(description="Unique conversation identifier")
user: User = Field(description="User this conversation belongs to")
messages: List[Message] = Field(
default_factory=list, description="Messages in conversation"
)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Additional conversation metadata"
)
def add_message(self, message: Message) -> None:
"""Add a message to the conversation."""
self.messages.append(message)
self.updated_at = datetime.utcnow()
================================================
FILE: src/vanna/core/system_prompt/__init__.py
================================================
"""
System prompt domain.
This module provides the core abstractions for building system prompts in the Vanna Agents framework.
"""
from .base import SystemPromptBuilder
from .default import DefaultSystemPromptBuilder
__all__ = [
"SystemPromptBuilder",
"DefaultSystemPromptBuilder",
]
================================================
FILE: src/vanna/core/system_prompt/base.py
================================================
"""
System prompt builder interface.
This module contains the abstract base class for system prompt builders.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from ..tool.models import ToolSchema
from ..user.models import User
class SystemPromptBuilder(ABC):
"""Abstract base class for system prompt builders.
Subclasses should implement the build_system_prompt method to generate
system prompts based on user context and available tools.
"""
@abstractmethod
async def build_system_prompt(
self, user: "User", tools: List["ToolSchema"]
) -> Optional[str]:
"""
Build a system prompt based on user context and available tools.
Args:
user: The user making the request
tools: List of tools available to the user
Returns:
System prompt string, or None if no system prompt should be used
"""
pass
================================================
FILE: src/vanna/core/system_prompt/default.py
================================================
"""
Default system prompt builder implementation with memory workflow support.
This module provides a default implementation of the SystemPromptBuilder interface
that automatically includes memory workflow instructions when memory tools are available.
"""
from typing import TYPE_CHECKING, List, Optional
from datetime import datetime
from .base import SystemPromptBuilder
if TYPE_CHECKING:
from ..tool.models import ToolSchema
from ..user.models import User
class DefaultSystemPromptBuilder(SystemPromptBuilder):
"""Default system prompt builder with automatic memory workflow integration.
Dynamically generates system prompts that include memory workflow
instructions when memory tools (search_saved_correct_tool_uses and
save_question_tool_args) are available.
"""
def __init__(self, base_prompt: Optional[str] = None):
"""Initialize with an optional base prompt.
Args:
base_prompt: Optional base system prompt. If not provided, uses a default.
"""
self.base_prompt = base_prompt
async def build_system_prompt(
self, user: "User", tools: List["ToolSchema"]
) -> Optional[str]:
"""
Build a system prompt with memory workflow instructions.
Args:
user: The user making the request
tools: List of tools available to the user
Returns:
System prompt string with memory workflow instructions if applicable
"""
if self.base_prompt is not None:
return self.base_prompt
# Check which memory tools are available
tool_names = [tool.name for tool in tools]
has_search = "search_saved_correct_tool_uses" in tool_names
has_save = "save_question_tool_args" in tool_names
has_text_memory = "save_text_memory" in tool_names
# Get today's date
today_date = datetime.now().strftime("%Y-%m-%d")
# Base system prompt
prompt_parts = [
f"You are Vanna, an AI data analyst assistant created to help users with data analysis tasks. Today's date is {today_date}.",
"",
"Response Guidelines:",
"- Any summary of what you did or observations should be the final step.",
"- Use the available tools to help the user accomplish their goals.",
"- When you execute a query, that raw result is shown to the user outside of your response so YOU DO NOT need to include it in your response. Focus on summarizing and interpreting the results.",
]
if tools:
prompt_parts.append(
f"\nYou have access to the following tools: {', '.join(tool_names)}"
)
# Add memory workflow instructions based on available tools
if has_search or has_save or has_text_memory:
prompt_parts.append("\n" + "=" * 60)
prompt_parts.append("MEMORY SYSTEM:")
prompt_parts.append("=" * 60)
if has_search or has_save:
prompt_parts.append("\n1. TOOL USAGE MEMORY (Structured Workflow):")
prompt_parts.append("-" * 50)
if has_search:
prompt_parts.extend(
[
"",
"• BEFORE executing any tool (run_sql, visualize_data, or calculator), you MUST first call search_saved_correct_tool_uses with the user's question to check if there are existing successful patterns for similar questions.",
"",
"• Review the search results (if any) to inform your approach before proceeding with other tool calls.",
]
)
if has_save:
prompt_parts.extend(
[
"",
"• AFTER successfully executing a tool that produces correct and useful results, you MUST call save_question_tool_args to save the successful pattern for future use.",
]
)
if has_search or has_save:
prompt_parts.extend(
[
"",
"Example workflow:",
" • User asks a question",
f' • First: Call search_saved_correct_tool_uses(question="user\'s question")'
if has_search
else "",
" • Then: Execute the appropriate tool(s) based on search results and the question",
f' • Finally: If successful, call save_question_tool_args(question="user\'s question", tool_name="tool_used", args={{the args you used}})'
if has_save
else "",
"",
"Do NOT skip the search step, even if you think you know how to answer. Do NOT forget to save successful executions."
if has_search
else "",
"",
"The only exceptions to searching first are:",
' • When the user is explicitly asking about the tools themselves (like "list the tools")',
" • When the user is testing or asking you to demonstrate the save/search functionality itself",
]
)
if has_text_memory:
prompt_parts.extend(
[
"",
"2. TEXT MEMORY (Domain Knowledge & Context):",
"-" * 50,
"",
"• save_text_memory: Save important context about the database, schema, or domain",
"",
"Use text memory to save:",
" • Database schema details (column meanings, data types, relationships)",
" • Company-specific terminology and definitions",
" • Query patterns or best practices for this database",
" • Domain knowledge about the business or data",
" • User preferences for queries or visualizations",
"",
"DO NOT save:",
" • Information already captured in tool usage memory",
" • One-time query results or temporary observations",
"",
"Examples:",
' • save_text_memory(content="The status column uses 1 for active, 0 for inactive")',
' • save_text_memory(content="MRR means Monthly Recurring Revenue in our schema")',
" • save_text_memory(content=\"Always exclude test accounts where email contains 'test'\")",
]
)
if has_search or has_save or has_text_memory:
# Remove empty strings from the list
prompt_parts = [part for part in prompt_parts if part != ""]
return "\n".join(prompt_parts)
================================================
FILE: src/vanna/core/tool/__init__.py
================================================
"""
Tool domain.
This module provides the core abstractions for tools in the Vanna Agents framework.
"""
from .base import T, Tool
from .models import ToolCall, ToolContext, ToolRejection, ToolResult, ToolSchema
__all__ = [
"Tool",
"T",
"ToolCall",
"ToolContext",
"ToolRejection",
"ToolResult",
"ToolSchema",
]
================================================
FILE: src/vanna/core/tool/base.py
================================================
"""
Tool domain interface.
This module contains the abstract base class for tools.
"""
from abc import ABC, abstractmethod
from typing import Generic, List, Type, TypeVar
from .models import ToolContext, ToolResult, ToolSchema
# Type variable for tool argument types
T = TypeVar("T")
class Tool(ABC, Generic[T]):
"""Abstract base class for tools."""
@property
@abstractmethod
def name(self) -> str:
"""Unique name for this tool."""
pass
@property
@abstractmethod
def description(self) -> str:
"""Description of what this tool does."""
pass
@property
def access_groups(self) -> List[str]:
"""Groups permitted to access this tool."""
return []
@abstractmethod
def get_args_schema(self) -> Type[T]:
"""Return the Pydantic model for arguments."""
pass
@abstractmethod
async def execute(self, context: ToolContext, args: T) -> ToolResult:
"""Execute the tool with validated arguments.
Args:
context: Execution context containing user, conversation_id, and request_id
args: Validated tool arguments
Returns:
ToolResult with success status, result for LLM, and optional UI component
"""
pass
def get_schema(self) -> ToolSchema:
"""Generate tool schema for LLM."""
from typing import Any, cast
args_model = self.get_args_schema()
# Get the schema - args_model should be a Pydantic model class
schema = (
cast(Any, args_model).model_json_schema()
if hasattr(args_model, "model_json_schema")
else {}
)
return ToolSchema(
name=self.name,
description=self.description,
parameters=schema,
access_groups=self.access_groups,
)
================================================
FILE: src/vanna/core/tool/models.py
================================================
"""
Tool domain models.
This module contains data models for tool execution.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import BaseModel, Field
# Import AgentMemory at runtime for Pydantic model resolution
from vanna.capabilities.agent_memory import AgentMemory
if TYPE_CHECKING:
from ..components import UiComponent
from ..user.models import User
from ..observability import ObservabilityProvider
class ToolCall(BaseModel):
"""Represents a tool call from the LLM."""
id: str = Field(description="Unique identifier for this tool call")
name: str = Field(description="Name of the tool to execute")
arguments: Dict[str, Any] = Field(description="Raw arguments from LLM")
class ToolContext(BaseModel):
"""Context passed to all tool executions."""
user: "User" # Forward reference to avoid circular import
conversation_id: str
request_id: str = Field(description="Unique request identifier for tracing")
agent_memory: AgentMemory = Field(
description="Agent memory for tool usage learning"
)
metadata: Dict[str, Any] = Field(default_factory=dict)
observability_provider: Optional["ObservabilityProvider"] = Field(
default=None,
description="Optional observability provider for metrics and spans",
)
class Config:
arbitrary_types_allowed = True
class ToolResult(BaseModel):
"""Result from tool execution.
Changes:
- `result_for_llm`: string that will be sent back to the LLM.
- `ui_component`: optional UI payload for rendering in clients.
"""
success: bool = Field(description="Whether execution succeeded")
result_for_llm: str = Field(description="String content to send back to the LLM")
ui_component: Optional["UiComponent"] = Field(
default=None, description="Optional UI component for rendering"
)
error: Optional[str] = Field(default=None, description="Error message if failed")
metadata: Dict[str, Any] = Field(default_factory=dict)
class ToolSchema(BaseModel):
"""Schema describing a tool for LLM consumption."""
name: str = Field(description="Tool name")
description: str = Field(description="What this tool does")
parameters: Dict[str, Any] = Field(description="JSON Schema of parameters")
access_groups: List[str] = Field(
default_factory=list, description="Groups permitted to access this tool"
)
class ToolRejection(BaseModel):
"""Indicates tool execution should be rejected with a message.
Used by transform_args to reject tool execution when arguments
cannot be appropriately transformed for the user's context.
"""
reason: str = Field(
description="Explanation of why the tool execution was rejected"
)
================================================
FILE: src/vanna/core/user/__init__.py
================================================
"""
User domain.
This module provides the core abstractions for user management in the Vanna Agents framework.
"""
from .base import UserService
from .models import User
from .resolver import UserResolver
from .request_context import RequestContext
__all__ = [
"UserService",
"User",
"UserResolver",
"RequestContext",
]
================================================
FILE: src/vanna/core/user/base.py
================================================
"""
User domain interface.
This module contains the abstract base class for user services.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from .models import User
class UserService(ABC):
"""Service for user management and authentication."""
@abstractmethod
async def get_user(self, user_id: str) -> Optional[User]:
"""Get user by ID."""
pass
@abstractmethod
async def authenticate(self, credentials: Dict[str, Any]) -> Optional[User]:
"""Authenticate user and return User object if successful."""
pass
@abstractmethod
async def has_permission(self, user: User, permission: str) -> bool:
"""Check if user has specific permission."""
pass
================================================
FILE: src/vanna/core/user/models.py
================================================
"""
User domain models.
This module contains data models for user management.
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
class User(BaseModel):
"""User model for authentication and scoping."""
id: str = Field(description="Unique user identifier")
username: Optional[str] = Field(default=None, description="Username")
email: Optional[str] = Field(default=None, description="User email")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Additional user metadata"
)
group_memberships: List[str] = Field(
default_factory=list, description="Groups the user belongs to"
)
model_config = ConfigDict(extra="allow")
================================================
FILE: src/vanna/core/user/request_context.py
================================================
"""
Request context for user resolution.
This module provides the RequestContext model for passing web request
information to UserResolver implementations.
"""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class RequestContext(BaseModel):
"""Context from a web request for user resolution.
This structured object replaces raw dictionaries for passing request
data to UserResolver implementations, making it easier to access
cookies, headers, and other request metadata.
Example:
context = RequestContext(
cookies={'vanna_email': 'alice@example.com'},
headers={'Authorization': 'Bearer token'},
remote_addr='127.0.0.1'
)
user = await resolver.resolve_user(context)
"""
cookies: Dict[str, str] = Field(default_factory=dict, description="Request cookies")
headers: Dict[str, str] = Field(default_factory=dict, description="Request headers")
remote_addr: Optional[str] = Field(default=None, description="Remote IP address")
query_params: Dict[str, str] = Field(
default_factory=dict, description="Query parameters"
)
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Additional framework-specific metadata"
)
def get_cookie(self, name: str, default: Optional[str] = None) -> Optional[str]:
"""Get cookie value by name.
Args:
name: Cookie name
default: Default value if cookie not found
Returns:
Cookie value or default
"""
return self.cookies.get(name, default)
def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
"""Get header value by name (case-insensitive).
Args:
name: Header name
default: Default value if header not found
Returns:
Header value or default
"""
# Case-insensitive header lookup
name_lower = name.lower()
for key, value in self.headers.items():
if key.lower() == name_lower:
return value
return default
================================================
FILE: src/vanna/core/user/resolver.py
================================================
"""
User resolver interface for web request authentication.
This module provides the abstract base class for resolving web requests
to authenticated User objects.
"""
from abc import ABC, abstractmethod
from .models import User
from .request_context import RequestContext
class UserResolver(ABC):
"""Resolves web requests to authenticated users.
Implementations of this interface handle the specifics of extracting
user identity from request context (cookies, headers, tokens, etc.)
and creating authenticated User objects.
Example:
class JwtUserResolver(UserResolver):
async def resolve_user(self, request_context: RequestContext) -> User:
token = request_context.get_header('Authorization')
# ... validate JWT and extract user info
return User(id=user_id, username=username, email=email)
"""
@abstractmethod
async def resolve_user(self, request_context: RequestContext) -> User:
"""Resolve user from request context.
Args:
request_context: Structured request context with cookies, headers, etc.
Returns:
Authenticated User object
Raises:
Can raise exceptions for authentication failures
"""
pass
================================================
FILE: src/vanna/core/validation.py
================================================
"""
Development utilities for validating Pydantic models.
This module provides utilities that can be used during development
and testing to catch forward reference issues early.
"""
from typing import Any, Dict, List, Tuple, Type
from pydantic import BaseModel
import importlib
import inspect
def validate_pydantic_models_in_package(package_name: str) -> Dict[str, Any]:
"""
Validate all Pydantic models in a package for completeness.
This function can be used in tests or development scripts to catch
forward reference issues before they cause runtime errors.
Args:
package_name: Name of the package to validate (e.g., 'vanna.core')
Returns:
Dictionary with validation results
"""
results: Dict[str, Any] = {
"total_models": 0,
"incomplete_models": [],
"models": {},
"summary": "",
}
try:
# Import the package
package = importlib.import_module(package_name)
# Get all submodules
submodules = []
if hasattr(package, "__path__"):
import pkgutil
for _, name, _ in pkgutil.iter_modules(
package.__path__, package_name + "."
):
try:
submodule = importlib.import_module(name)
submodules.append((name, submodule))
except ImportError:
continue
else:
submodules = [(package_name, package)]
# Check all Pydantic models in each submodule
for module_name, module in submodules:
for name, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and issubclass(obj, BaseModel)
and obj is not BaseModel
):
model_key = f"{module_name}.{name}"
results["total_models"] += 1
# Check for forward references
forward_refs: List[Tuple[str, str]] = []
for field_name, field_info in obj.model_fields.items():
annotation = field_info.annotation
if annotation is not None and hasattr(
annotation, "__forward_arg__"
):
forward_refs.append(
(field_name, annotation.__forward_arg__)
)
# Check completeness
try:
obj.model_json_schema()
is_complete = True
error = None
except Exception as e:
is_complete = False
error = str(e)
results["incomplete_models"].append(model_key)
results["models"][model_key] = {
"class": obj,
"forward_references": forward_refs,
"is_complete": is_complete,
"error": error,
}
# Generate summary
incomplete_models = results["incomplete_models"]
incomplete_count = len(incomplete_models)
total_models = results["total_models"]
if incomplete_count == 0:
results["summary"] = (
f"✓ All {total_models} Pydantic models are complete and valid!"
)
else:
results["summary"] = (
f"⚠ {incomplete_count} of {total_models} models are incomplete: "
f"{', '.join(incomplete_models)}"
)
except Exception as e:
results["summary"] = f"Error validating package {package_name}: {e}"
return results
def check_models_health() -> bool:
"""
Quick health check for all core Pydantic models.
Returns:
True if all models are healthy, False otherwise
"""
core_packages = [
"vanna.core.tool.models",
"vanna.core.user.models",
"vanna.core.llm.models",
"vanna.core.storage.models",
"vanna.core.agent.models",
]
all_healthy = True
for package in core_packages:
try:
results = validate_pydantic_models_in_package(package)
if results["incomplete_models"]:
print(f"❌ Issues in {package}: {results['incomplete_models']}")
all_healthy = False
else:
print(f"✅ {package}: {results['total_models']} models OK")
except Exception as e:
print(f"❌ Error checking {package}: {e}")
all_healthy = False
return all_healthy
if __name__ == "__main__":
print("Checking Pydantic model health across core packages...")
print("=" * 60)
healthy = check_models_health()
print("=" * 60)
if healthy:
print("🎉 All Pydantic models are healthy!")
else:
print("⚠️ Some models need attention.")
print("\nTo fix forward reference issues:")
print("1. Ensure all referenced classes are imported")
print("2. Call model_rebuild() after imports")
print("3. Use proper TYPE_CHECKING imports for circular deps")
print("\nNote: You can also catch these issues at development time using:")
print(" - mypy static type checking")
print(" - This validation script in your test suite")
print(" - Pre-commit hooks")
================================================
FILE: src/vanna/core/workflow/__init__.py
================================================
"""
Workflow handler system for deterministic workflow execution.
This module provides the WorkflowHandler interface for intercepting user messages
and executing deterministic workflows before they reach the LLM. This is useful
for command handling, pattern-based routing, and state-based workflows.
"""
from .base import WorkflowHandler, WorkflowResult
from .default import DefaultWorkflowHandler
__all__ = ["WorkflowHandler", "WorkflowResult", "DefaultWorkflowHandler"]
================================================
FILE: src/vanna/core/workflow/base.py
================================================
"""
Base workflow handler interface.
Workflow triggers allow you to execute deterministic workflows in response to
user messages before they are sent to the LLM. This is useful for:
- Command handling (e.g., /help, /reset)
- Pattern-based routing (e.g., report generation)
- State-based workflows (e.g., onboarding flows)
- Quota enforcement with custom responses
"""
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Optional,
Union,
List,
AsyncGenerator,
Callable,
Awaitable,
)
from dataclasses import dataclass
if TYPE_CHECKING:
from ..user.models import User
from ..storage import Conversation
from ...components import UiComponent
from ..agent.agent import Agent
@dataclass
class WorkflowResult:
"""Result from a workflow handler attempt.
When a workflow handles a message, it can optionally return UI components to stream
to the user and/or mutate the conversation state.
Attributes:
should_skip_llm: If True, the workflow handled the message and LLM processing is skipped.
If False, the message continues to the agent/LLM.
components: Optional UI components to stream back to the user.
Can be a list or async generator for streaming responses.
conversation_mutation: Optional async callback to modify conversation state
(e.g., clearing messages, adding system events).
Example:
# Simple command response
WorkflowResult(
should_skip_llm=True,
components=[RichTextComponent(content="Help text here")]
)
# With conversation mutation
async def clear_history(conv):
conv.messages.clear()
WorkflowResult(
should_skip_llm=True,
components=[StatusCardComponent(...)],
conversation_mutation=clear_history
)
# Not handled, continue to agent
WorkflowResult(should_skip_llm=False)
"""
should_skip_llm: bool
components: Optional[
Union[List["UiComponent"], AsyncGenerator["UiComponent", None]]
] = None
conversation_mutation: Optional[Callable[["Conversation"], Awaitable[None]]] = None
class WorkflowHandler(ABC):
"""Base class for handling deterministic workflows before LLM processing.
Implement this interface to intercept user messages and execute deterministic
workflows instead of sending to the LLM. This is the first extensibility point
in the agent's message processing pipeline, running after user resolution and
conversation loading but before the message is added to conversation history
or sent to the LLM.
Use cases:
- Slash commands (/help, /reset, /report)
- Pattern-based routing (regex matching)
- State-based workflows (onboarding, surveys)
- Custom quota enforcement with helpful messages
- Deterministic report generation
- Starter UI (buttons, welcome messages) when conversation begins
Example:
class CommandWorkflow(WorkflowHandler):
async def try_handle(self, agent, user, conversation, message):
if message.startswith("/help"):
return WorkflowResult(
should_skip_llm=True,
components=[
RichTextComponent(
content="Available commands:\\n- /help\\n- /reset",
markdown=True
)
]
)
# Execute tool for reports
if message.startswith("/report"):
tool = await agent.tool_registry.get_tool("generate_report")
result = await tool.execute(ToolContext(user=user), {})
return WorkflowResult(should_skip_llm=True, components=[result.ui_component])
# Not handled, continue to agent
return WorkflowResult(should_skip_llm=False)
async def get_starter_ui(self, agent, user, conversation):
return [
RichTextComponent(content=f"Welcome {user.username}!"),
ButtonComponent(label="Generate Report", value="/report"),
]
agent = Agent(
llm_service=...,
tool_registry=...,
user_resolver=...,
workflow_handler=CommandWorkflow()
)
Observability:
The agent automatically creates an "agent.workflow_handler" span when
a WorkflowHandler is configured, allowing you to monitor handler
performance and outcomes.
"""
@abstractmethod
async def try_handle(
self, agent: "Agent", user: "User", conversation: "Conversation", message: str
) -> WorkflowResult:
"""Attempt to handle a workflow for the given message.
This method is called for every user message before it reaches the LLM.
Inspect the message content, user context, and conversation state to
decide whether to execute a deterministic workflow or allow normal
agent processing.
Args:
agent: The agent instance, providing access to tool_registry, config,
and observability_provider for tool execution and logging.
user: The user who sent the message, including their ID, permissions,
and metadata. Use this for permission checks or personalization.
conversation: The current conversation context, including message history.
Can be inspected for state-based workflows.
message: The user's raw message content.
Returns:
WorkflowResult with should_skip_llm=True to execute a workflow and skip LLM,
or should_skip_llm=False to continue normal agent processing.
When should_skip_llm=True:
- The message is NOT added to conversation history automatically
- The components are streamed to the user
- The conversation_mutation callback (if provided) is executed
- The agent returns without calling the LLM
When should_skip_llm=False:
- The message is added to conversation history
- Normal agent processing continues (LLM call, tool execution, etc.)
Example:
async def try_handle(self, agent, user, conversation, message):
# Pattern matching with tool execution
if message.startswith("/report"):
# Execute tool from registry
tool = await agent.tool_registry.get_tool("generate_sales_report")
context = ToolContext(user=user, conversation=conversation)
result = await tool.execute(context, {})
return WorkflowResult(
should_skip_llm=True,
components=[...]
)
# State-based workflow
if user.metadata.get("needs_onboarding"):
return await self._onboarding_flow(agent, user, message)
# Permission check
if message.startswith("/admin") and "admin" not in user.permissions:
return WorkflowResult(
should_skip_llm=True,
components=[RichTextComponent(content="Access denied.")]
)
# Continue to agent
return WorkflowResult(should_skip_llm=False)
"""
pass
async def get_starter_ui(
self, agent: "Agent", user: "User", conversation: "Conversation"
) -> Optional[List["UiComponent"]]:
"""Provide UI components when a conversation starts.
Override this method to show starter buttons, welcome messages,
or quick actions when a new chat is opened by the user.
This is called by the frontend/server when initializing a new
conversation, before any user messages are sent.
Args:
agent: The agent instance, providing access to tool_registry, config,
and observability_provider for dynamic UI generation.
user: The user starting the conversation
conversation: The new conversation (typically empty)
Returns:
List of UI components to display, or None for no starter UI.
Components can include buttons, welcome text, quick actions, etc.
Example:
async def get_starter_ui(self, agent, user, conversation):
# Show role-based quick actions
if "analyst" in user.permissions:
# Dynamically generate buttons based on available tools
report_tools = [
tool for tool in agent.tool_registry.list_tools()
if tool.startswith("report_")
]
buttons = [
ButtonComponent(label=f"📊 {tool}", value=f"/{tool}")
for tool in report_tools
]
return [
RichTextComponent(
content=f"Welcome back, {user.username}!",
markdown=True
),
*buttons
]
# New user onboarding
if user.metadata.get("is_new_user"):
return [
RichTextComponent(
content="# Welcome to Vanna!\\n\\nTry one of these to get started:",
markdown=True
),
ButtonComponent(label="Show Example Query", value="/example"),
ButtonComponent(label="View Tutorial", value="/tutorial"),
]
return None
"""
return None
================================================
FILE: src/vanna/core/workflow/default.py
================================================
"""
Default workflow handler implementation with setup health checking.
This module provides a default implementation of the WorkflowHandler interface
that provides a smart starter UI based on available tools and setup status.
"""
from typing import TYPE_CHECKING, List, Optional, Dict, Any
import traceback
import uuid
from .base import WorkflowHandler, WorkflowResult
if TYPE_CHECKING:
from ..agent.agent import Agent
from ..user.models import User
from ..storage import Conversation
# Import components at module level to avoid circular imports
from vanna.components import (
UiComponent,
RichTextComponent,
StatusCardComponent,
ButtonComponent,
ButtonGroupComponent,
SimpleTextComponent,
CardComponent,
)
# Note: StatusCardComponent and ButtonGroupComponent are kept for /status command compatibility
class DefaultWorkflowHandler(WorkflowHandler):
"""Default workflow handler that provides setup health checking and starter UI.
This handler provides a starter UI that:
- Checks if run_sql tool is available (critical)
- Checks if memory tools are available (warning if missing)
- Checks if visualization tools are available
- Provides appropriate setup guidance based on what's missing
"""
def __init__(self, welcome_message: Optional[str] = None):
"""Initialize with optional custom welcome message.
Args:
welcome_message: Optional custom welcome message. If not provided,
generates one based on available tools.
"""
self.welcome_message = welcome_message
async def try_handle(
self, agent: "Agent", user: "User", conversation: "Conversation", message: str
) -> WorkflowResult:
"""Handle basic commands, but mostly passes through to LLM."""
# Handle basic help command
if message.strip().lower() in ["/help", "help", "/h"]:
# Check if user is admin
is_admin = "admin" in user.group_memberships
help_content = (
"## 🤖 Vanna AI Assistant\n\n"
"I'm your AI data analyst! Here's what I can help you with:\n\n"
"**💬 Natural Language Queries**\n"
'- "Show me sales data for last quarter"\n'
'- "Which customers have the highest orders?"\n'
'- "Create a chart of revenue by month"\n\n'
"**🔧 Commands**\n"
"- `/help` - Show this help message\n"
)
if is_admin:
help_content += (
"\n**🔒 Admin Commands**\n"
"- `/status` - Check setup status\n"
"- `/memories` - View and manage recent memories\n"
"- `/delete [id]` - Delete a memory by ID\n"
)
help_content += "\n\nJust ask me anything about your data in plain English!"
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content=help_content,
markdown=True,
),
simple_component=None,
)
],
)
# Handle status check command (admin-only)
if message.strip().lower() in ["/status", "status"]:
# Check if user is admin
if "admin" not in user.group_memberships:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content="# 🔒 Access Denied\n\n"
"The `/status` command is only available to administrators.\n\n"
"If you need access to system status information, please contact your system administrator.",
markdown=True,
),
simple_component=None,
)
],
)
return await self._generate_status_check(agent, user)
# Handle get recent memories command (admin-only)
if message.strip().lower() in [
"/memories",
"memories",
"/recent_memories",
"recent_memories",
]:
# Check if user is admin
if "admin" not in user.group_memberships:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content="# 🔒 Access Denied\n\n"
"The `/memories` command is only available to administrators.\n\n"
"If you need access to memory management features, please contact your system administrator.",
markdown=True,
),
simple_component=None,
)
],
)
return await self._get_recent_memories(agent, user, conversation)
# Handle delete memory command (admin-only)
if message.strip().lower().startswith("/delete "):
# Check if user is admin
if "admin" not in user.group_memberships:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content="# 🔒 Access Denied\n\n"
"The `/delete` command is only available to administrators.\n\n"
"If you need access to memory management features, please contact your system administrator.",
markdown=True,
),
simple_component=None,
)
],
)
memory_id = message.strip()[8:].strip() # Extract ID after "/delete "
return await self._delete_memory(agent, user, conversation, memory_id)
# Don't handle other messages, pass to LLM
return WorkflowResult(should_skip_llm=False)
async def get_starter_ui(
self, agent: "Agent", user: "User", conversation: "Conversation"
) -> Optional[List[UiComponent]]:
"""Generate starter UI based on available tools and setup status."""
# Get available tools
tools = await agent.tool_registry.get_schemas(user)
tool_names = [tool.name for tool in tools]
# Analyze setup
setup_analysis = self._analyze_setup(tool_names)
# Check if user is admin (has 'admin' in group memberships)
is_admin = "admin" in user.group_memberships
# Generate single concise card
if self.welcome_message:
# Use custom welcome message
return [
UiComponent(
rich_component=RichTextComponent(
content=self.welcome_message, markdown=True
),
simple_component=None,
)
]
else:
# Generate role-aware welcome card
return [self._generate_starter_card(setup_analysis, is_admin)]
def _generate_starter_card(
self, analysis: Dict[str, Any], is_admin: bool
) -> UiComponent:
"""Generate a single concise starter card based on role and setup status."""
if is_admin:
# Admin view: includes setup status and memory management
return self._generate_admin_starter_card(analysis)
else:
# User view: simple welcome message
return self._generate_user_starter_card(analysis)
def _generate_admin_starter_card(self, analysis: Dict[str, Any]) -> UiComponent:
"""Generate admin starter card with setup info and memory management."""
# Build concise content
if not analysis["has_sql"]:
title = "Admin: Setup Required"
content = "**🔒 Admin View** - You have admin privileges and will see additional system information.\n\n**Vanna AI** requires a SQL connection to function.\n\nPlease configure a SQL tool to get started."
status = "error"
icon = "⚠️"
elif analysis["is_complete"]:
title = "Admin: System Ready"
content = "**🔒 Admin View** - You have admin privileges and will see additional system information.\n\n**Vanna AI** is fully configured and ready.\n\n"
content += "**Setup:** SQL ✓ | Memory ✓ | Visualization ✓"
status = "success"
icon = "✅"
else:
title = "Admin: System Ready"
content = "**🔒 Admin View** - You have admin privileges and will see additional system information.\n\n**Vanna AI** is ready to query your database.\n\n"
setup_items = []
setup_items.append("SQL ✓")
setup_items.append("Memory ✓" if analysis["has_memory"] else "Memory ✗")
setup_items.append("Viz ✓" if analysis["has_viz"] else "Viz ✗")
content += f"**Setup:** {' | '.join(setup_items)}"
status = "warning" if not analysis["has_memory"] else "success"
icon = "⚠️" if not analysis["has_memory"] else "✅"
# Add memory management info for admins
actions: List[Dict[str, Any]] = []
if analysis["has_sql"]:
actions.append(
{
"label": "💡 Help",
"action": "/help",
"variant": "secondary",
}
)
if analysis["has_memory"]:
content += "\n\n**Memory Management:** Tool and text memories are available. As an admin, you can view and manage these memories to help me learn from successful queries."
actions.append(
{
"label": "🧠 View Memories",
"action": "/memories",
"variant": "secondary",
}
)
return UiComponent(
rich_component=CardComponent(
title=title,
content=content,
icon=icon,
status=status,
actions=actions,
markdown=True,
),
simple_component=None,
)
def _generate_user_starter_card(self, analysis: Dict[str, Any]) -> UiComponent:
"""Generate simple user starter view using RichTextComponent."""
if not analysis["has_sql"]:
content = (
"# ⚠️ Setup Required\n\n"
"Vanna AI requires configuration before it can help you analyze data."
)
else:
content = (
"# 👋 Welcome to Vanna AI\n\n"
"I'm your AI data analyst assistant. Ask me anything about your data in plain English!\n\n"
"Type `/help` to see what I can do."
)
return UiComponent(
rich_component=RichTextComponent(content=content, markdown=True),
simple_component=None,
)
def _analyze_setup(self, tool_names: List[str]) -> Dict[str, Any]:
"""Analyze the current tool setup and return status."""
# Critical tools
has_sql = any(
name in tool_names
for name in ["run_sql", "sql_query", "execute_sql", "query_sql"]
)
# Memory tools (important but not critical)
has_search = "search_saved_correct_tool_uses" in tool_names
has_save = "save_question_tool_args" in tool_names
has_memory = has_search and has_save
# Visualization tools (nice to have)
has_viz = any(
name in tool_names
for name in [
"visualize_data",
"create_chart",
"plot_data",
"generate_chart",
]
)
# Other useful tools
has_calculator = any(
name in tool_names for name in ["calculator", "calc", "calculate"]
)
# Determine overall status
is_complete = has_sql and has_memory and has_viz
is_functional = has_sql
return {
"has_sql": has_sql,
"has_memory": has_memory,
"has_search": has_search,
"has_save": has_save,
"has_viz": has_viz,
"has_calculator": has_calculator,
"is_complete": is_complete,
"is_functional": is_functional,
"tool_count": len(tool_names),
"tool_names": tool_names,
}
def _generate_setup_status_cards(
self, analysis: Dict[str, Any]
) -> List[UiComponent]:
"""Generate status cards showing setup health (used by /status command)."""
cards = []
# SQL Tool Status (Critical)
if analysis["has_sql"]:
sql_card = StatusCardComponent(
title="SQL Connection",
status="success",
description="Database connection configured and ready",
icon="✅",
)
else:
sql_card = StatusCardComponent(
title="SQL Connection",
status="error",
description="No SQL tool detected - this is required for data analysis",
icon="❌",
)
cards.append(UiComponent(rich_component=sql_card, simple_component=None))
# Memory Tools Status (Important)
if analysis["has_memory"]:
memory_card = StatusCardComponent(
title="Memory System",
status="success",
description="Search and save tools configured - I can learn from successful queries",
icon="🧠",
)
elif analysis["has_search"] or analysis["has_save"]:
memory_card = StatusCardComponent(
title="Memory System",
status="warning",
description="Partial memory setup - both search and save tools recommended",
icon="⚠️",
)
else:
memory_card = StatusCardComponent(
title="Memory System",
status="warning",
description="Memory tools not configured - I won't remember successful patterns",
icon="⚠️",
)
cards.append(UiComponent(rich_component=memory_card, simple_component=None))
# Visualization Status (Nice to have)
if analysis["has_viz"]:
viz_card = StatusCardComponent(
title="Visualization",
status="success",
description="Chart creation tools available",
icon="📊",
)
else:
viz_card = StatusCardComponent(
title="Visualization",
status="info",
description="No visualization tools - results will be text/tables only",
icon="📋",
)
cards.append(UiComponent(rich_component=viz_card, simple_component=None))
return cards
def _generate_setup_guidance(
self, analysis: Dict[str, Any]
) -> Optional[UiComponent]:
"""Generate setup guidance based on what's missing (used by /status command)."""
if not analysis["has_sql"]:
# Critical guidance - need SQL
content = (
"## 🚨 Setup Required\n\n"
"To get started with Vanna AI, you need to configure a SQL connection tool:\n\n"
"```python\n"
"from vanna.tools import RunSqlTool\n\n"
"# Add SQL tool to your agent\n"
"tool_registry.register(RunSqlTool(\n"
' connection_string="your-database-connection"\n'
"))\n"
"```\n\n"
"**Next Steps:**\n"
"1. Configure your database connection\n"
"2. Add memory tools for learning\n"
"3. Add visualization tools for charts"
)
else:
# Improvement suggestions
suggestions = []
if not analysis["has_memory"]:
suggestions.append(
"**🧠 Add Memory Tools** - Help me learn from successful queries:\n"
"```python\n"
"from vanna.tools import SearchSavedCorrectToolUses, SaveQuestionToolArgs\n"
"tool_registry.register(SearchSavedCorrectToolUses())\n"
"tool_registry.register(SaveQuestionToolArgs())\n"
"```"
)
if not analysis["has_viz"]:
suggestions.append(
"**📊 Add Visualization** - Create charts and graphs:\n"
"```python\n"
"from vanna.tools import VisualizeDataTool\n"
"tool_registry.register(VisualizeDataTool())\n"
"```"
)
if suggestions:
content = "## 💡 Suggested Improvements\n\n" + "\n\n".join(suggestions)
else:
return None # No guidance needed
return UiComponent(
rich_component=RichTextComponent(content=content, markdown=True),
simple_component=None,
)
async def _generate_status_check(
self, agent: "Agent", user: "User"
) -> WorkflowResult:
"""Generate a detailed status check response."""
# Get available tools
tools = await agent.tool_registry.get_schemas(user)
tool_names = [tool.name for tool in tools]
analysis = self._analyze_setup(tool_names)
# Generate status report
status_content = "# 🔍 Setup Status Report\n\n"
if analysis["is_complete"]:
status_content += (
"🎉 **Excellent!** Your Vanna AI setup is complete and optimized.\n\n"
)
elif analysis["is_functional"]:
status_content += (
"✅ **Good!** Your setup is functional with room for improvement.\n\n"
)
else:
status_content += (
"⚠️ **Action Required** - Your setup needs configuration.\n\n"
)
status_content += f"**Tools Detected:** {analysis['tool_count']} total\n\n"
# Tool breakdown
status_content += "## Tool Status\n\n"
status_content += f"- **SQL Connection:** {'✅ Available' if analysis['has_sql'] else '❌ Missing (Required)'}\n"
status_content += f"- **Memory System:** {'✅ Complete' if analysis['has_memory'] else '⚠️ Incomplete' if analysis['has_search'] or analysis['has_save'] else '❌ Missing'}\n"
status_content += f"- **Visualization:** {'✅ Available' if analysis['has_viz'] else '📋 Text/Tables Only'}\n"
status_content += f"- **Calculator:** {'✅ Available' if analysis['has_calculator'] else '➖ Not Available'}\n\n"
if analysis["tool_names"]:
status_content += (
f"**Available Tools:** {', '.join(sorted(analysis['tool_names']))}"
)
components = [
UiComponent(
rich_component=RichTextComponent(content=status_content, markdown=True),
simple_component=None,
)
]
# Add status cards
components.extend(self._generate_setup_status_cards(analysis))
# Add guidance if needed
guidance = self._generate_setup_guidance(analysis)
if guidance:
components.append(guidance)
return WorkflowResult(should_skip_llm=True, components=components)
async def _get_recent_memories(
self, agent: "Agent", user: "User", conversation: "Conversation"
) -> WorkflowResult:
"""Get and display recent memories from agent memory."""
try:
# Check if agent has memory capability
if not hasattr(agent, "agent_memory") or agent.agent_memory is None:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content="# ⚠️ No Memory System\n\n"
"Agent memory is not configured. Recent memories are not available.\n\n"
"To enable memory, configure an AgentMemory implementation in your agent setup.",
markdown=True,
),
simple_component=None,
)
],
)
# Create tool context
from vanna.core.tool import ToolContext
context = ToolContext(
user=user,
conversation_id=conversation.id,
request_id=str(uuid.uuid4()),
agent_memory=agent.agent_memory,
)
# Get both tool memories and text memories
tool_memories = await agent.agent_memory.get_recent_memories(
context=context, limit=10
)
# Try to get text memories (may not be implemented in all memory backends)
text_memories = []
try:
text_memories = await agent.agent_memory.get_recent_text_memories(
context=context, limit=10
)
except (AttributeError, NotImplementedError):
# Text memories not supported by this implementation
pass
if not tool_memories and not text_memories:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content="# 🧠 Recent Memories\n\n"
"No recent memories found. As you use tools and ask questions, "
"successful patterns will be saved here for future reference.",
markdown=True,
),
simple_component=None,
)
],
)
components = []
# Header
total_count = len(tool_memories) + len(text_memories)
header_content = f"# 🧠 Recent Memories\n\nFound {total_count} recent memor{'y' if total_count == 1 else 'ies'}"
components.append(
UiComponent(
rich_component=RichTextComponent(
content=header_content, markdown=True
),
simple_component=None,
)
)
# Display text memories
if text_memories:
components.append(
UiComponent(
rich_component=RichTextComponent(
content=f"## 📝 Text Memories ({len(text_memories)})",
markdown=True,
),
simple_component=None,
)
)
for memory in text_memories:
# Create card with delete button
card_content = f"**Content:** {memory.content}\n\n"
if memory.timestamp:
card_content += f"**Timestamp:** {memory.timestamp}\n\n"
card_content += f"**ID:** `{memory.memory_id}`"
card = CardComponent(
title="Text Memory",
content=card_content,
icon="📝",
actions=[
{
"label": "🗑️ Delete",
"action": f"/delete {memory.memory_id}",
"variant": "error",
}
],
)
components.append(
UiComponent(rich_component=card, simple_component=None)
)
# Display tool memories
if tool_memories:
components.append(
UiComponent(
rich_component=RichTextComponent(
content=f"## 🔧 Tool Memories ({len(tool_memories)})",
markdown=True,
),
simple_component=None,
)
)
for tool_memory in tool_memories:
# Create card with delete button
card_content = f"**Question:** {tool_memory.question}\n\n"
card_content += f"**Tool:** {tool_memory.tool_name}\n\n"
card_content += f"**Arguments:** `{tool_memory.args}`\n\n"
card_content += f"**Success:** {'✅ Yes' if tool_memory.success else '❌ No'}\n\n"
if tool_memory.timestamp:
card_content += f"**Timestamp:** {tool_memory.timestamp}\n\n"
card_content += f"**ID:** `{tool_memory.memory_id}`"
card = CardComponent(
title=f"Tool: {tool_memory.tool_name}",
content=card_content,
markdown=True,
icon="🔧",
status="success" if tool_memory.success else "error",
actions=[
{
"label": "🗑️ Delete",
"action": f"/delete {tool_memory.memory_id}",
"variant": "error",
}
],
)
components.append(
UiComponent(rich_component=card, simple_component=None)
)
return WorkflowResult(should_skip_llm=True, components=components)
except Exception as e:
traceback.print_exc()
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content=f"# ❌ Error Retrieving Memories\n\n"
f"Failed to get recent memories: {str(e)}\n\n"
f"This may indicate an issue with the agent memory configuration.",
markdown=True,
),
simple_component=None,
)
],
)
async def _delete_memory(
self, agent: "Agent", user: "User", conversation: "Conversation", memory_id: str
) -> WorkflowResult:
"""Delete a memory by its ID."""
try:
# Check if agent has memory capability
if not hasattr(agent, "agent_memory") or agent.agent_memory is None:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content="# ⚠️ No Memory System\n\n"
"Agent memory is not configured. Cannot delete memories.",
markdown=True,
),
simple_component=None,
)
],
)
if not memory_id:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content="# ⚠️ Invalid Command\n\n"
"Please provide a memory ID to delete.\n\n"
"Usage: `/delete [memory_id]`",
markdown=True,
),
simple_component=None,
)
],
)
# Create tool context
from vanna.core.tool import ToolContext
context = ToolContext(
user=user,
conversation_id=conversation.id,
request_id=str(uuid.uuid4()),
agent_memory=agent.agent_memory,
)
# Try to delete as a tool memory first
deleted = await agent.agent_memory.delete_by_id(context, memory_id)
# If not found as tool memory, try as text memory
if not deleted:
try:
deleted = await agent.agent_memory.delete_text_memory(
context, memory_id
)
except (AttributeError, NotImplementedError):
# Text memory deletion not supported by this implementation
pass
if deleted:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content=f"# ✅ Memory Deleted\n\n"
f"Successfully deleted memory with ID: `{memory_id}`\n\n"
f"You can view remaining memories using `/memories`.",
markdown=True,
),
simple_component=None,
)
],
)
else:
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content=f"# ❌ Memory Not Found\n\n"
f"Could not find memory with ID: `{memory_id}`\n\n"
f"Use `/memories` to see available memory IDs.",
markdown=True,
),
simple_component=None,
)
],
)
except Exception as e:
traceback.print_exc()
return WorkflowResult(
should_skip_llm=True,
components=[
UiComponent(
rich_component=RichTextComponent(
content=f"# ❌ Error Deleting Memory\n\n"
f"Failed to delete memory: {str(e)}\n\n"
f"This may indicate an issue with the agent memory configuration.",
markdown=True,
),
simple_component=None,
)
],
)
================================================
FILE: src/vanna/examples/__init__.py
================================================
"""Examples for using the Vanna Agents framework."""
================================================
FILE: src/vanna/examples/__main__.py
================================================
"""
Interactive example runner for Vanna Agents.
"""
import sys
import importlib
def main() -> None:
"""Run an example interactively."""
if len(sys.argv) < 2:
print("Available examples:")
print(" python -m vanna.examples mock_quickstart")
print(" python -m vanna.examples mock_custom_tool")
print(" python -m vanna.examples anthropic_quickstart")
print(" python -m vanna.examples openai_quickstart")
print(" python -m vanna.examples mock_quota_example")
print(" python -m vanna.examples mock_rich_components_demo")
print("")
print("Usage: python -m vanna.examples ")
return
example_name = sys.argv[1]
try:
module = importlib.import_module(f"vanna.examples.{example_name}")
if hasattr(module, "run_interactive"):
module.run_interactive()
elif hasattr(module, "main"):
import asyncio
if asyncio.iscoroutinefunction(module.main):
asyncio.run(module.main())
else:
module.main()
else:
print(f"Example '{example_name}' does not have a main function")
except ImportError:
print(f"Example '{example_name}' not found")
except Exception as e:
print(f"Error running example '{example_name}': {e}")
if __name__ == "__main__":
main()
================================================
FILE: src/vanna/examples/anthropic_quickstart.py
================================================
"""
Anthropic example using AnthropicLlmService.
Loads environment from .env (via python-dotenv), uses model 'claude-sonnet-4-20250514'
by default, and sends a simple message through a Agent.
Run:
PYTHONPATH=. python vanna/examples/anthropic_quickstart.py
"""
import asyncio
import importlib.util
import os
import sys
def ensure_env() -> None:
if importlib.util.find_spec("dotenv") is not None:
from dotenv import load_dotenv
# Load from local .env without overriding existing env
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"), override=False)
else:
print(
"[warn] python-dotenv not installed; skipping .env load. Install with: pip install python-dotenv"
)
if not os.getenv("ANTHROPIC_API_KEY"):
print(
"[error] ANTHROPIC_API_KEY is not set. Add it to your environment or .env file."
)
sys.exit(1)
async def main() -> None:
ensure_env()
try:
from vanna.integrations.anthropic import AnthropicLlmService
except ImportError:
print(
"[error] anthropic extra not installed. Install with: pip install -e .[anthropic]"
)
raise
from vanna import AgentConfig, Agent, User
from vanna.core.registry import ToolRegistry
from vanna.tools import ListFilesTool
model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
print(f"Using Anthropic model: {model}")
llm = AnthropicLlmService(model=model)
# Create tool registry and register the list_files tool
tool_registry = ToolRegistry()
list_files_tool = ListFilesTool()
tool_registry.register(list_files_tool)
agent = Agent(
llm_service=llm,
config=AgentConfig(stream_responses=False),
tool_registry=tool_registry,
)
user = User(id="demo-user", username="demo")
conversation_id = "anthropic-demo"
print("Sending: 'List the files in the current directory'\n")
async for component in agent.send_message(
user=user,
message="List the files in the current directory",
conversation_id=conversation_id,
):
if hasattr(component, "content") and component.content:
print("Assistant:", component.content)
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/examples/artifact_example.py
================================================
#!/usr/bin/env python3
"""
Example demonstrating the artifact system in Vanna Agents.
This script shows how agents can create interactive artifacts that can be
rendered externally by developers listening for the 'artifact-opened' event.
"""
import asyncio
from typing import AsyncGenerator, Optional
from vanna import Agent, UiComponent, User, AgentConfig
from vanna.core.rich_components import ArtifactComponent
from vanna.integrations.anthropic.mock import MockLlmService
from vanna.core.interfaces import Agent, LlmService
class ArtifactDemoAgent(Agent):
"""Demo agent that creates various types of artifacts."""
def __init__(self, llm_service: Optional[LlmService] = None) -> None:
if llm_service is None:
llm_service = MockLlmService(
"I'll help you create interactive artifacts! Try asking me to create a chart, dashboard, or interactive HTML widget."
)
super().__init__(
llm_service=llm_service,
config=AgentConfig(
stream_responses=True,
include_thinking_indicators=True,
),
)
async def send_message(
self, user: User, message: str, *, conversation_id: Optional[str] = None
) -> AsyncGenerator[UiComponent, None]:
"""Handle user messages and create appropriate artifacts."""
# First send the normal response
async for component in super().send_message(
user, message, conversation_id=conversation_id
):
yield component
# Then create artifacts based on message content
message_lower = message.lower()
if any(
word in message_lower for word in ["chart", "graph", "visualization", "d3"]
):
async for component in self.create_d3_visualization():
yield component
elif any(
word in message_lower for word in ["dashboard", "analytics", "metrics"]
):
async for component in self.create_dashboard_artifact():
yield component
elif any(
word in message_lower for word in ["html", "interactive", "widget", "demo"]
):
async for component in self.create_html_artifact():
yield component
async def create_html_artifact(self) -> AsyncGenerator[UiComponent, None]:
"""Create a simple HTML artifact."""
html_content = """
Interactive HTML Artifact
This is a simple HTML artifact that can be opened externally.
Output will appear here...
"""
artifact = ArtifactComponent.create_html(
content=html_content,
title="Interactive HTML Demo",
description="A simple HTML artifact with interactive elements",
)
yield UiComponent(rich_component=artifact)
async def create_d3_visualization(self) -> AsyncGenerator[UiComponent, None]:
"""Create a D3.js visualization artifact."""
d3_content = """
"""
artifact = ArtifactComponent.create_d3(
content=d3_content,
title="D3.js Bar Chart",
description="An interactive bar chart built with D3.js",
)
yield UiComponent(rich_component=artifact)
async def create_dashboard_artifact(self) -> AsyncGenerator[UiComponent, None]:
"""Create a dashboard-style artifact."""
dashboard_content = """
Analytics Dashboard
Total Users
12,456
↗ +5.2%
Revenue
$89,432
↗ +12.3%
Conversion Rate
3.4%
↘ -0.8%
Quick Actions
"""
artifact = ArtifactComponent(
content=dashboard_content,
artifact_type="dashboard",
title="Analytics Dashboard",
description="A sample analytics dashboard with metrics and controls",
external_renderable=True,
fullscreen_capable=True,
)
yield UiComponent(rich_component=artifact)
def create_demo_agent() -> ArtifactDemoAgent:
"""Create a demo agent for REPL and server usage.
Returns:
Configured ArtifactDemoAgent instance
"""
return ArtifactDemoAgent()
async def main() -> None:
"""Main demo function."""
print("🎨 Artifact Demo Agent")
print("This demo shows how to create different types of artifacts.")
print(
"In a real web application, developers can listen for 'artifact-opened' events."
)
print()
demo_agent = create_demo_agent()
user = User(id="demo_user", username="artifact_demo")
# Demo 1: HTML Artifact
print("1. Creating HTML Artifact...")
async for component in demo_agent.create_html_artifact():
artifact = component.rich_component
if isinstance(artifact, ArtifactComponent):
print(f" ✓ Created HTML artifact: {artifact.title}")
print(f" ✓ Artifact ID: {artifact.artifact_id}")
print(f" ✓ Type: {artifact.artifact_type}")
print(f" ✓ External renderable: {artifact.external_renderable}")
print()
# Demo 2: D3 Visualization
print("2. Creating D3.js Visualization...")
async for component in demo_agent.create_d3_visualization():
artifact = component.rich_component
if isinstance(artifact, ArtifactComponent):
print(f" ✓ Created D3 artifact: {artifact.title}")
print(f" ✓ Dependencies: {artifact.dependencies}")
print(f" ✓ Standalone HTML available via get_standalone_html()")
print()
# Demo 3: Dashboard
print("3. Creating Dashboard Artifact...")
async for component in demo_agent.create_dashboard_artifact():
artifact = component.rich_component
if isinstance(artifact, ArtifactComponent):
print(f" ✓ Created dashboard artifact: {artifact.title}")
print(f" ✓ Fullscreen capable: {artifact.fullscreen_capable}")
print()
print("🚀 Web Integration Example:")
print("""
In your web application, listen for the 'artifact-opened' event:
document.querySelector('vanna-chat').addEventListener('artifact-opened', (event) => {
const { artifactId, content, type, trigger } = event.detail;
if (trigger === 'created' && type === 'dashboard') {
// Auto-open dashboards in external window
const newWindow = window.open('', '_blank');
newWindow.document.write(event.detail.getStandaloneHTML());
newWindow.document.close();
// Prevent default rendering in chat
event.detail.preventDefault();
}
});
""")
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/examples/claude_sqlite_example.py
================================================
"""
Claude example using the SQL query tool with the Chinook database.
This example demonstrates using the RunSqlTool with SqliteRunner and Claude's AI
to intelligently query and analyze the Chinook database, with automatic visualization support.
Requirements:
- ANTHROPIC_API_KEY environment variable or .env file
- anthropic package: pip install -e .[anthropic]
- plotly package: pip install -e .[visualization]
Usage:
PYTHONPATH=. python vanna/examples/claude_sqlite_example.py
"""
import asyncio
import importlib.util
import os
import sys
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vanna import Agent
def ensure_env() -> None:
if importlib.util.find_spec("dotenv") is not None:
from dotenv import load_dotenv
# Load from local .env without overriding existing env
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"), override=False)
else:
print(
"[warn] python-dotenv not installed; skipping .env load. Install with: pip install python-dotenv"
)
if not os.getenv("ANTHROPIC_API_KEY"):
print(
"[error] ANTHROPIC_API_KEY is not set. Add it to your environment or .env file."
)
sys.exit(1)
async def main() -> None:
ensure_env()
try:
from vanna.integrations.anthropic import AnthropicLlmService
except ImportError:
print(
"[error] anthropic extra not installed. Install with: pip install -e .[anthropic]"
)
raise
from vanna import AgentConfig, Agent
from vanna.core.registry import ToolRegistry
from vanna.core.user import CookieEmailUserResolver, RequestContext
from vanna.integrations.sqlite import SqliteRunner
from vanna.tools import (
RunSqlTool,
VisualizeDataTool,
LocalFileSystem,
)
# Get the path to the Chinook database
database_path = os.path.join(
os.path.dirname(__file__), "..", "..", "Chinook.sqlite"
)
database_path = os.path.abspath(database_path)
if not os.path.exists(database_path):
print(f"[error] Chinook database not found at {database_path}")
print(
"Please download it with: curl -o Chinook.sqlite https://vanna.ai/Chinook.sqlite"
)
sys.exit(1)
model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
print(f"Using Anthropic model: {model}")
print(f"Using database: {database_path}")
llm = AnthropicLlmService(model=model)
# Create shared FileSystem for both tools
file_system = LocalFileSystem(working_directory="./claude_data")
# Create tool registry and register the SQL tool with SQLite runner
tool_registry = ToolRegistry()
sqlite_runner = SqliteRunner(database_path=database_path)
sql_tool = RunSqlTool(sql_runner=sqlite_runner, file_system=file_system)
tool_registry.register(sql_tool)
# Register visualization tool
try:
viz_tool = VisualizeDataTool(file_system=file_system)
tool_registry.register(viz_tool)
print("Visualization tool enabled")
except ImportError:
print(
"[warn] Plotly not installed. Visualization tool disabled. Install with: pip install -e .[visualization]"
)
user_resolver = CookieEmailUserResolver()
agent = Agent(
llm_service=llm,
config=AgentConfig(stream_responses=False),
tool_registry=tool_registry,
user_resolver=user_resolver,
)
# Simulate a logged-in demo user via cookie-based resolver
request_context = RequestContext(
cookies={user_resolver.cookie_name: "demo-user@example.com"},
metadata={"demo": True},
remote_addr="127.0.0.1",
)
conversation_id = "claude-sqlite-demo"
# Sample queries to demonstrate different capabilities
sample_questions = [
"What tables are in this database?",
"Show me the first 5 customers with their names",
"What's the total number of tracks in the database?",
"Find the top 5 artists by number of albums",
"What's the average invoice total?",
"Get data on the top 10 longest tracks and then visualize it",
]
print("\n" + "=" * 60)
print("Claude SQLite Database Assistant Demo")
print("=" * 60)
print("This demo shows Claude querying the Chinook music database.")
print("Claude will intelligently construct SQL queries to answer questions")
print("and can create visualizations of the results.")
print()
for i, question in enumerate(sample_questions, 1):
print(f"\n--- Question {i}: {question} ---")
async for component in agent.send_message(
request_context=request_context,
message=question,
conversation_id=conversation_id,
):
# Handle different component types
if hasattr(component, "simple_component") and component.simple_component:
if hasattr(component.simple_component, "text"):
print("Assistant:", component.simple_component.text)
elif hasattr(component, "rich_component") and component.rich_component:
if (
hasattr(component.rich_component, "content")
and component.rich_component.content
):
print("Assistant:", component.rich_component.content)
elif hasattr(component, "content") and component.content:
print("Assistant:", component.content)
print() # Add spacing between questions
print("\n" + "=" * 60)
print("Demo complete! Claude successfully queried the database.")
print("=" * 60)
def create_demo_agent() -> "Agent":
"""Create a demo agent with Claude and SQLite query tool.
This function is called by the vanna server framework.
Returns:
Configured Agent with Claude LLM and SQLite tool
"""
ensure_env()
try:
from vanna.integrations.anthropic import AnthropicLlmService
except ImportError:
print(
"[error] anthropic extra not installed. Install with: pip install -e .[anthropic]"
)
raise
from vanna import AgentConfig, Agent
from vanna.core.registry import ToolRegistry
from vanna.core.user import CookieEmailUserResolver
from vanna.integrations.sqlite import SqliteRunner
from vanna.tools import (
RunSqlTool,
VisualizeDataTool,
LocalFileSystem,
)
# Get the path to the Chinook database
database_path = os.path.join(
os.path.dirname(__file__), "..", "..", "Chinook.sqlite"
)
database_path = os.path.abspath(database_path)
if not os.path.exists(database_path):
raise FileNotFoundError(
f"Chinook database not found at {database_path}. Please download it from https://vanna.ai/Chinook.sqlite"
)
model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
llm = AnthropicLlmService(model=model)
# Create shared FileSystem for both tools
file_system = LocalFileSystem(working_directory="./claude_data")
# Create tool registry and register the SQL tool with SQLite runner
tool_registry = ToolRegistry()
sqlite_runner = SqliteRunner(database_path=database_path)
sql_tool = RunSqlTool(sql_runner=sqlite_runner, file_system=file_system)
tool_registry.register(sql_tool)
# Register visualization tool if available
try:
viz_tool = VisualizeDataTool(file_system=file_system)
tool_registry.register(viz_tool)
except ImportError:
pass # Visualization tool not available
user_resolver = CookieEmailUserResolver()
return Agent(
llm_service=llm,
config=AgentConfig(stream_responses=True), # Enable streaming for web interface
tool_registry=tool_registry,
user_resolver=user_resolver,
)
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/examples/coding_agent_example.py
================================================
"""
Example coding agent using the vanna-agents framework.
This example demonstrates building an agent that can edit code files,
following the concepts from the "How to Build an Agent" article.
The agent includes tools for file operations and uses an LLM service
that can understand and modify code.
Usage:
PYTHONPATH=. python vanna/examples/coding_agent_example.py
"""
import asyncio
import uuid
from typing import AsyncGenerator, List, Optional
from vanna import (
AgentConfig,
Agent,
ToolRegistry,
User,
)
from vanna.core.interfaces import LlmService
from vanna.core.models import (
LlmRequest,
LlmResponse,
LlmStreamChunk,
ToolCall,
ToolSchema,
)
from vanna.tools.file_system import create_file_system_tools
from vanna.tools.python import create_python_tools
class CodingLlmService(LlmService):
"""
LLM service that simulates a coding assistant.
This demonstrates the minimal implementation needed for an agent
as described in the article - just needs to understand tool calls
and respond appropriately.
"""
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Handle non-streaming requests."""
await asyncio.sleep(0.1) # Simulate thinking time
return self._build_response(request)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Handle streaming requests."""
await asyncio.sleep(0.1)
response = self._build_response(request)
if response.tool_calls:
yield LlmStreamChunk(tool_calls=response.tool_calls)
if response.content:
# Simulate streaming by chunking the response
words = response.content.split()
for i, word in enumerate(words):
chunk = word if i == 0 else f" {word}"
await asyncio.sleep(0.05) # Simulate streaming delay
yield LlmStreamChunk(content=chunk)
yield LlmStreamChunk(finish_reason=response.finish_reason)
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Validate tools - no errors for this simple implementation."""
return []
def _build_response(self, request: LlmRequest) -> LlmResponse:
"""Build a response based on the conversation context."""
last_message = request.messages[-1] if request.messages else None
# If we just got a tool result, respond to it
if last_message and last_message.role == "tool":
tool_result = last_message.content or "Tool executed"
return LlmResponse(
content=f"I've completed the operation. {tool_result}",
finish_reason="stop",
)
# If user is asking for file operations, use tools
if last_message and last_message.role == "user":
user_message = last_message.content.lower()
if "list files" in user_message or "show files" in user_message:
return LlmResponse(
content="I'll list the files for you.",
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="list_files",
arguments={},
)
],
finish_reason="tool_calls",
)
elif "read" in user_message and (
"file" in user_message
or ".py" in user_message
or ".txt" in user_message
):
filename = _extract_filename(user_message)
if filename:
return LlmResponse(
content=f"I'll read the file '{filename}' for you.",
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="read_file",
arguments={"filename": filename},
)
],
finish_reason="tool_calls",
)
elif "create" in user_message or "write" in user_message:
# Suggest creating a simple example file
return LlmResponse(
content="I'll create an example Python file for you.",
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="write_file",
arguments={
"filename": "example.py",
"content": "# Example Python file\nprint('Hello from the coding agent!')\n\ndef greet(name):\n return f'Hello, {name}!'\n\nif __name__ == '__main__':\n print(greet('World'))\n",
"overwrite": True,
},
)
],
finish_reason="tool_calls",
)
elif (
"run" in user_message or "execute" in user_message
) and ".py" in user_message:
filename = _extract_filename(user_message)
if filename:
return LlmResponse(
content=f"I'll run the Python file '{filename}'.",
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="run_python_file",
arguments={
"filename": filename,
"arguments": [],
},
)
],
finish_reason="tool_calls",
)
elif (
"edit" in user_message
or "update" in user_message
or "modify" in user_message
):
return LlmResponse(
content="I'll update the greet function to make it more descriptive.",
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="edit_file",
arguments={
"filename": "example.py",
"edits": [
{
"start_line": 4,
"end_line": 5,
"new_content": (
"def greet(name):\n"
' """Return a friendly greeting."""\n'
' return f"Hello, {name}! Welcome to the coding agent."\n'
),
}
],
},
)
],
finish_reason="tool_calls",
)
# Default response
return LlmResponse(
content=(
"I'm a coding assistant. I can help you list, read, write, edit, and run Python files. "
"Try asking me to 'list files', 'read example.py', 'create a Python file', 'run example.py', or 'update example.py'."
),
finish_reason="stop",
)
def create_demo_agent() -> Agent:
"""
Create a coding agent with file operation tools.
This follows the pattern from the article - minimal code
to create a powerful code-editing agent. Uses dependency injection
for file system operations with LocalFileSystem as default.
"""
# Create tool registry and register file system tools
tool_registry = ToolRegistry()
# Use the convenience function to create tools with default LocalFileSystem
for tool in create_file_system_tools():
tool_registry.register(tool)
for tool in create_python_tools():
tool_registry.register(tool)
# Create LLM service
llm_service = CodingLlmService()
# Create agent with configuration
return Agent(
llm_service=llm_service,
tool_registry=tool_registry,
config=AgentConfig(
stream_responses=True,
include_thinking_indicators=True,
max_tool_iterations=3,
),
)
async def main() -> None:
"""
Demonstrate the coding agent in action.
As the article mentions: "300 lines of code and three tools and now
you're able to talk to an alien intelligence that edits your code."
"""
print("🤖 Starting Coding Agent Demo")
print("This demonstrates the concepts from 'How to Build an Agent'")
print("-" * 50)
# Create the agent
agent = create_demo_agent()
# Create a test user
user = User(id="coder123", username="developer", permissions=[])
# Show available tools
tools = await agent.get_available_tools(user)
print(f"Available tools: {[tool.name for tool in tools]}")
print()
# Demo conversation
conversation_id = "coding-session"
demos = [
"Hello! Can you list the files in this directory?",
"Can you create a simple Python file for me?",
"Now read the example.py file you just created",
"Please update the greet function to include a docstring and a friendlier message.",
"Run example.py so I can see its output.",
"Great, read example.py again to confirm the changes.",
]
for i, message in enumerate(demos, 1):
print(f"Demo {i}: {message}")
print("Agent response:")
async for component in agent.send_message(
user=user, message=message, conversation_id=conversation_id
):
if (
hasattr(component.rich_component, "content")
and component.rich_component.content
):
print(f" 📝 {component.rich_component.content}")
elif hasattr(component.rich_component, "message"):
print(f" 💬 {component.rich_component.message}")
elif component.simple_component and hasattr(
component.simple_component, "text"
):
print(f" 📄 {component.simple_component.text}")
print("-" * 30)
def _extract_filename(message: str) -> Optional[str]:
"""Extract a likely filename token from a user message."""
for token in message.replace("\n", " ").split():
cleaned = token.strip("'\".,;!?")
if "." in cleaned and not cleaned.startswith("."):
return cleaned
return None
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/examples/custom_system_prompt_example.py
================================================
"""
Example demonstrating custom system prompt builder with dependency injection.
This example shows how to create a custom SystemPromptBuilder that dynamically
generates system prompts based on user context and available tools.
Usage:
python -m vanna.examples.custom_system_prompt_example
"""
from typing import List, Optional
from vanna.core.interfaces import SystemPromptBuilder
from vanna.core.models import ToolSchema, User
class CustomSystemPromptBuilder(SystemPromptBuilder):
"""Custom system prompt builder that personalizes prompts based on user."""
async def build_system_prompt(
self, user: User, tools: List[ToolSchema]
) -> Optional[str]:
"""Build a personalized system prompt.
Args:
user: The user making the request
tools: List of tools available to the user
Returns:
Personalized system prompt
"""
# Build personalized greeting
username = user.username or user.id
greeting = f"Hello {username}! I'm your AI assistant."
# Add role-specific instructions based on user permissions
role_instructions = []
if "admin" in user.permissions:
role_instructions.append(
"As an admin user, you have access to all tools and capabilities."
)
elif "analyst" in user.permissions:
role_instructions.append(
"You're working as an analyst. I'll help you query and visualize data."
)
else:
role_instructions.append("I'm here to help you with your tasks.")
# List available tools
tool_info = []
if tools:
tool_info.append("\nAvailable tools:")
for tool in tools:
tool_info.append(f"- {tool.name}: {tool.description}")
# Combine all parts
parts = [greeting] + role_instructions + tool_info
return "\n".join(parts)
class SQLAssistantSystemPromptBuilder(SystemPromptBuilder):
"""System prompt builder specifically for SQL database assistants."""
def __init__(self, database_name: str = "database"):
"""Initialize with database context.
Args:
database_name: Name of the database being queried
"""
self.database_name = database_name
async def build_system_prompt(
self, user: User, tools: List[ToolSchema]
) -> Optional[str]:
"""Build a SQL-focused system prompt.
Args:
user: The user making the request
tools: List of tools available to the user
Returns:
SQL-focused system prompt
"""
prompt = f"""You are an expert SQL database assistant for the {self.database_name} database.
Your primary responsibilities:
1. Write efficient, correct SQL queries
2. Explain query results clearly
3. Suggest optimizations when relevant
4. Visualize data when appropriate
Guidelines:
- Always validate SQL syntax before execution
- Use appropriate JOINs and avoid Cartesian products
- Limit result sets to reasonable sizes by default
- Format numbers and dates for readability
"""
# Add tool-specific instructions
has_viz_tool = any(tool.name == "visualize_data" for tool in tools)
if has_viz_tool:
prompt += "\n- Create visualizations for numerical data when it helps understanding"
return prompt
async def demo() -> None:
"""Demonstrate custom system prompt builders."""
from vanna import Agent, User
from vanna.core.registry import ToolRegistry
from vanna.integrations.anthropic.mock import MockLlmService
# Example 1: Custom personalized system prompt
print("=" * 60)
print("Example 1: Custom Personalized System Prompt")
print("=" * 60)
custom_builder = CustomSystemPromptBuilder()
admin_user = User(id="user-1", username="Alice", permissions=["admin"])
# Simulate some tools
mock_tools = [
ToolSchema(
name="query_database", description="Query the SQL database", parameters={}
),
ToolSchema(
name="visualize_data",
description="Create data visualizations",
parameters={},
),
]
prompt = await custom_builder.build_system_prompt(admin_user, mock_tools)
print("\nGenerated system prompt for admin user:")
print("-" * 60)
print(prompt)
print("-" * 60)
# Example 2: SQL-specific system prompt
print("\n" + "=" * 60)
print("Example 2: SQL Assistant System Prompt")
print("=" * 60)
sql_builder = SQLAssistantSystemPromptBuilder(database_name="Chinook")
analyst_user = User(id="user-2", username="Bob", permissions=["analyst"])
prompt = await sql_builder.build_system_prompt(analyst_user, mock_tools)
print("\nGenerated system prompt for SQL assistant:")
print("-" * 60)
print(prompt)
print("-" * 60)
# Example 3: Using custom builder with Agent
print("\n" + "=" * 60)
print("Example 3: Using Custom Builder with Agent")
print("=" * 60)
mock_llm = MockLlmService()
tool_registry = ToolRegistry()
agent = Agent(
llm_service=mock_llm,
tool_registry=tool_registry,
system_prompt_builder=sql_builder, # Inject custom builder here
)
print("\nAgent created with custom SQL system prompt builder!")
print("The agent will now use the SQL-focused system prompt for all interactions.")
if __name__ == "__main__":
import asyncio
asyncio.run(demo())
================================================
FILE: src/vanna/examples/default_workflow_handler_example.py
================================================
"""
Example demonstrating the DefaultWorkflowHandler with setup health checking.
This example shows how the DefaultWorkflowHandler provides intelligent starter UI
that adapts based on available tools and helps users understand their setup status.
Run:
PYTHONPATH=. python vanna/examples/default_workflow_handler_example.py
"""
import asyncio
from vanna import (
AgentConfig,
Agent,
MemoryConversationStore,
MockLlmService,
User,
DefaultWorkflowHandler,
)
from vanna.core.registry import ToolRegistry
from vanna.core.user.resolver import SimpleUserResolver
from vanna.tools import ListFilesTool
async def demonstrate_setup_scenarios():
"""Demonstrate different setup scenarios with DefaultWorkflowHandler."""
print("🚀 Starting DefaultWorkflowHandler Setup Health Check Demo\n")
# Create basic components
llm_service = MockLlmService(response_content="I'm ready to help!")
conversation_store = MemoryConversationStore()
user_resolver = SimpleUserResolver()
# Create test user
user = User(
id="user1",
username="alice",
email="alice@example.com",
group_memberships=["user"],
)
print("=" * 60)
print("SCENARIO 1: Empty Setup (No Tools)")
print("=" * 60)
# Empty tool registry
empty_registry = ToolRegistry()
agent_empty = Agent(
llm_service=llm_service,
tool_registry=empty_registry,
user_resolver=user_resolver,
conversation_store=conversation_store,
config=AgentConfig(stream_responses=False),
workflow_handler=DefaultWorkflowHandler(),
)
print("📋 Starter UI for empty setup:")
async for component in agent_empty.send_message(
request_context=user_resolver.create_request_context(
metadata={"starter_ui_request": True}
),
message="",
conversation_id="empty-setup",
):
if hasattr(component, "simple_component") and component.simple_component:
print(f" 📄 {component.simple_component.text[:100]}...")
elif hasattr(component, "rich_component"):
comp = component.rich_component
if hasattr(comp, "title"):
print(f" 📊 {comp.title}: {comp.status} - {comp.description}")
elif hasattr(comp, "content"):
print(f" 📝 {comp.content[:100]}...")
print("\n" + "=" * 60)
print("SCENARIO 2: Functional Setup (SQL + Basic Tools)")
print("=" * 60)
# Tool registry with SQL tool (simulated)
functional_registry = ToolRegistry()
# Register a mock SQL tool (we'll simulate by tool name)
list_tool = ListFilesTool()
list_tool.name = "run_sql" # Simulate SQL tool
functional_registry.register(list_tool)
agent_functional = Agent(
llm_service=llm_service,
tool_registry=functional_registry,
user_resolver=user_resolver,
conversation_store=conversation_store,
config=AgentConfig(stream_responses=False),
workflow_handler=DefaultWorkflowHandler(),
)
print("📋 Starter UI for functional setup:")
async for component in agent_functional.send_message(
request_context=user_resolver.create_request_context(
metadata={"starter_ui_request": True}
),
message="",
conversation_id="functional-setup",
):
if hasattr(component, "simple_component") and component.simple_component:
print(f" 📄 {component.simple_component.text[:100]}...")
elif hasattr(component, "rich_component"):
comp = component.rich_component
if hasattr(comp, "title"):
print(f" 📊 {comp.title}: {comp.status} - {comp.description}")
elif hasattr(comp, "content"):
print(f" 📝 {comp.content[:100]}...")
print("\n" + "=" * 60)
print("SCENARIO 3: Complete Setup (SQL + Memory + Visualization)")
print("=" * 60)
# Complete tool registry
complete_registry = ToolRegistry()
# Mock SQL tool
sql_tool = ListFilesTool()
sql_tool.name = "run_sql"
complete_registry.register(sql_tool)
# Mock memory tools
search_tool = ListFilesTool()
search_tool.name = "search_saved_correct_tool_uses"
complete_registry.register(search_tool)
save_tool = ListFilesTool()
save_tool.name = "save_question_tool_args"
complete_registry.register(save_tool)
# Mock visualization tool
viz_tool = ListFilesTool()
viz_tool.name = "visualize_data"
complete_registry.register(viz_tool)
agent_complete = Agent(
llm_service=llm_service,
tool_registry=complete_registry,
user_resolver=user_resolver,
conversation_store=conversation_store,
config=AgentConfig(stream_responses=False),
workflow_handler=DefaultWorkflowHandler(),
)
print("📋 Starter UI for complete setup:")
async for component in agent_complete.send_message(
request_context=user_resolver.create_request_context(
metadata={"starter_ui_request": True}
),
message="",
conversation_id="complete-setup",
):
if hasattr(component, "simple_component") and component.simple_component:
print(f" 📄 {component.simple_component.text[:100]}...")
elif hasattr(component, "rich_component"):
comp = component.rich_component
if hasattr(comp, "title"):
print(f" 📊 {comp.title}: {comp.status} - {comp.description}")
elif hasattr(comp, "content"):
print(f" 📝 {comp.content[:100]}...")
print("\n" + "=" * 60)
print("SCENARIO 4: Testing Commands")
print("=" * 60)
print("📋 Testing /help command:")
async for component in agent_complete.send_message(
request_context=user_resolver.create_request_context(),
message="/help",
conversation_id="help-test",
):
if hasattr(component, "rich_component") and hasattr(
component.rich_component, "content"
):
print(f" 📝 Help: {component.rich_component.content[:200]}...")
print("\n📋 Testing /status command:")
async for component in agent_complete.send_message(
request_context=user_resolver.create_request_context(),
message="/status",
conversation_id="status-test",
):
if hasattr(component, "rich_component"):
comp = component.rich_component
if hasattr(comp, "title"):
print(f" 📊 {comp.title}: {comp.status}")
elif hasattr(comp, "content"):
print(f" 📝 Status: {comp.content[:150]}...")
print("\n✅ Demo complete! The DefaultWorkflowHandler provides:")
print(" • Smart setup health checking")
print(" • Contextual starter UI based on available tools")
print(" • Helpful error messages and setup guidance")
print(" • Built-in command handling (/help, /status)")
print(" • Automatic tool analysis and recommendations")
async def main():
"""Run the DefaultWorkflowHandler demonstration."""
await demonstrate_setup_scenarios()
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/examples/email_auth_example.py
================================================
"""
Email authentication example for the Vanna Agents framework.
This example demonstrates how to create an agent with email-based authentication
where users are prompted for their email address in chat and the system creates
a user profile based on that email.
## What This Example Shows
1. **UserService Implementation**: A demo `DemoEmailUserService` that:
- Stores users in memory
- Authenticates users by email validation
- Creates user profiles automatically
- Manages user permissions
2. **Authentication Tool**: An `AuthTool` that:
- Takes an email address as input
- Uses the UserService to authenticate/create users
- Returns rich UI components for success/error feedback
- Provides structured results for the LLM
3. **In-Chat Authentication Flow**: Shows how:
- Users can provide their email in natural conversation
- The agent can prompt for authentication when needed
- Authentication results are displayed with rich UI components
- The system maintains user context across conversations
## Key Components
- `DemoEmailUserService`: Implements the `UserService` interface
- `AuthTool`: Implements the `Tool` interface for authentication
- Rich UI components for authentication feedback
- Integration with the agent's tool registry and conversation store
## Usage
Interactive: python -m vanna.examples.email_auth_example
## Note
This example uses a simplified mock LLM that doesn't actually call tools.
In a real implementation with OpenAI or Anthropic, the LLM would automatically
detect email addresses in user messages and call the authenticate_user tool.
For production use, you would:
- Replace DemoEmailUserService with a database-backed implementation
- Add proper email validation and security measures
- Implement session management in the server layer
- Add proper error handling and rate limiting
"""
import asyncio
from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Field
from vanna import (
AgentConfig,
Agent,
MemoryConversationStore,
MockLlmService,
User,
)
from vanna.core import Tool, UserService
from vanna.core import ToolContext, ToolResult
from vanna.core.registry import ToolRegistry
from vanna.core.components import UiComponent
from vanna.core import RichComponent
# Demo User Service Implementation
class DemoEmailUserService(UserService):
"""Demo user service that authenticates users by email."""
def __init__(self):
"""Initialize with in-memory user store."""
self._users: Dict[str, User] = {} # user_id -> User
self._email_to_id: Dict[str, str] = {} # email -> user_id
async def get_user(self, user_id: str) -> Optional[User]:
"""Get user by ID."""
return self._users.get(user_id)
async def authenticate(self, credentials: Dict[str, Any]) -> Optional[User]:
"""Authenticate user by email."""
email = credentials.get("email")
if not email or not self._is_valid_email(email):
return None
# Check if user exists
user_id = self._email_to_id.get(email)
if user_id:
return self._users[user_id]
# Create new user
user_id = f"user_{len(self._users) + 1}"
username = email.split("@")[0]
user = User(
id=user_id,
username=username,
email=email,
permissions=["basic_user"],
metadata={"auth_method": "email"},
)
self._users[user_id] = user
self._email_to_id[email] = user_id
return user
async def has_permission(self, user: User, permission: str) -> bool:
"""Check if user has permission."""
return permission in user.permissions
def _is_valid_email(self, email: str) -> bool:
"""Simple email validation."""
return "@" in email and "." in email.split("@")[1]
# Authentication Tool
class AuthArgs(BaseModel):
"""Arguments for authentication."""
email: str = Field(description="User's email address")
class AuthTool(Tool[AuthArgs]):
"""Tool to authenticate users by email."""
def __init__(self, user_service: DemoEmailUserService):
self.user_service = user_service
@property
def name(self) -> str:
return "authenticate_user"
@property
def description(self) -> str:
return "Authenticate a user by their email address. Use this when the user provides an email."
def get_args_schema(self) -> Type[AuthArgs]:
return AuthArgs
async def execute(self, context: ToolContext, args: AuthArgs) -> ToolResult:
"""Execute authentication."""
user = await self.user_service.authenticate({"email": args.email})
if user:
success_msg = (
f"✅ Welcome {user.username}! You're now authenticated as {user.email}"
)
auth_component = RichComponent(
type="status_card",
data={
"title": "Authentication Success",
"status": "success",
"description": success_msg,
"icon": "✅",
"metadata": {
"user_id": user.id,
"username": user.username,
"email": user.email,
},
},
)
return ToolResult(
success=True,
result_for_llm=f"User successfully authenticated as {user.username} ({user.email}). They can now access personalized features.",
ui_component=UiComponent(rich_component=auth_component),
)
else:
error_msg = f"❌ Invalid email format: {args.email}"
error_component = RichComponent(
type="status_card",
data={
"title": "Authentication Failed",
"status": "error",
"description": error_msg,
"icon": "❌",
"metadata": {"email": args.email},
},
)
return ToolResult(
success=False,
result_for_llm=f"Authentication failed for {args.email}. Please provide a valid email address.",
ui_component=UiComponent(rich_component=error_component),
error=error_msg,
)
def create_demo_agent() -> Agent:
"""Create a demo agent for REPL and server usage.
Returns:
Configured Agent instance with email authentication
"""
return create_auth_agent()
def create_auth_agent() -> Agent:
"""Create agent with email authentication."""
# Create user service
user_service = DemoEmailUserService()
# Use simple mock LLM - the system prompt will guide behavior
llm_service = MockLlmService(
response_content="Hello! I'm your AI assistant. To provide you with personalized help, I'll need your email address for authentication. Please share your email with me, and I'll use the authenticate_user tool to set up your profile."
)
# Create tool registry with auth tool
tool_registry = ToolRegistry()
auth_tool = AuthTool(user_service)
tool_registry.register(auth_tool)
# Create agent with authentication system prompt
agent = Agent(
llm_service=llm_service,
config=AgentConfig(
stream_responses=True,
include_thinking_indicators=False, # Cleaner output for demo
system_prompt="""You are a helpful AI assistant with an email-based authentication system.
AUTHENTICATION BEHAVIOR:
1. When a user provides an email address in their message, immediately use the 'authenticate_user' tool
2. Look for emails in patterns like "my email is...", "I'm john@example.com", or any text with @ symbols
3. If user isn't authenticated, politely ask for their email address to get started
4. After successful authentication, welcome them by name and offer personalized assistance
5. Be friendly and helpful throughout the process
Remember: Authentication is required for personalized features!""",
),
tool_registry=tool_registry,
conversation_store=MemoryConversationStore(),
)
return agent
async def demo_auth_flow():
"""Demonstrate the authentication flow with simple output."""
agent = create_auth_agent()
# Start with anonymous user
user = User(id="anonymous", username="guest", email=None, permissions=[])
conversation_id = "auth_demo_conv"
print("=== Email Authentication Demo ===")
print("This example shows how an agent can authenticate users via email in chat.")
print("Note: This uses a simple mock LLM for demonstration purposes.\n")
# Demo conversation
print("🔹 Step 1: Initial greeting")
print("User: Hello!")
print("Agent: ", end="")
async for component in agent.send_message(
user=user, message="Hello!", conversation_id=conversation_id
):
if (
hasattr(component, "rich_component")
and component.rich_component.type.value == "text"
):
content = component.rich_component.data.get("content") or getattr(
component.rich_component, "content", ""
)
if content:
print(content)
break
print("\n" + "=" * 60)
print("\n🔹 Step 2: User provides email for authentication")
print("User: My email is alice@example.com")
print("Agent: ", end="")
# This should trigger the auth tool
auth_shown = False
async for component in agent.send_message(
user=user,
message="My email is alice@example.com",
conversation_id=conversation_id,
):
if hasattr(component, "rich_component"):
rich_comp = component.rich_component
if rich_comp.type.value == "status_card" and not auth_shown:
status = rich_comp.data.get("status", "")
desc = rich_comp.data.get("description", "")
if status == "success":
auth_shown = True
print(f"🔐 {desc}")
break
print("\n" + "=" * 60)
print("\n🔹 Step 3: Post-authentication interaction")
print("User: What can you help me with now?")
print("Agent: ", end="")
async for component in agent.send_message(
user=user,
message="What can you help me with now?",
conversation_id=conversation_id,
):
if (
hasattr(component, "rich_component")
and component.rich_component.type.value == "text"
):
content = component.rich_component.data.get("content") or getattr(
component.rich_component, "content", ""
)
if content:
print(content)
break
print("\n" + "=" * 60)
print("\n✅ Authentication demo complete!")
print("\nKey Features Demonstrated:")
print("• Email-based user authentication")
print("• Tool-based authentication flow")
print("• In-memory user storage and management")
print("• Rich UI components for auth feedback")
async def main():
"""Run the authentication example."""
await demo_auth_flow()
def run_interactive():
"""Entry point for interactive usage."""
print("Starting email authentication example...")
asyncio.run(main())
if __name__ == "__main__":
run_interactive()
================================================
FILE: src/vanna/examples/evaluation_example.py
================================================
"""
Evaluation System Example
This example demonstrates how to use the evaluation framework to test
and compare agents. Shows:
- Creating test cases programmatically
- Running evaluations with multiple evaluators
- Comparing agent variants (e.g., different LLMs)
- Generating reports
Usage:
PYTHONPATH=. python vanna/examples/evaluation_example.py
"""
import asyncio
from vanna import Agent, MockLlmService, MemoryConversationStore, User
from vanna.core.evaluation import (
EvaluationRunner,
EvaluationDataset,
TestCase,
ExpectedOutcome,
AgentVariant,
TrajectoryEvaluator,
OutputEvaluator,
EfficiencyEvaluator,
)
from vanna.core.registry import ToolRegistry
def create_sample_dataset() -> EvaluationDataset:
"""Create a sample dataset for demonstration."""
eval_user = User(
id="eval_user", username="evaluator", email="eval@example.com", permissions=[]
)
test_cases = [
TestCase(
id="test_001",
user=eval_user,
message="Hello, how are you?",
expected_outcome=ExpectedOutcome(
final_answer_contains=["hello", "hi"],
max_execution_time_ms=3000,
),
metadata={"category": "greeting", "difficulty": "easy"},
),
TestCase(
id="test_002",
user=eval_user,
message="What can you help me with?",
expected_outcome=ExpectedOutcome(
final_answer_contains=["help", "assist"],
max_execution_time_ms=3000,
),
metadata={"category": "capabilities", "difficulty": "easy"},
),
TestCase(
id="test_003",
user=eval_user,
message="Explain quantum computing",
expected_outcome=ExpectedOutcome(
final_answer_contains=["quantum", "computing"],
min_components=1,
max_execution_time_ms=5000,
),
metadata={"category": "explanation", "difficulty": "medium"},
),
]
return EvaluationDataset(
name="Demo Test Cases",
test_cases=test_cases,
description="Sample test cases for evaluation demo",
)
def create_test_agent(name: str, response_content: str) -> Agent:
"""Create a test agent with mock LLM."""
return Agent(
llm_service=MockLlmService(response_content=response_content),
tool_registry=ToolRegistry(),
conversation_store=MemoryConversationStore(),
)
async def demo_single_agent_evaluation():
"""Demonstrate evaluating a single agent."""
print("\n" + "=" * 80)
print("DEMO 1: Single Agent Evaluation")
print("=" * 80 + "\n")
# Create dataset
dataset = create_sample_dataset()
print(f"Loaded dataset: {dataset.name}")
print(f"Test cases: {len(dataset.test_cases)}\n")
# Create agent
agent = create_test_agent(
"test-agent",
"Hello! I'm here to help you with various tasks including answering questions about topics like quantum computing.",
)
# Create evaluators
evaluators = [
TrajectoryEvaluator(),
OutputEvaluator(),
EfficiencyEvaluator(max_execution_time_ms=5000),
]
# Run evaluation
runner = EvaluationRunner(evaluators=evaluators, max_concurrency=5)
print("Running evaluation...")
report = await runner.run_evaluation(agent, dataset.test_cases)
# Print results
report.print_summary()
# Show failures
failures = report.get_failures()
if failures:
print("\nFailed test cases:")
for result in failures:
print(f" - {result.test_case.id}: {result.test_case.message}")
async def demo_agent_comparison():
"""Demonstrate comparing multiple agent variants."""
print("\n" + "=" * 80)
print("DEMO 2: Agent Comparison (LLM Comparison Use Case)")
print("=" * 80 + "\n")
# Create dataset
dataset = create_sample_dataset()
print(f"Loaded dataset: {dataset.name}")
print(f"Test cases: {len(dataset.test_cases)}\n")
# Create agent variants
variants = [
AgentVariant(
name="agent-v1",
agent=create_test_agent(
"v1",
"Hi there! I can help you with many things including explaining complex topics like quantum computing.",
),
metadata={"version": "1.0", "model": "mock-v1"},
),
AgentVariant(
name="agent-v2",
agent=create_test_agent(
"v2",
"Hello! I'm your helpful assistant. I can assist with various tasks and explain topics like quantum computing in detail.",
),
metadata={"version": "2.0", "model": "mock-v2"},
),
AgentVariant(
name="agent-v3",
agent=create_test_agent(
"v3",
"Greetings! I'm designed to help you with a wide range of tasks, from simple questions to complex explanations about quantum computing and more.",
),
metadata={"version": "3.0", "model": "mock-v3"},
),
]
print(f"Created {len(variants)} agent variants:")
for v in variants:
print(f" - {v.name}")
print()
# Create evaluators
evaluators = [
OutputEvaluator(),
EfficiencyEvaluator(max_execution_time_ms=5000),
]
# Run comparison
runner = EvaluationRunner(evaluators=evaluators, max_concurrency=10)
print(
f"Running comparison ({len(variants)} variants × {len(dataset.test_cases)} test cases)..."
)
print("All variants running in parallel for maximum efficiency...\n")
comparison = await runner.compare_agents(variants, dataset.test_cases)
# Print results
comparison.print_summary()
# Show best variants
print("Best Performing Variants:")
print(f" 🏆 Best score: {comparison.get_best_variant('score')}")
print(f" ⚡ Fastest: {comparison.get_best_variant('speed')}")
print(f" ✅ Best pass rate: {comparison.get_best_variant('pass_rate')}")
async def demo_dataset_operations():
"""Demonstrate dataset creation and manipulation."""
print("\n" + "=" * 80)
print("DEMO 3: Dataset Operations")
print("=" * 80 + "\n")
# Create dataset
dataset = create_sample_dataset()
# Show dataset info
print(f"Dataset: {dataset.name}")
print(f"Description: {dataset.description}")
print(f"Total test cases: {len(dataset)}\n")
# Filter by metadata
easy_tests = dataset.filter_by_metadata(difficulty="easy")
medium_tests = dataset.filter_by_metadata(difficulty="medium")
print(f"Easy test cases: {len(easy_tests)}")
print(f"Medium test cases: {len(medium_tests)}\n")
# Save to file (for demonstration)
import tempfile
import os
with tempfile.TemporaryDirectory() as tmpdir:
yaml_path = os.path.join(tmpdir, "dataset.yaml")
json_path = os.path.join(tmpdir, "dataset.json")
dataset.save_yaml(yaml_path)
dataset.save_json(json_path)
print("Dataset saved to temporary files:")
print(f" - YAML: {yaml_path}")
print(f" - JSON: {json_path}\n")
# Load back
loaded_yaml = EvaluationDataset.from_yaml(yaml_path)
loaded_json = EvaluationDataset.from_json(json_path)
print("Loaded datasets:")
print(f" - From YAML: {len(loaded_yaml)} test cases")
print(f" - From JSON: {len(loaded_json)} test cases")
async def main():
"""Run all evaluation demos."""
print("\n🚀 Vanna Agents Evaluation System Demo")
print("=" * 80)
# Demo 1: Single agent evaluation
await demo_single_agent_evaluation()
# Demo 2: Agent comparison (main use case)
await demo_agent_comparison()
# Demo 3: Dataset operations
await demo_dataset_operations()
print("\n" + "=" * 80)
print("✅ All demos completed!")
print("=" * 80)
print("\nKey Takeaways:")
print(" 1. Evaluations are integral to the Vanna package")
print(" 2. Parallel execution handles I/O-bound LLM calls efficiently")
print(" 3. Agent comparison is a first-class use case")
print(" 4. Multiple evaluators can be composed for comprehensive testing")
print(" 5. Reports can be exported to HTML, CSV, or printed to console")
print("\nFor LLM comparison, see: evals/benchmarks/llm_comparison.py")
print()
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/examples/extensibility_example.py
================================================
"""
Comprehensive example demonstrating all extensibility interfaces.
This example shows how to use:
- LlmMiddleware for caching
- ErrorRecoveryStrategy for retry logic
- ToolContextEnricher for adding user preferences
- ConversationFilter for context window management
- ObservabilityProvider for monitoring
"""
import asyncio
import time
from typing import Any, Dict, List, Optional
from vanna.core import (
Agent,
LlmMiddleware,
ErrorRecoveryStrategy,
ToolContextEnricher,
ConversationFilter,
ObservabilityProvider,
User,
ToolContext,
Conversation,
Message,
LlmRequest,
LlmResponse,
Span,
Metric,
)
from vanna.core.recovery import RecoveryAction, RecoveryActionType
from vanna.core.registry import ToolRegistry
# 1. LlmMiddleware Example: Simple Caching
class CachingMiddleware(LlmMiddleware):
"""Cache LLM responses to reduce costs and latency."""
def __init__(self) -> None:
self.cache: Dict[str, LlmResponse] = {}
self.hits = 0
self.misses = 0
def _compute_cache_key(self, request: LlmRequest) -> str:
"""Create cache key from request."""
messages_str = str([(m.role, m.content) for m in request.messages])
return f"{messages_str}:{request.temperature}"
async def before_llm_request(self, request: LlmRequest) -> LlmRequest:
"""Check cache before sending request."""
cache_key = self._compute_cache_key(request)
if cache_key in self.cache:
self.hits += 1
print(f"[CACHE HIT] Cache stats: {self.hits} hits, {self.misses} misses")
return request
async def after_llm_response(
self, request: LlmRequest, response: LlmResponse
) -> LlmResponse:
"""Cache the response."""
cache_key = self._compute_cache_key(request)
if cache_key not in self.cache:
self.cache[cache_key] = response
self.misses += 1
print(f"[CACHE MISS] Caching response")
return response
# 2. ErrorRecoveryStrategy Example: Exponential Backoff
class ExponentialBackoffStrategy(ErrorRecoveryStrategy):
"""Retry failed operations with exponential backoff."""
def __init__(self, max_retries: int = 3) -> None:
self.max_retries = max_retries
async def handle_tool_error(
self, error: Exception, context: ToolContext, attempt: int = 1
) -> RecoveryAction:
"""Retry tool errors with exponential backoff."""
if attempt < self.max_retries:
delay_ms = (2 ** (attempt - 1)) * 1000
print(
f"[RETRY] Tool failed, retrying in {delay_ms}ms (attempt {attempt}/{self.max_retries})"
)
return RecoveryAction(
action=RecoveryActionType.RETRY,
retry_delay_ms=delay_ms,
message=f"Retrying after {delay_ms}ms",
)
print(f"[FAIL] Max retries exceeded for tool error: {error}")
return RecoveryAction(
action=RecoveryActionType.FAIL,
message=f"Tool error after {self.max_retries} attempts: {str(error)}",
)
async def handle_llm_error(
self, error: Exception, request: LlmRequest, attempt: int = 1
) -> RecoveryAction:
"""Retry LLM errors with backoff."""
if attempt < self.max_retries:
delay_ms = (2 ** (attempt - 1)) * 1000
print(
f"[RETRY] LLM failed, retrying in {delay_ms}ms (attempt {attempt}/{self.max_retries})"
)
return RecoveryAction(
action=RecoveryActionType.RETRY,
retry_delay_ms=delay_ms,
message=f"Retrying LLM after {delay_ms}ms",
)
print(f"[FAIL] Max retries exceeded for LLM error: {error}")
return RecoveryAction(
action=RecoveryActionType.FAIL,
message=f"LLM error after {self.max_retries} attempts: {str(error)}",
)
# 3. ToolContextEnricher Example: Add User Preferences
class UserPreferencesEnricher(ToolContextEnricher):
"""Enrich context with user preferences."""
def __init__(self) -> None:
# Mock user preferences database
self.preferences: Dict[str, Dict[str, Any]] = {
"user123": {
"timezone": "America/New_York",
"language": "en",
"theme": "dark",
}
}
async def enrich_context(self, context: ToolContext) -> ToolContext:
"""Add user preferences to context."""
prefs = self.preferences.get(context.user.id, {})
context.metadata["user_preferences"] = prefs
context.metadata["timezone"] = prefs.get("timezone", "UTC")
print(f"[ENRICH] Added preferences for user {context.user.id}: {prefs}")
return context
# 4. ConversationFilter Example: Context Window Management
class ContextWindowFilter(ConversationFilter):
"""Limit conversation to fit within context window."""
def __init__(self, max_messages: int = 20) -> None:
self.max_messages = max_messages
async def filter_messages(self, messages: List[Message]) -> List[Message]:
"""Keep only recent messages within limit."""
if len(messages) <= self.max_messages:
return messages
# Keep system messages and recent messages
system_messages = [m for m in messages if m.role == "system"]
other_messages = [m for m in messages if m.role != "system"]
# Take the most recent messages
recent_messages = other_messages[-self.max_messages :]
filtered = system_messages + recent_messages
print(f"[FILTER] Reduced {len(messages)} messages to {len(filtered)}")
return filtered
# 5. ObservabilityProvider Example: Simple Logging
class LoggingObservabilityProvider(ObservabilityProvider):
"""Log metrics and spans for monitoring."""
def __init__(self) -> None:
self.metrics: List[Metric] = []
self.spans: List[Span] = []
async def record_metric(
self,
name: str,
value: float,
unit: str = "",
tags: Optional[Dict[str, str]] = None,
) -> None:
"""Record and log a metric."""
metric = Metric(name=name, value=value, unit=unit, tags=tags or {})
self.metrics.append(metric)
tags_str = ", ".join(f"{k}={v}" for k, v in (tags or {}).items())
print(f"[METRIC] {name}: {value}{unit} {tags_str}")
async def create_span(
self, name: str, attributes: Optional[Dict[str, Any]] = None
) -> Span:
"""Create a span for tracing."""
span = Span(name=name, attributes=attributes or {})
print(f"[SPAN START] {name}")
return span
async def end_span(self, span: Span) -> None:
"""End and record a span."""
span.end()
self.spans.append(span)
duration = span.duration_ms() or 0
print(f"[SPAN END] {span.name}: {duration:.2f}ms")
async def run_example() -> None:
"""
Example showing all extensibility interfaces working together.
"""
from vanna.integrations.anthropic import AnthropicLlmService
# Create all extensibility components
caching_middleware = CachingMiddleware()
retry_strategy = ExponentialBackoffStrategy(max_retries=3)
preferences_enricher = UserPreferencesEnricher()
context_filter = ContextWindowFilter(max_messages=20)
observability = LoggingObservabilityProvider()
# Mock conversation store
class MockStore:
async def get_conversation(self, cid: str, uid: str) -> Optional[Conversation]:
return None
async def create_conversation(
self, cid: str, uid: str, title: str
) -> Conversation:
return Conversation(
id=cid, user_id=uid, messages=[Message(role="user", content=title)]
)
async def update_conversation(self, conv: Conversation) -> None:
pass
async def delete_conversation(self, cid: str, uid: str) -> bool:
return False
async def list_conversations(
self, uid: str, limit: int = 50, offset: int = 0
) -> List[Conversation]:
return []
# Create agent with all extensibility components
agent = Agent(
llm_service=AnthropicLlmService(api_key="test-key"),
tool_registry=ToolRegistry(),
conversation_store=MockStore(), # type: ignore
llm_middlewares=[caching_middleware],
error_recovery_strategy=retry_strategy,
context_enrichers=[preferences_enricher],
conversation_filters=[context_filter],
observability_provider=observability,
)
print("✓ Agent created with all extensibility components:")
print(f" - LLM Middleware: {len(agent.llm_middlewares)} middlewares")
print(f" - Error Recovery: {type(agent.error_recovery_strategy).__name__}")
print(f" - Context Enrichers: {len(agent.context_enrichers)} enrichers")
print(f" - Conversation Filters: {len(agent.conversation_filters)} filters")
print(f" - Observability: {type(agent.observability_provider).__name__}")
print("\n🎉 All extensibility interfaces integrated successfully!")
if __name__ == "__main__":
asyncio.run(run_example())
================================================
FILE: src/vanna/examples/minimal_example.py
================================================
"""Minimal Claude + SQLite example ready for FastAPI."""
from __future__ import annotations
import os
from pathlib import Path
from vanna import AgentConfig, Agent
from vanna.core.registry import ToolRegistry
from vanna.integrations.anthropic import AnthropicLlmService
from vanna.integrations.sqlite import SqliteRunner
from vanna.integrations.local import LocalFileSystem
from vanna.tools import (
RunSqlTool,
# Visualization
VisualizeDataTool,
# Python execution
RunPythonFileTool,
PipInstallTool,
# File system (for coding agents)
SearchFilesTool,
ListFilesTool,
ReadFileTool,
WriteFileTool,
)
_DB = Path(__file__).resolve().parents[2] / "Chinook.sqlite"
def create_demo_agent() -> Agent:
# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()
llm = AnthropicLlmService(model=os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5"))
# Shared file system for all tools
file_system = LocalFileSystem("./claude_data")
tools = ToolRegistry()
# 1. Basic SQL agent - query databases
tools.register(
RunSqlTool(
sql_runner=SqliteRunner(database_path=str(_DB)),
file_system=file_system,
)
)
# 2. Add visualization - create charts from data
tools.register(VisualizeDataTool(file_system=file_system))
# 3. Add Python execution - build dashboards with artifacts
# tools.register(RunPythonFileTool(file_system=file_system))
# tools.register(PipInstallTool(file_system=file_system))
# 4. Full coding agent - read, write, search files
# tools.register(SearchFilesTool(file_system=file_system))
# tools.register(ListFilesTool(file_system=file_system))
# tools.register(ReadFileTool(file_system=file_system))
# tools.register(WriteFileTool(file_system=file_system))
return Agent(
llm_service=llm,
tool_registry=tools,
)
================================================
FILE: src/vanna/examples/mock_auth_example.py
================================================
"""
Mock authentication example to verify user resolution is working.
This example demonstrates the new UserResolver architecture where:
1. UserResolver is a required parameter of Agent
2. Agent.send_message() accepts RequestContext (not User directly)
3. The Agent resolves the user internally using the UserResolver
The agent uses an LLM middleware to inject user info into the response,
so we can verify the authentication is working correctly.
Usage:
python -m vanna.examples.mock_auth_example
"""
from __future__ import annotations
import asyncio
from vanna import AgentConfig, Agent
from vanna.core.registry import ToolRegistry
from vanna.core.llm import LlmRequest, LlmResponse
from vanna.core.middleware import LlmMiddleware
from vanna.integrations.mock import MockLlmService
from vanna.core.user import CookieEmailUserResolver, RequestContext
class UserEchoMiddleware(LlmMiddleware):
"""Middleware that injects user email into LLM responses."""
async def after_llm_response(
self, request: LlmRequest, response: LlmResponse
) -> LlmResponse:
"""Inject user email into response."""
# Extract user email from request user_id (which is set to user.id in the agent)
user_id = request.user_id
# Create a new response with user info
new_content = f"Hello! You are authenticated as: {user_id}"
return LlmResponse(
content=new_content,
finish_reason=response.finish_reason,
usage=response.usage,
)
def create_demo_agent() -> Agent:
"""Create a demo agent for server usage.
Returns:
Configured Agent instance with cookie-based authentication
"""
# Create a mock LLM
llm_service = MockLlmService(response_content="Mock response")
# Empty tool registry
tool_registry = ToolRegistry()
# Cookie-based user resolver
user_resolver = CookieEmailUserResolver(cookie_name="vanna_email")
# User echo middleware
middleware = UserEchoMiddleware()
# Create agent with user resolver and middleware
agent = Agent(
llm_service=llm_service,
tool_registry=tool_registry,
user_resolver=user_resolver,
llm_middlewares=[middleware],
config=AgentConfig(
stream_responses=True,
include_thinking_indicators=False,
),
)
return agent
async def demo_authentication():
"""Demonstrate authentication with different request contexts."""
agent = create_demo_agent()
print("=== Mock Authentication Demo ===")
print("This example verifies that user resolution is working correctly.\n")
# Test 1: Request with email cookie
print("🔹 Test 1: Authenticated user (alice@example.com)")
request_context = RequestContext(
cookies={"vanna_email": "alice@example.com"},
headers={},
remote_addr="127.0.0.1",
)
print(
"Request context:",
{
"cookies": request_context.cookies,
"headers": request_context.headers,
"remote_addr": request_context.remote_addr,
},
)
# Send message - Agent will resolve user internally
agent_response = ""
async for component in agent.send_message(
request_context=request_context,
message="Who am I?",
conversation_id="test_conv_1",
):
# Extract and display user info from the resolved user
if hasattr(component, "rich_component"):
rich = component.rich_component
# Check if it's a text component
if rich.type.value == "text":
# Access content directly from the component (before serialization)
if hasattr(rich, "content"):
agent_response = rich.content
print(f"Agent response: {agent_response}")
# Verify user was resolved by checking the conversation store
user_resolver = agent.user_resolver
resolved_user = await user_resolver.resolve_user(request_context)
print(
f"✅ Resolved user: {resolved_user.email} (username: {resolved_user.username}, id: {resolved_user.id})"
)
print(f" Permissions: {resolved_user.permissions}")
print(f" Metadata: {resolved_user.metadata}")
print("\n" + "=" * 60 + "\n")
# Test 2: Request without email cookie (anonymous)
print("🔹 Test 2: Anonymous user (no cookie)")
anonymous_context = RequestContext(cookies={}, headers={}, remote_addr="127.0.0.1")
print(
"Request context:",
{
"cookies": anonymous_context.cookies,
"headers": anonymous_context.headers,
"remote_addr": anonymous_context.remote_addr,
},
)
agent_response = ""
async for component in agent.send_message(
request_context=anonymous_context,
message="Who am I?",
conversation_id="test_conv_2",
):
if hasattr(component, "rich_component"):
rich = component.rich_component
if rich.type.value == "text" and hasattr(rich, "content"):
agent_response = rich.content
print(f"Agent response: {agent_response}")
resolved_user = await user_resolver.resolve_user(anonymous_context)
print(
f"✅ Resolved user: {resolved_user.email or 'None'} (username: {resolved_user.username}, id: {resolved_user.id})"
)
print(f" Permissions: {resolved_user.permissions}")
print(f" Metadata: {resolved_user.metadata}")
print("\n" + "=" * 60 + "\n")
# Test 3: Different user
print("🔹 Test 3: Different authenticated user (bob@company.com)")
bob_context = RequestContext(
cookies={"vanna_email": "bob@company.com"},
headers={"User-Agent": "Mozilla/5.0"},
remote_addr="192.168.1.100",
)
print(
"Request context:",
{
"cookies": bob_context.cookies,
"headers": bob_context.headers,
"remote_addr": bob_context.remote_addr,
},
)
agent_response = ""
async for component in agent.send_message(
request_context=bob_context, message="Who am I?", conversation_id="test_conv_3"
):
if hasattr(component, "rich_component"):
rich = component.rich_component
if rich.type.value == "text" and hasattr(rich, "content"):
agent_response = rich.content
print(f"Agent response: {agent_response}")
resolved_user = await user_resolver.resolve_user(bob_context)
print(
f"✅ Resolved user: {resolved_user.email} (username: {resolved_user.username}, id: {resolved_user.id})"
)
print(f" Permissions: {resolved_user.permissions}")
print(f" Metadata: {resolved_user.metadata}")
print("\n" + "=" * 60)
print("\n✅ Authentication demo complete!")
print("\nKey Features Verified:")
print("• UserResolver is part of Agent")
print("• Agent.send_message() accepts RequestContext")
print("• User resolution happens internally in Agent")
print("• CookieEmailUserResolver extracts email from vanna_email cookie")
print("• Anonymous users are created when no cookie is present")
print("• Different users can be resolved from different request contexts")
async def main():
"""Run the authentication example."""
await demo_authentication()
def run_interactive():
"""Entry point for interactive usage."""
print("Starting mock authentication example...")
asyncio.run(main())
if __name__ == "__main__":
run_interactive()
================================================
FILE: src/vanna/examples/mock_custom_tool.py
================================================
"""
Mock example showing how to create and use custom tools.
This example demonstrates creating a simple calculator tool
and registering it with an agent that uses a mock LLM service.
It now includes a `MockCalculatorLlmService` that automatically
invokes the calculator tool with random numbers before echoing
back the computed answer.
Usage:
Template: Copy this file and modify for your custom tools
Interactive: python -m vanna.examples.mock_custom_tool
REPL: from vanna.examples.mock_custom_tool import create_demo_agent
Server: python -m vanna.servers --example mock_custom_tool
"""
import asyncio
import random
import uuid
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Type
from pydantic import BaseModel, Field
from vanna import (
AgentConfig,
Agent,
Tool,
ToolContext,
ToolRegistry,
ToolResult,
User,
UiComponent,
)
from vanna.core.interfaces import LlmService
from vanna.core.models import (
LlmRequest,
LlmResponse,
LlmStreamChunk,
ToolCall,
ToolSchema,
)
from vanna.core.rich_components import (
CardComponent,
NotificationComponent,
ComponentType,
)
from vanna.core.simple_components import (
SimpleTextComponent,
)
class CalculatorArgs(BaseModel):
"""Arguments for the calculator tool."""
operation: str = Field(
description="The operation to perform: add, subtract, multiply, divide"
)
a: float = Field(description="First number")
b: float = Field(description="Second number")
class CalculatorTool(Tool[CalculatorArgs]):
"""A simple calculator tool."""
@property
def name(self) -> str:
return "calculator"
@property
def description(self) -> str:
return "Perform basic arithmetic operations (add, subtract, multiply, divide)"
def get_args_schema(self) -> Type[CalculatorArgs]:
return CalculatorArgs
async def execute(self, context: ToolContext, args: CalculatorArgs) -> ToolResult:
"""Execute the calculator operation."""
symbol_map = {"add": "+", "subtract": "-", "multiply": "×", "divide": "÷"}
try:
if args.operation == "add":
result = args.a + args.b
elif args.operation == "subtract":
result = args.a - args.b
elif args.operation == "multiply":
result = args.a * args.b
elif args.operation == "divide":
if args.b == 0:
message = "Cannot divide by zero"
await asyncio.sleep(3)
return ToolResult(
success=False,
result_for_llm=message,
ui_component=UiComponent(
rich_component=NotificationComponent(
type=ComponentType.NOTIFICATION,
level="error",
message=message,
),
simple_component=SimpleTextComponent(text=message),
),
error=message,
)
result = args.a / args.b
else:
message = f"Unknown operation: {args.operation}"
await asyncio.sleep(3)
return ToolResult(
success=False,
result_for_llm=message,
ui_component=UiComponent(
rich_component=NotificationComponent(
type=ComponentType.NOTIFICATION,
level="warning",
message=message,
),
simple_component=SimpleTextComponent(text=message),
),
error=message,
)
await asyncio.sleep(3)
symbol = symbol_map.get(args.operation, args.operation)
expression = f"{args.a:g} {symbol} {args.b:g} = {result:g}"
return ToolResult(
success=True,
result_for_llm=str(result),
ui_component=UiComponent(
rich_component=CardComponent(
type=ComponentType.CARD,
title="Calculator Result",
content=expression,
),
simple_component=SimpleTextComponent(text=expression),
),
error=None,
)
except Exception as e:
message = str(e)
await asyncio.sleep(3)
return ToolResult(
success=False,
result_for_llm=message,
ui_component=UiComponent(
rich_component=NotificationComponent(
type=ComponentType.NOTIFICATION,
level="error",
message=message,
),
simple_component=SimpleTextComponent(text=message),
),
error=message,
)
class MockCalculatorLlmService(LlmService):
"""LLM service that exercises the calculator tool before echoing the result."""
def __init__(self, seed: Optional[int] = None):
self._random = random.Random(seed)
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Handle non-streaming calculator interactions."""
await asyncio.sleep(0.05)
return self._build_response(request)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Provide streaming compatibility by yielding a single chunk."""
await asyncio.sleep(0.05)
response = self._build_response(request)
if response.tool_calls:
yield LlmStreamChunk(tool_calls=response.tool_calls)
if response.content is not None:
yield LlmStreamChunk(
content=response.content, finish_reason=response.finish_reason
)
else:
yield LlmStreamChunk(finish_reason=response.finish_reason)
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Mock validation - no errors."""
return []
def _build_response(self, request: LlmRequest) -> LlmResponse:
"""Create a response that either calls the tool or echoes its result."""
last_message = request.messages[-1] if request.messages else None
if last_message and last_message.role == "tool":
answer = last_message.content or "No result provided"
return LlmResponse(
content=answer,
finish_reason="stop",
usage={
"prompt_tokens": 30,
"completion_tokens": 10,
"total_tokens": 40,
},
)
operation, a, b = self._random_operands()
tool_call = ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="calculator",
arguments={"operation": operation, "a": a, "b": b},
)
return LlmResponse(
content="Let me ask my calculator friend for help...",
tool_calls=[tool_call],
finish_reason="tool_calls",
usage={"prompt_tokens": 30, "completion_tokens": 5, "total_tokens": 35},
)
def _random_operands(self) -> Tuple[str, float, float]:
"""Generate operation and operands suited for the calculator tool."""
operation = self._random.choice(["add", "subtract", "multiply", "divide"])
if operation == "divide":
b = float(self._random.randint(1, 10))
multiplier = self._random.randint(1, 10)
a = float(b * multiplier)
elif operation == "subtract":
b = float(self._random.randint(1, 10))
a = b + float(self._random.randint(0, 10))
else:
a = float(self._random.randint(1, 12))
b = float(self._random.randint(1, 12))
return operation, a, b
def create_demo_agent() -> Agent:
"""Create a demo agent with custom calculator tool.
Returns:
Configured Agent with calculator tool and mock calculator LLM
"""
tool_registry = ToolRegistry()
calculator_tool = CalculatorTool()
tool_registry.register(calculator_tool)
llm_service = MockCalculatorLlmService()
return Agent(
llm_service=llm_service,
tool_registry=tool_registry,
config=AgentConfig(
stream_responses=False,
include_thinking_indicators=False,
),
)
async def main() -> None:
"""Run the mock custom tool example."""
# Create agent using factory function
agent = create_demo_agent()
tool_registry = agent.tool_registry
# Create a test user
user = User(id="user123", username="testuser", permissions=[])
# Test the tool directly
print("Testing calculator tool directly:")
tool_call = ToolCall(
id="test123", name="calculator", arguments={"operation": "add", "a": 5, "b": 3}
)
context = ToolContext(user=user, conversation_id="test", request_id="test")
result = await tool_registry.execute(tool_call, context)
print(f"5 + 3 = {result.result_for_llm if result.success else result.error}")
# Show available tools
schemas = await tool_registry.get_schemas(user)
print(f"\nAvailable tools for user: {[schema.name for schema in schemas]}")
# Demonstrate the mock LLM triggering a tool call
print("\nAgent conversation demo:")
conversation_id = "calc-demo"
async for component in agent.send_message(
user=user,
message="Can you compute something for me?",
conversation_id=conversation_id,
):
print(f"- Component type: {component.rich_component.type}")
if (
hasattr(component.rich_component, "content")
and component.rich_component.content
):
print(f"Assistant: {component.rich_component.content}")
elif component.simple_component and hasattr(component.simple_component, "text"):
print(f"Assistant: {component.simple_component.text}")
else:
print(f"- Component data: {component.rich_component.data}")
def run_interactive() -> None:
"""Entry point for interactive usage."""
print("Starting mock custom tool example...")
asyncio.run(main())
if __name__ == "__main__":
run_interactive()
================================================
FILE: src/vanna/examples/mock_quickstart.py
================================================
"""
Mock quickstart example for the Vanna Agents framework.
This example shows how to create a basic agent with a mock LLM service
and have a simple conversation.
Usage:
Template: Copy this file and modify for your needs
Interactive: python -m vanna.examples.mock_quickstart
REPL: from vanna.examples.mock_quickstart import create_demo_agent
Server: python -m vanna.servers --example mock_quickstart
"""
import asyncio
from vanna import (
AgentConfig,
Agent,
MemoryConversationStore,
MockLlmService,
User,
)
def create_demo_agent() -> Agent:
"""Create a demo agent for REPL and server usage.
Returns:
Configured Agent instance
"""
llm_service = MockLlmService(
response_content="Hello! I'm a helpful AI assistant created using the Vanna Agents framework."
)
return Agent(
llm_service=llm_service,
config=AgentConfig(
stream_responses=True, # Enable streaming for better server experience
include_thinking_indicators=True,
),
)
async def main() -> None:
"""Run the mock quickstart example."""
# Create agent using factory function
agent = create_demo_agent()
# Create a test user
user = User(
id="user123", username="testuser", email="test@example.com", permissions=[]
)
# Start a conversation
conversation_id = "conversation123"
user_message = "Hello! Can you introduce yourself?"
print(f"User: {user_message}")
print("Agent: ", end="")
# Send message and collect response
async for component in agent.send_message(
user=user, message=user_message, conversation_id=conversation_id
):
if hasattr(component, "content"):
print(component.content, end="")
print()
def run_interactive() -> None:
"""Entry point for interactive usage."""
print("Starting Vanna Agents mock quickstart demo...")
asyncio.run(main())
if __name__ == "__main__":
run_interactive()
================================================
FILE: src/vanna/examples/mock_quota_example.py
================================================
"""
Mock quota-based agent example using Mock LLM service.
This example demonstrates how to create a custom agent runner that
enforces user-based message quotas. It shows:
- Custom agent runner subclass
- Quota management and enforcement
- Error handling for quota exceeded cases
- Multiple users with different quotas
Run:
PYTHONPATH=. python vanna/examples/mock_quota_example.py
"""
import asyncio
from vanna import (
AgentConfig,
MemoryConversationStore,
MockLlmService,
User,
)
from vanna.core.registry import ToolRegistry
from vanna.tools import ListFilesTool
from vanna.examples.quota_agent import QuotaAgentRunner, QuotaExceededError
async def demonstrate_quota_system() -> None:
"""Demonstrate the quota-based agent system."""
print("🚀 Starting Mock Quota-based Agent Example\n")
# Create a mock LLM service
llm_service = MockLlmService(
response_content="Hello! I'm here to help you with your questions."
)
# Create tool registry with list_files tool
tool_registry = ToolRegistry()
list_files_tool = ListFilesTool()
tool_registry.register(list_files_tool)
# Create conversation store
conversation_store = MemoryConversationStore()
# Create the quota-based agent
agent = QuotaAgentRunner(
llm_service=llm_service,
tool_registry=tool_registry,
conversation_store=conversation_store,
config=AgentConfig(
stream_responses=False,
include_thinking_indicators=False,
),
)
# Create users with different quota settings
regular_user = User(
id="user1", username="alice", email="alice@example.com", permissions=[]
)
premium_user = User(
id="user2", username="bob", email="bob@example.com", permissions=["premium"]
)
# Set custom quotas
agent.set_user_quota(regular_user.id, 3) # Alice gets 3 messages
agent.set_user_quota(premium_user.id, 5) # Bob gets 5 messages (premium)
print("📋 User Quotas:")
print(
f" • {regular_user.username}: {agent.get_user_quota(regular_user.id)} messages"
)
print(
f" • {premium_user.username}: {agent.get_user_quota(premium_user.id)} messages"
)
print()
# Test regular user within quota
print("💬 Testing regular user (Alice) within quota:")
for i in range(1, 4): # Send 3 messages (within quota)
print(f" Message {i}/3:")
async for component in agent.send_message(
user=regular_user,
message=f"Hello, this is message {i}",
conversation_id="alice-conv",
):
if hasattr(component, "content") and component.content:
print(f" Agent: {component.content}")
print()
# Test regular user exceeding quota
print("⚠️ Testing regular user (Alice) exceeding quota:")
async for component in agent.send_message(
user=regular_user,
message="This message should be blocked",
conversation_id="alice-conv",
):
if hasattr(component, "content") and component.content:
print(f" Agent: {component.content}")
print()
# Test premium user with higher quota
print("⭐ Testing premium user (Bob) with higher quota:")
for i in range(1, 4): # Send 3 messages
print(f" Message {i}/5:")
async for component in agent.send_message(
user=premium_user,
message=f"Premium user message {i}",
conversation_id="bob-conv",
):
if hasattr(component, "content") and component.content:
print(f" Agent: {component.content}")
print()
# Demonstrate quota reset
print("🔄 Resetting Alice's usage:")
agent.reset_user_usage(regular_user.id)
print(f" Alice's remaining messages: {agent.get_user_remaining(regular_user.id)}")
print()
print("✅ After reset, Alice can send messages again:")
async for component in agent.send_message(
user=regular_user,
message="This should work after reset",
conversation_id="alice-conv2",
):
if hasattr(component, "content") and component.content:
print(f" Agent: {component.content}")
print("\n📊 Final Usage Summary:")
print(
f" • Alice: {agent.get_user_usage(regular_user.id)}/{agent.get_user_quota(regular_user.id)} used"
)
print(
f" • Bob: {agent.get_user_usage(premium_user.id)}/{agent.get_user_quota(premium_user.id)} used"
)
async def main() -> None:
"""Run the mock quota example."""
await demonstrate_quota_system()
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/examples/mock_rich_components_demo.py
================================================
"""
Mock rich components demonstration example.
This example shows how to create an agent that emits rich, stateful components
including cards, task lists, and tool execution displays using a mock LLM service.
Usage:
PYTHONPATH=. python vanna/examples/mock_rich_components_demo.py
"""
import asyncio
import time
from datetime import datetime
from typing import AsyncGenerator, Optional
from vanna import (
AgentConfig,
Agent,
MemoryConversationStore,
MockLlmService,
User,
)
from vanna.core.components import UiComponent
from vanna.core.rich_components import (
StatusCardComponent,
ProgressDisplayComponent,
LogViewerComponent,
BadgeComponent,
IconTextComponent,
RichTextComponent,
Task,
)
class RichComponentsAgent(Agent):
"""Agent that demonstrates rich component capabilities."""
async def send_message(
self,
user: User,
message: str,
*,
conversation_id: Optional[str] = None,
) -> AsyncGenerator[UiComponent, None]:
"""Send message and yield UiComponent(rich_component=rich) components."""
# Welcome message using IconText
yield UiComponent(
rich_component=IconTextComponent(
id="welcome-message",
icon="👋",
text=f"Hello {user.username}! I'll demonstrate primitive components.",
variant="primary",
size="large",
)
)
# Status card showing we're processing
status_card = StatusCardComponent(
id="processing-status",
title="Processing Request",
status="running",
description="Processing your request...",
icon="⚙️",
)
yield UiComponent(rich_component=status_card)
# Simulate some processing time
await asyncio.sleep(1)
# Update status to success
yield UiComponent(
rich_component=status_card.set_status(
"success", "Request processed successfully!"
)
)
# Create a status card for overall demo progress
demo_card = StatusCardComponent(
id="demo-progress",
title="Demo Progress",
status="running",
description="Starting primitive components demonstration...",
icon="🎯",
)
yield UiComponent(rich_component=demo_card)
# Create badges for different stages
stages = [
("Initialize", "success", "✅"),
("Components", "running", "⚙️"),
("Progress", "pending", "⏳"),
("Logs", "pending", "📋"),
("Complete", "pending", "🎉"),
]
for stage_name, stage_status, stage_icon in stages:
yield UiComponent(
rich_component=BadgeComponent(
id=f"stage-{stage_name.lower()}",
text=stage_name,
variant=stage_status if stage_status != "pending" else "default",
icon=stage_icon,
size="md",
)
)
# Progress display
progress_display = ProgressDisplayComponent(
id="demo-progress-bar",
label="Overall Progress",
value=0.2,
description="Initializing demonstration...",
status="info",
animated=True,
)
yield UiComponent(rich_component=progress_display)
# Create log viewer for detailed progress
log_viewer = LogViewerComponent(id="demo-logs", title="Demo Activity Log")
yield UiComponent(rich_component=log_viewer)
# Simulate work with updates
for i in range(3):
await asyncio.sleep(1)
# Update progress
progress_value = 0.2 + (i + 1) * 0.2
step_name = ["Creating components", "Updating progress", "Finalizing demo"][
i
]
yield UiComponent(
rich_component=progress_display.update_progress(
progress_value, f"Step {i + 2} of 5: {step_name}..."
)
)
# Update demo card
yield UiComponent(
rich_component=demo_card.set_status(
"running",
f"Step {i + 2} of 5 completed. Progress: {int(progress_value * 100)}%",
)
)
# Add log entry
yield UiComponent(
rich_component=log_viewer.add_entry(
f"Completed step: {step_name}", "info"
)
)
# Update stage badges
if i == 0:
yield UiComponent(
rich_component=BadgeComponent(
id="stage-components",
text="Components",
variant="success",
icon="✅",
size="md",
)
)
elif i == 1:
yield UiComponent(
rich_component=BadgeComponent(
id="stage-progress",
text="Progress",
variant="success",
icon="✅",
size="md",
)
)
yield UiComponent(
rich_component=BadgeComponent(
id="stage-logs",
text="Logs",
variant="running",
icon="📋",
size="md",
)
)
# Tool execution using primitive components
tool_status = StatusCardComponent(
id="demo-tool",
title="Analyze Data Tool",
status="running",
description="Running regression analysis on user_data.csv",
icon="🔬",
)
yield UiComponent(rich_component=tool_status)
# Tool progress
tool_progress = ProgressDisplayComponent(
id="tool-progress",
label="Tool Execution",
value=0.0,
description="Initializing tool...",
animated=True,
)
yield UiComponent(rich_component=tool_progress)
# Tool logs
tool_logs = LogViewerComponent(id="tool-logs", title="Tool Execution Log")
yield UiComponent(rich_component=tool_logs)
# Simulate tool execution steps
tool_steps = [
(0.2, "Loading dataset...", "info"),
(0.4, "Dataset loaded: 1000 rows, 5 columns", "info"),
(0.6, "Preprocessing data...", "info"),
(0.8, "Running regression analysis...", "info"),
(1.0, "Analysis complete!", "info"),
]
for progress_val, log_message, log_level in tool_steps:
await asyncio.sleep(0.5)
yield UiComponent(
rich_component=tool_progress.update_progress(
progress_val, f"Progress: {int(progress_val * 100)}%"
)
)
yield UiComponent(
rich_component=tool_logs.add_entry(log_message, log_level)
)
# Complete tool execution
yield UiComponent(
rich_component=tool_status.set_status(
"success",
"Tool completed successfully. R² = 0.85, strong correlation found.",
)
)
# Show results using IconText
yield UiComponent(
rich_component=IconTextComponent(
id="tool-results",
icon="📊",
text="Analysis Results: R² = 0.85 (Strong correlation)",
variant="success",
size="medium",
)
)
# Update final stage badge
yield UiComponent(
rich_component=BadgeComponent(
id="stage-logs", text="Logs", variant="success", icon="✅", size="md"
)
)
yield UiComponent(
rich_component=BadgeComponent(
id="stage-complete",
text="Complete",
variant="success",
icon="🎉",
size="md",
)
)
# Final updates
yield UiComponent(
rich_component=progress_display.update_progress(
1.0, "Demo completed successfully!"
)
)
yield UiComponent(
rich_component=demo_card.set_status(
"success", "Primitive components demonstration finished successfully!"
)
)
# Add final log entry
yield UiComponent(
rich_component=tool_logs.add_entry("Demo completed successfully!", "info")
)
# Add final text response
yield UiComponent(
rich_component=RichTextComponent(
content=f"""## Primitive Components Demo Complete!
I've demonstrated the new primitive component system:
- **Status Cards**: Domain-agnostic status displays that work for any process
- **Progress Displays**: Reusable progress indicators with animations
- **Log Viewers**: Structured log display for any activity
- **Badges**: Flexible status and category indicators
- **Icon Text**: Composable icon+text combinations
### Key Benefits of Primitive Components:
- **Separation of Concerns**: UI components are purely presentational
- **Reusability**: Components work across different domains and tools
- **Composability**: Tools build exactly the UI they need from primitives
- **Maintainability**: Business logic changes don't affect UI components
- **Extensibility**: New tools don't require new component types
**Primitive Components**: Compose UI from domain-agnostic building blocks
**After**: Tools compose UI from primitive `StatusCard` + `ProgressDisplay` + `LogViewer`
Your message was: "{message}"
""",
markdown=True,
)
)
# CLI compatibility alias
create_demo_agent = lambda: create_rich_demo_agent()
def create_rich_demo_agent() -> RichComponentsAgent:
"""Create a primitive components demo agent.
Returns:
Configured RichComponentsAgent instance
"""
llm_service = MockLlmService(response_content="Primitive components demo response")
return RichComponentsAgent(
llm_service=llm_service,
config=AgentConfig(
stream_responses=True,
include_thinking_indicators=False, # We'll use custom status cards
),
)
async def main() -> None:
"""Run the primitive components demo."""
# Create agent
agent = create_rich_demo_agent()
# Create a test user
user = User(
id="user123", username="demo_user", email="demo@example.com", permissions=[]
)
# Start a conversation
conversation_id = "primitive_demo_123"
user_message = "Show me the primitive components demo!"
print(f"User: {user_message}")
print("Agent response (primitive components):")
print("=" * 50)
# Send message and display components
component_count = 0
async for component in agent.send_message(
user=user, message=user_message, conversation_id=conversation_id
):
component_count += 1
rich_comp = component.rich_component
component_type = getattr(rich_comp, "type", rich_comp.__class__.__name__)
component_id = getattr(rich_comp, "id", "N/A")
lifecycle = getattr(rich_comp, "lifecycle", "N/A")
print(
f"[{component_count:2d}] {component_type} (id: {component_id[:8]}, lifecycle: {lifecycle})"
)
# Show some component details
if hasattr(rich_comp, "title"):
print(f" Title: {rich_comp.title}")
if hasattr(rich_comp, "content") and len(str(rich_comp.content)) < 100:
print(f" Content: {rich_comp.content}")
if hasattr(rich_comp, "status"):
print(f" Status: {rich_comp.status}")
if (
hasattr(rich_comp, "value")
and hasattr(rich_comp.type, "value")
and rich_comp.type.value == "progress_bar"
):
print(f" Progress: {rich_comp.value:.1%}")
print()
print("=" * 50)
print(f"Total components emitted: {component_count}")
def run_interactive() -> None:
"""Entry point for interactive usage."""
print("Starting Primitive Components Demo...")
asyncio.run(main())
if __name__ == "__main__":
run_interactive()
================================================
FILE: src/vanna/examples/mock_sqlite_example.py
================================================
"""
Mock example showing how to use the SQL query tool with the Chinook database.
This example demonstrates using the RunSqlTool with SqliteRunner and a mock LLM service
that automatically executes sample SQL queries against the Chinook database.
Usage:
Template: Copy this file and modify for your custom database
Interactive: python -m vanna.examples.mock_sqlite_example
REPL: from vanna.examples.mock_sqlite_example import create_demo_agent
Server: python -m vanna.servers --example mock_sqlite_example
"""
import asyncio
import os
import random
import uuid
from typing import AsyncGenerator, Dict, List, Optional, Type
from pydantic import BaseModel, Field
from vanna import (
AgentConfig,
Agent,
Tool,
ToolContext,
ToolRegistry,
ToolResult,
User,
UiComponent,
)
from vanna.core.interfaces import LlmService
from vanna.core.models import (
LlmRequest,
LlmResponse,
LlmStreamChunk,
ToolCall,
ToolSchema,
)
from vanna.core.rich_components import (
CardComponent,
NotificationComponent,
ComponentType,
)
from vanna.core.simple_components import (
SimpleTextComponent,
)
from vanna.tools import RunSqlTool
from vanna.integrations.sqlite import SqliteRunner
class MockSqliteLlmService(LlmService):
"""LLM service that exercises the SQLite query tool with sample queries."""
def __init__(self, seed: Optional[int] = None):
self._random = random.Random(seed)
self._sample_queries = [
"SELECT name FROM sqlite_master WHERE type='table'",
"SELECT COUNT(*) as total_customers FROM Customer",
"SELECT FirstName, LastName FROM Customer LIMIT 5",
"SELECT Name, Composer FROM Track WHERE Composer IS NOT NULL LIMIT 5",
"SELECT COUNT(*) as album_count FROM Album",
"SELECT Name FROM Artist LIMIT 10",
"SELECT AVG(Total) as avg_invoice_total FROM Invoice",
"SELECT GenreId, COUNT(*) as track_count FROM Track GROUP BY GenreId LIMIT 5",
]
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Handle non-streaming SQLite interactions."""
await asyncio.sleep(0.1)
return self._build_response(request)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Provide streaming compatibility by yielding a single chunk."""
await asyncio.sleep(0.1)
response = self._build_response(request)
if response.tool_calls:
yield LlmStreamChunk(tool_calls=response.tool_calls)
if response.content is not None:
yield LlmStreamChunk(
content=response.content, finish_reason=response.finish_reason
)
else:
yield LlmStreamChunk(finish_reason=response.finish_reason)
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Mock validation - no errors."""
return []
def _build_response(self, request: LlmRequest) -> LlmResponse:
"""Create a response that either calls the tool or explains its result."""
last_message = request.messages[-1] if request.messages else None
if last_message and last_message.role == "tool":
# Respond to tool result
result = last_message.content or "No result provided"
return LlmResponse(
content=f"Here's what I found in the database:\n\n{result}",
finish_reason="stop",
usage={
"prompt_tokens": 40,
"completion_tokens": 20,
"total_tokens": 60,
},
)
# Generate a random SQL query
sql_query = self._random.choice(self._sample_queries)
tool_call = ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="run_sql",
arguments={"sql": sql_query},
)
return LlmResponse(
content="Let me query the Chinook database for you...",
tool_calls=[tool_call],
finish_reason="tool_calls",
usage={"prompt_tokens": 30, "completion_tokens": 10, "total_tokens": 40},
)
def create_demo_agent() -> Agent:
"""Create a demo agent with SQLite query tool.
Returns:
Configured Agent with SQLite tool and mock LLM
"""
# Get the path to the Chinook database
database_path = os.path.join(
os.path.dirname(__file__), "..", "..", "Chinook.sqlite"
)
database_path = os.path.abspath(database_path)
if not os.path.exists(database_path):
raise FileNotFoundError(
f"Chinook database not found at {database_path}. Please download it from https://vanna.ai/Chinook.sqlite"
)
tool_registry = ToolRegistry()
sqlite_runner = SqliteRunner(database_path=database_path)
sql_tool = RunSqlTool(sql_runner=sqlite_runner)
tool_registry.register(sql_tool)
llm_service = MockSqliteLlmService()
return Agent(
llm_service=llm_service,
tool_registry=tool_registry,
config=AgentConfig(
stream_responses=False,
include_thinking_indicators=False,
),
)
async def main() -> None:
"""Run the mock SQLite example."""
# Create agent using factory function
agent = create_demo_agent()
tool_registry = agent.tool_registry
# Create a test user
user = User(id="user123", username="testuser", permissions=[])
# Test the tool directly
print("Testing SQL tool directly:")
tool_call = ToolCall(
id="test123",
name="run_sql",
arguments={"sql": "SELECT name FROM sqlite_master WHERE type='table'"},
)
context = ToolContext(user=user, conversation_id="test", request_id="test")
result = await tool_registry.execute(tool_call, context)
print(
f"Tables in database:\n{result.result_for_llm if result.success else result.error}"
)
# Show available tools
schemas = await tool_registry.get_schemas(user)
print(f"\nAvailable tools for user: {[schema.name for schema in schemas]}")
# Demonstrate the mock LLM triggering SQL queries
print("\n" + "=" * 50)
print("Agent conversation demo:")
print("=" * 50)
conversation_id = "sqlite-demo"
# Run multiple queries to show different results
for i in range(3):
print(f"\n--- Query {i + 1} ---")
async for component in agent.send_message(
user=user,
message=f"Show me some data from the database (query {i + 1})",
conversation_id=conversation_id,
):
if (
hasattr(component.rich_component, "content")
and component.rich_component.content
):
print(f"Assistant: {component.rich_component.content}")
elif component.simple_component and hasattr(
component.simple_component, "text"
):
print(f"Assistant: {component.simple_component.text}")
def run_interactive() -> None:
"""Entry point for interactive usage."""
print("Starting mock SQLite example...")
print("This example uses the Chinook database to demonstrate SQL queries.")
asyncio.run(main())
if __name__ == "__main__":
run_interactive()
================================================
FILE: src/vanna/examples/openai_quickstart.py
================================================
"""
OpenAI example using OpenAILlmService.
Loads environment from .env (via python-dotenv), uses model 'gpt-5' by default,
and sends a simple message through a Agent.
Run:
PYTHONPATH=. python vanna/examples/openai_quickstart.py
"""
import asyncio
import importlib.util
import os
import sys
def ensure_env() -> None:
if importlib.util.find_spec("dotenv") is not None:
from dotenv import load_dotenv
# Load from local .env without overriding existing env
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"), override=False)
else:
print(
"[warn] python-dotenv not installed; skipping .env load. Install with: pip install python-dotenv"
)
if not os.getenv("OPENAI_API_KEY"):
print(
"[error] OPENAI_API_KEY is not set. Add it to your environment or .env file."
)
sys.exit(1)
async def main() -> None:
ensure_env()
# Lazy import after env load to allow custom base_url/org via env
try:
from vanna.integrations.anthropic import OpenAILlmService
except ImportError as e:
print(
"[error] openai extra not installed. Install with: pip install -e .[openai]"
)
raise
from vanna import AgentConfig, Agent, User
from vanna.core.registry import ToolRegistry
from vanna.tools import ListFilesTool
# Default to 'gpt-5' for this demo; override via $OPENAI_MODEL if desired
model = os.getenv("OPENAI_MODEL", "gpt-5")
print(f"Using OpenAI model: {model}")
llm = OpenAILlmService(model=model)
# Create tool registry and register the list_files tool
tool_registry = ToolRegistry()
list_files_tool = ListFilesTool()
tool_registry.register(list_files_tool)
# Some models (e.g., reasoning/gpt-5) only support the default temperature=1.0
agent = Agent(
llm_service=llm,
config=AgentConfig(stream_responses=False, temperature=1.0),
tool_registry=tool_registry,
)
user = User(id="demo-user", username="demo")
conversation_id = "openai-demo"
print("Sending: 'List the files in the current directory'\n")
async for component in agent.send_message(
user=user,
message="List the files in the current directory",
conversation_id=conversation_id,
):
if hasattr(component, "content") and component.content:
print("Assistant:", component.content)
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/examples/primitive_components_demo.py
================================================
"""
Demonstration of the new primitive component system.
This example shows how tools compose UI from primitive, domain-agnostic
components like StatusCardComponent, ProgressDisplayComponent, etc.
Usage:
PYTHONPATH=. python vanna/examples/primitive_components_demo.py
"""
import asyncio
import uuid
from datetime import datetime
from typing import AsyncGenerator, Optional
from vanna import (
AgentConfig,
Agent,
MemoryConversationStore,
MockLlmService,
User,
)
from vanna.core.components import UiComponent
from vanna.core.rich_components import (
StatusCardComponent,
ProgressDisplayComponent,
LogViewerComponent,
BadgeComponent,
IconTextComponent,
RichTextComponent,
)
class PrimitiveComponentsAgent(Agent):
"""Agent that demonstrates the new primitive component system."""
async def send_message(
self,
user: User,
message: str,
*,
conversation_id: Optional[str] = None,
) -> AsyncGenerator[UiComponent, None]:
"""Send message and demonstrate primitive component composition."""
session_id = str(uuid.uuid4())[:8]
# Demo 1: Tool execution using primitive components
yield UiComponent(
rich_component=RichTextComponent(
content="## Primitive Components Demo\n\nShowing how tools now compose UI from primitive components:",
markdown=True,
)
)
# Status card for overall operation
operation_status = StatusCardComponent(
id=f"operation-{session_id}",
title="Data Analysis Pipeline",
status="running",
description="Processing user data through multiple analysis stages",
icon="⚙️",
)
yield UiComponent(rich_component=operation_status)
# Progress display for overall progress
overall_progress = ProgressDisplayComponent(
id=f"progress-{session_id}",
label="Overall Progress",
value=0.0,
description="Starting analysis...",
animated=True,
)
yield UiComponent(rich_component=overall_progress)
# Log viewer for detailed output
log_viewer = LogViewerComponent(
id=f"logs-{session_id}",
title="Analysis Log",
entries=[],
show_timestamps=True,
auto_scroll=True,
)
yield UiComponent(rich_component=log_viewer)
# Simulate analysis stages
stages = [
("Data Loading", "📊", 0.2),
("Data Validation", "✅", 0.4),
("Statistical Analysis", "🧮", 0.6),
("Report Generation", "📄", 0.8),
("Finalization", "🎯", 1.0),
]
for i, (stage_name, stage_icon, progress_value) in enumerate(stages):
await asyncio.sleep(0.8)
# Update overall status
status = "success" if progress_value == 1.0 else "running"
yield UiComponent(
rich_component=operation_status.set_status(
status, f"Executing: {stage_name}"
)
)
# Update progress
yield UiComponent(
rich_component=overall_progress.update_progress(
progress_value, f"Executing {stage_name}..."
)
)
# Add log entry
yield UiComponent(
rich_component=log_viewer.add_entry(f"Starting {stage_name}", "info")
)
# Create a status card for this specific stage
stage_status = StatusCardComponent(
id=f"stage-{i}-{session_id}",
title=stage_name,
status="running" if progress_value < 1.0 else "success",
description=f"Processing stage {i + 1} of {len(stages)}",
icon=stage_icon,
)
yield UiComponent(rich_component=stage_status)
await asyncio.sleep(0.5)
# Complete the stage
final_stage_status = "success" if progress_value < 1.0 else "completed"
yield UiComponent(
rich_component=stage_status.set_status(
final_stage_status, f"{stage_name} completed successfully"
)
)
yield UiComponent(
rich_component=log_viewer.add_entry(f"Completed {stage_name}", "info")
)
# Demo 2: Badge and IconText primitives
yield UiComponent(
rich_component=RichTextComponent(
content="\n### Primitive Component Examples\n\nShowing individual primitive components:",
markdown=True,
)
)
# Various badge examples
badges = [
BadgeComponent(text="Processing", variant="primary", size="small"),
BadgeComponent(text="Complete", variant="success", size="medium"),
BadgeComponent(text="Warning", variant="warning", size="large", icon="⚠️"),
BadgeComponent(text="Error", variant="error", size="medium", icon="❌"),
]
for badge in badges:
yield UiComponent(rich_component=badge)
# IconText examples
icon_texts = [
IconTextComponent(
icon="📊",
text="Data Analysis Complete",
variant="primary",
size="large",
),
IconTextComponent(
icon="✅", text="All tests passed", variant="default", size="medium"
),
IconTextComponent(
icon="⏱️",
text="Processing time: 2.3s",
variant="secondary",
size="small",
),
]
for icon_text in icon_texts:
yield UiComponent(rich_component=icon_text)
# Demo 3: Comparison with old approach
yield UiComponent(
rich_component=RichTextComponent(
content=f"""
## Key Benefits of Primitive Components
**Primitive Component Approach:**
```python
# Tool composes UI from primitives
status_card = StatusCardComponent(
title="Data Analysis",
status="running", # Pure UI state
icon="📊"
)
progress = ProgressDisplayComponent(
label="Analysis Progress",
value=0.5
)
logs = LogViewerComponent(
title="Analysis Log",
entries=log_entries
)
```
### Benefits:
- **Separation of Concerns**: UI components are purely presentational
- **Reusability**: Status cards work for any process, not just tools
- **Composability**: Tools build exactly the UI they need
- **Maintainability**: Changes to business logic don't affect UI components
- **Extensibility**: New tools don't require new component types
Your message was: "{message}"
""",
markdown=True,
)
)
def create_primitive_demo_agent() -> PrimitiveComponentsAgent:
"""Create a primitive components demo agent.
Returns:
Configured PrimitiveComponentsAgent instance
"""
llm_service = MockLlmService(response_content="Primitive components demo response")
return PrimitiveComponentsAgent(
llm_service=llm_service,
config=AgentConfig(
stream_responses=True,
include_thinking_indicators=False,
),
)
async def main() -> None:
"""Run the primitive components demo."""
# Create agent
agent = create_primitive_demo_agent()
# Create a test user
user = User(
id="user123", username="demo_user", email="demo@example.com", permissions=[]
)
# Start a conversation
conversation_id = "primitive_demo_123"
user_message = "Show me how the new primitive component system works!"
print(f"User: {user_message}")
print("Agent response (primitive components):")
print("=" * 60)
# Send message and display components
component_count = 0
async for component in agent.send_message(
user=user, message=user_message, conversation_id=conversation_id
):
component_count += 1
component_type = getattr(component, "type", component.__class__.__name__)
component_id = getattr(component, "id", "N/A")
print(
f"[{component_count:2d}] {component_type.value if hasattr(component_type, 'value') else component_type} (id: {component_id[:12] if len(str(component_id)) > 12 else component_id})"
)
rich_comp = component.rich_component
# Show component details
if hasattr(rich_comp, "title"):
print(f" Title: {rich_comp.title}")
if hasattr(rich_comp, "status"):
print(f" Status: {rich_comp.status}")
if hasattr(rich_comp, "description") and rich_comp.description:
desc = (
rich_comp.description[:60] + "..."
if len(rich_comp.description) > 60
else rich_comp.description
)
print(f" Description: {desc}")
if (
hasattr(rich_comp, "value")
and hasattr(rich_comp.type, "value")
and rich_comp.type.value == "progress_display"
):
print(f" Progress: {rich_comp.value:.1%}")
print()
print("=" * 60)
print(f"Total components emitted: {component_count}")
print("\nThis demonstrates how tools can now compose rich UIs")
print("from primitive, reusable components without semantic coupling!")
def run_interactive() -> None:
"""Entry point for interactive usage."""
print("Starting Primitive Components Demo...")
asyncio.run(main())
if __name__ == "__main__":
run_interactive()
================================================
FILE: src/vanna/examples/quota_lifecycle_example.py
================================================
"""
Example demonstrating lifecycle hooks for user quota management.
This example shows how to use lifecycle hooks to add custom functionality
like quota management without creating custom agent runner subclasses.
"""
from typing import Any, Dict, Optional
from vanna.core import Agent, LifecycleHook, User
from vanna.core.errors import AgentError
class QuotaExceededError(AgentError):
"""Raised when a user exceeds their message quota."""
pass
class QuotaCheckHook(LifecycleHook):
"""Lifecycle hook that enforces user-based message quotas."""
def __init__(self, default_quota: int = 10) -> None:
"""Initialize quota hook.
Args:
default_quota: Default quota per user if not specifically set
"""
self._user_quotas: Dict[str, int] = {}
self._user_usage: Dict[str, int] = {}
self._default_quota = default_quota
def set_user_quota(self, user_id: str, quota: int) -> None:
"""Set a specific quota for a user."""
self._user_quotas[user_id] = quota
def get_user_quota(self, user_id: str) -> int:
"""Get the quota for a user."""
return self._user_quotas.get(user_id, self._default_quota)
def get_user_usage(self, user_id: str) -> int:
"""Get current usage count for a user."""
return self._user_usage.get(user_id, 0)
def get_user_remaining(self, user_id: str) -> int:
"""Get remaining messages for a user."""
return self.get_user_quota(user_id) - self.get_user_usage(user_id)
def reset_user_usage(self, user_id: str) -> None:
"""Reset usage count for a user."""
self._user_usage[user_id] = 0
async def before_message(self, user: User, message: str) -> Optional[str]:
"""Check quota before processing message.
Raises:
QuotaExceededError: If user has exceeded their quota
"""
usage = self.get_user_usage(user.id)
quota = self.get_user_quota(user.id)
if usage >= quota:
raise QuotaExceededError(
f"User {user.username} has exceeded their quota of {quota} messages. "
f"Current usage: {usage}"
)
# Increment usage count
current_usage = self._user_usage.get(user.id, 0)
self._user_usage[user.id] = current_usage + 1
# Don't modify the message
return None
class LoggingHook(LifecycleHook):
"""Example logging hook for demonstration."""
async def before_message(self, user: User, message: str) -> Optional[str]:
"""Log incoming messages."""
print(f"[LOG] User {user.username} ({user.id}) sent message: {message[:50]}...")
return None
async def after_message(self, result: Any) -> None:
"""Log message completion."""
print(f"[LOG] Message processing completed")
async def run_example() -> None:
"""
Example showing how to use lifecycle hooks with Agent.
Instead of creating a custom subclass, we compose
the behavior using lifecycle hooks.
"""
from vanna.core.registry import ToolRegistry
from vanna.integrations.anthropic import AnthropicLlmService
from vanna.integrations.local import MemoryConversationStore
# Create quota hook
quota_hook = QuotaCheckHook(default_quota=10)
quota_hook.set_user_quota("user123", 5) # Set custom quota for specific user
# Create logging hook
logging_hook = LoggingHook()
# Create agent with multiple hooks
agent = Agent(
llm_service=AnthropicLlmService(api_key="your-api-key"),
tool_registry=ToolRegistry(),
conversation_store=MemoryConversationStore(),
lifecycle_hooks=[
logging_hook, # Logs will happen first
quota_hook, # Then quota check
],
)
# Create a test user
user = User(
id="user123", username="test_user", email="test@example.com", permissions=[]
)
# Send messages - will track quota
try:
async for component in agent.send_message(user=user, message="Hello, agent!"):
# Process UI components
pass
# Check remaining quota
remaining = quota_hook.get_user_remaining(user.id)
print(f"Remaining messages: {remaining}/{quota_hook.get_user_quota(user.id)}")
except QuotaExceededError as e:
print(f"Quota exceeded: {e}")
if __name__ == "__main__":
import asyncio
asyncio.run(run_example())
================================================
FILE: src/vanna/examples/visualization_example.py
================================================
"""
Example demonstrating SQL query execution with automatic visualization.
This example shows the integration of RunSqlTool and VisualizeDataTool,
demonstrating how SQL results are saved to CSV files and can be visualized
using the visualization tool with dependency injection.
Usage:
PYTHONPATH=. python vanna/examples/visualization_example.py
"""
import asyncio
import os
import sys
import uuid
from typing import AsyncGenerator, List, Optional
from vanna import (
AgentConfig,
Agent,
ToolRegistry,
User,
)
from vanna.core import LlmService
from vanna.core import (
LlmRequest,
LlmResponse,
LlmStreamChunk,
ToolCall,
ToolSchema,
)
from vanna.integrations.sqlite import SqliteRunner
from vanna.tools import (
RunSqlTool,
VisualizeDataTool,
LocalFileSystem,
)
class VisualizationDemoLlmService(LlmService):
"""Mock LLM that demonstrates SQL query and visualization workflow."""
def __init__(self) -> None:
self.step = 0
self.csv_filename: Optional[str] = None
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Handle non-streaming requests."""
await asyncio.sleep(0.1)
return self._build_response(request)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Handle streaming requests."""
await asyncio.sleep(0.1)
response = self._build_response(request)
if response.tool_calls:
yield LlmStreamChunk(tool_calls=response.tool_calls)
if response.content:
yield LlmStreamChunk(
content=response.content, finish_reason=response.finish_reason
)
else:
yield LlmStreamChunk(finish_reason=response.finish_reason)
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Validate tools - no errors."""
return []
def _build_response(self, request: LlmRequest) -> LlmResponse:
"""Build response based on conversation state."""
last_message = request.messages[-1] if request.messages else None
# If we got a tool result, process it
if last_message and last_message.role == "tool":
tool_result = last_message.content or ""
# Check if this was a SQL query result with a CSV file
if "Results saved to" in tool_result and ".csv" in tool_result:
# Extract filename from result
import re
match = re.search(r"'([^']*\.csv)'", tool_result)
if match:
self.csv_filename = match.group(1)
# Now visualize the data
return LlmResponse(
content=f"Great! I've saved the query results. Now let me create a visualization of the data.",
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="visualize_data",
arguments={"filename": self.csv_filename},
)
],
finish_reason="tool_calls",
)
# If this was a visualization result, acknowledge it
if "Created visualization" in tool_result:
return LlmResponse(
content=f"Perfect! I've created a visualization of the data. {tool_result}",
finish_reason="stop",
)
# Default acknowledgment
return LlmResponse(
content=f"I've completed the operation. {tool_result}",
finish_reason="stop",
)
# Initial request - run SQL query
if self.step == 0:
self.step += 1
return LlmResponse(
content="I'll query the database for you and then create a visualization.",
tool_calls=[
ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
name="run_sql",
arguments={
"sql": "SELECT Name, Milliseconds, Bytes FROM Track LIMIT 20"
},
)
],
finish_reason="tool_calls",
)
# Default response
return LlmResponse(
content="I can help you query databases and visualize the results.",
finish_reason="stop",
)
def create_demo_agent() -> Agent:
"""
Create a demo agent with SQL and visualization tools.
This function is called by the vanna server framework.
Returns:
Configured Agent with SQL and visualization tools
"""
# Check for Chinook database
database_path = os.path.join(
os.path.dirname(__file__), "..", "..", "Chinook.sqlite"
)
database_path = os.path.abspath(database_path)
if not os.path.exists(database_path):
raise FileNotFoundError(
f"Chinook database not found at {database_path}. "
"Please download it from https://vanna.ai/Chinook.sqlite"
)
# Create shared FileSystem for both tools
file_system = LocalFileSystem(working_directory="./data_storage")
# Create SQL tool with FileSystem
sqlite_runner = SqliteRunner(database_path=database_path)
sql_tool = RunSqlTool(sql_runner=sqlite_runner, file_system=file_system)
# Create visualization tool with same FileSystem
viz_tool = VisualizeDataTool(file_system=file_system)
# Create tool registry
tool_registry = ToolRegistry()
tool_registry.register(sql_tool)
tool_registry.register(viz_tool)
# Create LLM service
llm_service = VisualizationDemoLlmService()
# Create agent with streaming enabled for web interface
return Agent(
llm_service=llm_service,
tool_registry=tool_registry,
config=AgentConfig(
stream_responses=True,
include_thinking_indicators=False,
),
)
async def main() -> None:
"""Demonstrate SQL query execution with automatic visualization."""
print("🎨 SQL + Visualization Demo")
print("=" * 60)
print("This example demonstrates:")
print("1. Running SQL queries that save results to CSV files")
print("2. Automatically visualizing the CSV data")
print("3. User isolation for file storage")
print("=" * 60)
print()
# Create agent using factory function
agent = create_demo_agent()
# Create test user
user = User(id="demo-user", username="demo")
# Show available tools
tools = await agent.get_available_tools(user)
print(f"Available tools: {[tool.name for tool in tools]}")
print()
# Run conversation
conversation_id = "viz-demo"
print("User: Show me some track data and visualize it")
print()
async for component in agent.send_message(
user=user,
message="Show me some track data and visualize it",
conversation_id=conversation_id,
):
if (
component.simple_component
and hasattr(component.simple_component, "text")
and component.simple_component.text
):
print(f"Agent: {component.simple_component.text}")
elif component.simple_component and hasattr(component.simple_component, "text"):
print(f"Agent: {component.simple_component.text}")
elif hasattr(component.rich_component, "content"):
if isinstance(component.rich_component.content, dict):
# This is the chart
print(
f"Agent: [Chart Generated - Plotly figure with {len(str(component.rich_component.content))} chars]"
)
else:
print(f"Agent: {component.rich_component.content}")
print()
print("=" * 60)
print("Demo complete!")
print()
print("Key features demonstrated:")
print("✅ SQL queries save results to user-isolated CSV files")
print("✅ Visualization tool reads CSV files using FileSystem")
print("✅ Automatic chart type selection based on data shape")
print("✅ Dependency injection allows customization")
print()
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/vanna/integrations/__init__.py
================================================
"""
Integrations module.
This package contains concrete implementations of core abstractions and capabilities.
"""
from .local import MemoryConversationStore
from .mock import MockLlmService
from .plotly import PlotlyChartGenerator
from .sqlite import SqliteRunner
__all__ = [
"MockLlmService",
"MemoryConversationStore",
"SqliteRunner",
"PlotlyChartGenerator",
]
================================================
FILE: src/vanna/integrations/anthropic/__init__.py
================================================
"""
Anthropic integration.
This module provides Anthropic LLM service implementation.
"""
from .llm import AnthropicLlmService
__all__ = ["AnthropicLlmService"]
================================================
FILE: src/vanna/integrations/anthropic/llm.py
================================================
"""
Anthropic LLM service implementation.
Implements the LlmService interface using Anthropic's Messages API
(anthropic>=0.8.0). Supports non-streaming and streaming text output.
Tool-calls (tool_use blocks) are surfaced at the end of a stream or after a
non-streaming call as ToolCall entries.
"""
from __future__ import annotations
import logging
import os
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
from vanna.core.llm import (
LlmService,
LlmRequest,
LlmResponse,
LlmStreamChunk,
)
from vanna.core.tool import ToolCall, ToolSchema
class AnthropicLlmService(LlmService):
"""Anthropic Messages-backed LLM service.
Args:
model: Anthropic model name (e.g., "claude-sonnet-4-5", "claude-opus-4").
Defaults to "claude-sonnet-4-5". Can also be set via ANTHROPIC_MODEL env var.
api_key: API key; falls back to env `ANTHROPIC_API_KEY`.
base_url: Optional custom base URL; env `ANTHROPIC_BASE_URL` if unset.
extra_client_kwargs: Extra kwargs forwarded to `anthropic.Anthropic()`.
"""
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**extra_client_kwargs: Any,
) -> None:
try:
import anthropic
except Exception as e: # pragma: no cover
raise ImportError(
"anthropic package is required. Install with: pip install 'vanna[anthropic]'"
) from e
# Model selection - use environment variable or default
self.model = model or os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5")
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
base_url = base_url or os.getenv("ANTHROPIC_BASE_URL")
client_kwargs: Dict[str, Any] = {**extra_client_kwargs}
if api_key:
client_kwargs["api_key"] = api_key
if base_url:
client_kwargs["base_url"] = base_url
self._client = anthropic.Anthropic(**client_kwargs)
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Send a non-streaming request to Anthropic and return the response."""
payload = self._build_payload(request)
resp = self._client.messages.create(**payload)
logger.info(f"Anthropic response: {resp}")
text_content, tool_calls = self._parse_message_content(resp)
usage: Dict[str, int] = {}
if getattr(resp, "usage", None):
try:
usage = {
"input_tokens": int(resp.usage.input_tokens),
"output_tokens": int(resp.usage.output_tokens),
}
except Exception:
pass
return LlmResponse(
content=text_content or None,
tool_calls=tool_calls or None,
finish_reason=getattr(resp, "stop_reason", None),
usage=usage or None,
)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Stream a request to Anthropic.
Yields text chunks as they arrive. Emits tool-calls at the end by
inspecting the final message.
"""
payload = self._build_payload(request)
logger.info(f"Anthropic streaming payload: {payload}")
# SDK provides a streaming context manager with a text_stream iterator.
with self._client.messages.stream(**payload) as stream:
for text in stream.text_stream:
if text:
yield LlmStreamChunk(content=text)
final = stream.get_final_message()
logger.info(f"Anthropic stream response: {final}")
_, tool_calls = self._parse_message_content(final)
if tool_calls:
yield LlmStreamChunk(
tool_calls=tool_calls,
finish_reason=getattr(final, "stop_reason", None),
)
else:
yield LlmStreamChunk(
finish_reason=getattr(final, "stop_reason", None) or "stop"
)
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Basic validation of tool schemas for Anthropic."""
errors: List[str] = []
for t in tools:
if not t.name:
errors.append("Tool name is required")
return errors
# Internal helpers
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
# Anthropic requires messages content as list of content blocks per message
# We need to group consecutive tool messages into single user messages
messages: List[Dict[str, Any]] = []
i = 0
while i < len(request.messages):
m = request.messages[i]
if m.role == "tool":
# Group consecutive tool messages into one user message
tool_content_blocks = []
while i < len(request.messages) and request.messages[i].role == "tool":
tool_msg = request.messages[i]
if tool_msg.tool_call_id:
tool_content_blocks.append(
{
"type": "tool_result",
"tool_use_id": tool_msg.tool_call_id,
"content": tool_msg.content,
}
)
i += 1
if tool_content_blocks:
messages.append(
{
"role": "user",
"content": tool_content_blocks,
}
)
else:
# Handle non-tool messages normally
content_blocks = []
# Handle text content - only add if not empty
if m.content and m.content.strip():
content_blocks.append({"type": "text", "text": m.content})
# Handle tool_calls for assistant messages (convert to tool_use blocks)
if m.role == "assistant" and m.tool_calls:
for tc in m.tool_calls:
content_blocks.append(
{
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": tc.arguments, # type: ignore[dict-item]
}
)
# Ensure we have at least one content block for text messages
if not content_blocks and m.role in {"user", "assistant"}:
content_blocks.append({"type": "text", "text": m.content or ""})
if content_blocks:
role = m.role if m.role in {"user", "assistant"} else "user"
messages.append(
{
"role": role,
"content": content_blocks,
}
)
i += 1
tools_payload: Optional[List[Dict[str, Any]]] = None
if request.tools:
tools_payload = [
{
"name": t.name,
"description": t.description,
"input_schema": t.parameters,
}
for t in request.tools
]
payload: Dict[str, Any] = {
"model": self.model,
"messages": messages,
# Anthropic requires max_tokens; default if not provided
"max_tokens": request.max_tokens if request.max_tokens is not None else 512,
"temperature": request.temperature,
}
if tools_payload:
payload["tools"] = tools_payload
payload["tool_choice"] = {"type": "auto"}
# Add system prompt if provided
if request.system_prompt:
payload["system"] = request.system_prompt
return payload
def _parse_message_content(self, msg: Any) -> Tuple[str, List[ToolCall]]:
text_parts: List[str] = []
tool_calls: List[ToolCall] = []
content_list = getattr(msg, "content", []) or []
for block in content_list:
btype = getattr(block, "type", None) or (
block.get("type") if isinstance(block, dict) else None
)
if btype == "text":
# SDK returns block.text for typed object; dict uses {"text": ...}
text = getattr(block, "text", None)
if text is None and isinstance(block, dict):
text = block.get("text")
if text:
text_parts.append(str(text))
elif btype == "tool_use":
# Tool call with name and input
name = getattr(block, "name", None) or (
block.get("name") if isinstance(block, dict) else None
)
tc_id = getattr(block, "id", None) or (
block.get("id") if isinstance(block, dict) else None
)
input_data = getattr(block, "input", None) or (
block.get("input") if isinstance(block, dict) else None
)
if name:
try:
# input_data should be a dict already
args = (
input_data
if isinstance(input_data, dict)
else {"_raw": input_data}
)
except Exception:
args = {"_raw": str(input_data)}
tool_calls.append(
ToolCall(
id=str(tc_id or "tool_call"), name=str(name), arguments=args
)
)
text_content = "".join(text_parts)
return text_content, tool_calls
================================================
FILE: src/vanna/integrations/azureopenai/__init__.py
================================================
"""
Azure OpenAI integration.
This module provides Azure OpenAI LLM service implementations.
"""
from .llm import AzureOpenAILlmService
__all__ = ["AzureOpenAILlmService"]
================================================
FILE: src/vanna/integrations/azureopenai/llm.py
================================================
"""
Azure OpenAI LLM service implementation.
Provides an `LlmService` backed by Azure OpenAI Chat Completions (openai>=1.0.0)
with support for streaming, deployment-scoped models, and Azure-specific
authentication flows.
"""
from __future__ import annotations
import json
import os
from typing import Any, AsyncGenerator, Dict, List, Optional, Set
from vanna.core.llm import (
LlmService,
LlmRequest,
LlmResponse,
LlmStreamChunk,
)
from vanna.core.tool import ToolCall, ToolSchema
# Models that don't support temperature and other sampling parameters
REASONING_MODELS: Set[str] = {
"o1",
"o1-mini",
"o1-preview",
"o3-mini",
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"gpt-5-pro",
"gpt-5-codex",
}
def _is_reasoning_model(model: str) -> bool:
"""Return True when the deployment targets a reasoning-only model."""
model_lower = model.lower()
return any(reasoning_model in model_lower for reasoning_model in REASONING_MODELS)
class AzureOpenAILlmService(LlmService):
"""Azure OpenAI Chat Completions-backed LLM service.
Wraps `openai.AzureOpenAI` so Vanna can talk to deployment-scoped models
and either API key or Microsoft Entra ID authentication.
Args:
model: Deployment name in Azure OpenAI (required).
api_key: API key; falls back to `AZURE_OPENAI_API_KEY`.
azure_endpoint: Azure OpenAI endpoint URL; falls back to
`AZURE_OPENAI_ENDPOINT`.
api_version: API version; defaults to "2024-10-21" or
`AZURE_OPENAI_API_VERSION`.
azure_ad_token_provider: Optional bearer token provider for Entra ID.
**extra_client_kwargs: Additional keyword arguments forwarded to the
underlying client.
"""
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = None,
azure_ad_token_provider: Optional[Any] = None,
**extra_client_kwargs: Any,
) -> None:
try:
from openai import AzureOpenAI
except Exception as e: # pragma: no cover
raise ImportError(
"openai package is required. Install with: pip install 'vanna[azureopenai]' "
"or 'pip install openai'"
) from e
# Model/deployment name is required for Azure OpenAI
self.model = model or os.getenv("AZURE_OPENAI_MODEL")
if not self.model:
raise ValueError(
"model parameter (deployment name) is required for Azure OpenAI. "
"Provide it as argument or set AZURE_OPENAI_MODEL environment variable."
)
# Azure endpoint is required
azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError(
"azure_endpoint is required for Azure OpenAI. "
"Provide it as argument or set AZURE_OPENAI_ENDPOINT environment variable."
)
# API version - use latest stable GA version by default
api_version = api_version or os.getenv("AZURE_OPENAI_API_VERSION", "2024-10-21")
# Build client kwargs
client_kwargs: Dict[str, Any] = {
"azure_endpoint": azure_endpoint,
"api_version": api_version,
**extra_client_kwargs,
}
# Authentication: prefer Azure AD token provider, fallback to API key
if azure_ad_token_provider is not None:
client_kwargs["azure_ad_token_provider"] = azure_ad_token_provider
else:
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
if not api_key:
raise ValueError(
"Authentication required: provide either api_key or azure_ad_token_provider. "
"API key can also be set via AZURE_OPENAI_API_KEY environment variable."
)
client_kwargs["api_key"] = api_key
self._client = AzureOpenAI(**client_kwargs)
self._is_reasoning_model = _is_reasoning_model(self.model)
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Send a non-streaming request to Azure OpenAI and return the response."""
payload = self._build_payload(request)
# Call the API synchronously; this function is async but we can block here.
resp = self._client.chat.completions.create(**payload, stream=False)
if not resp.choices:
return LlmResponse(content=None, tool_calls=None, finish_reason=None)
choice = resp.choices[0]
content: Optional[str] = getattr(choice.message, "content", None)
tool_calls = self._extract_tool_calls_from_message(choice.message)
usage: Dict[str, int] = {}
if getattr(resp, "usage", None):
usage = {
k: int(v)
for k, v in {
"prompt_tokens": getattr(resp.usage, "prompt_tokens", 0),
"completion_tokens": getattr(resp.usage, "completion_tokens", 0),
"total_tokens": getattr(resp.usage, "total_tokens", 0),
}.items()
}
return LlmResponse(
content=content,
tool_calls=tool_calls or None,
finish_reason=getattr(choice, "finish_reason", None),
usage=usage or None,
)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""
Stream a request to Azure OpenAI.
Emits `LlmStreamChunk` for textual deltas as they arrive. Tool-calls are
accumulated and emitted in a final chunk when the stream ends.
"""
payload = self._build_payload(request)
# Synchronous streaming iterator; iterate within async context.
stream = self._client.chat.completions.create(**payload, stream=True)
# Builders for streamed tool-calls (index -> partial)
tc_builders: Dict[int, Dict[str, Optional[str]]] = {}
last_finish: Optional[str] = None
for event in stream:
if not getattr(event, "choices", None):
continue
choice = event.choices[0]
delta = getattr(choice, "delta", None)
if delta is None:
# Some SDK versions use `event.choices[0].message` on the final packet
last_finish = getattr(choice, "finish_reason", last_finish)
continue
# Text content
content_piece: Optional[str] = getattr(delta, "content", None)
if content_piece:
yield LlmStreamChunk(content=content_piece)
# Tool calls (streamed)
streamed_tool_calls = getattr(delta, "tool_calls", None)
if streamed_tool_calls:
for tc in streamed_tool_calls:
idx = getattr(tc, "index", 0) or 0
b = tc_builders.setdefault(
idx, {"id": None, "name": None, "arguments": ""}
)
if getattr(tc, "id", None):
b["id"] = tc.id
fn = getattr(tc, "function", None)
if fn is not None:
if getattr(fn, "name", None):
b["name"] = fn.name
if getattr(fn, "arguments", None):
b["arguments"] = (b["arguments"] or "") + fn.arguments
last_finish = getattr(choice, "finish_reason", last_finish)
# Emit final tool-calls chunk if any
final_tool_calls: List[ToolCall] = []
for b in tc_builders.values():
if not b.get("name"):
continue
args_raw = b.get("arguments") or "{}"
try:
loaded = json.loads(args_raw)
if isinstance(loaded, dict):
args_dict: Dict[str, Any] = loaded
else:
args_dict = {"args": loaded}
except Exception:
args_dict = {"_raw": args_raw}
final_tool_calls.append(
ToolCall(
id=b.get("id") or "tool_call",
name=b["name"] or "tool",
arguments=args_dict,
)
)
if final_tool_calls:
yield LlmStreamChunk(tool_calls=final_tool_calls, finish_reason=last_finish)
else:
# Still emit a terminal chunk to signal completion
yield LlmStreamChunk(finish_reason=last_finish or "stop")
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Validate tool schemas. Returns a list of error messages."""
errors: List[str] = []
# Basic checks; Azure OpenAI will enforce further validation server-side.
for t in tools:
if not t.name or len(t.name) > 64:
errors.append(f"Invalid tool name: {t.name!r}")
return errors
# Internal helpers
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
"""Build the API payload from LlmRequest."""
messages: List[Dict[str, Any]] = []
# Add system prompt as first message if provided
if request.system_prompt:
messages.append({"role": "system", "content": request.system_prompt})
for m in request.messages:
msg: Dict[str, Any] = {"role": m.role, "content": m.content}
if m.role == "tool" and m.tool_call_id:
msg["tool_call_id"] = m.tool_call_id
elif m.role == "assistant" and m.tool_calls:
# Convert tool calls to OpenAI format
tool_calls_payload = []
for tc in m.tool_calls:
tool_calls_payload.append(
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments),
},
}
)
msg["tool_calls"] = tool_calls_payload
messages.append(msg)
tools_payload: Optional[List[Dict[str, Any]]] = None
if request.tools:
tools_payload = [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
},
}
for t in request.tools
]
payload: Dict[str, Any] = {
"model": self.model,
"messages": messages,
}
# Add temperature only for non-reasoning models
# Reasoning models (GPT-5, o1, o3-mini) don't support temperature parameter
if not self._is_reasoning_model:
payload["temperature"] = request.temperature
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if tools_payload:
payload["tools"] = tools_payload
payload["tool_choice"] = "auto"
return payload
def _extract_tool_calls_from_message(self, message: Any) -> List[ToolCall]:
"""Extract tool calls from OpenAI message object."""
tool_calls: List[ToolCall] = []
raw_tool_calls = getattr(message, "tool_calls", None) or []
for tc in raw_tool_calls:
fn = getattr(tc, "function", None)
if not fn:
continue
args_raw = getattr(fn, "arguments", "{}")
try:
loaded = json.loads(args_raw)
if isinstance(loaded, dict):
args_dict: Dict[str, Any] = loaded
else:
args_dict = {"args": loaded}
except Exception:
args_dict = {"_raw": args_raw}
tool_calls.append(
ToolCall(
id=getattr(tc, "id", "tool_call"),
name=getattr(fn, "name", "tool"),
arguments=args_dict,
)
)
return tool_calls
================================================
FILE: src/vanna/integrations/azuresearch/__init__.py
================================================
"""
Azure AI Search integration for Vanna Agents.
"""
from .agent_memory import AzureAISearchAgentMemory
__all__ = ["AzureAISearchAgentMemory"]
================================================
FILE: src/vanna/integrations/azuresearch/agent_memory.py
================================================
"""
Azure AI Search implementation of AgentMemory.
This implementation uses Azure Cognitive Search for vector storage of tool usage patterns.
"""
import json
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
try:
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
SearchIndex,
SearchField,
SearchFieldDataType,
VectorSearch,
VectorSearchAlgorithmConfiguration,
)
from azure.core.credentials import AzureKeyCredential
AZURE_SEARCH_AVAILABLE = True
except ImportError:
AZURE_SEARCH_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class AzureAISearchAgentMemory(AgentMemory):
"""Azure AI Search-based implementation of AgentMemory."""
def __init__(
self,
endpoint: str,
api_key: str,
index_name: str = "tool-memories",
dimension: int = 384,
):
if not AZURE_SEARCH_AVAILABLE:
raise ImportError(
"Azure Search is required for AzureAISearchAgentMemory. "
"Install with: pip install azure-search-documents"
)
self.endpoint = endpoint
self.api_key = api_key
self.index_name = index_name
self.dimension = dimension
self._credential = AzureKeyCredential(api_key)
self._search_client = None
self._index_client = None
self._executor = ThreadPoolExecutor(max_workers=2)
def _get_index_client(self):
"""Get or create index client."""
if self._index_client is None:
self._index_client = SearchIndexClient(
endpoint=self.endpoint, credential=self._credential
)
self._ensure_index_exists()
return self._index_client
def _get_search_client(self):
"""Get or create search client."""
if self._search_client is None:
self._get_index_client() # Ensure index exists
self._search_client = SearchClient(
endpoint=self.endpoint,
index_name=self.index_name,
credential=self._credential,
)
return self._search_client
def _ensure_index_exists(self):
"""Create index if it doesn't exist."""
try:
self._index_client.get_index(self.index_name)
except Exception:
# Create index with vector search configuration
fields = [
SearchField(
name="memory_id", type=SearchFieldDataType.String, key=True
),
SearchField(
name="question", type=SearchFieldDataType.String, searchable=True
),
SearchField(
name="tool_name", type=SearchFieldDataType.String, filterable=True
),
SearchField(name="args_json", type=SearchFieldDataType.String),
SearchField(
name="timestamp",
type=SearchFieldDataType.String,
sortable=True,
filterable=True,
),
SearchField(
name="success", type=SearchFieldDataType.Boolean, filterable=True
),
SearchField(name="metadata_json", type=SearchFieldDataType.String),
SearchField(
name="embedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=self.dimension,
vector_search_configuration="vector-config",
),
]
vector_search = VectorSearch(
algorithm_configurations=[
VectorSearchAlgorithmConfiguration(name="vector-config")
]
)
index = SearchIndex(
name=self.index_name, fields=fields, vector_search=vector_search
)
self._index_client.create_index(index)
def _create_embedding(self, text: str) -> List[float]:
"""Create a simple embedding from text (placeholder)."""
import hashlib
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
client = self._get_search_client()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(question)
document = {
"memory_id": memory_id,
"question": question,
"tool_name": tool_name,
"args_json": json.dumps(args),
"timestamp": timestamp,
"success": success,
"metadata_json": json.dumps(metadata or {}),
"embedding": embedding,
}
client.upload_documents(documents=[document])
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
client = self._get_search_client()
embedding = self._create_embedding(question)
# Build filter
filter_expr = "success eq true"
if tool_name_filter:
filter_expr += f" and tool_name eq '{tool_name_filter}'"
results = client.search(
search_text=None, vector=embedding, top_k=limit, filter=filter_expr
)
search_results = []
for i, doc in enumerate(results):
# Azure returns similarity score in @search.score
similarity_score = doc.get("@search.score", 0)
if similarity_score >= similarity_threshold:
args = json.loads(doc.get("args_json", "{}"))
metadata_dict = json.loads(doc.get("metadata_json", "{}"))
memory = ToolMemory(
memory_id=doc["memory_id"],
question=doc["question"],
tool_name=doc["tool_name"],
args=args,
timestamp=doc.get("timestamp"),
success=doc.get("success", True),
metadata=metadata_dict,
)
search_results.append(
ToolMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories."""
def _get_recent():
client = self._get_search_client()
results = client.search(
search_text="*", top=limit, order_by=["timestamp desc"]
)
memories = []
for doc in results:
args = json.loads(doc.get("args_json", "{}"))
metadata_dict = json.loads(doc.get("metadata_json", "{}"))
memory = ToolMemory(
memory_id=doc["memory_id"],
question=doc["question"],
tool_name=doc["tool_name"],
args=args,
timestamp=doc.get("timestamp"),
success=doc.get("success", True),
metadata=metadata_dict,
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID."""
def _delete():
client = self._get_search_client()
try:
client.delete_documents(documents=[{"memory_id": memory_id}])
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
client = self._get_search_client()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(content)
document = {
"memory_id": memory_id,
"content": content,
"timestamp": timestamp,
"embedding": embedding,
}
client.upload_documents(documents=[document])
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
client = self._get_search_client()
embedding = self._create_embedding(query)
results = client.search(search_text=None, vector=embedding, top_k=limit)
search_results = []
for i, doc in enumerate(results):
similarity_score = doc.get("@search.score", 0)
if similarity_score >= similarity_threshold:
memory = TextMemory(
memory_id=doc["memory_id"],
content=doc.get("content", ""),
timestamp=doc.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
client = self._get_search_client()
results = client.search(
search_text="*", top=limit, order_by=["timestamp desc"]
)
memories = []
for doc in results:
# Skip if this is a tool memory (has tool_name field)
if "tool_name" in doc:
continue
memory = TextMemory(
memory_id=doc["memory_id"],
content=doc.get("content", ""),
timestamp=doc.get("timestamp"),
)
memories.append(memory)
return memories[:limit]
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
client = self._get_search_client()
try:
client.delete_documents(documents=[{"memory_id": memory_id}])
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
client = self._get_search_client()
# Build filter
filter_parts = []
if tool_name:
filter_parts.append(f"tool_name eq '{tool_name}'")
if before_date:
filter_parts.append(f"timestamp lt '{before_date}'")
filter_expr = " and ".join(filter_parts) if filter_parts else None
# Search for documents to delete
results = client.search(
search_text="*", filter=filter_expr, select=["memory_id"]
)
docs_to_delete = [{"memory_id": doc["memory_id"]} for doc in results]
if docs_to_delete:
client.delete_documents(documents=docs_to_delete)
return len(docs_to_delete)
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/integrations/bigquery/__init__.py
================================================
"""BigQuery integration for Vanna."""
from .sql_runner import BigQueryRunner
__all__ = ["BigQueryRunner"]
================================================
FILE: src/vanna/integrations/bigquery/sql_runner.py
================================================
"""BigQuery implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class BigQueryRunner(SqlRunner):
"""BigQuery implementation of the SqlRunner interface."""
def __init__(self, project_id: str, cred_file_path: Optional[str] = None, **kwargs):
"""Initialize with BigQuery connection parameters.
Args:
project_id: Google Cloud Project ID
cred_file_path: Path to Google Cloud credentials JSON file (optional)
**kwargs: Additional google.cloud.bigquery.Client parameters
"""
try:
from google.cloud import bigquery
from google.oauth2 import service_account
self.bigquery = bigquery
self.service_account = service_account
except ImportError as e:
raise ImportError(
"google-cloud-bigquery package is required. "
"Install with: pip install 'vanna[bigquery]'"
) from e
self.project_id = project_id
self.cred_file_path = cred_file_path
self.kwargs = kwargs
self._client = None
def _get_client(self):
"""Get or create BigQuery client."""
if self._client is not None:
return self._client
if self.cred_file_path:
import json
with open(self.cred_file_path, "r") as f:
credentials = (
self.service_account.Credentials.from_service_account_info(
json.loads(f.read()),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
self._client = self.bigquery.Client(
project=self.project_id, credentials=credentials, **self.kwargs
)
else:
# Use default credentials
self._client = self.bigquery.Client(project=self.project_id, **self.kwargs)
return self._client
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against BigQuery database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
google.api_core.exceptions.GoogleAPIError: If query execution fails
"""
client = self._get_client()
# Execute the query
job = client.query(args.sql)
df = job.result().to_dataframe()
return df
================================================
FILE: src/vanna/integrations/chromadb/__init__.py
================================================
"""
ChromaDB integration for Vanna Agents.
"""
from .agent_memory import ChromaAgentMemory
def get_device() -> str:
"""Detect the best available device for embeddings.
This function checks for GPU availability and returns the appropriate device string
for use with embedding models. It prioritizes hardware acceleration when available.
Returns:
str: Device string - 'cuda' if NVIDIA GPU available, 'mps' if Apple Silicon,
'cpu' otherwise.
Examples:
>>> device = get_device()
>>> print(f"Using device: {device}")
Using device: cuda
# Use with ChromaDB SentenceTransformer embeddings
>>> from chromadb.utils import embedding_functions
>>> ef = embedding_functions.SentenceTransformerEmbeddingFunction(
... model_name="sentence-transformers/all-MiniLM-L6-v2",
... device=get_device()
... )
>>> memory = ChromaAgentMemory(embedding_function=ef)
"""
try:
import torch
# Check for CUDA (NVIDIA GPUs)
if torch.cuda.is_available():
return "cuda"
# Check for MPS (Apple Silicon GPUs)
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
except ImportError:
# PyTorch not installed, fall back to CPU
pass
return "cpu"
def create_sentence_transformer_embedding_function(
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", device: str = None
):
"""Create a SentenceTransformer embedding function with automatic device detection.
This convenience function creates a ChromaDB-compatible SentenceTransformer embedding
function with intelligent device selection. If no device is specified, it automatically
detects and uses the best available hardware (CUDA, MPS, or CPU).
Note: This requires the 'sentence-transformers' package to be installed.
Install with: pip install sentence-transformers
Args:
model_name: The name of the sentence-transformer model to use.
Defaults to "sentence-transformers/all-MiniLM-L6-v2".
device: Optional device string ('cuda', 'mps', or 'cpu'). If None,
automatically detects the best available device.
Returns:
A ChromaDB SentenceTransformer embedding function configured for the
specified/detected device.
Examples:
# Automatic device detection (uses CUDA/MPS if available)
>>> from vanna.integrations.chromadb import ChromaAgentMemory, create_sentence_transformer_embedding_function
>>> ef = create_sentence_transformer_embedding_function()
>>> memory = ChromaAgentMemory(embedding_function=ef)
# Explicitly use CUDA
>>> ef_cuda = create_sentence_transformer_embedding_function(device="cuda")
>>> memory = ChromaAgentMemory(embedding_function=ef_cuda)
# Use a different model
>>> ef_large = create_sentence_transformer_embedding_function(
... model_name="sentence-transformers/all-mpnet-base-v2"
... )
>>> memory = ChromaAgentMemory(embedding_function=ef_large)
"""
try:
from chromadb.utils import embedding_functions
except ImportError:
raise ImportError("ChromaDB is required. Install with: pip install chromadb")
if device is None:
device = get_device()
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=model_name, device=device
)
__all__ = [
"ChromaAgentMemory",
"get_device",
"create_sentence_transformer_embedding_function",
]
================================================
FILE: src/vanna/integrations/chromadb/agent_memory.py
================================================
"""
Local vector database implementation of AgentMemory.
This implementation uses ChromaDB for local vector storage of tool usage patterns.
"""
import json
import hashlib
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
try:
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
try:
from chromadb.errors import NotFoundError
except ImportError:
# Fallback for older ChromaDB versions that don't have chromadb.errors
class NotFoundError(Exception):
"""Fallback NotFoundError for older ChromaDB versions."""
pass
CHROMADB_AVAILABLE = True
except ImportError:
CHROMADB_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class ChromaAgentMemory(AgentMemory):
"""ChromaDB-based implementation of AgentMemory.
This implementation uses ChromaDB's PersistentClient to store agent memories
on disk, ensuring they persist across application restarts.
Key Features:
- Persistent storage: All memories are automatically saved to disk
- Efficient retrieval: Existing collections are loaded without re-initializing
embedding functions, avoiding unnecessary model downloads
- Flexible embedding: Supports custom embedding functions or uses ChromaDB's
default embedding function
Args:
persist_directory: Directory where ChromaDB will store its data.
Defaults to "./chroma_memory". Use an absolute path
for production deployments to ensure consistent location
across restarts.
collection_name: Name of the ChromaDB collection to use. Multiple agents
can share the same persist_directory with different
collection names.
embedding_function: Optional custom embedding function. If not provided,
ChromaDB's DefaultEmbeddingFunction is used (requires
internet connection on first use to download the model).
Once a collection is created, subsequent application
restarts will retrieve the existing collection without
re-downloading the model.
Example:
>>> from vanna.integrations.chromadb import ChromaAgentMemory
>>> # Basic usage with defaults
>>> memory = ChromaAgentMemory(
... persist_directory="/app/data/chroma",
... collection_name="my_agent_memory"
... )
>>>
>>> # With custom embedding function (e.g., for offline use)
>>> from chromadb.utils import embedding_functions
>>> ef = embedding_functions.SentenceTransformerEmbeddingFunction()
>>> memory = ChromaAgentMemory(
... persist_directory="/app/data/chroma",
... embedding_function=ef
... )
Note:
The default embedding function downloads an ONNX model (~80MB) on first use.
For air-gapped or offline environments, pre-download the model or provide
a custom embedding function.
Limitation:
This class does not validate that an existing Chroma collection was created
with the same embedding function as the one configured for the current
``ChromaAgentMemory`` instance. If you reuse a collection (same
``persist_directory`` and ``collection_name``) with a different embedding
function than was originally used, queries may fail or produce incorrect
similarity results. It is your responsibility to ensure that a given
collection is always accessed with a consistent embedding function, or to
implement your own validation around collection creation and reuse.
"""
def __init__(
self,
persist_directory: str = "./chroma_memory",
collection_name: str = "tool_memories",
embedding_function=None,
):
if not CHROMADB_AVAILABLE:
raise ImportError(
"ChromaDB is required for ChromaAgentMemory. Install with: pip install chromadb"
)
self.persist_directory = persist_directory
self.collection_name = collection_name
self._client = None
self._collection = None
self._executor = ThreadPoolExecutor(max_workers=2)
self._embedding_function = embedding_function
def _get_client(self):
"""Get or create ChromaDB client."""
if self._client is None:
self._client = chromadb.PersistentClient(
path=self.persist_directory,
settings=Settings(anonymized_telemetry=False, allow_reset=True),
)
return self._client
def _get_embedding_function(self):
"""Get or create the embedding function.
If no embedding function was provided during initialization,
uses ChromaDB's default embedding function.
"""
if self._embedding_function is None:
# Use ChromaDB's default embedding function
# This avoids requiring sentence-transformers as a hard dependency
self._embedding_function = embedding_functions.DefaultEmbeddingFunction()
return self._embedding_function
def _get_collection(self):
"""Get or create ChromaDB collection."""
if self._collection is None:
client = self._get_client()
try:
# Try to get existing collection first
# Don't pass embedding_function here to avoid re-instantiating/downloading it
# For existing collections, ChromaDB uses the stored embedding function configuration
self._collection = client.get_collection(name=self.collection_name)
except NotFoundError:
# Collection doesn't exist, create it with embedding function
embedding_func = self._get_embedding_function()
self._collection = client.create_collection(
name=self.collection_name,
embedding_function=embedding_func,
metadata={"description": "Tool usage memories for learning"},
)
return self._collection
def _create_memory_id(self) -> str:
"""Create a unique ID for a memory."""
import uuid
return str(uuid.uuid4())
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
collection = self._get_collection()
memory_id = self._create_memory_id()
timestamp = datetime.now().isoformat()
# ChromaDB only accepts primitive types in metadata
# Serialize complex objects to JSON strings
memory_data = {
"question": question,
"tool_name": tool_name,
"args_json": json.dumps(args), # Serialize to JSON string
"timestamp": timestamp,
"success": success,
"metadata_json": json.dumps(metadata or {}), # Serialize metadata too
}
# Use question as document text for embedding
collection.upsert(
ids=[memory_id], documents=[question], metadatas=[memory_data]
)
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
collection = self._get_collection()
# Prepare where filter - ChromaDB requires $and for multiple conditions
if tool_name_filter:
where_filter = {
"$and": [{"success": True}, {"tool_name": tool_name_filter}]
}
else:
where_filter = {"success": True}
results = collection.query(
query_texts=[question], n_results=limit, where=where_filter
)
search_results = []
if results["ids"] and len(results["ids"][0]) > 0:
for i, (id_, distance, metadata) in enumerate(
zip(
results["ids"][0],
results["distances"][0],
results["metadatas"][0],
)
):
# Convert distance to similarity score (ChromaDB uses L2 distance)
similarity_score = max(0, 1 - distance)
if similarity_score >= similarity_threshold:
# Deserialize JSON fields
args = json.loads(metadata.get("args_json", "{}"))
metadata_dict = json.loads(metadata.get("metadata_json", "{}"))
# Use the ChromaDB document ID as the memory ID
memory = ToolMemory(
memory_id=id_,
question=metadata["question"],
tool_name=metadata["tool_name"],
args=args,
timestamp=metadata.get("timestamp"),
success=metadata.get("success", True),
metadata=metadata_dict,
)
search_results.append(
ToolMemorySearchResult(
memory=memory,
similarity_score=similarity_score,
rank=i + 1,
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories. Returns most recent memories first."""
def _get_recent():
collection = self._get_collection()
# Get all memories and sort by timestamp
results = collection.get()
if not results["metadatas"] or not results["ids"]:
return []
# Parse and sort by timestamp
memories_with_time = []
for i, (doc_id, metadata) in enumerate(
zip(results["ids"], results["metadatas"])
):
# Skip text memories - they have is_text_memory flag
if metadata.get("is_text_memory"):
continue
args = json.loads(metadata.get("args_json", "{}"))
metadata_dict = json.loads(metadata.get("metadata_json", "{}"))
# Use the ChromaDB document ID as the memory ID
memory = ToolMemory(
memory_id=doc_id,
question=metadata["question"],
tool_name=metadata["tool_name"],
args=args,
timestamp=metadata.get("timestamp"),
success=metadata.get("success", True),
metadata=metadata_dict,
)
memories_with_time.append((memory, metadata.get("timestamp", "")))
# Sort by timestamp descending (most recent first)
memories_with_time.sort(key=lambda x: x[1], reverse=True)
# Return only the memory objects, limited to the requested amount
return [m[0] for m in memories_with_time[:limit]]
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID. Returns True if deleted, False if not found."""
def _delete():
collection = self._get_collection()
# Check if the ID exists
try:
results = collection.get(ids=[memory_id])
if results["ids"] and len(results["ids"]) > 0:
collection.delete(ids=[memory_id])
return True
return False
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
collection = self._get_collection()
memory_id = self._create_memory_id()
timestamp = datetime.now().isoformat()
memory_data = {
"content": content,
"timestamp": timestamp,
"is_text_memory": True,
}
collection.upsert(
ids=[memory_id], documents=[content], metadatas=[memory_data]
)
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
collection = self._get_collection()
where_filter = {"is_text_memory": True}
results = collection.query(
query_texts=[query], n_results=limit, where=where_filter
)
search_results = []
if results["ids"] and len(results["ids"][0]) > 0:
for i, (id_, distance, metadata) in enumerate(
zip(
results["ids"][0],
results["distances"][0],
results["metadatas"][0],
)
):
similarity_score = max(0, 1 - distance)
if similarity_score >= similarity_threshold:
memory = TextMemory(
memory_id=id_,
content=metadata.get("content", ""),
timestamp=metadata.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory,
similarity_score=similarity_score,
rank=i + 1,
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
collection = self._get_collection()
results = collection.get(where={"is_text_memory": True})
if not results["metadatas"] or not results["ids"]:
return []
memories_with_time = []
for doc_id, metadata in zip(results["ids"], results["metadatas"]):
memory = TextMemory(
memory_id=doc_id,
content=metadata.get("content", ""),
timestamp=metadata.get("timestamp"),
)
memories_with_time.append((memory, metadata.get("timestamp", "")))
memories_with_time.sort(key=lambda x: x[1], reverse=True)
return [m[0] for m in memories_with_time[:limit]]
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
collection = self._get_collection()
try:
results = collection.get(ids=[memory_id])
if results["ids"] and len(results["ids"]) > 0:
collection.delete(ids=[memory_id])
return True
return False
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
collection = self._get_collection()
# Build where filter
where_filter = {}
if tool_name:
where_filter["tool_name"] = tool_name
# Get memories to delete
results = collection.get(where=where_filter if where_filter else None)
if not results["ids"]:
return 0
ids_to_delete = []
for i, metadata in enumerate(results["metadatas"]):
if before_date:
memory_date = metadata.get("timestamp", "")
if memory_date and memory_date < before_date:
ids_to_delete.append(results["ids"][i])
else:
ids_to_delete.append(results["ids"][i])
if ids_to_delete:
collection.delete(ids=ids_to_delete)
return len(ids_to_delete)
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/integrations/clickhouse/__init__.py
================================================
"""ClickHouse integration for Vanna."""
from .sql_runner import ClickHouseRunner
__all__ = ["ClickHouseRunner"]
================================================
FILE: src/vanna/integrations/clickhouse/sql_runner.py
================================================
"""ClickHouse implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class ClickHouseRunner(SqlRunner):
"""ClickHouse implementation of the SqlRunner interface."""
def __init__(
self,
host: str,
database: str,
user: str,
password: str,
port: int = 8123,
**kwargs,
):
"""Initialize with ClickHouse connection parameters.
Args:
host: Database host address
database: Database name
user: Database user
password: Database password
port: Database port (default: 8123)
**kwargs: Additional clickhouse_connect connection parameters
"""
try:
import clickhouse_connect
self.clickhouse_connect = clickhouse_connect
except ImportError as e:
raise ImportError(
"clickhouse-connect package is required. "
"Install with: pip install 'vanna[clickhouse]'"
) from e
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
self.kwargs = kwargs
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against ClickHouse database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
Exception: If query execution fails
"""
# Connect to the database
client = self.clickhouse_connect.get_client(
host=self.host,
port=self.port,
username=self.user,
password=self.password,
database=self.database,
**self.kwargs,
)
try:
# Execute the query
result = client.query(args.sql)
results = result.result_rows
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=result.column_names)
return df
finally:
client.close()
================================================
FILE: src/vanna/integrations/duckdb/__init__.py
================================================
"""DuckDB integration for Vanna."""
from .sql_runner import DuckDBRunner
__all__ = ["DuckDBRunner"]
================================================
FILE: src/vanna/integrations/duckdb/sql_runner.py
================================================
"""DuckDB implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class DuckDBRunner(SqlRunner):
"""DuckDB implementation of the SqlRunner interface."""
def __init__(
self, database_path: str = ":memory:", init_sql: Optional[str] = None, **kwargs
):
"""Initialize with DuckDB connection parameters.
Args:
database_path: Path to the DuckDB database file.
Use ":memory:" for in-memory database (default).
Use "md:" or "motherduck:" for MotherDuck database.
init_sql: Optional SQL to run when connecting to the database
**kwargs: Additional duckdb connection parameters
"""
try:
import duckdb
self.duckdb = duckdb
except ImportError as e:
raise ImportError(
"duckdb package is required. Install with: pip install 'vanna[duckdb]'"
) from e
self.database_path = database_path
self.init_sql = init_sql
self.kwargs = kwargs
self._conn = None
def _get_connection(self):
"""Get or create DuckDB connection."""
if self._conn is None:
self._conn = self.duckdb.connect(self.database_path, **self.kwargs)
if self.init_sql:
self._conn.query(self.init_sql)
return self._conn
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against DuckDB database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
duckdb.Error: If query execution fails
"""
conn = self._get_connection()
# Execute the query and convert to DataFrame
df = conn.query(args.sql).to_df()
return df
================================================
FILE: src/vanna/integrations/faiss/__init__.py
================================================
"""
FAISS integration for Vanna Agents.
"""
from .agent_memory import FAISSAgentMemory
__all__ = ["FAISSAgentMemory"]
================================================
FILE: src/vanna/integrations/faiss/agent_memory.py
================================================
"""
FAISS vector database implementation of AgentMemory.
This implementation uses FAISS for local vector storage of tool usage patterns.
"""
import json
import uuid
import pickle
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
import numpy as np
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class FAISSAgentMemory(AgentMemory):
"""FAISS-based implementation of AgentMemory."""
def __init__(
self,
index_path: Optional[str] = None,
persist_path: Optional[str] = None,
dimension: int = 384,
metric: str = "cosine",
):
if not FAISS_AVAILABLE:
raise ImportError(
"FAISS is required for FAISSAgentMemory. Install with: pip install faiss-cpu"
)
# Accept either index_path or persist_path for backward compatibility
self.index_path = persist_path or index_path or "./faiss_index"
self.dimension = dimension
self.metric = metric
self._index = None
self._metadata = {}
self._executor = ThreadPoolExecutor(max_workers=2)
self._load_index()
def _load_index(self):
"""Load or create FAISS index."""
index_file = os.path.join(self.index_path, "index.faiss")
metadata_file = os.path.join(self.index_path, "metadata.pkl")
if os.path.exists(index_file) and os.path.exists(metadata_file):
# Load existing index
self._index = faiss.read_index(index_file)
with open(metadata_file, "rb") as f:
self._metadata = pickle.load(f)
else:
# Create new index
os.makedirs(self.index_path, exist_ok=True)
if self.metric == "cosine":
self._index = faiss.IndexFlatIP(self.dimension)
else:
self._index = faiss.IndexFlatL2(self.dimension)
self._metadata = {}
def _save_index(self):
"""Save FAISS index to disk."""
index_file = os.path.join(self.index_path, "index.faiss")
metadata_file = os.path.join(self.index_path, "metadata.pkl")
faiss.write_index(self._index, index_file)
with open(metadata_file, "wb") as f:
pickle.dump(self._metadata, f)
def _create_embedding(self, text: str) -> np.ndarray:
"""Create a simple embedding from text (placeholder)."""
import hashlib
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
embedding = np.array(
[(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)],
dtype=np.float32,
)
# Normalize for cosine similarity
if self.metric == "cosine":
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(question)
# Add to FAISS index
self._index.add(np.array([embedding]))
# Store metadata
idx = self._index.ntotal - 1
self._metadata[idx] = {
"memory_id": memory_id,
"question": question,
"tool_name": tool_name,
"args": args,
"timestamp": timestamp,
"success": success,
"metadata": metadata or {},
}
self._save_index()
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
embedding = self._create_embedding(question)
# Search in FAISS
k = min(limit * 3, self._index.ntotal) if self._index.ntotal > 0 else 1
if k == 0:
return []
distances, indices = self._index.search(np.array([embedding]), k)
search_results = []
rank = 1
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
if idx == -1 or idx not in self._metadata:
continue
metadata = self._metadata[idx]
# Filter by success
if not metadata.get("success", True):
continue
# Filter by tool name
if tool_name_filter and metadata.get("tool_name") != tool_name_filter:
continue
# Convert distance to similarity score
if self.metric == "cosine":
similarity_score = float(dist)
else:
similarity_score = 1.0 / (1.0 + float(dist))
if similarity_score >= similarity_threshold:
memory = ToolMemory(
memory_id=metadata["memory_id"],
question=metadata["question"],
tool_name=metadata["tool_name"],
args=metadata["args"],
timestamp=metadata.get("timestamp"),
success=metadata.get("success", True),
metadata=metadata.get("metadata", {}),
)
search_results.append(
ToolMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=rank
)
)
rank += 1
if len(search_results) >= limit:
break
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories."""
def _get_recent():
# Get all metadata entries and sort by timestamp
all_entries = list(self._metadata.values())
sorted_entries = sorted(
all_entries, key=lambda m: m.get("timestamp", ""), reverse=True
)
memories = []
for entry in sorted_entries[:limit]:
# Skip text memories - they have is_text_memory flag
if entry.get("is_text_memory"):
continue
memory = ToolMemory(
memory_id=entry["memory_id"],
question=entry["question"],
tool_name=entry["tool_name"],
args=entry["args"],
timestamp=entry.get("timestamp"),
success=entry.get("success", True),
metadata=entry.get("metadata", {}),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID."""
def _delete():
# Find and remove from metadata
idx_to_remove = None
for idx, metadata in self._metadata.items():
if metadata["memory_id"] == memory_id:
idx_to_remove = idx
break
if idx_to_remove is not None:
del self._metadata[idx_to_remove]
self._save_index()
return True
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(content)
# Add to FAISS index
self._index.add(np.array([embedding]))
# Store metadata
idx = self._index.ntotal - 1
self._metadata[idx] = {
"memory_id": memory_id,
"content": content,
"timestamp": timestamp,
"is_text_memory": True,
}
self._save_index()
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
embedding = self._create_embedding(query)
k = min(limit * 3, self._index.ntotal) if self._index.ntotal > 0 else 1
if k == 0:
return []
distances, indices = self._index.search(np.array([embedding]), k)
search_results = []
rank = 1
for dist, idx in zip(distances[0], indices[0]):
if idx == -1 or idx not in self._metadata:
continue
metadata = self._metadata[idx]
# Filter for text memories only
if not metadata.get("is_text_memory", False):
continue
# Convert distance to similarity score
if self.metric == "cosine":
similarity_score = float(dist)
else:
similarity_score = 1.0 / (1.0 + float(dist))
if similarity_score >= similarity_threshold:
memory = TextMemory(
memory_id=metadata["memory_id"],
content=metadata["content"],
timestamp=metadata.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=rank
)
)
rank += 1
if len(search_results) >= limit:
break
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
# Get all text memory entries and sort by timestamp
text_entries = [
entry
for entry in self._metadata.values()
if entry.get("is_text_memory", False)
]
sorted_entries = sorted(
text_entries, key=lambda m: m.get("timestamp", ""), reverse=True
)
memories = []
for entry in sorted_entries[:limit]:
memory = TextMemory(
memory_id=entry["memory_id"],
content=entry["content"],
timestamp=entry.get("timestamp"),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
# Find and remove from metadata
idx_to_remove = None
for idx, metadata in self._metadata.items():
if metadata["memory_id"] == memory_id:
idx_to_remove = idx
break
if idx_to_remove is not None:
del self._metadata[idx_to_remove]
self._save_index()
return True
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
indices_to_remove = []
for idx, metadata in self._metadata.items():
should_remove = True
if tool_name and metadata.get("tool_name") != tool_name:
should_remove = False
if before_date and metadata.get("timestamp", "") >= before_date:
should_remove = False
if should_remove:
indices_to_remove.append(idx)
# Remove from metadata
for idx in indices_to_remove:
del self._metadata[idx]
# If clearing all, recreate index
if not tool_name and not before_date:
if self.metric == "cosine":
self._index = faiss.IndexFlatIP(self.dimension)
else:
self._index = faiss.IndexFlatL2(self.dimension)
self._metadata = {}
self._save_index()
return len(indices_to_remove)
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/integrations/google/__init__.py
================================================
"""
Google AI integrations.
This module provides Google AI service implementations.
"""
from .gemini import GeminiLlmService
__all__ = ["GeminiLlmService"]
================================================
FILE: src/vanna/integrations/google/gemini.py
================================================
"""
Google Gemini LLM service implementation.
Implements the LlmService interface using Google's Gen AI SDK
(google-genai). Supports non-streaming and streaming text output,
as well as function calling (tool use).
"""
from __future__ import annotations
import json
import logging
import os
from typing import Any, AsyncGenerator, Dict, List, Optional
logger = logging.getLogger(__name__)
from vanna.core.llm import (
LlmService,
LlmRequest,
LlmResponse,
LlmStreamChunk,
)
from vanna.core.tool import ToolCall, ToolSchema
class GeminiLlmService(LlmService):
"""Google Gemini-backed LLM service.
Args:
model: Gemini model name (e.g., "gemini-2.5-pro", "gemini-2.5-flash").
Defaults to "gemini-2.5-pro". Can also be set via GEMINI_MODEL env var.
api_key: API key; falls back to env `GOOGLE_API_KEY` or `GEMINI_API_KEY`.
GOOGLE_API_KEY takes precedence if both are set.
temperature: Temperature for generation (0.0-2.0). Default 0.7.
extra_config: Extra kwargs forwarded to GenerateContentConfig.
"""
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
temperature: float = 0.7,
**extra_config: Any,
) -> None:
try:
from google import genai
from google.genai import types
except Exception as e: # pragma: no cover
raise ImportError(
"google-genai package is required. "
"Install with: pip install 'vanna[gemini]'"
) from e
self.model_name = model or os.getenv("GEMINI_MODEL", "gemini-2.5-pro")
# Check GOOGLE_API_KEY first (takes precedence), then GEMINI_API_KEY
api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError(
"Google API key is required. Set GOOGLE_API_KEY or GEMINI_API_KEY "
"environment variable, or pass api_key parameter."
)
# Store modules for use in methods
self._genai = genai
self._types = types
# Create client
self._client = genai.Client(api_key=api_key)
# Store generation config
self.temperature = temperature
self.extra_config = extra_config
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Send a non-streaming request to Gemini and return the response."""
contents, config = self._build_payload(request)
try:
# Generate content
response = self._client.models.generate_content(
model=self.model_name,
contents=contents,
config=config,
)
logger.info(f"Gemini response: {response}")
# Parse response
text_content, tool_calls = self._parse_response(response)
# Extract usage information
usage: Dict[str, int] = {}
if hasattr(response, "usage_metadata"):
try:
usage = {
"prompt_tokens": int(
response.usage_metadata.prompt_token_count
),
"completion_tokens": int(
response.usage_metadata.candidates_token_count
),
"total_tokens": int(response.usage_metadata.total_token_count),
}
except Exception:
pass
# Get finish reason
finish_reason = None
if response.candidates:
finish_reason = str(response.candidates[0].finish_reason).lower()
return LlmResponse(
content=text_content or None,
tool_calls=tool_calls or None,
finish_reason=finish_reason,
usage=usage or None,
)
except Exception as e:
logger.error(f"Error calling Gemini API: {e}")
raise
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Stream a request to Gemini.
Yields text chunks as they arrive. Emits tool calls at the end.
"""
contents, config = self._build_payload(request)
logger.info(f"Gemini streaming request with model: {self.model_name}")
try:
# Stream content
stream = self._client.models.generate_content_stream(
model=self.model_name,
contents=contents,
config=config,
)
# Accumulate chunks for tool calls
accumulated_chunks = []
for chunk in stream:
accumulated_chunks.append(chunk)
# Yield text content as it arrives
if hasattr(chunk, "text") and chunk.text:
yield LlmStreamChunk(content=chunk.text)
# After stream completes, check for tool calls in accumulated response
if accumulated_chunks:
final_chunk = accumulated_chunks[-1]
_, tool_calls = self._parse_response_chunk(final_chunk)
finish_reason = None
if final_chunk.candidates:
finish_reason = str(final_chunk.candidates[0].finish_reason).lower()
if tool_calls:
yield LlmStreamChunk(
tool_calls=tool_calls,
finish_reason=finish_reason,
)
else:
yield LlmStreamChunk(finish_reason=finish_reason or "stop")
except Exception as e:
logger.error(f"Error streaming from Gemini API: {e}")
raise
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Basic validation of tool schemas for Gemini."""
errors: List[str] = []
for t in tools:
if not t.name:
errors.append("Tool name is required")
if not t.description:
errors.append(f"Tool {t.name}: description is required")
return errors
# Internal helpers
def _build_payload(self, request: LlmRequest) -> tuple[List[Any], Any]:
"""Build the payload for Gemini API.
Returns:
Tuple of (contents, config)
"""
# Build contents (messages) for Gemini
contents = []
# System prompt handling - Gemini supports system instructions in config
system_instruction = None
if request.system_prompt:
system_instruction = request.system_prompt
for m in request.messages:
# Map roles: user -> user, assistant -> model, tool -> function
if m.role == "user":
contents.append(
self._types.Content(
role="user", parts=[self._types.Part(text=m.content)]
)
)
elif m.role == "assistant":
parts = []
# Add text content if present
if m.content and m.content.strip():
parts.append(self._types.Part(text=m.content))
# Add tool calls if present
if m.tool_calls:
for tc in m.tool_calls:
parts.append(
self._types.Part(
function_call=self._types.FunctionCall(
name=tc.name, args=tc.arguments
)
)
)
if parts:
contents.append(self._types.Content(role="model", parts=parts))
elif m.role == "tool":
# Tool results in Gemini format
if m.tool_call_id:
# Parse the content as JSON if possible
try:
response_content = json.loads(m.content)
except (json.JSONDecodeError, TypeError):
response_content = {"result": m.content}
# Extract function name from tool_call_id or use a default
function_name = m.tool_call_id.replace("call_", "")
contents.append(
self._types.Content(
role="function",
parts=[
self._types.Part(
function_response=self._types.FunctionResponse(
name=function_name, response=response_content
)
)
],
)
)
# Build tools configuration if tools are provided
tools = None
if request.tools:
function_declarations = []
for tool in request.tools:
# Clean schema to remove unsupported fields
cleaned_parameters = self._clean_schema_for_gemini(tool.parameters)
function_declarations.append(
{
"name": tool.name,
"description": tool.description,
"parameters": cleaned_parameters,
}
)
if function_declarations:
tools = [self._types.Tool(function_declarations=function_declarations)]
# Build generation config
config_dict = {
"temperature": request.temperature,
**self.extra_config,
}
if request.max_tokens is not None:
config_dict["max_output_tokens"] = request.max_tokens
if tools:
config_dict["tools"] = tools
if system_instruction:
config_dict["system_instruction"] = system_instruction
config = self._types.GenerateContentConfig(**config_dict)
return contents, config
def _parse_response(self, response: Any) -> tuple[str, List[ToolCall]]:
"""Parse a Gemini response into text and tool calls."""
text_parts: List[str] = []
tool_calls: List[ToolCall] = []
if not response.candidates:
return "", []
candidate = response.candidates[0]
if (
hasattr(candidate, "content")
and candidate.content
and hasattr(candidate.content, "parts")
and candidate.content.parts
):
for part in candidate.content.parts:
# Check for text content
if hasattr(part, "text") and part.text:
text_parts.append(part.text)
# Check for function calls
if hasattr(part, "function_call") and part.function_call:
fc = part.function_call
# Convert function call to ToolCall
tool_calls.append(
ToolCall(
id=f"call_{fc.name}", # Generate an ID
name=fc.name,
arguments=dict(fc.args) if hasattr(fc, "args") else {},
)
)
text_content = "".join(text_parts)
return text_content, tool_calls
def _parse_response_chunk(self, chunk: Any) -> tuple[str, List[ToolCall]]:
"""Parse a streaming chunk (same logic as _parse_response)."""
return self._parse_response(chunk)
def _clean_schema_for_gemini(self, schema: Dict[str, Any]) -> Dict[str, Any]:
"""Clean JSON Schema to only include fields supported by Gemini.
Gemini only supports a subset of OpenAPI schema. This removes unsupported
fields like 'title', 'default', '$schema', etc.
Supported fields:
- type, description, enum
- properties, required, items (for objects/arrays)
"""
if not isinstance(schema, dict):
return schema
# Fields that Gemini supports
allowed_fields = {
"type",
"description",
"enum",
"properties",
"required",
"items",
"format",
}
cleaned = {}
for key, value in schema.items():
if key in allowed_fields:
# Recursively clean nested schemas
if key == "properties" and isinstance(value, dict):
cleaned[key] = {
prop_name: self._clean_schema_for_gemini(prop_schema)
for prop_name, prop_schema in value.items()
}
elif key == "items" and isinstance(value, dict):
cleaned[key] = self._clean_schema_for_gemini(value)
else:
cleaned[key] = value
return cleaned
================================================
FILE: src/vanna/integrations/hive/__init__.py
================================================
"""Hive integration for Vanna."""
from .sql_runner import HiveRunner
__all__ = ["HiveRunner"]
================================================
FILE: src/vanna/integrations/hive/sql_runner.py
================================================
"""Hive implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class HiveRunner(SqlRunner):
"""Hive implementation of the SqlRunner interface."""
def __init__(
self,
host: str,
database: str = "default",
user: Optional[str] = None,
password: Optional[str] = None,
port: int = 10000,
auth: str = "CUSTOM",
**kwargs,
):
"""Initialize with Hive connection parameters.
Args:
host: The host of the Hive database
database: The name of the database to connect to (default: 'default')
user: The username to use for authentication
password: The password to use for authentication
port: The port to use for the connection (default: 10000)
auth: The authentication method to use (default: 'CUSTOM')
**kwargs: Additional pyhive connection parameters
"""
try:
from pyhive import hive
self.hive = hive
except ImportError as e:
raise ImportError(
"pyhive package is required. Install with: pip install pyhive"
) from e
self.host = host
self.database = database
self.user = user
self.password = password
self.port = port
self.auth = auth
self.kwargs = kwargs
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against Hive database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
hive.Error: If query execution fails
"""
# Connect to the database
conn = self.hive.Connection(
host=self.host,
username=self.user,
password=self.password,
database=self.database,
port=self.port,
auth=self.auth,
**self.kwargs,
)
try:
cursor = conn.cursor()
cursor.execute(args.sql)
results = cursor.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])
cursor.close()
return df
finally:
conn.close()
================================================
FILE: src/vanna/integrations/local/__init__.py
================================================
"""
Local integration.
This module provides built-in local implementations.
"""
from .audit import LoggingAuditLogger
from .file_system import LocalFileSystem
from .storage import MemoryConversationStore
from .file_system_conversation_store import FileSystemConversationStore
__all__ = [
"MemoryConversationStore",
"FileSystemConversationStore",
"LocalFileSystem",
"LoggingAuditLogger",
]
================================================
FILE: src/vanna/integrations/local/agent_memory/__init__.py
================================================
"""
Local agent memory implementations.
"""
from .in_memory import DemoAgentMemory
__all__ = ["DemoAgentMemory"]
================================================
FILE: src/vanna/integrations/local/agent_memory/in_memory.py
================================================
"""
Demo in-memory implementation of AgentMemory.
This implementation provides a zero-dependency, minimal storage solution that
keeps all memories in RAM. It uses simple similarity algorithms (Jaccard and
difflib) instead of vector embeddings. Perfect for demos and testing.
"""
from __future__ import annotations
import asyncio
import difflib
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class DemoAgentMemory(AgentMemory):
"""
Minimal, dependency-free in-memory storage for demos and testing.
- O(n) search over an in-memory list
- Simple similarity: max(Jaccard(token sets), difflib ratio)
- Optional FIFO eviction via max_items
- Async-safe with an asyncio.Lock
"""
def __init__(self, *, max_items: int = 10_000):
"""
Initialize the in-memory storage.
Args:
max_items: Maximum number of memories to keep. Oldest memories are
evicted when this limit is reached (FIFO).
"""
self._memories: List[ToolMemory] = []
self._text_memories: List[TextMemory] = []
self._lock = asyncio.Lock()
self._max_items = max_items
@staticmethod
def _now_iso() -> str:
"""Get current timestamp in ISO format."""
return datetime.now().isoformat()
@staticmethod
def _normalize(text: str) -> str:
"""Normalize text by lowercasing and collapsing whitespace."""
return " ".join(text.lower().split())
@staticmethod
def _tokenize(text: str) -> set[str]:
"""Simple tokenizer that splits on whitespace."""
return set(text.lower().split())
@classmethod
def _similarity(cls, a: str, b: str) -> float:
"""
Calculate similarity between two strings using multiple methods.
Returns the maximum of Jaccard similarity and difflib ratio.
"""
a_norm, b_norm = cls._normalize(a), cls._normalize(b)
# Jaccard over whitespace tokens
ta, tb = cls._tokenize(a_norm), cls._tokenize(b_norm)
if not ta and not tb:
jaccard = 1.0
elif not ta or not tb:
jaccard = 0.0
else:
jaccard = len(ta & tb) / max(1, len(ta | tb))
# difflib ratio
ratio = difflib.SequenceMatcher(None, a_norm, b_norm).ratio()
# Take the better of the two cheap measures
return max(jaccard, ratio)
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern for future reference."""
tm = ToolMemory(
memory_id=str(uuid.uuid4()),
question=question,
tool_name=tool_name,
args=args,
timestamp=self._now_iso(),
success=success,
metadata=metadata or {},
)
async with self._lock:
self._memories.append(tm)
# Optional FIFO eviction
if len(self._memories) > self._max_items:
overflow = len(self._memories) - self._max_items
del self._memories[:overflow]
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Store a text memory in RAM."""
tm = TextMemory(
memory_id=str(uuid.uuid4()), content=content, timestamp=self._now_iso()
)
async with self._lock:
self._text_memories.append(tm)
if len(self._text_memories) > self._max_items:
overflow = len(self._text_memories) - self._max_items
del self._text_memories[:overflow]
return tm
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns based on a question."""
q = self._normalize(question)
async with self._lock:
# Filter candidates by tool name and success status
candidates = [
m
for m in self._memories
if m.success
and (tool_name_filter is None or m.tool_name == tool_name_filter)
]
# Score each candidate by question similarity
results: List[tuple[ToolMemory, float]] = []
for m in candidates:
score = self._similarity(q, m.question)
results.append((m, min(score, 1.0)))
# Filter by threshold and sort by score
results = [(m, s) for (m, s) in results if s >= similarity_threshold]
results.sort(key=lambda x: x[1], reverse=True)
# Build ranked response
out: List[ToolMemorySearchResult] = []
for idx, (m, s) in enumerate(results[:limit], start=1):
out.append(
ToolMemorySearchResult(memory=m, similarity_score=s, rank=idx)
)
return out
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search free-form text memories using the demo similarity metric."""
normalized_query = self._normalize(query)
async with self._lock:
scored: List[tuple[TextMemory, float]] = []
for memory in self._text_memories:
score = self._similarity(normalized_query, memory.content)
scored.append((memory, min(score, 1.0)))
scored = [
(memory, score)
for memory, score in scored
if score >= similarity_threshold
]
scored.sort(key=lambda item: item[1], reverse=True)
results: List[TextMemorySearchResult] = []
for idx, (memory, score) in enumerate(scored[:limit], start=1):
results.append(
TextMemorySearchResult(
memory=memory, similarity_score=score, rank=idx
)
)
return results
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories. Returns most recent memories first."""
async with self._lock:
# Return memories in reverse order (most recent first)
return list(reversed(self._memories[-limit:]))
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Return recently added text memories."""
async with self._lock:
return list(reversed(self._text_memories[-limit:]))
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a stored text memory by ID."""
async with self._lock:
for index, memory in enumerate(self._text_memories):
if memory.memory_id == memory_id:
del self._text_memories[index]
return True
return False
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID. Returns True if deleted, False if not found."""
async with self._lock:
for i, m in enumerate(self._memories):
if m.memory_id == memory_id:
del self._memories[i]
return True
return False
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories. Returns number of memories deleted."""
async with self._lock:
original_tool_count = len(self._memories)
original_text_count = len(self._text_memories)
# Filter memories to keep
kept_memories = []
for m in self._memories:
should_delete = True
# Check tool name filter
if tool_name and m.tool_name != tool_name:
should_delete = False
# Check date filter
if should_delete and before_date and m.timestamp:
if m.timestamp >= before_date:
should_delete = False
# If no filters specified, delete all
if tool_name is None and before_date is None:
should_delete = True
# Keep if should not delete
if not should_delete:
kept_memories.append(m)
self._memories = kept_memories
deleted_tool_count = original_tool_count - len(self._memories)
# Apply filters to text memories (tool filter ignored)
kept_text_memories = []
for memory in self._text_memories:
should_delete = (
tool_name is None
) # only delete text when not targeting a tool
if before_date and memory.timestamp:
if memory.timestamp >= before_date:
should_delete = False
if not should_delete:
kept_text_memories.append(memory)
self._text_memories = kept_text_memories
deleted_text_count = original_text_count - len(self._text_memories)
return deleted_tool_count + deleted_text_count
================================================
FILE: src/vanna/integrations/local/audit.py
================================================
"""
Local audit logger implementation using Python logging.
This module provides a simple audit logger that writes events using
the standard Python logging module, useful for development and testing.
"""
import json
import logging
from typing import Optional
from vanna.core.audit import AuditEvent, AuditLogger
logger = logging.getLogger(__name__)
class LoggingAuditLogger(AuditLogger):
"""Audit logger that writes events to Python logger as structured JSON.
This implementation uses logger.info() to emit audit events as JSON,
making them easy to parse and route to log aggregation systems.
Example:
audit_logger = LoggingAuditLogger()
agent = Agent(
llm_service=...,
audit_logger=audit_logger
)
"""
def __init__(self, log_level: int = logging.INFO):
"""Initialize the logging audit logger.
Args:
log_level: Log level to use for audit events (default: INFO)
"""
self.log_level = log_level
async def log_event(self, event: AuditEvent) -> None:
"""Log an audit event as structured JSON.
Args:
event: The audit event to log
"""
try:
# Convert event to dict for JSON serialization
event_dict = event.model_dump(mode="json", exclude_none=True)
# Format as single-line JSON for easy parsing
event_json = json.dumps(event_dict, separators=(",", ":"))
# Log with structured prefix for easy filtering
logger.log(
self.log_level,
f"[AUDIT] {event.event_type.value} | {event_json}",
)
except Exception as e:
# Don't fail the operation if audit logging fails
logger.error(f"Failed to log audit event: {e}", exc_info=True)
================================================
FILE: src/vanna/integrations/local/file_system.py
================================================
"""
Local file system implementation.
This module provides a local file system implementation with per-user isolation.
"""
import asyncio
import hashlib
from pathlib import Path
from typing import List, Optional
from vanna.capabilities.file_system import CommandResult, FileSearchMatch, FileSystem
from vanna.core.tool import ToolContext
MAX_SEARCH_FILE_BYTES = 1_000_000
class LocalFileSystem(FileSystem):
"""Local file system implementation with per-user isolation."""
def __init__(self, working_directory: str = "."):
"""Initialize with a working directory.
Args:
working_directory: Base directory where user-specific folders will be created
"""
self.working_directory = Path(working_directory)
def _get_user_directory(self, context: ToolContext) -> Path:
"""Get the user-specific directory by hashing the user ID.
Args:
context: Tool context containing user information
Returns:
Path to the user-specific directory
"""
# Hash the user ID to create a directory name
user_hash = hashlib.sha256(context.user.id.encode()).hexdigest()[:16]
user_dir = self.working_directory / user_hash
# Create the directory if it doesn't exist
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir
def _resolve_path(self, path: str, context: ToolContext) -> Path:
"""Resolve a path relative to the user's directory.
Args:
path: Path relative to user directory
context: Tool context containing user information
Returns:
Absolute path within user's directory
"""
user_dir = self._get_user_directory(context)
resolved = user_dir / path
# Ensure the path is within the user's directory (prevent directory traversal)
try:
resolved.resolve().relative_to(user_dir.resolve())
except ValueError:
raise PermissionError(
f"Access denied: path '{path}' is outside user directory"
)
return resolved
async def list_files(self, directory: str, context: ToolContext) -> List[str]:
"""List files in a directory within the user's isolated space."""
directory_path = self._resolve_path(directory, context)
if not directory_path.exists():
raise FileNotFoundError(f"Directory '{directory}' does not exist")
if not directory_path.is_dir():
raise NotADirectoryError(f"'{directory}' is not a directory")
files = []
for item in directory_path.iterdir():
if item.is_file():
files.append(item.name)
return sorted(files)
async def read_file(self, filename: str, context: ToolContext) -> str:
"""Read the contents of a file within the user's isolated space."""
file_path = self._resolve_path(filename, context)
if not file_path.exists():
raise FileNotFoundError(f"File '{filename}' does not exist")
if not file_path.is_file():
raise IsADirectoryError(f"'{filename}' is a directory, not a file")
return file_path.read_text(encoding="utf-8")
async def write_file(
self, filename: str, content: str, context: ToolContext, overwrite: bool = False
) -> None:
"""Write content to a file within the user's isolated space."""
file_path = self._resolve_path(filename, context)
# Create parent directories if they don't exist
file_path.parent.mkdir(parents=True, exist_ok=True)
if file_path.exists() and not overwrite:
raise FileExistsError(
f"File '{filename}' already exists. Use overwrite=True to replace it."
)
file_path.write_text(content, encoding="utf-8")
async def exists(self, path: str, context: ToolContext) -> bool:
"""Check if a file or directory exists within the user's isolated space."""
try:
resolved_path = self._resolve_path(path, context)
return resolved_path.exists()
except PermissionError:
return False
async def is_directory(self, path: str, context: ToolContext) -> bool:
"""Check if a path is a directory within the user's isolated space."""
try:
resolved_path = self._resolve_path(path, context)
return resolved_path.exists() and resolved_path.is_dir()
except PermissionError:
return False
async def search_files(
self,
query: str,
context: ToolContext,
*,
max_results: int = 20,
include_content: bool = False,
) -> List[FileSearchMatch]:
"""Search for files within the user's isolated space."""
trimmed_query = query.strip()
if not trimmed_query:
raise ValueError("Search query must not be empty")
user_dir = self._get_user_directory(context)
matches: List[FileSearchMatch] = []
query_lower = trimmed_query.lower()
for path in user_dir.rglob("*"):
if len(matches) >= max_results:
break
if not path.is_file():
continue
relative_path = path.relative_to(user_dir).as_posix()
include_entry = False
snippet: Optional[str] = None
if query_lower in path.name.lower():
include_entry = True
snippet = "[filename match]"
content: Optional[str] = None
if include_content:
try:
size = path.stat().st_size
except OSError:
if include_entry:
matches.append(
FileSearchMatch(path=relative_path, snippet=snippet)
)
continue
if size <= MAX_SEARCH_FILE_BYTES:
try:
content = path.read_text(encoding="utf-8")
except (UnicodeDecodeError, OSError):
content = None
elif not include_entry:
# Skip oversized files if they do not match by name
continue
if include_content and content is not None:
if query_lower in content.lower():
# Create snippet
lowered = content.lower()
index = lowered.find(query_lower)
if index != -1:
context_window = 60
start = max(0, index - context_window)
end = min(len(content), index + len(query) + context_window)
snippet = content[start:end].replace("\n", " ").strip()
if start > 0:
snippet = f"…{snippet}"
if end < len(content):
snippet = f"{snippet}…"
include_entry = True
elif not include_entry:
continue
if include_entry:
matches.append(FileSearchMatch(path=relative_path, snippet=snippet))
return matches
async def run_bash(
self,
command: str,
context: ToolContext,
*,
timeout: Optional[float] = None,
) -> CommandResult:
"""Execute a bash command within the user's isolated space."""
if not command.strip():
raise ValueError("Command must not be empty")
user_dir = self._get_user_directory(context)
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(user_dir),
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
process.communicate(), timeout=timeout
)
except asyncio.TimeoutError as exc:
process.kill()
await process.wait()
raise TimeoutError(f"Command timed out after {timeout} seconds") from exc
stdout = stdout_bytes.decode("utf-8", errors="replace")
stderr = stderr_bytes.decode("utf-8", errors="replace")
return CommandResult(
stdout=stdout, stderr=stderr, returncode=process.returncode or 0
)
================================================
FILE: src/vanna/integrations/local/file_system_conversation_store.py
================================================
"""
File system conversation store implementation.
This module provides a file-based implementation of the ConversationStore
interface that persists conversations to disk as a directory structure.
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from datetime import datetime
import time
from vanna.core.storage import ConversationStore, Conversation, Message
from vanna.core.user import User
class FileSystemConversationStore(ConversationStore):
"""File system-based conversation store.
Stores conversations as directories with individual message files:
conversations/{conversation_id}/
metadata.json - conversation metadata (id, user info, timestamps)
messages/
{timestamp}_{index}.json - individual message files
"""
def __init__(self, base_dir: str = "conversations") -> None:
"""Initialize the file system conversation store.
Args:
base_dir: Base directory for storing conversations
"""
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
def _get_conversation_dir(self, conversation_id: str) -> Path:
"""Get the directory path for a conversation."""
return self.base_dir / conversation_id
def _get_metadata_path(self, conversation_id: str) -> Path:
"""Get the metadata file path for a conversation."""
return self._get_conversation_dir(conversation_id) / "metadata.json"
def _get_messages_dir(self, conversation_id: str) -> Path:
"""Get the messages directory for a conversation."""
return self._get_conversation_dir(conversation_id) / "messages"
def _save_metadata(self, conversation: Conversation) -> None:
"""Save conversation metadata to disk."""
conv_dir = self._get_conversation_dir(conversation.id)
conv_dir.mkdir(parents=True, exist_ok=True)
metadata = {
"id": conversation.id,
"user": conversation.user.model_dump(mode="json"),
"created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(),
}
metadata_path = self._get_metadata_path(conversation.id)
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
def _load_messages(self, conversation_id: str) -> List[Message]:
"""Load all messages for a conversation."""
messages_dir = self._get_messages_dir(conversation_id)
if not messages_dir.exists():
return []
messages = []
# Sort message files by name (timestamp_index ensures correct order)
message_files = sorted(messages_dir.glob("*.json"))
for file_path in message_files:
try:
with open(file_path, "r") as f:
data = json.load(f)
message = Message.model_validate(data)
messages.append(message)
except (json.JSONDecodeError, ValueError) as e:
print(f"Failed to load message from {file_path}: {e}")
continue
return messages
def _append_message(
self, conversation_id: str, message: Message, index: int
) -> None:
"""Append a message to the conversation."""
messages_dir = self._get_messages_dir(conversation_id)
messages_dir.mkdir(parents=True, exist_ok=True)
# Use timestamp + index to ensure unique, ordered filenames
timestamp = int(time.time() * 1000000) # microseconds
filename = f"{timestamp}_{index:06d}.json"
file_path = messages_dir / filename
with open(file_path, "w") as f:
json.dump(message.model_dump(mode="json"), f, indent=2)
async def create_conversation(
self, conversation_id: str, user: User, initial_message: str
) -> Conversation:
"""Create a new conversation with the specified ID."""
conversation = Conversation(
id=conversation_id,
user=user,
messages=[Message(role="user", content=initial_message)],
)
# Save metadata
self._save_metadata(conversation)
# Save initial message
self._append_message(conversation_id, conversation.messages[0], 0)
return conversation
async def get_conversation(
self, conversation_id: str, user: User
) -> Optional[Conversation]:
"""Get conversation by ID, scoped to user."""
metadata_path = self._get_metadata_path(conversation_id)
if not metadata_path.exists():
return None
try:
# Load metadata
with open(metadata_path, "r") as f:
metadata = json.load(f)
# Verify ownership
if metadata["user"]["id"] != user.id:
return None
# Load all messages
messages = self._load_messages(conversation_id)
# Reconstruct conversation
conversation = Conversation(
id=metadata["id"],
user=User.model_validate(metadata["user"]),
messages=messages,
created_at=datetime.fromisoformat(metadata["created_at"]),
updated_at=datetime.fromisoformat(metadata["updated_at"]),
)
return conversation
except (json.JSONDecodeError, ValueError, KeyError) as e:
print(f"Failed to load conversation {conversation_id}: {e}")
return None
async def update_conversation(self, conversation: Conversation) -> None:
"""Update conversation with new messages."""
# Update the updated_at timestamp
conversation.updated_at = datetime.now()
# Save updated metadata
self._save_metadata(conversation)
# Get existing messages count to determine new message indices
existing_messages = self._load_messages(conversation.id)
existing_count = len(existing_messages)
# Only append new messages (ones not already saved)
for i, message in enumerate(
conversation.messages[existing_count:], start=existing_count
):
self._append_message(conversation.id, message, i)
async def delete_conversation(self, conversation_id: str, user: User) -> bool:
"""Delete conversation."""
conv_dir = self._get_conversation_dir(conversation_id)
if not conv_dir.exists():
return False
# Verify ownership before deleting
conversation = await self.get_conversation(conversation_id, user)
if not conversation:
return False
try:
# Delete all message files
messages_dir = self._get_messages_dir(conversation_id)
if messages_dir.exists():
for file_path in messages_dir.glob("*.json"):
file_path.unlink()
messages_dir.rmdir()
# Delete metadata
metadata_path = self._get_metadata_path(conversation_id)
if metadata_path.exists():
metadata_path.unlink()
# Delete conversation directory
conv_dir.rmdir()
return True
except OSError as e:
print(f"Failed to delete conversation {conversation_id}: {e}")
return False
async def list_conversations(
self, user: User, limit: int = 50, offset: int = 0
) -> List[Conversation]:
"""List conversations for user."""
if not self.base_dir.exists():
return []
conversations = []
# Iterate through all conversation directories
for conv_dir in self.base_dir.iterdir():
if not conv_dir.is_dir():
continue
metadata_path = conv_dir / "metadata.json"
if not metadata_path.exists():
continue
try:
# Load metadata
with open(metadata_path, "r") as f:
metadata = json.load(f)
# Skip conversations not owned by this user
if metadata["user"]["id"] != user.id:
continue
# Load messages
messages = self._load_messages(conv_dir.name)
# Reconstruct conversation
conversation = Conversation(
id=metadata["id"],
user=User.model_validate(metadata["user"]),
messages=messages,
created_at=datetime.fromisoformat(metadata["created_at"]),
updated_at=datetime.fromisoformat(metadata["updated_at"]),
)
conversations.append(conversation)
except (json.JSONDecodeError, ValueError, KeyError) as e:
print(f"Failed to load conversation from {conv_dir}: {e}")
continue
# Sort by updated_at desc
conversations.sort(key=lambda x: x.updated_at, reverse=True)
# Apply pagination
return conversations[offset : offset + limit]
================================================
FILE: src/vanna/integrations/local/storage.py
================================================
"""
In-memory conversation store implementation.
This module provides a simple in-memory implementation of the ConversationStore
interface, useful for testing and development.
"""
from typing import Dict, List, Optional
from vanna.core.storage import ConversationStore, Conversation, Message
from vanna.core.user import User
class MemoryConversationStore(ConversationStore):
"""In-memory conversation store."""
def __init__(self) -> None:
self._conversations: Dict[str, Conversation] = {}
async def create_conversation(
self, conversation_id: str, user: User, initial_message: str
) -> Conversation:
"""Create a new conversation with the specified ID."""
conversation = Conversation(
id=conversation_id,
user=user,
messages=[Message(role="user", content=initial_message)],
)
self._conversations[conversation_id] = conversation
return conversation
async def get_conversation(
self, conversation_id: str, user: User
) -> Optional[Conversation]:
"""Get conversation by ID, scoped to user."""
conversation = self._conversations.get(conversation_id)
if conversation and conversation.user.id == user.id:
return conversation
return None
async def update_conversation(self, conversation: Conversation) -> None:
"""Update conversation with new messages."""
self._conversations[conversation.id] = conversation
async def delete_conversation(self, conversation_id: str, user: User) -> bool:
"""Delete conversation."""
conversation = await self.get_conversation(conversation_id, user)
if conversation:
del self._conversations[conversation_id]
return True
return False
async def list_conversations(
self, user: User, limit: int = 50, offset: int = 0
) -> List[Conversation]:
"""List conversations for user."""
user_conversations = [
conv for conv in self._conversations.values() if conv.user.id == user.id
]
# Sort by updated_at desc
user_conversations.sort(key=lambda x: x.updated_at, reverse=True)
return user_conversations[offset : offset + limit]
================================================
FILE: src/vanna/integrations/marqo/__init__.py
================================================
"""
Marqo integration for Vanna Agents.
"""
from .agent_memory import MarqoAgentMemory
__all__ = ["MarqoAgentMemory"]
================================================
FILE: src/vanna/integrations/marqo/agent_memory.py
================================================
"""
Marqo vector database implementation of AgentMemory.
This implementation uses Marqo for vector storage of tool usage patterns.
"""
import json
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
try:
import marqo
MARQO_AVAILABLE = True
except ImportError:
MARQO_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class MarqoAgentMemory(AgentMemory):
"""Marqo-based implementation of AgentMemory."""
def __init__(
self,
url: str = "http://localhost:8882",
index_name: str = "tool-memories",
api_key: Optional[str] = None,
):
if not MARQO_AVAILABLE:
raise ImportError(
"Marqo is required for MarqoAgentMemory. Install with: pip install marqo"
)
self.url = url
self.index_name = index_name
self.api_key = api_key
self._client = None
self._executor = ThreadPoolExecutor(max_workers=2)
def _get_client(self):
"""Get or create Marqo client."""
if self._client is None:
self._client = marqo.Client(url=self.url, api_key=self.api_key)
# Create index if it doesn't exist
try:
self._client.get_index(self.index_name)
except Exception:
self._client.create_index(self.index_name)
return self._client
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
client = self._get_client()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
document = {
"_id": memory_id,
"question": question,
"tool_name": tool_name,
"args": json.dumps(args),
"timestamp": timestamp,
"success": success,
"metadata": json.dumps(metadata or {}),
}
client.index(self.index_name).add_documents(
[document], tensor_fields=["question"]
)
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
client = self._get_client()
# Build filter
filter_string = "success:true"
if tool_name_filter:
filter_string += f" AND tool_name:{tool_name_filter}"
results = client.index(self.index_name).search(
q=question, limit=limit, filter_string=filter_string
)
search_results = []
for i, hit in enumerate(results["hits"]):
# Marqo returns score
similarity_score = hit.get("_score", 0)
if similarity_score >= similarity_threshold:
args = json.loads(hit.get("args", "{}"))
metadata_dict = json.loads(hit.get("metadata", "{}"))
memory = ToolMemory(
memory_id=hit["_id"],
question=hit["question"],
tool_name=hit["tool_name"],
args=args,
timestamp=hit.get("timestamp"),
success=hit.get("success", True),
metadata=metadata_dict,
)
search_results.append(
ToolMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories."""
def _get_recent():
client = self._get_client()
# Search with wildcard and sort by timestamp
results = client.index(self.index_name).search(
q="*", limit=limit, sort="timestamp:desc"
)
memories = []
for hit in results.get("hits", []):
args = json.loads(hit.get("args", "{}"))
metadata_dict = json.loads(hit.get("metadata", "{}"))
memory = ToolMemory(
memory_id=hit["_id"],
question=hit["question"],
tool_name=hit["tool_name"],
args=args,
timestamp=hit.get("timestamp"),
success=hit.get("success", True),
metadata=metadata_dict,
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID."""
def _delete():
client = self._get_client()
try:
client.index(self.index_name).delete_documents(ids=[memory_id])
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
client = self._get_client()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
document = {
"_id": memory_id,
"content": content,
"timestamp": timestamp,
"is_text_memory": True,
}
client.index(self.index_name).add_documents(
[document], tensor_fields=["content"]
)
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
client = self._get_client()
filter_string = "is_text_memory:true"
results = client.index(self.index_name).search(
q=query, limit=limit, filter_string=filter_string
)
search_results = []
for i, hit in enumerate(results["hits"]):
similarity_score = hit.get("_score", 0)
if similarity_score >= similarity_threshold:
memory = TextMemory(
memory_id=hit["_id"],
content=hit.get("content", ""),
timestamp=hit.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
client = self._get_client()
results = client.index(self.index_name).search(
q="*",
limit=limit,
filter_string="is_text_memory:true",
sort="timestamp:desc",
)
memories = []
for hit in results.get("hits", []):
memory = TextMemory(
memory_id=hit["_id"],
content=hit.get("content", ""),
timestamp=hit.get("timestamp"),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
client = self._get_client()
try:
client.index(self.index_name).delete_documents(ids=[memory_id])
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
client = self._get_client()
# Build filter for search
filter_parts = []
if tool_name:
filter_parts.append(f"tool_name:{tool_name}")
if before_date:
filter_parts.append(f"timestamp:[* TO {before_date}]")
if filter_parts or (tool_name is None and before_date is None):
filter_string = " AND ".join(filter_parts) if filter_parts else None
if filter_string:
# Search for documents to delete
results = client.index(self.index_name).search(
q="*",
limit=1000, # Max results
filter_string=filter_string,
)
ids_to_delete = [hit["_id"] for hit in results.get("hits", [])]
if ids_to_delete:
client.index(self.index_name).delete_documents(
ids=ids_to_delete
)
return len(ids_to_delete)
else:
# Delete entire index and recreate
try:
client.delete_index(self.index_name)
client.create_index(self.index_name)
except Exception:
pass
return 0
return 0
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/integrations/milvus/__init__.py
================================================
"""
Milvus integration for Vanna Agents.
"""
from .agent_memory import MilvusAgentMemory
__all__ = ["MilvusAgentMemory"]
================================================
FILE: src/vanna/integrations/milvus/agent_memory.py
================================================
"""
Milvus vector database implementation of AgentMemory.
This implementation uses Milvus for distributed vector storage of tool usage patterns.
"""
import json
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
try:
from pymilvus import (
connections,
Collection,
CollectionSchema,
FieldSchema,
DataType,
utility,
)
MILVUS_AVAILABLE = True
except ImportError:
MILVUS_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class MilvusAgentMemory(AgentMemory):
"""Milvus-based implementation of AgentMemory."""
def __init__(
self,
collection_name: str = "tool_memories",
host: str = "localhost",
port: int = 19530,
alias: str = "default",
dimension: int = 384,
):
if not MILVUS_AVAILABLE:
raise ImportError(
"Milvus is required for MilvusAgentMemory. Install with: pip install pymilvus"
)
self.collection_name = collection_name
self.host = host
self.port = port
self.alias = alias
self.dimension = dimension
self._collection = None
self._executor = ThreadPoolExecutor(max_workers=2)
def _get_collection(self):
"""Get or create Milvus collection."""
if self._collection is None:
# Connect to Milvus
connections.connect(alias=self.alias, host=self.host, port=self.port)
# Create collection if it doesn't exist
if not utility.has_collection(self.collection_name):
fields = [
FieldSchema(
name="id",
dtype=DataType.VARCHAR,
is_primary=True,
max_length=100,
),
FieldSchema(
name="embedding",
dtype=DataType.FLOAT_VECTOR,
dim=self.dimension,
),
FieldSchema(
name="question", dtype=DataType.VARCHAR, max_length=2000
),
FieldSchema(
name="tool_name", dtype=DataType.VARCHAR, max_length=200
),
FieldSchema(
name="args_json", dtype=DataType.VARCHAR, max_length=5000
),
FieldSchema(
name="timestamp", dtype=DataType.VARCHAR, max_length=50
),
FieldSchema(name="success", dtype=DataType.BOOL),
FieldSchema(
name="metadata_json", dtype=DataType.VARCHAR, max_length=5000
),
]
schema = CollectionSchema(
fields=fields, description="Tool usage memories"
)
collection = Collection(name=self.collection_name, schema=schema)
# Create index for vector field
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "IP",
"params": {"nlist": 128},
}
collection.create_index(
field_name="embedding", index_params=index_params
)
self._collection = Collection(self.collection_name)
self._collection.load()
return self._collection
def _create_embedding(self, text: str) -> List[float]:
"""Create a simple embedding from text (placeholder)."""
import hashlib
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
collection = self._get_collection()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(question)
entities = [
[memory_id],
[embedding],
[question],
[tool_name],
[json.dumps(args)],
[timestamp],
[success],
[json.dumps(metadata or {})],
]
collection.insert(entities)
collection.flush()
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
collection = self._get_collection()
embedding = self._create_embedding(question)
# Build filter expression
expr = "success == true"
if tool_name_filter:
expr += f' && tool_name == "{tool_name_filter}"'
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
results = collection.search(
data=[embedding],
anns_field="embedding",
param=search_params,
limit=limit,
expr=expr,
output_fields=[
"id",
"question",
"tool_name",
"args_json",
"timestamp",
"success",
"metadata_json",
],
)
search_results = []
for i, hits in enumerate(results):
for j, hit in enumerate(hits):
similarity_score = hit.distance
if similarity_score >= similarity_threshold:
args = json.loads(hit.entity.get("args_json", "{}"))
metadata_dict = json.loads(
hit.entity.get("metadata_json", "{}")
)
memory = ToolMemory(
memory_id=hit.entity.get("id"),
question=hit.entity.get("question"),
tool_name=hit.entity.get("tool_name"),
args=args,
timestamp=hit.entity.get("timestamp"),
success=hit.entity.get("success", True),
metadata=metadata_dict,
)
search_results.append(
ToolMemorySearchResult(
memory=memory,
similarity_score=similarity_score,
rank=j + 1,
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories."""
def _get_recent():
collection = self._get_collection()
# Query all entries and sort by timestamp
results = collection.query(
expr="id != ''",
output_fields=[
"id",
"question",
"tool_name",
"args_json",
"timestamp",
"success",
"metadata_json",
],
limit=1000,
)
# Sort by timestamp
sorted_results = sorted(
results, key=lambda r: r.get("timestamp", ""), reverse=True
)
memories = []
for result in sorted_results[:limit]:
args = json.loads(result.get("args_json", "{}"))
metadata_dict = json.loads(result.get("metadata_json", "{}"))
memory = ToolMemory(
memory_id=result.get("id"),
question=result.get("question"),
tool_name=result.get("tool_name"),
args=args,
timestamp=result.get("timestamp"),
success=result.get("success", True),
metadata=metadata_dict,
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID."""
def _delete():
collection = self._get_collection()
try:
expr = f'id == "{memory_id}"'
collection.delete(expr)
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
collection = self._get_collection()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(content)
entities = [
[memory_id],
[embedding],
[content],
[""], # tool_name (empty for text memories)
[""], # args_json (empty for text memories)
[timestamp],
[True], # success (always true for text memories)
[json.dumps({"is_text_memory": True})], # metadata_json
]
collection.insert(entities)
collection.flush()
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
collection = self._get_collection()
embedding = self._create_embedding(query)
# Build filter expression for text memories
expr = 'tool_name == ""'
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
results = collection.search(
data=[embedding],
anns_field="embedding",
param=search_params,
limit=limit,
expr=expr,
output_fields=["id", "question", "timestamp", "metadata_json"],
)
search_results = []
for i, hits in enumerate(results):
for j, hit in enumerate(hits):
similarity_score = hit.distance
if similarity_score >= similarity_threshold:
content = hit.entity.get("question", "")
memory = TextMemory(
memory_id=hit.entity.get("id"),
content=content,
timestamp=hit.entity.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory,
similarity_score=similarity_score,
rank=j + 1,
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
collection = self._get_collection()
# Query text memory entries
results = collection.query(
expr='tool_name == ""',
output_fields=["id", "question", "timestamp"],
limit=1000,
)
# Sort by timestamp
sorted_results = sorted(
results, key=lambda r: r.get("timestamp", ""), reverse=True
)
memories = []
for result in sorted_results[:limit]:
memory = TextMemory(
memory_id=result.get("id"),
content=result.get("question", ""),
timestamp=result.get("timestamp"),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
collection = self._get_collection()
try:
expr = f'id == "{memory_id}"'
collection.delete(expr)
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
collection = self._get_collection()
# Build filter expression
expr_parts = []
if tool_name:
expr_parts.append(f'tool_name == "{tool_name}"')
if before_date:
expr_parts.append(f'timestamp < "{before_date}"')
if expr_parts:
expr = " && ".join(expr_parts)
else:
expr = "id != ''"
collection.delete(expr)
return 0
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/integrations/mock/__init__.py
================================================
"""
Mock integration.
This module provides mock implementations for testing.
"""
from .llm import MockLlmService
__all__ = ["MockLlmService"]
================================================
FILE: src/vanna/integrations/mock/llm.py
================================================
"""
Mock LLM service implementation for testing.
This module provides a simple mock implementation of the LlmService interface,
useful for testing and development without requiring actual LLM API calls.
"""
import asyncio
from typing import AsyncGenerator, List
from vanna.core.llm import LlmService, LlmRequest, LlmResponse, LlmStreamChunk
from vanna.core.tool import ToolSchema
class MockLlmService(LlmService):
"""Mock LLM service that returns predefined responses."""
def __init__(self, response_content: str = "Hello! This is a mock response."):
self.response_content = response_content
self.call_count = 0
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Send a request to the mock LLM."""
self.call_count += 1
# Simulate processing delay
await asyncio.sleep(0.1)
# Return a simple response
return LlmResponse(
content=f"{self.response_content} (Request #{self.call_count})",
finish_reason="stop",
usage={"prompt_tokens": 50, "completion_tokens": 20, "total_tokens": 70},
)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Stream a request to the mock LLM."""
self.call_count += 1
# Split response into chunks
words = f"{self.response_content} (Streamed #{self.call_count})".split()
for i, word in enumerate(words):
await asyncio.sleep(0.05) # Simulate streaming delay
chunk_content = word + (" " if i < len(words) - 1 else "")
yield LlmStreamChunk(
content=chunk_content,
finish_reason="stop" if i == len(words) - 1 else None,
)
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Validate tool schemas and return any errors."""
# Mock validation - no errors
return []
def set_response(self, content: str) -> None:
"""Set the response content for testing."""
self.response_content = content
def reset_call_count(self) -> None:
"""Reset the call counter."""
self.call_count = 0
================================================
FILE: src/vanna/integrations/mssql/__init__.py
================================================
"""Microsoft SQL Server integration for Vanna."""
from .sql_runner import MSSQLRunner
__all__ = ["MSSQLRunner"]
================================================
FILE: src/vanna/integrations/mssql/sql_runner.py
================================================
"""Microsoft SQL Server implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class MSSQLRunner(SqlRunner):
"""Microsoft SQL Server implementation of the SqlRunner interface."""
def __init__(self, odbc_conn_str: str, **kwargs):
"""Initialize with MSSQL connection parameters.
Args:
odbc_conn_str: The ODBC connection string for SQL Server
**kwargs: Additional SQLAlchemy engine parameters
"""
try:
import pyodbc
self.pyodbc = pyodbc
except ImportError as e:
raise ImportError(
"pyodbc package is required. Install with: pip install pyodbc"
) from e
try:
import sqlalchemy as sa
from sqlalchemy.engine import URL
from sqlalchemy import create_engine
self.sa = sa
self.URL = URL
self.create_engine = create_engine
except ImportError as e:
raise ImportError(
"sqlalchemy package is required. Install with: pip install sqlalchemy"
) from e
# Create the connection URL
connection_url = self.URL.create(
"mssql+pyodbc", query={"odbc_connect": odbc_conn_str}
)
# Create the engine
self.engine = self.create_engine(connection_url, **kwargs)
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against MSSQL database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
sqlalchemy.exc.SQLAlchemyError: If query execution fails
"""
# Execute the SQL statement and return the result as a pandas DataFrame
with self.engine.begin() as conn:
df = pd.read_sql_query(self.sa.text(args.sql), conn)
return df
================================================
FILE: src/vanna/integrations/mysql/__init__.py
================================================
"""MySQL integration for Vanna."""
from .sql_runner import MySQLRunner
__all__ = ["MySQLRunner"]
================================================
FILE: src/vanna/integrations/mysql/sql_runner.py
================================================
"""MySQL implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class MySQLRunner(SqlRunner):
"""MySQL implementation of the SqlRunner interface."""
def __init__(
self,
host: str,
database: str,
user: str,
password: str,
port: int = 3306,
**kwargs,
):
"""Initialize with MySQL connection parameters.
Args:
host: Database host address
database: Database name
user: Database user
password: Database password
port: Database port (default: 3306)
**kwargs: Additional PyMySQL connection parameters
"""
try:
import pymysql.cursors
self.pymysql = pymysql
except ImportError as e:
raise ImportError(
"PyMySQL package is required. Install with: pip install 'vanna[mysql]'"
) from e
self.host = host
self.database = database
self.user = user
self.password = password
self.port = port
self.kwargs = kwargs
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against MySQL database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
pymysql.Error: If query execution fails
"""
# Connect to the database
conn = self.pymysql.connect(
host=self.host,
user=self.user,
password=self.password,
database=self.database,
port=self.port,
cursorclass=self.pymysql.cursors.DictCursor,
**self.kwargs,
)
try:
# Ping to ensure connection is alive
conn.ping(reconnect=True)
cursor = conn.cursor()
cursor.execute(args.sql)
results = cursor.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results,
columns=[desc[0] for desc in cursor.description]
if cursor.description
else [],
)
cursor.close()
return df
finally:
conn.close()
================================================
FILE: src/vanna/integrations/ollama/__init__.py
================================================
"""
Ollama integration for Vanna Agents.
"""
from .llm import OllamaLlmService
__all__ = ["OllamaLlmService"]
================================================
FILE: src/vanna/integrations/ollama/llm.py
================================================
"""
Ollama LLM service implementation.
This module provides an implementation of the LlmService interface backed by
Ollama's local LLM API. It supports non-streaming responses and streaming
of text content. Tool calling support depends on the Ollama model being used.
"""
from __future__ import annotations
import json
import os
from typing import Any, AsyncGenerator, Dict, List, Optional
from vanna.core.llm import (
LlmService,
LlmRequest,
LlmResponse,
LlmStreamChunk,
)
from vanna.core.tool import ToolCall, ToolSchema
class OllamaLlmService(LlmService):
"""Ollama-backed LLM service for local model inference.
Args:
model: Ollama model name (e.g., "gpt-oss:20b").
host: Ollama server URL; defaults to "http://localhost:11434" or env `OLLAMA_HOST`.
timeout: Request timeout in seconds; defaults to 240.
num_ctx: Context window size; defaults to 8192.
temperature: Sampling temperature; defaults to 0.7.
extra_options: Additional options passed to Ollama (e.g., num_predict, top_k, top_p).
"""
def __init__(
self,
model: str,
host: Optional[str] = None,
timeout: float = 240.0,
num_ctx: int = 8192,
temperature: float = 0.7,
**extra_options: Any,
) -> None:
try:
import ollama
except ImportError as e:
raise ImportError(
"ollama package is required. Install with: pip install 'vanna[ollama]' or pip install ollama"
) from e
if not model:
raise ValueError("model parameter is required for Ollama")
self.model = model
self.host = host or os.getenv("OLLAMA_HOST", "http://localhost:11434")
self.timeout = timeout
self.num_ctx = num_ctx
self.temperature = temperature
self.extra_options = extra_options
# Create Ollama client
self._client = ollama.Client(host=self.host, timeout=timeout)
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Send a non-streaming request to Ollama and return the response."""
payload = self._build_payload(request)
# Call the Ollama API
try:
resp = self._client.chat(**payload)
except Exception as e:
raise RuntimeError(f"Ollama request failed: {str(e)}") from e
# Extract message from response
message = resp.get("message", {})
content = message.get("content")
tool_calls = self._extract_tool_calls_from_message(message)
# Extract usage information if available
usage: Dict[str, int] = {}
if "prompt_eval_count" in resp or "eval_count" in resp:
usage = {
"prompt_tokens": resp.get("prompt_eval_count", 0),
"completion_tokens": resp.get("eval_count", 0),
"total_tokens": resp.get("prompt_eval_count", 0)
+ resp.get("eval_count", 0),
}
return LlmResponse(
content=content,
tool_calls=tool_calls or None,
finish_reason=resp.get("done_reason")
or ("stop" if resp.get("done") else None),
usage=usage or None,
)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Stream a request to Ollama.
Emits `LlmStreamChunk` for textual deltas as they arrive. Tool calls are
accumulated and emitted in a final chunk when the stream ends.
"""
payload = self._build_payload(request)
# Ollama streaming
try:
stream = self._client.chat(**payload, stream=True)
except Exception as e:
raise RuntimeError(f"Ollama streaming request failed: {str(e)}") from e
# Accumulate tool calls if present
accumulated_tool_calls: List[ToolCall] = []
last_finish: Optional[str] = None
for chunk in stream:
message = chunk.get("message", {})
# Yield text content
content = message.get("content")
if content:
yield LlmStreamChunk(content=content)
# Accumulate tool calls
tool_calls = self._extract_tool_calls_from_message(message)
if tool_calls:
accumulated_tool_calls.extend(tool_calls)
# Track finish reason
if chunk.get("done"):
last_finish = chunk.get("done_reason", "stop")
# Emit final chunk with tool calls if any
if accumulated_tool_calls:
yield LlmStreamChunk(
tool_calls=accumulated_tool_calls, finish_reason=last_finish or "stop"
)
else:
# Emit terminal chunk to signal completion
yield LlmStreamChunk(finish_reason=last_finish or "stop")
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Validate tool schemas. Returns a list of error messages."""
errors: List[str] = []
# Basic validation; Ollama model support for tools varies
for t in tools:
if not t.name:
errors.append(f"Tool must have a name")
if not t.description:
errors.append(f"Tool '{t.name}' should have a description")
return errors
# Internal helpers
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
"""Build the Ollama chat payload from LlmRequest."""
messages: List[Dict[str, Any]] = []
# Add system prompt as first message if provided
if request.system_prompt:
messages.append({"role": "system", "content": request.system_prompt})
# Convert messages to Ollama format
for m in request.messages:
msg: Dict[str, Any] = {"role": m.role, "content": m.content or ""}
# Handle tool calls in assistant messages
if m.role == "assistant" and m.tool_calls:
# Some Ollama models support tool_calls in message
tool_calls_payload = []
for tc in m.tool_calls:
tool_calls_payload.append(
{"function": {"name": tc.name, "arguments": tc.arguments}}
)
msg["tool_calls"] = tool_calls_payload
messages.append(msg)
# Build tools array if tools are provided
tools_payload: Optional[List[Dict[str, Any]]] = None
if request.tools:
tools_payload = []
for t in request.tools:
tools_payload.append(
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
},
}
)
# Build options
options: Dict[str, Any] = {
"num_ctx": self.num_ctx,
"temperature": self.temperature,
**self.extra_options,
}
# Build final payload
payload: Dict[str, Any] = {
"model": self.model,
"messages": messages,
"options": options,
}
# Add tools if provided (note: not all Ollama models support tools)
if tools_payload:
payload["tools"] = tools_payload
return payload
def _extract_tool_calls_from_message(
self, message: Dict[str, Any]
) -> List[ToolCall]:
"""Extract tool calls from Ollama message."""
tool_calls: List[ToolCall] = []
# Check for tool_calls in message
raw_tool_calls = message.get("tool_calls", [])
if not raw_tool_calls:
return tool_calls
for idx, tc in enumerate(raw_tool_calls):
fn = tc.get("function", {})
name = fn.get("name")
if not name:
continue
# Parse arguments
arguments = fn.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except Exception:
arguments = {"_raw": arguments}
if not isinstance(arguments, dict):
arguments = {"args": arguments}
tool_calls.append(
ToolCall(
id=tc.get("id", f"tool_call_{idx}"),
name=name,
arguments=arguments,
)
)
return tool_calls
================================================
FILE: src/vanna/integrations/openai/__init__.py
================================================
"""
OpenAI integration.
This module provides OpenAI LLM service implementations.
"""
from .llm import OpenAILlmService
from .responses import OpenAIResponsesService
__all__ = ["OpenAILlmService", "OpenAIResponsesService"]
================================================
FILE: src/vanna/integrations/openai/llm.py
================================================
"""
OpenAI LLM service implementation.
This module provides an implementation of the LlmService interface backed by
OpenAI's Chat Completions API (openai>=1.0.0). It supports non-streaming
responses and best-effort streaming of text content. Tool/function calling is
passed through when tools are provided, but full tool-call conversation
round-tripping may require adding assistant tool-call messages to the
conversation upstream.
"""
from __future__ import annotations
import json
import os
from typing import Any, AsyncGenerator, Dict, List, Optional, cast
from vanna.core.llm import (
LlmService,
LlmRequest,
LlmResponse,
LlmStreamChunk,
)
from vanna.core.tool import ToolCall, ToolSchema
class OpenAILlmService(LlmService):
"""OpenAI Chat Completions-backed LLM service.
Args:
model: OpenAI model name (e.g., "gpt-5").
api_key: API key; falls back to env `OPENAI_API_KEY`.
organization: Optional org; env `OPENAI_ORG` if unset.
base_url: Optional custom base URL; env `OPENAI_BASE_URL` if unset.
extra_client_kwargs: Extra kwargs forwarded to `openai.OpenAI()`.
"""
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
organization: Optional[str] = None,
base_url: Optional[str] = None,
**extra_client_kwargs: Any,
) -> None:
try:
from openai import OpenAI
except Exception as e: # pragma: no cover - import-time error surface
raise ImportError(
"openai package is required. Install with: pip install 'vanna[openai]'"
) from e
self.model = model or os.getenv("OPENAI_MODEL", "gpt-5")
api_key = api_key or os.getenv("OPENAI_API_KEY")
organization = organization or os.getenv("OPENAI_ORG")
base_url = base_url or os.getenv("OPENAI_BASE_URL")
client_kwargs: Dict[str, Any] = {**extra_client_kwargs}
if api_key:
client_kwargs["api_key"] = api_key
if organization:
client_kwargs["organization"] = organization
if base_url:
client_kwargs["base_url"] = base_url
self._client = OpenAI(**client_kwargs)
async def send_request(self, request: LlmRequest) -> LlmResponse:
"""Send a non-streaming request to OpenAI and return the response."""
payload = self._build_payload(request)
# Call the API synchronously; this function is async but we can block here.
resp = self._client.chat.completions.create(**payload, stream=False)
if not resp.choices:
return LlmResponse(content=None, tool_calls=None, finish_reason=None)
choice = resp.choices[0]
content: Optional[str] = getattr(choice.message, "content", None)
tool_calls = self._extract_tool_calls_from_message(choice.message)
usage: Dict[str, int] = {}
if getattr(resp, "usage", None):
usage = {
k: int(v)
for k, v in {
"prompt_tokens": getattr(resp.usage, "prompt_tokens", 0),
"completion_tokens": getattr(resp.usage, "completion_tokens", 0),
"total_tokens": getattr(resp.usage, "total_tokens", 0),
}.items()
}
return LlmResponse(
content=content,
tool_calls=tool_calls or None,
finish_reason=getattr(choice, "finish_reason", None),
usage=usage or None,
)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
"""Stream a request to OpenAI.
Emits `LlmStreamChunk` for textual deltas as they arrive. Tool-calls are
accumulated and emitted in a final chunk when the stream ends.
"""
payload = self._build_payload(request)
# Synchronous streaming iterator; iterate within async context.
stream = self._client.chat.completions.create(**payload, stream=True)
# Builders for streamed tool-calls (index -> partial)
tc_builders: Dict[int, Dict[str, Optional[str]]] = {}
last_finish: Optional[str] = None
for event in stream:
if not getattr(event, "choices", None):
continue
choice = event.choices[0]
delta = getattr(choice, "delta", None)
if delta is None:
# Some SDK versions use `event.choices[0].message` on the final packet
last_finish = getattr(choice, "finish_reason", last_finish)
continue
# Text content
content_piece: Optional[str] = getattr(delta, "content", None)
if content_piece:
yield LlmStreamChunk(content=content_piece)
# Tool calls (streamed)
streamed_tool_calls = getattr(delta, "tool_calls", None)
if streamed_tool_calls:
for tc in streamed_tool_calls:
idx = getattr(tc, "index", 0) or 0
b = tc_builders.setdefault(
idx, {"id": None, "name": None, "arguments": ""}
)
if getattr(tc, "id", None):
b["id"] = tc.id
fn = getattr(tc, "function", None)
if fn is not None:
if getattr(fn, "name", None):
b["name"] = fn.name
if getattr(fn, "arguments", None):
b["arguments"] = (b["arguments"] or "") + fn.arguments
last_finish = getattr(choice, "finish_reason", last_finish)
# Emit final tool-calls chunk if any
final_tool_calls: List[ToolCall] = []
for b in tc_builders.values():
if not b.get("name"):
continue
args_raw = b.get("arguments") or "{}"
try:
loaded = json.loads(args_raw)
if isinstance(loaded, dict):
args_dict: Dict[str, Any] = loaded
else:
args_dict = {"args": loaded}
except Exception:
args_dict = {"_raw": args_raw}
final_tool_calls.append(
ToolCall(
id=b.get("id") or "tool_call",
name=b["name"] or "tool",
arguments=args_dict,
)
)
if final_tool_calls:
yield LlmStreamChunk(tool_calls=final_tool_calls, finish_reason=last_finish)
else:
# Still emit a terminal chunk to signal completion
yield LlmStreamChunk(finish_reason=last_finish or "stop")
async def validate_tools(self, tools: List[ToolSchema]) -> List[str]:
"""Validate tool schemas. Returns a list of error messages."""
errors: List[str] = []
# Basic checks; OpenAI will enforce further validation server-side.
for t in tools:
if not t.name or len(t.name) > 64:
errors.append(f"Invalid tool name: {t.name!r}")
return errors
# Internal helpers
def _build_payload(self, request: LlmRequest) -> Dict[str, Any]:
messages: List[Dict[str, Any]] = []
# Add system prompt as first message if provided
if request.system_prompt:
messages.append({"role": "system", "content": request.system_prompt})
for m in request.messages:
msg: Dict[str, Any] = {"role": m.role, "content": m.content}
if m.role == "tool" and m.tool_call_id:
msg["tool_call_id"] = m.tool_call_id
elif m.role == "assistant" and m.tool_calls:
# Convert tool calls to OpenAI format
tool_calls_payload = []
for tc in m.tool_calls:
tool_calls_payload.append(
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments),
},
}
)
msg["tool_calls"] = tool_calls_payload
messages.append(msg)
tools_payload: Optional[List[Dict[str, Any]]] = None
if request.tools:
tools_payload = [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
},
}
for t in request.tools
]
payload: Dict[str, Any] = {
"model": self.model,
"messages": messages,
}
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if tools_payload:
payload["tools"] = tools_payload
payload["tool_choice"] = "auto"
return payload
def _extract_tool_calls_from_message(self, message: Any) -> List[ToolCall]:
tool_calls: List[ToolCall] = []
raw_tool_calls = getattr(message, "tool_calls", None) or []
for tc in raw_tool_calls:
fn = getattr(tc, "function", None)
if not fn:
continue
args_raw = getattr(fn, "arguments", "{}")
try:
loaded = json.loads(args_raw)
if isinstance(loaded, dict):
args_dict: Dict[str, Any] = loaded
else:
args_dict = {"args": loaded}
except Exception:
args_dict = {"_raw": args_raw}
tool_calls.append(
ToolCall(
id=getattr(tc, "id", "tool_call"),
name=getattr(fn, "name", "tool"),
arguments=args_dict,
)
)
return tool_calls
================================================
FILE: src/vanna/integrations/openai/responses.py
================================================
from __future__ import annotations
import json
import os
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, TYPE_CHECKING
from vanna.core.llm import LlmService, LlmRequest, LlmResponse, LlmStreamChunk
from vanna.core.tool import ToolCall, ToolSchema
if TYPE_CHECKING:
from openai.types.responses import Response
class OpenAIResponsesService(LlmService):
def __init__(
self, api_key: Optional[str] = None, model: Optional[str] = None
) -> None:
try:
from openai import AsyncOpenAI
from openai.types.responses import Response
except Exception as e: # pragma: no cover
raise ImportError(
"openai package is required. Install with: pip install 'vanna[openai]'"
) from e
self.client = AsyncOpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
self.model = model or os.getenv("OPENAI_MODEL", "gpt-5")
async def send_request(self, request: LlmRequest) -> LlmResponse:
payload = self._payload(request)
resp: Response = await self.client.responses.create(**payload)
self._debug_print("response", resp)
text, tools, status, usage = self._extract(resp)
return LlmResponse(
content=text,
tool_calls=tools or None,
finish_reason=status,
usage=usage or None,
metadata={"request_id": getattr(resp, "id", None)},
)
async def stream_request(
self, request: LlmRequest
) -> AsyncGenerator[LlmStreamChunk, None]:
payload = self._payload(request)
async with self.client.responses.stream(**payload) as stream:
async for event in stream:
self._debug_print("stream_event", event)
event_type = getattr(event, "type", None)
if event_type == "response.output_text.delta":
delta = getattr(event, "delta", None)
if delta:
yield LlmStreamChunk(content=delta)
final: Response = await stream.get_final_response()
self._debug_print("final_response", final)
_text, tools, status, _usage = self._extract(final)
yield LlmStreamChunk(tool_calls=tools or None, finish_reason=status)
async def validate_tools(self, tools: List[Any]) -> List[str]:
return [] # minimal: accept whatever's passed through
# ---- helpers ----
def _payload(self, request: LlmRequest) -> Dict[str, Any]:
msgs = [{"role": m.role, "content": m.content} for m in request.messages]
p: Dict[str, Any] = {"model": self.model, "input": msgs}
if request.system_prompt:
p["instructions"] = request.system_prompt
if request.max_tokens:
p["max_output_tokens"] = request.max_tokens
if request.tools:
p["tools"] = [self._serialize_tool(t) for t in request.tools]
return p
def _debug_print(self, label: str, obj: Any) -> None:
try:
payload = obj.model_dump()
except AttributeError:
try:
payload = obj.dict()
except AttributeError:
payload = obj
print(f"[OpenAIResponsesService] {label}: {payload}")
def _extract(
self, resp: Response
) -> Tuple[
Optional[str], Optional[List[ToolCall]], Optional[str], Optional[Dict[str, int]]
]:
text = getattr(resp, "output_text", None)
tool_calls: List[ToolCall] = []
for oc in getattr(resp, "output", []) or []:
for item in getattr(oc, "content", []) or []:
if getattr(item, "type", None) == "tool_call":
tc = getattr(item, "tool_call", None)
if tc and getattr(tc, "type", None) == "function":
fn = getattr(tc, "function", None)
if fn:
name = getattr(fn, "name", None)
args = getattr(fn, "arguments", None)
if not isinstance(args, (dict, list)):
try:
args = json.loads(args) if args else {}
except Exception:
args = {"_raw": args}
tool_calls.append(ToolCall(name=name, arguments=args))
usage = None
if getattr(resp, "usage", None):
usage = {
"input_tokens": getattr(resp.usage, "input_tokens", 0) or 0,
"output_tokens": getattr(resp.usage, "output_tokens", 0) or 0,
"total_tokens": getattr(resp.usage, "total_tokens", None)
or (
(getattr(resp.usage, "input_tokens", 0) or 0)
+ (getattr(resp.usage, "output_tokens", 0) or 0)
),
}
status = getattr(resp, "status", None) # e.g. "completed"
return text, (tool_calls or None), status, usage
def _serialize_tool(self, tool: Any) -> Dict[str, Any]:
"""Convert a tool schema into the dict format expected by OpenAI Responses."""
if isinstance(tool, ToolSchema):
return {
"type": "function",
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
"strict": False,
}
# Support generic pydantic/BaseModel style objects without importing pydantic here.
if hasattr(tool, "model_dump"):
data = tool.model_dump()
if all(key in data for key in ("name", "description", "parameters")):
return {
"type": "function",
"name": data["name"],
"description": data["description"],
"parameters": data["parameters"],
"strict": data.get("strict", False),
}
return data
if isinstance(tool, dict):
if "type" in tool:
return tool
if all(k in tool for k in ("name", "description", "parameters")):
return {
"type": "function",
"name": tool["name"],
"description": tool["description"],
"parameters": tool["parameters"],
"strict": tool.get("strict", False),
}
return tool
raise TypeError(f"Unsupported tool schema type: {type(tool)!r}")
================================================
FILE: src/vanna/integrations/opensearch/__init__.py
================================================
"""
OpenSearch integration for Vanna Agents.
"""
from .agent_memory import OpenSearchAgentMemory
__all__ = ["OpenSearchAgentMemory"]
================================================
FILE: src/vanna/integrations/opensearch/agent_memory.py
================================================
"""
OpenSearch vector database implementation of AgentMemory.
This implementation uses OpenSearch for distributed search and storage of tool usage patterns.
"""
import json
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
try:
from opensearchpy import OpenSearch, helpers
OPENSEARCH_AVAILABLE = True
except ImportError:
OPENSEARCH_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class OpenSearchAgentMemory(AgentMemory):
"""OpenSearch-based implementation of AgentMemory."""
def __init__(
self,
index_name: str = "tool_memories",
hosts: Optional[List[str]] = None,
http_auth: Optional[tuple] = None,
use_ssl: bool = False,
verify_certs: bool = False,
dimension: int = 384,
):
if not OPENSEARCH_AVAILABLE:
raise ImportError(
"OpenSearch is required for OpenSearchAgentMemory. Install with: pip install opensearch-py"
)
self.index_name = index_name
self.hosts = hosts or ["localhost:9200"]
self.http_auth = http_auth
self.use_ssl = use_ssl
self.verify_certs = verify_certs
self.dimension = dimension
self._client = None
self._executor = ThreadPoolExecutor(max_workers=2)
def _get_client(self):
"""Get or create OpenSearch client."""
if self._client is None:
self._client = OpenSearch(
hosts=self.hosts,
http_auth=self.http_auth,
use_ssl=self.use_ssl,
verify_certs=self.verify_certs,
ssl_show_warn=False,
)
# Create index if it doesn't exist
if not self._client.indices.exists(index=self.index_name):
index_body = {
"settings": {
"index": {"knn": True, "knn.algo_param.ef_search": 100}
},
"mappings": {
"properties": {
"memory_id": {"type": "keyword"},
"question": {"type": "text"},
"tool_name": {"type": "keyword"},
"args": {"type": "object", "enabled": False},
"timestamp": {"type": "date"},
"success": {"type": "boolean"},
"metadata": {"type": "object", "enabled": False},
"embedding": {
"type": "knn_vector",
"dimension": self.dimension,
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "nmslib",
},
},
}
},
}
self._client.indices.create(index=self.index_name, body=index_body)
return self._client
def _create_embedding(self, text: str) -> List[float]:
"""Create a simple embedding from text (placeholder)."""
import hashlib
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
client = self._get_client()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(question)
document = {
"memory_id": memory_id,
"question": question,
"tool_name": tool_name,
"args": args,
"timestamp": timestamp,
"success": success,
"metadata": metadata or {},
"embedding": embedding,
}
client.index(
index=self.index_name, body=document, id=memory_id, refresh=True
)
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
client = self._get_client()
embedding = self._create_embedding(question)
# Build query
must_conditions = [{"term": {"success": True}}]
if tool_name_filter:
must_conditions.append({"term": {"tool_name": tool_name_filter}})
query = {
"size": limit,
"query": {
"bool": {
"must": must_conditions,
"filter": {
"knn": {"embedding": {"vector": embedding, "k": limit}}
},
}
},
}
response = client.search(index=self.index_name, body=query)
search_results = []
for i, hit in enumerate(response["hits"]["hits"]):
source = hit["_source"]
score = hit["_score"]
# Normalize score to 0-1 range (OpenSearch scores can vary)
similarity_score = min(score / 10.0, 1.0)
if similarity_score >= similarity_threshold:
memory = ToolMemory(
memory_id=source["memory_id"],
question=source["question"],
tool_name=source["tool_name"],
args=source["args"],
timestamp=source.get("timestamp"),
success=source.get("success", True),
metadata=source.get("metadata", {}),
)
search_results.append(
ToolMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories."""
def _get_recent():
client = self._get_client()
query = {
"size": limit,
"query": {"match_all": {}},
"sort": [{"timestamp": {"order": "desc"}}],
}
response = client.search(index=self.index_name, body=query)
memories = []
for hit in response["hits"]["hits"]:
source = hit["_source"]
memory = ToolMemory(
memory_id=source["memory_id"],
question=source["question"],
tool_name=source["tool_name"],
args=source["args"],
timestamp=source.get("timestamp"),
success=source.get("success", True),
metadata=source.get("metadata", {}),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID."""
def _delete():
client = self._get_client()
try:
client.delete(index=self.index_name, id=memory_id, refresh=True)
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
client = self._get_client()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(content)
document = {
"memory_id": memory_id,
"content": content,
"timestamp": timestamp,
"is_text_memory": True,
"embedding": embedding,
}
client.index(
index=self.index_name, body=document, id=memory_id, refresh=True
)
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
client = self._get_client()
embedding = self._create_embedding(query)
query_body = {
"size": limit,
"query": {
"bool": {
"must": [{"term": {"is_text_memory": True}}],
"filter": {
"knn": {"embedding": {"vector": embedding, "k": limit}}
},
}
},
}
response = client.search(index=self.index_name, body=query_body)
search_results = []
for i, hit in enumerate(response["hits"]["hits"]):
source = hit["_source"]
score = hit["_score"]
similarity_score = min(score / 10.0, 1.0)
if similarity_score >= similarity_threshold:
memory = TextMemory(
memory_id=source["memory_id"],
content=source.get("content", ""),
timestamp=source.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
client = self._get_client()
query = {
"size": limit,
"query": {"term": {"is_text_memory": True}},
"sort": [{"timestamp": {"order": "desc"}}],
}
response = client.search(index=self.index_name, body=query)
memories = []
for hit in response["hits"]["hits"]:
source = hit["_source"]
memory = TextMemory(
memory_id=source["memory_id"],
content=source.get("content", ""),
timestamp=source.get("timestamp"),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
client = self._get_client()
try:
client.delete(index=self.index_name, id=memory_id, refresh=True)
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
client = self._get_client()
# Build query
must_conditions = []
if tool_name:
must_conditions.append({"term": {"tool_name": tool_name}})
if before_date:
must_conditions.append({"range": {"timestamp": {"lt": before_date}}})
if must_conditions:
query = {"query": {"bool": {"must": must_conditions}}}
else:
query = {"query": {"match_all": {}}}
response = client.delete_by_query(
index=self.index_name, body=query, refresh=True
)
return response.get("deleted", 0)
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/integrations/oracle/__init__.py
================================================
"""Oracle integration for Vanna."""
from .sql_runner import OracleRunner
__all__ = ["OracleRunner"]
================================================
FILE: src/vanna/integrations/oracle/sql_runner.py
================================================
"""Oracle implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class OracleRunner(SqlRunner):
"""Oracle implementation of the SqlRunner interface."""
def __init__(self, user: str, password: str, dsn: str, **kwargs):
"""Initialize with Oracle connection parameters.
Args:
user: Oracle database user name
password: Oracle database user password
dsn: Oracle database host - format: host:port/sid
**kwargs: Additional oracledb connection parameters
"""
try:
import oracledb
self.oracledb = oracledb
except ImportError as e:
raise ImportError(
"oracledb package is required. Install with: pip install 'vanna[oracle]'"
) from e
self.user = user
self.password = password
self.dsn = dsn
self.kwargs = kwargs
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against Oracle database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
oracledb.Error: If query execution fails
"""
# Connect to the database
conn = self.oracledb.connect(
user=self.user, password=self.password, dsn=self.dsn, **self.kwargs
)
cursor = conn.cursor()
try:
# Strip and remove trailing semicolons (Oracle doesn't like them)
sql = args.sql.rstrip()
if sql.endswith(";"):
sql = sql[:-1]
# Execute the query
cursor.execute(sql)
results = cursor.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])
return df
except self.oracledb.Error:
conn.rollback()
raise
finally:
cursor.close()
conn.close()
================================================
FILE: src/vanna/integrations/pinecone/__init__.py
================================================
"""
Pinecone integration for Vanna Agents.
"""
from .agent_memory import PineconeAgentMemory
__all__ = ["PineconeAgentMemory"]
================================================
FILE: src/vanna/integrations/pinecone/agent_memory.py
================================================
"""
Pinecone vector database implementation of AgentMemory.
This implementation uses Pinecone for cloud-based vector storage of tool usage patterns.
"""
import json
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
try:
from pinecone import Pinecone, ServerlessSpec
PINECONE_AVAILABLE = True
except ImportError:
PINECONE_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class PineconeAgentMemory(AgentMemory):
"""Pinecone-based implementation of AgentMemory."""
def __init__(
self,
api_key: str,
index_name: str = "tool-memories",
environment: str = "us-east-1",
dimension: int = 384,
metric: str = "cosine",
):
if not PINECONE_AVAILABLE:
raise ImportError(
"Pinecone is required for PineconeAgentMemory. Install with: pip install pinecone-client"
)
self.api_key = api_key
self.index_name = index_name
self.environment = environment
self.dimension = dimension
self.metric = metric
self._client = None
self._index = None
self._executor = ThreadPoolExecutor(max_workers=2)
def _get_client(self):
"""Get or create Pinecone client."""
if self._client is None:
self._client = Pinecone(api_key=self.api_key)
return self._client
def _get_index(self):
"""Get or create Pinecone index."""
if self._index is None:
client = self._get_client()
# Create index if it doesn't exist
if self.index_name not in client.list_indexes().names():
client.create_index(
name=self.index_name,
dimension=self.dimension,
metric=self.metric,
spec=ServerlessSpec(cloud="aws", region=self.environment),
)
self._index = client.Index(self.index_name)
return self._index
def _create_embedding(self, text: str) -> List[float]:
"""Create a simple embedding from text (placeholder - should use actual embedding model)."""
# TODO: Replace with actual embedding model
import hashlib
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
index = self._get_index()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(question)
# Pinecone metadata must be simple types
memory_metadata = {
"question": question,
"tool_name": tool_name,
"args_json": json.dumps(args),
"timestamp": timestamp,
"success": success,
"metadata_json": json.dumps(metadata or {}),
}
index.upsert(vectors=[(memory_id, embedding, memory_metadata)])
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
index = self._get_index()
embedding = self._create_embedding(question)
# Build filter
filter_dict = {"success": True}
if tool_name_filter:
filter_dict["tool_name"] = tool_name_filter
results = index.query(
vector=embedding, top_k=limit, filter=filter_dict, include_metadata=True
)
search_results = []
for i, match in enumerate(results.matches):
# Pinecone returns similarity score directly
similarity_score = match.score
if similarity_score >= similarity_threshold:
metadata = match.metadata
args = json.loads(metadata.get("args_json", "{}"))
metadata_dict = json.loads(metadata.get("metadata_json", "{}"))
memory = ToolMemory(
memory_id=match.id,
question=metadata["question"],
tool_name=metadata["tool_name"],
args=args,
timestamp=metadata.get("timestamp"),
success=metadata.get("success", True),
metadata=metadata_dict,
)
search_results.append(
ToolMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories."""
def _get_recent():
index = self._get_index()
# Pinecone doesn't have a native "get all" - we need to query with a dummy vector
# or use the list operation with metadata filtering
# This is a limitation - we'll return empty for now
# In production, you'd maintain a separate timestamp index or use Pinecone's metadata filtering
return []
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID."""
def _delete():
index = self._get_index()
try:
index.delete(ids=[memory_id])
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
index = self._get_index()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(content)
memory_metadata = {
"content": content,
"timestamp": timestamp,
"is_text_memory": True,
}
index.upsert(vectors=[(memory_id, embedding, memory_metadata)])
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
index = self._get_index()
embedding = self._create_embedding(query)
filter_dict = {"is_text_memory": True}
results = index.query(
vector=embedding, top_k=limit, filter=filter_dict, include_metadata=True
)
search_results = []
for i, match in enumerate(results.matches):
similarity_score = match.score
if similarity_score >= similarity_threshold:
metadata = match.metadata
memory = TextMemory(
memory_id=match.id,
content=metadata.get("content", ""),
timestamp=metadata.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
# Pinecone doesn't have a native "get all sorted by timestamp" operation
# This is a limitation - returning empty list
# In production, you'd need to maintain a separate index or use metadata filtering
return []
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
index = self._get_index()
try:
index.delete(ids=[memory_id])
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
index = self._get_index()
# Build filter
filter_dict = {}
if tool_name:
filter_dict["tool_name"] = tool_name
if before_date:
filter_dict["timestamp"] = {"$lt": before_date}
if filter_dict:
# Delete with filter
index.delete(filter=filter_dict)
else:
# Delete all
index.delete(delete_all=True)
# Pinecone doesn't return count of deleted items
return 0
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/integrations/plotly/__init__.py
================================================
"""Plotly integration for chart generation."""
from .chart_generator import PlotlyChartGenerator
__all__ = ["PlotlyChartGenerator"]
================================================
FILE: src/vanna/integrations/plotly/chart_generator.py
================================================
"""Plotly-based chart generator with automatic chart type selection."""
from typing import Dict, Any, List, cast
import json
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
class PlotlyChartGenerator:
"""Generate Plotly charts using heuristics based on DataFrame characteristics."""
# Vanna brand colors from landing page
THEME_COLORS = {
"navy": "#023d60",
"cream": "#e7e1cf",
"teal": "#15a8a8",
"orange": "#fe5d26",
"magenta": "#bf1363",
}
# Color palette for charts (excluding cream as it's too light for data)
COLOR_PALETTE = ["#15a8a8", "#fe5d26", "#bf1363", "#023d60"]
def generate_chart(self, df: pd.DataFrame, title: str = "Chart") -> Dict[str, Any]:
"""Generate a Plotly chart based on DataFrame shape and types.
Heuristics:
- 4+ columns: table
- 1 numeric column: histogram
- 2 columns (1 categorical, 1 numeric): bar chart
- 2 numeric columns: scatter plot
- 3+ numeric columns: correlation heatmap or multi-line chart
- Time series data: line chart
- Multiple categorical: grouped bar chart
Args:
df: DataFrame to visualize
title: Title for the chart
Returns:
Plotly figure as dictionary
Raises:
ValueError: If DataFrame is empty or cannot be visualized
"""
if df.empty:
raise ValueError("Cannot visualize empty DataFrame")
# Heuristic: If 4 or more columns, render as a table
if len(df.columns) >= 4:
fig = self._create_table(df, title)
result: Dict[str, Any] = json.loads(pio.to_json(fig))
return result
# Identify column types
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
categorical_cols = df.select_dtypes(
include=["object", "category"]
).columns.tolist()
datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
# Check for time series
is_timeseries = len(datetime_cols) > 0
# Apply heuristics
if is_timeseries and len(numeric_cols) > 0:
# Time series line chart
fig = self._create_time_series_chart(
df, datetime_cols[0], numeric_cols, title
)
elif len(numeric_cols) == 1 and len(categorical_cols) == 0:
# Single numeric column: histogram
fig = self._create_histogram(df, numeric_cols[0], title)
elif len(numeric_cols) == 1 and len(categorical_cols) == 1:
# One categorical, one numeric: bar chart
fig = self._create_bar_chart(
df, categorical_cols[0], numeric_cols[0], title
)
elif len(numeric_cols) == 2:
# Two numeric columns: scatter plot
fig = self._create_scatter_plot(df, numeric_cols[0], numeric_cols[1], title)
elif len(numeric_cols) >= 3:
# Multiple numeric columns: correlation heatmap
fig = self._create_correlation_heatmap(df, numeric_cols, title)
elif len(categorical_cols) >= 2:
# Multiple categorical: grouped bar chart
fig = self._create_grouped_bar_chart(df, categorical_cols, title)
else:
# Fallback: show first two columns as scatter/bar
if len(df.columns) >= 2:
fig = self._create_generic_chart(
df, df.columns[0], df.columns[1], title
)
else:
raise ValueError(
"Cannot determine appropriate visualization for this DataFrame"
)
# Convert to JSON-serializable dict using plotly's JSON encoder
result = json.loads(pio.to_json(fig))
return result
def _apply_standard_layout(self, fig: go.Figure) -> go.Figure:
"""Apply consistent Vanna brand styling to all charts.
Uses Vanna brand colors from the landing page for a cohesive look.
Args:
fig: Plotly figure to update
Returns:
Updated figure with Vanna brand styling
"""
fig.update_layout(
# paper_bgcolor='white',
# plot_bgcolor='white',
font={"color": self.THEME_COLORS["navy"]}, # Navy for text
autosize=True, # Allow chart to resize responsively
colorway=self.COLOR_PALETTE, # Use Vanna brand colors for data
# Don't set width/height - let frontend handle sizing
)
return fig
def _create_histogram(self, df: pd.DataFrame, column: str, title: str) -> go.Figure:
"""Create a histogram for a single numeric column."""
fig = px.histogram(
df,
x=column,
title=title,
color_discrete_sequence=[self.THEME_COLORS["teal"]],
)
fig.update_layout(xaxis_title=column, yaxis_title="Count", showlegend=False)
self._apply_standard_layout(fig)
return fig
def _create_bar_chart(
self, df: pd.DataFrame, x_col: str, y_col: str, title: str
) -> go.Figure:
"""Create a bar chart for categorical vs numeric data."""
# Aggregate if needed
agg_df = df.groupby(x_col)[y_col].sum().reset_index()
fig = px.bar(
agg_df,
x=x_col,
y=y_col,
title=title,
color_discrete_sequence=[self.THEME_COLORS["orange"]],
)
fig.update_layout(xaxis_title=x_col, yaxis_title=y_col)
self._apply_standard_layout(fig)
return fig
def _create_scatter_plot(
self, df: pd.DataFrame, x_col: str, y_col: str, title: str
) -> go.Figure:
"""Create a scatter plot for two numeric columns."""
fig = px.scatter(
df,
x=x_col,
y=y_col,
title=title,
color_discrete_sequence=[self.THEME_COLORS["magenta"]],
)
fig.update_layout(xaxis_title=x_col, yaxis_title=y_col)
self._apply_standard_layout(fig)
return fig
def _create_correlation_heatmap(
self, df: pd.DataFrame, columns: List[str], title: str
) -> go.Figure:
"""Create a correlation heatmap for multiple numeric columns."""
corr_matrix = df[columns].corr()
# Custom Vanna color scale: navy (negative) -> cream (neutral) -> teal (positive)
vanna_colorscale = [
[0.0, self.THEME_COLORS["navy"]],
[0.5, self.THEME_COLORS["cream"]],
[1.0, self.THEME_COLORS["teal"]],
]
fig = cast(
go.Figure,
px.imshow(
corr_matrix,
title=title,
labels=dict(color="Correlation"),
x=columns,
y=columns,
color_continuous_scale=vanna_colorscale,
zmin=-1,
zmax=1,
),
)
self._apply_standard_layout(fig)
return fig
def _create_time_series_chart(
self, df: pd.DataFrame, time_col: str, value_cols: List[str], title: str
) -> go.Figure:
"""Create a time series line chart."""
fig = go.Figure()
for i, col in enumerate(value_cols[:5]): # Limit to 5 lines for readability
color = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
fig.add_trace(
go.Scatter(
x=df[time_col],
y=df[col],
mode="lines",
name=col,
line=dict(color=color),
)
)
fig.update_layout(
title=title,
xaxis_title=time_col,
yaxis_title="Value",
hovermode="x unified",
)
self._apply_standard_layout(fig)
return fig
def _create_grouped_bar_chart(
self, df: pd.DataFrame, categorical_cols: List[str], title: str
) -> go.Figure:
"""Create a grouped bar chart for multiple categorical columns."""
# Use first two categorical columns
if len(categorical_cols) >= 2:
# Count occurrences
grouped = df.groupby(categorical_cols[:2]).size().reset_index(name="count")
fig = px.bar(
grouped,
x=categorical_cols[0],
y="count",
color=categorical_cols[1],
title=title,
barmode="group",
color_discrete_sequence=self.COLOR_PALETTE,
)
self._apply_standard_layout(fig)
return fig
else:
# Single categorical: value counts
counts = df[categorical_cols[0]].value_counts().reset_index()
counts.columns = [categorical_cols[0], "count"]
fig = px.bar(
counts,
x=categorical_cols[0],
y="count",
title=title,
color_discrete_sequence=[self.THEME_COLORS["teal"]],
)
self._apply_standard_layout(fig)
return fig
def _create_generic_chart(
self, df: pd.DataFrame, col1: str, col2: str, title: str
) -> go.Figure:
"""Create a generic chart for any two columns."""
# Try to determine the best representation
if pd.api.types.is_numeric_dtype(df[col1]) and pd.api.types.is_numeric_dtype(
df[col2]
):
return self._create_scatter_plot(df, col1, col2, title)
else:
# Treat first as categorical, second as value
fig = px.bar(
df,
x=col1,
y=col2,
title=title,
color_discrete_sequence=[self.THEME_COLORS["orange"]],
)
self._apply_standard_layout(fig)
return fig
def _create_table(self, df: pd.DataFrame, title: str) -> go.Figure:
"""Create a Plotly table for DataFrames with 4 or more columns."""
# Prepare header
header_values = list(df.columns)
# Prepare cell values (transpose to get columns)
cell_values = [df[col].tolist() for col in df.columns]
# Create the table
fig = go.Figure(
data=[
go.Table(
header=dict(
values=header_values,
fill_color=self.THEME_COLORS["navy"],
font=dict(color="white", size=12),
align="left",
),
cells=dict(
values=cell_values,
fill_color=[
[
self.THEME_COLORS["cream"] if i % 2 == 0 else "white"
for i in range(len(df))
]
],
font=dict(color=self.THEME_COLORS["navy"], size=11),
align="left",
),
)
]
)
fig.update_layout(title=title, font={"color": self.THEME_COLORS["navy"]})
return fig
================================================
FILE: src/vanna/integrations/postgres/__init__.py
================================================
"""
PostgreSQL integration.
This module provides PostgreSQL runner implementation.
"""
from .sql_runner import PostgresRunner
__all__ = ["PostgresRunner"]
================================================
FILE: src/vanna/integrations/postgres/sql_runner.py
================================================
"""PostgreSQL implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class PostgresRunner(SqlRunner):
"""PostgreSQL implementation of the SqlRunner interface."""
def __init__(
self,
connection_string: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = 5432,
database: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
**kwargs,
):
"""Initialize with PostgreSQL connection parameters.
You can either provide a connection_string OR individual parameters (host, database, etc.).
If connection_string is provided, it takes precedence.
Args:
connection_string: PostgreSQL connection string (e.g., "postgresql://user:password@host:port/database")
host: Database host address
port: Database port (default: 5432)
database: Database name
user: Database user
password: Database password
**kwargs: Additional psycopg2 connection parameters (sslmode, connect_timeout, etc.)
"""
try:
import psycopg2
import psycopg2.extras
self.psycopg2 = psycopg2
except Exception as e:
raise ImportError(
"psycopg2 package is required. Install with: pip install 'vanna[postgres]'"
) from e
if connection_string:
self.connection_string = connection_string
self.connection_params = None
elif host and database and user:
self.connection_string = None
self.connection_params = {
"host": host,
"port": port,
"database": database,
"user": user,
"password": password,
**kwargs,
}
else:
raise ValueError(
"Either provide connection_string OR (host, database, and user) parameters"
)
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against PostgreSQL database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
psycopg2.Error: If query execution fails
"""
# Connect to the database using either connection string or parameters
if self.connection_string:
conn = self.psycopg2.connect(self.connection_string)
else:
conn = self.psycopg2.connect(**self.connection_params)
cursor = conn.cursor(cursor_factory=self.psycopg2.extras.RealDictCursor)
try:
# Execute the query
cursor.execute(args.sql)
# Determine if this is a SELECT query or modification query
query_type = args.sql.strip().upper().split()[0]
if query_type == "SELECT":
# Fetch results for SELECT queries
rows = cursor.fetchall()
if not rows:
# Return empty DataFrame
return pd.DataFrame()
# Convert rows to list of dictionaries
results_data = [dict(row) for row in rows]
return pd.DataFrame(results_data)
else:
# For non-SELECT queries (INSERT, UPDATE, DELETE, etc.)
conn.commit()
rows_affected = cursor.rowcount
# Return a DataFrame indicating rows affected
return pd.DataFrame({"rows_affected": [rows_affected]})
finally:
cursor.close()
conn.close()
================================================
FILE: src/vanna/integrations/premium/agent_memory/__init__.py
================================================
"""
Cloud-based agent memory implementations.
"""
from .premium import CloudAgentMemory
__all__ = ["CloudAgentMemory"]
================================================
FILE: src/vanna/integrations/premium/agent_memory/premium.py
================================================
"""
Cloud-based implementation of AgentMemory.
This implementation uses Vanna's premium cloud service for storing and searching
tool usage patterns with advanced similarity search and analytics.
"""
import json
from datetime import datetime
from typing import Any, Dict, List, Optional
import httpx
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class CloudAgentMemory(AgentMemory):
"""Cloud-based implementation of AgentMemory."""
def __init__(
self,
api_base_url: str = "https://api.vanna.ai",
api_key: Optional[str] = None,
organization_id: Optional[str] = None,
):
self.api_base_url = api_base_url.rstrip("/")
self.api_key = api_key
self.organization_id = organization_id
self._client = httpx.AsyncClient(base_url=self.api_base_url, timeout=30.0)
def _get_headers(self) -> Dict[str, str]:
"""Get request headers with authentication."""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
if self.organization_id:
headers["X-Organization-ID"] = self.organization_id
return headers
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern to premium cloud storage."""
import uuid
payload = {
"id": str(uuid.uuid4()),
"question": question,
"tool_name": tool_name,
"args": args,
"success": success,
"metadata": metadata or {},
"timestamp": datetime.now().isoformat(),
}
response = await self._client.post(
"/memory/tool-usage", json=payload, headers=self._get_headers()
)
response.raise_for_status()
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns in premium cloud storage."""
params = {
"question": question,
"limit": limit,
"similarity_threshold": similarity_threshold,
}
if tool_name_filter:
params["tool_name_filter"] = tool_name_filter
response = await self._client.get(
"/memory/search-similar", params=params, headers=self._get_headers()
)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("results", []):
memory = ToolMemory(**item["memory"])
result = ToolMemorySearchResult(
memory=memory,
similarity_score=item["similarity_score"],
rank=item["rank"],
)
results.append(result)
return results
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories from premium cloud storage."""
params = {"limit": limit}
response = await self._client.get(
"/memory/recent", params=params, headers=self._get_headers()
)
response.raise_for_status()
data = response.json()
memories = []
for item in data.get("memories", []):
memory = ToolMemory(**item)
memories.append(memory)
return memories
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID from premium cloud storage."""
response = await self._client.delete(
f"/memory/{memory_id}", headers=self._get_headers()
)
if response.status_code == 404:
return False
response.raise_for_status()
return True
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Cloud implementation does not yet support text memories."""
raise NotImplementedError("CloudAgentMemory does not support text memories.")
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Cloud implementation does not yet support text memories."""
return []
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Cloud implementation does not yet support text memories."""
return []
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Cloud implementation does not yet support text memories."""
return False
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories from premium cloud storage."""
payload = {}
if tool_name:
payload["tool_name"] = tool_name
if before_date:
payload["before_date"] = before_date
response = await self._client.delete(
"/memory/clear", json=payload, headers=self._get_headers()
)
response.raise_for_status()
data = response.json()
return data.get("deleted_count", 0)
================================================
FILE: src/vanna/integrations/presto/__init__.py
================================================
"""Presto integration for Vanna."""
from .sql_runner import PrestoRunner
__all__ = ["PrestoRunner"]
================================================
FILE: src/vanna/integrations/presto/sql_runner.py
================================================
"""Presto implementation of SqlRunner interface."""
from typing import Optional
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class PrestoRunner(SqlRunner):
"""Presto implementation of the SqlRunner interface."""
def __init__(
self,
host: str,
catalog: str = "hive",
schema: str = "default",
user: Optional[str] = None,
password: Optional[str] = None,
port: int = 443,
combined_pem_path: Optional[str] = None,
protocol: str = "https",
requests_kwargs: Optional[dict] = None,
**kwargs,
):
"""Initialize with Presto connection parameters.
Args:
host: The host address of the Presto database
catalog: The catalog to use in the Presto environment (default: 'hive')
schema: The schema to use in the Presto environment (default: 'default')
user: The username for authentication
password: The password for authentication
port: The port number for the Presto connection (default: 443)
combined_pem_path: The path to the combined pem file for SSL connection
protocol: The protocol to use for the connection (default: 'https')
requests_kwargs: Additional keyword arguments for requests
**kwargs: Additional pyhive connection parameters
"""
try:
from pyhive import presto
self.presto = presto
except ImportError as e:
raise ImportError(
"pyhive package is required. Install with: pip install pyhive"
) from e
self.host = host
self.catalog = catalog
self.schema = schema
self.user = user
self.password = password
self.port = port
self.protocol = protocol
self.kwargs = kwargs
# Set up requests_kwargs for SSL if combined_pem_path is provided
if requests_kwargs is None and combined_pem_path is not None:
self.requests_kwargs = {"verify": combined_pem_path}
else:
self.requests_kwargs = requests_kwargs
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against Presto database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
presto.Error: If query execution fails
"""
# Connect to the database
conn = self.presto.Connection(
host=self.host,
username=self.user,
password=self.password,
catalog=self.catalog,
schema=self.schema,
port=self.port,
protocol=self.protocol,
requests_kwargs=self.requests_kwargs,
**self.kwargs,
)
try:
# Strip and remove trailing semicolons (Presto doesn't like them)
sql = args.sql.rstrip()
if sql.endswith(";"):
sql = sql[:-1]
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])
cursor.close()
return df
finally:
conn.close()
================================================
FILE: src/vanna/integrations/qdrant/__init__.py
================================================
"""
Qdrant integration for Vanna Agents.
"""
from .agent_memory import QdrantAgentMemory
__all__ = ["QdrantAgentMemory"]
================================================
FILE: src/vanna/integrations/qdrant/agent_memory.py
================================================
"""
Qdrant vector database implementation of AgentMemory.
This implementation uses Qdrant for vector storage of tool usage patterns.
"""
import json
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
try:
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
Filter,
FieldCondition,
MatchValue,
)
QDRANT_AVAILABLE = True
except ImportError:
QDRANT_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class QdrantAgentMemory(AgentMemory):
"""Qdrant-based implementation of AgentMemory."""
def __init__(
self,
collection_name: str = "tool_memories",
url: Optional[str] = None,
path: Optional[str] = None,
api_key: Optional[str] = None,
dimension: int = 384,
):
if not QDRANT_AVAILABLE:
raise ImportError(
"Qdrant is required for QdrantAgentMemory. Install with: pip install qdrant-client"
)
self.collection_name = collection_name
self.url = url
self.path = path
self.api_key = api_key
self.dimension = dimension
self._client = None
self._executor = ThreadPoolExecutor(max_workers=2)
def _get_client(self):
"""Get or create Qdrant client."""
if self._client is None:
if self.url:
self._client = QdrantClient(url=self.url, api_key=self.api_key)
else:
self._client = QdrantClient(path=self.path or ":memory:")
# Create collection if it doesn't exist
collections = self._client.get_collections().collections
if not any(c.name == self.collection_name for c in collections):
self._client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.dimension, distance=Distance.COSINE
),
)
return self._client
def _create_embedding(self, text: str) -> List[float]:
"""Create a simple embedding from text (placeholder)."""
import hashlib
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
client = self._get_client()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(question)
payload = {
"question": question,
"tool_name": tool_name,
"args": args,
"timestamp": timestamp,
"success": success,
"metadata": metadata or {},
}
point = PointStruct(id=memory_id, vector=embedding, payload=payload)
client.upsert(collection_name=self.collection_name, points=[point])
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
client = self._get_client()
embedding = self._create_embedding(question)
# Build filter
query_filter = None
conditions = [FieldCondition(key="success", match=MatchValue(value=True))]
if tool_name_filter:
conditions.append(
FieldCondition(
key="tool_name", match=MatchValue(value=tool_name_filter)
)
)
if conditions:
query_filter = Filter(must=conditions)
# Use query_points for newer qdrant-client (1.8.0+) or search for older versions
if hasattr(client, "query_points"):
results = client.query_points(
collection_name=self.collection_name,
query=embedding,
limit=limit,
query_filter=query_filter,
score_threshold=similarity_threshold,
).points
else:
# Fallback to search method for older qdrant-client versions
results = client.search(
collection_name=self.collection_name,
query_vector=embedding,
limit=limit,
query_filter=query_filter,
score_threshold=similarity_threshold,
)
search_results = []
for i, hit in enumerate(results):
payload = hit.payload
memory = ToolMemory(
memory_id=str(hit.id),
question=payload["question"],
tool_name=payload["tool_name"],
args=payload["args"],
timestamp=payload.get("timestamp"),
success=payload.get("success", True),
metadata=payload.get("metadata", {}),
)
search_results.append(
ToolMemorySearchResult(
memory=memory, similarity_score=hit.score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories."""
def _get_recent():
client = self._get_client()
# Scroll through all points and sort by timestamp
points, _ = client.scroll(
collection_name=self.collection_name,
limit=1000, # Get more than we need to sort
with_payload=True,
with_vectors=False,
)
# Sort by timestamp
sorted_points = sorted(
points, key=lambda p: p.payload.get("timestamp", ""), reverse=True
)
memories = []
for point in sorted_points[:limit]:
payload = point.payload
# Skip text memories - they have is_text_memory flag
if payload.get("is_text_memory"):
continue
memory = ToolMemory(
memory_id=str(point.id),
question=payload["question"],
tool_name=payload["tool_name"],
args=payload["args"],
timestamp=payload.get("timestamp"),
success=payload.get("success", True),
metadata=payload.get("metadata", {}),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID. Returns True if deleted, False if not found."""
def _delete():
client = self._get_client()
try:
# Check if the point exists before attempting to delete
points = client.retrieve(
collection_name=self.collection_name,
ids=[memory_id],
with_payload=False,
with_vectors=False,
)
if points and len(points) > 0:
client.delete(
collection_name=self.collection_name,
points_selector=[memory_id],
)
return True
return False
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
client = self._get_client()
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(content)
payload = {
"content": content,
"timestamp": timestamp,
"is_text_memory": True,
}
point = PointStruct(id=memory_id, vector=embedding, payload=payload)
client.upsert(collection_name=self.collection_name, points=[point])
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
client = self._get_client()
embedding = self._create_embedding(query)
query_filter = Filter(
must=[
FieldCondition(key="is_text_memory", match=MatchValue(value=True))
]
)
# Use query_points for newer qdrant-client (1.8.0+) or search for older versions
if hasattr(client, "query_points"):
results = client.query_points(
collection_name=self.collection_name,
query=embedding,
limit=limit,
query_filter=query_filter,
score_threshold=similarity_threshold,
).points
else:
# Fallback to search method for older qdrant-client versions
results = client.search(
collection_name=self.collection_name,
query_vector=embedding,
limit=limit,
query_filter=query_filter,
score_threshold=similarity_threshold,
)
search_results = []
for i, hit in enumerate(results):
payload = hit.payload
memory = TextMemory(
memory_id=str(hit.id),
content=payload.get("content", ""),
timestamp=payload.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory, similarity_score=hit.score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
client = self._get_client()
# Scroll through text memory points and sort by timestamp
points, _ = client.scroll(
collection_name=self.collection_name,
scroll_filter=Filter(
must=[
FieldCondition(
key="is_text_memory", match=MatchValue(value=True)
)
]
),
limit=1000,
with_payload=True,
with_vectors=False,
)
# Sort by timestamp
sorted_points = sorted(
points, key=lambda p: p.payload.get("timestamp", ""), reverse=True
)
memories = []
for point in sorted_points[:limit]:
payload = point.payload
memory = TextMemory(
memory_id=str(point.id),
content=payload.get("content", ""),
timestamp=payload.get("timestamp"),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
client = self._get_client()
try:
# Check if the point exists before attempting to delete
points = client.retrieve(
collection_name=self.collection_name,
ids=[memory_id],
with_payload=False,
with_vectors=False,
)
if points and len(points) > 0:
client.delete(
collection_name=self.collection_name,
points_selector=[memory_id],
)
return True
return False
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
client = self._get_client()
# Build filter
conditions = []
if tool_name:
conditions.append(
FieldCondition(key="tool_name", match=MatchValue(value=tool_name))
)
if before_date:
conditions.append(
FieldCondition(key="timestamp", match=MatchValue(value=before_date))
)
if conditions or (tool_name is None and before_date is None):
# Delete with filter or delete all
query_filter = Filter(must=conditions) if conditions else None
if query_filter:
client.delete(
collection_name=self.collection_name,
points_selector=query_filter,
)
else:
# Delete all points
client.delete_collection(collection_name=self.collection_name)
# Recreate empty collection
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.dimension, distance=Distance.COSINE
),
)
return 0 # Qdrant doesn't return count
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/integrations/snowflake/__init__.py
================================================
"""Snowflake integration for Vanna."""
from .sql_runner import SnowflakeRunner
__all__ = ["SnowflakeRunner"]
================================================
FILE: src/vanna/integrations/snowflake/sql_runner.py
================================================
"""Snowflake implementation of SqlRunner interface."""
from typing import Optional, Union
import os
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class SnowflakeRunner(SqlRunner):
"""Snowflake implementation of the SqlRunner interface."""
def __init__(
self,
account: str,
username: str,
password: Optional[str] = None,
database: str = "",
role: Optional[str] = None,
warehouse: Optional[str] = None,
private_key_path: Optional[str] = None,
private_key_passphrase: Optional[str] = None,
private_key_content: Optional[bytes] = None,
**kwargs,
):
"""Initialize with Snowflake connection parameters.
Args:
account: Snowflake account identifier
username: Database user
password: Database password (optional if using key-pair auth)
database: Database name
role: Snowflake role to use (optional)
warehouse: Snowflake warehouse to use (optional)
private_key_path: Path to private key file for RSA key-pair authentication (optional)
private_key_passphrase: Passphrase for encrypted private key (optional)
private_key_content: Private key content as bytes (optional, alternative to private_key_path)
**kwargs: Additional snowflake.connector connection parameters
Note:
Either password OR private_key_path/private_key_content must be provided.
RSA key-pair authentication is recommended for production systems as Snowflake
is deprecating user/password authentication.
"""
try:
import snowflake.connector
self.snowflake = snowflake.connector
except ImportError as e:
raise ImportError(
"snowflake-connector-python package is required. "
"Install with: pip install 'vanna[snowflake]'"
) from e
# Validate that at least one authentication method is provided
if not password and not private_key_path and not private_key_content:
raise ValueError(
"Either password or private_key_path/private_key_content must be provided for authentication"
)
# Validate private key path exists if provided
if private_key_path and not os.path.isfile(private_key_path):
raise FileNotFoundError(f"Private key file not found: {private_key_path}")
self.account = account
self.username = username
self.password = password
self.database = database
self.role = role
self.warehouse = warehouse
self.private_key_path = private_key_path
self.private_key_passphrase = private_key_passphrase
self.private_key_content = private_key_content
self.kwargs = kwargs
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against Snowflake database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
snowflake.connector.Error: If query execution fails
"""
# Build connection parameters
conn_params = {
"user": self.username,
"account": self.account,
"client_session_keep_alive": True,
}
# Add database if specified
if self.database:
conn_params["database"] = self.database
# Configure authentication method
if self.private_key_path or self.private_key_content:
# Use RSA key-pair authentication
if self.private_key_path:
conn_params["private_key_path"] = self.private_key_path
else:
conn_params["private_key_content"] = self.private_key_content
# Add passphrase if provided
if self.private_key_passphrase:
conn_params["private_key_passphrase"] = self.private_key_passphrase
else:
# Use password authentication (fallback)
conn_params["password"] = self.password
# Add any additional kwargs
conn_params.update(self.kwargs)
# Connect to the database
conn = self.snowflake.connect(**conn_params)
cursor = conn.cursor()
try:
# Set role if specified
if self.role:
cursor.execute(f"USE ROLE {self.role}")
# Set warehouse if specified
if self.warehouse:
cursor.execute(f"USE WAREHOUSE {self.warehouse}")
# Use the specified database if provided
if self.database:
cursor.execute(f"USE DATABASE {self.database}")
# Execute the query
cursor.execute(args.sql)
results = cursor.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cursor.description])
return df
finally:
cursor.close()
conn.close()
================================================
FILE: src/vanna/integrations/sqlite/__init__.py
================================================
"""
SQLite integration.
This module provides SQLite runner implementation.
"""
from .sql_runner import SqliteRunner
__all__ = ["SqliteRunner"]
================================================
FILE: src/vanna/integrations/sqlite/sql_runner.py
================================================
"""SQLite implementation of SqlRunner interface."""
import sqlite3
import pandas as pd
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
class SqliteRunner(SqlRunner):
"""SQLite implementation of the SqlRunner interface."""
def __init__(self, database_path: str):
"""Initialize with a SQLite database path.
Args:
database_path: Path to the SQLite database file
"""
self.database_path = database_path
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query against SQLite database and return results as DataFrame.
Args:
args: SQL query arguments
context: Tool execution context
Returns:
DataFrame with query results
Raises:
sqlite3.Error: If query execution fails
"""
# Connect to the database
conn = sqlite3.connect(self.database_path)
conn.row_factory = sqlite3.Row # Enable column access by name
cursor = conn.cursor()
try:
# Execute the query
cursor.execute(args.sql)
# Determine if this is a SELECT query or modification query
query_type = args.sql.strip().upper().split()[0]
if query_type == "SELECT":
# Fetch results for SELECT queries
rows = cursor.fetchall()
if not rows:
# Return empty DataFrame
return pd.DataFrame()
# Convert rows to list of dictionaries
results_data = [dict(row) for row in rows]
return pd.DataFrame(results_data)
else:
# For non-SELECT queries (INSERT, UPDATE, DELETE, etc.)
conn.commit()
rows_affected = cursor.rowcount
# Return a DataFrame indicating rows affected
return pd.DataFrame({"rows_affected": [rows_affected]})
finally:
cursor.close()
conn.close()
================================================
FILE: src/vanna/integrations/weaviate/__init__.py
================================================
"""
Weaviate integration for Vanna Agents.
"""
from .agent_memory import WeaviateAgentMemory
__all__ = ["WeaviateAgentMemory"]
================================================
FILE: src/vanna/integrations/weaviate/agent_memory.py
================================================
"""
Weaviate vector database implementation of AgentMemory.
This implementation uses Weaviate for semantic search and storage of tool usage patterns.
"""
import json
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import asyncio
from concurrent.futures import ThreadPoolExecutor
try:
import weaviate
from weaviate.classes.config import (
Configure,
Property,
DataType as WeaviateDataType,
)
WEAVIATE_AVAILABLE = True
except ImportError:
WEAVIATE_AVAILABLE = False
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
class WeaviateAgentMemory(AgentMemory):
"""Weaviate-based implementation of AgentMemory."""
def __init__(
self,
collection_name: str = "ToolMemory",
url: str = "http://localhost:8080",
api_key: Optional[str] = None,
dimension: int = 384,
):
if not WEAVIATE_AVAILABLE:
raise ImportError(
"Weaviate is required for WeaviateAgentMemory. Install with: pip install weaviate-client"
)
self.collection_name = collection_name
self.url = url
self.api_key = api_key
self.dimension = dimension
self._client = None
self._executor = ThreadPoolExecutor(max_workers=2)
def _get_client(self):
"""Get or create Weaviate client."""
if self._client is None:
if self.api_key:
self._client = weaviate.connect_to_weaviate_cloud(
cluster_url=self.url,
auth_credentials=weaviate.auth.AuthApiKey(self.api_key),
)
else:
self._client = weaviate.connect_to_local(
host=self.url.replace("http://", "").replace("https://", "")
)
# Create collection if it doesn't exist
if not self._client.collections.exists(self.collection_name):
self._client.collections.create(
name=self.collection_name,
vectorizer_config=Configure.Vectorizer.none(),
properties=[
Property(name="question", data_type=WeaviateDataType.TEXT),
Property(name="tool_name", data_type=WeaviateDataType.TEXT),
Property(name="args_json", data_type=WeaviateDataType.TEXT),
Property(name="timestamp", data_type=WeaviateDataType.TEXT),
Property(name="success", data_type=WeaviateDataType.BOOL),
Property(name="metadata_json", data_type=WeaviateDataType.TEXT),
],
)
return self._client
def _create_embedding(self, text: str) -> List[float]:
"""Create a simple embedding from text (placeholder)."""
import hashlib
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern."""
def _save():
client = self._get_client()
collection = client.collections.get(self.collection_name)
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(question)
properties = {
"question": question,
"tool_name": tool_name,
"args_json": json.dumps(args),
"timestamp": timestamp,
"success": success,
"metadata_json": json.dumps(metadata or {}),
}
collection.data.insert(
properties=properties, vector=embedding, uuid=memory_id
)
await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns."""
def _search():
client = self._get_client()
collection = client.collections.get(self.collection_name)
embedding = self._create_embedding(question)
# Build filter
filters = weaviate.classes.query.Filter.by_property("success").equal(True)
if tool_name_filter:
filters = filters & weaviate.classes.query.Filter.by_property(
"tool_name"
).equal(tool_name_filter)
response = collection.query.near_vector(
near_vector=embedding,
limit=limit,
filters=filters,
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
)
search_results = []
for i, obj in enumerate(response.objects):
# Weaviate returns distance, convert to similarity
distance = obj.metadata.distance if obj.metadata else 1.0
similarity_score = 1 - distance
if similarity_score >= similarity_threshold:
properties = obj.properties
args = json.loads(properties.get("args_json", "{}"))
metadata_dict = json.loads(properties.get("metadata_json", "{}"))
memory = ToolMemory(
memory_id=str(obj.uuid),
question=properties.get("question"),
tool_name=properties.get("tool_name"),
args=args,
timestamp=properties.get("timestamp"),
success=properties.get("success", True),
metadata=metadata_dict,
)
search_results.append(
ToolMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories."""
def _get_recent():
client = self._get_client()
collection = client.collections.get(self.collection_name)
# Query and sort by timestamp
response = collection.query.fetch_objects(limit=1000)
# Convert to list and sort
objects_list = list(response.objects)
sorted_objects = sorted(
objects_list,
key=lambda o: o.properties.get("timestamp", ""),
reverse=True,
)
memories = []
for obj in sorted_objects[:limit]:
properties = obj.properties
args = json.loads(properties.get("args_json", "{}"))
metadata_dict = json.loads(properties.get("metadata_json", "{}"))
memory = ToolMemory(
memory_id=str(obj.uuid),
question=properties.get("question"),
tool_name=properties.get("tool_name"),
args=args,
timestamp=properties.get("timestamp"),
success=properties.get("success", True),
metadata=metadata_dict,
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID."""
def _delete():
client = self._get_client()
collection = client.collections.get(self.collection_name)
try:
collection.data.delete_by_id(uuid=memory_id)
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save a text memory."""
def _save():
client = self._get_client()
collection = client.collections.get(self.collection_name)
memory_id = str(uuid.uuid4())
timestamp = datetime.now().isoformat()
embedding = self._create_embedding(content)
properties = {
"question": content, # Using question field for content
"tool_name": "", # Empty for text memories
"args_json": "",
"timestamp": timestamp,
"success": True,
"metadata_json": json.dumps({"is_text_memory": True}),
}
collection.data.insert(
properties=properties, vector=embedding, uuid=memory_id
)
return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search for similar text memories."""
def _search():
client = self._get_client()
collection = client.collections.get(self.collection_name)
embedding = self._create_embedding(query)
# Build filter for text memories (empty tool_name)
filters = weaviate.classes.query.Filter.by_property("tool_name").equal("")
response = collection.query.near_vector(
near_vector=embedding,
limit=limit,
filters=filters,
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
)
search_results = []
for i, obj in enumerate(response.objects):
distance = obj.metadata.distance if obj.metadata else 1.0
similarity_score = 1 - distance
if similarity_score >= similarity_threshold:
properties = obj.properties
content = properties.get("question", "")
memory = TextMemory(
memory_id=str(obj.uuid),
content=content,
timestamp=properties.get("timestamp"),
)
search_results.append(
TextMemorySearchResult(
memory=memory, similarity_score=similarity_score, rank=i + 1
)
)
return search_results
return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Get recently added text memories."""
def _get_recent():
client = self._get_client()
collection = client.collections.get(self.collection_name)
# Query text memories (empty tool_name) and sort by timestamp
response = collection.query.fetch_objects(
filters=weaviate.classes.query.Filter.by_property("tool_name").equal(
""
),
limit=1000,
)
# Convert to list and sort
objects_list = list(response.objects)
sorted_objects = sorted(
objects_list,
key=lambda o: o.properties.get("timestamp", ""),
reverse=True,
)
memories = []
for obj in sorted_objects[:limit]:
properties = obj.properties
content = properties.get("question", "")
memory = TextMemory(
memory_id=str(obj.uuid),
content=content,
timestamp=properties.get("timestamp"),
)
memories.append(memory)
return memories
return await asyncio.get_event_loop().run_in_executor(
self._executor, _get_recent
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID."""
def _delete():
client = self._get_client()
collection = client.collections.get(self.collection_name)
try:
collection.data.delete_by_id(uuid=memory_id)
return True
except Exception:
return False
return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories."""
def _clear():
client = self._get_client()
collection = client.collections.get(self.collection_name)
# Build filter
if tool_name and before_date:
filters = weaviate.classes.query.Filter.by_property("tool_name").equal(
tool_name
) & weaviate.classes.query.Filter.by_property("timestamp").less_than(
before_date
)
elif tool_name:
filters = weaviate.classes.query.Filter.by_property("tool_name").equal(
tool_name
)
elif before_date:
filters = weaviate.classes.query.Filter.by_property(
"timestamp"
).less_than(before_date)
else:
filters = None
if filters:
collection.data.delete_many(where=filters)
else:
# Delete all
collection.data.delete_many(
where=weaviate.classes.query.Filter.by_property(
"success"
).contains_any([True, False])
)
return 0
return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
================================================
FILE: src/vanna/legacy/ZhipuAI/ZhipuAI_Chat.py
================================================
import re
from typing import List
import pandas as pd
from zhipuai import ZhipuAI
from ..base import VannaBase
class ZhipuAI_Chat(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
if config is None:
return
if "api_key" not in config:
raise Exception("Missing api_key in config")
self.api_key = config["api_key"]
self.model = config["model"] if "model" in config else "glm-4"
self.api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
# Static methods similar to those in ZhipuAI_Chat for message formatting and utility
@staticmethod
def system_message(message: str) -> dict:
return {"role": "system", "content": message}
@staticmethod
def user_message(message: str) -> dict:
return {"role": "user", "content": message}
@staticmethod
def assistant_message(message: str) -> dict:
return {"role": "assistant", "content": message}
@staticmethod
def str_to_approx_token_count(string: str) -> int:
return len(string) / 4
@staticmethod
def add_ddl_to_prompt(
initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
) -> str:
if len(ddl_list) > 0:
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
for ddl in ddl_list:
if (
ZhipuAI_Chat.str_to_approx_token_count(initial_prompt)
+ ZhipuAI_Chat.str_to_approx_token_count(ddl)
< max_tokens
):
initial_prompt += f"{ddl}\n\n"
return initial_prompt
@staticmethod
def add_documentation_to_prompt(
initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
) -> str:
if len(documentation_List) > 0:
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
for documentation in documentation_List:
if (
ZhipuAI_Chat.str_to_approx_token_count(initial_prompt)
+ ZhipuAI_Chat.str_to_approx_token_count(documentation)
< max_tokens
):
initial_prompt += f"{documentation}\n\n"
return initial_prompt
@staticmethod
def add_sql_to_prompt(
initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
) -> str:
if len(sql_List) > 0:
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
for question in sql_List:
if (
ZhipuAI_Chat.str_to_approx_token_count(initial_prompt)
+ ZhipuAI_Chat.str_to_approx_token_count(question["sql"])
< max_tokens
):
initial_prompt += f"{question['question']}\n{question['sql']}\n\n"
return initial_prompt
def get_sql_prompt(
self,
question: str,
question_sql_list: List,
ddl_list: List,
doc_list: List,
**kwargs,
):
initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"
initial_prompt = ZhipuAI_Chat.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=14000
)
initial_prompt = ZhipuAI_Chat.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=14000
)
message_log = [ZhipuAI_Chat.system_message(initial_prompt)]
for example in question_sql_list:
if example is None:
print("example is None")
else:
if example is not None and "question" in example and "sql" in example:
message_log.append(ZhipuAI_Chat.user_message(example["question"]))
message_log.append(ZhipuAI_Chat.assistant_message(example["sql"]))
message_log.append({"role": "user", "content": question})
return message_log
def get_followup_questions_prompt(
self,
question: str,
df: pd.DataFrame,
question_sql_list: List,
ddl_list: List,
doc_list: List,
**kwargs,
):
initial_prompt = f"The user initially asked the question: '{question}': \n\n"
initial_prompt = ZhipuAI_Chat.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=14000
)
initial_prompt = ZhipuAI_Chat.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=14000
)
initial_prompt = ZhipuAI_Chat.add_sql_to_prompt(
initial_prompt, question_sql_list, max_tokens=14000
)
message_log = [ZhipuAI_Chat.system_message(initial_prompt)]
message_log.append(
ZhipuAI_Chat.user_message(
"Generate a List of followup questions that the user might ask about this data. Respond with a List of questions, one per line. Do not answer with any explanations -- just the questions."
)
)
return message_log
def generate_question(self, sql: str, **kwargs) -> str:
response = self.submit_prompt(
[
self.system_message(
"The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
),
self.user_message(sql),
],
**kwargs,
)
return response
def _extract_python_code(self, markdown_string: str) -> str:
# Regex pattern to match Python code blocks
pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"
# Find all matches in the markdown string
matches = re.findall(pattern, markdown_string, re.IGNORECASE)
# Extract the Python code from the matches
python_code = []
for match in matches:
python = match[0] if match[0] else match[1]
python_code.append(python.strip())
if len(python_code) == 0:
return markdown_string
return python_code[0]
def _sanitize_plotly_code(self, raw_plotly_code: str) -> str:
# Remove the fig.show() statement from the plotly code
plotly_code = raw_plotly_code.replace("fig.show()", "")
return plotly_code
def generate_plotly_code(
self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs
) -> str:
if question is not None:
system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'"
else:
system_msg = "The following is a pandas DataFrame "
if sql is not None:
system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n"
system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}"
message_log = [
self.system_message(system_msg),
self.user_message(
"Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
),
]
plotly_code = self.submit_prompt(message_log, kwargs=kwargs)
return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
def submit_prompt(
self, prompt, max_tokens=500, temperature=0.7, top_p=0.7, stop=None, **kwargs
):
if prompt is None:
raise Exception("Prompt is None")
if len(prompt) == 0:
raise Exception("Prompt is empty")
client = ZhipuAI(api_key=self.api_key)
response = client.chat.completions.create(
model="glm-4",
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
messages=prompt,
)
return response.choices[0].message.content
================================================
FILE: src/vanna/legacy/ZhipuAI/ZhipuAI_embeddings.py
================================================
from typing import List
from zhipuai import ZhipuAI
from chromadb import Documents, EmbeddingFunction, Embeddings
from ..base import VannaBase
class ZhipuAI_Embeddings(VannaBase):
"""
[future functionality] This function is used to generate embeddings from ZhipuAI.
Args:
VannaBase (_type_): _description_
"""
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
if "api_key" not in config:
raise Exception("Missing api_key in config")
self.api_key = config["api_key"]
self.client = ZhipuAI(api_key=self.api_key)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding = self.client.embeddings.create(
model="embedding-2",
input=data,
)
return embedding.data[0].embedding
class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""
A embeddingFunction that uses ZhipuAI to generate embeddings which can use in chromadb.
usage:
class MyVanna(ChromaDB_VectorStore, ZhipuAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
ZhipuAI_Chat.__init__(self, config=config)
config={'api_key': 'xxx'}
zhipu_embedding_function = ZhipuAIEmbeddingFunction(config=config)
config = {"api_key": "xxx", "model": "glm-4","path":"xy","embedding_function":zhipu_embedding_function}
vn = MyVanna(config)
"""
def __init__(self, config=None):
if config is None or "api_key" not in config:
raise ValueError("Missing 'api_key' in config")
self.api_key = config["api_key"]
self.model_name = config.get("model_name", "embedding-2")
try:
self.client = ZhipuAI(api_key=self.api_key)
except Exception as e:
raise ValueError(f"Error initializing ZhipuAI client: {e}")
def __call__(self, input: Documents) -> Embeddings:
# Replace newlines, which can negatively affect performance.
input = [t.replace("\n", " ") for t in input]
all_embeddings = []
print(f"Generating embeddings for {len(input)} documents")
# Iterating over each document for individual API calls
for document in input:
try:
response = self.client.embeddings.create(
model=self.model_name, input=document
)
# print(response)
embedding = response.data[0].embedding
all_embeddings.append(embedding)
# print(f"Cost required: {response.usage.total_tokens}")
except Exception as e:
raise ValueError(f"Error generating embedding for document: {e}")
return all_embeddings
================================================
FILE: src/vanna/legacy/ZhipuAI/__init__.py
================================================
from .ZhipuAI_Chat import ZhipuAI_Chat
from .ZhipuAI_embeddings import ZhipuAI_Embeddings, ZhipuAIEmbeddingFunction
================================================
FILE: src/vanna/legacy/__init__.py
================================================
import dataclasses
import json
import os
from dataclasses import dataclass
from typing import Callable, List, Tuple, Union
import pandas as pd
import requests
import plotly.graph_objs
from .exceptions import (
OTPCodeError,
ValidationError,
)
from .types import (
ApiKey,
Status,
TrainingData,
UserEmail,
UserOTP,
)
from .utils import sanitize_model_name, validate_config_path
api_key: Union[str, None] = None # API key for Vanna.AI
fig_as_img: bool = False # Whether or not to return Plotly figures as images
run_sql: Union[Callable[[str], pd.DataFrame], None] = (
None # Function to convert SQL to a Pandas DataFrame
)
"""
**Example**
```python
vn.run_sql = lambda sql: pd.read_sql(sql, engine)
```
Set the SQL to DataFrame function for Vanna.AI. This is used in the [`vn.ask(...)`][vanna.ask] function.
Instead of setting this directly you can also use [`vn.connect_to_snowflake(...)`][vanna.connect_to_snowflake] to set this.
"""
__org: Union[str, None] = None # Organization name for Vanna.AI
_unauthenticated_endpoint = "https://ask.vanna.ai/unauthenticated_rpc"
def error_deprecation():
raise Exception("""
Please switch to the following method for initializing Vanna:
from vanna.remote import VannaDefault
api_key = # Your API key from https://vanna.ai/account/profile
vanna_model_name = # Your model name from https://vanna.ai/account/profile
vn = VannaDefault(model=vanna_model_name, api_key=api_key)
""")
def __unauthenticated_rpc_call(method, params):
headers = {
"Content-Type": "application/json",
}
data = {"method": method, "params": [__dataclass_to_dict(obj) for obj in params]}
response = requests.post(
_unauthenticated_endpoint, headers=headers, data=json.dumps(data)
)
return response.json()
def __dataclass_to_dict(obj):
return dataclasses.asdict(obj)
def get_api_key(email: str, otp_code: Union[str, None] = None) -> str:
"""
**Example:**
```python
vn.get_api_key(email="my-email@example.com")
```
Login to the Vanna.AI API.
Args:
email (str): The email address to login with.
otp_code (Union[str, None]): The OTP code to login with. If None, an OTP code will be sent to the email address.
Returns:
str: The API key.
"""
vanna_api_key = os.environ.get("VANNA_API_KEY", None)
if vanna_api_key is not None:
return vanna_api_key
if email == "my-email@example.com":
raise ValidationError(
"Please replace 'my-email@example.com' with your email address."
)
if otp_code is None:
params = [UserEmail(email=email)]
d = __unauthenticated_rpc_call(method="send_otp", params=params)
if "result" not in d:
raise OTPCodeError("Error sending OTP code.")
status = Status(**d["result"])
if not status.success:
raise OTPCodeError(f"Error sending OTP code: {status.message}")
otp_code = input("Check your email for the code and enter it here: ")
params = [UserOTP(email=email, otp=otp_code)]
d = __unauthenticated_rpc_call(method="verify_otp", params=params)
if "result" not in d:
raise OTPCodeError("Error verifying OTP code.")
key = ApiKey(**d["result"])
if key is None:
raise OTPCodeError("Error verifying OTP code.")
api_key = key.key
return api_key
def set_api_key(key: str) -> None:
error_deprecation()
def get_models() -> List[str]:
error_deprecation()
def create_model(model: str, db_type: str) -> bool:
error_deprecation()
def add_user_to_model(model: str, email: str, is_admin: bool) -> bool:
error_deprecation()
def update_model_visibility(public: bool) -> bool:
error_deprecation()
def set_model(model: str):
error_deprecation()
def add_sql(
question: str, sql: str, tag: Union[str, None] = "Manually Trained"
) -> bool:
error_deprecation()
def add_ddl(ddl: str) -> bool:
error_deprecation()
def add_documentation(documentation: str) -> bool:
error_deprecation()
@dataclass
class TrainingPlanItem:
item_type: str
item_group: str
item_name: str
item_value: str
def __str__(self):
if self.item_type == self.ITEM_TYPE_SQL:
return f"Train on SQL: {self.item_group} {self.item_name}"
elif self.item_type == self.ITEM_TYPE_DDL:
return f"Train on DDL: {self.item_group} {self.item_name}"
elif self.item_type == self.ITEM_TYPE_IS:
return f"Train on Information Schema: {self.item_group} {self.item_name}"
ITEM_TYPE_SQL = "sql"
ITEM_TYPE_DDL = "ddl"
ITEM_TYPE_IS = "is"
class TrainingPlan:
"""
A class representing a training plan. You can see what's in it, and remove items from it that you don't want trained.
**Example:**
```python
plan = vn.get_training_plan()
plan.get_summary()
```
"""
_plan: List[TrainingPlanItem]
def __init__(self, plan: List[TrainingPlanItem]):
self._plan = plan
def __str__(self):
return "\n".join(self.get_summary())
def __repr__(self):
return self.__str__()
def get_summary(self) -> List[str]:
"""
**Example:**
```python
plan = vn.get_training_plan()
plan.get_summary()
```
Get a summary of the training plan.
Returns:
List[str]: A list of strings describing the training plan.
"""
return [f"{item}" for item in self._plan]
def remove_item(self, item: str):
"""
**Example:**
```python
plan = vn.get_training_plan()
plan.remove_item("Train on SQL: What is the average salary of employees?")
```
Remove an item from the training plan.
Args:
item (str): The item to remove.
"""
for plan_item in self._plan:
if str(plan_item) == item:
self._plan.remove(plan_item)
break
def get_training_plan_postgres(
filter_databases: Union[List[str], None] = None,
filter_schemas: Union[List[str], None] = None,
include_information_schema: bool = False,
use_historical_queries: bool = True,
) -> TrainingPlan:
error_deprecation()
def get_training_plan_generic(df) -> TrainingPlan:
error_deprecation()
def get_training_plan_experimental(
filter_databases: Union[List[str], None] = None,
filter_schemas: Union[List[str], None] = None,
include_information_schema: bool = False,
use_historical_queries: bool = True,
) -> TrainingPlan:
error_deprecation()
def train(
question: str = None,
sql: str = None,
ddl: str = None,
documentation: str = None,
json_file: str = None,
sql_file: str = None,
plan: TrainingPlan = None,
) -> bool:
error_deprecation()
def flag_sql_for_review(
question: str, sql: Union[str, None] = None, error_msg: Union[str, None] = None
) -> bool:
error_deprecation()
def remove_sql(question: str) -> bool:
error_deprecation()
def remove_training_data(id: str) -> bool:
error_deprecation()
def generate_sql(question: str) -> str:
error_deprecation()
def get_related_training_data(question: str) -> TrainingData:
error_deprecation()
def generate_meta(question: str) -> str:
error_deprecation()
def generate_followup_questions(question: str, df: pd.DataFrame) -> List[str]:
error_deprecation()
def generate_questions() -> List[str]:
error_deprecation()
def ask(
question: Union[str, None] = None,
print_results: bool = True,
auto_train: bool = True,
generate_followups: bool = True,
) -> Union[
Tuple[
Union[str, None],
Union[pd.DataFrame, None],
Union[plotly.graph_objs.Figure, None],
Union[List[str], None],
],
None,
]:
error_deprecation()
def generate_plotly_code(
question: Union[str, None],
sql: Union[str, None],
df: pd.DataFrame,
chart_instructions: Union[str, None] = None,
) -> str:
error_deprecation()
def get_plotly_figure(
plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
) -> plotly.graph_objs.Figure:
error_deprecation()
def get_results(cs, default_database: str, sql: str) -> pd.DataFrame:
error_deprecation()
def generate_explanation(sql: str) -> str:
error_deprecation()
def generate_question(sql: str) -> str:
error_deprecation()
def get_all_questions() -> pd.DataFrame:
error_deprecation()
def get_training_data() -> pd.DataFrame:
error_deprecation()
def connect_to_sqlite(url: str):
error_deprecation()
def connect_to_snowflake(
account: str,
username: str,
password: str,
database: str,
schema: Union[str, None] = None,
role: Union[str, None] = None,
):
error_deprecation()
def connect_to_postgres(
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
):
error_deprecation()
def connect_to_bigquery(cred_file_path: str = None, project_id: str = None):
error_deprecation()
def connect_to_duckdb(url: str = "memory", init_sql: str = None):
error_deprecation()
================================================
FILE: src/vanna/legacy/adapter.py
================================================
"""
Legacy VannaBase adapter for the Vanna Agents framework.
This module provides a LegacyVannaAdapter that bridges legacy VannaBase objects
with the new ToolRegistry system by auto-registering legacy methods as tools
with appropriate group-based access control.
"""
from typing import Any, Dict, List, Optional
import pandas as pd
from ..capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from ..capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from ..core.registry import ToolRegistry
from ..core.tool import Tool, ToolContext, ToolResult
from ..core.user import User
from ..tools.agent_memory import (
SaveQuestionToolArgsTool,
SearchSavedCorrectToolUsesTool,
)
from ..tools.run_sql import RunSqlTool
from .base.base import VannaBase
class LegacySqlRunner(SqlRunner):
"""SqlRunner implementation that wraps a legacy VannaBase instance.
This class bridges the new SqlRunner interface with legacy VannaBase
run_sql methods, allowing legacy database connections to work with
the new tool-based architecture.
"""
def __init__(self, vn: VannaBase):
"""Initialize with a legacy VannaBase instance.
Args:
vn: The legacy VannaBase instance with an initialized run_sql method
"""
self.vn = vn
async def run_sql(self, args: RunSqlToolArgs, context: ToolContext) -> pd.DataFrame:
"""Execute SQL query using the legacy VannaBase run_sql method.
Args:
args: SQL query arguments containing the SQL string
context: Tool execution context (not used by legacy implementation)
Returns:
DataFrame with query results
Raises:
Exception: If query execution fails
"""
# Call the legacy VannaBase run_sql method
# The legacy method is synchronous, so we call it directly
return self.vn.run_sql(args.sql)
class LegacyVannaAdapter(ToolRegistry, AgentMemory):
"""Adapter that wraps a legacy VannaBase object and exposes its methods as tools.
This adapter automatically registers specific VannaBase methods as tools in the
registry with configurable group-based access control. This allows legacy Vanna
instances to work seamlessly with the new Agents framework.
Features:
- Auto-registers legacy methods as tools
- Configurable group-based permissions ('user', 'admin', etc.)
- Seamless integration with ToolRegistry
- Implements AgentMemory interface
- Preserves legacy VannaBase functionality
Example:
```python
from vanna.legacy.base import VannaBase
from vanna.legacy.adapter import LegacyVannaAdapter
# Initialize your legacy Vanna instance
vn = VannaBase(config={"model": "gpt-4"})
vn.connect_to_postgres(...)
# Create adapter and auto-register tools
adapter = LegacyVannaAdapter(vn)
# Tools are now available through the registry
schemas = await adapter.get_schemas(user)
```
"""
def __init__(
self,
vn: VannaBase,
audit_logger: Optional[Any] = None,
audit_config: Optional[Any] = None,
) -> None:
"""Initialize the adapter with a legacy VannaBase instance.
Args:
vanna: The legacy VannaBase instance to wrap
audit_logger: Optional audit logger for tool execution tracking
audit_config: Optional audit configuration
"""
ToolRegistry.__init__(
self, audit_logger=audit_logger, audit_config=audit_config
)
self.vn = vn
self._register_tools()
def _register_tools(self) -> None:
"""Register legacy VannaBase methods as tools with appropriate permissions.
Registers the following tools:
- RunSqlTool: Wraps the legacy run_sql method via LegacySqlRunner
- SaveQuestionToolArgsTool: Wraps add_question_sql via LegacyAgentMemory
- SearchSavedCorrectToolUsesTool: Wraps get_similar_question_sql via LegacyAgentMemory
"""
# Create a LegacySqlRunner to wrap the VannaBase run_sql method
sql_runner = LegacySqlRunner(self.vn)
# Register the RunSqlTool with user and admin access
run_sql_tool = RunSqlTool(sql_runner)
self.register_local_tool(run_sql_tool, access_groups=["user", "admin"])
# Register memory tools using the internal _agent_memory instance
# SaveQuestionToolArgsTool - for saving question-tool-args patterns (admin only)
save_memory_tool = SaveQuestionToolArgsTool()
self.register_local_tool(save_memory_tool, access_groups=["admin"])
# SearchSavedCorrectToolUsesTool - for searching similar patterns (user and admin)
search_memory_tool = SearchSavedCorrectToolUsesTool()
self.register_local_tool(search_memory_tool, access_groups=["user", "admin"])
# AgentMemory interface implementation
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save a tool usage pattern by storing it as a question-sql pair.
Args:
question: The user question
tool_name: Name of the tool that was used
args: Arguments passed to the tool
context: Tool execution context (not used by legacy implementation)
success: Whether the tool execution was successful
metadata: Additional metadata (not used by legacy implementation)
"""
# For legacy compatibility, we primarily care about SQL queries
# Extract SQL from args if this was a run_sql tool
if tool_name == "run_sql" and "sql" in args:
sql = args["sql"]
# Call the legacy add_question_sql method
# The legacy method is synchronous, so we call it directly
self.vn.add_question_sql(question=question, sql=sql)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
"""Search for similar tool usage patterns using legacy question-sql lookup.
Args:
question: The question to search for
context: Tool execution context (not used by legacy implementation)
limit: Maximum number of results (not directly supported by legacy)
similarity_threshold: Minimum similarity score (not directly supported by legacy)
tool_name_filter: Filter by tool name (not directly supported by legacy)
Returns:
List of memory search results with similar question-sql pairs
"""
# Call the legacy get_similar_question_sql method
similar_results = self.vn.get_similar_question_sql(question=question)
# Convert legacy results to ToolMemorySearchResult format
memory_results = []
for idx, result in enumerate(similar_results):
# Legacy results are typically dicts with 'question' and 'sql' keys
if isinstance(result, dict) and "question" in result and "sql" in result:
tool_memory = ToolMemory(
memory_id=None, # Legacy doesn't provide IDs
question=result["question"],
tool_name="run_sql",
args={"sql": result["sql"]},
success=True,
)
# Assign a simple rank-based similarity score
# Legacy system doesn't provide actual similarity scores
similarity_score = 1.0 - (idx * 0.1) # Decreasing score by rank
similarity_score = max(similarity_score, 0.0)
memory_results.append(
ToolMemorySearchResult(
memory=tool_memory,
similarity_score=similarity_score,
rank=idx + 1,
)
)
return memory_results[:limit]
async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
"""Save text memory using legacy add_documentation method.
Args:
content: The documentation content to save
context: Tool execution context (not used by legacy implementation)
Returns:
TextMemory object with the saved content
"""
# Call the legacy add_documentation method
# The legacy method is synchronous, so we call it directly
doc_id = self.vn.add_documentation(documentation=content)
return TextMemory(
memory_id=doc_id,
content=content,
timestamp=None, # Legacy doesn't provide timestamps
)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
"""Search text memories using legacy get_related_documentation method.
Args:
query: The query to search for
context: Tool execution context (not used by legacy implementation)
limit: Maximum number of results (not directly supported by legacy)
similarity_threshold: Minimum similarity score (not directly supported by legacy)
Returns:
List of text memory search results
"""
# Call the legacy get_related_documentation method
related_docs = self.vn.get_related_documentation(question=query)
# Convert legacy results to TextMemorySearchResult format
memory_results = []
for idx, doc in enumerate(related_docs):
# Legacy results are typically strings or dicts
if isinstance(doc, str):
content = doc
doc_id = None
elif isinstance(doc, dict):
content = str(doc.get("documentation", doc.get("content", str(doc))))
doc_id = doc.get("id")
else:
content = str(doc)
doc_id = None
# Create TextMemory object
text_memory = TextMemory(
memory_id=doc_id,
content=content,
timestamp=None, # Legacy doesn't provide timestamps
)
# Assign a simple rank-based similarity score
# Legacy system doesn't provide actual similarity scores
similarity_score = 1.0 - (idx * 0.1) # Decreasing score by rank
similarity_score = max(similarity_score, 0.0)
if similarity_score >= similarity_threshold:
memory_results.append(
TextMemorySearchResult(
memory=text_memory,
similarity_score=similarity_score,
rank=idx + 1,
)
)
return memory_results[:limit]
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
"""Get recently added memories.
Note: Legacy VannaBase does not provide a direct way to get recent memories,
so we retrieve using a blank string which typically returns the most relevant
or recent items from the vector store.
Args:
context: Tool execution context
limit: Maximum number of memories to return
Returns:
List of recently added tool memories
"""
# Use blank string retrieval to get recent/relevant memories
similar_results = self.vn.get_similar_question_sql(question="")
# Convert legacy results to ToolMemory format
memories = []
for idx, result in enumerate(similar_results[:limit]):
# Legacy results are typically dicts with 'question' and 'sql' keys
if isinstance(result, dict) and "question" in result and "sql" in result:
tool_memory = ToolMemory(
memory_id=None, # Legacy doesn't provide IDs
question=result["question"],
tool_name="run_sql",
args={"sql": result["sql"]},
success=True,
)
memories.append(tool_memory)
return memories
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
"""Fetch recently stored text memories.
Note: Legacy VannaBase does not provide a direct way to get recent text memories,
so we retrieve using a blank string which typically returns the most relevant
or recent items from the vector store.
Args:
context: Tool execution context
limit: Maximum number of memories to return
Returns:
List of recently added text memories
"""
# Use blank string retrieval to get recent/relevant documentation
related_docs = self.vn.get_related_documentation(question="")
# Convert legacy results to TextMemory format
memories = []
for doc in related_docs[:limit]:
# Legacy results are typically strings or dicts
if isinstance(doc, str):
content = doc
doc_id = None
elif isinstance(doc, dict):
content = str(doc.get("documentation", doc.get("content", str(doc))))
doc_id = doc.get("id")
else:
content = str(doc)
doc_id = None
# Create TextMemory object
text_memory = TextMemory(
memory_id=doc_id,
content=content,
timestamp=None, # Legacy doesn't provide timestamps
)
memories.append(text_memory)
return memories
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a memory by its ID using legacy remove_training_data method.
Args:
context: Tool execution context
memory_id: ID of the memory to delete
Returns:
True if the memory was deleted, False otherwise
"""
# Call the legacy remove_training_data method
# The legacy method is synchronous, so we call it directly
return self.vn.remove_training_data(id=memory_id)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
"""Delete a text memory by its ID using legacy remove_training_data method.
Args:
context: Tool execution context
memory_id: ID of the text memory to delete
Returns:
True if the text memory was deleted, False otherwise
"""
# Call the legacy remove_training_data method
# The legacy method is synchronous, so we call it directly
return self.vn.remove_training_data(id=memory_id)
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
"""Clear stored memories.
Note: Legacy VannaBase does not provide a direct clear method,
so this operation is not supported.
Args:
context: Tool execution context
tool_name: Optional tool name filter
before_date: Optional date filter
Returns:
0 (operation not supported by legacy)
"""
return 0
# Example stub for a tool wrapper (to be expanded)
# You can copy and customize this pattern for each tool you want to expose
"""
class ExampleTool(Tool[ExampleToolArgs]):
def __init__(self, vanna: VannaBase):
self.vanna = vanna
@property
def name(self) -> str:
return "example_tool"
@property
def description(self) -> str:
return "Example tool description"
@property
def access_groups(self) -> List[str]:
# This is optional - will be overridden by register_local_tool
return []
def get_args_schema(self) -> type[ExampleToolArgs]:
return ExampleToolArgs
async def execute(
self,
context: ToolContext,
args: ExampleToolArgs
) -> ToolResult:
# Call the legacy VannaBase method
result = self.vanna.example_method(args.param1, args.param2)
return ToolResult(
success=True,
result_for_llm=result,
ui_component=None,
)
"""
================================================
FILE: src/vanna/legacy/advanced/__init__.py
================================================
from abc import ABC, abstractmethod
class VannaAdvanced(ABC):
def __init__(self, config=None):
self.config = config
@abstractmethod
def get_function(self, question: str, additional_data: dict = {}) -> dict:
pass
@abstractmethod
def create_function(
self, question: str, sql: str, plotly_code: str, **kwargs
) -> dict:
pass
@abstractmethod
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
pass
@abstractmethod
def delete_function(self, function_name: str) -> bool:
pass
@abstractmethod
def get_all_functions(self) -> list:
pass
================================================
FILE: src/vanna/legacy/anthropic/__init__.py
================================================
from .anthropic_chat import Anthropic_Chat
================================================
FILE: src/vanna/legacy/anthropic/anthropic_chat.py
================================================
import os
import anthropic
from ..base import VannaBase
class Anthropic_Chat(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)
# default parameters - can be overrided using config
self.temperature = 0.7
self.max_tokens = 500
if "temperature" in config:
self.temperature = config["temperature"]
if "max_tokens" in config:
self.max_tokens = config["max_tokens"]
if client is not None:
self.client = client
return
if config is None and client is None:
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
return
if "api_key" in config:
self.client = anthropic.Anthropic(api_key=config["api_key"])
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")
if len(prompt) == 0:
raise Exception("Prompt is empty")
# Count the number of tokens in the message log
# Use 4 as an approximation for the number of characters per token
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
if self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
# claude required system message is a single filed
# https://docs.anthropic.com/claude/reference/messages_post
system_message = ""
no_system_prompt = []
for prompt_message in prompt:
role = prompt_message["role"]
if role == "system":
system_message = prompt_message["content"]
else:
no_system_prompt.append(
{"role": role, "content": prompt_message["content"]}
)
response = self.client.messages.create(
model=self.config["model"],
messages=no_system_prompt,
system=system_message,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
return response.content[0].text
================================================
FILE: src/vanna/legacy/azuresearch/__init__.py
================================================
from .azuresearch_vector import AzureAISearch_VectorStore
================================================
FILE: src/vanna/legacy/azuresearch/azuresearch_vector.py
================================================
import ast
import json
from typing import List
import pandas as pd
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
SearchableField,
SearchField,
SearchFieldDataType,
SearchIndex,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from azure.search.documents.models import VectorFilterMode, VectorizedQuery
from fastembed import TextEmbedding
from ..base import VannaBase
from ..utils import deterministic_uuid
class AzureAISearch_VectorStore(VannaBase):
"""
AzureAISearch_VectorStore is a class that provides a vector store for Azure AI Search.
Args:
config (dict): Configuration dictionary. Defaults to {}. You must provide an API key in the config.
- azure_search_endpoint (str, optional): Azure Search endpoint. Defaults to "https://azcognetive.search.windows.net".
- azure_search_api_key (str): Azure Search API key.
- dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which corresponds to the dimensions of BAAI/bge-small-en-v1.5.
- fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
- index_name (str, optional): Name of the index. Defaults to "vanna-index".
- n_results (int, optional): Number of results to return. Defaults to 10.
- n_results_ddl (int, optional): Number of results to return for DDL queries. Defaults to the value of n_results.
- n_results_sql (int, optional): Number of results to return for SQL queries. Defaults to the value of n_results.
- n_results_documentation (int, optional): Number of results to return for documentation queries. Defaults to the value of n_results.
Raises:
ValueError: If config is None, or if 'azure_search_api_key' is not provided in the config.
"""
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
self.config = config or None
if config is None:
raise ValueError(
"config is required, pass an API key, 'azure_search_api_key', in the config."
)
azure_search_endpoint = config.get(
"azure_search_endpoint", "https://azcognetive.search.windows.net"
)
azure_search_api_key = config.get("azure_search_api_key")
self.dimensions = config.get("dimensions", 384)
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
self.index_name = config.get("index_name", "vanna-index")
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_documentation = config.get(
"n_results_documentation", config.get("n_results", 10)
)
if not azure_search_api_key:
raise ValueError(
"'azure_search_api_key' is required in config to use AzureAISearch_VectorStore"
)
self.index_client = SearchIndexClient(
endpoint=azure_search_endpoint,
credential=AzureKeyCredential(azure_search_api_key),
)
self.search_client = SearchClient(
endpoint=azure_search_endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(azure_search_api_key),
)
if self.index_name not in self._get_indexes():
self._create_index()
def _create_index(self) -> bool:
fields = [
SearchableField(
name="id", type=SearchFieldDataType.String, key=True, filterable=True
),
SearchableField(
name="document",
type=SearchFieldDataType.String,
searchable=True,
filterable=True,
),
SearchField(
name="type",
type=SearchFieldDataType.String,
filterable=True,
searchable=True,
),
SearchField(
name="document_vector",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=self.dimensions,
vector_search_profile_name="ExhaustiveKnnProfile",
),
]
vector_search = VectorSearch(
algorithms=[
ExhaustiveKnnAlgorithmConfiguration(
name="ExhaustiveKnn",
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
parameters=ExhaustiveKnnParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
)
],
profiles=[
VectorSearchProfile(
name="ExhaustiveKnnProfile",
algorithm_configuration_name="ExhaustiveKnn",
)
],
)
index = SearchIndex(
name=self.index_name, fields=fields, vector_search=vector_search
)
result = self.index_client.create_or_update_index(index)
print(f"{result.name} created")
def _get_indexes(self) -> list:
return [index for index in self.index_client.list_index_names()]
def add_ddl(self, ddl: str) -> str:
id = deterministic_uuid(ddl) + "-ddl"
document = {
"id": id,
"document": ddl,
"type": "ddl",
"document_vector": self.generate_embedding(ddl),
}
self.search_client.upload_documents(documents=[document])
return id
def add_documentation(self, doc: str) -> str:
id = deterministic_uuid(doc) + "-doc"
document = {
"id": id,
"document": doc,
"type": "doc",
"document_vector": self.generate_embedding(doc),
}
self.search_client.upload_documents(documents=[document])
return id
def add_question_sql(self, question: str, sql: str) -> str:
question_sql_json = json.dumps(
{"question": question, "sql": sql}, ensure_ascii=False
)
id = deterministic_uuid(question_sql_json) + "-sql"
document = {
"id": id,
"document": question_sql_json,
"type": "sql",
"document_vector": self.generate_embedding(question_sql_json),
}
self.search_client.upload_documents(documents=[document])
return id
def get_related_ddl(self, text: str) -> List[str]:
result = []
vector_query = VectorizedQuery(
vector=self.generate_embedding(text), fields="document_vector"
)
df = pd.DataFrame(
self.search_client.search(
top=self.n_results_ddl,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'ddl'",
)
)
if len(df):
result = df["document"].tolist()
return result
def get_related_documentation(self, text: str) -> List[str]:
result = []
vector_query = VectorizedQuery(
vector=self.generate_embedding(text), fields="document_vector"
)
df = pd.DataFrame(
self.search_client.search(
top=self.n_results_documentation,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'doc'",
vector_filter_mode=VectorFilterMode.PRE_FILTER,
)
)
if len(df):
result = df["document"].tolist()
return result
def get_similar_question_sql(self, question: str) -> List[str]:
result = []
# Vectorize the text
vector_query = VectorizedQuery(
vector=self.generate_embedding(question), fields="document_vector"
)
df = pd.DataFrame(
self.search_client.search(
top=self.n_results_sql,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'sql'",
)
)
if len(df): # Check if there is similar query and the result is not empty
result = [ast.literal_eval(element) for element in df["document"].tolist()]
return result
def get_training_data(self) -> List[str]:
search = self.search_client.search(
search_text="*",
select=["id", "document", "type"],
filter=f"(type eq 'sql') or (type eq 'ddl') or (type eq 'doc')",
).by_page()
df = pd.DataFrame([item for page in search for item in page])
if len(df):
df.loc[df["type"] == "sql", "question"] = df.loc[df["type"] == "sql"][
"document"
].apply(lambda x: json.loads(x)["question"])
df.loc[df["type"] == "sql", "content"] = df.loc[df["type"] == "sql"][
"document"
].apply(lambda x: json.loads(x)["sql"])
df.loc[df["type"] != "sql", "content"] = df.loc[df["type"] != "sql"][
"document"
]
return df[["id", "question", "content", "type"]]
return pd.DataFrame()
def remove_training_data(self, id: str) -> bool:
result = self.search_client.delete_documents(documents=[{"id": id}])
return result[0].succeeded
def remove_index(self):
self.index_client.delete_index(self.index_name)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = TextEmbedding(model_name=self.fastembed_model)
embedding = next(embedding_model.embed(data))
return embedding.tolist()
================================================
FILE: src/vanna/legacy/base/__init__.py
================================================
from .base import VannaBase
================================================
FILE: src/vanna/legacy/base/base.py
================================================
r"""
# Nomenclature
| Prefix | Definition | Examples |
| --- | --- | --- |
| `vn.get_` | Fetch some data | [`vn.get_related_ddl(...)`][vanna.base.base.VannaBase.get_related_ddl] |
| `vn.add_` | Adds something to the retrieval layer | [`vn.add_question_sql(...)`][vanna.base.base.VannaBase.add_question_sql] [`vn.add_ddl(...)`][vanna.base.base.VannaBase.add_ddl] |
| `vn.generate_` | Generates something using AI based on the information in the model | [`vn.generate_sql(...)`][vanna.base.base.VannaBase.generate_sql] [`vn.generate_explanation()`][vanna.base.base.VannaBase.generate_explanation] |
| `vn.run_` | Runs code (SQL) | [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] |
| `vn.remove_` | Removes something from the retrieval layer | [`vn.remove_training_data`][vanna.base.base.VannaBase.remove_training_data] |
| `vn.connect_` | Connects to a database | [`vn.connect_to_snowflake(...)`][vanna.base.base.VannaBase.connect_to_snowflake] |
| `vn.update_` | Updates something | N/A -- unused |
| `vn.set_` | Sets something | N/A -- unused |
# Open-Source and Extending
Vanna.AI is open-source and extensible. If you'd like to use Vanna without the servers, see an example [here](https://vanna.ai/docs/postgres-ollama-chromadb/).
The following is an example of where various functions are implemented in the codebase when using the default "local" version of Vanna. `vanna.base.VannaBase` is the base class which provides a `vanna.base.VannaBase.ask` and `vanna.base.VannaBase.train` function. Those rely on abstract methods which are implemented in the subclasses `vanna.openai_chat.OpenAI_Chat` and `vanna.chromadb_vector.ChromaDB_VectorStore`. `vanna.openai_chat.OpenAI_Chat` uses the OpenAI API to generate SQL and Plotly code. `vanna.chromadb_vector.ChromaDB_VectorStore` uses ChromaDB to store training data and generate embeddings.
If you want to use Vanna with other LLMs or databases, you can create your own subclass of `vanna.base.VannaBase` and implement the abstract methods.
```mermaid
flowchart
subgraph VannaBase
ask
train
end
subgraph OpenAI_Chat
get_sql_prompt
submit_prompt
generate_question
generate_plotly_code
end
subgraph ChromaDB_VectorStore
generate_embedding
add_question_sql
add_ddl
add_documentation
get_similar_question_sql
get_related_ddl
get_related_documentation
end
```
"""
import json
import os
import re
import sqlite3
import traceback
from abc import ABC, abstractmethod
from typing import List, Tuple, Union
from urllib.parse import urlparse
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go
import requests
import sqlparse
from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
from ..utils import validate_config_path
class VannaBase(ABC):
def __init__(self, config=None):
if config is None:
config = {}
self.config = config
self.run_sql_is_set = False
self.static_documentation = ""
self.dialect = self.config.get("dialect", "SQL")
self.language = self.config.get("language", None)
self.max_tokens = self.config.get("max_tokens", 14000)
def log(self, message: str, title: str = "Info"):
print(f"{title}: {message}")
def _response_language(self) -> str:
if self.language is None:
return ""
return f"Respond in the {self.language} language."
def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
"""
Example:
```python
vn.generate_sql("What are the top 10 customers by sales?")
```
Uses the LLM to generate a SQL query that answers a question. It runs the following methods:
- [`get_similar_question_sql`][vanna.base.base.VannaBase.get_similar_question_sql]
- [`get_related_ddl`][vanna.base.base.VannaBase.get_related_ddl]
- [`get_related_documentation`][vanna.base.base.VannaBase.get_related_documentation]
- [`get_sql_prompt`][vanna.base.base.VannaBase.get_sql_prompt]
- [`submit_prompt`][vanna.base.base.VannaBase.submit_prompt]
Args:
question (str): The question to generate a SQL query for.
allow_llm_to_see_data (bool): Whether to allow the LLM to see the data (for the purposes of introspecting the data to generate the final SQL).
Returns:
str: The SQL query that answers the question.
"""
if self.config is not None:
initial_prompt = self.config.get("initial_prompt", None)
else:
initial_prompt = None
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
self.log(title="SQL Prompt", message=prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(title="LLM Response", message=llm_response)
if "intermediate_sql" in llm_response:
if not allow_llm_to_see_data:
return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."
if allow_llm_to_see_data:
intermediate_sql = self.extract_sql(llm_response)
try:
self.log(title="Running Intermediate SQL", message=intermediate_sql)
df = self.run_sql(intermediate_sql)
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list
+ [
f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n"
+ df.to_markdown()
],
**kwargs,
)
self.log(title="Final SQL Prompt", message=prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(title="LLM Response", message=llm_response)
except Exception as e:
return f"Error running intermediate SQL: {e}"
return self.extract_sql(llm_response)
def extract_sql(self, llm_response: str) -> str:
"""
Example:
```python
vn.extract_sql("Here's the SQL query in a code block: ```sql\nSELECT * FROM customers\n```")
```
Extracts the SQL query from the LLM response. This is useful in case the LLM response contains other information besides the SQL query.
Override this function if your LLM responses need custom extraction logic.
Args:
llm_response (str): The LLM response.
Returns:
str: The extracted SQL query.
"""
import re
"""
Extracts the SQL query from the LLM response, handling various formats including:
- WITH clause
- SELECT statement
- CREATE TABLE AS SELECT
- Markdown code blocks
"""
# Match CREATE TABLE ... AS SELECT
sqls = re.findall(
r"\bCREATE\s+TABLE\b.*?\bAS\b.*?;", llm_response, re.DOTALL | re.IGNORECASE
)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql
# Match WITH clause (CTEs)
sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql
# Match SELECT ... ;
sqls = re.findall(r"\bSELECT\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql
# Match ```sql ... ``` blocks
sqls = re.findall(
r"```sql\s*\n(.*?)```", llm_response, re.DOTALL | re.IGNORECASE
)
if sqls:
sql = sqls[-1].strip()
self.log(title="Extracted SQL", message=f"{sql}")
return sql
# Match any ``` ... ``` code blocks
sqls = re.findall(r"```(.*?)```", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1].strip()
self.log(title="Extracted SQL", message=f"{sql}")
return sql
return llm_response
def is_sql_valid(self, sql: str) -> bool:
"""
Example:
```python
vn.is_sql_valid("SELECT * FROM customers")
```
Checks if the SQL query is valid. This is usually used to check if we should run the SQL query or not.
By default it checks if the SQL query is a SELECT statement. You can override this method to enable running other types of SQL queries.
Args:
sql (str): The SQL query to check.
Returns:
bool: True if the SQL query is valid, False otherwise.
"""
parsed = sqlparse.parse(sql)
for statement in parsed:
if statement.get_type() == "SELECT":
return True
return False
def should_generate_chart(self, df: pd.DataFrame) -> bool:
"""
Example:
```python
vn.should_generate_chart(df)
```
Checks if a chart should be generated for the given DataFrame. By default, it checks if the DataFrame has more than one row and has numerical columns.
You can override this method to customize the logic for generating charts.
Args:
df (pd.DataFrame): The DataFrame to check.
Returns:
bool: True if a chart should be generated, False otherwise.
"""
if len(df) > 1 and df.select_dtypes(include=["number"]).shape[1] > 0:
return True
return False
def generate_rewritten_question(
self, last_question: str, new_question: str, **kwargs
) -> str:
"""
**Example:**
```python
rewritten_question = vn.generate_rewritten_question("Who are the top 5 customers by sales?", "Show me their email addresses")
```
Generate a rewritten question by combining the last question and the new question if they are related. If the new question is self-contained and not related to the last question, return the new question.
Args:
last_question (str): The previous question that was asked.
new_question (str): The new question to be combined with the last question.
**kwargs: Additional keyword arguments.
Returns:
str: The combined question if related, otherwise the new question.
"""
if last_question is None:
return new_question
prompt = [
self.system_message(
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."
),
self.user_message(
"First question: "
+ last_question
+ "\nSecond question: "
+ new_question
),
]
return self.submit_prompt(prompt=prompt, **kwargs)
def generate_followup_questions(
self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
) -> list:
"""
**Example:**
```python
vn.generate_followup_questions("What are the top 10 customers by sales?", sql, df)
```
Generate a list of followup questions that you can ask Vanna.AI.
Args:
question (str): The question that was asked.
sql (str): The LLM-generated SQL query.
df (pd.DataFrame): The results of the SQL query.
n_questions (int): Number of follow-up questions to generate.
Returns:
list: A list of followup questions that you can ask Vanna.AI.
"""
message_log = [
self.system_message(
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.head(25).to_markdown()}\n\n"
),
self.user_message(
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
+ self._response_language()
),
]
llm_response = self.submit_prompt(message_log, **kwargs)
numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
return numbers_removed.split("\n")
def generate_questions(self, **kwargs) -> List[str]:
"""
**Example:**
```python
vn.generate_questions()
```
Generate a list of questions that you can ask Vanna.AI.
"""
question_sql = self.get_similar_question_sql(question="", **kwargs)
return [q["question"] for q in question_sql]
def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str:
"""
**Example:**
```python
vn.generate_summary("What are the top 10 customers by sales?", df)
```
Generate a summary of the results of a SQL query.
Args:
question (str): The question that was asked.
df (pd.DataFrame): The results of the SQL query.
Returns:
str: The summary of the results of the SQL query.
"""
message_log = [
self.system_message(
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
),
self.user_message(
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
+ self._response_language()
),
]
summary = self.submit_prompt(message_log, **kwargs)
return summary
# ----------------- Use Any Embeddings API ----------------- #
@abstractmethod
def generate_embedding(self, data: str, **kwargs) -> List[float]:
pass
# ----------------- Use Any Database to Store and Retrieve Context ----------------- #
@abstractmethod
def get_similar_question_sql(self, question: str, **kwargs) -> list:
"""
This method is used to get similar questions and their corresponding SQL statements.
Args:
question (str): The question to get similar questions and their corresponding SQL statements for.
Returns:
list: A list of similar questions and their corresponding SQL statements.
"""
pass
@abstractmethod
def get_related_ddl(self, question: str, **kwargs) -> list:
"""
This method is used to get related DDL statements to a question.
Args:
question (str): The question to get related DDL statements for.
Returns:
list: A list of related DDL statements.
"""
pass
@abstractmethod
def get_related_documentation(self, question: str, **kwargs) -> list:
"""
This method is used to get related documentation to a question.
Args:
question (str): The question to get related documentation for.
Returns:
list: A list of related documentation.
"""
pass
@abstractmethod
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
"""
This method is used to add a question and its corresponding SQL query to the training data.
Args:
question (str): The question to add.
sql (str): The SQL query to add.
Returns:
str: The ID of the training data that was added.
"""
pass
@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.
Args:
ddl (str): The DDL statement to add.
Returns:
str: The ID of the training data that was added.
"""
pass
@abstractmethod
def add_documentation(self, documentation: str, **kwargs) -> str:
"""
This method is used to add documentation to the training data.
Args:
documentation (str): The documentation to add.
Returns:
str: The ID of the training data that was added.
"""
pass
@abstractmethod
def get_training_data(self, **kwargs) -> pd.DataFrame:
"""
Example:
```python
vn.get_training_data()
```
This method is used to get all the training data from the retrieval layer.
Returns:
pd.DataFrame: The training data.
"""
pass
@abstractmethod
def remove_training_data(self, id: str, **kwargs) -> bool:
"""
Example:
```python
vn.remove_training_data(id="123-ddl")
```
This method is used to remove training data from the retrieval layer.
Args:
id (str): The ID of the training data to remove.
Returns:
bool: True if the training data was removed, False otherwise.
"""
pass
# ----------------- Use Any Language Model API ----------------- #
@abstractmethod
def system_message(self, message: str) -> any:
pass
@abstractmethod
def user_message(self, message: str) -> any:
pass
@abstractmethod
def assistant_message(self, message: str) -> any:
pass
def str_to_approx_token_count(self, string: str) -> int:
return len(string) / 4
def add_ddl_to_prompt(
self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
) -> str:
if len(ddl_list) > 0:
initial_prompt += "\n===Tables \n"
for ddl in ddl_list:
if (
self.str_to_approx_token_count(initial_prompt)
+ self.str_to_approx_token_count(ddl)
< max_tokens
):
initial_prompt += f"{ddl}\n\n"
return initial_prompt
def add_documentation_to_prompt(
self,
initial_prompt: str,
documentation_list: list[str],
max_tokens: int = 14000,
) -> str:
if len(documentation_list) > 0:
initial_prompt += "\n===Additional Context \n\n"
for documentation in documentation_list:
if (
self.str_to_approx_token_count(initial_prompt)
+ self.str_to_approx_token_count(documentation)
< max_tokens
):
initial_prompt += f"{documentation}\n\n"
return initial_prompt
def add_sql_to_prompt(
self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
) -> str:
if len(sql_list) > 0:
initial_prompt += "\n===Question-SQL Pairs\n\n"
for question in sql_list:
if (
self.str_to_approx_token_count(initial_prompt)
+ self.str_to_approx_token_count(question["sql"])
< max_tokens
):
initial_prompt += f"{question['question']}\n{question['sql']}\n\n"
return initial_prompt
def get_sql_prompt(
self,
initial_prompt: str,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
):
"""
Example:
```python
vn.get_sql_prompt(
question="What are the top 10 customers by sales?",
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
doc_list=["The customers table contains information about customers and their sales."],
)
```
This method is used to generate a prompt for the LLM to generate SQL.
Args:
question (str): The question to generate SQL for.
question_sql_list (list): A list of questions and their corresponding SQL statements.
ddl_list (list): A list of DDL statements.
doc_list (list): A list of documentation.
Returns:
any: The prompt for the LLM to generate SQL.
"""
if initial_prompt is None:
initial_prompt = (
f"You are a {self.dialect} expert. "
+ "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
)
initial_prompt = self.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=self.max_tokens
)
if self.static_documentation != "":
doc_list.append(self.static_documentation)
initial_prompt = self.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=self.max_tokens
)
initial_prompt += (
"===Response Guidelines \n"
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
"3. If the provided context is insufficient, please explain why it can't be generated. \n"
"4. Please use the most relevant table(s). \n"
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
)
message_log = [self.system_message(initial_prompt)]
for example in question_sql_list:
if example is None:
print("example is None")
else:
if example is not None and "question" in example and "sql" in example:
message_log.append(self.user_message(example["question"]))
message_log.append(self.assistant_message(example["sql"]))
message_log.append(self.user_message(question))
return message_log
def get_followup_questions_prompt(
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
) -> list:
initial_prompt = f"The user initially asked the question: '{question}': \n\n"
initial_prompt = self.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=self.max_tokens
)
initial_prompt = self.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=self.max_tokens
)
initial_prompt = self.add_sql_to_prompt(
initial_prompt, question_sql_list, max_tokens=self.max_tokens
)
message_log = [self.system_message(initial_prompt)]
message_log.append(
self.user_message(
"Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions."
)
)
return message_log
@abstractmethod
def submit_prompt(self, prompt, **kwargs) -> str:
"""
Example:
```python
vn.submit_prompt(
[
vn.system_message("The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."),
vn.user_message("What are the top 10 customers by sales?"),
]
)
```
This method is used to submit a prompt to the LLM.
Args:
prompt (any): The prompt to submit to the LLM.
Returns:
str: The response from the LLM.
"""
pass
def generate_question(self, sql: str, **kwargs) -> str:
response = self.submit_prompt(
[
self.system_message(
"The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
),
self.user_message(sql),
],
**kwargs,
)
return response
def _extract_python_code(self, markdown_string: str) -> str:
# Strip whitespace to avoid indentation errors in LLM-generated code
markdown_string = markdown_string.strip()
# Regex pattern to match Python code blocks
pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"
# Find all matches in the markdown string
matches = re.findall(pattern, markdown_string, re.IGNORECASE)
# Extract the Python code from the matches
python_code = []
for match in matches:
python = match[0] if match[0] else match[1]
python_code.append(python.strip())
if len(python_code) == 0:
return markdown_string
return python_code[0]
def _sanitize_plotly_code(self, raw_plotly_code: str) -> str:
# Remove the fig.show() statement from the plotly code
plotly_code = raw_plotly_code.replace("fig.show()", "")
return plotly_code
def generate_plotly_code(
self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs
) -> str:
if question is not None:
system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'"
else:
system_msg = "The following is a pandas DataFrame "
if sql is not None:
system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n"
system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}"
message_log = [
self.system_message(system_msg),
self.user_message(
"Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
),
]
plotly_code = self.submit_prompt(message_log, kwargs=kwargs)
return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
# ----------------- Connect to Any Database to run the Generated SQL ----------------- #
def connect_to_snowflake(
self,
account: str,
username: str,
password: str,
database: str,
role: Union[str, None] = None,
warehouse: Union[str, None] = None,
**kwargs,
):
try:
snowflake = __import__("snowflake.connector")
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method, run command:"
" \npip install vanna[snowflake]"
)
if username == "my-username":
username_env = os.getenv("SNOWFLAKE_USERNAME")
if username_env is not None:
username = username_env
else:
raise ImproperlyConfigured("Please set your Snowflake username.")
if password == "mypassword":
password_env = os.getenv("SNOWFLAKE_PASSWORD")
if password_env is not None:
password = password_env
else:
raise ImproperlyConfigured("Please set your Snowflake password.")
if account == "my-account":
account_env = os.getenv("SNOWFLAKE_ACCOUNT")
if account_env is not None:
account = account_env
else:
raise ImproperlyConfigured("Please set your Snowflake account.")
if database == "my-database":
database_env = os.getenv("SNOWFLAKE_DATABASE")
if database_env is not None:
database = database_env
else:
raise ImproperlyConfigured("Please set your Snowflake database.")
conn = snowflake.connector.connect(
user=username,
password=password,
account=account,
database=database,
client_session_keep_alive=True,
**kwargs,
)
def run_sql_snowflake(sql: str) -> pd.DataFrame:
cs = conn.cursor()
if role is not None:
cs.execute(f"USE ROLE {role}")
if warehouse is not None:
cs.execute(f"USE WAREHOUSE {warehouse}")
cs.execute(f"USE DATABASE {database}")
cur = cs.execute(sql)
results = cur.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cur.description])
return df
self.dialect = "Snowflake SQL"
self.run_sql = run_sql_snowflake
self.run_sql_is_set = True
def connect_to_sqlite(self, url: str, check_same_thread: bool = False, **kwargs):
"""
Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Args:
url (str): The URL of the database to connect to.
check_same_thread (str): Allow the connection may be accessed in multiple threads.
Returns:
None
"""
# URL of the database to download
# Path to save the downloaded database
path = os.path.basename(urlparse(url).path)
# Download the database if it doesn't exist
if not os.path.exists(url):
response = requests.get(url)
response.raise_for_status() # Check that the request was successful
with open(path, "wb") as f:
f.write(response.content)
url = path
# Connect to the database
conn = sqlite3.connect(url, check_same_thread=check_same_thread, **kwargs)
def run_sql_sqlite(sql: str):
return pd.read_sql_query(sql, conn)
self.dialect = "SQLite"
self.run_sql = run_sql_sqlite
self.run_sql_is_set = True
def connect_to_postgres(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs,
):
"""
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
**Example:**
```python
vn.connect_to_postgres(
host="myhost",
dbname="mydatabase",
user="myuser",
password="mypassword",
port=5432
)
```
Args:
host (str): The postgres host.
dbname (str): The postgres database name.
user (str): The postgres user.
password (str): The postgres password.
port (int): The postgres Port.
"""
try:
import psycopg2
import psycopg2.extras
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install vanna[postgres]"
)
if not host:
host = os.getenv("HOST")
if not host:
raise ImproperlyConfigured("Please set your postgres host")
if not dbname:
dbname = os.getenv("DATABASE")
if not dbname:
raise ImproperlyConfigured("Please set your postgres database")
if not user:
user = os.getenv("PG_USER")
if not user:
raise ImproperlyConfigured("Please set your postgres user")
if not password:
password = os.getenv("PASSWORD")
if not password:
raise ImproperlyConfigured("Please set your postgres password")
if not port:
port = os.getenv("PORT")
if not port:
raise ImproperlyConfigured("Please set your postgres port")
conn = None
try:
conn = psycopg2.connect(
host=host,
dbname=dbname,
user=user,
password=password,
port=port,
**kwargs,
)
except psycopg2.Error as e:
raise ValidationError(e)
def connect_to_db():
return psycopg2.connect(
host=host,
dbname=dbname,
user=user,
password=password,
port=port,
**kwargs,
)
def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
conn = None
try:
conn = connect_to_db() # Initial connection attempt
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
return df
except psycopg2.InterfaceError as e:
# Attempt to reconnect and retry the operation
if conn:
conn.close() # Ensure any existing connection is closed
conn = connect_to_db()
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
return df
except psycopg2.Error as e:
if conn:
conn.rollback()
raise ValidationError(e)
except Exception as e:
conn.rollback()
raise e
self.dialect = "PostgreSQL"
self.run_sql_is_set = True
self.run_sql = run_sql_postgres
def connect_to_mysql(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs,
):
try:
import pymysql.cursors
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install PyMySQL"
)
if not host:
host = os.getenv("HOST")
if not host:
raise ImproperlyConfigured("Please set your MySQL host")
if not dbname:
dbname = os.getenv("DATABASE")
if not dbname:
raise ImproperlyConfigured("Please set your MySQL database")
if not user:
user = os.getenv("USER")
if not user:
raise ImproperlyConfigured("Please set your MySQL user")
if not password:
password = os.getenv("PASSWORD")
if not password:
raise ImproperlyConfigured("Please set your MySQL password")
if not port:
port = os.getenv("PORT")
if not port:
raise ImproperlyConfigured("Please set your MySQL port")
conn = None
try:
conn = pymysql.connect(
host=host,
user=user,
password=password,
database=dbname,
port=port,
cursorclass=pymysql.cursors.DictCursor,
**kwargs,
)
except pymysql.Error as e:
raise ValidationError(e)
def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
conn.ping(reconnect=True)
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except pymysql.Error as e:
conn.rollback()
raise ValidationError(e)
except Exception as e:
conn.rollback()
raise e
self.run_sql_is_set = True
self.run_sql = run_sql_mysql
def connect_to_clickhouse(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs,
):
try:
import clickhouse_connect
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install clickhouse_connect"
)
if not host:
host = os.getenv("HOST")
if not host:
raise ImproperlyConfigured("Please set your ClickHouse host")
if not dbname:
dbname = os.getenv("DATABASE")
if not dbname:
raise ImproperlyConfigured("Please set your ClickHouse database")
if not user:
user = os.getenv("USER")
if not user:
raise ImproperlyConfigured("Please set your ClickHouse user")
if not password:
password = os.getenv("PASSWORD")
if not password:
raise ImproperlyConfigured("Please set your ClickHouse password")
if not port:
port = os.getenv("PORT")
if not port:
raise ImproperlyConfigured("Please set your ClickHouse port")
conn = None
try:
conn = clickhouse_connect.get_client(
host=host,
port=port,
username=user,
password=password,
database=dbname,
**kwargs,
)
print(conn)
except Exception as e:
raise ValidationError(e)
def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
result = conn.query(sql)
results = result.result_rows
# Create a pandas dataframe from the results
df = pd.DataFrame(results, columns=result.column_names)
return df
except Exception as e:
raise e
self.run_sql_is_set = True
self.run_sql = run_sql_clickhouse
def connect_to_oracle(
self, user: str = None, password: str = None, dsn: str = None, **kwargs
):
"""
Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
**Example:**
```python
vn.connect_to_oracle(
user="username",
password="password",
dsn="host:port/sid",
)
```
Args:
USER (str): Oracle db user name.
PASSWORD (str): Oracle db user password.
DSN (str): Oracle db host ip - host:port/sid.
"""
try:
import oracledb
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install oracledb"
)
if not dsn:
dsn = os.getenv("DSN")
if not dsn:
raise ImproperlyConfigured(
"Please set your Oracle dsn which should include host:port/sid"
)
if not user:
user = os.getenv("USER")
if not user:
raise ImproperlyConfigured("Please set your Oracle db user")
if not password:
password = os.getenv("PASSWORD")
if not password:
raise ImproperlyConfigured("Please set your Oracle db password")
conn = None
try:
conn = oracledb.connect(user=user, password=password, dsn=dsn, **kwargs)
except oracledb.Error as e:
raise ValidationError(e)
def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
sql = sql.rstrip()
if sql.endswith(
";"
): # fix for a known problem with Oracle db where an extra ; will cause an error.
sql = sql[:-1]
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except oracledb.Error as e:
conn.rollback()
raise ValidationError(e)
except Exception as e:
conn.rollback()
raise e
self.run_sql_is_set = True
self.run_sql = run_sql_oracle
def connect_to_bigquery(
self, cred_file_path: str = None, project_id: str = None, **kwargs
):
"""
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
**Example:**
```python
vn.connect_to_bigquery(
project_id="myprojectid",
cred_file_path="path/to/credentials.json",
)
```
Args:
project_id (str): The gcs project id.
cred_file_path (str): The gcs credential file path
"""
try:
from google.api_core.exceptions import GoogleAPIError
from google.cloud import bigquery
from google.oauth2 import service_account
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method, run command:"
" \npip install vanna[bigquery]"
)
if not project_id:
project_id = os.getenv("PROJECT_ID")
if not project_id:
raise ImproperlyConfigured("Please set your Google Cloud Project ID.")
import sys
if "google.colab" in sys.modules:
try:
from google.colab import auth
auth.authenticate_user()
except Exception as e:
raise ImproperlyConfigured(e)
else:
print("Not using Google Colab.")
conn = None
if not cred_file_path:
try:
conn = bigquery.Client(project=project_id)
except Exception:
print("Could not found any google cloud implicit credentials")
else:
# Validate file path and pemissions
validate_config_path(cred_file_path)
if not conn:
with open(cred_file_path, "r") as f:
credentials = service_account.Credentials.from_service_account_info(
json.loads(f.read()),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
try:
conn = bigquery.Client(
project=project_id, credentials=credentials, **kwargs
)
except Exception:
raise ImproperlyConfigured(
"Could not connect to bigquery please correct credentials"
)
def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
if conn:
job = conn.query(sql)
df = job.result().to_dataframe()
return df
return None
self.dialect = "BigQuery SQL"
self.run_sql_is_set = True
self.run_sql = run_sql_bigquery
def connect_to_duckdb(self, url: str, init_sql: str = None, **kwargs):
"""
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Args:
url (str): The URL of the database to connect to. Use :memory: to create an in-memory database. Use md: or motherduck: to use the MotherDuck database.
init_sql (str, optional): SQL to run when connecting to the database. Defaults to None.
Returns:
None
"""
try:
import duckdb
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install vanna[duckdb]"
)
# URL of the database to download
if url == ":memory:" or url == "":
path = ":memory:"
else:
# Path to save the downloaded database
print(os.path.exists(url))
if os.path.exists(url):
path = url
elif url.startswith("md") or url.startswith("motherduck"):
path = url
else:
path = os.path.basename(urlparse(url).path)
# Download the database if it doesn't exist
if not os.path.exists(path):
response = requests.get(url)
response.raise_for_status() # Check that the request was successful
with open(path, "wb") as f:
f.write(response.content)
# Connect to the database
conn = duckdb.connect(path, **kwargs)
if init_sql:
conn.query(init_sql)
def run_sql_duckdb(sql: str):
return conn.query(sql).to_df()
self.dialect = "DuckDB SQL"
self.run_sql = run_sql_duckdb
self.run_sql_is_set = True
def connect_to_mssql(self, odbc_conn_str: str, **kwargs):
"""
Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Args:
odbc_conn_str (str): The ODBC connection string.
Returns:
None
"""
try:
import pyodbc
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: pip install pyodbc"
)
try:
import sqlalchemy as sa
from sqlalchemy.engine import URL
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: pip install sqlalchemy"
)
connection_url = URL.create(
"mssql+pyodbc", query={"odbc_connect": odbc_conn_str}
)
from sqlalchemy import create_engine
engine = create_engine(connection_url, **kwargs)
def run_sql_mssql(sql: str):
# Execute the SQL statement and return the result as a pandas DataFrame
with engine.begin() as conn:
df = pd.read_sql_query(sa.text(sql), conn)
conn.close()
return df
raise Exception("Couldn't run sql")
self.dialect = "T-SQL / Microsoft SQL Server"
self.run_sql = run_sql_mssql
self.run_sql_is_set = True
def connect_to_presto(
self,
host: str,
catalog: str = "hive",
schema: str = "default",
user: str = None,
password: str = None,
port: int = None,
combined_pem_path: str = None,
protocol: str = "https",
requests_kwargs: dict = None,
**kwargs,
):
"""
Connect to a Presto database using the specified parameters.
Args:
host (str): The host address of the Presto database.
catalog (str): The catalog to use in the Presto environment.
schema (str): The schema to use in the Presto environment.
user (str): The username for authentication.
password (str): The password for authentication.
port (int): The port number for the Presto connection.
combined_pem_path (str): The path to the combined pem file for SSL connection.
protocol (str): The protocol to use for the connection (default is 'https').
requests_kwargs (dict): Additional keyword arguments for requests.
Raises:
DependencyError: If required dependencies are not installed.
ImproperlyConfigured: If essential configuration settings are missing.
Returns:
None
"""
try:
from pyhive import presto
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install pyhive"
)
if not host:
host = os.getenv("PRESTO_HOST")
if not host:
raise ImproperlyConfigured("Please set your presto host")
if not catalog:
catalog = os.getenv("PRESTO_CATALOG")
if not catalog:
raise ImproperlyConfigured("Please set your presto catalog")
if not user:
user = os.getenv("PRESTO_USER")
if not user:
raise ImproperlyConfigured("Please set your presto user")
if not password:
password = os.getenv("PRESTO_PASSWORD")
if not port:
port = os.getenv("PRESTO_PORT")
if not port:
raise ImproperlyConfigured("Please set your presto port")
conn = None
try:
if requests_kwargs is None and combined_pem_path is not None:
# use the combined pem file to verify the SSL connection
requests_kwargs = {
"verify": combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
}
conn = presto.Connection(
host=host,
username=user,
password=password,
catalog=catalog,
schema=schema,
port=port,
protocol=protocol,
requests_kwargs=requests_kwargs,
**kwargs,
)
except presto.Error as e:
raise ValidationError(e)
def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
sql = sql.rstrip()
# fix for a known problem with presto db where an extra ; will cause an error.
if sql.endswith(";"):
sql = sql[:-1]
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except presto.Error as e:
print(e)
raise ValidationError(e)
except Exception as e:
print(e)
raise e
self.run_sql_is_set = True
self.run_sql = run_sql_presto
def connect_to_hive(
self,
host: str = None,
dbname: str = "default",
user: str = None,
password: str = None,
port: int = None,
auth: str = "CUSTOM",
**kwargs,
):
"""
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Args:
host (str): The host of the Hive database.
dbname (str): The name of the database to connect to.
user (str): The username to use for authentication.
password (str): The password to use for authentication.
port (int): The port to use for the connection.
auth (str): The authentication method to use.
Returns:
None
"""
try:
from pyhive import hive
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install pyhive"
)
if not host:
host = os.getenv("HIVE_HOST")
if not host:
raise ImproperlyConfigured("Please set your hive host")
if not dbname:
dbname = os.getenv("HIVE_DATABASE")
if not dbname:
raise ImproperlyConfigured("Please set your hive database")
if not user:
user = os.getenv("HIVE_USER")
if not user:
raise ImproperlyConfigured("Please set your hive user")
if not password:
password = os.getenv("HIVE_PASSWORD")
if not port:
port = os.getenv("HIVE_PORT")
if not port:
raise ImproperlyConfigured("Please set your hive port")
conn = None
try:
conn = hive.Connection(
host=host,
username=user,
password=password,
database=dbname,
port=port,
auth=auth,
)
except hive.Error as e:
raise ValidationError(e)
def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except hive.Error as e:
print(e)
raise ValidationError(e)
except Exception as e:
print(e)
raise e
self.run_sql_is_set = True
self.run_sql = run_sql_hive
def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
"""
Example:
```python
vn.run_sql("SELECT * FROM my_table")
```
Run a SQL query on the connected database.
Args:
sql (str): The SQL query to run.
Returns:
pd.DataFrame: The results of the SQL query.
"""
raise Exception(
"You need to connect to a database first by running vn.connect_to_snowflake(), vn.connect_to_postgres(), similar function, or manually set vn.run_sql"
)
def ask(
self,
question: Union[str, None] = None,
print_results: bool = True,
auto_train: bool = True,
visualize: bool = True, # if False, will not generate plotly code
allow_llm_to_see_data: bool = False,
) -> Union[
Tuple[
Union[str, None],
Union[pd.DataFrame, None],
Union[plotly.graph_objs.Figure, None],
],
None,
]:
"""
**Example:**
```python
vn.ask("What are the top 10 customers by sales?")
```
Ask Vanna.AI a question and get the SQL query that answers it.
Args:
question (str): The question to ask.
print_results (bool): Whether to print the results of the SQL query.
auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query.
visualize (bool): Whether to generate plotly code and display the plotly figure.
Returns:
Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]: The SQL query, the results of the SQL query, and the plotly figure.
"""
if question is None:
question = input("Enter a question: ")
try:
sql = self.generate_sql(
question=question, allow_llm_to_see_data=allow_llm_to_see_data
)
except Exception as e:
print(e)
return None, None, None
if print_results:
try:
from IPython.display import Code, display
display(Code(sql))
except Exception as e:
print(sql)
if self.run_sql_is_set is False:
print("If you want to run the SQL query, connect to a database first.")
if print_results:
return None
else:
return sql, None, None
try:
df = self.run_sql(sql)
if print_results:
try:
display = __import__(
"IPython.display", fromList=["display"]
).display
display(df)
except Exception as e:
print(df)
if len(df) > 0 and auto_train:
self.add_question_sql(question=question, sql=sql)
# Only generate plotly code if visualize is True
if visualize:
try:
plotly_code = self.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
fig = self.get_plotly_figure(plotly_code=plotly_code, df=df)
if print_results:
try:
display = __import__(
"IPython.display", fromlist=["display"]
).display
Image = __import__(
"IPython.display", fromlist=["Image"]
).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except Exception as e:
fig.show()
except Exception as e:
# Print stack trace
traceback.print_stack()
traceback.print_exc()
print("Couldn't run plotly code: ", e)
if print_results:
return None
else:
return sql, df, None
else:
return sql, df, None
except Exception as e:
print("Couldn't run sql: ", e)
if print_results:
return None
else:
return sql, None, None
return sql, df, fig
def train(
self,
question: str = None,
sql: str = None,
ddl: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
"""
**Example:**
```python
vn.train()
```
Train Vanna.AI on a question and its corresponding SQL query.
If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database.
If you call it with the sql argument, it's equivalent to [`vn.add_question_sql()`][vanna.base.base.VannaBase.add_question_sql].
If you call it with the ddl argument, it's equivalent to [`vn.add_ddl()`][vanna.base.base.VannaBase.add_ddl].
If you call it with the documentation argument, it's equivalent to [`vn.add_documentation()`][vanna.base.base.VannaBase.add_documentation].
Additionally, you can pass a [`TrainingPlan`][vanna.types.TrainingPlan] object. Get a training plan with [`vn.get_training_plan_generic()`][vanna.base.base.VannaBase.get_training_plan_generic].
Args:
question (str): The question to train on.
sql (str): The SQL query to train on.
ddl (str): The DDL statement.
documentation (str): The documentation to train on.
plan (TrainingPlan): The training plan to train on.
"""
if question and not sql:
raise ValidationError("Please also provide a SQL query")
if documentation:
print("Adding documentation....")
return self.add_documentation(documentation)
if sql:
if question is None:
question = self.generate_question(sql)
print("Question generated with sql:", question, "\nAdding SQL...")
return self.add_question_sql(question=question, sql=sql)
if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl)
if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
self.add_question_sql(question=item.item_name, sql=item.item_value)
def _get_databases(self) -> List[str]:
try:
print("Trying INFORMATION_SCHEMA.DATABASES")
df_databases = self.run_sql("SELECT * FROM INFORMATION_SCHEMA.DATABASES")
except Exception as e:
print(e)
try:
print("Trying SHOW DATABASES")
df_databases = self.run_sql("SHOW DATABASES")
except Exception as e:
print(e)
return []
return df_databases["DATABASE_NAME"].unique().tolist()
def _get_information_schema_tables(self, database: str) -> pd.DataFrame:
df_tables = self.run_sql(f"SELECT * FROM {database}.INFORMATION_SCHEMA.TABLES")
return df_tables
def get_training_plan_generic(self, df) -> TrainingPlan:
"""
This method is used to generate a training plan from an information schema dataframe.
Basically what it does is breaks up INFORMATION_SCHEMA.COLUMNS into groups of table/column descriptions that can be used to pass to the LLM.
Args:
df (pd.DataFrame): The dataframe to generate the training plan from.
Returns:
TrainingPlan: The training plan.
"""
# For each of the following, we look at the df columns to see if there's a match:
database_column = df.columns[
df.columns.str.lower().str.contains("database")
| df.columns.str.lower().str.contains("table_catalog")
].to_list()[0]
schema_column = df.columns[
df.columns.str.lower().str.contains("table_schema")
].to_list()[0]
table_column = df.columns[
df.columns.str.lower().str.contains("table_name")
].to_list()[0]
columns = [database_column, schema_column, table_column]
candidates = ["column_name", "data_type", "comment"]
matches = df.columns.str.lower().str.contains("|".join(candidates), regex=True)
columns += df.columns[matches].to_list()
plan = TrainingPlan([])
for database in df[database_column].unique().tolist():
for schema in (
df.query(f'{database_column} == "{database}"')[schema_column]
.unique()
.tolist()
):
for table in (
df.query(
f'{database_column} == "{database}" and {schema_column} == "{schema}"'
)[table_column]
.unique()
.tolist()
):
df_columns_filtered_to_table = df.query(
f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"'
)
doc = f"The following columns are in the {table} table in the {database} database:\n\n"
doc += df_columns_filtered_to_table[columns].to_markdown()
plan._plan.append(
TrainingPlanItem(
item_type=TrainingPlanItem.ITEM_TYPE_IS,
item_group=f"{database}.{schema}",
item_name=table,
item_value=doc,
)
)
return plan
def get_training_plan_snowflake(
self,
filter_databases: Union[List[str], None] = None,
filter_schemas: Union[List[str], None] = None,
include_information_schema: bool = False,
use_historical_queries: bool = True,
) -> TrainingPlan:
plan = TrainingPlan([])
if self.run_sql_is_set is False:
raise ImproperlyConfigured("Please connect to a database first.")
if use_historical_queries:
try:
print("Trying query history")
df_history = self.run_sql(
""" select * from table(information_schema.query_history(result_limit => 5000)) order by start_time"""
)
df_history_filtered = df_history.query("ROWS_PRODUCED > 1")
if filter_databases is not None:
mask = (
df_history_filtered["QUERY_TEXT"]
.str.lower()
.apply(
lambda x: any(
s in x for s in [s.lower() for s in filter_databases]
)
)
)
df_history_filtered = df_history_filtered[mask]
if filter_schemas is not None:
mask = (
df_history_filtered["QUERY_TEXT"]
.str.lower()
.apply(
lambda x: any(
s in x for s in [s.lower() for s in filter_schemas]
)
)
)
df_history_filtered = df_history_filtered[mask]
if len(df_history_filtered) > 10:
df_history_filtered = df_history_filtered.sample(10)
for query in df_history_filtered["QUERY_TEXT"].unique().tolist():
plan._plan.append(
TrainingPlanItem(
item_type=TrainingPlanItem.ITEM_TYPE_SQL,
item_group="",
item_name=self.generate_question(query),
item_value=query,
)
)
except Exception as e:
print(e)
databases = self._get_databases()
for database in databases:
if filter_databases is not None and database not in filter_databases:
continue
try:
df_tables = self._get_information_schema_tables(database=database)
print(f"Trying INFORMATION_SCHEMA.COLUMNS for {database}")
df_columns = self.run_sql(
f"SELECT * FROM {database}.INFORMATION_SCHEMA.COLUMNS"
)
for schema in df_tables["TABLE_SCHEMA"].unique().tolist():
if filter_schemas is not None and schema not in filter_schemas:
continue
if (
not include_information_schema
and schema == "INFORMATION_SCHEMA"
):
continue
df_columns_filtered_to_schema = df_columns.query(
f"TABLE_SCHEMA == '{schema}'"
)
try:
tables = (
df_columns_filtered_to_schema["TABLE_NAME"]
.unique()
.tolist()
)
for table in tables:
df_columns_filtered_to_table = (
df_columns_filtered_to_schema.query(
f"TABLE_NAME == '{table}'"
)
)
doc = f"The following columns are in the {table} table in the {database} database:\n\n"
doc += df_columns_filtered_to_table[
[
"TABLE_CATALOG",
"TABLE_SCHEMA",
"TABLE_NAME",
"COLUMN_NAME",
"DATA_TYPE",
"COMMENT",
]
].to_markdown()
plan._plan.append(
TrainingPlanItem(
item_type=TrainingPlanItem.ITEM_TYPE_IS,
item_group=f"{database}.{schema}",
item_name=table,
item_value=doc,
)
)
except Exception as e:
print(e)
pass
except Exception as e:
print(e)
return plan
def get_plotly_figure(
self, plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
) -> plotly.graph_objs.Figure:
"""
**Example:**
```python
fig = vn.get_plotly_figure(
plotly_code="fig = px.bar(df, x='name', y='salary')",
df=df
)
fig.show()
```
Get a Plotly figure from a dataframe and Plotly code.
Args:
df (pd.DataFrame): The dataframe to use.
plotly_code (str): The Plotly code to use.
Returns:
plotly.graph_objs.Figure: The Plotly figure.
"""
ldict = {"df": df, "px": px, "go": go}
try:
exec(plotly_code, globals(), ldict)
fig = ldict.get("fig", None)
except Exception as e:
# Inspect data types
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
categorical_cols = df.select_dtypes(
include=["object", "category"]
).columns.tolist()
# Decision-making for plot type
if len(numeric_cols) >= 2:
# Use the first two numeric columns for a scatter plot
fig = px.scatter(df, x=numeric_cols[0], y=numeric_cols[1])
elif len(numeric_cols) == 1 and len(categorical_cols) >= 1:
# Use a bar plot if there's one numeric and one categorical column
fig = px.bar(df, x=categorical_cols[0], y=numeric_cols[0])
elif len(categorical_cols) >= 1 and df[categorical_cols[0]].nunique() < 10:
# Use a pie chart for categorical data with fewer unique values
fig = px.pie(df, names=categorical_cols[0])
else:
# Default to a simple line plot if above conditions are not met
fig = px.line(df)
if fig is None:
return None
if dark_mode:
fig.update_layout(template="plotly_dark")
return fig
================================================
FILE: src/vanna/legacy/bedrock/__init__.py
================================================
from .bedrock_converse import Bedrock_Converse
================================================
FILE: src/vanna/legacy/bedrock/bedrock_converse.py
================================================
from ..base import VannaBase
try:
import boto3
from botocore.exceptions import ClientError
except ImportError:
raise ImportError("Please install boto3 and botocore to use Amazon Bedrock models")
class Bedrock_Converse(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)
# default parameters
self.temperature = 0.0
self.max_tokens = 500
if client is None:
raise ValueError(
"A valid Bedrock runtime client must be provided to invoke Bedrock models"
)
else:
self.client = client
if config is None:
raise ValueError(
"Config is required with model_id and inference parameters"
)
if "modelId" not in config:
raise ValueError("config must contain a modelId to invoke")
else:
self.model = config["modelId"]
if "temperature" in config:
self.temperature = config["temperature"]
if "max_tokens" in config:
self.max_tokens = config["max_tokens"]
def system_message(self, message: str) -> dict:
return {"role": "system", "content": message}
def user_message(self, message: str) -> dict:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> dict:
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
inference_config = {
"temperature": self.temperature,
"maxTokens": self.max_tokens,
}
additional_model_fields = {
"top_p": 1, # setting top_p value for nucleus sampling
}
system_message = None
no_system_prompt = []
for prompt_message in prompt:
role = prompt_message["role"]
if role == "system":
system_message = prompt_message["content"]
else:
no_system_prompt.append(
{"role": role, "content": [{"text": prompt_message["content"]}]}
)
converse_api_params = {
"modelId": self.model,
"messages": no_system_prompt,
"inferenceConfig": inference_config,
"additionalModelRequestFields": additional_model_fields,
}
if system_message:
converse_api_params["system"] = [{"text": system_message}]
try:
response = self.client.converse(**converse_api_params)
text_content = response["output"]["message"]["content"][0]["text"]
return text_content
except ClientError as err:
message = err.response["Error"]["Message"]
raise Exception(f"A Bedrock client error occurred: {message}")
================================================
FILE: src/vanna/legacy/chromadb/__init__.py
================================================
from .chromadb_vector import ChromaDB_VectorStore
================================================
FILE: src/vanna/legacy/chromadb/chromadb_vector.py
================================================
import json
from typing import List
import chromadb
import pandas as pd
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from ..base import VannaBase
from ..utils import deterministic_uuid
default_ef = embedding_functions.DefaultEmbeddingFunction()
class ChromaDB_VectorStore(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
if config is None:
config = {}
path = config.get("path", ".")
self.embedding_function = config.get("embedding_function", default_ef)
curr_client = config.get("client", "persistent")
collection_metadata = config.get("collection_metadata", None)
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_documentation = config.get(
"n_results_documentation", config.get("n_results", 10)
)
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
if curr_client == "persistent":
self.chroma_client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
elif curr_client == "in-memory":
self.chroma_client = chromadb.EphemeralClient(
settings=Settings(anonymized_telemetry=False)
)
elif isinstance(curr_client, chromadb.api.client.Client):
# allow providing client directly
self.chroma_client = curr_client
else:
raise ValueError(f"Unsupported client was set in config: {curr_client}")
self.documentation_collection = self.chroma_client.get_or_create_collection(
name="documentation",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.ddl_collection = self.chroma_client.get_or_create_collection(
name="ddl",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.sql_collection = self.chroma_client.get_or_create_collection(
name="sql",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding = self.embedding_function([data])
if len(embedding) == 1:
return embedding[0]
return embedding
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps(
{
"question": question,
"sql": sql,
},
ensure_ascii=False,
)
id = deterministic_uuid(question_sql_json) + "-sql"
self.sql_collection.add(
documents=question_sql_json,
embeddings=self.generate_embedding(question_sql_json),
ids=id,
)
return id
def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl) + "-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=id,
)
return id
def add_documentation(self, documentation: str, **kwargs) -> str:
id = deterministic_uuid(documentation) + "-doc"
self.documentation_collection.add(
documents=documentation,
embeddings=self.generate_embedding(documentation),
ids=id,
)
return id
def get_training_data(self, **kwargs) -> pd.DataFrame:
sql_data = self.sql_collection.get()
df = pd.DataFrame()
if sql_data is not None:
# Extract the documents and ids
documents = [json.loads(doc) for doc in sql_data["documents"]]
ids = sql_data["ids"]
# Create a DataFrame
df_sql = pd.DataFrame(
{
"id": ids,
"question": [doc["question"] for doc in documents],
"content": [doc["sql"] for doc in documents],
}
)
df_sql["training_data_type"] = "sql"
df = pd.concat([df, df_sql])
ddl_data = self.ddl_collection.get()
if ddl_data is not None:
# Extract the documents and ids
documents = [doc for doc in ddl_data["documents"]]
ids = ddl_data["ids"]
# Create a DataFrame
df_ddl = pd.DataFrame(
{
"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents],
}
)
df_ddl["training_data_type"] = "ddl"
df = pd.concat([df, df_ddl])
doc_data = self.documentation_collection.get()
if doc_data is not None:
# Extract the documents and ids
documents = [doc for doc in doc_data["documents"]]
ids = doc_data["ids"]
# Create a DataFrame
df_doc = pd.DataFrame(
{
"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents],
}
)
df_doc["training_data_type"] = "documentation"
df = pd.concat([df, df_doc])
return df
def remove_training_data(self, id: str, **kwargs) -> bool:
if id.endswith("-sql"):
self.sql_collection.delete(ids=id)
return True
elif id.endswith("-ddl"):
self.ddl_collection.delete(ids=id)
return True
elif id.endswith("-doc"):
self.documentation_collection.delete(ids=id)
return True
else:
return False
def remove_collection(self, collection_name: str) -> bool:
"""
This function can reset the collection to empty state.
Args:
collection_name (str): sql or ddl or documentation
Returns:
bool: True if collection is deleted, False otherwise
"""
if collection_name == "sql":
self.chroma_client.delete_collection(name="sql")
self.sql_collection = self.chroma_client.get_or_create_collection(
name="sql", embedding_function=self.embedding_function
)
return True
elif collection_name == "ddl":
self.chroma_client.delete_collection(name="ddl")
self.ddl_collection = self.chroma_client.get_or_create_collection(
name="ddl", embedding_function=self.embedding_function
)
return True
elif collection_name == "documentation":
self.chroma_client.delete_collection(name="documentation")
self.documentation_collection = self.chroma_client.get_or_create_collection(
name="documentation", embedding_function=self.embedding_function
)
return True
else:
return False
@staticmethod
def _extract_documents(query_results) -> list:
"""
Static method to extract the documents from the results of a query.
Args:
query_results (pd.DataFrame): The dataframe to use.
Returns:
List[str] or None: The extracted documents, or an empty list or
single document if an error occurred.
"""
if query_results is None:
return []
if "documents" in query_results:
documents = query_results["documents"]
if len(documents) == 1 and isinstance(documents[0], list):
try:
documents = [json.loads(doc) for doc in documents[0]]
except Exception as e:
return documents[0]
return documents
def get_similar_question_sql(self, question: str, **kwargs) -> list:
return ChromaDB_VectorStore._extract_documents(
self.sql_collection.query(
query_texts=[question],
n_results=self.n_results_sql,
)
)
def get_related_ddl(self, question: str, **kwargs) -> list:
return ChromaDB_VectorStore._extract_documents(
self.ddl_collection.query(
query_texts=[question],
n_results=self.n_results_ddl,
)
)
def get_related_documentation(self, question: str, **kwargs) -> list:
return ChromaDB_VectorStore._extract_documents(
self.documentation_collection.query(
query_texts=[question],
n_results=self.n_results_documentation,
)
)
================================================
FILE: src/vanna/legacy/cohere/__init__.py
================================================
from .cohere_chat import Cohere_Chat
from .cohere_embeddings import Cohere_Embeddings
================================================
FILE: src/vanna/legacy/cohere/cohere_chat.py
================================================
import os
from openai import OpenAI
from ..base import VannaBase
class Cohere_Chat(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)
# default parameters - can be overridden using config
self.temperature = 0.2 # Lower temperature for more precise SQL generation
self.model = "command-a-03-2025" # Cohere's default model
if config is not None:
if "temperature" in config:
self.temperature = config["temperature"]
if "model" in config:
self.model = config["model"]
if client is not None:
self.client = client
return
# Check for API key in environment variable
api_key = os.getenv("COHERE_API_KEY")
# Check for API key in config
if config is not None and "api_key" in config:
api_key = config["api_key"]
# Validate API key
if not api_key:
raise ValueError(
"Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable."
)
# Initialize client with validated API key
self.client = OpenAI(
base_url="https://api.cohere.ai/compatibility/v1",
api_key=api_key,
)
def system_message(self, message: str) -> any:
return {
"role": "developer",
"content": message,
} # Cohere uses 'developer' for system role
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")
if len(prompt) == 0:
raise Exception("Prompt is empty")
# Count the number of tokens in the message log
# Use 4 as an approximation for the number of characters per token
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
# Use model from kwargs, config, or default
model = kwargs.get("model", self.model)
if self.config is not None and "model" in self.config and model == self.model:
model = self.config["model"]
print(f"Using model {model} for {num_tokens} tokens (approx)")
try:
response = self.client.chat.completions.create(
model=model,
messages=prompt,
temperature=self.temperature,
)
# Check if response has expected structure
if not response or not hasattr(response, "choices") or not response.choices:
raise ValueError("Received empty or malformed response from API")
if not response.choices[0] or not hasattr(response.choices[0], "message"):
raise ValueError("Response is missing expected 'message' field")
if not hasattr(response.choices[0].message, "content"):
raise ValueError("Response message is missing expected 'content' field")
return response.choices[0].message.content
except Exception as e:
# Log the error and raise a more informative exception
error_msg = f"Error processing Cohere chat response: {str(e)}"
print(error_msg)
raise Exception(error_msg)
================================================
FILE: src/vanna/legacy/cohere/cohere_embeddings.py
================================================
import os
from openai import OpenAI
from ..base import VannaBase
class Cohere_Embeddings(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)
# Default embedding model
self.model = "embed-multilingual-v3.0"
if config is not None and "model" in config:
self.model = config["model"]
if client is not None:
self.client = client
return
# Check for API key in environment variable
api_key = os.getenv("COHERE_API_KEY")
# Check for API key in config
if config is not None and "api_key" in config:
api_key = config["api_key"]
# Validate API key
if not api_key:
raise ValueError(
"Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable."
)
# Initialize client with validated API key
self.client = OpenAI(
base_url="https://api.cohere.ai/compatibility/v1",
api_key=api_key,
)
def generate_embedding(self, data: str, **kwargs) -> list[float]:
if not data:
raise ValueError("Cannot generate embedding for empty input data")
# Use model from kwargs, config, or default
model = kwargs.get("model", self.model)
if self.config is not None and "model" in self.config and model == self.model:
model = self.config["model"]
try:
embedding = self.client.embeddings.create(
model=model,
input=data,
encoding_format="float", # Ensure we get float values
)
# Check if response has expected structure
if not embedding or not hasattr(embedding, "data") or not embedding.data:
raise ValueError(
"Received empty or malformed embedding response from API"
)
if not embedding.data[0] or not hasattr(embedding.data[0], "embedding"):
raise ValueError(
"Embedding response is missing expected 'embedding' field"
)
if not embedding.data[0].embedding:
raise ValueError("Received empty embedding vector")
return embedding.data[0].embedding
except Exception as e:
# Log the error and raise a more informative exception
error_msg = f"Error generating embedding with Cohere: {str(e)}"
print(error_msg)
raise Exception(error_msg)
================================================
FILE: src/vanna/legacy/deepseek/__init__.py
================================================
from .deepseek_chat import DeepSeekChat
================================================
FILE: src/vanna/legacy/deepseek/deepseek_chat.py
================================================
import os
from openai import OpenAI
from ..base import VannaBase
# from vanna.chromadb import ChromaDB_VectorStore
# class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
# def __init__(self, config=None):
# ChromaDB_VectorStore.__init__(self, config=config)
# DeepSeekChat.__init__(self, config=config)
# vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"})
class DeepSeekChat(VannaBase):
def __init__(self, config=None):
if config is None:
raise ValueError(
"For DeepSeek, config must be provided with an api_key and model"
)
if "api_key" not in config:
raise ValueError("config must contain a DeepSeek api_key")
if "model" not in config:
raise ValueError("config must contain a DeepSeek model")
api_key = config["api_key"]
model = config["model"]
self.model = model
self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1")
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
def generate_sql(self, question: str, **kwargs) -> str:
# 使用父类的 generate_sql
sql = super().generate_sql(question, **kwargs)
# 替换 "\_" 为 "_"
sql = sql.replace("\\_", "_")
return sql
def submit_prompt(self, prompt, **kwargs) -> str:
chat_response = self.client.chat.completions.create(
model=self.model,
messages=prompt,
)
return chat_response.choices[0].message.content
================================================
FILE: src/vanna/legacy/exceptions/__init__.py
================================================
class ImproperlyConfigured(Exception):
"""Raise for incorrect configuration."""
pass
class DependencyError(Exception):
"""Raise for missing dependencies."""
pass
class ConnectionError(Exception):
"""Raise for connection"""
pass
class OTPCodeError(Exception):
"""Raise for invalid otp or not able to send it"""
pass
class SQLRemoveError(Exception):
"""Raise when not able to remove SQL"""
pass
class ExecutionError(Exception):
"""Raise when not able to execute Code"""
pass
class ValidationError(Exception):
"""Raise for validations"""
pass
class APIError(Exception):
"""Raise for API errors"""
pass
================================================
FILE: src/vanna/legacy/faiss/__init__.py
================================================
from .faiss import FAISS
================================================
FILE: src/vanna/legacy/faiss/faiss.py
================================================
import os
import json
import uuid
from typing import List, Dict, Any
import faiss
import numpy as np
import pandas as pd
from ..base import VannaBase
from ..exceptions import DependencyError
class FAISS(VannaBase):
def __init__(self, config=None):
if config is None:
config = {}
VannaBase.__init__(self, config=config)
try:
import faiss
except ImportError:
raise DependencyError(
"FAISS is not installed. Please install it with 'pip install faiss-cpu' or 'pip install faiss-gpu'"
)
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise DependencyError(
"SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'."
)
self.path = config.get("path", ".")
self.embedding_dim = config.get("embedding_dim", 384)
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
self.n_results_documentation = config.get(
"n_results_documentation", config.get("n_results", 10)
)
self.curr_client = config.get("client", "persistent")
if self.curr_client == "persistent":
self.sql_index = self._load_or_create_index("sql_index.faiss")
self.ddl_index = self._load_or_create_index("ddl_index.faiss")
self.doc_index = self._load_or_create_index("doc_index.faiss")
elif self.curr_client == "in-memory":
self.sql_index = faiss.IndexFlatL2(self.embedding_dim)
self.ddl_index = faiss.IndexFlatL2(self.embedding_dim)
self.doc_index = faiss.IndexFlatL2(self.embedding_dim)
elif (
isinstance(self.curr_client, list)
and len(self.curr_client) == 3
and all(isinstance(idx, faiss.Index) for idx in self.curr_client)
):
self.sql_index = self.curr_client[0]
self.ddl_index = self.curr_client[1]
self.doc_index = self.curr_client[2]
else:
raise ValueError(
f"Unsupported storage type was set in config: {self.curr_client}"
)
self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata(
"sql_metadata.json"
)
self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata(
"ddl_metadata.json"
)
self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata(
"doc_metadata.json"
)
model_name = config.get("embedding_model", "all-MiniLM-L6-v2")
self.embedding_model = SentenceTransformer(model_name)
def _load_or_create_index(self, filename):
filepath = os.path.join(self.path, filename)
if os.path.exists(filepath):
return faiss.read_index(filepath)
return faiss.IndexFlatL2(self.embedding_dim)
def _load_or_create_metadata(self, filename):
filepath = os.path.join(self.path, filename)
if os.path.exists(filepath):
with open(filepath, "r") as f:
return json.load(f)
return []
def _save_index(self, index, filename):
if self.curr_client == "persistent":
filepath = os.path.join(self.path, filename)
faiss.write_index(index, filepath)
def _save_metadata(self, metadata, filename):
if self.curr_client == "persistent":
filepath = os.path.join(self.path, filename)
with open(filepath, "w") as f:
json.dump(metadata, f)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding = self.embedding_model.encode(data)
assert embedding.shape[0] == self.embedding_dim, (
f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
)
return embedding.tolist()
def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str:
embedding = self.generate_embedding(text)
index.add(np.array([embedding], dtype=np.float32))
entry_id = str(uuid.uuid4())
metadata_list.append({"id": entry_id, **(extra_metadata or {})})
return entry_id
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
entry_id = self._add_to_index(
self.sql_index,
self.sql_metadata,
question + " " + sql,
{"question": question, "sql": sql},
)
self._save_index(self.sql_index, "sql_index.faiss")
self._save_metadata(self.sql_metadata, "sql_metadata.json")
return entry_id
def add_ddl(self, ddl: str, **kwargs) -> str:
entry_id = self._add_to_index(
self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl}
)
self._save_index(self.ddl_index, "ddl_index.faiss")
self._save_metadata(self.ddl_metadata, "ddl_metadata.json")
return entry_id
def add_documentation(self, documentation: str, **kwargs) -> str:
entry_id = self._add_to_index(
self.doc_index,
self.doc_metadata,
documentation,
{"documentation": documentation},
)
self._save_index(self.doc_index, "doc_index.faiss")
self._save_metadata(self.doc_metadata, "doc_metadata.json")
return entry_id
def _get_similar(self, index, metadata_list, text, n_results) -> list:
embedding = self.generate_embedding(text)
D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results)
return (
[] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]]
)
def get_similar_question_sql(self, question: str, **kwargs) -> list:
return self._get_similar(
self.sql_index, self.sql_metadata, question, self.n_results_sql
)
def get_related_ddl(self, question: str, **kwargs) -> list:
return [
metadata["ddl"]
for metadata in self._get_similar(
self.ddl_index, self.ddl_metadata, question, self.n_results_ddl
)
]
def get_related_documentation(self, question: str, **kwargs) -> list:
return [
metadata["documentation"]
for metadata in self._get_similar(
self.doc_index,
self.doc_metadata,
question,
self.n_results_documentation,
)
]
def get_training_data(self, **kwargs) -> pd.DataFrame:
sql_data = pd.DataFrame(self.sql_metadata)
sql_data["training_data_type"] = "sql"
ddl_data = pd.DataFrame(self.ddl_metadata)
ddl_data["training_data_type"] = "ddl"
doc_data = pd.DataFrame(self.doc_metadata)
doc_data["training_data_type"] = "documentation"
return pd.concat([sql_data, ddl_data, doc_data], ignore_index=True)
def remove_training_data(self, id: str, **kwargs) -> bool:
for metadata_list, index, index_name in [
(self.sql_metadata, self.sql_index, "sql_index.faiss"),
(self.ddl_metadata, self.ddl_index, "ddl_index.faiss"),
(self.doc_metadata, self.doc_index, "doc_index.faiss"),
]:
for i, item in enumerate(metadata_list):
if item["id"] == id:
del metadata_list[i]
new_index = faiss.IndexFlatL2(self.embedding_dim)
embeddings = [
self.generate_embedding(json.dumps(m)) for m in metadata_list
]
if embeddings:
new_index.add(np.array(embeddings, dtype=np.float32))
setattr(self, index_name.split(".")[0], new_index)
if self.curr_client == "persistent":
self._save_index(new_index, index_name)
self._save_metadata(
metadata_list, f"{index_name.split('.')[0]}_metadata.json"
)
return True
return False
def remove_collection(self, collection_name: str) -> bool:
if collection_name in ["sql", "ddl", "documentation"]:
setattr(
self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim)
)
setattr(self, f"{collection_name}_metadata", [])
if self.curr_client == "persistent":
self._save_index(
getattr(self, f"{collection_name}_index"),
f"{collection_name}_index.faiss",
)
self._save_metadata([], f"{collection_name}_metadata.json")
return True
return False
================================================
FILE: src/vanna/legacy/flask/__init__.py
================================================
import json
import logging
import os
import sys
import uuid
from abc import ABC, abstractmethod
from functools import wraps
import importlib.metadata
import flask
import requests
from flasgger import Swagger
from flask import Flask, Response, jsonify, request, send_from_directory
from flask_sock import Sock
from ..base import VannaBase
from .assets import css_content, html_content, js_content
from .auth import AuthInterface, NoAuth
class Cache(ABC):
"""
Define the interface for a cache that can be used to store data in a Flask app.
"""
@abstractmethod
def generate_id(self, *args, **kwargs):
"""
Generate a unique ID for the cache.
"""
pass
@abstractmethod
def get(self, id, field):
"""
Get a value from the cache.
"""
pass
@abstractmethod
def get_all(self, field_list) -> list:
"""
Get all values from the cache.
"""
pass
@abstractmethod
def set(self, id, field, value):
"""
Set a value in the cache.
"""
pass
@abstractmethod
def delete(self, id):
"""
Delete a value from the cache.
"""
pass
class MemoryCache(Cache):
def __init__(self):
self.cache = {}
def generate_id(self, *args, **kwargs):
return str(uuid.uuid4())
def set(self, id, field, value):
if id not in self.cache:
self.cache[id] = {}
self.cache[id][field] = value
def get(self, id, field):
if id not in self.cache:
return None
if field not in self.cache[id]:
return None
return self.cache[id][field]
def get_all(self, field_list) -> list:
return [
{"id": id, **{field: self.get(id=id, field=field) for field in field_list}}
for id in self.cache
]
def delete(self, id):
if id in self.cache:
del self.cache[id]
class VannaFlaskAPI:
flask_app = None
def requires_cache(self, required_fields, optional_fields=[]):
def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
id = request.args.get("id")
if id is None:
id = request.json.get("id")
if id is None:
return jsonify({"type": "error", "error": "No id provided"})
for field in required_fields:
if self.cache.get(id=id, field=field) is None:
return jsonify({"type": "error", "error": f"No {field} found"})
field_values = {
field: self.cache.get(id=id, field=field)
for field in required_fields
}
for field in optional_fields:
field_values[field] = self.cache.get(id=id, field=field)
# Add the id to the field_values
field_values["id"] = id
return f(*args, **field_values, **kwargs)
return decorated
return decorator
def requires_auth(self, f):
@wraps(f)
def decorated(*args, **kwargs):
user = self.auth.get_user(flask.request)
if not self.auth.is_logged_in(user):
return jsonify(
{"type": "not_logged_in", "html": self.auth.login_form()}
)
# Pass the user to the function
return f(*args, user=user, **kwargs)
return decorated
def __init__(
self,
vn: VannaBase,
cache: Cache = MemoryCache(),
auth: AuthInterface = NoAuth(),
debug=True,
allow_llm_to_see_data=False,
chart=True,
):
"""
Expose a Flask API that can be used to interact with a Vanna instance.
Args:
vn: The Vanna instance to interact with.
cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
debug: Show the debug console. Defaults to True.
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
chart: Whether to show the chart output in the UI. Defaults to True.
Returns:
None
"""
self.flask_app = Flask(__name__)
self.swagger = Swagger(
self.flask_app, template={"info": {"title": "Vanna API"}}
)
self.sock = Sock(self.flask_app)
self.ws_clients = []
self.vn = vn
self.auth = auth
self.cache = cache
self.debug = debug
self.allow_llm_to_see_data = allow_llm_to_see_data
self.chart = chart
self.config = {
"debug": debug,
"allow_llm_to_see_data": allow_llm_to_see_data,
"chart": chart,
}
log = logging.getLogger("werkzeug")
log.setLevel(logging.ERROR)
if "google.colab" in sys.modules:
self.debug = False
print(
"Google Colab doesn't support running websocket servers. Disabling debug mode."
)
if self.debug:
def log(message, title="Info"):
[
ws.send(json.dumps({"message": message, "title": title}))
for ws in self.ws_clients
]
self.vn.log = log
@self.flask_app.route("/api/v0/get_config", methods=["GET"])
@self.requires_auth
def get_config(user: any):
"""
Get the configuration for a user
---
parameters:
- name: user
in: query
responses:
200:
schema:
type: object
properties:
type:
type: string
default: config
config:
type: object
"""
config = self.auth.override_config_for_user(user, self.config)
return jsonify({"type": "config", "config": config})
@self.flask_app.route("/api/v0/generate_questions", methods=["GET"])
@self.requires_auth
def generate_questions(user: any):
"""
Generate questions
---
parameters:
- name: user
in: query
responses:
200:
schema:
type: object
properties:
type:
type: string
default: question_list
questions:
type: array
items:
type: string
header:
type: string
default: Here are some questions you can ask
"""
# If self has an _model attribute and model=='chinook'
if hasattr(self.vn, "_model") and self.vn._model == "chinook":
return jsonify(
{
"type": "question_list",
"questions": [
"What are the top 10 artists by sales?",
"What are the total sales per year by country?",
"Who is the top selling artist in each genre? Show the sales numbers.",
"How do the employees rank in terms of sales performance?",
"Which 5 cities have the most customers?",
],
"header": "Here are some questions you can ask:",
}
)
training_data = vn.get_training_data()
# If training data is None or empty
if training_data is None or len(training_data) == 0:
return jsonify(
{
"type": "error",
"error": "No training data found. Please add some training data first.",
}
)
# Get the questions from the training data
try:
# Filter training data to only include questions where the question is not null
questions = (
training_data[training_data["question"].notnull()]
.sample(5)["question"]
.tolist()
)
# Temporarily this will just return an empty list
return jsonify(
{
"type": "question_list",
"questions": questions,
"header": "Here are some questions you can ask",
}
)
except Exception as e:
return jsonify(
{
"type": "question_list",
"questions": [],
"header": "Go ahead and ask a question",
}
)
@self.flask_app.route("/api/v0/generate_sql", methods=["GET"])
@self.requires_auth
def generate_sql(user: any):
"""
Generate SQL from a question
---
parameters:
- name: user
in: query
- name: question
in: query
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: sql
id:
type: string
text:
type: string
"""
question = flask.request.args.get("question")
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
id = self.cache.generate_id(question=question)
sql = vn.generate_sql(
question=question, allow_llm_to_see_data=self.allow_llm_to_see_data
)
self.cache.set(id=id, field="question", value=question)
self.cache.set(id=id, field="sql", value=sql)
if vn.is_sql_valid(sql=sql):
return jsonify(
{
"type": "sql",
"id": id,
"text": sql,
}
)
else:
return jsonify(
{
"type": "text",
"id": id,
"text": sql,
}
)
@self.flask_app.route("/api/v0/generate_rewritten_question", methods=["GET"])
@self.requires_auth
def generate_rewritten_question(user: any):
"""
Generate a rewritten question
---
parameters:
- name: last_question
in: query
type: string
required: true
- name: new_question
in: query
type: string
required: true
"""
last_question = flask.request.args.get("last_question")
new_question = flask.request.args.get("new_question")
rewritten_question = self.vn.generate_rewritten_question(
last_question, new_question
)
return jsonify(
{"type": "rewritten_question", "question": rewritten_question}
)
@self.flask_app.route("/api/v0/get_function", methods=["GET"])
@self.requires_auth
def get_function(user: any):
"""
Get a function from a question
---
parameters:
- name: user
in: query
- name: question
in: query
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: function
id:
type: object
function:
type: string
"""
question = flask.request.args.get("question")
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
if not hasattr(vn, "get_function"):
return jsonify(
{
"type": "error",
"error": "This setup does not support function generation.",
}
)
id = self.cache.generate_id(question=question)
function = vn.get_function(question=question)
if function is None:
return jsonify({"type": "error", "error": "No function found"})
if "instantiated_sql" not in function:
self.vn.log(f"No instantiated SQL found for {question} in {function}")
return jsonify({"type": "error", "error": "No instantiated SQL found"})
self.cache.set(id=id, field="question", value=question)
self.cache.set(id=id, field="sql", value=function["instantiated_sql"])
if (
"instantiated_post_processing_code" in function
and function["instantiated_post_processing_code"] is not None
and len(function["instantiated_post_processing_code"]) > 0
):
self.cache.set(
id=id,
field="plotly_code",
value=function["instantiated_post_processing_code"],
)
return jsonify(
{
"type": "function",
"id": id,
"function": function,
}
)
@self.flask_app.route("/api/v0/get_all_functions", methods=["GET"])
@self.requires_auth
def get_all_functions(user: any):
"""
Get all the functions
---
parameters:
- name: user
in: query
responses:
200:
schema:
type: object
properties:
type:
type: string
default: functions
functions:
type: array
"""
if not hasattr(vn, "get_all_functions"):
return jsonify(
{
"type": "error",
"error": "This setup does not support function generation.",
}
)
functions = vn.get_all_functions()
return jsonify(
{
"type": "functions",
"functions": functions,
}
)
@self.flask_app.route("/api/v0/run_sql", methods=["GET"])
@self.requires_auth
@self.requires_cache(["sql"])
def run_sql(user: any, id: str, sql: str):
"""
Run SQL
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: df
id:
type: string
df:
type: object
should_generate_chart:
type: boolean
"""
try:
if not vn.run_sql_is_set:
return jsonify(
{
"type": "error",
"error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
}
)
df = vn.run_sql(sql=sql)
self.cache.set(id=id, field="df", value=df)
return jsonify(
{
"type": "df",
"id": id,
"df": df.head(10).to_json(orient="records", date_format="iso"),
"should_generate_chart": self.chart
and vn.should_generate_chart(df),
}
)
except Exception as e:
return jsonify({"type": "sql_error", "error": str(e)})
@self.flask_app.route("/api/v0/fix_sql", methods=["POST"])
@self.requires_auth
@self.requires_cache(["question", "sql"])
def fix_sql(user: any, id: str, question: str, sql: str):
"""
Fix SQL
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
- name: error
in: body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: sql
id:
type: string
text:
type: string
"""
error = flask.request.json.get("error")
if error is None:
return jsonify({"type": "error", "error": "No error provided"})
question = f"I have an error: {error}\n\nHere is the SQL I tried to run: {sql}\n\nThis is the question I was trying to answer: {question}\n\nCan you rewrite the SQL to fix the error?"
fixed_sql = vn.generate_sql(question=question)
self.cache.set(id=id, field="sql", value=fixed_sql)
return jsonify(
{
"type": "sql",
"id": id,
"text": fixed_sql,
}
)
@self.flask_app.route("/api/v0/update_sql", methods=["POST"])
@self.requires_auth
@self.requires_cache([])
def update_sql(user: any, id: str):
"""
Update SQL
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
- name: sql
in: body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: sql
id:
type: string
text:
type: string
"""
sql = flask.request.json.get("sql")
if sql is None:
return jsonify({"type": "error", "error": "No sql provided"})
self.cache.set(id=id, field="sql", value=sql)
return jsonify(
{
"type": "sql",
"id": id,
"text": sql,
}
)
@self.flask_app.route("/api/v0/download_csv", methods=["GET"])
@self.requires_auth
@self.requires_cache(["df"])
def download_csv(user: any, id: str, df):
"""
Download CSV
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
responses:
200:
description: download CSV
"""
csv = df.to_csv()
return Response(
csv,
mimetype="text/csv",
headers={"Content-disposition": f"attachment; filename={id}.csv"},
)
@self.flask_app.route("/api/v0/generate_plotly_figure", methods=["GET"])
@self.requires_auth
@self.requires_cache(["df", "question", "sql"])
def generate_plotly_figure(user: any, id: str, df, question, sql):
"""
Generate plotly figure
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
- name: chart_instructions
in: body
type: string
responses:
200:
schema:
type: object
properties:
type:
type: string
default: plotly_figure
id:
type: string
fig:
type: object
"""
chart_instructions = flask.request.args.get("chart_instructions")
try:
# If chart_instructions is not set then attempt to retrieve the code from the cache
if chart_instructions is None or len(chart_instructions) == 0:
code = self.cache.get(id=id, field="plotly_code")
else:
question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"
code = vn.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
self.cache.set(id=id, field="plotly_code", value=code)
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
fig_json = fig.to_json()
self.cache.set(id=id, field="fig_json", value=fig_json)
return jsonify(
{
"type": "plotly_figure",
"id": id,
"fig": fig_json,
}
)
except Exception as e:
# Print the stack trace
import traceback
traceback.print_stack()
traceback.print_exc()
return jsonify({"type": "error", "error": str(e)})
@self.flask_app.route("/api/v0/get_training_data", methods=["GET"])
@self.requires_auth
def get_training_data(user: any):
"""
Get all training data
---
parameters:
- name: user
in: query
responses:
200:
schema:
type: object
properties:
type:
type: string
default: df
id:
type: string
default: training_data
df:
type: object
"""
df = vn.get_training_data()
if df is None or len(df) == 0:
return jsonify(
{
"type": "error",
"error": "No training data found. Please add some training data first.",
}
)
return jsonify(
{
"type": "df",
"id": "training_data",
"df": df.to_json(orient="records"),
}
)
@self.flask_app.route("/api/v0/remove_training_data", methods=["POST"])
@self.requires_auth
def remove_training_data(user: any):
"""
Remove training data
---
parameters:
- name: user
in: query
- name: id
in: body
type: string
required: true
responses:
200:
schema:
type: object
properties:
success:
type: boolean
"""
# Get id from the JSON body
id = flask.request.json.get("id")
if id is None:
return jsonify({"type": "error", "error": "No id provided"})
if vn.remove_training_data(id=id):
return jsonify({"success": True})
else:
return jsonify(
{"type": "error", "error": "Couldn't remove training data"}
)
@self.flask_app.route("/api/v0/train", methods=["POST"])
@self.requires_auth
def add_training_data(user: any):
"""
Add training data
---
parameters:
- name: user
in: query
- name: question
in: body
type: string
- name: sql
in: body
type: string
- name: ddl
in: body
type: string
- name: documentation
in: body
type: string
responses:
200:
schema:
type: object
properties:
id:
type: string
"""
question = flask.request.json.get("question")
sql = flask.request.json.get("sql")
ddl = flask.request.json.get("ddl")
documentation = flask.request.json.get("documentation")
try:
id = vn.train(
question=question, sql=sql, ddl=ddl, documentation=documentation
)
return jsonify({"id": id})
except Exception as e:
print("TRAINING ERROR", e)
return jsonify({"type": "error", "error": str(e)})
@self.flask_app.route("/api/v0/create_function", methods=["GET"])
@self.requires_auth
@self.requires_cache(["question", "sql"])
def create_function(user: any, id: str, question: str, sql: str):
"""
Create function
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: function_template
id:
type: string
function_template:
type: object
"""
plotly_code = self.cache.get(id=id, field="plotly_code")
if plotly_code is None:
plotly_code = ""
function_data = self.vn.create_function(
question=question, sql=sql, plotly_code=plotly_code
)
return jsonify(
{
"type": "function_template",
"id": id,
"function_template": function_data,
}
)
@self.flask_app.route("/api/v0/update_function", methods=["POST"])
@self.requires_auth
def update_function(user: any):
"""
Update function
---
parameters:
- name: user
in: query
- name: old_function_name
in: body
type: string
required: true
- name: updated_function
in: body
type: object
required: true
responses:
200:
schema:
type: object
properties:
success:
type: boolean
"""
old_function_name = flask.request.json.get("old_function_name")
updated_function = flask.request.json.get("updated_function")
print("old_function_name", old_function_name)
print("updated_function", updated_function)
updated = vn.update_function(
old_function_name=old_function_name, updated_function=updated_function
)
return jsonify({"success": updated})
@self.flask_app.route("/api/v0/delete_function", methods=["POST"])
@self.requires_auth
def delete_function(user: any):
"""
Delete function
---
parameters:
- name: user
in: query
- name: function_name
in: body
type: string
required: true
responses:
200:
schema:
type: object
properties:
success:
type: boolean
"""
function_name = flask.request.json.get("function_name")
return jsonify({"success": vn.delete_function(function_name=function_name)})
@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
@self.requires_auth
@self.requires_cache(["df", "question", "sql"])
def generate_followup_questions(user: any, id: str, df, question, sql):
"""
Generate followup questions
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: question_list
questions:
type: array
items:
type: string
header:
type: string
"""
if self.allow_llm_to_see_data:
followup_questions = vn.generate_followup_questions(
question=question, sql=sql, df=df
)
if followup_questions is not None and len(followup_questions) > 5:
followup_questions = followup_questions[:5]
self.cache.set(
id=id, field="followup_questions", value=followup_questions
)
return jsonify(
{
"type": "question_list",
"id": id,
"questions": followup_questions,
"header": "Here are some potential followup questions:",
}
)
else:
self.cache.set(id=id, field="followup_questions", value=[])
return jsonify(
{
"type": "question_list",
"id": id,
"questions": [],
"header": "Followup Questions can be enabled if you set allow_llm_to_see_data=True",
}
)
@self.flask_app.route("/api/v0/generate_summary", methods=["GET"])
@self.requires_auth
@self.requires_cache(["df", "question"])
def generate_summary(user: any, id: str, df, question):
"""
Generate summary
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: text
id:
type: string
text:
type: string
"""
if self.allow_llm_to_see_data:
summary = vn.generate_summary(question=question, df=df)
self.cache.set(id=id, field="summary", value=summary)
return jsonify(
{
"type": "text",
"id": id,
"text": summary,
}
)
else:
return jsonify(
{
"type": "text",
"id": id,
"text": "Summarization can be enabled if you set allow_llm_to_see_data=True",
}
)
@self.flask_app.route("/api/v0/load_question", methods=["GET"])
@self.requires_auth
@self.requires_cache(
["question", "sql", "df"], optional_fields=["summary", "fig_json"]
)
def load_question(user: any, id: str, question, sql, df, fig_json, summary):
"""
Load question
---
parameters:
- name: user
in: query
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: question_cache
id:
type: string
question:
type: string
sql:
type: string
df:
type: object
fig:
type: object
summary:
type: string
"""
try:
return jsonify(
{
"type": "question_cache",
"id": id,
"question": question,
"sql": sql,
"df": df.head(10).to_json(orient="records", date_format="iso"),
"fig": fig_json,
"summary": summary,
}
)
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
@self.flask_app.route("/api/v0/get_question_history", methods=["GET"])
@self.requires_auth
def get_question_history(user: any):
"""
Get question history
---
parameters:
- name: user
in: query
responses:
200:
schema:
type: object
properties:
type:
type: string
default: question_history
questions:
type: array
items:
type: string
"""
return jsonify(
{
"type": "question_history",
"questions": cache.get_all(field_list=["question"]),
}
)
@self.flask_app.route("/api/v0/", methods=["GET", "POST"])
def catch_all(catch_all):
return jsonify(
{"type": "error", "error": "The rest of the API is not ported yet."}
)
if self.debug:
@self.sock.route("/api/v0/log")
def sock_log(ws):
self.ws_clients.append(ws)
try:
while True:
message = (
ws.receive()
) # This example just reads and ignores to keep the socket open
finally:
self.ws_clients.remove(ws)
def run(self, *args, **kwargs):
"""
Run the Flask app.
Args:
*args: Arguments to pass to Flask's run method.
**kwargs: Keyword arguments to pass to Flask's run method.
Returns:
None
"""
if args or kwargs:
self.flask_app.run(*args, **kwargs)
else:
try:
from google.colab import output
output.serve_kernel_port_as_window(8084)
from google.colab.output import eval_js
print("Your app is running at:")
print(eval_js("google.colab.kernel.proxyPort(8084)"))
except Exception:
print("Your app is running at:")
print("http://localhost:8084")
self.flask_app.run(
host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False
)
class VannaFlaskApp(VannaFlaskAPI):
def __init__(
self,
vn: VannaBase,
cache: Cache = MemoryCache(),
auth: AuthInterface = NoAuth(),
debug=True,
allow_llm_to_see_data=False,
logo="https://img.vanna.ai/vanna-flask.svg",
title="Welcome to Vanna.AI",
subtitle="Your AI-powered copilot for SQL queries.",
show_training_data=True,
suggested_questions=True,
sql=True,
table=True,
csv_download=True,
chart=True,
redraw_chart=True,
auto_fix_sql=True,
ask_results_correct=True,
followup_questions=True,
summarization=True,
function_generation=True,
index_html_path=None,
assets_folder=None,
):
"""
Expose a Flask app that can be used to interact with a Vanna instance.
Args:
vn: The Vanna instance to interact with.
cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
debug: Show the debug console. Defaults to True.
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
logo: The logo to display in the UI. Defaults to the Vanna logo.
title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
show_training_data: Whether to show the training data in the UI. Defaults to True.
suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
sql: Whether to show the SQL input in the UI. Defaults to True.
table: Whether to show the table output in the UI. Defaults to True.
csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
chart: Whether to show the chart output in the UI. Defaults to True.
redraw_chart: Whether to allow redrawing the chart. Defaults to True.
auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
followup_questions: Whether to show followup questions. Defaults to True.
summarization: Whether to show summarization. Defaults to True.
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
Returns:
None
"""
super().__init__(vn, cache, auth, debug, allow_llm_to_see_data, chart)
self.config["logo"] = logo
self.config["title"] = title
self.config["subtitle"] = subtitle
self.config["show_training_data"] = show_training_data
self.config["suggested_questions"] = suggested_questions
self.config["sql"] = sql
self.config["table"] = table
self.config["csv_download"] = csv_download
self.config["chart"] = chart
self.config["redraw_chart"] = redraw_chart
self.config["auto_fix_sql"] = auto_fix_sql
self.config["ask_results_correct"] = ask_results_correct
self.config["followup_questions"] = followup_questions
self.config["summarization"] = summarization
self.config["function_generation"] = function_generation and hasattr(
vn, "get_function"
)
self.config["version"] = importlib.metadata.version("vanna")
self.index_html_path = index_html_path
self.assets_folder = assets_folder
@self.flask_app.route("/auth/login", methods=["POST"])
def login():
return self.auth.login_handler(flask.request)
@self.flask_app.route("/auth/callback", methods=["GET"])
def callback():
return self.auth.callback_handler(flask.request)
@self.flask_app.route("/auth/logout", methods=["GET"])
def logout():
return self.auth.logout_handler(flask.request)
@self.flask_app.route("/assets/")
def proxy_assets(filename):
if self.assets_folder:
return send_from_directory(self.assets_folder, filename)
if ".css" in filename:
return Response(css_content, mimetype="text/css")
if ".js" in filename:
return Response(js_content, mimetype="text/javascript")
# Return 404
return "File not found", 404
# Proxy the /vanna.svg file to the remote server
@self.flask_app.route("/vanna.svg")
def proxy_vanna_svg():
remote_url = "https://vanna.ai/img/vanna.svg"
response = requests.get(remote_url, stream=True)
# Check if the request to the remote URL was successful
if response.status_code == 200:
excluded_headers = [
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
]
headers = [
(name, value)
for (name, value) in response.raw.headers.items()
if name.lower() not in excluded_headers
]
return Response(response.content, response.status_code, headers)
else:
return "Error fetching file from remote server", response.status_code
@self.flask_app.route("/", defaults={"path": ""})
@self.flask_app.route("/")
def hello(path: str):
if self.index_html_path:
directory = os.path.dirname(self.index_html_path)
filename = os.path.basename(self.index_html_path)
return send_from_directory(directory=directory, path=filename)
return html_content
================================================
FILE: src/vanna/legacy/flask/assets.py
================================================
html_content = """
Vanna.AI
"""
css_content = """.nav-title{font-family:Roboto Slab,serif}*,:before,:after{box-sizing:border-box;border-width:0;border-style:solid;border-color:#e5e7eb}:before,:after{--tw-content: ""}html{line-height:1.5;-webkit-text-size-adjust:100%;-moz-tab-size:4;-o-tab-size:4;tab-size:4;font-family:ui-sans-serif,system-ui,-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,Helvetica Neue,Arial,Noto Sans,sans-serif,"Apple Color Emoji","Segoe UI Emoji",Segoe UI Symbol,"Noto Color Emoji";font-feature-settings:normal;font-variation-settings:normal}body{margin:0;line-height:inherit}hr{height:0;color:inherit;border-top-width:1px}abbr:where([title]){-webkit-text-decoration:underline dotted;text-decoration:underline dotted}h1,h2,h3,h4,h5,h6{font-size:inherit;font-weight:inherit}a{color:inherit;text-decoration:inherit}b,strong{font-weight:bolder}code,kbd,samp,pre{font-family:ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,monospace;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}table{text-indent:0;border-color:inherit;border-collapse:collapse}button,input,optgroup,select,textarea{font-family:inherit;font-feature-settings:inherit;font-variation-settings:inherit;font-size:100%;font-weight:inherit;line-height:inherit;color:inherit;margin:0;padding:0}button,select{text-transform:none}button,[type=button],[type=reset],[type=submit]{-webkit-appearance:button;background-color:transparent;background-image:none}:-moz-focusring{outline:auto}:-moz-ui-invalid{box-shadow:none}progress{vertical-align:baseline}::-webkit-inner-spin-button,::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}summary{display:list-item}blockquote,dl,dd,h1,h2,h3,h4,h5,h6,hr,figure,p,pre{margin:0}fieldset{margin:0;padding:0}legend{padding:0}ol,ul,menu{list-style:none;margin:0;padding:0}dialog{padding:0}textarea{resize:vertical}input::-moz-placeholder,textarea::-moz-placeholder{opacity:1;color:#9ca3af}input::placeholder,textarea::placeholder{opacity:1;color:#9ca3af}button,[role=button]{cursor:pointer}:disabled{cursor:default}img,svg,video,canvas,audio,iframe,embed,object{display:block;vertical-align:middle}img,video{max-width:100%;height:auto}[hidden]{display:none}*,:before,:after{--tw-border-spacing-x: 0;--tw-border-spacing-y: 0;--tw-translate-x: 0;--tw-translate-y: 0;--tw-rotate: 0;--tw-skew-x: 0;--tw-skew-y: 0;--tw-scale-x: 1;--tw-scale-y: 1;--tw-pan-x: ;--tw-pan-y: ;--tw-pinch-zoom: ;--tw-scroll-snap-strictness: proximity;--tw-gradient-from-position: ;--tw-gradient-via-position: ;--tw-gradient-to-position: ;--tw-ordinal: ;--tw-slashed-zero: ;--tw-numeric-figure: ;--tw-numeric-spacing: ;--tw-numeric-fraction: ;--tw-ring-inset: ;--tw-ring-offset-width: 0px;--tw-ring-offset-color: #fff;--tw-ring-color: rgb(59 130 246 / .5);--tw-ring-offset-shadow: 0 0 #0000;--tw-ring-shadow: 0 0 #0000;--tw-shadow: 0 0 #0000;--tw-shadow-colored: 0 0 #0000;--tw-blur: ;--tw-brightness: ;--tw-contrast: ;--tw-grayscale: ;--tw-hue-rotate: ;--tw-invert: ;--tw-saturate: ;--tw-sepia: ;--tw-drop-shadow: ;--tw-backdrop-blur: ;--tw-backdrop-brightness: ;--tw-backdrop-contrast: ;--tw-backdrop-grayscale: ;--tw-backdrop-hue-rotate: ;--tw-backdrop-invert: ;--tw-backdrop-opacity: ;--tw-backdrop-saturate: ;--tw-backdrop-sepia: }::backdrop{--tw-border-spacing-x: 0;--tw-border-spacing-y: 0;--tw-translate-x: 0;--tw-translate-y: 0;--tw-rotate: 0;--tw-skew-x: 0;--tw-skew-y: 0;--tw-scale-x: 1;--tw-scale-y: 1;--tw-pan-x: ;--tw-pan-y: ;--tw-pinch-zoom: ;--tw-scroll-snap-strictness: proximity;--tw-gradient-from-position: ;--tw-gradient-via-position: ;--tw-gradient-to-position: ;--tw-ordinal: ;--tw-slashed-zero: ;--tw-numeric-figure: ;--tw-numeric-spacing: ;--tw-numeric-fraction: ;--tw-ring-inset: ;--tw-ring-offset-width: 0px;--tw-ring-offset-color: #fff;--tw-ring-color: rgb(59 130 246 / .5);--tw-ring-offset-shadow: 0 0 #0000;--tw-ring-shadow: 0 0 #0000;--tw-shadow: 0 0 #0000;--tw-shadow-colored: 0 0 #0000;--tw-blur: ;--tw-brightness: ;--tw-contrast: ;--tw-grayscale: ;--tw-hue-rotate: ;--tw-invert: ;--tw-saturate: ;--tw-sepia: ;--tw-drop-shadow: ;--tw-backdrop-blur: ;--tw-backdrop-brightness: ;--tw-backdrop-contrast: ;--tw-backdrop-grayscale: ;--tw-backdrop-hue-rotate: ;--tw-backdrop-invert: ;--tw-backdrop-opacity: ;--tw-backdrop-saturate: ;--tw-backdrop-sepia: }.sr-only{position:absolute;width:1px;height:1px;padding:0;margin:-1px;overflow:hidden;clip:rect(0,0,0,0);white-space:nowrap;border-width:0}.pointer-events-none{pointer-events:none}.collapse{visibility:collapse}.static{position:static}.fixed{position:fixed}.absolute{position:absolute}.relative{position:relative}.sticky{position:sticky}.inset-0{top:0;right:0;bottom:0;left:0}.inset-x-px{left:1px;right:1px}.inset-y-0{top:0;bottom:0}.bottom-0{bottom:0}.bottom-px{bottom:1px}.end-0{inset-inline-end:0px}.left-0{left:0}.right-0{right:0}.top-0{top:0}.z-10{z-index:10}.z-50{z-index:50}.z-\\[60\\]{z-index:60}.z-\\[80\\]{z-index:80}.-m-1{margin:-.25rem}.-m-1\\.5{margin:-.375rem}.m-1{margin:.25rem}.m-3{margin:.75rem}.mx-auto{margin-left:auto;margin-right:auto}.mb-1{margin-bottom:.25rem}.mb-2{margin-bottom:.5rem}.mb-2\\.5{margin-bottom:.625rem}.mb-3{margin-bottom:.75rem}.mb-auto{margin-bottom:auto}.ml-3{margin-left:.75rem}.ml-4{margin-left:1rem}.mr-1{margin-right:.25rem}.mr-1\\.5{margin-right:.375rem}.mr-3{margin-right:.75rem}.ms-0{margin-inline-start:0px}.ms-3{margin-inline-start:.75rem}.mt-0{margin-top:0}.mt-0\\.5{margin-top:.125rem}.mt-1{margin-top:.25rem}.mt-16{margin-top:4rem}.mt-2{margin-top:.5rem}.mt-2\\.5{margin-top:.625rem}.mt-3{margin-top:.75rem}.mt-4{margin-top:1rem}.mt-5{margin-top:1.25rem}.mt-6{margin-top:1.5rem}.mt-7{margin-top:1.75rem}.mt-auto{margin-top:auto}.block{display:block}.inline-block{display:inline-block}.inline{display:inline}.flex{display:flex}.inline-flex{display:inline-flex}.table{display:table}.grid{display:grid}.hidden{display:none}.h-1{height:.25rem}.h-1\\.5{height:.375rem}.h-2{height:.5rem}.h-3{height:.75rem}.h-3\\.5{height:.875rem}.h-4{height:1rem}.h-5{height:1.25rem}.h-52{height:13rem}.h-6{height:1.5rem}.h-7{height:1.75rem}.h-8{height:2rem}.h-\\[2\\.375rem\\]{height:2.375rem}.h-auto{height:auto}.h-full{height:100%}.h-px{height:1px}.h-screen{height:100vh}.min-h-\\[15rem\\]{min-height:15rem}.min-h-\\[calc\\(100\\%-3\\.5rem\\)\\]{min-height:calc(100% - 3.5rem)}.w-0{width:0px}.w-1{width:.25rem}.w-1\\.5{width:.375rem}.w-2{width:.5rem}.w-28{width:7rem}.w-3{width:.75rem}.w-3\\.5{width:.875rem}.w-4{width:1rem}.w-6{width:1.5rem}.w-64{width:16rem}.w-8{width:2rem}.w-\\[2\\.375rem\\]{width:2.375rem}.w-\\[3\\.25rem\\]{width:3.25rem}.w-full{width:100%}.w-px{width:1px}.min-w-full{min-width:100%}.max-w-2xl{max-width:42rem}.max-w-4xl{max-width:56rem}.max-w-7xl{max-width:80rem}.max-w-\\[50rem\\]{max-width:50rem}.max-w-\\[85rem\\]{max-width:85rem}.max-w-fit{max-width:-moz-fit-content;max-width:fit-content}.max-w-md{max-width:28rem}.max-w-sm{max-width:24rem}.max-w-xs{max-width:20rem}.flex-1{flex:1 1 0%}.flex-auto{flex:1 1 auto}.flex-none{flex:none}.flex-shrink-0,.shrink-0{flex-shrink:0}.flex-grow,.grow{flex-grow:1}.-translate-x-full{--tw-translate-x: -100%;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.translate-x-0{--tw-translate-x: 0px;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.translate-x-full{--tw-translate-x: 100%;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.transform{transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}@keyframes bounce{0%,to{transform:translateY(-25%);animation-timing-function:cubic-bezier(.8,0,1,1)}50%{transform:none;animation-timing-function:cubic-bezier(0,0,.2,1)}}.animate-bounce{animation:bounce 1s infinite}@keyframes ping{75%,to{transform:scale(2);opacity:0}}.animate-ping{animation:ping 1s cubic-bezier(0,0,.2,1) infinite}@keyframes pulse{50%{opacity:.5}}.animate-pulse{animation:pulse 2s cubic-bezier(.4,0,.6,1) infinite}@keyframes spin{to{transform:rotate(360deg)}}.animate-spin{animation:spin 1s linear infinite}.cursor-pointer{cursor:pointer}.resize{resize:both}.list-disc{list-style-type:disc}.appearance-none{-webkit-appearance:none;-moz-appearance:none;appearance:none}.grid-cols-4{grid-template-columns:repeat(4,minmax(0,1fr))}.flex-col{flex-direction:column}.items-start{align-items:flex-start}.items-center{align-items:center}.justify-end{justify-content:flex-end}.justify-center{justify-content:center}.justify-between{justify-content:space-between}.gap-1{gap:.25rem}.gap-1\\.5{gap:.375rem}.gap-2{gap:.5rem}.gap-3{gap:.75rem}.gap-4{gap:1rem}.gap-x-1{-moz-column-gap:.25rem;column-gap:.25rem}.gap-x-2{-moz-column-gap:.5rem;column-gap:.5rem}.gap-x-3{-moz-column-gap:.75rem;column-gap:.75rem}.gap-y-4{row-gap:1rem}.-space-y-px>:not([hidden])~:not([hidden]){--tw-space-y-reverse: 0;margin-top:calc(-1px * calc(1 - var(--tw-space-y-reverse)));margin-bottom:calc(-1px * var(--tw-space-y-reverse))}.space-x-2>:not([hidden])~:not([hidden]){--tw-space-x-reverse: 0;margin-right:calc(.5rem * var(--tw-space-x-reverse));margin-left:calc(.5rem * calc(1 - var(--tw-space-x-reverse)))}.space-y-1>:not([hidden])~:not([hidden]){--tw-space-y-reverse: 0;margin-top:calc(.25rem * calc(1 - var(--tw-space-y-reverse)));margin-bottom:calc(.25rem * var(--tw-space-y-reverse))}.space-y-1\\.5>:not([hidden])~:not([hidden]){--tw-space-y-reverse: 0;margin-top:calc(.375rem * calc(1 - var(--tw-space-y-reverse)));margin-bottom:calc(.375rem * var(--tw-space-y-reverse))}.space-y-2>:not([hidden])~:not([hidden]){--tw-space-y-reverse: 0;margin-top:calc(.5rem * calc(1 - var(--tw-space-y-reverse)));margin-bottom:calc(.5rem * var(--tw-space-y-reverse))}.space-y-3>:not([hidden])~:not([hidden]){--tw-space-y-reverse: 0;margin-top:calc(.75rem * calc(1 - var(--tw-space-y-reverse)));margin-bottom:calc(.75rem * var(--tw-space-y-reverse))}.space-y-4>:not([hidden])~:not([hidden]){--tw-space-y-reverse: 0;margin-top:calc(1rem * calc(1 - var(--tw-space-y-reverse)));margin-bottom:calc(1rem * var(--tw-space-y-reverse))}.space-y-5>:not([hidden])~:not([hidden]){--tw-space-y-reverse: 0;margin-top:calc(1.25rem * calc(1 - var(--tw-space-y-reverse)));margin-bottom:calc(1.25rem * var(--tw-space-y-reverse))}.divide-x>:not([hidden])~:not([hidden]){--tw-divide-x-reverse: 0;border-right-width:calc(1px * var(--tw-divide-x-reverse));border-left-width:calc(1px * calc(1 - var(--tw-divide-x-reverse)))}.divide-y>:not([hidden])~:not([hidden]){--tw-divide-y-reverse: 0;border-top-width:calc(1px * calc(1 - var(--tw-divide-y-reverse)));border-bottom-width:calc(1px * var(--tw-divide-y-reverse))}.divide-gray-200>:not([hidden])~:not([hidden]){--tw-divide-opacity: 1;border-color:rgb(229 231 235 / var(--tw-divide-opacity))}.overflow-hidden{overflow:hidden}.overflow-x-auto{overflow-x:auto}.overflow-y-auto{overflow-y:auto}.overflow-x-hidden{overflow-x:hidden}.overflow-y-hidden{overflow-y:hidden}.truncate{overflow:hidden;text-overflow:ellipsis;white-space:nowrap}.whitespace-nowrap{white-space:nowrap}.whitespace-pre-wrap{white-space:pre-wrap}.whitespace-break-spaces{white-space:break-spaces}.rounded{border-radius:.25rem}.rounded-full{border-radius:9999px}.rounded-lg{border-radius:.5rem}.rounded-md{border-radius:.375rem}.rounded-xl{border-radius:.75rem}.rounded-b-md{border-bottom-right-radius:.375rem;border-bottom-left-radius:.375rem}.rounded-t-xl{border-top-left-radius:.75rem;border-top-right-radius:.75rem}.rounded-ee-xl{border-end-end-radius:.75rem}.rounded-es-xl{border-end-start-radius:.75rem}.border{border-width:1px}.border-2{border-width:2px}.border-4{border-width:4px}.border-\\[3px\\]{border-width:3px}.border-b{border-bottom-width:1px}.border-b-2{border-bottom-width:2px}.border-r{border-right-width:1px}.border-s{border-inline-start-width:1px}.border-t{border-top-width:1px}.border-t-2{border-top-width:2px}.border-t-4{border-top-width:4px}.border-blue-500{--tw-border-opacity: 1;border-color:rgb(59 130 246 / var(--tw-border-opacity))}.border-blue-600{--tw-border-opacity: 1;border-color:rgb(37 99 235 / var(--tw-border-opacity))}.border-current{border-color:currentColor}.border-gray-200{--tw-border-opacity: 1;border-color:rgb(229 231 235 / var(--tw-border-opacity))}.border-gray-300{--tw-border-opacity: 1;border-color:rgb(209 213 219 / var(--tw-border-opacity))}.border-gray-600{--tw-border-opacity: 1;border-color:rgb(75 85 99 / var(--tw-border-opacity))}.border-gray-700{--tw-border-opacity: 1;border-color:rgb(55 65 81 / var(--tw-border-opacity))}.border-green-200{--tw-border-opacity: 1;border-color:rgb(187 247 208 / var(--tw-border-opacity))}.border-neutral-700{--tw-border-opacity: 1;border-color:rgb(64 64 64 / var(--tw-border-opacity))}.border-red-200{--tw-border-opacity: 1;border-color:rgb(254 202 202 / var(--tw-border-opacity))}.border-red-500{--tw-border-opacity: 1;border-color:rgb(239 68 68 / var(--tw-border-opacity))}.border-red-600{--tw-border-opacity: 1;border-color:rgb(220 38 38 / var(--tw-border-opacity))}.border-teal-100{--tw-border-opacity: 1;border-color:rgb(204 251 241 / var(--tw-border-opacity))}.border-teal-500{--tw-border-opacity: 1;border-color:rgb(20 184 166 / var(--tw-border-opacity))}.border-teal-900{--tw-border-opacity: 1;border-color:rgb(19 78 74 / var(--tw-border-opacity))}.border-transparent{border-color:transparent}.border-yellow-200{--tw-border-opacity: 1;border-color:rgb(254 240 138 / var(--tw-border-opacity))}.border-t-blue-500{--tw-border-opacity: 1;border-top-color:rgb(59 130 246 / var(--tw-border-opacity))}.border-t-blue-600{--tw-border-opacity: 1;border-top-color:rgb(37 99 235 / var(--tw-border-opacity))}.border-t-green-500{--tw-border-opacity: 1;border-top-color:rgb(34 197 94 / var(--tw-border-opacity))}.border-t-green-600{--tw-border-opacity: 1;border-top-color:rgb(22 163 74 / var(--tw-border-opacity))}.border-t-red-500{--tw-border-opacity: 1;border-top-color:rgb(239 68 68 / var(--tw-border-opacity))}.border-t-red-600{--tw-border-opacity: 1;border-top-color:rgb(220 38 38 / var(--tw-border-opacity))}.border-t-transparent{border-top-color:transparent}.bg-blue-500{--tw-bg-opacity: 1;background-color:rgb(59 130 246 / var(--tw-bg-opacity))}.bg-blue-600{--tw-bg-opacity: 1;background-color:rgb(37 99 235 / var(--tw-bg-opacity))}.bg-gray-100{--tw-bg-opacity: 1;background-color:rgb(243 244 246 / var(--tw-bg-opacity))}.bg-gray-50{--tw-bg-opacity: 1;background-color:rgb(249 250 251 / var(--tw-bg-opacity))}.bg-gray-600{--tw-bg-opacity: 1;background-color:rgb(75 85 99 / var(--tw-bg-opacity))}.bg-gray-700{--tw-bg-opacity: 1;background-color:rgb(55 65 81 / var(--tw-bg-opacity))}.bg-gray-800{--tw-bg-opacity: 1;background-color:rgb(31 41 55 / var(--tw-bg-opacity))}.bg-gray-900{--tw-bg-opacity: 1;background-color:rgb(17 24 39 / var(--tw-bg-opacity))}.bg-green-600{--tw-bg-opacity: 1;background-color:rgb(22 163 74 / var(--tw-bg-opacity))}.bg-neutral-800{--tw-bg-opacity: 1;background-color:rgb(38 38 38 / var(--tw-bg-opacity))}.bg-neutral-900{--tw-bg-opacity: 1;background-color:rgb(23 23 23 / var(--tw-bg-opacity))}.bg-red-500{--tw-bg-opacity: 1;background-color:rgb(239 68 68 / var(--tw-bg-opacity))}.bg-slate-800{--tw-bg-opacity: 1;background-color:rgb(30 41 59 / var(--tw-bg-opacity))}.bg-slate-900{--tw-bg-opacity: 1;background-color:rgb(15 23 42 / var(--tw-bg-opacity))}.bg-teal-200{--tw-bg-opacity: 1;background-color:rgb(153 246 228 / var(--tw-bg-opacity))}.bg-teal-50{--tw-bg-opacity: 1;background-color:rgb(240 253 250 / var(--tw-bg-opacity))}.bg-teal-800{--tw-bg-opacity: 1;background-color:rgb(17 94 89 / var(--tw-bg-opacity))}.bg-white{--tw-bg-opacity: 1;background-color:rgb(255 255 255 / var(--tw-bg-opacity))}.bg-yellow-50{--tw-bg-opacity: 1;background-color:rgb(254 252 232 / var(--tw-bg-opacity))}.bg-opacity-50{--tw-bg-opacity: .5}.bg-opacity-80{--tw-bg-opacity: .8}.p-1{padding:.25rem}.p-1\\.5{padding:.375rem}.p-2{padding:.5rem}.p-2\\.5{padding:.625rem}.p-3{padding:.75rem}.p-4{padding:1rem}.p-6{padding:1.5rem}.px-1{padding-left:.25rem;padding-right:.25rem}.px-3{padding-left:.75rem;padding-right:.75rem}.px-4{padding-left:1rem;padding-right:1rem}.px-6{padding-left:1.5rem;padding-right:1.5rem}.px-7{padding-left:1.75rem;padding-right:1.75rem}.py-1{padding-top:.25rem;padding-bottom:.25rem}.py-10{padding-top:2.5rem;padding-bottom:2.5rem}.py-16{padding-top:4rem;padding-bottom:4rem}.py-2{padding-top:.5rem;padding-bottom:.5rem}.py-2\\.5{padding-top:.625rem;padding-bottom:.625rem}.py-3{padding-top:.75rem;padding-bottom:.75rem}.py-4{padding-top:1rem;padding-bottom:1rem}.py-5{padding-top:1.25rem;padding-bottom:1.25rem}.pb-12{padding-bottom:3rem}.pe-11{padding-inline-end:2.75rem}.pe-3{padding-inline-end:.75rem}.pl-3{padding-left:.75rem}.pl-7{padding-left:1.75rem}.pr-10{padding-right:2.5rem}.pr-4{padding-right:1rem}.pr-9{padding-right:2.25rem}.ps-5{padding-inline-start:1.25rem}.text-left{text-align:left}.text-center{text-align:center}.text-start{text-align:start}.align-middle{vertical-align:middle}.font-mono{font-family:ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,monospace}.text-2xl{font-size:1.5rem;line-height:2rem}.text-3xl{font-size:1.875rem;line-height:2.25rem}.text-7xl{font-size:4.5rem;line-height:1}.text-base{font-size:1rem;line-height:1.5rem}.text-lg{font-size:1.125rem;line-height:1.75rem}.text-sm{font-size:.875rem;line-height:1.25rem}.text-xl{font-size:1.25rem;line-height:1.75rem}.text-xs{font-size:.75rem;line-height:1rem}.font-bold{font-weight:700}.font-medium{font-weight:500}.font-semibold{font-weight:600}.uppercase{text-transform:uppercase}.leading-none{line-height:1}.tracking-wide{letter-spacing:.025em}.text-blue-500{--tw-text-opacity: 1;color:rgb(59 130 246 / var(--tw-text-opacity))}.text-blue-600{--tw-text-opacity: 1;color:rgb(37 99 235 / var(--tw-text-opacity))}.text-gray-200{--tw-text-opacity: 1;color:rgb(229 231 235 / var(--tw-text-opacity))}.text-gray-300{--tw-text-opacity: 1;color:rgb(209 213 219 / var(--tw-text-opacity))}.text-gray-400{--tw-text-opacity: 1;color:rgb(156 163 175 / var(--tw-text-opacity))}.text-gray-500{--tw-text-opacity: 1;color:rgb(107 114 128 / var(--tw-text-opacity))}.text-gray-600{--tw-text-opacity: 1;color:rgb(75 85 99 / var(--tw-text-opacity))}.text-gray-700{--tw-text-opacity: 1;color:rgb(55 65 81 / var(--tw-text-opacity))}.text-gray-800{--tw-text-opacity: 1;color:rgb(31 41 55 / var(--tw-text-opacity))}.text-green-500{--tw-text-opacity: 1;color:rgb(34 197 94 / var(--tw-text-opacity))}.text-green-600{--tw-text-opacity: 1;color:rgb(22 163 74 / var(--tw-text-opacity))}.text-neutral-200{--tw-text-opacity: 1;color:rgb(229 229 229 / var(--tw-text-opacity))}.text-neutral-300{--tw-text-opacity: 1;color:rgb(212 212 212 / var(--tw-text-opacity))}.text-neutral-400{--tw-text-opacity: 1;color:rgb(163 163 163 / var(--tw-text-opacity))}.text-neutral-500{--tw-text-opacity: 1;color:rgb(115 115 115 / var(--tw-text-opacity))}.text-red-500{--tw-text-opacity: 1;color:rgb(239 68 68 / var(--tw-text-opacity))}.text-red-600{--tw-text-opacity: 1;color:rgb(220 38 38 / var(--tw-text-opacity))}.text-slate-400{--tw-text-opacity: 1;color:rgb(148 163 184 / var(--tw-text-opacity))}.text-slate-700{--tw-text-opacity: 1;color:rgb(51 65 85 / var(--tw-text-opacity))}.text-teal-400{--tw-text-opacity: 1;color:rgb(45 212 191 / var(--tw-text-opacity))}.text-teal-800{--tw-text-opacity: 1;color:rgb(17 94 89 / var(--tw-text-opacity))}.text-white{--tw-text-opacity: 1;color:rgb(255 255 255 / var(--tw-text-opacity))}.text-yellow-400{--tw-text-opacity: 1;color:rgb(250 204 21 / var(--tw-text-opacity))}.text-yellow-700{--tw-text-opacity: 1;color:rgb(161 98 7 / var(--tw-text-opacity))}.text-yellow-800{--tw-text-opacity: 1;color:rgb(133 77 14 / var(--tw-text-opacity))}.decoration-2{text-decoration-thickness:2px}.opacity-0{opacity:0}.shadow{--tw-shadow: 0 1px 3px 0 rgb(0 0 0 / .1), 0 1px 2px -1px rgb(0 0 0 / .1);--tw-shadow-colored: 0 1px 3px 0 var(--tw-shadow-color), 0 1px 2px -1px var(--tw-shadow-color);box-shadow:var(--tw-ring-offset-shadow, 0 0 #0000),var(--tw-ring-shadow, 0 0 #0000),var(--tw-shadow)}.shadow-md{--tw-shadow: 0 4px 6px -1px rgb(0 0 0 / .1), 0 2px 4px -2px rgb(0 0 0 / .1);--tw-shadow-colored: 0 4px 6px -1px var(--tw-shadow-color), 0 2px 4px -2px var(--tw-shadow-color);box-shadow:var(--tw-ring-offset-shadow, 0 0 #0000),var(--tw-ring-shadow, 0 0 #0000),var(--tw-shadow)}.shadow-sm{--tw-shadow: 0 1px 2px 0 rgb(0 0 0 / .05);--tw-shadow-colored: 0 1px 2px 0 var(--tw-shadow-color);box-shadow:var(--tw-ring-offset-shadow, 0 0 #0000),var(--tw-ring-shadow, 0 0 #0000),var(--tw-shadow)}.shadow-neutral-700{--tw-shadow-color: #404040;--tw-shadow: var(--tw-shadow-colored)}.shadow-slate-700{--tw-shadow-color: #334155;--tw-shadow: var(--tw-shadow-colored)}.ring-1{--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(1px + var(--tw-ring-offset-width)) var(--tw-ring-color);box-shadow:var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow, 0 0 #0000)}.ring-transparent{--tw-ring-color: transparent}.ring-offset-white{--tw-ring-offset-color: #fff}.blur{--tw-blur: blur(8px);filter:var(--tw-blur) var(--tw-brightness) var(--tw-contrast) var(--tw-grayscale) var(--tw-hue-rotate) var(--tw-invert) var(--tw-saturate) var(--tw-sepia) var(--tw-drop-shadow)}.filter{filter:var(--tw-blur) var(--tw-brightness) var(--tw-contrast) var(--tw-grayscale) var(--tw-hue-rotate) var(--tw-invert) var(--tw-saturate) var(--tw-sepia) var(--tw-drop-shadow)}.transition{transition-property:color,background-color,border-color,text-decoration-color,fill,stroke,opacity,box-shadow,transform,filter,-webkit-backdrop-filter;transition-property:color,background-color,border-color,text-decoration-color,fill,stroke,opacity,box-shadow,transform,filter,backdrop-filter;transition-property:color,background-color,border-color,text-decoration-color,fill,stroke,opacity,box-shadow,transform,filter,backdrop-filter,-webkit-backdrop-filter;transition-timing-function:cubic-bezier(.4,0,.2,1);transition-duration:.15s}.transition-all{transition-property:all;transition-timing-function:cubic-bezier(.4,0,.2,1);transition-duration:.15s}.transition-colors{transition-property:color,background-color,border-color,text-decoration-color,fill,stroke;transition-timing-function:cubic-bezier(.4,0,.2,1);transition-duration:.15s}.duration-200{transition-duration:.2s}.duration-300{transition-duration:.3s}.duration-500{transition-duration:.5s}.ease-in-out{transition-timing-function:cubic-bezier(.4,0,.2,1)}.ease-out{transition-timing-function:cubic-bezier(0,0,.2,1)}.\\[--body-scroll\\:true\\]{--body-scroll: true}.marker\\:text-blue-600 *::marker{color:#2563eb}.marker\\:text-blue-600::marker{color:#2563eb}.before\\:inline-block:before{content:var(--tw-content);display:inline-block}.before\\:h-6:before{content:var(--tw-content);height:1.5rem}.before\\:w-6:before{content:var(--tw-content);width:1.5rem}.before\\:translate-x-0:before{content:var(--tw-content);--tw-translate-x: 0px;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.before\\:transform:before{content:var(--tw-content);transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.before\\:rounded-full:before{content:var(--tw-content);border-radius:9999px}.before\\:bg-white:before{content:var(--tw-content);--tw-bg-opacity: 1;background-color:rgb(255 255 255 / var(--tw-bg-opacity))}.before\\:shadow:before{content:var(--tw-content);--tw-shadow: 0 1px 3px 0 rgb(0 0 0 / .1), 0 1px 2px -1px rgb(0 0 0 / .1);--tw-shadow-colored: 0 1px 3px 0 var(--tw-shadow-color), 0 1px 2px -1px var(--tw-shadow-color);box-shadow:var(--tw-ring-offset-shadow, 0 0 #0000),var(--tw-ring-shadow, 0 0 #0000),var(--tw-shadow)}.before\\:ring-0:before{content:var(--tw-content);--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(0px + var(--tw-ring-offset-width)) var(--tw-ring-color);box-shadow:var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow, 0 0 #0000)}.before\\:transition:before{content:var(--tw-content);transition-property:color,background-color,border-color,text-decoration-color,fill,stroke,opacity,box-shadow,transform,filter,-webkit-backdrop-filter;transition-property:color,background-color,border-color,text-decoration-color,fill,stroke,opacity,box-shadow,transform,filter,backdrop-filter;transition-property:color,background-color,border-color,text-decoration-color,fill,stroke,opacity,box-shadow,transform,filter,backdrop-filter,-webkit-backdrop-filter;transition-timing-function:cubic-bezier(.4,0,.2,1);transition-duration:.15s}.before\\:duration-200:before{content:var(--tw-content);transition-duration:.2s}.before\\:ease-in-out:before{content:var(--tw-content);transition-timing-function:cubic-bezier(.4,0,.2,1)}.first\\:mt-0:first-child{margin-top:0}.first\\:rounded-t-lg:first-child{border-top-left-radius:.5rem;border-top-right-radius:.5rem}.last\\:rounded-b-lg:last-child{border-bottom-right-radius:.5rem;border-bottom-left-radius:.5rem}.checked\\:bg-blue-600:checked{--tw-bg-opacity: 1;background-color:rgb(37 99 235 / var(--tw-bg-opacity))}.checked\\:bg-none:checked{background-image:none}.checked\\:before\\:translate-x-full:checked:before{content:var(--tw-content);--tw-translate-x: 100%;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.checked\\:before\\:bg-blue-200:checked:before{content:var(--tw-content);--tw-bg-opacity: 1;background-color:rgb(191 219 254 / var(--tw-bg-opacity))}.hover\\:border-blue-500:hover{--tw-border-opacity: 1;border-color:rgb(59 130 246 / var(--tw-border-opacity))}.hover\\:border-green-500:hover{--tw-border-opacity: 1;border-color:rgb(34 197 94 / var(--tw-border-opacity))}.hover\\:border-red-400:hover{--tw-border-opacity: 1;border-color:rgb(248 113 113 / var(--tw-border-opacity))}.hover\\:border-red-500:hover{--tw-border-opacity: 1;border-color:rgb(239 68 68 / var(--tw-border-opacity))}.hover\\:bg-blue-50:hover{--tw-bg-opacity: 1;background-color:rgb(239 246 255 / var(--tw-bg-opacity))}.hover\\:bg-blue-500:hover{--tw-bg-opacity: 1;background-color:rgb(59 130 246 / var(--tw-bg-opacity))}.hover\\:bg-blue-600:hover{--tw-bg-opacity: 1;background-color:rgb(37 99 235 / var(--tw-bg-opacity))}.hover\\:bg-blue-700:hover{--tw-bg-opacity: 1;background-color:rgb(29 78 216 / var(--tw-bg-opacity))}.hover\\:bg-gray-100:hover{--tw-bg-opacity: 1;background-color:rgb(243 244 246 / var(--tw-bg-opacity))}.hover\\:bg-gray-50:hover{--tw-bg-opacity: 1;background-color:rgb(249 250 251 / var(--tw-bg-opacity))}.hover\\:bg-green-500:hover{--tw-bg-opacity: 1;background-color:rgb(34 197 94 / var(--tw-bg-opacity))}.hover\\:bg-red-500:hover{--tw-bg-opacity: 1;background-color:rgb(239 68 68 / var(--tw-bg-opacity))}.hover\\:text-blue-500:hover{--tw-text-opacity: 1;color:rgb(59 130 246 / var(--tw-text-opacity))}.hover\\:text-blue-600:hover{--tw-text-opacity: 1;color:rgb(37 99 235 / var(--tw-text-opacity))}.hover\\:text-blue-800:hover{--tw-text-opacity: 1;color:rgb(30 64 175 / var(--tw-text-opacity))}.hover\\:text-gray-400:hover{--tw-text-opacity: 1;color:rgb(156 163 175 / var(--tw-text-opacity))}.hover\\:text-green-800:hover{--tw-text-opacity: 1;color:rgb(22 101 52 / var(--tw-text-opacity))}.hover\\:text-red-400:hover{--tw-text-opacity: 1;color:rgb(248 113 113 / var(--tw-text-opacity))}.hover\\:text-red-500:hover{--tw-text-opacity: 1;color:rgb(239 68 68 / var(--tw-text-opacity))}.hover\\:text-red-600:hover{--tw-text-opacity: 1;color:rgb(220 38 38 / var(--tw-text-opacity))}.hover\\:text-white:hover{--tw-text-opacity: 1;color:rgb(255 255 255 / var(--tw-text-opacity))}.focus\\:z-10:focus{z-index:10}.focus\\:border-blue-500:focus{--tw-border-opacity: 1;border-color:rgb(59 130 246 / var(--tw-border-opacity))}.focus\\:border-blue-600:focus{--tw-border-opacity: 1;border-color:rgb(37 99 235 / var(--tw-border-opacity))}.focus\\:border-indigo-500:focus{--tw-border-opacity: 1;border-color:rgb(99 102 241 / var(--tw-border-opacity))}.focus\\:outline-none:focus{outline:2px solid transparent;outline-offset:2px}.focus\\:ring-2:focus{--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color);box-shadow:var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow, 0 0 #0000)}.focus\\:ring-blue-500:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(59 130 246 / var(--tw-ring-opacity))}.focus\\:ring-blue-600:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(37 99 235 / var(--tw-ring-opacity))}.focus\\:ring-gray-400:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(156 163 175 / var(--tw-ring-opacity))}.focus\\:ring-green-200:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(187 247 208 / var(--tw-ring-opacity))}.focus\\:ring-indigo-500:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(99 102 241 / var(--tw-ring-opacity))}.focus\\:ring-red-200:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(254 202 202 / var(--tw-ring-opacity))}.focus\\:ring-red-500:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(239 68 68 / var(--tw-ring-opacity))}.focus\\:ring-offset-2:focus{--tw-ring-offset-width: 2px}.focus\\:ring-offset-white:focus{--tw-ring-offset-color: #fff}.disabled\\:pointer-events-none:disabled{pointer-events:none}.disabled\\:opacity-50:disabled{opacity:.5}[data-hs-tab].active.hs-tab-active\\:border-blue-600{--tw-border-opacity: 1;border-color:rgb(37 99 235 / var(--tw-border-opacity))}[data-hs-tab].active.hs-tab-active\\:font-semibold{font-weight:600}[data-hs-tab].active.hs-tab-active\\:text-blue-600{--tw-text-opacity: 1;color:rgb(37 99 235 / var(--tw-text-opacity))}[data-hs-tab].active .hs-tab-active\\:border-blue-600{--tw-border-opacity: 1;border-color:rgb(37 99 235 / var(--tw-border-opacity))}[data-hs-tab].active .hs-tab-active\\:font-semibold{font-weight:600}[data-hs-tab].active .hs-tab-active\\:text-blue-600{--tw-text-opacity: 1;color:rgb(37 99 235 / var(--tw-text-opacity))}.open.hs-overlay-open\\:mt-7{margin-top:1.75rem}.open.hs-overlay-open\\:translate-x-0{--tw-translate-x: 0px;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.open.hs-overlay-open\\:opacity-100{opacity:1}.open.hs-overlay-open\\:duration-500{transition-duration:.5s}.open .hs-overlay-open\\:mt-7{margin-top:1.75rem}.open .hs-overlay-open\\:translate-x-0{--tw-translate-x: 0px;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.open .hs-overlay-open\\:opacity-100{opacity:1}.open .hs-overlay-open\\:duration-500{transition-duration:.5s}@media (prefers-color-scheme: dark){.dark\\:divide-gray-700>:not([hidden])~:not([hidden]){--tw-divide-opacity: 1;border-color:rgb(55 65 81 / var(--tw-divide-opacity))}.dark\\:border-blue-500{--tw-border-opacity: 1;border-color:rgb(59 130 246 / var(--tw-border-opacity))}.dark\\:border-gray-600{--tw-border-opacity: 1;border-color:rgb(75 85 99 / var(--tw-border-opacity))}.dark\\:border-gray-700{--tw-border-opacity: 1;border-color:rgb(55 65 81 / var(--tw-border-opacity))}.dark\\:border-neutral-700{--tw-border-opacity: 1;border-color:rgb(64 64 64 / var(--tw-border-opacity))}.dark\\:border-red-500{--tw-border-opacity: 1;border-color:rgb(239 68 68 / var(--tw-border-opacity))}.dark\\:border-teal-900{--tw-border-opacity: 1;border-color:rgb(19 78 74 / var(--tw-border-opacity))}.dark\\:border-t-blue-500{--tw-border-opacity: 1;border-top-color:rgb(59 130 246 / var(--tw-border-opacity))}.dark\\:border-t-green-500{--tw-border-opacity: 1;border-top-color:rgb(34 197 94 / var(--tw-border-opacity))}.dark\\:border-t-red-500{--tw-border-opacity: 1;border-top-color:rgb(239 68 68 / var(--tw-border-opacity))}.dark\\:bg-gray-700{--tw-bg-opacity: 1;background-color:rgb(55 65 81 / var(--tw-bg-opacity))}.dark\\:bg-gray-800{--tw-bg-opacity: 1;background-color:rgb(31 41 55 / var(--tw-bg-opacity))}.dark\\:bg-neutral-800{--tw-bg-opacity: 1;background-color:rgb(38 38 38 / var(--tw-bg-opacity))}.dark\\:bg-neutral-900{--tw-bg-opacity: 1;background-color:rgb(23 23 23 / var(--tw-bg-opacity))}.dark\\:bg-slate-800{--tw-bg-opacity: 1;background-color:rgb(30 41 59 / var(--tw-bg-opacity))}.dark\\:bg-slate-900{--tw-bg-opacity: 1;background-color:rgb(15 23 42 / var(--tw-bg-opacity))}.dark\\:bg-teal-800{--tw-bg-opacity: 1;background-color:rgb(17 94 89 / var(--tw-bg-opacity))}.dark\\:bg-teal-800\\/30{background-color:#115e594d}.dark\\:bg-opacity-80{--tw-bg-opacity: .8}.dark\\:text-blue-500{--tw-text-opacity: 1;color:rgb(59 130 246 / var(--tw-text-opacity))}.dark\\:text-gray-200{--tw-text-opacity: 1;color:rgb(229 231 235 / var(--tw-text-opacity))}.dark\\:text-gray-300{--tw-text-opacity: 1;color:rgb(209 213 219 / var(--tw-text-opacity))}.dark\\:text-gray-400{--tw-text-opacity: 1;color:rgb(156 163 175 / var(--tw-text-opacity))}.dark\\:text-gray-500{--tw-text-opacity: 1;color:rgb(107 114 128 / var(--tw-text-opacity))}.dark\\:text-green-500{--tw-text-opacity: 1;color:rgb(34 197 94 / var(--tw-text-opacity))}.dark\\:text-neutral-200{--tw-text-opacity: 1;color:rgb(229 229 229 / var(--tw-text-opacity))}.dark\\:text-neutral-400{--tw-text-opacity: 1;color:rgb(163 163 163 / var(--tw-text-opacity))}.dark\\:text-neutral-500{--tw-text-opacity: 1;color:rgb(115 115 115 / var(--tw-text-opacity))}.dark\\:text-red-500{--tw-text-opacity: 1;color:rgb(239 68 68 / var(--tw-text-opacity))}.dark\\:text-slate-400{--tw-text-opacity: 1;color:rgb(148 163 184 / var(--tw-text-opacity))}.dark\\:text-teal-400{--tw-text-opacity: 1;color:rgb(45 212 191 / var(--tw-text-opacity))}.dark\\:text-white{--tw-text-opacity: 1;color:rgb(255 255 255 / var(--tw-text-opacity))}.dark\\:placeholder-gray-400::-moz-placeholder{--tw-placeholder-opacity: 1;color:rgb(156 163 175 / var(--tw-placeholder-opacity))}.dark\\:placeholder-gray-400::placeholder{--tw-placeholder-opacity: 1;color:rgb(156 163 175 / var(--tw-placeholder-opacity))}.dark\\:placeholder-neutral-500::-moz-placeholder{--tw-placeholder-opacity: 1;color:rgb(115 115 115 / var(--tw-placeholder-opacity))}.dark\\:placeholder-neutral-500::placeholder{--tw-placeholder-opacity: 1;color:rgb(115 115 115 / var(--tw-placeholder-opacity))}.dark\\:shadow-neutral-700\\/70{--tw-shadow-color: rgb(64 64 64 / .7);--tw-shadow: var(--tw-shadow-colored)}.dark\\:shadow-slate-700\\/\\[\\.7\\]{--tw-shadow-color: rgb(51 65 85 / .7);--tw-shadow: var(--tw-shadow-colored)}.dark\\:before\\:bg-gray-400:before{content:var(--tw-content);--tw-bg-opacity: 1;background-color:rgb(156 163 175 / var(--tw-bg-opacity))}.dark\\:checked\\:border-blue-500:checked{--tw-border-opacity: 1;border-color:rgb(59 130 246 / var(--tw-border-opacity))}.dark\\:checked\\:bg-blue-500:checked{--tw-bg-opacity: 1;background-color:rgb(59 130 246 / var(--tw-bg-opacity))}.dark\\:checked\\:bg-blue-600:checked{--tw-bg-opacity: 1;background-color:rgb(37 99 235 / var(--tw-bg-opacity))}.dark\\:checked\\:before\\:bg-blue-200:checked:before{content:var(--tw-content);--tw-bg-opacity: 1;background-color:rgb(191 219 254 / var(--tw-bg-opacity))}.dark\\:hover\\:border-blue-400:hover{--tw-border-opacity: 1;border-color:rgb(96 165 250 / var(--tw-border-opacity))}.dark\\:hover\\:border-red-400:hover{--tw-border-opacity: 1;border-color:rgb(248 113 113 / var(--tw-border-opacity))}.dark\\:hover\\:bg-gray-900:hover{--tw-bg-opacity: 1;background-color:rgb(17 24 39 / var(--tw-bg-opacity))}.dark\\:hover\\:bg-neutral-700:hover{--tw-bg-opacity: 1;background-color:rgb(64 64 64 / var(--tw-bg-opacity))}.dark\\:hover\\:bg-slate-800:hover{--tw-bg-opacity: 1;background-color:rgb(30 41 59 / var(--tw-bg-opacity))}.dark\\:hover\\:text-blue-400:hover{--tw-text-opacity: 1;color:rgb(96 165 250 / var(--tw-text-opacity))}.dark\\:hover\\:text-blue-500:hover{--tw-text-opacity: 1;color:rgb(59 130 246 / var(--tw-text-opacity))}.dark\\:hover\\:text-green-400:hover{--tw-text-opacity: 1;color:rgb(74 222 128 / var(--tw-text-opacity))}.dark\\:hover\\:text-red-400:hover{--tw-text-opacity: 1;color:rgb(248 113 113 / var(--tw-text-opacity))}.dark\\:hover\\:text-red-500:hover{--tw-text-opacity: 1;color:rgb(239 68 68 / var(--tw-text-opacity))}.dark\\:hover\\:text-slate-300:hover{--tw-text-opacity: 1;color:rgb(203 213 225 / var(--tw-text-opacity))}.dark\\:hover\\:text-white:hover{--tw-text-opacity: 1;color:rgb(255 255 255 / var(--tw-text-opacity))}.dark\\:focus\\:border-blue-500:focus{--tw-border-opacity: 1;border-color:rgb(59 130 246 / var(--tw-border-opacity))}.dark\\:focus\\:outline-none:focus{outline:2px solid transparent;outline-offset:2px}.dark\\:focus\\:ring-1:focus{--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(1px + var(--tw-ring-offset-width)) var(--tw-ring-color);box-shadow:var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow, 0 0 #0000)}.dark\\:focus\\:ring-blue-500:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(59 130 246 / var(--tw-ring-opacity))}.dark\\:focus\\:ring-gray-600:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(75 85 99 / var(--tw-ring-opacity))}.dark\\:focus\\:ring-gray-700:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(55 65 81 / var(--tw-ring-opacity))}.dark\\:focus\\:ring-neutral-600:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(82 82 82 / var(--tw-ring-opacity))}.dark\\:focus\\:ring-red-600:focus{--tw-ring-opacity: 1;--tw-ring-color: rgb(220 38 38 / var(--tw-ring-opacity))}.dark\\:focus\\:ring-offset-gray-800:focus{--tw-ring-offset-color: #1f2937}}@media (min-width: 640px){.sm\\:mx-auto{margin-left:auto;margin-right:auto}.sm\\:mb-3{margin-bottom:.75rem}.sm\\:mt-10{margin-top:2.5rem}.sm\\:w-auto{width:auto}.sm\\:w-full{width:100%}.sm\\:max-w-lg{max-width:32rem}.sm\\:flex-row{flex-direction:row}.sm\\:gap-3{gap:.75rem}.sm\\:gap-x-4{-moz-column-gap:1rem;column-gap:1rem}.sm\\:space-y-6>:not([hidden])~:not([hidden]){--tw-space-y-reverse: 0;margin-top:calc(1.5rem * calc(1 - var(--tw-space-y-reverse)));margin-bottom:calc(1.5rem * var(--tw-space-y-reverse))}.sm\\:p-4{padding:1rem}.sm\\:px-6{padding-left:1.5rem;padding-right:1.5rem}.sm\\:py-4{padding-top:1rem;padding-bottom:1rem}.sm\\:py-6{padding-top:1.5rem;padding-bottom:1.5rem}.sm\\:text-3xl{font-size:1.875rem;line-height:2.25rem}.sm\\:text-4xl{font-size:2.25rem;line-height:2.5rem}.sm\\:text-9xl{font-size:8rem;line-height:1}.sm\\:text-sm{font-size:.875rem;line-height:1.25rem}}@media (min-width: 768px){.md\\:flex{display:flex}.md\\:items-center{align-items:center}.md\\:justify-between{justify-content:space-between}.md\\:p-10{padding:2.5rem}.md\\:p-5{padding:1.25rem}}@media (min-width: 1024px){.lg\\:bottom-0{bottom:0}.lg\\:right-auto{right:auto}.lg\\:block{display:block}.lg\\:hidden{display:none}.lg\\:translate-x-0{--tw-translate-x: 0px;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skew(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.lg\\:px-8{padding-left:2rem;padding-right:2rem}.lg\\:py-14{padding-top:3.5rem;padding-bottom:3.5rem}.lg\\:pl-64{padding-left:16rem}}
"""
js_content = '''var Rn=Object.defineProperty;var nn=(E,e,T)=>e in E?Rn(E,e,{enumerable:!0,configurable:!0,writable:!0,value:T}):E[e]=T;var wt=(E,e,T)=>(nn(E,typeof e!="symbol"?e+"":e,T),T);(function(){const e=document.createElement("link").relList;if(e&&e.supports&&e.supports("modulepreload"))return;for(const r of document.querySelectorAll('link[rel="modulepreload"]'))t(r);new MutationObserver(r=>{for(const R of r)if(R.type==="childList")for(const n of R.addedNodes)n.tagName==="LINK"&&n.rel==="modulepreload"&&t(n)}).observe(document,{childList:!0,subtree:!0});function T(r){const R={};return r.integrity&&(R.integrity=r.integrity),r.referrerPolicy&&(R.referrerPolicy=r.referrerPolicy),r.crossOrigin==="use-credentials"?R.credentials="include":r.crossOrigin==="anonymous"?R.credentials="omit":R.credentials="same-origin",R}function t(r){if(r.ep)return;r.ep=!0;const R=T(r);fetch(r.href,R)}})();function j(){}function An(E,e){for(const T in e)E[T]=e[T];return E}function cR(E){return E()}function YT(){return Object.create(null)}function NE(E){E.forEach(cR)}function zE(E){return typeof E=="function"}function _e(E,e){return E!=E?e==e:E!==e||E&&typeof E=="object"||typeof E=="function"}let ut;function VT(E,e){return E===e?!0:(ut||(ut=document.createElement("a")),ut.href=e,E===ut.href)}function sn(E){return Object.keys(E).length===0}function fR(E,...e){if(E==null){for(const t of e)t(void 0);return j}const T=E.subscribe(...e);return T.unsubscribe?()=>T.unsubscribe():T}function uE(E){let e;return fR(E,T=>e=T)(),e}function eE(E,e,T){E.$$.on_destroy.push(fR(e,T))}function Ut(E,e,T,t){if(E){const r=PR(E,e,T,t);return E[0](r)}}function PR(E,e,T,t){return E[1]&&t?An(T.ctx.slice(),E[1](t(e))):T.ctx}function mt(E,e,T,t){if(E[2]&&t){const r=E[2](t(T));if(e.dirty===void 0)return r;if(typeof r=="object"){const R=[],n=Math.max(e.dirty.length,r.length);for(let s=0;s32){const e=[],T=E.ctx.length/32;for(let t=0;tE.removeEventListener(e,T,t)}function a(E,e,T){T==null?E.removeAttribute(e):E.getAttribute(e)!==T&&E.setAttribute(e,T)}function Sn(E){let e;return{p(...T){e=T,e.forEach(t=>E.push(t))},r(){e.forEach(T=>E.splice(E.indexOf(T),1))}}}function on(E){return Array.from(E.childNodes)}function Le(E,e){e=""+e,E.data!==e&&(E.data=e)}function Ye(E,e){E.value=e??""}function ct(E,e,T,t){T==null?E.style.removeProperty(e):E.style.setProperty(e,T,t?"important":"")}function WT(E,e,T){for(let t=0;tE.indexOf(t)===-1?e.push(t):T.push(t)),T.forEach(t=>t()),et=e}const ft=new Set;let vE;function Ge(){vE={r:0,c:[],p:vE}}function ge(){vE.r||NE(vE.c),vE=vE.p}function m(E,e){E&&E.i&&(ft.delete(E),E.i(e))}function y(E,e,T,t){if(E&&E.o){if(ft.has(E))return;ft.add(E),vE.c.push(()=>{ft.delete(E),t&&(T&&E.d(1),t())}),E.o(e)}else t&&t()}function De(E){return(E==null?void 0:E.length)!==void 0?E:Array.from(E)}function Cn(E,e){E.d(1),e.delete(E.key)}function un(E,e,T,t,r,R,n,s,S,A,o,i){let _=E.length,c=R.length,P=_;const p={};for(;P--;)p[E[P].key]=P;const C=[],L=new Map,I=new Map,u=[];for(P=c;P--;){const O=i(r,R,P),N=T(O);let D=n.get(N);D?t&&u.push(()=>D.p(O,e)):(D=A(N,O),D.c()),L.set(N,C[P]=D),N in p&&I.set(N,Math.abs(P-p[N]))}const H=new Set,b=new Set;function M(O){m(O,1),O.m(s,o),n.set(O.key,O),o=O.first,c--}for(;_&&c;){const O=C[c-1],N=E[_-1],D=O.key,B=N.key;O===N?(o=O.first,_--,c--):L.has(B)?!n.has(D)||H.has(D)?M(O):b.has(B)?_--:I.get(D)>I.get(B)?(b.add(D),M(O)):(H.add(B),_--):(S(N,n),_--)}for(;_--;){const O=E[_];L.has(O.key)||S(O,n)}for(;c;)M(C[c-1]);return NE(u),C}function cn(E,e,T){const t=E.$$.props[e];t!==void 0&&(E.$$.bound[t]=T,T(E.$$.ctx[t]))}function K(E){E&&E.c()}function X(E,e,T){const{fragment:t,after_update:r}=E.$$;t&&t.m(e,T),dt(()=>{const R=E.$$.on_mount.map(cR).filter(zE);E.$$.on_destroy?E.$$.on_destroy.push(...R):NE(R),E.$$.on_mount=[]}),r.forEach(dt)}function k(E,e){const T=E.$$;T.fragment!==null&&(Ln(T.after_update),NE(T.on_destroy),T.fragment&&T.fragment.d(e),T.on_destroy=T.fragment=null,T.ctx=[])}function fn(E,e){E.$$.dirty[0]===-1&&(jE.push(E),Nn(),E.$$.dirty.fill(0)),E.$$.dirty[e/31|0]|=1<{const P=c.length?c[0]:_;return A.ctx&&r(A.ctx[i],A.ctx[i]=P)&&(!A.skip_bound&&A.bound[i]&&A.bound[i](P),o&&fn(E,i)),_}):[],A.update(),o=!0,NE(A.before_update),A.fragment=t?t(A.ctx):!1,e.target){if(e.hydrate){const i=on(e.target);A.fragment&&A.fragment.l(i),i.forEach(Y)}else A.fragment&&A.fragment.c();e.intro&&m(E.$$.fragment),X(E,e.target,e.anchor),dR()}st(S)}class ue{constructor(){wt(this,"$$");wt(this,"$$set")}$destroy(){k(this,1),this.$destroy=j}$on(e,T){if(!zE(T))return j;const t=this.$$.callbacks[e]||(this.$$.callbacks[e]=[]);return t.push(T),()=>{const r=t.indexOf(T);r!==-1&&t.splice(r,1)}}$set(e){this.$$set&&!sn(e)&&(this.$$.skip_bound=!0,this.$$set(e),this.$$.skip_bound=!1)}}const Pn="4";typeof window<"u"&&(window.__svelte||(window.__svelte={v:new Set})).v.add(Pn);const qE=[];function iE(E,e=j){let T;const t=new Set;function r(s){if(_e(E,s)&&(E=s,T)){const S=!qE.length;for(const A of t)A[1](),qE.push(A,E);if(S){for(let A=0;A{t.delete(A),t.size===0&&T&&(T(),T=null)}}return{set:r,update:R,subscribe:n}}let GE=iE(""),YE=iE([]),LT=iE(null),gt=iE(null),Ht=iE(!1),St=iE(!1),DE=iE("chat"),CT=iE([]),Et=iE(""),pR=iE(!1),BE=iE(""),VE=iE({debug:!0,logo:"",title:"Welcome to Vanna.AI",subtitle:"Loading...",show_training_data:!0,suggested_questions:!0,sql:!0,table:!0,csv_download:!0,chart:!0,redraw_chart:!0,auto_fix_sql:!0,ask_results_correct:!0,followup_questions:!0,summarization:!0,function_generation:!0,version:""}),bt=iE(null),MR=iE([]);function UR(){YE.set([]),Ht.set(!1),St.set(!1)}async function uT(E){let e=uE(VE),T=yn();if(Se({type:"user_question",question:E}),Ht.set(!0),T){const n=await Pe("generate_rewritten_question","GET",{last_question:T,new_question:E});n.type==="rewritten_question"&&n.question!==E&&(Se(n),E=n.question)}const t=await Pe("get_function","GET",{question:E});let r;if(e.function_generation&&t.type==="function")Se(t),r=t.id,GE.set(t.id),Et.set(t.function.instantiated_sql);else{const n=await Pe("generate_sql","GET",{question:E});if(Se(n),n.type!=="sql")return;window.location.hash=n.id,GE.set(n.id),Et.set(n.text),r=n.id}const R=await Pe("run_sql","GET",{id:r});if(Se(R),R.type==="df"){if(R.should_generate_chart){const n=await Pe("generate_plotly_figure","GET",{id:R.id});if(Se(n),n.type!=="plotly_figure")return;CT.update(s=>[...s,{question:E,id:n.id}])}if(e.summarization){const n=await Pe("generate_summary","GET",{id:r});Se(n)}Se({type:"feedback_question"}),Se({type:"feedback_buttons"})}}async function Dn(E){let e=uE(VE);if(Se(E),E.type!=="sql")return;window.location.hash=E.id,GE.set(E.id),Et.set(E.text);const T=await Pe("run_sql","GET",{id:E.id});if(Se(T),T.type!=="df")return;const t=await Pe("generate_plotly_figure","GET",{id:T.id});if(Se(t),t.type==="plotly_figure"){if(e.summarization){const r=await Pe("generate_summary","GET",{id:t.id});Se(r)}Se({type:"feedback_question"}),Se({type:"feedback_buttons"})}}function dn(E){Se({type:"user_question",question:"Re-run the SQL"}),Pe("run_sql","GET",{id:E}).then(Se).then(e=>{e.type==="df"&&Pe("generate_plotly_figure","GET",{id:e.id}).then(Se).then(T=>{T.type==="plotly_figure"&&Pe("generate_followup_questions","GET",{id:T.id}).then(Se)})})}function mR(){Pe("get_question_history","GET",[]).then(gn)}function pn(){Pe("get_config","GET",[]).then(Gn)}function cT(){window.location.hash="functions",DE.set("functions"),Pe("get_all_functions","GET",[]).then(mn)}function hR(){window.location.hash="training-data",DE.set("training-data"),Pe("get_training_data","GET",[]).then(pt)}function it(){window.location.hash="",DE.set("chat"),UR(),uE(LT)===null&&Pe("generate_questions","GET",[]).then(hn),mR()}function Mn(E){window.location.hash=E,DE.set("chat"),UR(),Ht.set(!0),Pe("load_question","GET",{id:E}).then(Se)}function Un(E){gt.set(null),Pe("remove_training_data","POST",{id:E}).then(e=>{Pe("get_training_data","GET",[]).then(pt)})}function Se(E){return E.type==="not_logged_in"?(bt.set(E.html),DE.set("login"),E):(YE.update(e=>[...e,E]),bn(),E)}function pt(E){return gt.set(E),E.type==="df"?JSON.parse(E.df).length===0&&DE.set("no-training-data"):E.type==="not_logged_in"&&(bt.set(E.html),DE.set("login")),E}function mn(E){return E.type==="functions"&&MR.set(E.functions),E}function hn(E){return LT.set(E),E}function Gn(E){return E.type==="config"?(VE.set(E.config),E.config.debug&&xn()):E.type==="not_logged_in"&&(bt.set(E.html),DE.set("login")),E}function gn(E){return E.type==="question_history"&&CT.set(E.questions),E}function Hn(E,e){gt.set(null);let T={};T[e]=E,Pe("train","POST",T).then(pt).then(t=>{t.type!=="error"&&Pe("get_training_data","GET",[]).then(pt)})}async function Pe(E,e,T){try{St.set(!0);let t="",r;if(e==="GET")t=Object.entries(T).filter(([n,s])=>n!=="endpoint"&&n!=="addMessage").map(([n,s])=>`${encodeURIComponent(n)}=${encodeURIComponent(s)}`).join("&"),r=await fetch(`/api/v0/${E}?${t}`);else{let n=JSON.stringify(T);r=await fetch(`/api/v0/${E}`,{method:"POST",headers:{"Content-Type":"application/json"},body:n})}if(!r.ok)throw new Error("The server returned an error. See the server logs for more details. If you are running in Colab, this function is probably not supported. Please try running in a local environment.");const R=await r.json();return St.set(!1),R}catch(t){return St.set(!1),{type:"error",error:String(t)}}}function bn(){setTimeout(()=>{window.scrollTo({top:document.body.scrollHeight,behavior:"smooth"})},100)}function yn(){let E=uE(YE),e=E.findLast(T=>T.type==="rewritten_question");return e||(e=E.findLast(T=>T.type==="user_question")),e&&(e.type==="rewritten_question"||e.type==="user_question")?e.question:null}function fT(){let E=uE(YE),e=E.find(T=>T.type==="user_question");if(e&&e.type==="user_question"){let T=E.findLast(t=>t.type==="sql");if(T&&T.type==="sql")return{question:e.question,sql:T.text}}return null}function at(E){YE.update(e=>e.filter(T=>T.type!==E))}function Bn(E){Pe("fix_sql","POST",{id:uE(GE),error:E}).then(Dn)}function vn(E){let T=uE(YE).find(t=>t.type==="user_question");T&&T.type==="user_question"&&(Pe("update_sql","POST",{id:uE(GE),sql:E}).then(Se).then(t=>{t.type==="sql"&&(Et.set(t.text),Pe("run_sql","GET",{id:t.id}).then(Se).then(r=>{r.type==="df"?JSON.parse(r.df).length>1?Pe("generate_plotly_figure","GET",{id:r.id}).then(Se).then(n=>{Se({type:"feedback_question"}),Se({type:"feedback_buttons"})}):(Se({type:"feedback_question"}),Se({type:"feedback_buttons"})):(Se({type:"feedback_question"}),Se({type:"feedback_buttons"}))}))}),at("user_sql"))}function Fn(){Se({type:"chart_modification"})}function Yn(){at("feedback_buttons"),Se({type:"feedback_correct"}),fT()?Pe("create_function","GET",{id:uE(GE)}).then(Se):console.log("No Question-SQL Found")}function Vn(E,e){Pe("update_function","POST",{old_function_name:E,updated_function:e})}function Wn(E){Pe("delete_function","POST",{function_name:E}).finally(()=>{cT()})}function wn(){at("feedback_buttons"),Se({type:"feedback_correct"});let E=fT();E&&(Pe("train","POST",E),Pe("generate_followup_questions","GET",{id:uE(GE)}).then(Se))}function wT(){at("feedback_buttons"),Se({type:"feedback_incorrect"}),Se({type:"user_sql"})}function $n(E){at("chart_modification"),Se({type:"user_question",question:"Update the chart with these instructions: "+E}),Pe("generate_plotly_figure","GET",{id:uE(GE),chart_instructions:E}).then(Se)}function xn(){var E=new WebSocket("ws://"+window.location.host+"/api/v0/log");E.onopen=function(){console.log("Connected to WebSocket server at /log.")},E.onmessage=function(e){console.log("Received message:",e.data);try{var T=JSON.parse(e.data)}catch(r){console.error("Error parsing JSON:",r);return}var t=document.getElementById("log-contents");t&&(t.innerHTML+=""+T.title+" "+JSON.stringify(T.message)+" ")},E.onclose=function(e){console.log("WebSocket connection closed:",e)},E.onerror=function(e){console.error("WebSocket error:",e)}}function $T(E,e,T){const t=E.slice();return t[3]=e[T],t}function xT(E){let e,T,t,r;return{c(){e=f("li"),T=f("button"),T.innerHTML=`
Functions`,a(T,"class","flex items-center gap-x-3 py-2 px-3 text-sm text-slate-700 rounded-md hover:bg-gray-100 dark:hover:bg-gray-900 dark:text-slate-400 dark:hover:text-slate-300 border border-gray-200 dark:border-gray-700 w-full")},m(R,n){V(R,e,n),l(e,T),t||(r=Ne(T,"click",cT),t=!0)},d(R){R&&Y(e),t=!1,r()}}}function XT(E){let e,T,t,r;return{c(){e=f("li"),T=f("button"),T.innerHTML=`
Training Data`,a(T,"class","flex items-center gap-x-3 py-2 px-3 text-sm text-slate-700 rounded-md hover:bg-gray-100 dark:hover:bg-gray-900 dark:text-slate-400 dark:hover:text-slate-300 border border-gray-200 dark:border-gray-700 w-full")},m(R,n){V(R,e,n),l(e,T),t||(r=Ne(T,"click",hR),t=!0)},d(R){R&&Y(e),t=!1,r()}}}function kT(E){let e,T,t,r,R,n=E[3].question+"",s,S,A,o;function i(){return E[2](E[3])}return{c(){e=f("li"),T=f("button"),t=OE("svg"),r=OE("path"),R=$(),s=te(n),S=$(),a(r,"stroke-linecap","round"),a(r,"stroke-linejoin","round"),a(r,"d","M7.5 8.25h9m-9 3H12m-9.75 1.51c0 1.6 1.123 2.994 2.707 3.227 1.129.166 2.27.293 3.423.379.35.026.67.21.865.501L12 21l2.755-4.133a1.14 1.14 0 01.865-.501 48.172 48.172 0 003.423-.379c1.584-.233 2.707-1.626 2.707-3.228V6.741c0-1.602-1.123-2.995-2.707-3.228A48.394 48.394 0 0012 3c-2.392 0-4.744.175-7.043.513C3.373 3.746 2.25 5.14 2.25 6.741v6.018z"),a(t,"class","w-3.5 h-3.5"),a(t,"fill","none"),a(t,"stroke","currentColor"),a(t,"stroke-width","1.5"),a(t,"viewBox","0 0 24 24"),a(t,"xmlns","http://www.w3.org/2000/svg"),a(t,"aria-hidden","true"),a(T,"class","flex items-center text-left gap-x-3 py-2 px-3 text-sm text-slate-700 rounded-md hover:bg-gray-100 dark:hover:bg-gray-900 dark:text-slate-400 dark:hover:text-slate-300")},m(_,c){V(_,e,c),l(e,T),l(T,t),l(t,r),l(T,R),l(T,s),l(e,S),A||(o=Ne(T,"click",i),A=!0)},p(_,c){E=_,c&2&&n!==(n=E[3].question+"")&&Le(s,n)},d(_){_&&Y(e),A=!1,o()}}}function Xn(E){let e,T,t,r,R,n,s,S,A,o,i,_,c,P,p,C,L,I,u,H,b,M=E[0].version+"",O,N,D,B,h,G=E[0].function_generation&&xT(),F=E[0].show_training_data&&XT(),W=De(E[1]),x=[];for(let J=0;JSidebar',S=$(),A=f("div"),o=f("ul"),G&&G.c(),i=$(),F&&F.c(),_=$(),c=f("li"),P=f("button"),P.innerHTML=`
New question`,p=$();for(let J=0;JSign out
`,a(r,"class","w-28 h-auto"),VT(r.src,R=E[0].logo)||a(r,"src",R),a(r,"alt","Vanna Logo"),a(s,"class","lg:hidden"),a(t,"class","flex items-center justify-between py-4 pr-4 pl-7"),a(P,"class","w-full py-2 px-4 inline-flex items-center gap-x-2 text-sm font-semibold rounded-lg border border-transparent bg-blue-600 text-white hover:bg-blue-700 disabled:opacity-50 disabled:pointer-events-none"),a(o,"class","space-y-1.5 p-4"),a(A,"class","h-full"),a(H,"class","block w-1.5 h-1.5 rounded-full bg-green-600"),a(u,"class","inline-flex items-center gap-x-2 text-xs text-green-600"),a(I,"class","py-2.5 px-7"),a(D,"class","p-4 border-t border-gray-200 dark:border-gray-700"),a(L,"class","mt-auto"),a(T,"class","hs-accordion-group w-full h-full flex flex-col"),a(T,"data-hs-accordion-always-open",""),a(e,"id","application-sidebar"),a(e,"class","hs-overlay hs-overlay-open:translate-x-0 -translate-x-full transition-all duration-300 transform hidden fixed top-0 left-0 bottom-0 z-[60] w-64 bg-white border-r border-gray-200 overflow-y-auto scrollbar-y lg:block lg:translate-x-0 lg:right-auto lg:bottom-0 dark:scrollbar-y dark:bg-slate-900 dark:border-gray-700")},m(J,oe){V(J,e,oe),l(e,T),l(T,t),l(t,r),l(t,n),l(t,s),l(T,S),l(T,A),l(A,o),G&&G.m(o,null),l(o,i),F&&F.m(o,null),l(o,_),l(o,c),l(c,P),l(o,p);for(let z=0;zT(0,t=n)),eE(E,CT,n=>T(1,r=n)),[t,r,n=>{Mn(n.id)}]}class Kn extends ue{constructor(e){super(),Ce(this,e,kn,Xn,_e,{})}}var Jn=typeof globalThis<"u"?globalThis:typeof window<"u"?window:typeof global<"u"?global:typeof self<"u"?self:{};function qn(E){return E&&E.__esModule&&Object.prototype.hasOwnProperty.call(E,"default")?E.default:E}var Qn={exports:{}};/*! For license information please see preline.js.LICENSE.txt */(function(E,e){(function(T,t){E.exports=t()})(self,function(){return(()=>{var T={661:(n,s,S)=>{function A(p){return A=typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?function(C){return typeof C}:function(C){return C&&typeof Symbol=="function"&&C.constructor===Symbol&&C!==Symbol.prototype?"symbol":typeof C},A(p)}function o(p,C){for(var L=0;L"u"||!Reflect.construct||Reflect.construct.sham)return!1;if(typeof Proxy=="function")return!0;try{return Boolean.prototype.valueOf.call(Reflect.construct(Boolean,[],function(){})),!0}catch{return!1}}(),function(){var M,O=c(I);if(u){var N=c(this).constructor;M=Reflect.construct(O,arguments,N)}else M=O.apply(this,arguments);return _(this,M)});function b(){return function(M,O){if(!(M instanceof O))throw new TypeError("Cannot call a class as a function")}(this,b),H.call(this,".hs-accordion")}return C=b,(L=[{key:"init",value:function(){var M=this;document.addEventListener("click",function(O){var N=O.target,D=N.closest(M.selector),B=N.closest(".hs-accordion-toggle"),h=N.closest(".hs-accordion-group");D&&h&&B&&(M._hideAll(D),M.show(D))})}},{key:"show",value:function(M){var O=this;if(M.classList.contains("active"))return this.hide(M);M.classList.add("active");var N=M.querySelector(".hs-accordion-content");N.style.display="block",N.style.height=0,setTimeout(function(){N.style.height="".concat(N.scrollHeight,"px")}),this.afterTransition(N,function(){M.classList.contains("active")&&(N.style.height="",O._fireEvent("open",M),O._dispatch("open.hs.accordion",M,M))})}},{key:"hide",value:function(M){var O=this,N=M.querySelector(".hs-accordion-content");N.style.height="".concat(N.scrollHeight,"px"),setTimeout(function(){N.style.height=0}),this.afterTransition(N,function(){M.classList.contains("active")||(N.style.display="",O._fireEvent("hide",M),O._dispatch("hide.hs.accordion",M,M))}),M.classList.remove("active")}},{key:"_hideAll",value:function(M){var O=this,N=M.closest(".hs-accordion-group");N.hasAttribute("data-hs-accordion-always-open")||N.querySelectorAll(this.selector).forEach(function(D){M!==D&&O.hide(D)})}}])&&o(C.prototype,L),Object.defineProperty(C,"prototype",{writable:!1}),b}(S(765).Z);window.HSAccordion=new P,document.addEventListener("load",window.HSAccordion.init())},795:(n,s,S)=>{function A(C){return A=typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?function(L){return typeof L}:function(L){return L&&typeof Symbol=="function"&&L.constructor===Symbol&&L!==Symbol.prototype?"symbol":typeof L},A(C)}function o(C,L){(L==null||L>C.length)&&(L=C.length);for(var I=0,u=new Array(L);I"u"||!Reflect.construct||Reflect.construct.sham)return!1;if(typeof Proxy=="function")return!0;try{return Boolean.prototype.valueOf.call(Reflect.construct(Boolean,[],function(){})),!0}catch{return!1}}(),function(){var O,N=P(u);if(H){var D=P(this).constructor;O=Reflect.construct(N,arguments,D)}else O=N.apply(this,arguments);return c(this,O)});function M(){return function(O,N){if(!(O instanceof N))throw new TypeError("Cannot call a class as a function")}(this,M),b.call(this,"[data-hs-collapse]")}return L=M,(I=[{key:"init",value:function(){var O=this;document.addEventListener("click",function(N){var D=N.target.closest(O.selector);if(D){var B=document.querySelectorAll(D.getAttribute("data-hs-collapse"));O.toggle(B)}})}},{key:"toggle",value:function(O){var N,D=this;O.length&&(N=O,function(B){if(Array.isArray(B))return o(B)}(N)||function(B){if(typeof Symbol<"u"&&B[Symbol.iterator]!=null||B["@@iterator"]!=null)return Array.from(B)}(N)||function(B,h){if(B){if(typeof B=="string")return o(B,h);var G=Object.prototype.toString.call(B).slice(8,-1);return G==="Object"&&B.constructor&&(G=B.constructor.name),G==="Map"||G==="Set"?Array.from(B):G==="Arguments"||/^(?:Ui|I)nt(?:8|16|32)(?:Clamped)?Array$/.test(G)?o(B,h):void 0}}(N)||function(){throw new TypeError(`Invalid attempt to spread non-iterable instance.
In order to be iterable, non-array objects must have a [Symbol.iterator]() method.`)}()).forEach(function(B){B.classList.contains("hidden")?D.show(B):D.hide(B)})}},{key:"show",value:function(O){var N=this;O.classList.add("open"),O.classList.remove("hidden"),O.style.height=0,document.querySelectorAll(this.selector).forEach(function(D){O.closest(D.getAttribute("data-hs-collapse"))&&D.classList.add("open")}),O.style.height="".concat(O.scrollHeight,"px"),this.afterTransition(O,function(){O.classList.contains("open")&&(O.style.height="",N._fireEvent("open",O),N._dispatch("open.hs.collapse",O,O))})}},{key:"hide",value:function(O){var N=this;O.style.height="".concat(O.scrollHeight,"px"),setTimeout(function(){O.style.height=0}),O.classList.remove("open"),this.afterTransition(O,function(){O.classList.contains("open")||(O.classList.add("hidden"),O.style.height=null,N._fireEvent("hide",O),N._dispatch("hide.hs.collapse",O,O),O.querySelectorAll(".hs-mega-menu-content.block").forEach(function(D){D.classList.remove("block"),D.classList.add("hidden")}))}),document.querySelectorAll(this.selector).forEach(function(D){O.closest(D.getAttribute("data-hs-collapse"))&&D.classList.remove("open")})}}])&&i(L.prototype,I),Object.defineProperty(L,"prototype",{writable:!1}),M}(S(765).Z);window.HSCollapse=new p,document.addEventListener("load",window.HSCollapse.init())},682:(n,s,S)=>{var A=S(714),o=S(765);const i={historyIndex:-1,addHistory:function(H){this.historyIndex=H},existsInHistory:function(H){return H>this.historyIndex},clearHistory:function(){this.historyIndex=-1}};function _(H){return _=typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?function(b){return typeof b}:function(b){return b&&typeof Symbol=="function"&&b.constructor===Symbol&&b!==Symbol.prototype?"symbol":typeof b},_(H)}function c(H){return function(b){if(Array.isArray(b))return P(b)}(H)||function(b){if(typeof Symbol<"u"&&b[Symbol.iterator]!=null||b["@@iterator"]!=null)return Array.from(b)}(H)||function(b,M){if(b){if(typeof b=="string")return P(b,M);var O=Object.prototype.toString.call(b).slice(8,-1);return O==="Object"&&b.constructor&&(O=b.constructor.name),O==="Map"||O==="Set"?Array.from(b):O==="Arguments"||/^(?:Ui|I)nt(?:8|16|32)(?:Clamped)?Array$/.test(O)?P(b,M):void 0}}(H)||function(){throw new TypeError(`Invalid attempt to spread non-iterable instance.
In order to be iterable, non-array objects must have a [Symbol.iterator]() method.`)}()}function P(H,b){(b==null||b>H.length)&&(b=H.length);for(var M=0,O=new Array(b);M"u"||!Reflect.construct||Reflect.construct.sham)return!1;if(typeof Proxy=="function")return!0;try{return Boolean.prototype.valueOf.call(Reflect.construct(Boolean,[],function(){})),!0}catch{return!1}}(),function(){var h,G=I(O);if(N){var F=I(this).constructor;h=Reflect.construct(G,arguments,F)}else h=G.apply(this,arguments);return L(this,h)});function B(){var h;return function(G,F){if(!(G instanceof F))throw new TypeError("Cannot call a class as a function")}(this,B),(h=D.call(this,".hs-dropdown")).positions={top:"top","top-left":"top-start","top-right":"top-end",bottom:"bottom","bottom-left":"bottom-start","bottom-right":"bottom-end",right:"right","right-top":"right-start","right-bottom":"right-end",left:"left","left-top":"left-start","left-bottom":"left-end"},h.absoluteStrategyModifiers=function(G){return[{name:"applyStyles",fn:function(F){var W=(window.getComputedStyle(G).getPropertyValue("--strategy")||"absolute").replace(" ",""),x=(window.getComputedStyle(G).getPropertyValue("--adaptive")||"adaptive").replace(" ","");F.state.elements.popper.style.position=W,F.state.elements.popper.style.transform=x==="adaptive"?F.state.styles.popper.transform:null,F.state.elements.popper.style.top=null,F.state.elements.popper.style.bottom=null,F.state.elements.popper.style.left=null,F.state.elements.popper.style.right=null,F.state.elements.popper.style.margin=0}},{name:"computeStyles",options:{adaptive:!1}}]},h._history=i,h}return b=B,M=[{key:"init",value:function(){var h=this;document.addEventListener("click",function(G){var F=G.target,W=F.closest(h.selector),x=F.closest(".hs-dropdown-menu");if(W&&W.classList.contains("open")||h._closeOthers(W),x){var J=(window.getComputedStyle(W).getPropertyValue("--auto-close")||"").replace(" ","");if((J=="false"||J=="inside")&&!W.parentElement.closest(h.selector))return}W&&(W.classList.contains("open")?h.close(W):h.open(W))}),document.addEventListener("mousemove",function(G){var F=G.target,W=F.closest(h.selector);if(F.closest(".hs-dropdown-menu"),W){var x=(window.getComputedStyle(W).getPropertyValue("--trigger")||"click").replace(" ","");if(x!=="hover")return;W&&W.classList.contains("open")||h._closeOthers(W),x!=="hover"||W.classList.contains("open")||/iPad|iPhone|iPod/.test(navigator.platform)||navigator.maxTouchPoints&&navigator.maxTouchPoints>2&&/MacIntel/.test(navigator.platform)||navigator.maxTouchPoints&&navigator.maxTouchPoints>2&&/MacIntel/.test(navigator.platform)||h._hover(F)}}),document.addEventListener("keydown",this._keyboardSupport.bind(this)),window.addEventListener("resize",function(){document.querySelectorAll(".hs-dropdown.open").forEach(function(G){h.close(G,!0)})})}},{key:"_closeOthers",value:function(){var h=this,G=arguments.length>0&&arguments[0]!==void 0?arguments[0]:null,F=document.querySelectorAll("".concat(this.selector,".open"));F.forEach(function(W){if(!G||G.closest(".hs-dropdown.open")!==W){var x=(window.getComputedStyle(W).getPropertyValue("--auto-close")||"").replace(" ","");x!="false"&&x!="outside"&&h.close(W)}})}},{key:"_hover",value:function(h){var G=this,F=h.closest(this.selector);this.open(F),document.addEventListener("mousemove",function W(x){x.target.closest(G.selector)&&x.target.closest(G.selector)!==F.parentElement.closest(G.selector)||(G.close(F),document.removeEventListener("mousemove",W,!0))},!0)}},{key:"close",value:function(h){var G=this,F=arguments.length>1&&arguments[1]!==void 0&&arguments[1],W=h.querySelector(".hs-dropdown-menu"),x=function(){h.classList.contains("open")||(W.classList.remove("block"),W.classList.add("hidden"),W.style.inset=null,W.style.position=null,h._popper&&h._popper.destroy())};F||this.afterTransition(h.querySelector("[data-hs-dropdown-transition]")||W,function(){x()}),W.style.margin=null,h.classList.remove("open"),F&&x(),this._fireEvent("close",h),this._dispatch("close.hs.dropdown",h,h);var J=W.querySelectorAll(".hs-dropdown.open");J.forEach(function(oe){G.close(oe,!0)})}},{key:"open",value:function(h){var G=h.querySelector(".hs-dropdown-menu"),F=(window.getComputedStyle(h).getPropertyValue("--placement")||"").replace(" ",""),W=(window.getComputedStyle(h).getPropertyValue("--strategy")||"fixed").replace(" ",""),x=((window.getComputedStyle(h).getPropertyValue("--adaptive")||"adaptive").replace(" ",""),parseInt((window.getComputedStyle(h).getPropertyValue("--offset")||"10").replace(" ","")));if(W!=="static"){h._popper&&h._popper.destroy();var J=(0,A.fi)(h,G,{placement:this.positions[F]||"bottom-start",strategy:W,modifiers:[].concat(c(W!=="fixed"?this.absoluteStrategyModifiers(h):[]),[{name:"offset",options:{offset:[0,x]}}])});h._popper=J}G.style.margin=null,G.classList.add("block"),G.classList.remove("hidden"),setTimeout(function(){h.classList.add("open")}),this._fireEvent("open",h),this._dispatch("open.hs.dropdown",h,h)}},{key:"_keyboardSupport",value:function(h){var G=document.querySelector(".hs-dropdown.open");if(G)return h.keyCode===27?(h.preventDefault(),this._esc(G)):h.keyCode===40?(h.preventDefault(),this._down(G)):h.keyCode===38?(h.preventDefault(),this._up(G)):h.keyCode===36?(h.preventDefault(),this._start(G)):h.keyCode===35?(h.preventDefault(),this._end(G)):void this._byChar(G,h.key)}},{key:"_esc",value:function(h){this.close(h)}},{key:"_up",value:function(h){var G=h.querySelector(".hs-dropdown-menu"),F=c(G.querySelectorAll("a")).reverse().filter(function(J){return!J.disabled}),W=G.querySelector("a:focus"),x=F.findIndex(function(J){return J===W});x+1{function A(C){return A=typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?function(L){return typeof L}:function(L){return L&&typeof Symbol=="function"&&L.constructor===Symbol&&L!==Symbol.prototype?"symbol":typeof L},A(C)}function o(C,L){(L==null||L>C.length)&&(L=C.length);for(var I=0,u=new Array(L);I"u"||!Reflect.construct||Reflect.construct.sham)return!1;if(typeof Proxy=="function")return!0;try{return Boolean.prototype.valueOf.call(Reflect.construct(Boolean,[],function(){})),!0}catch{return!1}}(),function(){var O,N=P(u);if(H){var D=P(this).constructor;O=Reflect.construct(N,arguments,D)}else O=N.apply(this,arguments);return c(this,O)});function M(){var O;return function(N,D){if(!(N instanceof D))throw new TypeError("Cannot call a class as a function")}(this,M),(O=b.call(this,"[data-hs-overlay]")).openNextOverlay=!1,O}return L=M,(I=[{key:"init",value:function(){var O=this;document.addEventListener("click",function(N){var D=N.target.closest(O.selector),B=N.target.closest("[data-hs-overlay-close]"),h=N.target.getAttribute("aria-overlay")==="true";return B?O.close(B.closest(".hs-overlay.open")):D?O.toggle(document.querySelector(D.getAttribute("data-hs-overlay"))):void(h&&O._onBackdropClick(N.target))}),document.addEventListener("keydown",function(N){if(N.keyCode===27){var D=document.querySelector(".hs-overlay.open");if(!D)return;setTimeout(function(){D.getAttribute("data-hs-overlay-keyboard")!=="false"&&O.close(D)})}})}},{key:"toggle",value:function(O){O&&(O.classList.contains("hidden")?this.open(O):this.close(O))}},{key:"open",value:function(O){var N=this;if(O){var D=document.querySelector(".hs-overlay.open"),B=this.getClassProperty(O,"--body-scroll","false")!=="true";if(D)return this.openNextOverlay=!0,this.close(D).then(function(){N.open(O),N.openNextOverlay=!1});B&&(document.body.style.overflow="hidden"),this._buildBackdrop(O),this._checkTimer(O),this._autoHide(O),O.classList.remove("hidden"),O.setAttribute("aria-overlay","true"),O.setAttribute("tabindex","-1"),setTimeout(function(){O.classList.contains("hidden")||(O.classList.add("open"),N._fireEvent("open",O),N._dispatch("open.hs.overlay",O,O),N._focusInput(O))},50)}}},{key:"close",value:function(O){var N=this;return new Promise(function(D){O&&(O.classList.remove("open"),O.removeAttribute("aria-overlay"),O.removeAttribute("tabindex","-1"),N.afterTransition(O,function(){O.classList.contains("open")||(O.classList.add("hidden"),N._destroyBackdrop(),N._fireEvent("close",O),N._dispatch("close.hs.overlay",O,O),document.body.style.overflow="",D(O))}))})}},{key:"_autoHide",value:function(O){var N=this,D=parseInt(this.getClassProperty(O,"--auto-hide","0"));D&&(O.autoHide=setTimeout(function(){N.close(O)},D))}},{key:"_checkTimer",value:function(O){O.autoHide&&(clearTimeout(O.autoHide),delete O.autoHide)}},{key:"_onBackdropClick",value:function(O){this.getClassProperty(O,"--overlay-backdrop","true")!=="static"&&this.close(O)}},{key:"_buildBackdrop",value:function(O){var N,D=this,B=O.getAttribute("data-hs-overlay-backdrop-container")||!1,h=document.createElement("div"),G="transition duration fixed inset-0 z-50 bg-gray-900 bg-opacity-50 dark:bg-opacity-80 hs-overlay-backdrop",F=function(J,oe){var z=typeof Symbol<"u"&&J[Symbol.iterator]||J["@@iterator"];if(!z){if(Array.isArray(J)||(z=function(me,fE){if(me){if(typeof me=="string")return o(me,fE);var rE=Object.prototype.toString.call(me).slice(8,-1);return rE==="Object"&&me.constructor&&(rE=me.constructor.name),rE==="Map"||rE==="Set"?Array.from(me):rE==="Arguments"||/^(?:Ui|I)nt(?:8|16|32)(?:Clamped)?Array$/.test(rE)?o(me,fE):void 0}}(J))||oe&&J&&typeof J.length=="number"){z&&(J=z);var Oe=0,Ue=function(){};return{s:Ue,n:function(){return Oe>=J.length?{done:!0}:{done:!1,value:J[Oe++]}},e:function(me){throw me},f:Ue}}throw new TypeError(`Invalid attempt to iterate non-iterable instance.
In order to be iterable, non-array objects must have a [Symbol.iterator]() method.`)}var ye,TE=!0,dE=!1;return{s:function(){z=z.call(J)},n:function(){var me=z.next();return TE=me.done,me},e:function(me){dE=!0,ye=me},f:function(){try{TE||z.return==null||z.return()}finally{if(dE)throw ye}}}}(O.classList.values());try{for(F.s();!(N=F.n()).done;){var W=N.value;W.startsWith("hs-overlay-backdrop-open:")&&(G+=" ".concat(W))}}catch(J){F.e(J)}finally{F.f()}var x=this.getClassProperty(O,"--overlay-backdrop","true")!=="static";this.getClassProperty(O,"--overlay-backdrop","true")==="false"||(B&&((h=document.querySelector(B).cloneNode(!0)).classList.remove("hidden"),G=h.classList,h.classList=""),x&&h.addEventListener("click",function(){return D.close(O)},!0),h.setAttribute("data-hs-overlay-backdrop-template",""),document.body.appendChild(h),setTimeout(function(){h.classList=G}))}},{key:"_destroyBackdrop",value:function(){var O=document.querySelector("[data-hs-overlay-backdrop-template]");O&&(this.openNextOverlay&&(O.style.transitionDuration="".concat(1.8*parseFloat(window.getComputedStyle(O).transitionDuration.replace(/[^\\d.-]/g,"")),"s")),O.classList.add("opacity-0"),this.afterTransition(O,function(){O.remove()}))}},{key:"_focusInput",value:function(O){var N=O.querySelector("[autofocus]");N&&N.focus()}}])&&i(L.prototype,I),Object.defineProperty(L,"prototype",{writable:!1}),M}(S(765).Z);window.HSOverlay=new p,document.addEventListener("load",window.HSOverlay.init())},181:(n,s,S)=>{function A(p){return A=typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?function(C){return typeof C}:function(C){return C&&typeof Symbol=="function"&&C.constructor===Symbol&&C!==Symbol.prototype?"symbol":typeof C},A(p)}function o(p,C){for(var L=0;L"u"||!Reflect.construct||Reflect.construct.sham)return!1;if(typeof Proxy=="function")return!0;try{return Boolean.prototype.valueOf.call(Reflect.construct(Boolean,[],function(){})),!0}catch{return!1}}(),function(){var M,O=c(I);if(u){var N=c(this).constructor;M=Reflect.construct(O,arguments,N)}else M=O.apply(this,arguments);return _(this,M)});function b(){return function(M,O){if(!(M instanceof O))throw new TypeError("Cannot call a class as a function")}(this,b),H.call(this,"[data-hs-remove-element]")}return C=b,(L=[{key:"init",value:function(){var M=this;document.addEventListener("click",function(O){var N=O.target.closest(M.selector);if(N){var D=document.querySelector(N.getAttribute("data-hs-remove-element"));D&&(D.classList.add("hs-removing"),M.afterTransition(D,function(){D.remove()}))}})}}])&&o(C.prototype,L),Object.defineProperty(C,"prototype",{writable:!1}),b}(S(765).Z);window.HSRemoveElement=new P,document.addEventListener("load",window.HSRemoveElement.init())},778:(n,s,S)=>{function A(p){return A=typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?function(C){return typeof C}:function(C){return C&&typeof Symbol=="function"&&C.constructor===Symbol&&C!==Symbol.prototype?"symbol":typeof C},A(p)}function o(p,C){for(var L=0;L"u"||!Reflect.construct||Reflect.construct.sham)return!1;if(typeof Proxy=="function")return!0;try{return Boolean.prototype.valueOf.call(Reflect.construct(Boolean,[],function(){})),!0}catch{return!1}}(),function(){var M,O=c(I);if(u){var N=c(this).constructor;M=Reflect.construct(O,arguments,N)}else M=O.apply(this,arguments);return _(this,M)});function b(){var M;return function(O,N){if(!(O instanceof N))throw new TypeError("Cannot call a class as a function")}(this,b),(M=H.call(this,"[data-hs-scrollspy] ")).activeSection=null,M}return C=b,(L=[{key:"init",value:function(){var M=this;document.querySelectorAll(this.selector).forEach(function(O){var N=document.querySelector(O.getAttribute("data-hs-scrollspy")),D=O.querySelectorAll("[href]"),B=N.children,h=O.getAttribute("data-hs-scrollspy-scrollable-parent")?document.querySelector(O.getAttribute("data-hs-scrollspy-scrollable-parent")):document;Array.from(B).forEach(function(G){G.getAttribute("id")&&h.addEventListener("scroll",function(F){return M._update({$scrollspyEl:O,$scrollspyContentEl:N,links:D,$sectionEl:G,sections:B,ev:F})})}),D.forEach(function(G){G.addEventListener("click",function(F){F.preventDefault(),G.getAttribute("href")!=="javascript:;"&&M._scrollTo({$scrollspyEl:O,$scrollableEl:h,$link:G})})})})}},{key:"_update",value:function(M){var O=M.ev,N=M.$scrollspyEl,D=(M.sections,M.links),B=M.$sectionEl,h=parseInt(this.getClassProperty(N,"--scrollspy-offset","0")),G=this.getClassProperty(B,"--scrollspy-offset")||h,F=O.target===document?0:parseInt(O.target.getBoundingClientRect().top),W=parseInt(B.getBoundingClientRect().top)-G-F,x=B.offsetHeight;if(W<=0&&W+x>0){if(this.activeSection===B)return;D.forEach(function(Oe){Oe.classList.remove("active")});var J=N.querySelector('[href="#'.concat(B.getAttribute("id"),'"]'));if(J){J.classList.add("active");var oe=J.closest("[data-hs-scrollspy-group]");if(oe){var z=oe.querySelector("[href]");z&&z.classList.add("active")}}this.activeSection=B}}},{key:"_scrollTo",value:function(M){var O=M.$scrollspyEl,N=M.$scrollableEl,D=M.$link,B=document.querySelector(D.getAttribute("href")),h=parseInt(this.getClassProperty(O,"--scrollspy-offset","0")),G=this.getClassProperty(B,"--scrollspy-offset")||h,F=N===document?0:N.offsetTop,W=B.offsetTop-G-F,x=N===document?window:N;this._fireEvent("scroll",O),this._dispatch("scroll.hs.scrollspy",O,O),window.history.replaceState(null,null,D.getAttribute("href")),x.scrollTo({top:W,left:0,behavior:"smooth"})}}])&&o(C.prototype,L),Object.defineProperty(C,"prototype",{writable:!1}),b}(S(765).Z);window.HSScrollspy=new P,document.addEventListener("load",window.HSScrollspy.init())},51:(n,s,S)=>{function A(L){return A=typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?function(I){return typeof I}:function(I){return I&&typeof Symbol=="function"&&I.constructor===Symbol&&I!==Symbol.prototype?"symbol":typeof I},A(L)}function o(L){return function(I){if(Array.isArray(I))return i(I)}(L)||function(I){if(typeof Symbol<"u"&&I[Symbol.iterator]!=null||I["@@iterator"]!=null)return Array.from(I)}(L)||function(I,u){if(I){if(typeof I=="string")return i(I,u);var H=Object.prototype.toString.call(I).slice(8,-1);return H==="Object"&&I.constructor&&(H=I.constructor.name),H==="Map"||H==="Set"?Array.from(I):H==="Arguments"||/^(?:Ui|I)nt(?:8|16|32)(?:Clamped)?Array$/.test(H)?i(I,u):void 0}}(L)||function(){throw new TypeError(`Invalid attempt to spread non-iterable instance.
In order to be iterable, non-array objects must have a [Symbol.iterator]() method.`)}()}function i(L,I){(I==null||I>L.length)&&(I=L.length);for(var u=0,H=new Array(I);u"u"||!Reflect.construct||Reflect.construct.sham)return!1;if(typeof Proxy=="function")return!0;try{return Boolean.prototype.valueOf.call(Reflect.construct(Boolean,[],function(){})),!0}catch{return!1}}(),function(){var N,D=p(H);if(b){var B=p(this).constructor;N=Reflect.construct(D,arguments,B)}else N=D.apply(this,arguments);return P(this,N)});function O(){return function(N,D){if(!(N instanceof D))throw new TypeError("Cannot call a class as a function")}(this,O),M.call(this,"[data-hs-tab]")}return I=O,(u=[{key:"init",value:function(){var N=this;document.addEventListener("keydown",this._keyboardSupport.bind(this)),document.addEventListener("click",function(D){var B=D.target.closest(N.selector);B&&N.open(B)}),document.querySelectorAll("[hs-data-tab-select]").forEach(function(D){var B=document.querySelector(D.getAttribute("hs-data-tab-select"));B&&B.addEventListener("change",function(h){var G=document.querySelector('[data-hs-tab="'.concat(h.target.value,'"]'));G&&N.open(G)})})}},{key:"open",value:function(N){var D=document.querySelector(N.getAttribute("data-hs-tab")),B=o(N.parentElement.children),h=o(D.parentElement.children),G=N.closest("[hs-data-tab-select]"),F=G?document.querySelector(G.getAttribute("data-hs-tab")):null;B.forEach(function(W){return W.classList.remove("active")}),h.forEach(function(W){return W.classList.add("hidden")}),N.classList.add("active"),D.classList.remove("hidden"),this._fireEvent("change",N),this._dispatch("change.hs.tab",N,N),F&&(F.value=N.getAttribute("data-hs-tab"))}},{key:"_keyboardSupport",value:function(N){var D=N.target.closest(this.selector);if(D){var B=D.closest('[role="tablist"]').getAttribute("data-hs-tabs-vertical")==="true";return(B?N.keyCode===38:N.keyCode===37)?(N.preventDefault(),this._left(D)):(B?N.keyCode===40:N.keyCode===39)?(N.preventDefault(),this._right(D)):N.keyCode===36?(N.preventDefault(),this._start(D)):N.keyCode===35?(N.preventDefault(),this._end(D)):void 0}}},{key:"_right",value:function(N){var D=N.closest('[role="tablist"]');if(D){var B=o(D.querySelectorAll(this.selector)).filter(function(F){return!F.disabled}),h=D.querySelector("button:focus"),G=B.findIndex(function(F){return F===h});G+1{var A=S(765),o=S(714);function i(L){return i=typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?function(I){return typeof I}:function(I){return I&&typeof Symbol=="function"&&I.constructor===Symbol&&I!==Symbol.prototype?"symbol":typeof I},i(L)}function _(L,I){for(var u=0;u"u"||!Reflect.construct||Reflect.construct.sham)return!1;if(typeof Proxy=="function")return!0;try{return Boolean.prototype.valueOf.call(Reflect.construct(Boolean,[],function(){})),!0}catch{return!1}}(),function(){var N,D=p(H);if(b){var B=p(this).constructor;N=Reflect.construct(D,arguments,B)}else N=D.apply(this,arguments);return P(this,N)});function O(){return function(N,D){if(!(N instanceof D))throw new TypeError("Cannot call a class as a function")}(this,O),M.call(this,".hs-tooltip")}return I=O,(u=[{key:"init",value:function(){var N=this;document.addEventListener("click",function(D){var B=D.target.closest(N.selector);B&&N.getClassProperty(B,"--trigger")==="focus"&&N._focus(B),B&&N.getClassProperty(B,"--trigger")==="click"&&N._click(B)}),document.addEventListener("mousemove",function(D){var B=D.target.closest(N.selector);B&&N.getClassProperty(B,"--trigger")!=="focus"&&N.getClassProperty(B,"--trigger")!=="click"&&N._hover(B)})}},{key:"_hover",value:function(N){var D=this;if(!N.classList.contains("show")){var B=N.querySelector(".hs-tooltip-toggle"),h=N.querySelector(".hs-tooltip-content"),G=this.getClassProperty(N,"--placement");(0,o.fi)(B,h,{placement:G||"top",strategy:"fixed",modifiers:[{name:"offset",options:{offset:[0,5]}}]}),this.show(N),N.addEventListener("mouseleave",function F(W){W.relatedTarget.closest(D.selector)&&W.relatedTarget.closest(D.selector)==N||(D.hide(N),N.removeEventListener("mouseleave",F,!0))},!0)}}},{key:"_focus",value:function(N){var D=this,B=N.querySelector(".hs-tooltip-toggle"),h=N.querySelector(".hs-tooltip-content"),G=this.getClassProperty(N,"--placement"),F=this.getClassProperty(N,"--strategy");(0,o.fi)(B,h,{placement:G||"top",strategy:F||"fixed",modifiers:[{name:"offset",options:{offset:[0,5]}}]}),this.show(N),N.addEventListener("blur",function W(){D.hide(N),N.removeEventListener("blur",W,!0)},!0)}},{key:"_click",value:function(N){var D=this;if(!N.classList.contains("show")){var B=N.querySelector(".hs-tooltip-toggle"),h=N.querySelector(".hs-tooltip-content"),G=this.getClassProperty(N,"--placement"),F=this.getClassProperty(N,"--strategy");(0,o.fi)(B,h,{placement:G||"top",strategy:F||"fixed",modifiers:[{name:"offset",options:{offset:[0,5]}}]}),this.show(N);var W=function x(J){setTimeout(function(){D.hide(N),N.removeEventListener("click",x,!0),N.removeEventListener("blur",x,!0)})};N.addEventListener("blur",W,!0),N.addEventListener("click",W,!0)}}},{key:"show",value:function(N){var D=this;N.querySelector(".hs-tooltip-content").classList.remove("hidden"),setTimeout(function(){N.classList.add("show"),D._fireEvent("show",N),D._dispatch("show.hs.tooltip",N,N)})}},{key:"hide",value:function(N){var D=N.querySelector(".hs-tooltip-content");N.classList.remove("show"),this._fireEvent("hide",N),this._dispatch("hide.hs.tooltip",N,N),this.afterTransition(D,function(){N.classList.contains("show")||D.classList.add("hidden")})}}])&&_(I.prototype,u),Object.defineProperty(I,"prototype",{writable:!1}),O}(A.Z);window.HSTooltip=new C,document.addEventListener("load",window.HSTooltip.init())},765:(n,s,S)=>{function A(i,_){for(var c=0;c<_.length;c++){var P=_[c];P.enumerable=P.enumerable||!1,P.configurable=!0,"value"in P&&(P.writable=!0),Object.defineProperty(i,P.key,P)}}S.d(s,{Z:()=>o});var o=function(){function i(P,p){(function(C,L){if(!(C instanceof L))throw new TypeError("Cannot call a class as a function")})(this,i),this.$collection=[],this.selector=P,this.config=p,this.events={}}var _,c;return _=i,c=[{key:"_fireEvent",value:function(P){var p=arguments.length>1&&arguments[1]!==void 0?arguments[1]:null;this.events.hasOwnProperty(P)&&this.events[P](p)}},{key:"_dispatch",value:function(P,p){var C=arguments.length>2&&arguments[2]!==void 0?arguments[2]:null,L=new CustomEvent(P,{detail:{payload:C},bubbles:!0,cancelable:!0,composed:!1});p.dispatchEvent(L)}},{key:"on",value:function(P,p){this.events[P]=p}},{key:"afterTransition",value:function(P,p){window.getComputedStyle(P,null).getPropertyValue("transition")!=="all 0s ease 0s"?P.addEventListener("transitionend",function C(){p(),P.removeEventListener("transitionend",C,!0)},!0):p()}},{key:"getClassProperty",value:function(P,p){var C=arguments.length>2&&arguments[2]!==void 0?arguments[2]:"",L=(window.getComputedStyle(P).getPropertyValue(p)||C).replace(" ","");return L}}],c&&A(_.prototype,c),Object.defineProperty(_,"prototype",{writable:!1}),i}()},714:(n,s,S)=>{function A(U){if(U==null)return window;if(U.toString()!=="[object Window]"){var d=U.ownerDocument;return d&&d.defaultView||window}return U}function o(U){return U instanceof A(U).Element||U instanceof Element}function i(U){return U instanceof A(U).HTMLElement||U instanceof HTMLElement}function _(U){return typeof ShadowRoot<"u"&&(U instanceof A(U).ShadowRoot||U instanceof ShadowRoot)}S.d(s,{fi:()=>tn});var c=Math.max,P=Math.min,p=Math.round;function C(U,d){d===void 0&&(d=!1);var g=U.getBoundingClientRect(),w=1,Z=1;if(i(U)&&d){var q=U.offsetHeight,Q=U.offsetWidth;Q>0&&(w=p(g.width)/Q||1),q>0&&(Z=p(g.height)/q||1)}return{width:g.width/w,height:g.height/Z,top:g.top/Z,right:g.right/w,bottom:g.bottom/Z,left:g.left/w,x:g.left/w,y:g.top/Z}}function L(U){var d=A(U);return{scrollLeft:d.pageXOffset,scrollTop:d.pageYOffset}}function I(U){return U?(U.nodeName||"").toLowerCase():null}function u(U){return((o(U)?U.ownerDocument:U.document)||window.document).documentElement}function H(U){return C(u(U)).left+L(U).scrollLeft}function b(U){return A(U).getComputedStyle(U)}function M(U){var d=b(U),g=d.overflow,w=d.overflowX,Z=d.overflowY;return/auto|scroll|overlay|hidden/.test(g+Z+w)}function O(U,d,g){g===void 0&&(g=!1);var w,Z,q=i(d),Q=i(d)&&function(re){var We=re.getBoundingClientRect(),ae=p(We.width)/re.offsetWidth||1,Me=p(We.height)/re.offsetHeight||1;return ae!==1||Me!==1}(d),ee=u(d),Ee=C(U,Q),ne={scrollLeft:0,scrollTop:0},Ae={x:0,y:0};return(q||!q&&!g)&&((I(d)!=="body"||M(ee))&&(ne=(w=d)!==A(w)&&i(w)?{scrollLeft:(Z=w).scrollLeft,scrollTop:Z.scrollTop}:L(w)),i(d)?((Ae=C(d,!0)).x+=d.clientLeft,Ae.y+=d.clientTop):ee&&(Ae.x=H(ee))),{x:Ee.left+ne.scrollLeft-Ae.x,y:Ee.top+ne.scrollTop-Ae.y,width:Ee.width,height:Ee.height}}function N(U){var d=C(U),g=U.offsetWidth,w=U.offsetHeight;return Math.abs(d.width-g)<=1&&(g=d.width),Math.abs(d.height-w)<=1&&(w=d.height),{x:U.offsetLeft,y:U.offsetTop,width:g,height:w}}function D(U){return I(U)==="html"?U:U.assignedSlot||U.parentNode||(_(U)?U.host:null)||u(U)}function B(U){return["html","body","#document"].indexOf(I(U))>=0?U.ownerDocument.body:i(U)&&M(U)?U:B(D(U))}function h(U,d){var g;d===void 0&&(d=[]);var w=B(U),Z=w===((g=U.ownerDocument)==null?void 0:g.body),q=A(w),Q=Z?[q].concat(q.visualViewport||[],M(w)?w:[]):w,ee=d.concat(Q);return Z?ee:ee.concat(h(D(Q)))}function G(U){return["table","td","th"].indexOf(I(U))>=0}function F(U){return i(U)&&b(U).position!=="fixed"?U.offsetParent:null}function W(U){for(var d=A(U),g=F(U);g&&G(g)&&b(g).position==="static";)g=F(g);return g&&(I(g)==="html"||I(g)==="body"&&b(g).position==="static")?d:g||function(w){var Z=navigator.userAgent.toLowerCase().indexOf("firefox")!==-1;if(navigator.userAgent.indexOf("Trident")!==-1&&i(w)&&b(w).position==="fixed")return null;for(var q=D(w);i(q)&&["html","body"].indexOf(I(q))<0;){var Q=b(q);if(Q.transform!=="none"||Q.perspective!=="none"||Q.contain==="paint"||["transform","perspective"].indexOf(Q.willChange)!==-1||Z&&Q.willChange==="filter"||Z&&Q.filter&&Q.filter!=="none")return q;q=q.parentNode}return null}(U)||d}var x="top",J="bottom",oe="right",z="left",Oe="auto",Ue=[x,J,oe,z],ye="start",TE="end",dE="viewport",me="popper",fE=Ue.reduce(function(U,d){return U.concat([d+"-"+ye,d+"-"+TE])},[]),rE=[].concat(Ue,[Oe]).reduce(function(U,d){return U.concat([d,d+"-"+ye,d+"-"+TE])},[]),pE=["beforeRead","read","afterRead","beforeMain","main","afterMain","beforeWrite","write","afterWrite"];function Tt(U){var d=new Map,g=new Set,w=[];function Z(q){g.add(q.name),[].concat(q.requires||[],q.requiresIfExists||[]).forEach(function(Q){if(!g.has(Q)){var ee=d.get(Q);ee&&Z(ee)}}),w.push(q)}return U.forEach(function(q){d.set(q.name,q)}),U.forEach(function(q){g.has(q.name)||Z(q)}),w}var rt={placement:"bottom",modifiers:[],strategy:"absolute"};function Ve(){for(var U=arguments.length,d=new Array(U),g=0;g=0?"x":"y"}function Te(U){var d,g=U.reference,w=U.element,Z=U.placement,q=Z?ze(Z):null,Q=Z?ve(Z):null,ee=g.x+g.width/2-w.width/2,Ee=g.y+g.height/2-w.height/2;switch(q){case x:d={x:ee,y:g.y-w.height};break;case J:d={x:ee,y:g.y+g.height};break;case oe:d={x:g.x+g.width,y:Ee};break;case z:d={x:g.x-w.width,y:Ee};break;default:d={x:g.x,y:g.y}}var ne=q?Fe(q):null;if(ne!=null){var Ae=ne==="y"?"height":"width";switch(Q){case ye:d[ne]=d[ne]-(g[Ae]/2-w[Ae]/2);break;case TE:d[ne]=d[ne]+(g[Ae]/2-w[Ae]/2)}}return d}var Xe={top:"auto",right:"auto",bottom:"auto",left:"auto"};function se(U){var d,g=U.popper,w=U.popperRect,Z=U.placement,q=U.variation,Q=U.offsets,ee=U.position,Ee=U.gpuAcceleration,ne=U.adaptive,Ae=U.roundOffsets,re=U.isFixed,We=Q.x,ae=We===void 0?0:We,Me=Q.y,de=Me===void 0?0:Me,ie=typeof Ae=="function"?Ae({x:ae,y:de}):{x:ae,y:de};ae=ie.x,de=ie.y;var we=Q.hasOwnProperty("x"),be=Q.hasOwnProperty("y"),ce=z,he=x,Ie=window;if(ne){var le=W(g),fe="clientHeight",pe="clientWidth";le===A(g)&&b(le=u(g)).position!=="static"&&ee==="absolute"&&(fe="scrollHeight",pe="scrollWidth"),le=le,(Z===x||(Z===z||Z===oe)&&q===TE)&&(he=J,de-=(re&&Ie.visualViewport?Ie.visualViewport.height:le[fe])-w.height,de*=Ee?1:-1),Z!==z&&(Z!==x&&Z!==J||q!==TE)||(ce=oe,ae-=(re&&Ie.visualViewport?Ie.visualViewport.width:le[pe])-w.width,ae*=Ee?1:-1)}var He,ke=Object.assign({position:ee},ne&&Xe),Ke=Ae===!0?function(sE){var lE=sE.x,PE=sE.y,SE=window.devicePixelRatio||1;return{x:p(lE*SE)/SE||0,y:p(PE*SE)/SE||0}}({x:ae,y:de}):{x:ae,y:de};return ae=Ke.x,de=Ke.y,Ee?Object.assign({},ke,((He={})[he]=be?"0":"",He[ce]=we?"0":"",He.transform=(Ie.devicePixelRatio||1)<=1?"translate("+ae+"px, "+de+"px)":"translate3d("+ae+"px, "+de+"px, 0)",He)):Object.assign({},ke,((d={})[he]=be?de+"px":"",d[ce]=we?ae+"px":"",d.transform="",d))}var gE={left:"right",right:"left",bottom:"top",top:"bottom"};function Nt(U){return U.replace(/left|right|bottom|top/g,function(d){return gE[d]})}var En={start:"end",end:"start"};function dT(U){return U.replace(/start|end/g,function(d){return En[d]})}function pT(U,d){var g=d.getRootNode&&d.getRootNode();if(U.contains(d))return!0;if(g&&_(g)){var w=d;do{if(w&&U.isSameNode(w))return!0;w=w.parentNode||w.host}while(w)}return!1}function Bt(U){return Object.assign({},U,{left:U.x,top:U.y,right:U.x+U.width,bottom:U.y+U.height})}function MT(U,d){return d===dE?Bt(function(g){var w=A(g),Z=u(g),q=w.visualViewport,Q=Z.clientWidth,ee=Z.clientHeight,Ee=0,ne=0;return q&&(Q=q.width,ee=q.height,/^((?!chrome|android).)*safari/i.test(navigator.userAgent)||(Ee=q.offsetLeft,ne=q.offsetTop)),{width:Q,height:ee,x:Ee+H(g),y:ne}}(U)):o(d)?function(g){var w=C(g);return w.top=w.top+g.clientTop,w.left=w.left+g.clientLeft,w.bottom=w.top+g.clientHeight,w.right=w.left+g.clientWidth,w.width=g.clientWidth,w.height=g.clientHeight,w.x=w.left,w.y=w.top,w}(d):Bt(function(g){var w,Z=u(g),q=L(g),Q=(w=g.ownerDocument)==null?void 0:w.body,ee=c(Z.scrollWidth,Z.clientWidth,Q?Q.scrollWidth:0,Q?Q.clientWidth:0),Ee=c(Z.scrollHeight,Z.clientHeight,Q?Q.scrollHeight:0,Q?Q.clientHeight:0),ne=-q.scrollLeft+H(g),Ae=-q.scrollTop;return b(Q||Z).direction==="rtl"&&(ne+=c(Z.clientWidth,Q?Q.clientWidth:0)-ee),{width:ee,height:Ee,x:ne,y:Ae}}(u(U)))}function UT(U){return Object.assign({},{top:0,right:0,bottom:0,left:0},U)}function mT(U,d){return d.reduce(function(g,w){return g[w]=U,g},{})}function Rt(U,d){d===void 0&&(d={});var g=d,w=g.placement,Z=w===void 0?U.placement:w,q=g.boundary,Q=q===void 0?"clippingParents":q,ee=g.rootBoundary,Ee=ee===void 0?dE:ee,ne=g.elementContext,Ae=ne===void 0?me:ne,re=g.altBoundary,We=re!==void 0&&re,ae=g.padding,Me=ae===void 0?0:ae,de=UT(typeof Me!="number"?Me:mT(Me,Ue)),ie=Ae===me?"reference":me,we=U.rects.popper,be=U.elements[We?ie:Ae],ce=function(Ke,sE,lE){var PE=sE==="clippingParents"?function(Be){var UE=h(D(Be)),oE=["absolute","fixed"].indexOf(b(Be).position)>=0&&i(Be)?W(Be):Be;return o(oE)?UE.filter(function(RE){return o(RE)&&pT(RE,oE)&&I(RE)!=="body"}):[]}(Ke):[].concat(sE),SE=[].concat(PE,[lE]),tE=SE[0],xe=SE.reduce(function(Be,UE){var oE=MT(Ke,UE);return Be.top=c(oE.top,Be.top),Be.right=P(oE.right,Be.right),Be.bottom=P(oE.bottom,Be.bottom),Be.left=c(oE.left,Be.left),Be},MT(Ke,tE));return xe.width=xe.right-xe.left,xe.height=xe.bottom-xe.top,xe.x=xe.left,xe.y=xe.top,xe}(o(be)?be:be.contextElement||u(U.elements.popper),Q,Ee),he=C(U.elements.reference),Ie=Te({reference:he,element:we,strategy:"absolute",placement:Z}),le=Bt(Object.assign({},we,Ie)),fe=Ae===me?le:he,pe={top:ce.top-fe.top+de.top,bottom:fe.bottom-ce.bottom+de.bottom,left:ce.left-fe.left+de.left,right:fe.right-ce.right+de.right},He=U.modifiersData.offset;if(Ae===me&&He){var ke=He[Z];Object.keys(pe).forEach(function(Ke){var sE=[oe,J].indexOf(Ke)>=0?1:-1,lE=[x,J].indexOf(Ke)>=0?"y":"x";pe[Ke]+=ke[lE]*sE})}return pe}function nt(U,d,g){return c(U,P(d,g))}function hT(U,d,g){return g===void 0&&(g={x:0,y:0}),{top:U.top-d.height-g.y,right:U.right-d.width+g.x,bottom:U.bottom-d.height+g.y,left:U.left-d.width-g.x}}function GT(U){return[x,oe,J,z].some(function(d){return U[d]>=0})}var tn=ME({defaultModifiers:[{name:"eventListeners",enabled:!0,phase:"write",fn:function(){},effect:function(U){var d=U.state,g=U.instance,w=U.options,Z=w.scroll,q=Z===void 0||Z,Q=w.resize,ee=Q===void 0||Q,Ee=A(d.elements.popper),ne=[].concat(d.scrollParents.reference,d.scrollParents.popper);return q&&ne.forEach(function(Ae){Ae.addEventListener("scroll",g.update,$e)}),ee&&Ee.addEventListener("resize",g.update,$e),function(){q&&ne.forEach(function(Ae){Ae.removeEventListener("scroll",g.update,$e)}),ee&&Ee.removeEventListener("resize",g.update,$e)}},data:{}},{name:"popperOffsets",enabled:!0,phase:"read",fn:function(U){var d=U.state,g=U.name;d.modifiersData[g]=Te({reference:d.rects.reference,element:d.rects.popper,strategy:"absolute",placement:d.placement})},data:{}},{name:"computeStyles",enabled:!0,phase:"beforeWrite",fn:function(U){var d=U.state,g=U.options,w=g.gpuAcceleration,Z=w===void 0||w,q=g.adaptive,Q=q===void 0||q,ee=g.roundOffsets,Ee=ee===void 0||ee,ne={placement:ze(d.placement),variation:ve(d.placement),popper:d.elements.popper,popperRect:d.rects.popper,gpuAcceleration:Z,isFixed:d.options.strategy==="fixed"};d.modifiersData.popperOffsets!=null&&(d.styles.popper=Object.assign({},d.styles.popper,se(Object.assign({},ne,{offsets:d.modifiersData.popperOffsets,position:d.options.strategy,adaptive:Q,roundOffsets:Ee})))),d.modifiersData.arrow!=null&&(d.styles.arrow=Object.assign({},d.styles.arrow,se(Object.assign({},ne,{offsets:d.modifiersData.arrow,position:"absolute",adaptive:!1,roundOffsets:Ee})))),d.attributes.popper=Object.assign({},d.attributes.popper,{"data-popper-placement":d.placement})},data:{}},{name:"applyStyles",enabled:!0,phase:"write",fn:function(U){var d=U.state;Object.keys(d.elements).forEach(function(g){var w=d.styles[g]||{},Z=d.attributes[g]||{},q=d.elements[g];i(q)&&I(q)&&(Object.assign(q.style,w),Object.keys(Z).forEach(function(Q){var ee=Z[Q];ee===!1?q.removeAttribute(Q):q.setAttribute(Q,ee===!0?"":ee)}))})},effect:function(U){var d=U.state,g={popper:{position:d.options.strategy,left:"0",top:"0",margin:"0"},arrow:{position:"absolute"},reference:{}};return Object.assign(d.elements.popper.style,g.popper),d.styles=g,d.elements.arrow&&Object.assign(d.elements.arrow.style,g.arrow),function(){Object.keys(d.elements).forEach(function(w){var Z=d.elements[w],q=d.attributes[w]||{},Q=Object.keys(d.styles.hasOwnProperty(w)?d.styles[w]:g[w]).reduce(function(ee,Ee){return ee[Ee]="",ee},{});i(Z)&&I(Z)&&(Object.assign(Z.style,Q),Object.keys(q).forEach(function(ee){Z.removeAttribute(ee)}))})}},requires:["computeStyles"]},{name:"offset",enabled:!0,phase:"main",requires:["popperOffsets"],fn:function(U){var d=U.state,g=U.options,w=U.name,Z=g.offset,q=Z===void 0?[0,0]:Z,Q=rE.reduce(function(Ae,re){return Ae[re]=function(We,ae,Me){var de=ze(We),ie=[z,x].indexOf(de)>=0?-1:1,we=typeof Me=="function"?Me(Object.assign({},ae,{placement:We})):Me,be=we[0],ce=we[1];return be=be||0,ce=(ce||0)*ie,[z,oe].indexOf(de)>=0?{x:ce,y:be}:{x:be,y:ce}}(re,d.rects,q),Ae},{}),ee=Q[d.placement],Ee=ee.x,ne=ee.y;d.modifiersData.popperOffsets!=null&&(d.modifiersData.popperOffsets.x+=Ee,d.modifiersData.popperOffsets.y+=ne),d.modifiersData[w]=Q}},{name:"flip",enabled:!0,phase:"main",fn:function(U){var d=U.state,g=U.options,w=U.name;if(!d.modifiersData[w]._skip){for(var Z=g.mainAxis,q=Z===void 0||Z,Q=g.altAxis,ee=Q===void 0||Q,Ee=g.fallbackPlacements,ne=g.padding,Ae=g.boundary,re=g.rootBoundary,We=g.altBoundary,ae=g.flipVariations,Me=ae===void 0||ae,de=g.allowedAutoPlacements,ie=d.options.placement,we=ze(ie),be=Ee||(we!==ie&&Me?function(RE){if(ze(RE)===Oe)return[];var _E=Nt(RE);return[dT(RE),_E,dT(_E)]}(ie):[Nt(ie)]),ce=[ie].concat(be).reduce(function(RE,_E){return RE.concat(ze(_E)===Oe?function($E,mE){mE===void 0&&(mE={});var LE=mE,lt=LE.placement,_t=LE.boundary,xE=LE.rootBoundary,vt=LE.padding,Ft=LE.flipVariations,XE=LE.allowedAutoPlacements,Yt=XE===void 0?rE:XE,At=ve(lt),Lt=At?Ft?fE:fE.filter(function(cE){return ve(cE)===At}):Ue,kE=Lt.filter(function(cE){return Yt.indexOf(cE)>=0});kE.length===0&&(kE=Lt);var KE=kE.reduce(function(cE,HE){return cE[HE]=Rt($E,{placement:HE,boundary:_t,rootBoundary:xE,padding:vt})[ze(HE)],cE},{});return Object.keys(KE).sort(function(cE,HE){return KE[cE]-KE[HE]})}(d,{placement:_E,boundary:Ae,rootBoundary:re,padding:ne,flipVariations:Me,allowedAutoPlacements:de}):_E)},[]),he=d.rects.reference,Ie=d.rects.popper,le=new Map,fe=!0,pe=ce[0],He=0;He=0,PE=lE?"width":"height",SE=Rt(d,{placement:ke,boundary:Ae,rootBoundary:re,altBoundary:We,padding:ne}),tE=lE?sE?oe:z:sE?J:x;he[PE]>Ie[PE]&&(tE=Nt(tE));var xe=Nt(tE),Be=[];if(q&&Be.push(SE[Ke]<=0),ee&&Be.push(SE[tE]<=0,SE[xe]<=0),Be.every(function(RE){return RE})){pe=ke,fe=!1;break}le.set(ke,Be)}if(fe)for(var UE=function(RE){var _E=ce.find(function($E){var mE=le.get($E);if(mE)return mE.slice(0,RE).every(function(LE){return LE})});if(_E)return pe=_E,"break"},oE=Me?3:1;oE>0&&UE(oE)!=="break";oE--);d.placement!==pe&&(d.modifiersData[w]._skip=!0,d.placement=pe,d.reset=!0)}},requiresIfExists:["offset"],data:{_skip:!1}},{name:"preventOverflow",enabled:!0,phase:"main",fn:function(U){var d=U.state,g=U.options,w=U.name,Z=g.mainAxis,q=Z===void 0||Z,Q=g.altAxis,ee=Q!==void 0&&Q,Ee=g.boundary,ne=g.rootBoundary,Ae=g.altBoundary,re=g.padding,We=g.tether,ae=We===void 0||We,Me=g.tetherOffset,de=Me===void 0?0:Me,ie=Rt(d,{boundary:Ee,rootBoundary:ne,padding:re,altBoundary:Ae}),we=ze(d.placement),be=ve(d.placement),ce=!be,he=Fe(we),Ie=he==="x"?"y":"x",le=d.modifiersData.popperOffsets,fe=d.rects.reference,pe=d.rects.popper,He=typeof de=="function"?de(Object.assign({},d.rects,{placement:d.placement})):de,ke=typeof He=="number"?{mainAxis:He,altAxis:He}:Object.assign({mainAxis:0,altAxis:0},He),Ke=d.modifiersData.offset?d.modifiersData.offset[d.placement]:null,sE={x:0,y:0};if(le){if(q){var lE,PE=he==="y"?x:z,SE=he==="y"?J:oe,tE=he==="y"?"height":"width",xe=le[he],Be=xe+ie[PE],UE=xe-ie[SE],oE=ae?-pe[tE]/2:0,RE=be===ye?fe[tE]:pe[tE],_E=be===ye?-pe[tE]:-fe[tE],$E=d.elements.arrow,mE=ae&&$E?N($E):{width:0,height:0},LE=d.modifiersData["arrow#persistent"]?d.modifiersData["arrow#persistent"].padding:{top:0,right:0,bottom:0,left:0},lt=LE[PE],_t=LE[SE],xE=nt(0,fe[tE],mE[tE]),vt=ce?fe[tE]/2-oE-xE-lt-ke.mainAxis:RE-xE-lt-ke.mainAxis,Ft=ce?-fe[tE]/2+oE+xE+_t+ke.mainAxis:_E+xE+_t+ke.mainAxis,XE=d.elements.arrow&&W(d.elements.arrow),Yt=XE?he==="y"?XE.clientTop||0:XE.clientLeft||0:0,At=(lE=Ke==null?void 0:Ke[he])!=null?lE:0,Lt=xe+Ft-At,kE=nt(ae?P(Be,xe+vt-At-Yt):Be,xe,ae?c(UE,Lt):UE);le[he]=kE,sE[he]=kE-xe}if(ee){var KE,cE=he==="x"?x:z,HE=he==="x"?J:oe,bE=le[Ie],Ct=Ie==="y"?"height":"width",gT=bE+ie[cE],HT=bE-ie[HE],Vt=[x,z].indexOf(we)!==-1,bT=(KE=Ke==null?void 0:Ke[Ie])!=null?KE:0,yT=Vt?gT:bE-fe[Ct]-pe[Ct]-bT+ke.altAxis,BT=Vt?bE+fe[Ct]+pe[Ct]-bT-ke.altAxis:HT,vT=ae&&Vt?function(Tn,rn,Wt){var FT=nt(Tn,rn,Wt);return FT>Wt?Wt:FT}(yT,bE,BT):nt(ae?yT:gT,bE,ae?BT:HT);le[Ie]=vT,sE[Ie]=vT-bE}d.modifiersData[w]=sE}},requiresIfExists:["offset"]},{name:"arrow",enabled:!0,phase:"main",fn:function(U){var d,g=U.state,w=U.name,Z=U.options,q=g.elements.arrow,Q=g.modifiersData.popperOffsets,ee=ze(g.placement),Ee=Fe(ee),ne=[z,oe].indexOf(ee)>=0?"height":"width";if(q&&Q){var Ae=function(pe,He){return UT(typeof(pe=typeof pe=="function"?pe(Object.assign({},He.rects,{placement:He.placement})):pe)!="number"?pe:mT(pe,Ue))}(Z.padding,g),re=N(q),We=Ee==="y"?x:z,ae=Ee==="y"?J:oe,Me=g.rects.reference[ne]+g.rects.reference[Ee]-Q[Ee]-g.rects.popper[ne],de=Q[Ee]-g.rects.reference[Ee],ie=W(q),we=ie?Ee==="y"?ie.clientHeight||0:ie.clientWidth||0:0,be=Me/2-de/2,ce=Ae[We],he=we-re[ne]-Ae[ae],Ie=we/2-re[ne]/2+be,le=nt(ce,Ie,he),fe=Ee;g.modifiersData[w]=((d={})[fe]=le,d.centerOffset=le-Ie,d)}},effect:function(U){var d=U.state,g=U.options.element,w=g===void 0?"[data-popper-arrow]":g;w!=null&&(typeof w!="string"||(w=d.elements.popper.querySelector(w)))&&pT(d.elements.popper,w)&&(d.elements.arrow=w)},requires:["popperOffsets"],requiresIfExists:["preventOverflow"]},{name:"hide",enabled:!0,phase:"main",requiresIfExists:["preventOverflow"],fn:function(U){var d=U.state,g=U.name,w=d.rects.reference,Z=d.rects.popper,q=d.modifiersData.preventOverflow,Q=Rt(d,{elementContext:"reference"}),ee=Rt(d,{altBoundary:!0}),Ee=hT(Q,w),ne=hT(ee,Z,q),Ae=GT(Ee),re=GT(ne);d.modifiersData[g]={referenceClippingOffsets:Ee,popperEscapeOffsets:ne,isReferenceHidden:Ae,hasPopperEscaped:re},d.attributes.popper=Object.assign({},d.attributes.popper,{"data-popper-reference-hidden":Ae,"data-popper-escaped":re})}}]})}},t={};function r(n){var s=t[n];if(s!==void 0)return s.exports;var S=t[n]={exports:{}};return T[n](S,S.exports,r),S.exports}r.d=(n,s)=>{for(var S in s)r.o(s,S)&&!r.o(n,S)&&Object.defineProperty(n,S,{enumerable:!0,get:s[S]})},r.o=(n,s)=>Object.prototype.hasOwnProperty.call(n,s),r.r=n=>{typeof Symbol<"u"&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})};var R={};return r.r(R),r(661),r(795),r(682),r(284),r(181),r(778),r(51),r(185),R})()})})(Qn);function Zn(E){let e=E[0].title+"",T;return{c(){T=te(e)},m(t,r){V(t,T,r)},p(t,r){r&1&&e!==(e=t[0].title+"")&&Le(T,e)},d(t){t&&Y(T)}}}function jn(E){let e,T;return{c(){e=te("Welcome to "),T=f("span"),T.textContent="Vanna.AI",a(T,"class","nav-title")},m(t,r){V(t,e,r),V(t,T,r)},p:j,d(t){t&&(Y(e),Y(T))}}}function zn(E){let e,T,t,r,R=E[0].subtitle+"",n;function s(o,i){return o[0].title=="Welcome to Vanna.AI"?jn:Zn}let S=s(E),A=S(E);return{c(){e=f("div"),T=f("h1"),A.c(),t=$(),r=f("p"),n=te(R),a(T,"class","text-3xl font-bold text-gray-800 sm:text-4xl dark:text-white"),a(r,"class","mt-3 text-gray-600 dark:text-gray-400"),a(e,"class","max-w-4xl px-4 sm:px-6 lg:px-8 mx-auto text-center")},m(o,i){V(o,e,i),l(e,T),A.m(T,null),l(e,t),l(e,r),l(r,n)},p(o,[i]){S===(S=s(o))&&A?A.p(o,i):(A.d(1),A=S(o),A&&(A.c(),A.m(T,null))),i&1&&R!==(R=o[0].subtitle+"")&&Le(n,R)},i:j,o:j,d(o){o&&Y(e),A.d()}}}function eA(E,e,T){let t;return eE(E,VE,r=>T(0,t=r)),[t]}class EA extends ue{constructor(e){super(),Ce(this,e,eA,zn,_e,{})}}function tA(E){let e,T;const t=E[1].default,r=Ut(t,E,E[0],null);return{c(){e=f("p"),r&&r.c(),a(e,"class","text-gray-800 dark:text-gray-200")},m(R,n){V(R,e,n),r&&r.m(e,null),T=!0},p(R,[n]){r&&r.p&&(!T||n&1)&&ht(r,t,R,R[0],T?mt(t,R[0],n,null):Gt(R[0]),null)},i(R){T||(m(r,R),T=!0)},o(R){y(r,R),T=!1},d(R){R&&Y(e),r&&r.d(R)}}}function TA(E,e,T){let{$$slots:t={},$$scope:r}=e;return E.$$set=R=>{"$$scope"in R&&T(0,r=R.$$scope)},[r,t]}class aE extends ue{constructor(e){super(),Ce(this,e,TA,tA,_e,{})}}function rA(E){let e;return{c(){e=te(E[0])},m(T,t){V(T,e,t)},p(T,t){t&1&&Le(e,T[0])},d(T){T&&Y(e)}}}function RA(E){let e,T,t,r,R,n,s,S,A;s=new aE({props:{$$slots:{default:[rA]},$$scope:{ctx:E}}});const o=E[1].default,i=Ut(o,E,E[2],null);return{c(){e=f("li"),T=f("div"),t=f("div"),r=f("span"),r.innerHTML='You',R=$(),n=f("div"),K(s.$$.fragment),S=$(),i&&i.c(),a(r,"class","flex-shrink-0 inline-flex items-center justify-center h-[2.375rem] w-[2.375rem] rounded-full bg-gray-600"),a(n,"class","grow mt-2 space-y-3"),a(t,"class","max-w-2xl flex gap-x-2 sm:gap-x-4"),a(T,"class","max-w-4xl px-4 sm:px-6 lg:px-8 mx-auto"),a(e,"class","py-2 sm:py-4")},m(_,c){V(_,e,c),l(e,T),l(T,t),l(t,r),l(t,R),l(t,n),X(s,n,null),l(n,S),i&&i.m(n,null),A=!0},p(_,[c]){const P={};c&5&&(P.$$scope={dirty:c,ctx:_}),s.$set(P),i&&i.p&&(!A||c&4)&&ht(i,o,_,_[2],A?mt(o,_[2],c,null):Gt(_[2]),null)},i(_){A||(m(s.$$.fragment,_),m(i,_),A=!0)},o(_){y(s.$$.fragment,_),y(i,_),A=!1},d(_){_&&Y(e),k(s),i&&i.d(_)}}}function nA(E,e,T){let{$$slots:t={},$$scope:r}=e,{message:R}=e;return E.$$set=n=>{"message"in n&&T(0,R=n.message),"$$scope"in n&&T(2,r=n.$$scope)},[R,t,r]}class WE extends ue{constructor(e){super(),Ce(this,e,nA,RA,_e,{message:0})}}function AA(E){let e,T,t;return{c(){e=f("button"),e.innerHTML='',a(e,"type","button"),a(e,"class","inline-flex flex-shrink-0 justify-center items-center size-8 rounded-lg text-gray-500 hover:text-blue-600 focus:z-10 focus:outline-none focus:ring-2 focus:ring-blue-500 dark:hover:text-blue-500 dark:focus:outline-none dark:focus:ring-1 dark:focus:ring-gray-600")},m(r,R){V(r,e,R),T||(t=Ne(e,"click",E[1]),T=!0)},p:j,d(r){r&&Y(e),T=!1,t()}}}function sA(E){let e;return{c(){e=f("button"),e.innerHTML='',a(e,"type","button"),a(e,"class","animate-ping animate-pulse inline-flex flex-shrink-0 justify-center items-center size-8 rounded-lg text-red-500 hover:text-red-600 focus:z-10 focus:outline-none focus:ring-2 focus:ring-red-500 dark:hover:text-red-500 dark:focus:outline-none dark:focus:ring-1 dark:focus:ring-red-600")},m(T,t){V(T,e,t)},p:j,d(T){T&&Y(e)}}}function SA(E){let e;function T(R,n){return R[0]?sA:AA}let t=T(E),r=t(E);return{c(){r.c(),e=je()},m(R,n){r.m(R,n),V(R,e,n)},p(R,[n]){t===(t=T(R))&&r?r.p(R,n):(r.d(1),r=t(R),r&&(r.c(),r.m(e.parentNode,e)))},i:j,o:j,d(R){R&&Y(e),r.d(R)}}}function oA(E,e,T){let{newMessage:t}=e,r=!1;function R(){if(T(0,r=!0),pR.set(!0),"webkitSpeechRecognition"in window)var n=new window.webkitSpeechRecognition;else var n=new window.SpeechRecognition;n.lang="en-US",n.start(),n.onresult=s=>{const S=s.results[0][0].transcript;console.log(S),T(2,t=S),T(0,r=!1)},n.onend=()=>{T(0,r=!1)},n.onerror=()=>{T(0,r=!1)}}return E.$$set=n=>{"newMessage"in n&&T(2,t=n.newMessage)},[r,R,t]}class OA extends ue{constructor(e){super(),Ce(this,e,oA,SA,_e,{newMessage:2})}}function iA(E){let e,T,t,r,R,n,s,S,A,o,i,_,c,P,p;function C(I){E[5](I)}let L={};return E[0]!==void 0&&(L.newMessage=E[0]),A=new OA({props:L}),iT.push(()=>cn(A,"newMessage",C)),{c(){e=f("div"),T=f("input"),t=$(),r=f("div"),R=f("div"),n=f("div"),n.innerHTML="",s=$(),S=f("div"),K(A.$$.fragment),i=$(),_=f("button"),_.innerHTML='',a(T,"type","text"),a(T,"class","p-4 pb-12 block w-full bg-gray-100 border-gray-200 rounded-md text-sm focus:border-blue-500 focus:ring-blue-500 dark:bg-slate-800 dark:border-gray-700 dark:text-gray-400"),a(T,"placeholder","Ask me a question about your data that I can turn into SQL."),a(n,"class","flex items-center"),a(_,"type","button"),a(_,"class","inline-flex flex-shrink-0 justify-center items-center h-8 w-8 rounded-md text-white bg-blue-600 hover:bg-blue-500 focus:z-10 focus:outline-none focus:ring-2 focus:ring-blue-500 transition-all"),a(S,"class","flex items-center gap-x-1"),a(R,"class","flex justify-between items-center"),a(r,"class","absolute bottom-px inset-x-px p-2 rounded-b-md bg-gray-100 dark:bg-slate-800"),a(e,"class","relative")},m(I,u){V(I,e,u),l(e,T),Ye(T,E[0]),l(e,t),l(e,r),l(r,R),l(R,n),l(R,s),l(R,S),X(A,S,null),l(S,i),l(S,_),c=!0,P||(p=[Ne(T,"input",E[4]),Ne(T,"keydown",E[1]),Ne(_,"click",E[2])],P=!0)},p(I,[u]){u&1&&T.value!==I[0]&&Ye(T,I[0]);const H={};!o&&u&1&&(o=!0,H.newMessage=I[0],ln(()=>o=!1)),A.$set(H)},i(I){c||(m(A.$$.fragment,I),c=!0)},o(I){y(A.$$.fragment,I),c=!1},d(I){I&&Y(e),k(A),P=!1,NE(p)}}}function aA(E,e,T){let t;eE(E,BE,A=>T(0,t=A));let{onSubmit:r}=e;function R(A){A.key==="Enter"&&(r(t),OT(BE,t="",t),A.preventDefault())}function n(){r(t),OT(BE,t="",t)}function s(){t=this.value,BE.set(t)}function S(A){t=A,BE.set(t)}return E.$$set=A=>{"onSubmit"in A&&T(3,r=A.onSubmit)},[t,R,n,r,s,S]}class IA extends ue{constructor(e){super(),Ce(this,e,aA,iA,_e,{onSubmit:3})}}function NA(E){let e;return{c(){e=f("div"),e.innerHTML='',a(e,"class","lg:hidden flex justify-end mb-2 sm:mb-3")},m(T,t){V(T,e,t)},p:j,i:j,o:j,d(T){T&&Y(e)}}}class lA extends ue{constructor(e){super(),Ce(this,e,null,NA,_e,{})}}function _A(E){let e,T,t,r;return{c(){e=f("button"),T=te(E[0]),a(e,"type","button"),a(e,"class","mb-2.5 mr-1.5 py-2 px-3 inline-flex justify-center items-center gap-x-2 rounded-md border border-blue-600 bg-white text-blue-600 align-middle hover:bg-blue-50 text-sm dark:bg-slate-900 dark:text-blue-500 dark:border-blue-500 dark:hover:text-blue-400 dark:hover:border-blue-400")},m(R,n){V(R,e,n),l(e,T),t||(r=Ne(e,"click",E[1]),t=!0)},p(R,[n]){n&1&&Le(T,R[0])},i:j,o:j,d(R){R&&Y(e),t=!1,r()}}}function LA(E,e,T){let{message:t}=e,{onSubmit:r}=e;function R(){r(t)}return E.$$set=n=>{"message"in n&&T(0,t=n.message),"onSubmit"in n&&T(2,r=n.onSubmit)},[t,R,r]}class IE extends ue{constructor(e){super(),Ce(this,e,LA,_A,_e,{message:0,onSubmit:2})}}function CA(E){let e,T,t,r,R,n,s,S,A,o,i;return{c(){e=f("span"),T=OE("svg"),t=OE("defs"),r=OE("linearGradient"),R=OE("stop"),n=OE("stop"),s=OE("g"),S=OE("g"),A=OE("path"),o=OE("path"),a(R,"offset","0"),a(R,"stop-color","#009efd"),a(n,"offset","1"),a(n,"stop-color","#2af598"),a(r,"gradientTransform","matrix(1.09331 0 0 1.09331 -47.1838 -88.8946)"),a(r,"gradientUnits","userSpaceOnUse"),a(r,"id","LinearGradient"),a(r,"x1","237.82"),a(r,"x2","785.097"),a(r,"y1","549.609"),a(r,"y2","549.609"),a(A,"d","M117.718 228.798C117.718 119.455 206.358 30.8151 315.701 30.8151L708.299 30.8151C817.642 30.8151 906.282 119.455 906.282 228.798L906.282 795.202C906.282 904.545 817.642 993.185 708.299 993.185L315.701 993.185C206.358 993.185 117.718 904.545 117.718 795.202L117.718 228.798Z"),a(A,"fill","#0f172a"),a(A,"fill-rule","nonzero"),a(A,"opacity","1"),a(A,"stroke","#374151"),a(A,"stroke-linecap","butt"),a(A,"stroke-linejoin","round"),a(A,"stroke-width","20"),a(o,"d","M212.828 215.239C213.095 281.169 213.629 413.028 213.629 413.028C213.629 413.028 511.51 808.257 513.993 809.681C612.915 677.809 810.759 414.065 810.759 414.065C810.759 414.065 811.034 280.901 811.172 214.319C662.105 362.973 662.105 362.973 513.038 511.627C362.933 363.433 362.933 363.433 212.828 215.239Z"),a(o,"fill","url(#LinearGradient)"),a(o,"fill-rule","nonzero"),a(o,"opacity","1"),a(o,"stroke","none"),a(S,"opacity","1"),a(s,"id","Layer-1"),a(T,"height","100%"),a(T,"stroke-miterlimit","10"),ct(T,"fill-rule","nonzero"),ct(T,"clip-rule","evenodd"),ct(T,"stroke-linecap","round"),ct(T,"stroke-linejoin","round"),a(T,"version","1.1"),a(T,"viewBox","0 0 1024 1024"),a(T,"width","100%"),a(T,"xml:space","preserve"),a(T,"xmlns","http://www.w3.org/2000/svg"),a(e,"class",i="flex-shrink-0 w-[2.375rem] h-[2.375rem] "+E[0])},m(_,c){V(_,e,c),l(e,T),l(T,t),l(t,r),l(r,R),l(r,n),l(T,s),l(s,S),l(S,A),l(S,o)},p(_,[c]){c&1&&i!==(i="flex-shrink-0 w-[2.375rem] h-[2.375rem] "+_[0])&&a(e,"class",i)},i:j,o:j,d(_){_&&Y(e)}}}function uA(E,e,T){let t,{animate:r=!1}=e;return E.$$set=R=>{"animate"in R&&T(1,r=R.animate)},E.$$.update=()=>{E.$$.dirty&2&&T(0,t=r?"animate-bounce":"")},[t,r]}class GR extends ue{constructor(e){super(),Ce(this,e,uA,CA,_e,{animate:1})}}function cA(E){let e,T,t,r,R;T=new GR({});const n=E[1].default,s=Ut(n,E,E[0],null);return{c(){e=f("li"),K(T.$$.fragment),t=$(),r=f("div"),s&&s.c(),a(r,"class","space-y-3 overflow-x-auto overflow-y-hidden whitespace-break-spaces w-full"),a(e,"class","max-w-4xl py-2 px-4 sm:px-6 lg:px-8 mx-auto flex gap-x-2 sm:gap-x-4")},m(S,A){V(S,e,A),X(T,e,null),l(e,t),l(e,r),s&&s.m(r,null),R=!0},p(S,[A]){s&&s.p&&(!R||A&1)&&ht(s,n,S,S[0],R?mt(n,S[0],A,null):Gt(S[0]),null)},i(S){R||(m(T.$$.fragment,S),m(s,S),R=!0)},o(S){y(T.$$.fragment,S),y(s,S),R=!1},d(S){S&&Y(e),k(T),s&&s.d(S)}}}function fA(E,e,T){let{$$slots:t={},$$scope:r}=e;return E.$$set=R=>{"$$scope"in R&&T(0,r=R.$$scope)},[r,t]}class Ze extends ue{constructor(e){super(),Ce(this,e,fA,cA,_e,{})}}function PA(E){let e;return{c(){e=te("Thinking...")},m(T,t){V(T,e,t)},d(T){T&&Y(e)}}}function DA(E){let e,T,t,r,R,n;return T=new GR({props:{animate:!0}}),R=new aE({props:{$$slots:{default:[PA]},$$scope:{ctx:E}}}),{c(){e=f("li"),K(T.$$.fragment),t=$(),r=f("div"),K(R.$$.fragment),a(r,"class","space-y-3"),a(e,"class","max-w-4xl py-2 px-4 sm:px-6 lg:px-8 mx-auto flex gap-x-2 sm:gap-x-4")},m(s,S){V(s,e,S),X(T,e,null),l(e,t),l(e,r),X(R,r,null),n=!0},p(s,[S]){const A={};S&1&&(A.$$scope={dirty:S,ctx:s}),R.$set(A)},i(s){n||(m(T.$$.fragment,s),m(R.$$.fragment,s),n=!0)},o(s){y(T.$$.fragment,s),y(R.$$.fragment,s),n=!1},d(s){s&&Y(e),k(T),k(R)}}}class dA extends ue{constructor(e){super(),Ce(this,e,null,DA,_e,{})}}function pA(E){let e,T,t,r,R,n,s,S,A,o,i;return{c(){e=f("ul"),T=f("li"),t=f("div"),r=f("span"),r.textContent="CSV",R=$(),n=f("a"),s=OE("svg"),S=OE("path"),A=OE("path"),o=te(`
Download`),a(r,"class","mr-3 flex-1 w-0 truncate"),a(S,"d","M.5 9.9a.5.5 0 0 1 .5.5v2.5a1 1 0 0 0 1 1h12a1 1 0 0 0 1-1v-2.5a.5.5 0 0 1 1 0v2.5a2 2 0 0 1-2 2H2a2 2 0 0 1-2-2v-2.5a.5.5 0 0 1 .5-.5z"),a(A,"d","M7.646 11.854a.5.5 0 0 0 .708 0l3-3a.5.5 0 0 0-.708-.708L8.5 10.293V1.5a.5.5 0 0 0-1 0v8.793L5.354 8.146a.5.5 0 1 0-.708.708l3 3z"),a(s,"class","flex-shrink-0 w-3 h-3"),a(s,"width","16"),a(s,"height","16"),a(s,"viewBox","0 0 16 16"),a(s,"fill","currentColor"),a(n,"class","flex items-center gap-x-2 text-gray-500 hover:text-blue-500 whitespace-nowrap"),a(n,"href",i="/api/v0/download_csv?id="+E[0]),a(t,"class","w-full flex justify-between truncate"),a(T,"class","flex items-center gap-x-2 p-3 text-sm bg-white border text-gray-800 first:rounded-t-lg first:mt-0 last:rounded-b-lg dark:bg-slate-900 dark:border-gray-700 dark:text-gray-200"),a(e,"class","flex flex-col justify-end text-start -space-y-px")},m(_,c){V(_,e,c),l(e,T),l(T,t),l(t,r),l(t,R),l(t,n),l(n,s),l(s,S),l(s,A),l(n,o)},p(_,[c]){c&1&&i!==(i="/api/v0/download_csv?id="+_[0])&&a(n,"href",i)},i:j,o:j,d(_){_&&Y(e)}}}function MA(E,e,T){let{id:t}=e;return E.$$set=r=>{"id"in r&&T(0,t=r.id)},[t]}class UA extends ue{constructor(e){super(),Ce(this,e,MA,pA,_e,{id:0})}}function KT(E,e,T){const t=E.slice();return t[5]=e[T],t}function JT(E,e,T){const t=E.slice();return t[8]=e[T],t}function qT(E,e,T){const t=E.slice();return t[8]=e[T],t}function QT(E){let e,T,t,r;return{c(){e=f("th"),T=f("div"),t=f("span"),t.textContent=`${E[8]}`,r=$(),a(t,"class","text-xs font-semibold uppercase tracking-wide text-gray-800 dark:text-gray-200"),a(T,"class","flex items-center gap-x-2"),a(e,"scope","col"),a(e,"class","px-6 py-3 text-left")},m(R,n){V(R,e,n),l(e,T),l(T,t),l(e,r)},p:j,d(R){R&&Y(e)}}}function ZT(E){let e,T,t;return{c(){e=f("td"),T=f("div"),t=f("span"),t.textContent=`${E[5][E[8]]}`,a(t,"class","text-gray-800 dark:text-gray-200"),a(T,"class","px-6 py-3"),a(e,"class","h-px w-px whitespace-nowrap")},m(r,R){V(r,e,R),l(e,T),l(T,t)},p:j,d(r){r&&Y(e)}}}function jT(E){let e,T,t=De(E[3]),r=[];for(let R=0;R{C=null}),ge())},i(L){i||(m(C),i=!0)},o(L){y(C),i=!1},d(L){L&&(Y(e),Y(A),Y(o)),nE(c,L),nE(p,L),C&&C.d(L)}}}function hA(E,e,T){let t;eE(E,VE,S=>T(1,t=S));let{id:r}=e,{df:R}=e,n=JSON.parse(R),s=n.length>0?Object.keys(n[0]):[];return E.$$set=S=>{"id"in S&&T(0,r=S.id),"df"in S&&T(4,R=S.df)},[r,t,n,s,R]}class gR extends ue{constructor(e){super(),Ce(this,e,hA,mA,_e,{id:0,df:4})}}function GA(E){let e;return{c(){e=f("div"),a(e,"id",E[0])},m(T,t){V(T,e,t)},p:j,i:j,o:j,d(T){T&&Y(e)}}}function gA(E,e,T){let{fig:t}=e,r=JSON.parse(t),R=Math.random().toString(36).substring(2,15)+Math.random().toString(36).substring(2,15);return DR(()=>{Plotly.newPlot(document.getElementById(R),r,{responsive:!0})}),E.$$set=n=>{"fig"in n&&T(1,t=n.fig)},[R,t]}class HR extends ue{constructor(e){super(),Ce(this,e,gA,GA,_e,{fig:1})}}function HA(E){let e,T,t,r;return{c(){e=f("button"),T=te(E[0]),a(e,"type","button"),a(e,"class","mb-2.5 mr-1.5 py-3 px-4 inline-flex justify-center items-center gap-2 rounded-md border-2 border-green-200 font-semibold text-green-500 hover:text-white hover:bg-green-500 hover:border-green-500 focus:outline-none focus:ring-2 focus:ring-green-200 focus:ring-offset-2 transition-all text-sm dark:focus:ring-offset-gray-800")},m(R,n){V(R,e,n),l(e,T),t||(r=Ne(e,"click",E[1]),t=!0)},p(R,[n]){n&1&&Le(T,R[0])},i:j,o:j,d(R){R&&Y(e),t=!1,r()}}}function bA(E,e,T){let{message:t}=e,{onSubmit:r}=e;function R(){r(t)}return E.$$set=n=>{"message"in n&&T(0,t=n.message),"onSubmit"in n&&T(2,r=n.onSubmit)},[t,R,r]}class yA extends ue{constructor(e){super(),Ce(this,e,bA,HA,_e,{message:0,onSubmit:2})}}function BA(E){let e,T,t,r,R,n,s,S,A;return{c(){e=f("div"),T=f("div"),t=f("div"),t.innerHTML='',r=$(),R=f("div"),n=f("h3"),n.textContent="Error",s=$(),S=f("div"),A=te(E[0]),a(t,"class","flex-shrink-0"),a(n,"class","text-sm text-yellow-800 font-semibold"),a(S,"class","mt-1 text-sm text-yellow-700"),a(R,"class","ml-4"),a(T,"class","flex"),a(e,"class","bg-yellow-50 border border-yellow-200 rounded-md p-4"),a(e,"role","alert")},m(o,i){V(o,e,i),l(e,T),l(T,t),l(T,r),l(T,R),l(R,n),l(R,s),l(R,S),l(S,A)},p(o,[i]){i&1&&Le(A,o[0])},i:j,o:j,d(o){o&&Y(e)}}}function vA(E,e,T){let{message:t}=e;return E.$$set=r=>{"message"in r&&T(0,t=r.message)},[t]}let PT=class extends ue{constructor(e){super(),Ce(this,e,vA,BA,_e,{message:0})}};function FA(E){let e,T;const t=E[1].default,r=Ut(t,E,E[0],null);return{c(){e=f("div"),r&&r.c(),a(e,"class","font-mono whitespace-pre-wrap")},m(R,n){V(R,e,n),r&&r.m(e,null),T=!0},p(R,[n]){r&&r.p&&(!T||n&1)&&ht(r,t,R,R[0],T?mt(t,R[0],n,null):Gt(R[0]),null)},i(R){T||(m(r,R),T=!0)},o(R){y(r,R),T=!1},d(R){R&&Y(e),r&&r.d(R)}}}function YA(E,e,T){let{$$slots:t={},$$scope:r}=e;return E.$$set=R=>{"$$scope"in R&&T(0,r=R.$$scope)},[r,t]}class bR extends ue{constructor(e){super(),Ce(this,e,YA,FA,_e,{})}}function VA(E){let e;return{c(){e=te(E[1])},m(T,t){V(T,e,t)},p(T,t){t&2&&Le(e,T[1])},d(T){T&&Y(e)}}}function WA(E){let e,T,t,r,R,n,s,S;return t=new IE({props:{message:"Run SQL",onSubmit:E[3]}}),R=new aE({props:{$$slots:{default:[VA]},$$scope:{ctx:E}}}),{c(){e=f("textarea"),T=$(),K(t.$$.fragment),r=$(),K(R.$$.fragment),a(e,"rows","6"),a(e,"class","block p-2.5 w-full text-blue-600 hover:text-blue-500 dark:text-blue-500 dark:hover:text-blue-400 bg-gray-50 rounded-lg border border-gray-300 focus:ring-blue-500 focus:border-blue-500 dark:bg-gray-700 dark:border-gray-600 dark:placeholder-gray-400 dark:focus:ring-blue-500 dark:focus:border-blue-500 font-mono"),a(e,"placeholder","SELECT col1, col2, col3 FROM ...")},m(A,o){V(A,e,o),Ye(e,E[1]),V(A,T,o),X(t,A,o),V(A,r,o),X(R,A,o),n=!0,s||(S=Ne(e,"input",E[2]),s=!0)},p(A,[o]){o&2&&Ye(e,A[1]);const i={};o&3&&(i.onSubmit=A[3]),t.$set(i);const _={};o&18&&(_.$$scope={dirty:o,ctx:A}),R.$set(_)},i(A){n||(m(t.$$.fragment,A),m(R.$$.fragment,A),n=!0)},o(A){y(t.$$.fragment,A),y(R.$$.fragment,A),n=!1},d(A){A&&(Y(e),Y(T),Y(r)),k(t,A),k(R,A),s=!1,S()}}}function wA(E,e,T){let t;eE(E,Et,s=>T(1,t=s));let{onSubmit:r}=e;function R(){t=this.value,Et.set(t)}const n=()=>r(t);return E.$$set=s=>{"onSubmit"in s&&T(0,r=s.onSubmit)},[r,t,R,n]}class $A extends ue{constructor(e){super(),Ce(this,e,wA,WA,_e,{onSubmit:0})}}function xA(E){let e,T,t,r,R,n;return t=new IE({props:{message:E[3],onSubmit:E[5]}}),{c(){e=f("textarea"),T=$(),K(t.$$.fragment),a(e,"rows","6"),a(e,"class","block p-2.5 w-full text-blue-600 hover:text-blue-500 dark:text-blue-500 dark:hover:text-blue-400 bg-gray-50 rounded-lg border border-gray-300 focus:ring-blue-500 focus:border-blue-500 dark:bg-gray-700 dark:border-gray-600 dark:placeholder-gray-400 dark:focus:ring-blue-500 dark:focus:border-blue-500 font-mono"),a(e,"placeholder",E[2])},m(s,S){V(s,e,S),Ye(e,E[0]),V(s,T,S),X(t,s,S),r=!0,R||(n=Ne(e,"input",E[4]),R=!0)},p(s,[S]){(!r||S&4)&&a(e,"placeholder",s[2]),S&1&&Ye(e,s[0]);const A={};S&8&&(A.message=s[3]),S&3&&(A.onSubmit=s[5]),t.$set(A)},i(s){r||(m(t.$$.fragment,s),r=!0)},o(s){y(t.$$.fragment,s),r=!1},d(s){s&&(Y(e),Y(T)),k(t,s),R=!1,n()}}}function XA(E,e,T){let{onSubmit:t}=e,{currentValue:r}=e,{placeholder:R}=e,{buttonText:n}=e;function s(){r=this.value,T(0,r)}const S=()=>t(r);return E.$$set=A=>{"onSubmit"in A&&T(1,t=A.onSubmit),"currentValue"in A&&T(0,r=A.currentValue),"placeholder"in A&&T(2,R=A.placeholder),"buttonText"in A&&T(3,n=A.buttonText)},[r,t,R,n,s,S]}class kA extends ue{constructor(e){super(),Ce(this,e,XA,xA,_e,{onSubmit:1,currentValue:0,placeholder:2,buttonText:3})}}function KA(E){let e,T;return e=new IE({props:{message:"Play",onSubmit:E[2]}}),{c(){K(e.$$.fragment)},m(t,r){X(e,t,r),T=!0},p(t,[r]){const R={};r&1&&(R.onSubmit=t[2]),e.$set(R)},i(t){T||(m(e.$$.fragment,t),T=!0)},o(t){y(e.$$.fragment,t),T=!1},d(t){k(e,t)}}}function er(E){if("speechSynthesis"in window){const e=new SpeechSynthesisUtterance(E);e.lang="en-US",e.volume=1,e.rate=1,e.pitch=1,window.speechSynthesis.speak(e)}else console.error("SpeechSynthesis API is not supported in this browser.")}function JA(E,e,T){let t;eE(E,pR,n=>T(1,t=n));let{message:r}=e;const R=()=>er(r);return E.$$set=n=>{"message"in n&&T(0,r=n.message)},E.$$.update=()=>{E.$$.dirty&3&&t&&er(r)},[r,t,R]}class qA extends ue{constructor(e){super(),Ce(this,e,JA,KA,_e,{message:0})}}function QA(E){let e,T,t;return{c(){e=f("button"),e.textContent="Open Debugger",T=$(),t=f("div"),t.innerHTML='